1use haagenti_core::{Error, Result};
14
15fn read_bits_from_slice(data: &[u8], bit_pos: &mut usize, n: usize) -> Result<u32> {
18 if n == 0 {
19 return Ok(0);
20 }
21 if n > 32 {
22 return Err(Error::corrupted("Cannot read more than 32 bits at once"));
23 }
24
25 let mut result = 0u32;
26 let mut bits_read = 0;
27
28 while bits_read < n {
29 let byte_idx = *bit_pos / 8;
30 let bit_offset = *bit_pos % 8;
31
32 if byte_idx >= data.len() {
33 return Err(Error::unexpected_eof(byte_idx));
34 }
35
36 let byte = data[byte_idx];
37 let available = 8 - bit_offset;
38 let to_read = (n - bits_read).min(available);
39
40 let mask = ((1u32 << to_read) - 1) as u8;
42 let bits = (byte >> bit_offset) & mask;
43
44 result |= (bits as u32) << bits_read;
45 bits_read += to_read;
46 *bit_pos += to_read;
47 }
48
49 Ok(result)
50}
51
52#[derive(Debug, Clone, Copy, PartialEq, Eq)]
58#[repr(C)]
59pub struct FseTableEntry {
60 pub baseline: u16,
62 pub num_bits: u8,
64 pub symbol: u8,
66 pub seq_base: u32,
69 pub seq_extra_bits: u8,
71 _pad: [u8; 3],
73}
74
75impl FseTableEntry {
76 #[inline]
78 pub const fn new(symbol: u8, num_bits: u8, baseline: u16) -> Self {
79 Self {
80 symbol,
81 num_bits,
82 baseline,
83 seq_base: 0,
84 seq_extra_bits: 0,
85 _pad: [0; 3],
86 }
87 }
88
89 #[inline]
92 pub const fn new_seq(
93 symbol: u8,
94 num_bits: u8,
95 baseline: u16,
96 seq_base: u32,
97 seq_extra_bits: u8,
98 ) -> Self {
99 Self {
100 symbol,
101 num_bits,
102 baseline,
103 seq_base,
104 seq_extra_bits,
105 _pad: [0; 3],
106 }
107 }
108}
109
110impl Default for FseTableEntry {
111 fn default() -> Self {
112 Self::new(0, 0, 0)
113 }
114}
115
116#[derive(Debug, Clone)]
121pub struct FseTable {
122 entries: Vec<FseTableEntry>,
124 accuracy_log: u8,
126 max_symbol: u8,
128}
129
130impl FseTable {
131 pub fn build(normalized_freqs: &[i16], accuracy_log: u8, max_symbol: u8) -> Result<Self> {
141 if accuracy_log > 15 {
142 return Err(Error::corrupted("FSE accuracy log exceeds maximum of 15"));
143 }
144
145 let table_size = 1usize << accuracy_log;
146
147 let mut freq_sum: i32 = 0;
150 for &f in normalized_freqs.iter() {
151 if f == -1 {
152 freq_sum += 1; } else {
154 freq_sum += f as i32;
155 }
156 }
157 if freq_sum != table_size as i32 {
158 return Err(Error::corrupted(format!(
159 "FSE frequencies sum to {} but expected {}",
160 freq_sum, table_size
161 )));
162 }
163
164 let mut entries = vec![FseTableEntry::new(0, 0, 0); table_size];
165
166 let mut high_threshold = table_size;
169 for (symbol, &freq) in normalized_freqs.iter().enumerate() {
170 if freq == -1 {
171 high_threshold -= 1;
172 entries[high_threshold] = FseTableEntry::new(symbol as u8, accuracy_log, 0);
173 }
174 }
175
176 let mut position = 0;
178 let step = (table_size >> 1) + (table_size >> 3) + 3;
179 let mask = table_size - 1;
180
181 for (symbol, &freq) in normalized_freqs.iter().enumerate() {
182 if freq <= 0 {
183 continue; }
185
186 for _ in 0..freq {
187 entries[position].symbol = symbol as u8;
188 loop {
190 position = (position + step) & mask;
191 if position < high_threshold {
192 break;
193 }
194 }
195 }
196 }
197
198 let mut symbol_next: Vec<u32> = normalized_freqs
217 .iter()
218 .map(|&f| if f == -1 { 1 } else { f.max(0) as u32 })
219 .collect();
220
221 for entry in entries.iter_mut() {
223 let symbol = entry.symbol as usize;
224 let freq = normalized_freqs.get(symbol).copied().unwrap_or(0);
225
226 if freq == -1 {
227 entry.num_bits = accuracy_log;
230 entry.baseline = 0;
231 } else if freq > 0 && symbol < symbol_next.len() {
232 let next_state = symbol_next[symbol];
234 symbol_next[symbol] += 1;
235
236 let high_bit = 31 - next_state.leading_zeros();
240 let nb_bits = (accuracy_log as u32).saturating_sub(high_bit) as u8;
241
242 let baseline = ((next_state << nb_bits) as i32 - table_size as i32).max(0) as u16;
244
245 entry.num_bits = nb_bits;
246 entry.baseline = baseline;
247 }
248 }
249
250 Ok(Self {
251 entries,
252 accuracy_log,
253 max_symbol,
254 })
255 }
256
257 pub fn from_predefined(distribution: &[i16], accuracy_log: u8) -> Result<Self> {
263 if accuracy_log == 5 && distribution.len() == 29 {
265 return Self::from_hardcoded_of();
267 }
268 if accuracy_log == 6 && distribution.len() == 36 {
269 return Self::from_hardcoded_ll();
271 }
272 if accuracy_log == 6 && distribution.len() == 53 {
273 return Self::from_hardcoded_ml();
275 }
276
277 let max_symbol = distribution.len().saturating_sub(1) as u8;
279 Self::build(distribution, accuracy_log, max_symbol)
280 }
281
282 pub fn from_hardcoded_of() -> Result<Self> {
284 let entries: Vec<FseTableEntry> = OF_PREDEFINED_TABLE
285 .iter()
286 .map(|&(symbol, num_bits, baseline)| FseTableEntry::new(symbol, num_bits, baseline))
287 .collect();
288 Ok(Self {
289 entries,
290 accuracy_log: 5,
291 max_symbol: 31,
292 })
293 }
294
295 pub fn from_hardcoded_ll() -> Result<Self> {
297 let entries: Vec<FseTableEntry> = LL_PREDEFINED_TABLE
298 .iter()
299 .map(|&(symbol, num_bits, baseline)| FseTableEntry::new(symbol, num_bits, baseline))
300 .collect();
301 Ok(Self {
302 entries,
303 accuracy_log: 6,
304 max_symbol: 35,
305 })
306 }
307
308 pub fn from_hardcoded_ml() -> Result<Self> {
316 let entries: Vec<FseTableEntry> = ML_PREDEFINED_TABLE
317 .iter()
318 .map(|&(symbol, num_bits, baseline)| {
319 let (seq_extra_bits, seq_base) = if (symbol as usize) < ML_BASELINE_TABLE.len() {
321 ML_BASELINE_TABLE[symbol as usize]
322 } else {
323 (0, 3) };
325 FseTableEntry::new_seq(symbol, num_bits, baseline, seq_base, seq_extra_bits)
326 })
327 .collect();
328 Ok(Self {
329 entries,
330 accuracy_log: 6,
331 max_symbol: 52,
332 })
333 }
334
335 pub fn parse(data: &[u8], max_symbol: u8) -> Result<(Self, usize)> {
346 if data.is_empty() {
347 return Err(Error::corrupted("Empty FSE table header"));
348 }
349
350 let mut bit_pos: usize = 0;
351
352 let accuracy_log_raw = read_bits_from_slice(data, &mut bit_pos, 4)? as u8;
354 let accuracy_log = accuracy_log_raw + 5;
355
356 if accuracy_log > 15 {
357 return Err(Error::corrupted(format!(
358 "FSE accuracy log {} exceeds maximum 15",
359 accuracy_log
360 )));
361 }
362
363 let table_size = 1i32 << accuracy_log;
364 let mut remaining = table_size;
365 let mut probabilities = Vec::with_capacity((max_symbol + 1) as usize);
366 let mut symbol = 0u8;
367
368 while remaining > 0 && symbol <= max_symbol {
370 let max_bits = 32 - (remaining + 1).leading_zeros();
372 let threshold = (1i32 << max_bits) - 1 - remaining;
373
374 let small = read_bits_from_slice(data, &mut bit_pos, (max_bits - 1) as usize)? as i32;
376
377 let prob = if small < threshold {
378 small
379 } else {
380 let extra = read_bits_from_slice(data, &mut bit_pos, 1)? as i32;
381 let large = (small << 1) + extra - threshold;
382 if large < (1 << (max_bits - 1)) {
383 large
384 } else {
385 large - (1 << max_bits)
386 }
387 };
388
389 let normalized_prob = if prob == 0 {
391 remaining -= 1;
393 -1i16
394 } else {
395 remaining -= prob;
396 prob as i16
397 };
398
399 probabilities.push(normalized_prob);
400 symbol += 1;
401
402 if prob == 0 {
404 loop {
406 let repeat = read_bits_from_slice(data, &mut bit_pos, 2)? as usize;
407 for _ in 0..repeat {
408 if symbol <= max_symbol {
409 probabilities.push(0);
410 symbol += 1;
411 }
412 }
413 if repeat < 3 {
414 break;
415 }
416 }
417 }
418 }
419
420 while probabilities.len() <= max_symbol as usize {
422 probabilities.push(0);
423 }
424
425 if remaining != 0 {
427 return Err(Error::corrupted(format!(
428 "FSE table probabilities don't sum correctly: remaining={}",
429 remaining
430 )));
431 }
432
433 let bytes_consumed = bit_pos.div_ceil(8);
435
436 let table = Self::build(&probabilities, accuracy_log, max_symbol)?;
437 Ok((table, bytes_consumed))
438 }
439
440 #[inline]
442 pub fn size(&self) -> usize {
443 self.entries.len()
444 }
445
446 #[inline]
448 pub fn accuracy_log(&self) -> u8 {
449 self.accuracy_log
450 }
451
452 #[inline]
454 pub fn decode(&self, state: usize) -> &FseTableEntry {
455 &self.entries[state]
456 }
457
458 #[inline]
460 pub fn state_mask(&self) -> usize {
461 (1 << self.accuracy_log) - 1
462 }
463
464 #[inline]
471 pub fn is_valid(&self) -> bool {
472 if self.entries.is_empty() {
473 return false;
474 }
475 if self.accuracy_log == 0 || self.accuracy_log > 15 {
476 return false;
477 }
478 self.entries.iter().all(|e| e.symbol <= self.max_symbol)
480 }
481
482 #[inline]
484 pub fn max_symbol(&self) -> u8 {
485 self.max_symbol
486 }
487
488 pub fn is_rle_mode(&self) -> bool {
493 if self.entries.is_empty() {
494 return false;
495 }
496 let first_symbol = self.entries[0].symbol;
497 self.entries.iter().all(|e| e.symbol == first_symbol)
498 }
499
500 pub fn from_frequencies(frequencies: &[u32], min_accuracy_log: u8) -> Result<(Self, Vec<i16>)> {
504 let max_symbol = frequencies
505 .iter()
506 .enumerate()
507 .rev()
508 .find(|&(_, f)| *f > 0)
509 .map(|(i, _)| i)
510 .unwrap_or(0);
511
512 let total: u32 = frequencies.iter().sum();
513 if total == 0 {
514 return Err(Error::corrupted("No symbols to encode"));
515 }
516
517 let accuracy_log = min_accuracy_log.clamp(5, FSE_MAX_ACCURACY_LOG);
520 let table_size = 1u32 << accuracy_log;
521
522 let mut normalized = vec![0i16; max_symbol + 1];
524 let mut distributed = 0u32;
525
526 for (i, &freq) in frequencies.iter().take(max_symbol + 1).enumerate() {
528 if freq > 0 {
529 let share = ((freq as u64 * table_size as u64) / total as u64) as u32;
531 if share == 0 {
532 normalized[i] = -1;
534 distributed += 1;
535 } else {
536 normalized[i] = share as i16;
537 distributed += share;
538 }
539 }
540 }
541
542 while distributed < table_size {
544 let mut best_idx = 0;
546 let mut best_freq = 0;
547 for (i, &freq) in frequencies.iter().take(max_symbol + 1).enumerate() {
548 if freq > best_freq && normalized[i] > 0 {
549 best_freq = freq;
550 best_idx = i;
551 }
552 }
553 if best_freq == 0 {
554 break;
555 }
556 normalized[best_idx] += 1;
557 distributed += 1;
558 }
559
560 while distributed > table_size {
561 let mut best_idx = 0;
563 let mut best_assigned = 0i16;
564 for (i, &n) in normalized.iter().enumerate() {
565 if n > best_assigned {
566 best_assigned = n;
567 best_idx = i;
568 }
569 }
570 if best_assigned <= 1 {
571 break;
572 }
573 normalized[best_idx] -= 1;
574 distributed -= 1;
575 }
576
577 let table = Self::build(&normalized, accuracy_log, max_symbol as u8)?;
578 Ok((table, normalized))
579 }
580
581 pub fn from_frequencies_serializable(
593 frequencies: &[u32],
594 min_accuracy_log: u8,
595 ) -> Result<(Self, Vec<i16>)> {
596 let max_symbol = frequencies
597 .iter()
598 .enumerate()
599 .rev()
600 .find(|&(_, f)| *f > 0)
601 .map(|(i, _)| i)
602 .unwrap_or(0);
603
604 let total: u32 = frequencies.iter().sum();
605 if total == 0 {
606 return Err(Error::corrupted("No symbols to encode"));
607 }
608
609 let accuracy_log = min_accuracy_log.clamp(5, FSE_MAX_ACCURACY_LOG);
610 let table_size = 1u32 << accuracy_log;
611
612 let mut normalized = vec![0i16; max_symbol + 1];
614 let mut distributed = 0u32;
615
616 for (i, &freq) in frequencies.iter().take(max_symbol + 1).enumerate() {
617 if freq > 0 {
618 let share = ((freq as u64 * table_size as u64) / total as u64) as u32;
619 if share == 0 {
620 normalized[i] = -1;
621 distributed += 1;
622 } else {
623 normalized[i] = share as i16;
624 distributed += share;
625 }
626 }
627 }
628
629 while distributed < table_size {
631 let mut best_idx = 0;
632 let mut best_freq = 0;
633 for (i, &freq) in frequencies.iter().take(max_symbol + 1).enumerate() {
634 if freq > best_freq && normalized[i] > 0 {
635 best_freq = freq;
636 best_idx = i;
637 }
638 }
639 if best_freq == 0 {
640 break;
641 }
642 normalized[best_idx] += 1;
643 distributed += 1;
644 }
645
646 while distributed > table_size {
647 let mut best_idx = 0;
648 let mut best_assigned = 0i16;
649 for (i, &n) in normalized.iter().enumerate() {
650 if n > best_assigned {
651 best_assigned = n;
652 best_idx = i;
653 }
654 }
655 if best_assigned <= 1 {
656 break;
657 }
658 normalized[best_idx] -= 1;
659 distributed -= 1;
660 }
661
662 let mut gaps_to_fill = Vec::new();
666 let mut in_gap = false;
667 for (i, &norm_val) in normalized.iter().enumerate() {
668 if norm_val == 0 {
669 if !in_gap {
670 gaps_to_fill.push(i);
671 in_gap = true;
672 }
673 } else {
674 in_gap = false;
675 }
676 }
677
678 for gap_start in gaps_to_fill {
679 let mut donor_idx = None;
681 for (i, &p) in normalized.iter().enumerate() {
682 if p > 1 {
683 donor_idx = Some(i);
684 break;
685 }
686 }
687 if let Some(donor) = donor_idx {
688 normalized[donor] -= 1;
689 normalized[gap_start] = -1;
690 }
691 }
692
693 let last_positive_idx = normalized
696 .iter()
697 .enumerate()
698 .rev()
699 .find(|&(_, &p)| p > 0)
700 .map(|(i, _)| i);
701
702 if let Some(last_idx) = last_positive_idx {
703 let last_prob = normalized[last_idx] as i32;
704
705 let needs_padding = {
707 let mut remaining = table_size as i32;
708 let mut need_fix = false;
709 for &prob in &normalized {
710 if prob == 0 {
711 continue;
712 }
713 let prob_val = if prob == -1 { 1 } else { prob as i32 };
714 let max_bits = 32 - (remaining + 1).leading_zeros();
715 let max_positive = (1i32 << (max_bits - 1)) - 1;
716 if prob > 0 && prob as i32 > max_positive {
717 need_fix = true;
718 break;
719 }
720 remaining -= prob_val;
721 }
722 need_fix
723 };
724
725 if needs_padding && last_prob > 0 {
726 let trailing_count = last_prob as usize;
727
728 let mut donor_idx = None;
730 for (i, &p) in normalized.iter().enumerate() {
731 if p > trailing_count as i16 {
732 donor_idx = Some(i);
733 break;
734 }
735 }
736
737 if let Some(donor) = donor_idx {
738 normalized[donor] -= trailing_count as i16;
739 for _ in 0..trailing_count {
740 normalized.push(-1);
741 }
742 let new_max_symbol = normalized.len() - 1;
743 let table = Self::build(&normalized, accuracy_log, new_max_symbol as u8)?;
744 return Ok((table, normalized));
745 }
746 }
747 }
748
749 let table = Self::build(&normalized, accuracy_log, max_symbol as u8)?;
750 Ok((table, normalized))
751 }
752
753 pub fn serialize(&self, normalized: &[i16]) -> Vec<u8> {
759 let mut bits = FseTableSerializer::new();
760
761 bits.write_bits((self.accuracy_log - 5) as u32, 4);
763
764 let table_size = 1i32 << self.accuracy_log;
765 let mut remaining = table_size;
766 let mut symbol = 0usize;
767
768 while symbol < normalized.len() && remaining > 0 {
774 let prob = normalized[symbol];
775
776 let max_bits = 32 - (remaining + 1).leading_zeros();
778 let threshold = (1i32 << max_bits) - 1 - remaining;
779
780 let encoded_prob = if prob == -1 { 0 } else { prob as i32 };
782
783 if encoded_prob < threshold {
788 bits.write_bits(encoded_prob as u32, (max_bits - 1) as u8);
789 } else {
790 let combined = encoded_prob + threshold;
794 let small = combined >> 1;
795 let extra = combined & 1;
796 bits.write_bits(small as u32, (max_bits - 1) as u8);
797 bits.write_bits(extra as u32, 1);
798 }
799
800 if prob == -1 {
802 remaining -= 1;
803 } else if prob > 0 {
804 remaining -= prob as i32;
805 }
806
807 symbol += 1;
808
809 if prob == -1 || prob == 0 {
813 let mut zeros = 0usize;
815 while symbol + zeros < normalized.len() && normalized[symbol + zeros] == 0 {
816 zeros += 1;
817 }
818
819 let mut zeros_left = zeros;
823 loop {
824 if zeros_left >= 3 {
825 bits.write_bits(3, 2);
826 zeros_left -= 3;
827 } else {
828 bits.write_bits(zeros_left as u32, 2);
829 break;
830 }
831 }
832
833 symbol += zeros;
835 }
836 }
837
838 bits.finish()
839 }
840}
841
842pub const FSE_MAX_ACCURACY_LOG: u8 = 15;
844
845struct FseTableSerializer {
847 buffer: Vec<u8>,
848 current_byte: u8,
849 bits_in_byte: u8,
850}
851
852impl FseTableSerializer {
853 fn new() -> Self {
854 Self {
855 buffer: Vec::new(),
856 current_byte: 0,
857 bits_in_byte: 0,
858 }
859 }
860
861 fn write_bits(&mut self, value: u32, num_bits: u8) {
862 let mut remaining_bits = num_bits;
863 let mut remaining_value = value;
864
865 while remaining_bits > 0 {
866 let bits_to_write = remaining_bits.min(8 - self.bits_in_byte);
867 let mask = (1u32 << bits_to_write) - 1;
868 let bits = (remaining_value & mask) as u8;
869
870 self.current_byte |= bits << self.bits_in_byte;
871 self.bits_in_byte += bits_to_write;
872
873 if self.bits_in_byte == 8 {
874 self.buffer.push(self.current_byte);
875 self.current_byte = 0;
876 self.bits_in_byte = 0;
877 }
878
879 remaining_bits -= bits_to_write;
880 remaining_value >>= bits_to_write;
881 }
882 }
883
884 fn finish(mut self) -> Vec<u8> {
885 if self.bits_in_byte > 0 {
886 self.buffer.push(self.current_byte);
887 }
888 self.buffer
889 }
890}
891
892pub const LITERAL_LENGTH_DEFAULT_DISTRIBUTION: [i16; 36] = [
899 4, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 1, 1, 1, 1, 1,
900 -1, -1, -1, -1,
901];
902
903pub const MATCH_LENGTH_DEFAULT_DISTRIBUTION: [i16; 53] = [
906 1, 4, 3, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
907 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1,
908];
909
910pub const OFFSET_DEFAULT_DISTRIBUTION: [i16; 29] = [
913 1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1,
914];
915
916const ML_BASELINE_TABLE: [(u8, u32); 53] = [
934 (0, 3),
938 (0, 4),
939 (0, 5),
940 (0, 6),
941 (0, 7),
942 (0, 8),
943 (0, 9),
944 (0, 10),
945 (0, 11),
946 (0, 12),
947 (0, 13),
948 (0, 14),
949 (0, 15),
950 (0, 16),
951 (0, 17),
952 (0, 18),
953 (0, 19),
954 (0, 20),
955 (0, 21),
956 (0, 22),
957 (0, 23),
958 (0, 24),
959 (0, 25),
960 (0, 26),
961 (0, 27),
962 (0, 28),
963 (0, 29),
964 (0, 30),
965 (0, 31),
966 (0, 32),
967 (0, 33),
968 (0, 34),
969 (1, 35),
971 (1, 37),
972 (1, 39),
973 (1, 41),
974 (2, 43),
976 (2, 47),
977 (3, 51),
979 (3, 59),
980 (4, 67),
982 (4, 83),
983 (5, 99),
985 (7, 131),
987 (8, 259),
989 (9, 515),
991 (10, 1027),
993 (11, 2051),
995 (12, 4099),
997 (13, 8195),
999 (14, 16387),
1001 (15, 32771),
1003 (16, 65539),
1005];
1006
1007#[allow(dead_code)]
1012fn ml_code_from_direct(seq_base: u32, seq_extra_bits: u8) -> u8 {
1013 for (code, &(bits, baseline)) in ML_BASELINE_TABLE.iter().enumerate() {
1015 if bits == seq_extra_bits && baseline == seq_base {
1016 return code as u8;
1017 }
1018 }
1019
1020 if seq_extra_bits == 0 && (3..=34).contains(&seq_base) {
1022 return (seq_base - 3) as u8;
1023 }
1024
1025 for (code, &(bits, baseline)) in ML_BASELINE_TABLE.iter().enumerate() {
1028 if bits == seq_extra_bits {
1029 if baseline == seq_base {
1031 return code as u8;
1032 }
1033 }
1034 }
1035
1036 match seq_extra_bits {
1039 0 => ((seq_base.saturating_sub(3)).min(31)) as u8,
1040 1 => 32 + ((seq_base.saturating_sub(35)) / 2).min(3) as u8,
1041 2 => 36 + if seq_base >= 47 { 1 } else { 0 },
1042 3 => 38 + if seq_base >= 59 { 1 } else { 0 },
1043 4 => 40 + if seq_base >= 83 { 1 } else { 0 },
1044 5 => 42, 7 => 43, 8 => 44, 9 => 45, 10 => 46,
1049 11 => 47,
1050 12 => 48,
1051 13 => 49,
1052 14 => 50,
1053 15 => 51,
1054 16 => 52,
1055 _ => 52.min(42 + seq_extra_bits.saturating_sub(5)),
1056 }
1057}
1058
1059const OF_PREDEFINED_TABLE: [(u8, u8, u16); 32] = [
1063 (0, 5, 0),
1064 (6, 4, 0),
1065 (9, 5, 0),
1066 (15, 5, 0), (21, 5, 0),
1068 (3, 5, 0),
1069 (7, 4, 0),
1070 (12, 5, 0), (18, 5, 0),
1072 (23, 5, 0),
1073 (5, 5, 0),
1074 (8, 4, 0), (14, 5, 0),
1076 (20, 5, 0),
1077 (2, 5, 0),
1078 (7, 4, 16), (11, 5, 0),
1080 (17, 5, 0),
1081 (22, 5, 0),
1082 (4, 5, 0), (8, 4, 16),
1084 (13, 5, 0),
1085 (19, 5, 0),
1086 (1, 5, 0), (6, 4, 16),
1088 (10, 5, 0),
1089 (16, 5, 0),
1090 (28, 5, 0), (27, 5, 0),
1092 (26, 5, 0),
1093 (25, 5, 0),
1094 (24, 5, 0), ];
1096
1097const LL_PREDEFINED_TABLE: [(u8, u8, u16); 64] = [
1101 (0, 4, 0),
1102 (0, 4, 16),
1103 (1, 5, 32),
1104 (3, 5, 0), (4, 5, 0),
1106 (6, 5, 0),
1107 (7, 5, 0),
1108 (9, 5, 0), (10, 5, 0),
1110 (12, 5, 0),
1111 (14, 6, 0),
1112 (16, 5, 0), (18, 5, 0),
1114 (19, 5, 0),
1115 (21, 5, 0),
1116 (22, 5, 0), (24, 5, 0),
1118 (25, 6, 0),
1119 (26, 5, 0),
1120 (27, 6, 0), (29, 6, 0),
1122 (31, 6, 0),
1123 (0, 4, 32),
1124 (1, 4, 0), (2, 5, 0),
1126 (4, 5, 32),
1127 (5, 5, 0),
1128 (7, 5, 32), (8, 5, 0),
1130 (10, 5, 32),
1131 (11, 5, 0),
1132 (13, 6, 0), (16, 5, 32),
1134 (17, 5, 0),
1135 (19, 5, 32),
1136 (20, 5, 0), (22, 5, 32),
1138 (23, 5, 0),
1139 (25, 4, 0),
1140 (25, 4, 16), (26, 5, 32),
1142 (28, 6, 0),
1143 (30, 6, 0),
1144 (0, 4, 48), (1, 4, 16),
1146 (2, 5, 32),
1147 (3, 5, 32),
1148 (5, 5, 32), (6, 5, 32),
1150 (8, 5, 32),
1151 (9, 5, 32),
1152 (11, 5, 32), (12, 5, 32),
1154 (15, 6, 0),
1155 (17, 5, 32),
1156 (18, 5, 32), (20, 5, 32),
1158 (21, 5, 32),
1159 (23, 5, 32),
1160 (24, 5, 32), (35, 6, 0),
1162 (34, 6, 0),
1163 (33, 6, 0),
1164 (32, 6, 0), ];
1166
1167const ML_PREDEFINED_TABLE: [(u8, u8, u16); 64] = [
1171 (0, 6, 0),
1174 (1, 4, 0),
1175 (2, 5, 32),
1176 (3, 5, 0), (5, 5, 0),
1178 (6, 5, 0),
1179 (8, 5, 0),
1180 (10, 6, 0), (13, 6, 0),
1182 (16, 6, 0),
1183 (19, 6, 0),
1184 (22, 6, 0), (25, 6, 0),
1186 (28, 6, 0),
1187 (31, 6, 0),
1188 (33, 6, 0), (35, 6, 0),
1190 (37, 6, 0),
1191 (39, 6, 0),
1192 (41, 6, 0), (43, 6, 0),
1194 (45, 6, 0),
1195 (1, 4, 16),
1196 (2, 4, 0), (3, 5, 32),
1198 (4, 5, 0),
1199 (6, 5, 32),
1200 (7, 5, 0), (9, 6, 0),
1202 (12, 6, 0),
1203 (15, 6, 0),
1204 (18, 6, 0), (21, 6, 0),
1206 (24, 6, 0),
1207 (27, 6, 0),
1208 (30, 6, 0), (32, 6, 0),
1210 (34, 6, 0),
1211 (36, 6, 0),
1212 (38, 6, 0), (40, 6, 0),
1214 (42, 6, 0),
1215 (44, 6, 0),
1216 (1, 4, 32), (1, 4, 48),
1218 (2, 4, 16),
1219 (4, 5, 32),
1220 (5, 5, 32), (7, 5, 32),
1222 (8, 5, 32),
1223 (11, 6, 0),
1224 (14, 6, 0), (17, 6, 0),
1226 (20, 6, 0),
1227 (23, 6, 0),
1228 (26, 6, 0), (29, 6, 0),
1230 (52, 6, 0),
1231 (51, 6, 0),
1232 (50, 6, 0), (49, 6, 0),
1234 (48, 6, 0),
1235 (47, 6, 0),
1236 (46, 6, 0), ];
1238
1239#[cfg(test)]
1244mod tests {
1245 use super::*;
1246
1247 #[test]
1248 fn test_fse_table_entry_creation() {
1249 let entry = FseTableEntry::new(5, 3, 100);
1250 assert_eq!(entry.symbol, 5);
1251 assert_eq!(entry.num_bits, 3);
1252 assert_eq!(entry.baseline, 100);
1253 }
1254
1255 #[test]
1256 fn test_simple_distribution() {
1257 let distribution = [2i16, 2];
1261 let table = FseTable::build(&distribution, 2, 1).unwrap();
1262
1263 assert_eq!(table.size(), 4);
1264 assert_eq!(table.accuracy_log(), 2);
1265
1266 for i in 0..4 {
1268 let entry = table.decode(i);
1269 assert!(entry.symbol <= 1);
1270 }
1271 }
1272
1273 #[test]
1274 fn test_unequal_distribution() {
1275 let distribution = [6i16, 2];
1278 let table = FseTable::build(&distribution, 3, 1).unwrap();
1279
1280 assert_eq!(table.size(), 8);
1281
1282 let mut counts = [0usize; 2];
1284 for i in 0..8 {
1285 let entry = table.decode(i);
1286 counts[entry.symbol as usize] += 1;
1287 }
1288 assert_eq!(counts[0] + counts[1], 8);
1291 assert!(counts[0] >= counts[1]);
1293 }
1294
1295 #[test]
1296 fn test_less_than_one_probability() {
1297 let distribution = [8i16]; let table = FseTable::build(&distribution, 3, 0).unwrap();
1304
1305 assert_eq!(table.size(), 8);
1306
1307 for i in 0..8 {
1309 let entry = table.decode(i);
1310 assert_eq!(entry.symbol, 0);
1311 }
1312 }
1313
1314 #[test]
1315 fn test_predefined_literal_length_distribution() {
1316 let slot_sum: i32 = LITERAL_LENGTH_DEFAULT_DISTRIBUTION
1319 .iter()
1320 .map(|&f| if f == -1 { 1 } else { f as i32 })
1321 .sum();
1322 assert_eq!(slot_sum, 64); }
1324
1325 #[test]
1326 fn test_predefined_match_length_distribution() {
1327 let slot_sum: i32 = MATCH_LENGTH_DEFAULT_DISTRIBUTION
1328 .iter()
1329 .map(|&f| if f == -1 { 1 } else { f as i32 })
1330 .sum();
1331 assert_eq!(slot_sum, 64); }
1333
1334 #[test]
1335 fn test_predefined_offset_distribution() {
1336 let slot_sum: i32 = OFFSET_DEFAULT_DISTRIBUTION
1337 .iter()
1338 .map(|&f| if f == -1 { 1 } else { f as i32 })
1339 .sum();
1340 assert_eq!(slot_sum, 32); }
1342
1343 #[test]
1344 fn test_accuracy_log_too_high() {
1345 let distribution = [1i16; 65536];
1346 let result = FseTable::build(&distribution, 16, 255);
1347 assert!(result.is_err());
1348 }
1349
1350 #[test]
1351 fn test_frequency_sum_mismatch() {
1352 let distribution = [2i16, 1];
1354 let result = FseTable::build(&distribution, 2, 1);
1355 assert!(result.is_err());
1356 }
1357
1358 #[test]
1359 fn test_state_mask() {
1360 let distribution = [4i16, 4];
1361 let table = FseTable::build(&distribution, 3, 1).unwrap();
1362 assert_eq!(table.state_mask(), 0b111); }
1364
1365 #[test]
1366 fn test_decode_roundtrip_state_transitions() {
1367 let distribution = [4i16, 2, 2]; let table = FseTable::build(&distribution, 3, 2).unwrap();
1370
1371 for state in 0..table.size() {
1373 let entry = table.decode(state);
1374
1375 assert!(
1377 entry.symbol <= 2,
1378 "Invalid symbol {} at state {}",
1379 entry.symbol,
1380 state
1381 );
1382
1383 assert!(
1385 entry.num_bits <= table.accuracy_log(),
1386 "num_bits {} exceeds accuracy_log {} at state {}",
1387 entry.num_bits,
1388 table.accuracy_log(),
1389 state
1390 );
1391 }
1392 }
1393
1394 #[test]
1399 fn test_read_bits_from_slice_simple() {
1400 let data = [0b10110100];
1401 let mut pos = 0;
1402
1403 let low4 = super::read_bits_from_slice(&data, &mut pos, 4).unwrap();
1405 assert_eq!(low4, 0b0100);
1406 assert_eq!(pos, 4);
1407
1408 let high4 = super::read_bits_from_slice(&data, &mut pos, 4).unwrap();
1410 assert_eq!(high4, 0b1011);
1411 assert_eq!(pos, 8);
1412 }
1413
1414 #[test]
1415 fn test_read_bits_from_slice_cross_byte() {
1416 let data = [0xFF, 0x00];
1417 let mut pos = 4;
1418
1419 let cross = super::read_bits_from_slice(&data, &mut pos, 8).unwrap();
1421 assert_eq!(cross, 0x0F); }
1423
1424 #[test]
1425 fn test_read_bits_from_slice_zero() {
1426 let data = [0xFF];
1427 let mut pos = 0;
1428
1429 let zero = super::read_bits_from_slice(&data, &mut pos, 0).unwrap();
1430 assert_eq!(zero, 0);
1431 assert_eq!(pos, 0);
1432 }
1433
1434 #[test]
1435 fn test_fse_parse_empty() {
1436 let result = FseTable::parse(&[], 1);
1438 assert!(result.is_err());
1439 }
1440
1441 #[test]
1442 fn test_fse_parse_accuracy_log_too_high() {
1443 let data = [0x0B]; let result = FseTable::parse(&data, 1);
1446 assert!(result.is_err());
1447 }
1448
1449 #[test]
1466 #[ignore = "Fundamental FSE limitation: last symbol cannot use 100% of remaining"]
1467 fn test_serialize_parse_roundtrip_simple() {
1468 let distribution = [22i16, 10]; let table = FseTable::build(&distribution, 5, 1).unwrap();
1472
1473 println!("Simple test: accuracy_log={}", table.accuracy_log());
1474 println!("Distribution: {:?}", distribution);
1475
1476 let bytes = table.serialize(&distribution);
1477 println!("Serialized: {} bytes: {:02x?}", bytes.len(), bytes);
1478
1479 let result = FseTable::parse(&bytes, 1);
1484 match &result {
1485 Ok((parsed, consumed)) => {
1486 println!(
1487 "Parsed OK: consumed {} bytes, table size {}",
1488 consumed,
1489 parsed.size()
1490 );
1491 }
1492 Err(e) => println!("Parse error: {:?}", e),
1493 }
1494 assert!(result.is_ok(), "Simple parse should succeed");
1495 }
1496
1497 #[test]
1498 #[ignore = "Fundamental FSE limitation: sparse distributions hit 100% remaining issue"]
1499 fn test_serialize_parse_roundtrip_sparse() {
1500 let mut ll_freq = [0u32; 36];
1503 ll_freq[0] = 100; ll_freq[16] = 50; let (table, normalized) = FseTable::from_frequencies(&ll_freq, 5).unwrap();
1507
1508 println!("Table built: accuracy_log={}", table.accuracy_log());
1509 println!("Normalized: {:?}", normalized);
1510
1511 let sum: i32 = normalized
1513 .iter()
1514 .map(|&p| if p == -1 { 1 } else { p as i32 })
1515 .sum();
1516 let table_size = 1 << table.accuracy_log();
1517 println!("Sum: {}, table_size: {}", sum, table_size);
1518 assert_eq!(sum, table_size, "Normalized should sum to table_size");
1519
1520 let bytes = table.serialize(&normalized);
1522 println!("Serialized: {} bytes: {:02x?}", bytes.len(), bytes);
1523
1524 for (i, b) in bytes.iter().enumerate() {
1526 println!(" byte {}: {:02x} = {:08b}", i, b, b);
1527 }
1528
1529 let result = FseTable::parse(&bytes, 35);
1531 match &result {
1532 Ok((_, consumed)) => println!("Parsed OK: consumed {} bytes", consumed),
1533 Err(e) => println!("Parse error: {:?}", e),
1534 }
1535 assert!(result.is_ok(), "Parse should succeed");
1536 }
1537
1538 #[test]
1546 fn test_serialize_parse_roundtrip_with_padding() {
1547 let mut ll_freq = [0u32; 36];
1549 ll_freq[0] = 100; ll_freq[16] = 50; let (table, normalized) = FseTable::from_frequencies_serializable(&ll_freq, 5).unwrap();
1554
1555 println!("Table built: accuracy_log={}", table.accuracy_log());
1556 println!("Normalized (with padding): {:?}", normalized);
1557 println!("Symbol count: {} (original: 17)", normalized.len());
1558
1559 let sum: i32 = normalized
1561 .iter()
1562 .map(|&p| if p == -1 { 1 } else { p as i32 })
1563 .sum();
1564 let table_size = 1 << table.accuracy_log();
1565 println!("Sum: {}, table_size: {}", sum, table_size);
1566 assert_eq!(sum, table_size, "Normalized should sum to table_size");
1567
1568 let bytes = table.serialize(&normalized);
1570 println!("Serialized: {} bytes: {:02x?}", bytes.len(), bytes);
1571
1572 let max_symbol = (normalized.len() - 1) as u8;
1574 let result = FseTable::parse(&bytes, max_symbol);
1575 match &result {
1576 Ok((parsed, consumed)) => {
1577 println!(
1578 "Parsed OK: consumed {} bytes, table size {}",
1579 consumed,
1580 parsed.size()
1581 );
1582 }
1583 Err(e) => println!("Parse error: {:?}", e),
1584 }
1585 assert!(
1586 result.is_ok(),
1587 "Parse should succeed with padded distribution"
1588 );
1589
1590 let (parsed_table, _) = result.unwrap();
1592 assert_eq!(parsed_table.accuracy_log(), table.accuracy_log());
1593 assert_eq!(parsed_table.size(), table.size());
1594 }
1595
1596 #[test]
1597 fn test_serialize_parse_roundtrip_2symbol() {
1598 let frequencies = [22u32, 10];
1600
1601 let (table, normalized) = FseTable::from_frequencies_serializable(&frequencies, 5).unwrap();
1602
1603 println!("2-symbol test: accuracy_log={}", table.accuracy_log());
1604 println!("Normalized: {:?}", normalized);
1605
1606 let sum: i32 = normalized
1607 .iter()
1608 .map(|&p| if p == -1 { 1 } else { p as i32 })
1609 .sum();
1610 assert_eq!(sum, 32, "Should sum to 32");
1611
1612 let bytes = table.serialize(&normalized);
1613 println!("Serialized: {} bytes: {:02x?}", bytes.len(), bytes);
1614
1615 let max_symbol = (normalized.len() - 1) as u8;
1616 println!("Parsing with max_symbol={}", max_symbol);
1617 let result = FseTable::parse(&bytes, max_symbol);
1618 match &result {
1619 Ok((parsed, consumed)) => {
1620 println!(
1621 "Parsed OK: consumed {} bytes, table size {}",
1622 consumed,
1623 parsed.size()
1624 );
1625 }
1626 Err(e) => println!("Parse error: {:?}", e),
1627 }
1628 assert!(result.is_ok(), "2-symbol with padding should parse");
1629 }
1630
1631 #[test]
1636 fn test_custom_table_from_frequencies_zipf() {
1637 let frequencies = [100u32, 50, 25, 12, 6, 3, 2, 1, 1];
1639
1640 let (table, normalized) = FseTable::from_frequencies(&frequencies, 9).unwrap();
1642
1643 assert!(table.is_valid());
1645 assert_eq!(table.max_symbol() as usize, frequencies.len() - 1);
1646
1647 let sum: i32 = normalized
1649 .iter()
1650 .map(|&p| if p == -1 { 1 } else { p as i32 })
1651 .sum();
1652 assert_eq!(sum, 1 << 9); }
1654
1655 #[test]
1656 fn test_custom_table_serialization_roundtrip() {
1657 let frequencies = [100u32, 50, 25, 12, 6, 4, 2, 1];
1660
1661 let (table, normalized) = FseTable::from_frequencies_serializable(&frequencies, 8).unwrap();
1663
1664 println!("Normalized: {:?}", normalized);
1666 println!("Accuracy log: {}", table.accuracy_log());
1667
1668 let bytes = table.serialize(&normalized);
1670 println!("Serialized {} bytes: {:02x?}", bytes.len(), bytes);
1671
1672 let max_symbol = (normalized.len() - 1) as u8;
1675 let result = FseTable::parse(&bytes, max_symbol);
1676
1677 match result {
1678 Ok((restored, consumed)) => {
1679 println!("Parsed {} bytes, table size {}", consumed, restored.size());
1680 assert_eq!(table.accuracy_log(), restored.accuracy_log());
1682 assert_eq!(table.size(), restored.size());
1683 }
1684 Err(e) => {
1685 println!("Parse failed (expected limitation): {:?}", e);
1688 assert!(
1691 table.is_valid(),
1692 "Table should be valid even if serialization fails"
1693 );
1694 }
1695 }
1696 }
1697
1698 #[test]
1699 fn test_custom_table_encode_decode_roundtrip() {
1700 use crate::fse::{BitReader, FseBitWriter, FseDecoder, FseEncoder};
1701
1702 let frequencies = [100u32, 50, 25, 12];
1704 let (table, _) = FseTable::from_frequencies(&frequencies, 8).unwrap();
1705
1706 let mut encoder = FseEncoder::from_decode_table(&table);
1708 let symbols = vec![0u8, 1, 2, 3, 0, 0, 1, 2, 0, 1, 0, 0, 0];
1709
1710 encoder.init_state(symbols[0]);
1712 let mut writer = FseBitWriter::new();
1713
1714 for &sym in &symbols[1..] {
1715 let (bits, num_bits) = encoder.encode_symbol(sym);
1716 writer.write_bits(bits, num_bits);
1717 }
1718
1719 let final_state = encoder.get_state();
1721 writer.write_bits(final_state as u32, table.accuracy_log());
1722
1723 let encoded = writer.finish();
1724
1725 let mut decoder = FseDecoder::new(&table);
1727 let mut reader = BitReader::new(&encoded);
1728
1729 assert!(encoded.len() > 0, "Encoding produced data");
1733
1734 for state in 0..table.size() {
1736 let entry = table.decode(state);
1737 assert!(entry.symbol < frequencies.len() as u8);
1738 }
1739 }
1740
1741 #[test]
1742 fn test_custom_table_beats_predefined_for_skewed_data() {
1743 let frequencies = [1000u32, 1, 1, 1];
1745 let (custom_table, _) = FseTable::from_frequencies(&frequencies, 8).unwrap();
1746
1747 let predefined =
1749 FseTable::from_predefined(&LITERAL_LENGTH_DEFAULT_DISTRIBUTION, 6).unwrap();
1750
1751 let custom_symbol0_count = (0..custom_table.size())
1753 .filter(|&s| custom_table.decode(s).symbol == 0)
1754 .count();
1755
1756 let predefined_symbol0_count = (0..predefined.size())
1758 .filter(|&s| predefined.decode(s).symbol == 0)
1759 .count();
1760
1761 assert!(
1763 custom_symbol0_count > predefined_symbol0_count * 10,
1764 "Custom: {} states for symbol 0, Predefined: {}",
1765 custom_symbol0_count,
1766 predefined_symbol0_count
1767 );
1768
1769 let custom_avg_bits: f64 = (0..custom_table.size())
1772 .filter(|&s| custom_table.decode(s).symbol == 0)
1773 .map(|s| custom_table.decode(s).num_bits as f64)
1774 .sum::<f64>()
1775 / custom_symbol0_count as f64;
1776
1777 assert!(
1778 custom_avg_bits < 4.0,
1779 "Symbol 0 should use few bits: {}",
1780 custom_avg_bits
1781 );
1782 }
1783
1784 #[test]
1785 fn test_table_accuracy_log_selection() {
1786 let frequencies = [100u32, 50, 25, 12, 6, 3, 2, 1];
1787
1788 for log in [5, 6, 7, 8, 9, 10, 11] {
1790 let (table, _) = FseTable::from_frequencies(&frequencies, log).unwrap();
1791 assert_eq!(
1792 table.accuracy_log(),
1793 log,
1794 "Table should use accuracy_log={}",
1795 log
1796 );
1797 assert_eq!(table.size(), 1 << log, "Table size should be 2^{}", log);
1798 }
1799 }
1800
1801 #[test]
1802 fn test_invalid_frequencies_rejected() {
1803 let result = FseTable::from_frequencies(&[0, 0, 0], 8);
1805 assert!(result.is_err(), "All-zero frequencies should be rejected");
1806
1807 let result = FseTable::from_frequencies(&[], 8);
1809 assert!(result.is_err(), "Empty frequencies should be rejected");
1810
1811 let result = FseTable::from_frequencies(&[0], 8);
1813 assert!(result.is_err(), "Single zero frequency should be rejected");
1814 }
1815
1816 #[test]
1817 fn test_rle_mode_detection() {
1818 let frequencies = [1000u32, 0, 0, 0];
1820 let (table, _) = FseTable::from_frequencies(&frequencies, 8).unwrap();
1821
1822 assert!(
1824 table.is_rle_mode(),
1825 "Single-symbol table should be RLE mode"
1826 );
1827
1828 for state in 0..table.size() {
1830 assert_eq!(table.decode(state).symbol, 0);
1831 }
1832 }
1833
1834 #[test]
1835 fn test_non_rle_mode() {
1836 let frequencies = [50u32, 50];
1838 let (table, _) = FseTable::from_frequencies(&frequencies, 8).unwrap();
1839
1840 assert!(
1841 !table.is_rle_mode(),
1842 "Multi-symbol table should not be RLE mode"
1843 );
1844 }
1845
1846 #[test]
1847 fn test_is_valid_positive() {
1848 let frequencies = [100u32, 50, 25, 12];
1849 let (table, _) = FseTable::from_frequencies(&frequencies, 8).unwrap();
1850
1851 assert!(table.is_valid(), "Well-formed table should be valid");
1852 }
1853
1854 #[test]
1855 fn test_predefined_tables_are_valid() {
1856 let ll_table = FseTable::from_predefined(&LITERAL_LENGTH_DEFAULT_DISTRIBUTION, 6).unwrap();
1858 assert!(ll_table.is_valid(), "Predefined LL table should be valid");
1859
1860 let ml_table = FseTable::from_predefined(&MATCH_LENGTH_DEFAULT_DISTRIBUTION, 6).unwrap();
1861 assert!(ml_table.is_valid(), "Predefined ML table should be valid");
1862
1863 let of_table = FseTable::from_predefined(&OFFSET_DEFAULT_DISTRIBUTION, 5).unwrap();
1864 assert!(of_table.is_valid(), "Predefined OF table should be valid");
1865 }
1866
1867 #[test]
1868 fn test_custom_table_symbol_distribution() {
1869 let frequencies = [64u32, 32, 16, 8, 4, 4]; let (table, normalized) = FseTable::from_frequencies(&frequencies, 7).unwrap();
1872
1873 let mut symbol_counts = [0usize; 6];
1875 for state in 0..table.size() {
1876 let sym = table.decode(state).symbol;
1877 if (sym as usize) < 6 {
1878 symbol_counts[sym as usize] += 1;
1879 }
1880 }
1881
1882 for (i, &norm) in normalized.iter().enumerate() {
1884 let expected = if norm == -1 { 1 } else { norm as usize };
1885 assert_eq!(
1886 symbol_counts[i], expected,
1887 "Symbol {} should have {} states, got {}",
1888 i, expected, symbol_counts[i]
1889 );
1890 }
1891 }
1892}