1#[derive(Debug, Clone)]
17#[allow(dead_code)]
18pub struct ArithmeticCoder {
19 pub low: u32,
21 pub high: u32,
23 pub bits_to_follow: u32,
25}
26
27impl ArithmeticCoder {
28 #[allow(dead_code)]
30 pub fn new() -> Self {
31 Self {
32 low: 0,
33 high: 0xFFFF_FFFF,
34 bits_to_follow: 0,
35 }
36 }
37
38 #[allow(dead_code)]
44 #[allow(clippy::cast_possible_truncation, clippy::same_item_push)]
45 pub fn encode_bit(&mut self, prob_one: f32, bit: bool) -> Vec<u8> {
46 let range = u64::from(self.high) - u64::from(self.low) + 1;
47 #[allow(clippy::cast_precision_loss)]
48 let split = ((range as f64 * f64::from(1.0 - prob_one)) as u64).saturating_sub(1);
49 let mid = self.low.saturating_add(split as u32);
50
51 if bit {
52 self.low = mid + 1;
53 } else {
54 self.high = mid;
55 }
56
57 let mut emitted_bits: Vec<bool> = Vec::new();
61 loop {
62 if self.high < 0x8000_0000 {
63 emitted_bits.push(false);
65 for _ in 0..self.bits_to_follow {
66 emitted_bits.push(true);
67 }
68 self.bits_to_follow = 0;
69 self.low <<= 1;
70 self.high = (self.high << 1) | 1;
71 } else if self.low >= 0x8000_0000 {
72 emitted_bits.push(true);
74 for _ in 0..self.bits_to_follow {
75 emitted_bits.push(false);
76 }
77 self.bits_to_follow = 0;
78 self.low = (self.low - 0x8000_0000) << 1;
79 self.high = ((self.high - 0x8000_0000) << 1) | 1;
80 } else if self.low >= 0x4000_0000 && self.high < 0xC000_0000 {
81 self.bits_to_follow += 1;
83 self.low = (self.low - 0x4000_0000) << 1;
84 self.high = ((self.high - 0x4000_0000) << 1) | 1;
85 } else {
86 break;
87 }
88 }
89
90 bits_to_bytes(&emitted_bits)
92 }
93
94 #[allow(dead_code)]
99 pub fn get_range(&self) -> u64 {
100 u64::from(self.high) - u64::from(self.low) + 1
101 }
102}
103
104#[allow(dead_code)]
106fn bits_to_bytes(bits: &[bool]) -> Vec<u8> {
107 let mut bytes = Vec::new();
108 let mut current: u8 = 0;
109 let mut count = 0u8;
110 for &b in bits {
111 current = (current << 1) | u8::from(b);
112 count += 1;
113 if count == 8 {
114 bytes.push(current);
115 current = 0;
116 count = 0;
117 }
118 }
119 if count > 0 {
120 bytes.push(current << (8 - count));
121 }
122 bytes
123}
124
125#[derive(Debug, Clone)]
134#[allow(dead_code)]
135pub struct RangeCoder {
136 pub range: u32,
138 pub code: u32,
140}
141
142impl RangeCoder {
143 #[allow(dead_code)]
145 pub fn new() -> Self {
146 Self {
147 range: 256,
148 code: 0,
149 }
150 }
151
152 #[allow(dead_code)]
155 pub fn normalize(&mut self) -> u32 {
156 let mut bits_consumed = 0;
157 while self.range < 128 {
158 self.range <<= 1;
159 self.code <<= 1;
160 bits_consumed += 1;
161 }
162 bits_consumed
163 }
164
165 #[allow(dead_code)]
169 pub fn decode_symbol(&mut self, prob: u32) -> bool {
170 let split = (self.range * prob) >> 8;
171 if self.code >= split {
172 self.code -= split;
173 self.range -= split;
174 true
175 } else {
176 self.range = split;
177 false
178 }
179 }
180}
181
182#[derive(Debug)]
188#[allow(dead_code)]
189pub struct HuffmanNode {
190 pub symbol: Option<u8>,
192 pub freq: u32,
194 pub left: Option<Box<HuffmanNode>>,
196 pub right: Option<Box<HuffmanNode>>,
198}
199
200impl HuffmanNode {
201 #[allow(dead_code)]
203 pub fn is_leaf(&self) -> bool {
204 self.left.is_none() && self.right.is_none()
205 }
206}
207
208#[allow(dead_code)]
215pub fn build_huffman_tree(freqs: &[u32]) -> HuffmanNode {
216 let mut nodes: Vec<HuffmanNode> = freqs
218 .iter()
219 .enumerate()
220 .filter(|(_, &f)| f > 0)
221 .map(|(i, &f)| HuffmanNode {
222 symbol: Some(i as u8),
223 freq: f,
224 left: None,
225 right: None,
226 })
227 .collect();
228
229 if nodes.is_empty() {
230 return HuffmanNode {
232 symbol: Some(0),
233 freq: 0,
234 left: None,
235 right: None,
236 };
237 }
238
239 if nodes.len() == 1 {
241 let leaf = nodes.remove(0);
242 return HuffmanNode {
243 symbol: None,
244 freq: leaf.freq,
245 left: Some(Box::new(leaf)),
246 right: None,
247 };
248 }
249
250 while nodes.len() > 1 {
252 nodes.sort_by_key(|n| n.freq);
254 let left = nodes.remove(0);
255 let right = nodes.remove(0);
256 let parent = HuffmanNode {
257 symbol: None,
258 freq: left.freq + right.freq,
259 left: Some(Box::new(left)),
260 right: Some(Box::new(right)),
261 };
262 nodes.push(parent);
263 }
264
265 nodes.remove(0)
266}
267
268#[allow(dead_code)]
274pub fn compute_huffman_codes(node: &HuffmanNode, prefix: Vec<u8>) -> Vec<(u8, Vec<u8>)> {
275 if node.is_leaf() {
276 if let Some(sym) = node.symbol {
277 return vec![(sym, prefix)];
278 }
279 return vec![];
280 }
281
282 let mut codes = Vec::new();
283 if let Some(left) = &node.left {
284 let mut left_prefix = prefix.clone();
285 left_prefix.push(0);
286 codes.extend(compute_huffman_codes(left, left_prefix));
287 }
288 if let Some(right) = &node.right {
289 let mut right_prefix = prefix.clone();
290 right_prefix.push(1);
291 codes.extend(compute_huffman_codes(right, right_prefix));
292 }
293 codes
294}
295
296const TABLE_PROB_BITS: u32 = 8;
302const TABLE_SIZE: usize = 1 << TABLE_PROB_BITS;
304const TABLE_MASK: u32 = (TABLE_SIZE as u32) - 1;
306
307#[derive(Clone, Copy, Debug, Default)]
309pub struct ProbTableEntry {
310 pub cum_prob_low: u32,
312 pub width_low: u32,
314 pub width_high: u32,
316}
317
318#[allow(dead_code)]
323pub fn build_prob_table(freqs: &[u32]) -> Vec<ProbTableEntry> {
324 let total: u64 = freqs.iter().map(|&f| u64::from(f)).sum();
325 if total == 0 {
326 return vec![ProbTableEntry::default(); freqs.len()];
327 }
328 let mut table = Vec::with_capacity(freqs.len());
329 let mut cum: u32 = 0;
330 for &freq in freqs {
331 let width = ((u64::from(freq) * u64::from(TABLE_SIZE as u32)) / total) as u32;
332 table.push(ProbTableEntry {
333 cum_prob_low: cum,
334 width_low: width,
335 width_high: TABLE_SIZE as u32 - cum - width,
336 });
337 cum += width;
338 }
339 table
340}
341
342#[derive(Debug, Clone)]
368#[allow(dead_code)]
369pub struct TableArithmeticCoder {
370 low: u32,
372 range: u32,
374 output: Vec<u8>,
376}
377
378impl TableArithmeticCoder {
379 const RANGE_MIN: u32 = 0x0100_0000;
381 const RANGE_MAX: u32 = 0xFF00_0000;
383
384 #[allow(dead_code)]
386 pub fn new() -> Self {
387 Self {
388 low: 0,
389 range: 0xFFFF_FF00,
390 output: Vec::new(),
391 }
392 }
393
394 #[allow(dead_code)]
398 pub fn encode_symbol(&mut self, sym_is_high: bool, entry: &ProbTableEntry) {
399 let (cum, width) = if sym_is_high {
400 let cum = entry.cum_prob_low + entry.width_low;
401 (cum, entry.width_high)
402 } else {
403 (entry.cum_prob_low, entry.width_low)
404 };
405
406 let r = self.range >> TABLE_PROB_BITS;
408 self.low = self.low.wrapping_add(r.saturating_mul(cum));
409 self.range = r.saturating_mul(width).max(1);
410
411 while self.range < Self::RANGE_MIN {
413 let byte = (self.low >> 24) as u8;
414 self.output.push(byte);
415 self.low <<= 8;
416 self.range <<= 8;
417 }
418 }
419
420 #[allow(dead_code)]
422 pub fn flush(mut self) -> Vec<u8> {
423 for _ in 0..4 {
425 self.output.push((self.low >> 24) as u8);
426 self.low <<= 8;
427 }
428 self.output
429 }
430
431 #[allow(dead_code)]
433 pub fn bytes_emitted(&self) -> usize {
434 self.output.len()
435 }
436}
437
438#[derive(Debug, Clone)]
443#[allow(dead_code)]
444pub struct TableArithmeticDecoder<'a> {
445 data: &'a [u8],
447 pos: usize,
449 code: u32,
451 range: u32,
453}
454
455impl<'a> TableArithmeticDecoder<'a> {
456 #[allow(dead_code)]
460 pub fn new(data: &'a [u8]) -> Self {
461 let mut dec = Self {
462 data,
463 pos: 0,
464 code: 0,
465 range: 0xFFFF_FF00,
466 };
467 for _ in 0..4 {
469 dec.code = (dec.code << 8) | u32::from(dec.read_byte());
470 }
471 dec
472 }
473
474 fn read_byte(&mut self) -> u8 {
475 if self.pos < self.data.len() {
476 let b = self.data[self.pos];
477 self.pos += 1;
478 b
479 } else {
480 0xFF }
482 }
483
484 #[allow(dead_code)]
488 pub fn decode_symbol(&mut self, entry: &ProbTableEntry) -> bool {
489 let r = self.range >> TABLE_PROB_BITS;
490 let split = r.saturating_mul(entry.cum_prob_low + entry.width_low);
491 let is_high = self.code >= split;
492
493 if is_high {
494 self.code = self.code.wrapping_sub(split);
495 self.range = r.saturating_mul(entry.width_high).max(1);
496 } else {
497 self.range = r.saturating_mul(entry.width_low).max(1);
498 }
499
500 while self.range < TableArithmeticCoder::RANGE_MIN {
502 self.code = (self.code << 8) | u32::from(self.read_byte());
503 self.range <<= 8;
504 }
505
506 is_high
507 }
508}
509
510#[derive(Clone, Debug)]
519pub struct CabacContext {
520 pub state: u8,
522 pub mps: bool,
524}
525
526impl CabacContext {
527 pub fn new() -> Self {
529 Self {
530 state: 64, mps: false,
532 }
533 }
534
535 pub fn with_state(init_state: u8, mps: bool) -> Self {
540 Self {
541 state: init_state.min(127).max(1),
542 mps,
543 }
544 }
545
546 pub fn update(&mut self, bin: bool) {
552 if bin == self.mps {
553 self.state = self.state.saturating_add(((127 - self.state) >> 3).max(1));
555 if self.state > 127 {
556 self.state = 127;
557 }
558 } else {
559 if self.state <= 1 {
561 self.mps = !self.mps;
563 self.state = 2;
564 } else {
565 self.state = self.state.saturating_sub((self.state >> 3).max(1));
566 }
567 }
568 }
569
570 pub fn mps_probability(&self) -> f64 {
572 self.state as f64 / 128.0
573 }
574}
575
576#[derive(Clone, Debug)]
578pub struct CabacEncoder {
579 pub contexts: Vec<CabacContext>,
581 pub coder: ArithmeticCoder,
583 pub bins_encoded: u64,
585}
586
587impl CabacEncoder {
588 pub fn new(num_contexts: usize) -> Self {
590 Self {
591 contexts: (0..num_contexts).map(|_| CabacContext::new()).collect(),
592 coder: ArithmeticCoder::new(),
593 bins_encoded: 0,
594 }
595 }
596
597 pub fn encode_bin(&mut self, ctx_id: usize, bin: bool) -> Vec<u8> {
601 let ctx = if ctx_id < self.contexts.len() {
602 &self.contexts[ctx_id]
603 } else {
604 return self.coder.encode_bit(0.5, bin);
606 };
607
608 let prob_one = if ctx.mps {
609 ctx.mps_probability()
610 } else {
611 1.0 - ctx.mps_probability()
612 };
613
614 let bytes = self.coder.encode_bit(prob_one as f32, bin);
615
616 if ctx_id < self.contexts.len() {
618 self.contexts[ctx_id].update(bin);
619 }
620 self.bins_encoded += 1;
621 bytes
622 }
623
624 pub fn encode_bypass(&mut self, bin: bool) -> Vec<u8> {
626 self.bins_encoded += 1;
627 self.coder.encode_bit(0.5, bin)
628 }
629}
630
631#[derive(Clone, Debug)]
637pub struct RangeEncoder {
638 low: u64,
640 range: u64,
642 output: Vec<u8>,
644 carry_count: u32,
646 first_byte: bool,
648}
649
650impl RangeEncoder {
651 const TOP: u64 = 1 << 24;
653 const BOT: u64 = 1 << 16;
655
656 pub fn new() -> Self {
658 Self {
659 low: 0,
660 range: u32::MAX as u64,
661 output: Vec::new(),
662 carry_count: 0,
663 first_byte: true,
664 }
665 }
666
667 pub fn encode(&mut self, cum_freq: u64, sym_freq: u64, total_freq: u64) {
670 let r = self.range / total_freq;
671 self.low += r * cum_freq;
672 self.range = r * sym_freq;
673 self.renormalize();
674 }
675
676 fn renormalize(&mut self) {
677 while self.range < Self::BOT {
678 if self.low < 0xFF00_0000 || self.first_byte {
679 if !self.first_byte {
680 self.output.push((self.low >> 24) as u8);
681 }
682 self.first_byte = false;
683 for _ in 0..self.carry_count {
684 self.output.push(0xFF);
685 }
686 self.carry_count = 0;
687 } else if self.low >= 0x1_0000_0000 {
688 if let Some(last) = self.output.last_mut() {
690 *last = last.wrapping_add(1);
691 }
692 for _ in 0..self.carry_count {
693 self.output.push(0x00);
694 }
695 self.carry_count = 0;
696 } else {
697 self.carry_count += 1;
698 }
699 self.low = (self.low << 8) & 0xFFFF_FFFF;
700 self.range <<= 8;
701 }
702 }
703
704 pub fn flush(mut self) -> Vec<u8> {
706 for _ in 0..5 {
708 self.range = Self::BOT.saturating_sub(1);
709 self.renormalize();
710 }
711 self.output
712 }
713
714 pub fn bytes_emitted(&self) -> usize {
716 self.output.len()
717 }
718}
719
720pub fn optimal_code_lengths(freqs: &[u32], max_length: u8) -> Vec<(usize, u8)> {
731 let max_length = max_length.max(1).min(30);
732
733 let symbols: Vec<(usize, u32)> = freqs
735 .iter()
736 .enumerate()
737 .filter(|(_, &f)| f > 0)
738 .map(|(i, &f)| (i, f))
739 .collect();
740
741 if symbols.is_empty() {
742 return vec![];
743 }
744 if symbols.len() == 1 {
745 return vec![(symbols[0].0, 1)];
746 }
747
748 let tree = build_huffman_tree(freqs);
750 let codes = compute_huffman_codes(&tree, vec![]);
751
752 let mut lengths: Vec<(usize, u8)> = codes
753 .iter()
754 .map(|(sym, code)| (*sym as usize, code.len() as u8))
755 .collect();
756
757 let mut changed = true;
759 while changed {
760 changed = false;
761 for entry in lengths.iter_mut() {
763 if entry.1 > max_length {
764 entry.1 = max_length;
765 changed = true;
766 }
767 }
768
769 let kraft_sum: f64 = lengths
771 .iter()
772 .map(|(_, l)| 2.0_f64.powi(-(*l as i32)))
773 .sum();
774 if kraft_sum > 1.0 && changed {
775 lengths.sort_by_key(|(_, l)| *l);
778 for idx in 0..lengths.len() {
779 if lengths[idx].1 < max_length {
780 let new_kraft: f64 = (0..lengths.len())
781 .map(|i| 2.0_f64.powi(-(lengths[i].1 as i32)))
782 .sum();
783 if new_kraft > 1.0 {
784 lengths[idx].1 += 1;
785 } else {
786 break;
787 }
788 }
789 }
790 }
791 }
792
793 lengths.sort_by_key(|(sym, _)| *sym);
795 lengths
796}
797
798pub fn estimate_block_entropy(symbols: &[u8]) -> f64 {
807 if symbols.is_empty() {
808 return 0.0;
809 }
810
811 let mut freq = [0u32; 256];
812 for &s in symbols {
813 freq[s as usize] += 1;
814 }
815
816 let n = symbols.len() as f64;
817 let mut entropy = 0.0_f64;
818 for &f in &freq {
819 if f > 0 {
820 let p = f as f64 / n;
821 entropy -= p * p.log2();
822 }
823 }
824
825 entropy * n
827}
828
829pub fn estimate_entropy_from_freqs(freqs: &[u32]) -> f64 {
831 let total: u64 = freqs.iter().map(|&f| f as u64).sum();
832 if total == 0 {
833 return 0.0;
834 }
835
836 let mut entropy = 0.0_f64;
837 for &f in freqs {
838 if f > 0 {
839 let p = f as f64 / total as f64;
840 entropy -= p * p.log2();
841 }
842 }
843 entropy
844}
845
846pub fn compare_coding_strategies(freqs_a: &[u32], freqs_b: &[u32], symbol_count: u64) -> bool {
850 let entropy_a = estimate_entropy_from_freqs(freqs_a);
851 let entropy_b = estimate_entropy_from_freqs(freqs_b);
852 let bits_a = entropy_a * symbol_count as f64;
853 let bits_b = entropy_b * symbol_count as f64;
854 bits_a <= bits_b
855}
856
857#[derive(Clone, Debug)]
866pub struct AdaptiveFrequencyTracker {
867 window: Vec<u8>,
869 pos: usize,
871 count: usize,
873 capacity: usize,
875 freq: [u32; 256],
877}
878
879impl AdaptiveFrequencyTracker {
880 pub fn new(window_size: usize) -> Self {
882 let cap = window_size.max(1);
883 Self {
884 window: vec![0; cap],
885 pos: 0,
886 count: 0,
887 capacity: cap,
888 freq: [0u32; 256],
889 }
890 }
891
892 pub fn observe(&mut self, symbol: u8) {
894 if self.count >= self.capacity {
895 let oldest = self.window[self.pos];
897 self.freq[oldest as usize] = self.freq[oldest as usize].saturating_sub(1);
898 } else {
899 self.count += 1;
900 }
901 self.window[self.pos] = symbol;
902 self.freq[symbol as usize] += 1;
903 self.pos = (self.pos + 1) % self.capacity;
904 }
905
906 pub fn frequency(&self, symbol: u8) -> u32 {
908 self.freq[symbol as usize]
909 }
910
911 pub fn total(&self) -> usize {
913 self.count
914 }
915
916 pub fn probability(&self, symbol: u8) -> f64 {
918 if self.count == 0 {
919 return 0.0;
920 }
921 self.freq[symbol as usize] as f64 / self.count as f64
922 }
923
924 pub fn frequency_table(&self) -> [u32; 256] {
926 self.freq
927 }
928
929 pub fn reset(&mut self) {
931 self.pos = 0;
932 self.count = 0;
933 self.freq = [0u32; 256];
934 for b in self.window.iter_mut() {
935 *b = 0;
936 }
937 }
938}
939
940#[cfg(test)]
941mod tests {
942 use super::*;
943
944 #[test]
947 fn arithmetic_coder_new_initial_range() {
948 let coder = ArithmeticCoder::new();
949 assert_eq!(coder.low, 0);
950 assert_eq!(coder.high, 0xFFFF_FFFF);
951 assert_eq!(coder.get_range(), 0x1_0000_0000u64);
953 }
954
955 #[test]
956 fn arithmetic_coder_get_range() {
957 let c = ArithmeticCoder::new();
958 let initial_range = c.get_range();
959 assert!(initial_range > 0);
961 assert_eq!(initial_range, 0x1_0000_0000u64);
963 let mut c2 = ArithmeticCoder::new();
966 c2.encode_bit(0.9, true);
967 assert!(c2.get_range() > 0);
968 assert!(c2.low <= c2.high);
969 }
970
971 #[test]
972 fn arithmetic_coder_encode_bit_does_not_panic() {
973 let mut c = ArithmeticCoder::new();
974 let _bytes = c.encode_bit(0.5, true);
975 let _bytes = c.encode_bit(0.5, false);
976 let _bytes = c.encode_bit(0.9, true);
977 }
979
980 #[test]
981 fn arithmetic_coder_bits_to_follow_increments() {
982 let mut c = ArithmeticCoder::new();
983 for _ in 0..16 {
985 c.encode_bit(0.5, true);
986 }
987 assert!(c.low <= c.high);
989 }
990
991 #[test]
992 fn arithmetic_coder_encode_sequence_returns_bytes() {
993 let mut c = ArithmeticCoder::new();
994 let mut all_bytes = Vec::new();
995 for _ in 0..32 {
997 all_bytes.extend(c.encode_bit(0.95, true));
998 }
999 assert!(all_bytes.len() <= 32 * 2); }
1002
1003 #[test]
1006 fn bits_to_bytes_empty() {
1007 let b = bits_to_bytes(&[]);
1008 assert!(b.is_empty());
1009 }
1010
1011 #[test]
1012 fn bits_to_bytes_full_byte() {
1013 let bits = [true, false, true, false, true, false, true, false];
1015 let b = bits_to_bytes(&bits);
1016 assert_eq!(b, vec![0xAA]);
1017 }
1018
1019 #[test]
1022 fn range_coder_new() {
1023 let rc = RangeCoder::new();
1024 assert_eq!(rc.range, 256);
1025 assert_eq!(rc.code, 0);
1026 }
1027
1028 #[test]
1029 fn range_coder_normalize_already_normalised() {
1030 let mut rc = RangeCoder::new();
1031 let bits = rc.normalize();
1032 assert_eq!(bits, 0); }
1034
1035 #[test]
1036 fn range_coder_normalize_below_128() {
1037 let mut rc = RangeCoder { range: 32, code: 0 };
1038 let bits = rc.normalize();
1039 assert!(rc.range >= 128);
1040 assert_eq!(bits, 2); }
1042
1043 #[test]
1044 fn range_coder_decode_symbol_high_partition() {
1045 let mut rc = RangeCoder {
1046 range: 256,
1047 code: 200,
1048 };
1049 let sym = rc.decode_symbol(128);
1051 assert!(sym);
1052 assert_eq!(rc.range, 256 - 128);
1053 assert_eq!(rc.code, 200 - 128);
1054 }
1055
1056 #[test]
1057 fn range_coder_decode_symbol_low_partition() {
1058 let mut rc = RangeCoder {
1059 range: 256,
1060 code: 50,
1061 };
1062 let sym = rc.decode_symbol(128);
1064 assert!(!sym);
1065 assert_eq!(rc.range, 128);
1066 assert_eq!(rc.code, 50);
1067 }
1068
1069 #[test]
1072 fn huffman_node_is_leaf_true() {
1073 let leaf = HuffmanNode {
1074 symbol: Some(42),
1075 freq: 10,
1076 left: None,
1077 right: None,
1078 };
1079 assert!(leaf.is_leaf());
1080 }
1081
1082 #[test]
1083 fn huffman_node_is_leaf_false() {
1084 let inner = HuffmanNode {
1085 symbol: None,
1086 freq: 20,
1087 left: Some(Box::new(HuffmanNode {
1088 symbol: Some(0),
1089 freq: 10,
1090 left: None,
1091 right: None,
1092 })),
1093 right: None,
1094 };
1095 assert!(!inner.is_leaf());
1096 }
1097
1098 #[test]
1099 fn build_huffman_tree_two_symbols() {
1100 let freqs = [10u32, 20];
1101 let tree = build_huffman_tree(&freqs);
1102 assert!(!tree.is_leaf());
1103 assert_eq!(tree.freq, 30);
1104 let codes = compute_huffman_codes(&tree, vec![]);
1105 assert_eq!(codes.len(), 2);
1107 }
1108
1109 #[test]
1110 fn build_huffman_tree_multiple_symbols() {
1111 let freqs = [5u32, 9, 12, 13, 16, 45];
1113 let tree = build_huffman_tree(&freqs);
1114 let codes = compute_huffman_codes(&tree, vec![]);
1115 assert_eq!(codes.len(), 6);
1116 let mut code_map = std::collections::HashMap::new();
1118 for (sym, code) in &codes {
1119 code_map.insert(*sym, code.len());
1120 }
1121 assert!(code_map[&5] <= code_map[&0]);
1123 }
1124
1125 #[test]
1126 fn build_huffman_tree_empty_freqs() {
1127 let tree = build_huffman_tree(&[]);
1128 assert!(tree.is_leaf());
1130 assert_eq!(tree.symbol, Some(0));
1131 }
1132
1133 #[test]
1134 fn build_huffman_tree_single_symbol() {
1135 let freqs = [0u32, 7, 0];
1136 let tree = build_huffman_tree(&freqs);
1137 assert!(!tree.is_leaf());
1139 let codes = compute_huffman_codes(&tree, vec![]);
1140 assert_eq!(codes.len(), 1);
1141 assert_eq!(codes[0].0, 1); }
1143
1144 #[test]
1145 fn compute_huffman_codes_all_unique() {
1146 let freqs = [1u32, 2, 4, 8];
1147 let tree = build_huffman_tree(&freqs);
1148 let codes = compute_huffman_codes(&tree, vec![]);
1149 let symbols: Vec<u8> = codes.iter().map(|(s, _)| *s).collect();
1150 let mut sorted = symbols.clone();
1152 sorted.sort_unstable();
1153 sorted.dedup();
1154 assert_eq!(sorted.len(), symbols.len());
1155 }
1156
1157 #[test]
1160 fn table_coder_build_prob_table_basic() {
1161 let freqs = [10u32, 30, 20, 5];
1162 let table = build_prob_table(&freqs);
1163 assert_eq!(table.len(), 4);
1164 for entry in &table {
1166 assert!(entry.cum_prob_low <= TABLE_SIZE as u32);
1168 }
1169 }
1170
1171 #[test]
1172 fn table_coder_build_prob_table_empty() {
1173 let table = build_prob_table(&[]);
1174 assert!(table.is_empty());
1175 }
1176
1177 #[test]
1178 fn table_coder_encode_produces_bytes() {
1179 let freqs = [128u32, 128u32]; let table = build_prob_table(&freqs);
1181 let mut enc = TableArithmeticCoder::new();
1182 for _ in 0..32 {
1183 enc.encode_symbol(false, &table[0]);
1184 }
1185 let data = enc.flush();
1186 assert!(!data.is_empty());
1187 }
1188
1189 #[test]
1190 fn table_coder_encode_decode_roundtrip() {
1191 let freqs = [128u32, 128u32]; let table = build_prob_table(&freqs);
1195 let symbols: Vec<bool> = vec![false, false, true, false, true, true, false];
1196
1197 let mut enc = TableArithmeticCoder::new();
1199 for &s in &symbols {
1200 enc.encode_symbol(s, &table[0]); }
1202 let data = enc.flush();
1203
1204 let mut dec = TableArithmeticDecoder::new(&data);
1206 let mut decoded = Vec::new();
1207 for _ in 0..symbols.len() {
1208 decoded.push(dec.decode_symbol(&table[0]));
1209 }
1210
1211 assert_eq!(
1212 decoded, symbols,
1213 "Round-trip must reproduce the original symbols"
1214 );
1215 }
1216
1217 #[test]
1218 fn table_coder_bytes_emitted_before_flush() {
1219 let freqs = [1u32, 255u32];
1220 let table = build_prob_table(&freqs);
1221 let mut enc = TableArithmeticCoder::new();
1222 for _ in 0..100 {
1223 enc.encode_symbol(true, &table[1]);
1224 }
1225 let mid_count = enc.bytes_emitted();
1227 let data = enc.flush();
1228 assert!(data.len() >= mid_count);
1229 }
1230
1231 #[test]
1232 fn table_coder_all_high_partition() {
1233 let freqs = [50u32, 206u32];
1234 let table = build_prob_table(&freqs);
1235 let symbols = vec![true; 20];
1236
1237 let mut enc = TableArithmeticCoder::new();
1238 for &s in &symbols {
1239 enc.encode_symbol(s, &table[1]);
1240 }
1241 let data = enc.flush();
1242
1243 let mut dec = TableArithmeticDecoder::new(&data);
1244 for _ in 0..symbols.len() {
1245 let sym = dec.decode_symbol(&table[0]);
1246 assert!(sym, "should decode as high partition");
1247 }
1248 }
1249
1250 #[test]
1251 fn table_coder_all_low_partition() {
1252 let freqs = [200u32, 56u32];
1253 let table = build_prob_table(&freqs);
1254 let symbols = vec![false; 20];
1255
1256 let mut enc = TableArithmeticCoder::new();
1257 for &s in &symbols {
1258 enc.encode_symbol(s, &table[0]);
1259 }
1260 let data = enc.flush();
1261
1262 let mut dec = TableArithmeticDecoder::new(&data);
1263 for _ in 0..symbols.len() {
1264 let sym = dec.decode_symbol(&table[0]);
1265 assert!(!sym, "should decode as low partition");
1266 }
1267 }
1268
1269 #[test]
1272 fn cabac_context_initial_equi_probable() {
1273 let ctx = CabacContext::new();
1274 let p = ctx.mps_probability();
1275 assert!((p - 0.5).abs() < 0.01);
1276 }
1277
1278 #[test]
1279 fn cabac_context_adapts_towards_mps() {
1280 let mut ctx = CabacContext::new();
1281 for _ in 0..20 {
1282 ctx.update(ctx.mps);
1283 }
1284 assert!(
1285 ctx.mps_probability() > 0.7,
1286 "should converge towards high confidence"
1287 );
1288 }
1289
1290 #[test]
1291 fn cabac_context_adapts_towards_lps() {
1292 let mut ctx = CabacContext::new();
1293 let lps = !ctx.mps;
1294 for _ in 0..30 {
1295 ctx.update(lps);
1296 }
1297 assert!(ctx.mps == lps || ctx.state <= 10);
1299 }
1300
1301 #[test]
1302 fn cabac_context_with_biased_state() {
1303 let ctx = CabacContext::with_state(120, true);
1304 assert!(ctx.mps_probability() > 0.9);
1305 assert!(ctx.mps);
1306 }
1307
1308 #[test]
1309 fn cabac_encoder_basic() {
1310 let mut enc = CabacEncoder::new(4);
1311 let mut bytes = Vec::new();
1312 for i in 0..16 {
1313 bytes.extend(enc.encode_bin(i % 4, i % 2 == 0));
1314 }
1315 assert_eq!(enc.bins_encoded, 16);
1316 }
1317
1318 #[test]
1319 fn cabac_encoder_bypass_mode() {
1320 let mut enc = CabacEncoder::new(1);
1321 let bytes = enc.encode_bypass(true);
1322 assert_eq!(enc.bins_encoded, 1);
1323 let p = enc.contexts[0].mps_probability();
1325 assert!((p - 0.5).abs() < 0.01);
1326 }
1327
1328 #[test]
1331 fn range_encoder_encode_flush() {
1332 let mut enc = RangeEncoder::new();
1333 enc.encode(0, 50, 100);
1334 enc.encode(50, 50, 100);
1335 let data = enc.flush();
1336 assert!(!data.is_empty());
1337 }
1338
1339 #[test]
1340 fn range_encoder_bytes_emitted() {
1341 let mut enc = RangeEncoder::new();
1342 for _ in 0..100 {
1343 enc.encode(0, 128, 256);
1344 }
1345 let mid = enc.bytes_emitted();
1346 let data = enc.flush();
1347 assert!(data.len() >= mid);
1348 }
1349
1350 #[test]
1353 fn optimal_code_lengths_basic() {
1354 let freqs = [10u32, 20, 40, 80];
1355 let lengths = optimal_code_lengths(&freqs, 15);
1356 assert_eq!(lengths.len(), 4);
1357 let len_map: std::collections::HashMap<usize, u8> = lengths.iter().cloned().collect();
1359 assert!(len_map[&3] <= len_map[&0]);
1360 }
1361
1362 #[test]
1363 fn optimal_code_lengths_max_length_respected() {
1364 let freqs = [1u32, 1, 1, 1, 1, 1, 1, 1, 100];
1365 let lengths = optimal_code_lengths(&freqs, 4);
1366 for (_, l) in &lengths {
1367 assert!(*l <= 4, "code length {} exceeds max 4", l);
1368 }
1369 }
1370
1371 #[test]
1372 fn optimal_code_lengths_single_symbol() {
1373 let freqs = [0u32, 0, 42];
1374 let lengths = optimal_code_lengths(&freqs, 10);
1375 assert_eq!(lengths.len(), 1);
1376 assert_eq!(lengths[0], (2, 1));
1377 }
1378
1379 #[test]
1380 fn optimal_code_lengths_empty() {
1381 let lengths = optimal_code_lengths(&[], 10);
1382 assert!(lengths.is_empty());
1383 }
1384
1385 #[test]
1388 fn estimate_block_entropy_uniform() {
1389 let block = vec![42u8; 100];
1391 let bits = estimate_block_entropy(&block);
1392 assert!(
1393 bits < 1.0,
1394 "uniform block entropy should be ~0, got {}",
1395 bits
1396 );
1397 }
1398
1399 #[test]
1400 fn estimate_block_entropy_binary() {
1401 let mut block = vec![0u8; 100];
1403 for b in block.iter_mut().step_by(2) {
1404 *b = 1;
1405 }
1406 let bits = estimate_block_entropy(&block);
1407 let bits_per_sym = bits / 100.0;
1408 assert!((bits_per_sym - 1.0).abs() < 0.1);
1409 }
1410
1411 #[test]
1412 fn estimate_entropy_from_freqs_uniform() {
1413 let freqs = vec![1u32; 256];
1415 let entropy = estimate_entropy_from_freqs(&freqs);
1416 assert!((entropy - 8.0).abs() < 0.01);
1417 }
1418
1419 #[test]
1420 fn compare_coding_strategies_picks_better() {
1421 let a = [100u32, 1, 1, 1];
1423 let b = [25u32, 25, 25, 25];
1424 assert!(compare_coding_strategies(&a, &b, 1000));
1425 assert!(!compare_coding_strategies(&b, &a, 1000));
1426 }
1427
1428 #[test]
1431 fn adaptive_tracker_basic() {
1432 let mut tracker = AdaptiveFrequencyTracker::new(10);
1433 tracker.observe(5);
1434 tracker.observe(5);
1435 tracker.observe(3);
1436 assert_eq!(tracker.frequency(5), 2);
1437 assert_eq!(tracker.frequency(3), 1);
1438 assert_eq!(tracker.total(), 3);
1439 }
1440
1441 #[test]
1442 fn adaptive_tracker_window_eviction() {
1443 let mut tracker = AdaptiveFrequencyTracker::new(3);
1444 tracker.observe(1);
1445 tracker.observe(2);
1446 tracker.observe(3);
1447 assert_eq!(tracker.frequency(1), 1);
1448
1449 tracker.observe(4);
1451 assert_eq!(tracker.frequency(1), 0);
1452 assert_eq!(tracker.frequency(4), 1);
1453 assert_eq!(tracker.total(), 3);
1454 }
1455
1456 #[test]
1457 fn adaptive_tracker_probability() {
1458 let mut tracker = AdaptiveFrequencyTracker::new(100);
1459 for _ in 0..75 {
1460 tracker.observe(0);
1461 }
1462 for _ in 0..25 {
1463 tracker.observe(1);
1464 }
1465 let p0 = tracker.probability(0);
1466 assert!((p0 - 0.75).abs() < 0.01);
1467 }
1468
1469 #[test]
1470 fn adaptive_tracker_reset() {
1471 let mut tracker = AdaptiveFrequencyTracker::new(10);
1472 tracker.observe(42);
1473 tracker.reset();
1474 assert_eq!(tracker.frequency(42), 0);
1475 assert_eq!(tracker.total(), 0);
1476 }
1477
1478 #[test]
1479 fn adaptive_tracker_frequency_table() {
1480 let mut tracker = AdaptiveFrequencyTracker::new(100);
1481 tracker.observe(10);
1482 tracker.observe(10);
1483 tracker.observe(20);
1484 let table = tracker.frequency_table();
1485 assert_eq!(table[10], 2);
1486 assert_eq!(table[20], 1);
1487 assert_eq!(table[0], 0);
1488 }
1489}