1use std;
9use std::ops::{Deref, DerefMut};
10
11use collect_slice::CollectSlice;
12
13use crate::bits;
14
15use self::Decision::*;
16
17pub type DibitFSM = TrellisFSM<DibitStates>;
19
20pub type TribitFSM = TrellisFSM<TribitStates>;
22
23pub type DibitDecoder<T> = ViterbiDecoder<DibitStates, DibitHistory, DibitWalks, T>;
25
26pub type TribitDecoder<T> = ViterbiDecoder<TribitStates, TribitHistory, TribitWalks, T>;
28
29pub trait States {
30 type Symbol;
32
33 fn size() -> usize;
35
36 fn pair_idx(cur: usize, next: usize) -> usize;
39
40 fn state(input: Self::Symbol) -> usize;
42 fn symbol(state: usize) -> Self::Symbol;
44
45 fn finisher() -> Self::Symbol;
47
48 fn pair(state: usize, next: usize) -> (bits::Dibit, bits::Dibit) {
50 const PAIRS: [(u8, u8); 16] = [
51 (0b00, 0b10),
52 (0b10, 0b10),
53 (0b01, 0b11),
54 (0b11, 0b11),
55 (0b11, 0b10),
56 (0b01, 0b10),
57 (0b10, 0b11),
58 (0b00, 0b11),
59 (0b11, 0b01),
60 (0b01, 0b01),
61 (0b10, 0b00),
62 (0b00, 0b00),
63 (0b00, 0b01),
64 (0b10, 0b01),
65 (0b01, 0b00),
66 (0b11, 0b00),
67 ];
68
69 let (hi, lo) = PAIRS[Self::pair_idx(state, next)];
70 (bits::Dibit::new(hi), bits::Dibit::new(lo))
71 }
72}
73
74pub struct DibitStates;
76
77impl States for DibitStates {
78 type Symbol = bits::Dibit;
79
80 fn size() -> usize {
81 4
82 }
83
84 fn pair_idx(cur: usize, next: usize) -> usize {
85 const STATES: [[usize; 4]; 4] =
86 [[0, 15, 12, 3], [4, 11, 8, 7], [13, 2, 1, 14], [9, 6, 5, 10]];
87
88 STATES[cur][next]
89 }
90
91 fn state(input: bits::Dibit) -> usize {
92 input.bits() as usize
93 }
94 fn finisher() -> Self::Symbol {
95 bits::Dibit::new(0b00)
96 }
97 fn symbol(state: usize) -> Self::Symbol {
98 bits::Dibit::new(state as u8)
99 }
100}
101
102pub struct TribitStates;
104
105impl States for TribitStates {
106 type Symbol = bits::Tribit;
107
108 fn size() -> usize {
109 8
110 }
111
112 fn pair_idx(cur: usize, next: usize) -> usize {
113 const STATES: [[usize; 8]; 8] = [
114 [0, 8, 4, 12, 2, 10, 6, 14],
115 [4, 12, 2, 10, 6, 14, 0, 8],
116 [1, 9, 5, 13, 3, 11, 7, 15],
117 [5, 13, 3, 11, 7, 15, 1, 9],
118 [3, 11, 7, 15, 1, 9, 5, 13],
119 [7, 15, 1, 9, 5, 13, 3, 11],
120 [2, 10, 6, 14, 0, 8, 4, 12],
121 [6, 14, 0, 8, 4, 12, 2, 10],
122 ];
123
124 STATES[cur][next]
125 }
126
127 fn state(input: bits::Tribit) -> usize {
128 input.bits() as usize
129 }
130 fn finisher() -> Self::Symbol {
131 bits::Tribit::new(0b000)
132 }
133 fn symbol(state: usize) -> Self::Symbol {
134 bits::Tribit::new(state as u8)
135 }
136}
137
138pub struct TrellisFSM<S: States> {
141 states: std::marker::PhantomData<S>,
142 state: usize,
144}
145
146impl<S: States> Default for TrellisFSM<S> {
147 fn default() -> Self {
148 TrellisFSM {
149 states: std::marker::PhantomData,
150 state: 0,
151 }
152 }
153}
154
155impl<S: States> TrellisFSM<S> {
156 pub fn new() -> TrellisFSM<S> {
158 TrellisFSM {
159 states: std::marker::PhantomData,
160 state: 0,
161 }
162 }
163
164 pub fn feed(&mut self, input: S::Symbol) -> (bits::Dibit, bits::Dibit) {
167 let next = S::state(input);
168 let pair = S::pair(self.state, next);
169
170 self.state = next;
171
172 pair
173 }
174
175 pub fn finish(&mut self) -> (bits::Dibit, bits::Dibit) {
177 self.feed(S::finisher())
178 }
179}
180
181pub trait WalkHistory: Copy + Clone + Default + Deref<Target = [Option<usize>]> + DerefMut {
182 fn history() -> usize;
185}
186
187macro_rules! history_type {
188 ($name:ident, $history:expr) => {
189 #[derive(Copy, Clone, Default)]
190 pub struct $name([Option<usize>; $history]);
191
192 impl Deref for $name {
193 type Target = [Option<usize>];
194 fn deref<'a>(&'a self) -> &'a Self::Target {
195 &self.0[..]
196 }
197 }
198
199 impl DerefMut for $name {
200 fn deref_mut<'a>(&'a mut self) -> &'a mut Self::Target {
201 &mut self.0[..]
202 }
203 }
204
205 impl WalkHistory for $name {
206 fn history() -> usize {
207 $history
208 }
209 }
210 };
211}
212
213history_type!(DibitHistory, 4);
214history_type!(TribitHistory, 4);
215
216pub trait Walks<H: WalkHistory>:
217 Copy + Clone + Default + Deref<Target = [Walk<H>]> + DerefMut
218{
219 fn states() -> usize;
220}
221
222macro_rules! impl_walks {
223 ($name:ident, $hist:ident, $states:expr) => {
224 #[derive(Copy, Clone)]
225 pub struct $name([Walk<$hist>; $states]);
226
227 impl Deref for $name {
228 type Target = [Walk<$hist>];
229 fn deref<'a>(&'a self) -> &'a Self::Target {
230 &self.0[..]
231 }
232 }
233
234 impl DerefMut for $name {
235 fn deref_mut<'a>(&'a mut self) -> &'a mut Self::Target {
236 &mut self.0[..]
237 }
238 }
239
240 impl Walks<$hist> for $name {
241 fn states() -> usize {
242 $states
243 }
244 }
245
246 impl Default for $name {
247 fn default() -> Self {
248 let mut walks = [Walk::default(); $states];
249
250 (0..Self::states())
251 .map(Walk::new)
252 .collect_slice_checked(&mut walks[..]);
253
254 $name(walks)
255 }
256 }
257 };
258}
259
260impl_walks!(DibitWalks, DibitHistory, 4);
261impl_walks!(TribitWalks, TribitHistory, 8);
262
263pub struct ViterbiDecoder<S, H, W, T>
266where
267 S: States,
268 H: WalkHistory,
269 W: Walks<H>,
270 T: Iterator<Item = bits::Dibit>,
271{
272 states: std::marker::PhantomData<S>,
273 history: std::marker::PhantomData<H>,
274 src: T,
276 cur: usize,
278 prev: usize,
279 walks: [W; 2],
280 remain: usize,
282}
283
284impl<S, H, W, T> ViterbiDecoder<S, H, W, T>
285where
286 S: States,
287 H: WalkHistory,
288 W: Walks<H>,
289 T: Iterator<Item = bits::Dibit>,
290{
291 pub fn new(src: T) -> ViterbiDecoder<S, H, W, T> {
293 debug_assert!(S::size() == W::states());
294
295 ViterbiDecoder {
296 states: std::marker::PhantomData,
297 history: std::marker::PhantomData,
298 src,
299 walks: [W::default(); 2],
300 cur: 1,
301 prev: 0,
302 remain: 0,
303 }
304 .prime()
305 }
306
307 fn prime(mut self) -> Self {
308 for _ in 1..H::history() {
309 self.step();
310 }
311
312 self
313 }
314
315 fn switch_walk(&mut self) {
316 std::mem::swap(&mut self.cur, &mut self.prev);
317 }
318
319 fn step(&mut self) -> bool {
320 let input = Edge::new(match (self.src.next(), self.src.next()) {
321 (Some(hi), Some(lo)) => (hi, lo),
322 (None, None) => return false,
323 _ => panic!("dibits ended on boundary"),
324 });
325
326 self.remain += 1;
327 self.switch_walk();
328
329 for s in 0..S::size() {
330 let (walk, _) = self.search(s, input);
331 self.walks[self.cur][s].append(walk);
332 }
333
334 true
335 }
336
337 fn search(&self, state: usize, input: Edge) -> (Walk<H>, bool) {
339 self.walks[self.prev]
340 .iter()
341 .enumerate()
342 .map(|(i, w)| (Edge::new(S::pair(i, state)), w))
343 .fold((Walk::default(), false), |(walk, amb), (e, w)| {
344 match w.distance.checked_add(input.distance(e)) {
345 Some(sum) if sum < walk.distance => (walk.replace(w, sum), false),
346 Some(sum) if sum == walk.distance => (walk.combine(w, sum), true),
347 _ => (walk, amb),
348 }
349 })
350 }
351
352 fn decode(&self) -> Decision {
354 self.walks[self.cur]
355 .iter()
356 .fold(Ambiguous(std::usize::MAX), |s, w| match s {
357 Ambiguous(min) | Definite(min, _) if w.distance < min => {
358 Definite(w.distance, w[self.remain])
359 }
360 Definite(min, state) if w.distance == min && w[self.remain] != state => {
361 Ambiguous(w.distance)
362 }
363 _ => s,
364 })
365 }
366}
367
368impl<S, H, W, T> Iterator for ViterbiDecoder<S, H, W, T>
369where
370 S: States,
371 H: WalkHistory,
372 W: Walks<H>,
373 T: Iterator<Item = bits::Dibit>,
374{
375 type Item = Result<S::Symbol, ()>;
376
377 fn next(&mut self) -> Option<Self::Item> {
378 if !self.step() && self.remain == 1 {
381 return None;
382 }
383
384 self.remain -= 1;
385
386 Some(match self.decode() {
387 Ambiguous(_) | Definite(_, None) => Err(()),
388 Definite(_, Some(state)) => Ok(S::symbol(state)),
389 })
390 }
391}
392
393enum Decision {
395 Definite(usize, Option<usize>),
396 Ambiguous(usize),
397}
398
399#[derive(Copy, Clone, Debug)]
400pub struct Walk<H: WalkHistory> {
401 history: H,
402 pub distance: usize,
403}
404
405impl<H: WalkHistory> Walk<H> {
406 pub fn new(state: usize) -> Walk<H> {
407 Walk {
408 history: H::default(),
409 distance: if state == 0 { 0 } else { std::usize::MAX },
410 }
411 .init(state)
412 }
413
414 fn init(mut self, state: usize) -> Self {
415 self.history[0] = Some(state);
416 self
417 }
418
419 pub fn append(&mut self, other: Self) {
420 self.distance = other.distance;
421 other.iter().cloned().collect_slice(&mut self[1..]);
422 }
423
424 pub fn combine(mut self, other: &Self, distance: usize) -> Self {
425 self.distance = distance;
426
427 for (dest, src) in self.iter_mut().zip(other.iter()) {
428 if src != dest {
429 *dest = None;
430 }
431 }
432
433 self
434 }
435
436 pub fn replace(mut self, other: &Self, distance: usize) -> Self {
437 self.distance = distance;
438 other.iter().cloned().collect_slice_checked(&mut self[..]);
439
440 self
441 }
442}
443
444impl<H: WalkHistory> Deref for Walk<H> {
445 type Target = [Option<usize>];
446 fn deref(&self) -> &Self::Target {
447 &self.history
448 }
449}
450
451impl<H: WalkHistory> DerefMut for Walk<H> {
452 fn deref_mut(&mut self) -> &mut Self::Target {
453 &mut self.history
454 }
455}
456
457impl<H: WalkHistory> Default for Walk<H> {
458 fn default() -> Self {
459 Walk::new(std::usize::MAX)
460 }
461}
462
463#[derive(Copy, Clone)]
464struct Edge(u8);
465
466impl Edge {
467 pub fn new((hi, lo): (bits::Dibit, bits::Dibit)) -> Edge {
468 Edge(hi.bits() << 2 | lo.bits())
469 }
470
471 pub fn distance(&self, other: Edge) -> usize {
472 (self.0 ^ other.0).count_ones() as usize
473 }
474}
475
476#[cfg(test)]
477mod test {
478 use super::Edge;
479 use super::*;
480 use bits::*;
481
482 #[test]
483 fn test_dibit_code() {
484 let mut fsm = DibitFSM::new();
485 assert_eq!(
486 fsm.feed(Dibit::new(0b00)),
487 (Dibit::new(0b00), Dibit::new(0b10))
488 );
489 assert_eq!(
490 fsm.feed(Dibit::new(0b00)),
491 (Dibit::new(0b00), Dibit::new(0b10))
492 );
493 assert_eq!(
494 fsm.feed(Dibit::new(0b01)),
495 (Dibit::new(0b11), Dibit::new(0b00))
496 );
497 assert_eq!(
498 fsm.feed(Dibit::new(0b01)),
499 (Dibit::new(0b00), Dibit::new(0b00))
500 );
501 assert_eq!(
502 fsm.feed(Dibit::new(0b10)),
503 (Dibit::new(0b11), Dibit::new(0b01))
504 );
505 assert_eq!(
506 fsm.feed(Dibit::new(0b10)),
507 (Dibit::new(0b10), Dibit::new(0b10))
508 );
509 assert_eq!(
510 fsm.feed(Dibit::new(0b11)),
511 (Dibit::new(0b01), Dibit::new(0b00))
512 );
513 assert_eq!(
514 fsm.feed(Dibit::new(0b11)),
515 (Dibit::new(0b10), Dibit::new(0b00))
516 );
517 }
518
519 #[test]
520 fn test_tribit_code() {
521 let mut fsm = TribitFSM::new();
522 assert_eq!(
523 fsm.feed(Tribit::new(0b000)),
524 (Dibit::new(0b00), Dibit::new(0b10))
525 );
526 assert_eq!(
527 fsm.feed(Tribit::new(0b000)),
528 (Dibit::new(0b00), Dibit::new(0b10))
529 );
530 assert_eq!(
531 fsm.feed(Tribit::new(0b001)),
532 (Dibit::new(0b11), Dibit::new(0b01))
533 );
534 assert_eq!(
535 fsm.feed(Tribit::new(0b010)),
536 (Dibit::new(0b01), Dibit::new(0b11))
537 );
538 assert_eq!(
539 fsm.feed(Tribit::new(0b100)),
540 (Dibit::new(0b11), Dibit::new(0b11))
541 );
542 assert_eq!(
543 fsm.feed(Tribit::new(0b101)),
544 (Dibit::new(0b01), Dibit::new(0b01))
545 );
546 assert_eq!(
547 fsm.feed(Tribit::new(0b110)),
548 (Dibit::new(0b11), Dibit::new(0b11))
549 );
550 assert_eq!(
551 fsm.feed(Tribit::new(0b111)),
552 (Dibit::new(0b00), Dibit::new(0b01))
553 );
554 assert_eq!(
555 fsm.feed(Tribit::new(0b000)),
556 (Dibit::new(0b10), Dibit::new(0b11))
557 );
558 assert_eq!(
559 fsm.feed(Tribit::new(0b111)),
560 (Dibit::new(0b01), Dibit::new(0b00))
561 );
562 }
563
564 #[test]
565 fn test_edge() {
566 assert_eq!(
567 Edge::new((Dibit::new(0b11), Dibit::new(0b01)))
568 .distance(Edge::new((Dibit::new(0b11), Dibit::new(0b01)))),
569 0
570 );
571
572 assert_eq!(
573 Edge::new((Dibit::new(0b11), Dibit::new(0b01)))
574 .distance(Edge::new((Dibit::new(0b00), Dibit::new(0b10)))),
575 4
576 );
577 }
578
579 #[test]
580 fn test_dibit_decoder() {
581 let bits = [1, 2, 2, 2, 2, 1, 3, 3, 0, 2];
582 let stream = bits.iter().map(|&bits| Dibit::new(bits));
583
584 let mut dibits = vec![];
585 let mut fsm = DibitFSM::new();
586
587 for dibit in stream {
588 let (hi, lo) = fsm.feed(dibit);
589 dibits.push(hi);
590 dibits.push(lo);
591 }
592
593 let (hi, lo) = fsm.finish();
594 dibits.push(hi);
595 dibits.push(lo);
596
597 dibits[2] = Dibit::new(0b10);
598 dibits[4] = Dibit::new(0b10);
599
600 let mut dec = DibitDecoder::new(dibits.iter().cloned());
601
602 assert_eq!(dec.next().unwrap().unwrap().bits(), 1);
603 assert_eq!(dec.next().unwrap().unwrap().bits(), 2);
604 assert_eq!(dec.next().unwrap().unwrap().bits(), 2);
605 assert_eq!(dec.next().unwrap().unwrap().bits(), 2);
606 assert_eq!(dec.next().unwrap().unwrap().bits(), 2);
607 assert_eq!(dec.next().unwrap().unwrap().bits(), 1);
608 assert_eq!(dec.next().unwrap().unwrap().bits(), 3);
609 assert_eq!(dec.next().unwrap().unwrap().bits(), 3);
610 assert_eq!(dec.next().unwrap().unwrap().bits(), 0);
611 assert_eq!(dec.next().unwrap().unwrap().bits(), 2);
612 }
613
614 #[test]
615 fn test_tribit_decoder() {
616 let bits = [1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0];
617 let stream = bits.iter().map(|&bits| Tribit::new(bits));
618
619 let mut dibits = vec![];
620 let mut fsm = TribitFSM::new();
621
622 for tribit in stream {
623 let (hi, lo) = fsm.feed(tribit);
624 dibits.push(hi);
625 dibits.push(lo);
626 }
627
628 let (hi, lo) = fsm.finish();
629 dibits.push(hi);
630 dibits.push(lo);
631
632 dibits[6] = Dibit::new(0b10);
633 dibits[4] = Dibit::new(0b10);
634 dibits[14] = Dibit::new(0b10);
635
636 let mut dec = TribitDecoder::new(dibits.iter().cloned());
637
638 assert_eq!(dec.next().unwrap().unwrap().bits(), 1);
639 assert_eq!(dec.next().unwrap().unwrap().bits(), 2);
640 assert_eq!(dec.next().unwrap().unwrap().bits(), 3);
641 assert_eq!(dec.next().unwrap().unwrap().bits(), 4);
642 assert_eq!(dec.next().unwrap().unwrap().bits(), 5);
643 assert_eq!(dec.next().unwrap().unwrap().bits(), 6);
644 assert_eq!(dec.next().unwrap().unwrap().bits(), 7);
645 assert_eq!(dec.next().unwrap().unwrap().bits(), 0);
646 assert_eq!(dec.next().unwrap().unwrap().bits(), 1);
647 assert_eq!(dec.next().unwrap().unwrap().bits(), 2);
648 assert_eq!(dec.next().unwrap().unwrap().bits(), 3);
649 assert_eq!(dec.next().unwrap().unwrap().bits(), 4);
650 assert_eq!(dec.next().unwrap().unwrap().bits(), 5);
651 assert_eq!(dec.next().unwrap().unwrap().bits(), 6);
652 assert_eq!(dec.next().unwrap().unwrap().bits(), 7);
653 assert_eq!(dec.next().unwrap().unwrap().bits(), 0);
654 }
655}