code_rs/coding/
trellis.rs

1//! Implements encoding and decoding of the "trellis" convolutional error correcting code
2//! specified by P25.
3//!
4//! Encoding is done with a state machine and decoding is done with the Viterbi algorithm,
5//! adapted from *Coding Theory and Cryptography: The Essentials*, Hankerson, Hoffman, et
6//! al, 2000.
7
8use std;
9use std::ops::{Deref, DerefMut};
10
11use collect_slice::CollectSlice;
12
13use crate::bits;
14
15use self::Decision::*;
16
17/// Half-rate convolutional ("trellis") code state machine.
18pub type DibitFSM = TrellisFSM<DibitStates>;
19
20/// 3/4-rate convolutional ("trellis") code state machine.
21pub type TribitFSM = TrellisFSM<TribitStates>;
22
23/// Half-rate convolution ("trellis") code decoder.
24pub type DibitDecoder<T> = ViterbiDecoder<DibitStates, DibitHistory, DibitWalks, T>;
25
26/// 3/4-rate convolution ("trellis") code decoder.
27pub type TribitDecoder<T> = ViterbiDecoder<TribitStates, TribitHistory, TribitWalks, T>;
28
29pub trait States {
30    /// Symbol type to use for states and input.
31    type Symbol;
32
33    /// Number of rows/columns in the state machine.
34    fn size() -> usize;
35
36    /// Get the "constellation point" on the transition from the current state to the next
37    /// state.
38    fn pair_idx(cur: usize, next: usize) -> usize;
39
40    /// Convert the given symbol to a state.
41    fn state(input: Self::Symbol) -> usize;
42    /// Convert the given state to a symbol.
43    fn symbol(state: usize) -> Self::Symbol;
44
45    /// Get the "flushing" symbol fed in at the end of a stream.
46    fn finisher() -> Self::Symbol;
47
48    /// Get the dibit pair on the transition from the current state to the next state.
49    fn pair(state: usize, next: usize) -> (bits::Dibit, bits::Dibit) {
50        const PAIRS: [(u8, u8); 16] = [
51            (0b00, 0b10),
52            (0b10, 0b10),
53            (0b01, 0b11),
54            (0b11, 0b11),
55            (0b11, 0b10),
56            (0b01, 0b10),
57            (0b10, 0b11),
58            (0b00, 0b11),
59            (0b11, 0b01),
60            (0b01, 0b01),
61            (0b10, 0b00),
62            (0b00, 0b00),
63            (0b00, 0b01),
64            (0b10, 0b01),
65            (0b01, 0b00),
66            (0b11, 0b00),
67        ];
68
69        let (hi, lo) = PAIRS[Self::pair_idx(state, next)];
70        (bits::Dibit::new(hi), bits::Dibit::new(lo))
71    }
72}
73
74/// Half-rate state machine (dibit input).
75pub struct DibitStates;
76
77impl States for DibitStates {
78    type Symbol = bits::Dibit;
79
80    fn size() -> usize {
81        4
82    }
83
84    fn pair_idx(cur: usize, next: usize) -> usize {
85        const STATES: [[usize; 4]; 4] =
86            [[0, 15, 12, 3], [4, 11, 8, 7], [13, 2, 1, 14], [9, 6, 5, 10]];
87
88        STATES[cur][next]
89    }
90
91    fn state(input: bits::Dibit) -> usize {
92        input.bits() as usize
93    }
94    fn finisher() -> Self::Symbol {
95        bits::Dibit::new(0b00)
96    }
97    fn symbol(state: usize) -> Self::Symbol {
98        bits::Dibit::new(state as u8)
99    }
100}
101
102/// 3/4-rate state machine (tribit input).
103pub struct TribitStates;
104
105impl States for TribitStates {
106    type Symbol = bits::Tribit;
107
108    fn size() -> usize {
109        8
110    }
111
112    fn pair_idx(cur: usize, next: usize) -> usize {
113        const STATES: [[usize; 8]; 8] = [
114            [0, 8, 4, 12, 2, 10, 6, 14],
115            [4, 12, 2, 10, 6, 14, 0, 8],
116            [1, 9, 5, 13, 3, 11, 7, 15],
117            [5, 13, 3, 11, 7, 15, 1, 9],
118            [3, 11, 7, 15, 1, 9, 5, 13],
119            [7, 15, 1, 9, 5, 13, 3, 11],
120            [2, 10, 6, 14, 0, 8, 4, 12],
121            [6, 14, 0, 8, 4, 12, 2, 10],
122        ];
123
124        STATES[cur][next]
125    }
126
127    fn state(input: bits::Tribit) -> usize {
128        input.bits() as usize
129    }
130    fn finisher() -> Self::Symbol {
131        bits::Tribit::new(0b000)
132    }
133    fn symbol(state: usize) -> Self::Symbol {
134        bits::Tribit::new(state as u8)
135    }
136}
137
138/// Convolutional code finite state machine with the given transition table. Each fed-in
139/// symbol is used as the next state.
140pub struct TrellisFSM<S: States> {
141    states: std::marker::PhantomData<S>,
142    /// Current state.
143    state: usize,
144}
145
146impl<S: States> Default for TrellisFSM<S> {
147    fn default() -> Self {
148        TrellisFSM {
149            states: std::marker::PhantomData,
150            state: 0,
151        }
152    }
153}
154
155impl<S: States> TrellisFSM<S> {
156    /// Construct a new `TrellisFSM` at the initial state.
157    pub fn new() -> TrellisFSM<S> {
158        TrellisFSM {
159            states: std::marker::PhantomData,
160            state: 0,
161        }
162    }
163
164    /// Apply the given symbol to the state machine and return the dibit pair on the
165    /// transition.
166    pub fn feed(&mut self, input: S::Symbol) -> (bits::Dibit, bits::Dibit) {
167        let next = S::state(input);
168        let pair = S::pair(self.state, next);
169
170        self.state = next;
171
172        pair
173    }
174
175    /// Flush the state machine with the finishing symbol and return the final transition.
176    pub fn finish(&mut self) -> (bits::Dibit, bits::Dibit) {
177        self.feed(S::finisher())
178    }
179}
180
181pub trait WalkHistory: Copy + Clone + Default + Deref<Target = [Option<usize>]> + DerefMut {
182    /// The length of each walk associated with each state. This also determines the delay
183    /// before the first decoded symbol is yielded.
184    fn history() -> usize;
185}
186
187macro_rules! history_type {
188    ($name:ident, $history:expr) => {
189        #[derive(Copy, Clone, Default)]
190        pub struct $name([Option<usize>; $history]);
191
192        impl Deref for $name {
193            type Target = [Option<usize>];
194            fn deref<'a>(&'a self) -> &'a Self::Target {
195                &self.0[..]
196            }
197        }
198
199        impl DerefMut for $name {
200            fn deref_mut<'a>(&'a mut self) -> &'a mut Self::Target {
201                &mut self.0[..]
202            }
203        }
204
205        impl WalkHistory for $name {
206            fn history() -> usize {
207                $history
208            }
209        }
210    };
211}
212
213history_type!(DibitHistory, 4);
214history_type!(TribitHistory, 4);
215
216pub trait Walks<H: WalkHistory>:
217    Copy + Clone + Default + Deref<Target = [Walk<H>]> + DerefMut
218{
219    fn states() -> usize;
220}
221
222macro_rules! impl_walks {
223    ($name:ident, $hist:ident, $states:expr) => {
224        #[derive(Copy, Clone)]
225        pub struct $name([Walk<$hist>; $states]);
226
227        impl Deref for $name {
228            type Target = [Walk<$hist>];
229            fn deref<'a>(&'a self) -> &'a Self::Target {
230                &self.0[..]
231            }
232        }
233
234        impl DerefMut for $name {
235            fn deref_mut<'a>(&'a mut self) -> &'a mut Self::Target {
236                &mut self.0[..]
237            }
238        }
239
240        impl Walks<$hist> for $name {
241            fn states() -> usize {
242                $states
243            }
244        }
245
246        impl Default for $name {
247            fn default() -> Self {
248                let mut walks = [Walk::default(); $states];
249
250                (0..Self::states())
251                    .map(Walk::new)
252                    .collect_slice_checked(&mut walks[..]);
253
254                $name(walks)
255            }
256        }
257    };
258}
259
260impl_walks!(DibitWalks, DibitHistory, 4);
261impl_walks!(TribitWalks, TribitHistory, 8);
262
263/// Decodes a received convolutional code dibit stream to a nearby codeword using the
264/// truncated Viterbi algorithm.
265pub struct ViterbiDecoder<S, H, W, T>
266where
267    S: States,
268    H: WalkHistory,
269    W: Walks<H>,
270    T: Iterator<Item = bits::Dibit>,
271{
272    states: std::marker::PhantomData<S>,
273    history: std::marker::PhantomData<H>,
274    /// Source of dibits.
275    src: T,
276    /// Walks associated with each state, for the current and previous tick.
277    cur: usize,
278    prev: usize,
279    walks: [W; 2],
280    /// Remaining symbols to yield.
281    remain: usize,
282}
283
284impl<S, H, W, T> ViterbiDecoder<S, H, W, T>
285where
286    S: States,
287    H: WalkHistory,
288    W: Walks<H>,
289    T: Iterator<Item = bits::Dibit>,
290{
291    /// Construct a new `ViterbiDecoder` over the given dibit source.
292    pub fn new(src: T) -> ViterbiDecoder<S, H, W, T> {
293        debug_assert!(S::size() == W::states());
294
295        ViterbiDecoder {
296            states: std::marker::PhantomData,
297            history: std::marker::PhantomData,
298            src,
299            walks: [W::default(); 2],
300            cur: 1,
301            prev: 0,
302            remain: 0,
303        }
304        .prime()
305    }
306
307    fn prime(mut self) -> Self {
308        for _ in 1..H::history() {
309            self.step();
310        }
311
312        self
313    }
314
315    fn switch_walk(&mut self) {
316        std::mem::swap(&mut self.cur, &mut self.prev);
317    }
318
319    fn step(&mut self) -> bool {
320        let input = Edge::new(match (self.src.next(), self.src.next()) {
321            (Some(hi), Some(lo)) => (hi, lo),
322            (None, None) => return false,
323            _ => panic!("dibits ended on boundary"),
324        });
325
326        self.remain += 1;
327        self.switch_walk();
328
329        for s in 0..S::size() {
330            let (walk, _) = self.search(s, input);
331            self.walks[self.cur][s].append(walk);
332        }
333
334        true
335    }
336
337    ///
338    fn search(&self, state: usize, input: Edge) -> (Walk<H>, bool) {
339        self.walks[self.prev]
340            .iter()
341            .enumerate()
342            .map(|(i, w)| (Edge::new(S::pair(i, state)), w))
343            .fold((Walk::default(), false), |(walk, amb), (e, w)| {
344                match w.distance.checked_add(input.distance(e)) {
345                    Some(sum) if sum < walk.distance => (walk.replace(w, sum), false),
346                    Some(sum) if sum == walk.distance => (walk.combine(w, sum), true),
347                    _ => (walk, amb),
348                }
349            })
350    }
351
352    ///
353    fn decode(&self) -> Decision {
354        self.walks[self.cur]
355            .iter()
356            .fold(Ambiguous(std::usize::MAX), |s, w| match s {
357                Ambiguous(min) | Definite(min, _) if w.distance < min => {
358                    Definite(w.distance, w[self.remain])
359                }
360                Definite(min, state) if w.distance == min && w[self.remain] != state => {
361                    Ambiguous(w.distance)
362                }
363                _ => s,
364            })
365    }
366}
367
368impl<S, H, W, T> Iterator for ViterbiDecoder<S, H, W, T>
369where
370    S: States,
371    H: WalkHistory,
372    W: Walks<H>,
373    T: Iterator<Item = bits::Dibit>,
374{
375    type Item = Result<S::Symbol, ()>;
376
377    fn next(&mut self) -> Option<Self::Item> {
378        // Stop on the symbol before last since the final symbol is always a dummy symbol
379        // used for flushing.
380        if !self.step() && self.remain == 1 {
381            return None;
382        }
383
384        self.remain -= 1;
385
386        Some(match self.decode() {
387            Ambiguous(_) | Definite(_, None) => Err(()),
388            Definite(_, Some(state)) => Ok(S::symbol(state)),
389        })
390    }
391}
392
393/// Decoding decision.
394enum Decision {
395    Definite(usize, Option<usize>),
396    Ambiguous(usize),
397}
398
399#[derive(Copy, Clone, Debug)]
400pub struct Walk<H: WalkHistory> {
401    history: H,
402    pub distance: usize,
403}
404
405impl<H: WalkHistory> Walk<H> {
406    pub fn new(state: usize) -> Walk<H> {
407        Walk {
408            history: H::default(),
409            distance: if state == 0 { 0 } else { std::usize::MAX },
410        }
411        .init(state)
412    }
413
414    fn init(mut self, state: usize) -> Self {
415        self.history[0] = Some(state);
416        self
417    }
418
419    pub fn append(&mut self, other: Self) {
420        self.distance = other.distance;
421        other.iter().cloned().collect_slice(&mut self[1..]);
422    }
423
424    pub fn combine(mut self, other: &Self, distance: usize) -> Self {
425        self.distance = distance;
426
427        for (dest, src) in self.iter_mut().zip(other.iter()) {
428            if src != dest {
429                *dest = None;
430            }
431        }
432
433        self
434    }
435
436    pub fn replace(mut self, other: &Self, distance: usize) -> Self {
437        self.distance = distance;
438        other.iter().cloned().collect_slice_checked(&mut self[..]);
439
440        self
441    }
442}
443
444impl<H: WalkHistory> Deref for Walk<H> {
445    type Target = [Option<usize>];
446    fn deref(&self) -> &Self::Target {
447        &self.history
448    }
449}
450
451impl<H: WalkHistory> DerefMut for Walk<H> {
452    fn deref_mut(&mut self) -> &mut Self::Target {
453        &mut self.history
454    }
455}
456
457impl<H: WalkHistory> Default for Walk<H> {
458    fn default() -> Self {
459        Walk::new(std::usize::MAX)
460    }
461}
462
463#[derive(Copy, Clone)]
464struct Edge(u8);
465
466impl Edge {
467    pub fn new((hi, lo): (bits::Dibit, bits::Dibit)) -> Edge {
468        Edge(hi.bits() << 2 | lo.bits())
469    }
470
471    pub fn distance(&self, other: Edge) -> usize {
472        (self.0 ^ other.0).count_ones() as usize
473    }
474}
475
476#[cfg(test)]
477mod test {
478    use super::Edge;
479    use super::*;
480    use bits::*;
481
482    #[test]
483    fn test_dibit_code() {
484        let mut fsm = DibitFSM::new();
485        assert_eq!(
486            fsm.feed(Dibit::new(0b00)),
487            (Dibit::new(0b00), Dibit::new(0b10))
488        );
489        assert_eq!(
490            fsm.feed(Dibit::new(0b00)),
491            (Dibit::new(0b00), Dibit::new(0b10))
492        );
493        assert_eq!(
494            fsm.feed(Dibit::new(0b01)),
495            (Dibit::new(0b11), Dibit::new(0b00))
496        );
497        assert_eq!(
498            fsm.feed(Dibit::new(0b01)),
499            (Dibit::new(0b00), Dibit::new(0b00))
500        );
501        assert_eq!(
502            fsm.feed(Dibit::new(0b10)),
503            (Dibit::new(0b11), Dibit::new(0b01))
504        );
505        assert_eq!(
506            fsm.feed(Dibit::new(0b10)),
507            (Dibit::new(0b10), Dibit::new(0b10))
508        );
509        assert_eq!(
510            fsm.feed(Dibit::new(0b11)),
511            (Dibit::new(0b01), Dibit::new(0b00))
512        );
513        assert_eq!(
514            fsm.feed(Dibit::new(0b11)),
515            (Dibit::new(0b10), Dibit::new(0b00))
516        );
517    }
518
519    #[test]
520    fn test_tribit_code() {
521        let mut fsm = TribitFSM::new();
522        assert_eq!(
523            fsm.feed(Tribit::new(0b000)),
524            (Dibit::new(0b00), Dibit::new(0b10))
525        );
526        assert_eq!(
527            fsm.feed(Tribit::new(0b000)),
528            (Dibit::new(0b00), Dibit::new(0b10))
529        );
530        assert_eq!(
531            fsm.feed(Tribit::new(0b001)),
532            (Dibit::new(0b11), Dibit::new(0b01))
533        );
534        assert_eq!(
535            fsm.feed(Tribit::new(0b010)),
536            (Dibit::new(0b01), Dibit::new(0b11))
537        );
538        assert_eq!(
539            fsm.feed(Tribit::new(0b100)),
540            (Dibit::new(0b11), Dibit::new(0b11))
541        );
542        assert_eq!(
543            fsm.feed(Tribit::new(0b101)),
544            (Dibit::new(0b01), Dibit::new(0b01))
545        );
546        assert_eq!(
547            fsm.feed(Tribit::new(0b110)),
548            (Dibit::new(0b11), Dibit::new(0b11))
549        );
550        assert_eq!(
551            fsm.feed(Tribit::new(0b111)),
552            (Dibit::new(0b00), Dibit::new(0b01))
553        );
554        assert_eq!(
555            fsm.feed(Tribit::new(0b000)),
556            (Dibit::new(0b10), Dibit::new(0b11))
557        );
558        assert_eq!(
559            fsm.feed(Tribit::new(0b111)),
560            (Dibit::new(0b01), Dibit::new(0b00))
561        );
562    }
563
564    #[test]
565    fn test_edge() {
566        assert_eq!(
567            Edge::new((Dibit::new(0b11), Dibit::new(0b01)))
568                .distance(Edge::new((Dibit::new(0b11), Dibit::new(0b01)))),
569            0
570        );
571
572        assert_eq!(
573            Edge::new((Dibit::new(0b11), Dibit::new(0b01)))
574                .distance(Edge::new((Dibit::new(0b00), Dibit::new(0b10)))),
575            4
576        );
577    }
578
579    #[test]
580    fn test_dibit_decoder() {
581        let bits = [1, 2, 2, 2, 2, 1, 3, 3, 0, 2];
582        let stream = bits.iter().map(|&bits| Dibit::new(bits));
583
584        let mut dibits = vec![];
585        let mut fsm = DibitFSM::new();
586
587        for dibit in stream {
588            let (hi, lo) = fsm.feed(dibit);
589            dibits.push(hi);
590            dibits.push(lo);
591        }
592
593        let (hi, lo) = fsm.finish();
594        dibits.push(hi);
595        dibits.push(lo);
596
597        dibits[2] = Dibit::new(0b10);
598        dibits[4] = Dibit::new(0b10);
599
600        let mut dec = DibitDecoder::new(dibits.iter().cloned());
601
602        assert_eq!(dec.next().unwrap().unwrap().bits(), 1);
603        assert_eq!(dec.next().unwrap().unwrap().bits(), 2);
604        assert_eq!(dec.next().unwrap().unwrap().bits(), 2);
605        assert_eq!(dec.next().unwrap().unwrap().bits(), 2);
606        assert_eq!(dec.next().unwrap().unwrap().bits(), 2);
607        assert_eq!(dec.next().unwrap().unwrap().bits(), 1);
608        assert_eq!(dec.next().unwrap().unwrap().bits(), 3);
609        assert_eq!(dec.next().unwrap().unwrap().bits(), 3);
610        assert_eq!(dec.next().unwrap().unwrap().bits(), 0);
611        assert_eq!(dec.next().unwrap().unwrap().bits(), 2);
612    }
613
614    #[test]
615    fn test_tribit_decoder() {
616        let bits = [1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0];
617        let stream = bits.iter().map(|&bits| Tribit::new(bits));
618
619        let mut dibits = vec![];
620        let mut fsm = TribitFSM::new();
621
622        for tribit in stream {
623            let (hi, lo) = fsm.feed(tribit);
624            dibits.push(hi);
625            dibits.push(lo);
626        }
627
628        let (hi, lo) = fsm.finish();
629        dibits.push(hi);
630        dibits.push(lo);
631
632        dibits[6] = Dibit::new(0b10);
633        dibits[4] = Dibit::new(0b10);
634        dibits[14] = Dibit::new(0b10);
635
636        let mut dec = TribitDecoder::new(dibits.iter().cloned());
637
638        assert_eq!(dec.next().unwrap().unwrap().bits(), 1);
639        assert_eq!(dec.next().unwrap().unwrap().bits(), 2);
640        assert_eq!(dec.next().unwrap().unwrap().bits(), 3);
641        assert_eq!(dec.next().unwrap().unwrap().bits(), 4);
642        assert_eq!(dec.next().unwrap().unwrap().bits(), 5);
643        assert_eq!(dec.next().unwrap().unwrap().bits(), 6);
644        assert_eq!(dec.next().unwrap().unwrap().bits(), 7);
645        assert_eq!(dec.next().unwrap().unwrap().bits(), 0);
646        assert_eq!(dec.next().unwrap().unwrap().bits(), 1);
647        assert_eq!(dec.next().unwrap().unwrap().bits(), 2);
648        assert_eq!(dec.next().unwrap().unwrap().bits(), 3);
649        assert_eq!(dec.next().unwrap().unwrap().bits(), 4);
650        assert_eq!(dec.next().unwrap().unwrap().bits(), 5);
651        assert_eq!(dec.next().unwrap().unwrap().bits(), 6);
652        assert_eq!(dec.next().unwrap().unwrap().bits(), 7);
653        assert_eq!(dec.next().unwrap().unwrap().bits(), 0);
654    }
655}