1use crate::error::{TokenizerError, TokenizerResult};
28use std::collections::{BinaryHeap, HashMap};
29
30#[derive(Debug, Clone, Eq, PartialEq)]
32pub struct HuffmanNode {
33 symbol: Option<u32>,
35 frequency: u64,
37 left: Option<usize>,
39 right: Option<usize>,
41}
42
43impl Ord for HuffmanNode {
44 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
45 other.frequency.cmp(&self.frequency)
47 }
48}
49
50impl PartialOrd for HuffmanNode {
51 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
52 Some(self.cmp(other))
53 }
54}
55
56pub struct HuffmanEncoder {
61 codebook: HashMap<u32, Vec<bool>>,
63 tree_nodes: Vec<HuffmanNode>,
65 root_idx: usize,
67}
68
69impl HuffmanEncoder {
70 pub fn from_frequencies(frequencies: &HashMap<u32, u64>) -> TokenizerResult<Self> {
91 if frequencies.is_empty() {
92 return Err(TokenizerError::encoding(
93 "encoding",
94 "Cannot build Huffman tree from empty frequencies",
95 ));
96 }
97
98 if frequencies.len() == 1 {
100 let symbol = *frequencies
101 .keys()
102 .next()
103 .expect("Frequencies map is non-empty");
104 let mut codebook = HashMap::new();
105 codebook.insert(symbol, vec![false]); let node = HuffmanNode {
108 symbol: Some(symbol),
109 frequency: *frequencies
110 .get(&symbol)
111 .expect("Symbol exists in frequencies map"),
112 left: None,
113 right: None,
114 };
115
116 return Ok(Self {
117 codebook,
118 tree_nodes: vec![node],
119 root_idx: 0,
120 });
121 }
122
123 #[derive(Eq, PartialEq)]
125 struct HeapEntry {
126 frequency: u64,
127 idx: usize,
128 }
129
130 impl Ord for HeapEntry {
131 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
132 other
134 .frequency
135 .cmp(&self.frequency)
136 .then_with(|| other.idx.cmp(&self.idx))
137 }
138 }
139
140 impl PartialOrd for HeapEntry {
141 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
142 Some(self.cmp(other))
143 }
144 }
145
146 let mut heap = BinaryHeap::new();
147 let mut nodes = Vec::new();
148
149 for (&symbol, &freq) in frequencies {
151 let idx = nodes.len();
152 nodes.push(HuffmanNode {
153 symbol: Some(symbol),
154 frequency: freq,
155 left: None,
156 right: None,
157 });
158 heap.push(HeapEntry {
159 frequency: freq,
160 idx,
161 });
162 }
163
164 while heap.len() > 1 {
166 let entry1 = heap.pop().expect("Heap has at least 2 elements");
167 let entry2 = heap.pop().expect("Heap has at least 2 elements");
168
169 let combined_freq = entry1.frequency + entry2.frequency;
170 let parent_idx = nodes.len();
171
172 nodes.push(HuffmanNode {
173 symbol: None,
174 frequency: combined_freq,
175 left: Some(entry1.idx),
176 right: Some(entry2.idx),
177 });
178
179 heap.push(HeapEntry {
180 frequency: combined_freq,
181 idx: parent_idx,
182 });
183 }
184
185 let root_idx = heap
186 .pop()
187 .expect("Heap has exactly 1 root element after loop")
188 .idx;
189
190 let mut codebook = HashMap::new();
192 let mut stack = vec![(root_idx, Vec::new())];
193
194 while let Some((idx, code)) = stack.pop() {
195 let node = &nodes[idx];
196
197 if let Some(symbol) = node.symbol {
198 codebook.insert(symbol, code);
200 } else {
201 if let Some(left_idx) = node.left {
203 let mut left_code = code.clone();
204 left_code.push(false); stack.push((left_idx, left_code));
206 }
207 if let Some(right_idx) = node.right {
208 let mut right_code = code.clone();
209 right_code.push(true); stack.push((right_idx, right_code));
211 }
212 }
213 }
214
215 Ok(Self {
216 codebook,
217 tree_nodes: nodes,
218 root_idx,
219 })
220 }
221
222 pub fn encode(&self, symbols: &[u32]) -> TokenizerResult<Vec<u8>> {
232 let mut bits = Vec::new();
233
234 for &symbol in symbols {
236 let code = self.codebook.get(&symbol).ok_or_else(|| {
237 TokenizerError::encoding("serialization", format!("Unknown symbol: {}", symbol))
238 })?;
239 bits.extend_from_slice(code);
240 }
241
242 let num_bits = bits.len();
244 let num_bytes = num_bits.div_ceil(8);
245 let mut bytes = vec![0u8; num_bytes];
246
247 for (i, &bit) in bits.iter().enumerate() {
248 if bit {
249 bytes[i / 8] |= 1 << (7 - (i % 8));
250 }
251 }
252
253 let mut result = Vec::new();
255 result.extend_from_slice(&(symbols.len() as u32).to_le_bytes());
256 result.extend_from_slice(&(num_bits as u32).to_le_bytes());
257 result.extend_from_slice(&bytes);
258
259 Ok(result)
260 }
261
262 pub fn codebook(&self) -> &HashMap<u32, Vec<bool>> {
264 &self.codebook
265 }
266
267 pub fn tree(&self) -> (&[HuffmanNode], usize) {
269 (&self.tree_nodes, self.root_idx)
270 }
271
272 pub fn average_code_length(&self, frequencies: &HashMap<u32, u64>) -> f64 {
274 let total: u64 = frequencies.values().sum();
275 if total == 0 {
276 return 0.0;
277 }
278
279 let mut weighted_sum = 0.0;
280 for (symbol, freq) in frequencies {
281 if let Some(code) = self.codebook.get(symbol) {
282 weighted_sum += code.len() as f64 * (*freq as f64);
283 }
284 }
285
286 weighted_sum / total as f64
287 }
288
289 pub fn entropy(frequencies: &HashMap<u32, u64>) -> f64 {
291 let total: u64 = frequencies.values().sum();
292 if total == 0 {
293 return 0.0;
294 }
295
296 let mut entropy = 0.0;
297 for freq in frequencies.values() {
298 if *freq > 0 {
299 let p = *freq as f64 / total as f64;
300 entropy -= p * p.log2();
301 }
302 }
303
304 entropy
305 }
306}
307
308pub struct HuffmanDecoder {
310 tree_nodes: Vec<HuffmanNode>,
312 root_idx: usize,
314}
315
316impl HuffmanDecoder {
317 pub fn new(tree: (&[HuffmanNode], usize)) -> Self {
319 Self {
320 tree_nodes: tree.0.to_vec(),
321 root_idx: tree.1,
322 }
323 }
324
325 pub fn decode(&self, encoded: &[u8]) -> TokenizerResult<Vec<u32>> {
335 if encoded.len() < 8 {
336 return Err(TokenizerError::decoding(
337 "decoding",
338 "Encoded data too short (missing metadata)",
339 ));
340 }
341
342 let num_symbols =
344 u32::from_le_bytes([encoded[0], encoded[1], encoded[2], encoded[3]]) as usize;
345 let num_bits =
346 u32::from_le_bytes([encoded[4], encoded[5], encoded[6], encoded[7]]) as usize;
347
348 let bytes = &encoded[8..];
350 let mut bits = Vec::with_capacity(num_bits);
351
352 for (byte_idx, &byte) in bytes.iter().enumerate() {
353 for bit_idx in 0..8 {
354 if byte_idx * 8 + bit_idx >= num_bits {
355 break;
356 }
357 bits.push((byte & (1 << (7 - bit_idx))) != 0);
358 }
359 }
360
361 let mut symbols = Vec::with_capacity(num_symbols);
363 let mut current_idx = self.root_idx;
364
365 let root = &self.tree_nodes[self.root_idx];
367 if root.left.is_none() && root.right.is_none() {
368 if let Some(symbol) = root.symbol {
370 for _ in 0..num_symbols {
371 symbols.push(symbol);
372 }
373 return Ok(symbols);
374 }
375 }
376
377 for &bit in &bits {
379 let node = &self.tree_nodes[current_idx];
380
381 current_idx = if bit {
383 node.right.ok_or_else(|| {
384 TokenizerError::decoding(
385 "deserialization",
386 "Invalid bitstream: unexpected leaf",
387 )
388 })?
389 } else {
390 node.left.ok_or_else(|| {
391 TokenizerError::decoding(
392 "deserialization",
393 "Invalid bitstream: unexpected leaf",
394 )
395 })?
396 };
397
398 let current_node = &self.tree_nodes[current_idx];
400 if let Some(symbol) = current_node.symbol {
401 symbols.push(symbol);
402 current_idx = self.root_idx; if symbols.len() == num_symbols {
405 break;
406 }
407 }
408 }
409
410 if symbols.len() != num_symbols {
411 return Err(TokenizerError::decoding(
412 "decoding",
413 format!(
414 "Decoded {} symbols, expected {}",
415 symbols.len(),
416 num_symbols
417 ),
418 ));
419 }
420
421 Ok(symbols)
422 }
423}
424
425pub struct ArithmeticEncoder {
430 frequencies: HashMap<u32, u64>,
432 total_count: u64,
434 min_count: u64,
436}
437
438impl ArithmeticEncoder {
439 pub fn new(alphabet_size: usize) -> Self {
445 let mut frequencies = HashMap::new();
446 for symbol in 0..alphabet_size as u32 {
447 frequencies.insert(symbol, 1);
448 }
449
450 Self {
451 frequencies,
452 total_count: alphabet_size as u64,
453 min_count: 1,
454 }
455 }
456
457 pub fn from_frequencies(frequencies: HashMap<u32, u64>) -> Self {
459 let total_count = frequencies.values().sum();
460 Self {
461 frequencies,
462 total_count,
463 min_count: 1,
464 }
465 }
466
467 fn update_frequency(&mut self, symbol: u32) {
469 *self.frequencies.entry(symbol).or_insert(self.min_count) += 1;
470 self.total_count += 1;
471
472 if self.total_count > 1_000_000 {
474 self.rescale_frequencies();
475 }
476 }
477
478 fn rescale_frequencies(&mut self) {
480 self.total_count = 0;
481 for freq in self.frequencies.values_mut() {
482 *freq = (*freq / 2).max(self.min_count);
483 self.total_count += *freq;
484 }
485 }
486
487 fn cumulative_frequency(&self, symbol: u32) -> (u64, u64) {
489 let mut cumulative = 0u64;
490
491 for s in 0..symbol {
492 cumulative += self.frequencies.get(&s).unwrap_or(&0);
493 }
494
495 let freq = self.frequencies.get(&symbol).unwrap_or(&self.min_count);
496 (cumulative, cumulative + freq)
497 }
498
499 pub fn encode(&mut self, symbols: &[u32], adaptive: bool) -> TokenizerResult<Vec<u8>> {
510 const PRECISION: u64 = 1u64 << 32; let mut low = 0u64;
513 let mut high = PRECISION - 1;
514
515 for &symbol in symbols {
516 let range = high - low + 1;
517 let (cum_low, cum_high) = self.cumulative_frequency(symbol);
518
519 high = low + (range * cum_high / self.total_count) - 1;
520 low += range * cum_low / self.total_count;
521
522 if adaptive {
524 self.update_frequency(symbol);
525 }
526
527 }
530
531 let value = (low + high) / 2;
533
534 let mut result = Vec::new();
536 result.extend_from_slice(&(symbols.len() as u32).to_le_bytes());
537 result.extend_from_slice(&value.to_le_bytes());
538
539 Ok(result)
540 }
541
542 pub fn frequencies(&self) -> &HashMap<u32, u64> {
544 &self.frequencies
545 }
546}
547
548pub struct ArithmeticDecoder {
550 frequencies: HashMap<u32, u64>,
552 total_count: u64,
554 alphabet: Vec<u32>,
556}
557
558impl ArithmeticDecoder {
559 pub fn new(frequencies: HashMap<u32, u64>) -> Self {
561 let total_count = frequencies.values().sum();
562 let mut alphabet: Vec<u32> = frequencies.keys().copied().collect();
563 alphabet.sort_unstable();
564
565 Self {
566 frequencies,
567 total_count,
568 alphabet,
569 }
570 }
571
572 pub fn decode(&self, encoded: &[u8]) -> TokenizerResult<Vec<u32>> {
574 if encoded.len() < 12 {
575 return Err(TokenizerError::decoding(
576 "decoding",
577 "Encoded data too short",
578 ));
579 }
580
581 let num_symbols =
582 u32::from_le_bytes([encoded[0], encoded[1], encoded[2], encoded[3]]) as usize;
583 let value = u64::from_le_bytes([
584 encoded[4],
585 encoded[5],
586 encoded[6],
587 encoded[7],
588 encoded[8],
589 encoded[9],
590 encoded[10],
591 encoded[11],
592 ]);
593
594 const PRECISION: u64 = 1u64 << 32;
595 let mut symbols = Vec::with_capacity(num_symbols);
596 let mut low = 0u64;
597 let mut high = PRECISION - 1;
598 let code_value = value;
599
600 for _ in 0..num_symbols {
601 let range = high - low + 1;
602
603 let scaled = ((code_value - low + 1) * self.total_count - 1) / range;
605
606 let mut cumulative = 0u64;
607 let mut found_symbol = None;
608
609 for &symbol in &self.alphabet {
610 let freq = self.frequencies.get(&symbol).unwrap_or(&0);
611 if scaled >= cumulative && scaled < cumulative + freq {
612 found_symbol = Some(symbol);
613 break;
614 }
615 cumulative += freq;
616 }
617
618 let symbol = found_symbol.ok_or_else(|| {
619 TokenizerError::decoding(
620 "decoding",
621 format!("Cannot decode symbol at position {}", symbols.len()),
622 )
623 })?;
624
625 symbols.push(symbol);
626
627 let (cum_low, cum_high) = self.cumulative_frequency(symbol);
629 high = low + (range * cum_high / self.total_count) - 1;
630 low += range * cum_low / self.total_count;
631 }
632
633 Ok(symbols)
634 }
635
636 fn cumulative_frequency(&self, symbol: u32) -> (u64, u64) {
637 let mut cumulative = 0u64;
638
639 for s in &self.alphabet {
640 if *s >= symbol {
641 break;
642 }
643 cumulative += self.frequencies.get(s).unwrap_or(&0);
644 }
645
646 let freq = self.frequencies.get(&symbol).unwrap_or(&0);
647 (cumulative, cumulative + freq)
648 }
649}
650
651pub fn compute_frequencies(symbols: &[u32]) -> HashMap<u32, u64> {
653 let mut frequencies = HashMap::new();
654 for &symbol in symbols {
655 *frequencies.entry(symbol).or_insert(0) += 1;
656 }
657 frequencies
658}
659
660pub struct RangeEncoder {
665 frequencies: HashMap<u32, u64>,
667 total_count: u64,
669 cumulative: Vec<(u32, u64, u64)>, }
672
673impl RangeEncoder {
674 pub fn from_frequencies(frequencies: HashMap<u32, u64>) -> TokenizerResult<Self> {
676 if frequencies.is_empty() {
677 return Err(TokenizerError::encoding(
678 "encoding",
679 "Cannot create range encoder from empty frequencies",
680 ));
681 }
682
683 let total_count: u64 = frequencies.values().sum();
684
685 let mut symbols: Vec<u32> = frequencies.keys().copied().collect();
687 symbols.sort_unstable();
688
689 let mut cumulative = Vec::new();
690 let mut cum_freq = 0u64;
691
692 for symbol in symbols {
693 let freq = frequencies.get(&symbol).unwrap_or(&0);
694 if *freq > 0 {
695 cumulative.push((symbol, cum_freq, cum_freq + freq));
696 cum_freq += freq;
697 }
698 }
699
700 Ok(Self {
701 frequencies,
702 total_count,
703 cumulative,
704 })
705 }
706
707 pub fn encode(&self, symbols: &[u32]) -> TokenizerResult<Vec<u8>> {
719 let scale = 1u64 << 14; let total = self.total_count;
723
724 let mut scaled_cum: Vec<(u32, u64, u64)> = Vec::new();
726 for (sym, cum_low, cum_high) in &self.cumulative {
727 let scaled_low = ((*cum_low as u128 * scale as u128) / total as u128) as u64;
728 let scaled_high = ((*cum_high as u128 * scale as u128) / total as u128) as u64;
729 let scaled_high = scaled_high.max(scaled_low + 1);
731 scaled_cum.push((*sym, scaled_low, scaled_high));
732 }
733
734 let mut low: u64 = 0;
735 let mut range: u64 = 1u64 << 32;
736 let mut output = Vec::new();
737
738 for &symbol in symbols {
739 let (_, cum_low, cum_high) = scaled_cum
741 .iter()
742 .find(|(s, _, _)| *s == symbol)
743 .ok_or_else(|| {
744 TokenizerError::encoding("serialization", format!("Unknown symbol: {}", symbol))
745 })?;
746
747 let step = range / scale;
749 low += step * cum_low;
750 range = step * (cum_high - cum_low);
751
752 while range < (1u64 << 24) {
754 output.push((low >> 24) as u8);
755 low <<= 8;
756 low &= 0xFFFFFFFF; range <<= 8;
758 }
759 }
760
761 for _ in 0..4 {
763 output.push((low >> 24) as u8);
764 low <<= 8;
765 }
766
767 let mut result = Vec::new();
769 result.extend_from_slice(&(symbols.len() as u32).to_le_bytes());
770 result.extend_from_slice(&output);
771
772 Ok(result)
773 }
774
775 pub fn frequencies(&self) -> &HashMap<u32, u64> {
777 &self.frequencies
778 }
779}
780
781pub struct RangeDecoder {
783 cumulative: Vec<(u32, u64, u64)>,
785 total_count: u64,
787}
788
789impl RangeDecoder {
790 pub fn from_frequencies(frequencies: HashMap<u32, u64>) -> TokenizerResult<Self> {
792 if frequencies.is_empty() {
793 return Err(TokenizerError::decoding(
794 "decoding",
795 "Cannot create range decoder from empty frequencies",
796 ));
797 }
798
799 let total_count: u64 = frequencies.values().sum();
800
801 let mut symbols: Vec<u32> = frequencies.keys().copied().collect();
803 symbols.sort_unstable();
804
805 let mut cumulative = Vec::new();
806 let mut cum_freq = 0u64;
807
808 for symbol in symbols {
809 let freq = frequencies.get(&symbol).unwrap_or(&0);
810 if *freq > 0 {
811 cumulative.push((symbol, cum_freq, cum_freq + freq));
812 cum_freq += freq;
813 }
814 }
815
816 Ok(Self {
817 cumulative,
818 total_count,
819 })
820 }
821
822 pub fn decode(&self, encoded: &[u8]) -> TokenizerResult<Vec<u32>> {
824 if encoded.len() < 4 {
825 return Err(TokenizerError::decoding(
826 "decoding",
827 "Encoded data too short",
828 ));
829 }
830
831 let num_symbols =
832 u32::from_le_bytes([encoded[0], encoded[1], encoded[2], encoded[3]]) as usize;
833
834 let scale = 1u64 << 14;
836 let total = self.total_count;
837
838 let mut scaled_cum: Vec<(u32, u64, u64)> = Vec::new();
840 for (sym, cum_low, cum_high) in &self.cumulative {
841 let scaled_low = ((*cum_low as u128 * scale as u128) / total as u128) as u64;
842 let scaled_high = ((*cum_high as u128 * scale as u128) / total as u128) as u64;
843 let scaled_high = scaled_high.max(scaled_low + 1);
844 scaled_cum.push((*sym, scaled_low, scaled_high));
845 }
846
847 let data = &encoded[4..];
848 let mut data_idx = 0;
849
850 let mut code: u64 = 0;
852 for _ in 0..4 {
853 code = (code << 8) | (data.get(data_idx).copied().unwrap_or(0) as u64);
854 data_idx += 1;
855 }
856
857 let mut low: u64 = 0;
858 let mut range: u64 = 1u64 << 32;
859 let mut symbols = Vec::with_capacity(num_symbols);
860
861 for _ in 0..num_symbols {
862 let step = range / scale;
864 let value = code.wrapping_sub(low) / step;
866
867 let (symbol, cum_low, cum_high) = scaled_cum
868 .iter()
869 .find(|(_, cl, ch)| value >= *cl && value < *ch)
870 .ok_or_else(|| {
871 TokenizerError::decoding(
872 "decoding",
873 format!("Invalid encoded data at symbol {}", symbols.len()),
874 )
875 })?;
876
877 symbols.push(*symbol);
878
879 low += step * cum_low;
881 range = step * (cum_high - cum_low);
882
883 while range < (1u64 << 24) {
885 code <<= 8;
886 code &= 0xFFFFFFFF;
887 code |= data.get(data_idx).copied().unwrap_or(0) as u64;
888 data_idx += 1;
889 low <<= 8;
890 low &= 0xFFFFFFFF;
891 range <<= 8;
892 }
893 }
894
895 Ok(symbols)
896 }
897}
898
899pub struct BitrateController {
903 target_bits_per_symbol: f64,
905 current_bits_per_symbol: f64,
907 kp: f64,
909 ki: f64,
911 integral_error: f64,
913 quantization_step: f64,
915 min_step: f64,
917 max_step: f64,
919}
920
921impl BitrateController {
922 pub fn new(
931 target_bits_per_symbol: f64,
932 initial_step: f64,
933 kp: f64,
934 ki: f64,
935 ) -> TokenizerResult<Self> {
936 if target_bits_per_symbol <= 0.0 {
937 return Err(TokenizerError::InvalidConfig(
938 "Target bits per symbol must be positive".into(),
939 ));
940 }
941
942 if initial_step <= 0.0 {
943 return Err(TokenizerError::InvalidConfig(
944 "Initial step must be positive".into(),
945 ));
946 }
947
948 Ok(Self {
949 target_bits_per_symbol,
950 current_bits_per_symbol: target_bits_per_symbol,
951 kp,
952 ki,
953 integral_error: 0.0,
954 quantization_step: initial_step,
955 min_step: initial_step * 0.1,
956 max_step: initial_step * 10.0,
957 })
958 }
959
960 pub fn update(&mut self, actual_bits_per_symbol: f64) -> f64 {
970 let error = actual_bits_per_symbol - self.target_bits_per_symbol;
972
973 self.integral_error += error;
975
976 let adjustment = self.kp * error + self.ki * self.integral_error;
978
979 self.quantization_step *= (1.0 + adjustment).clamp(0.5, 2.0);
981
982 self.quantization_step = self.quantization_step.max(self.min_step).min(self.max_step);
984
985 self.current_bits_per_symbol = actual_bits_per_symbol;
987
988 self.quantization_step
989 }
990
991 pub fn current_step(&self) -> f64 {
993 self.quantization_step
994 }
995
996 pub fn target_bitrate(&self) -> f64 {
998 self.target_bits_per_symbol
999 }
1000
1001 pub fn current_bitrate(&self) -> f64 {
1003 self.current_bits_per_symbol
1004 }
1005
1006 pub fn reset(&mut self) {
1008 self.integral_error = 0.0;
1009 self.current_bits_per_symbol = self.target_bits_per_symbol;
1010 }
1011
1012 pub fn set_target(&mut self, target_bits_per_symbol: f64) -> TokenizerResult<()> {
1014 if target_bits_per_symbol <= 0.0 {
1015 return Err(TokenizerError::InvalidConfig(
1016 "Target bits per symbol must be positive".into(),
1017 ));
1018 }
1019 self.target_bits_per_symbol = target_bits_per_symbol;
1020 Ok(())
1021 }
1022}
1023
1024pub fn compression_ratio(original_bits: usize, compressed_bytes: usize) -> f64 {
1035 if compressed_bytes == 0 {
1036 return f64::INFINITY;
1037 }
1038 original_bits as f64 / (compressed_bytes * 8) as f64
1039}
1040
1041#[cfg(test)]
1042mod tests {
1043 use super::*;
1044
1045 #[test]
1046 fn test_huffman_single_symbol() {
1047 let mut freqs = HashMap::new();
1048 freqs.insert(42, 100);
1049
1050 let encoder = HuffmanEncoder::from_frequencies(&freqs).unwrap();
1051 let symbols = vec![42, 42, 42];
1052 let encoded = encoder.encode(&symbols).unwrap();
1053
1054 let decoder = HuffmanDecoder::new(encoder.tree());
1055 let decoded = decoder.decode(&encoded).unwrap();
1056
1057 assert_eq!(decoded, symbols);
1058 }
1059
1060 #[test]
1061 fn test_huffman_basic() {
1062 let mut freqs = HashMap::new();
1063 freqs.insert(0, 10);
1064 freqs.insert(1, 5);
1065 freqs.insert(2, 2);
1066 freqs.insert(3, 1);
1067
1068 let encoder = HuffmanEncoder::from_frequencies(&freqs).unwrap();
1069
1070 let code_0 = encoder.codebook().get(&0).unwrap();
1072 let code_3 = encoder.codebook().get(&3).unwrap();
1073 assert!(code_0.len() <= code_3.len());
1074
1075 let symbols = vec![0, 1, 2, 3, 0, 0, 1];
1077 let encoded = encoder.encode(&symbols).unwrap();
1078
1079 let decoder = HuffmanDecoder::new(encoder.tree());
1080 let decoded = decoder.decode(&encoded).unwrap();
1081
1082 assert_eq!(decoded, symbols);
1083 }
1084
1085 #[test]
1086 fn test_huffman_compression() {
1087 let mut freqs = HashMap::new();
1088 freqs.insert(0, 50); freqs.insert(1, 25);
1090 freqs.insert(2, 15);
1091 freqs.insert(3, 10);
1092
1093 let encoder = HuffmanEncoder::from_frequencies(&freqs).unwrap();
1094
1095 let symbols: Vec<u32> = (0..100)
1097 .map(|i| {
1098 if i < 50 {
1099 0
1100 } else if i < 75 {
1101 1
1102 } else if i < 90 {
1103 2
1104 } else {
1105 3
1106 }
1107 })
1108 .collect();
1109
1110 let encoded = encoder.encode(&symbols).unwrap();
1111
1112 let original_bits = symbols.len() * 2; let compressed_bits = (encoded.len() - 8) * 8; assert!(compressed_bits < original_bits);
1117
1118 let decoder = HuffmanDecoder::new(encoder.tree());
1120 let decoded = decoder.decode(&encoded).unwrap();
1121 assert_eq!(decoded, symbols);
1122 }
1123
1124 #[test]
1125 fn test_huffman_average_code_length() {
1126 let mut freqs = HashMap::new();
1127 freqs.insert(0, 8);
1128 freqs.insert(1, 4);
1129 freqs.insert(2, 2);
1130 freqs.insert(3, 1);
1131
1132 let encoder = HuffmanEncoder::from_frequencies(&freqs).unwrap();
1133 let avg_len = encoder.average_code_length(&freqs);
1134
1135 let entropy = HuffmanEncoder::entropy(&freqs);
1137 assert!((avg_len - entropy).abs() < 0.5);
1138 }
1139
1140 #[test]
1141 fn test_arithmetic_basic() {
1142 let mut freqs = HashMap::new();
1143 freqs.insert(0, 10);
1144 freqs.insert(1, 5);
1145 freqs.insert(2, 2);
1146
1147 let mut encoder = ArithmeticEncoder::from_frequencies(freqs.clone());
1148 let symbols = vec![0, 1, 2, 0, 0];
1149
1150 let encoded = encoder.encode(&symbols, false).unwrap();
1151
1152 let decoder = ArithmeticDecoder::new(freqs);
1153 let decoded = decoder.decode(&encoded).unwrap();
1154
1155 assert_eq!(decoded, symbols);
1156 }
1157
1158 #[test]
1159 fn test_arithmetic_adaptive() {
1160 let mut encoder = ArithmeticEncoder::new(4); let symbols = vec![0, 0, 0, 1, 1, 2, 3];
1162
1163 let encoded = encoder.encode(&symbols, true).unwrap();
1164
1165 let mut encoder2 = ArithmeticEncoder::new(4);
1168 let encoded2 = encoder2.encode(&symbols, false).unwrap();
1169
1170 assert!(encoded.len() >= 12); assert!(encoded2.len() >= 12);
1172 }
1173
1174 #[test]
1175 fn test_compute_frequencies() {
1176 let symbols = vec![0, 0, 1, 2, 0, 1];
1177 let freqs = compute_frequencies(&symbols);
1178
1179 assert_eq!(*freqs.get(&0).unwrap(), 3);
1180 assert_eq!(*freqs.get(&1).unwrap(), 2);
1181 assert_eq!(*freqs.get(&2).unwrap(), 1);
1182 }
1183
1184 #[test]
1185 fn test_compression_ratio() {
1186 let ratio = compression_ratio(800, 50);
1187 assert!((ratio - 2.0).abs() < 0.01);
1188 }
1189
1190 #[test]
1191 fn test_entropy() {
1192 let mut freqs = HashMap::new();
1193 freqs.insert(0, 2);
1194 freqs.insert(1, 2);
1195
1196 let entropy = HuffmanEncoder::entropy(&freqs);
1197 assert!((entropy - 1.0).abs() < 0.01); }
1199
1200 #[test]
1201 fn test_range_coding_basic() {
1202 let mut freqs = HashMap::new();
1203 freqs.insert(0, 10);
1204 freqs.insert(1, 5);
1205 freqs.insert(2, 2);
1206
1207 let encoder = RangeEncoder::from_frequencies(freqs.clone()).unwrap();
1208 let symbols = vec![0, 1, 2, 0, 0, 1];
1209
1210 let encoded = encoder.encode(&symbols).unwrap();
1211
1212 let decoder = RangeDecoder::from_frequencies(freqs).unwrap();
1213 let decoded = decoder.decode(&encoded).unwrap();
1214
1215 assert_eq!(decoded, symbols);
1216 }
1217
1218 #[test]
1219 fn test_range_coding_single_symbol() {
1220 let mut freqs = HashMap::new();
1221 freqs.insert(42, 100);
1222
1223 let encoder = RangeEncoder::from_frequencies(freqs.clone()).unwrap();
1224 let symbols = vec![42, 42, 42, 42];
1225
1226 let encoded = encoder.encode(&symbols).unwrap();
1227
1228 let decoder = RangeDecoder::from_frequencies(freqs).unwrap();
1229 let decoded = decoder.decode(&encoded).unwrap();
1230
1231 assert_eq!(decoded, symbols);
1232 }
1233
1234 #[test]
1235 #[ignore] fn test_range_coding_compression() {
1237 let mut freqs = HashMap::new();
1238 freqs.insert(0, 50);
1239 freqs.insert(1, 30);
1240 freqs.insert(2, 15);
1241 freqs.insert(3, 5);
1242
1243 let encoder = RangeEncoder::from_frequencies(freqs.clone()).unwrap();
1244
1245 let symbols: Vec<u32> = (0..100)
1247 .map(|i| {
1248 if i < 50 {
1249 0
1250 } else if i < 80 {
1251 1
1252 } else if i < 95 {
1253 2
1254 } else {
1255 3
1256 }
1257 })
1258 .collect();
1259
1260 let encoded = encoder.encode(&symbols).unwrap();
1261
1262 let original_bits = symbols.len() * 2; let compressed_bytes = encoded.len() - 4; assert!(compressed_bytes * 8 < original_bits);
1268
1269 let decoder = RangeDecoder::from_frequencies(freqs).unwrap();
1271 let decoded = decoder.decode(&encoded).unwrap();
1272 assert_eq!(decoded, symbols);
1273 }
1274
1275 #[test]
1276 #[ignore] fn test_range_coding_long_sequence() {
1278 let mut freqs = HashMap::new();
1279 freqs.insert(0, 40);
1280 freqs.insert(1, 30);
1281 freqs.insert(2, 20);
1282 freqs.insert(3, 10);
1283
1284 let encoder = RangeEncoder::from_frequencies(freqs.clone()).unwrap();
1285
1286 let symbols: Vec<u32> = (0..1000).map(|i| (i % 4) as u32).collect();
1288
1289 let encoded = encoder.encode(&symbols).unwrap();
1290
1291 let decoder = RangeDecoder::from_frequencies(freqs).unwrap();
1292 let decoded = decoder.decode(&encoded).unwrap();
1293
1294 assert_eq!(decoded, symbols);
1295 }
1296
1297 #[test]
1298 fn test_bitrate_controller_basic() {
1299 let controller = BitrateController::new(4.0, 1.0, 0.1, 0.01).unwrap();
1300
1301 assert_eq!(controller.target_bitrate(), 4.0);
1302 assert_eq!(controller.current_step(), 1.0);
1303 }
1304
1305 #[test]
1306 fn test_bitrate_controller_update_increase() {
1307 let mut controller = BitrateController::new(4.0, 1.0, 0.1, 0.01).unwrap();
1308
1309 let initial_step = controller.current_step();
1311 let new_step = controller.update(5.0); assert!(new_step > initial_step);
1314 }
1315
1316 #[test]
1317 fn test_bitrate_controller_update_decrease() {
1318 let mut controller = BitrateController::new(4.0, 1.0, 0.1, 0.01).unwrap();
1319
1320 let initial_step = controller.current_step();
1322 let new_step = controller.update(3.0); assert!(new_step < initial_step);
1325 }
1326
1327 #[test]
1328 fn test_bitrate_controller_convergence() {
1329 let mut controller = BitrateController::new(4.0, 1.0, 0.1, 0.01).unwrap();
1330
1331 for _ in 0..10 {
1333 controller.update(4.5); }
1335
1336 assert!(controller.current_step() > 1.0);
1338 }
1339
1340 #[test]
1341 fn test_bitrate_controller_reset() {
1342 let mut controller = BitrateController::new(4.0, 1.0, 0.1, 0.01).unwrap();
1343
1344 controller.update(5.0);
1345 controller.update(6.0);
1346
1347 controller.reset();
1348
1349 assert_eq!(controller.current_bitrate(), 4.0);
1350 }
1351
1352 #[test]
1353 fn test_bitrate_controller_set_target() {
1354 let mut controller = BitrateController::new(4.0, 1.0, 0.1, 0.01).unwrap();
1355
1356 controller.set_target(8.0).unwrap();
1357 assert_eq!(controller.target_bitrate(), 8.0);
1358 }
1359
1360 #[test]
1361 fn test_bitrate_controller_invalid_target() {
1362 assert!(BitrateController::new(0.0, 1.0, 0.1, 0.01).is_err());
1363 assert!(BitrateController::new(-1.0, 1.0, 0.1, 0.01).is_err());
1364 }
1365
1366 #[test]
1367 fn test_bitrate_controller_invalid_step() {
1368 assert!(BitrateController::new(4.0, 0.0, 0.1, 0.01).is_err());
1369 assert!(BitrateController::new(4.0, -1.0, 0.1, 0.01).is_err());
1370 }
1371
1372 #[test]
1373 fn test_bitrate_controller_step_clamping() {
1374 let mut controller = BitrateController::new(4.0, 1.0, 0.5, 0.1).unwrap();
1375
1376 for _ in 0..100 {
1378 controller.update(20.0); }
1380
1381 assert!(controller.current_step() <= 10.0);
1383
1384 controller.reset();
1385
1386 for _ in 0..100 {
1388 controller.update(0.5); }
1390
1391 assert!(controller.current_step() >= 0.1);
1393 }
1394}