1use crate::bit_writer::BitWriter;
11use crate::error::{Error, Result};
12
13pub const ANS_LOG_TAB_SIZE: u32 = 12;
15pub const ANS_TAB_SIZE: u32 = 1 << ANS_LOG_TAB_SIZE;
16pub const ANS_TAB_MASK: u32 = ANS_TAB_SIZE - 1;
17
18pub const ANS_MAX_ALPHABET_SIZE: usize = 256;
20
21pub const ANS_SIGNATURE: u32 = 0x13;
23
24#[allow(dead_code)]
26const RLE_MARKER_SYM: u8 = 13;
27
28const LOGCOUNT_PREFIX_CODE: [(u8, u8); 14] = [
32 (5, 0b10001), (4, 0b1011), (4, 0b1111), (4, 0b0011), (4, 0b1001), (4, 0b0111), (3, 0b100), (3, 0b010), (3, 0b101), (3, 0b110), (3, 0b000), (6, 0b100001), (7, 0b0000001), (7, 0b1000001), ];
47
48const RECIPROCAL_PRECISION: u32 = 44;
50
51#[derive(Debug, Clone)]
53pub struct AnsEncSymbolInfo {
54 pub freq: u16,
56 pub ifreq: u64,
58 pub reverse_map: Vec<u16>,
60}
61
62impl AnsEncSymbolInfo {
63 pub fn new(freq: u16) -> Self {
65 let ifreq = if freq > 0 {
66 (1u64 << RECIPROCAL_PRECISION).div_ceil(freq as u64)
67 } else {
68 0
69 };
70
71 Self {
72 freq,
73 ifreq,
74 reverse_map: Vec::new(), }
76 }
77}
78
79pub struct AnsEncoder {
81 state: u32,
83 bits: Vec<(u32, u8)>, }
86
87impl AnsEncoder {
88 pub fn new() -> Self {
90 Self {
91 state: ANS_SIGNATURE << 16,
92 bits: Vec::new(),
93 }
94 }
95
96 pub fn with_capacity(num_tokens: usize) -> Self {
98 Self {
99 state: ANS_SIGNATURE << 16,
100 bits: Vec::with_capacity(num_tokens * 2), }
102 }
103
104 #[inline]
108 pub fn put_symbol(&mut self, info: &AnsEncSymbolInfo) {
109 let freq = info.freq as u32;
110
111 if (self.state >> (32 - ANS_LOG_TAB_SIZE)) >= freq {
113 self.bits.push((self.state & 0xFFFF, 16));
114 self.state >>= 16;
115 }
116
117 let v = ((self.state as u64 * info.ifreq) >> RECIPROCAL_PRECISION) as u32;
120 let remainder = self.state - v * freq;
121
122 let offset = info.reverse_map[remainder as usize] as u32;
124
125 self.state = (v << ANS_LOG_TAB_SIZE) + offset;
127 }
128
129 #[inline]
135 pub fn push_bits(&mut self, bits: u32, nbits: u8) {
136 if nbits > 0 {
137 self.bits.push((bits, nbits));
138 }
139 }
140
141 pub fn finalize(self, writer: &mut BitWriter) -> Result<()> {
145 #[cfg(feature = "debug-tokens")]
147 eprintln!(
148 "ANS finalize: state=0x{:08x}, {} bit chunks",
149 self.state,
150 self.bits.len()
151 );
152
153 writer.write(32, self.state as u64)?;
155
156 for &(bits, nbits) in self.bits.iter().rev() {
158 writer.write(nbits as usize, bits as u64)?;
159 }
160
161 Ok(())
162 }
163
164 pub fn state(&self) -> u32 {
166 self.state
167 }
168}
169
170impl Default for AnsEncoder {
171 fn default() -> Self {
172 Self::new()
173 }
174}
175
176#[derive(Debug, Clone)]
178pub struct AnsDistribution {
179 pub symbols: Vec<AnsEncSymbolInfo>,
181 pub log_alpha_size: u32,
183 pub total: u32,
185}
186
187impl AnsDistribution {
188 pub fn from_frequencies(freqs: &[u32]) -> Result<Self> {
192 if freqs.is_empty() {
193 return Err(Error::InvalidHistogram("empty distribution".to_string()));
194 }
195
196 let total_count: u64 = freqs.iter().map(|&f| f as u64).sum();
197 if total_count == 0 {
198 return Err(Error::InvalidHistogram("all zero frequencies".to_string()));
199 }
200
201 let mut normalized: Vec<u16> = Vec::with_capacity(freqs.len());
203 let mut running_total: u32 = 0;
204
205 for &freq in freqs.iter() {
206 let normalized_freq = if freq == 0 {
207 0
208 } else {
209 ((freq as u64 * ANS_TAB_SIZE as u64) / total_count).max(1) as u16
211 };
212 normalized.push(normalized_freq);
213 running_total += normalized_freq as u32;
214 }
215
216 let diff = running_total as i32 - ANS_TAB_SIZE as i32;
218 if diff != 0 {
219 if let Some((max_idx, _)) = normalized
221 .iter()
222 .enumerate()
223 .filter(|&(_, &f)| f > 0)
224 .max_by_key(|&(_, &f)| f)
225 {
226 let new_val = (normalized[max_idx] as i32 - diff).max(1) as u16;
227 normalized[max_idx] = new_val;
228 }
229 }
230
231 let mut symbols: Vec<AnsEncSymbolInfo> = normalized
233 .iter()
234 .map(|&f| AnsEncSymbolInfo::new(f))
235 .collect();
236
237 Self::build_reverse_maps(&mut symbols)?;
239
240 Ok(Self {
241 symbols,
242 log_alpha_size: ANS_LOG_TAB_SIZE,
243 total: ANS_TAB_SIZE,
244 })
245 }
246
247 pub fn flat(alphabet_size: usize) -> Result<Self> {
249 if alphabet_size == 0 || alphabet_size > ANS_TAB_SIZE as usize {
250 return Err(Error::InvalidHistogram(format!(
251 "invalid alphabet size: {}",
252 alphabet_size
253 )));
254 }
255
256 let base_freq = ANS_TAB_SIZE as usize / alphabet_size;
257 let remainder = ANS_TAB_SIZE as usize % alphabet_size;
258
259 let mut freqs = vec![base_freq as u32; alphabet_size];
260 for freq in freqs.iter_mut().take(remainder) {
261 *freq += 1;
262 }
263
264 Self::from_frequencies(&freqs)
265 }
266
267 pub fn from_normalized_counts(counts: &[i32]) -> Result<Self> {
273 if counts.is_empty() {
274 return Err(Error::InvalidHistogram("empty distribution".to_string()));
275 }
276
277 let total: i32 = counts.iter().sum();
279 if total != ANS_TAB_SIZE as i32 {
280 return Err(Error::InvalidHistogram(format!(
281 "normalized counts sum to {} instead of {}",
282 total, ANS_TAB_SIZE
283 )));
284 }
285
286 let mut symbols: Vec<AnsEncSymbolInfo> = counts
288 .iter()
289 .map(|&c| AnsEncSymbolInfo::new(c.max(0) as u16))
290 .collect();
291
292 Self::build_reverse_maps(&mut symbols)?;
294
295 Ok(Self {
296 symbols,
297 log_alpha_size: ANS_LOG_TAB_SIZE,
298 total: ANS_TAB_SIZE,
299 })
300 }
301
302 fn build_reverse_maps(symbols: &mut [AnsEncSymbolInfo]) -> Result<()> {
310 let alphabet_size = symbols.len();
311 if alphabet_size == 0 {
312 return Ok(());
313 }
314
315 let total: u32 = symbols.iter().map(|s| s.freq as u32).sum();
317 if total != ANS_TAB_SIZE {
318 return Err(Error::InvalidHistogram(format!(
319 "frequencies sum to {} instead of {}",
320 total, ANS_TAB_SIZE
321 )));
322 }
323
324 if let Some(single_sym_idx) = symbols.iter().position(|s| s.freq == ANS_TAB_SIZE as u16) {
328 for sym in symbols.iter_mut() {
330 sym.reverse_map.clear();
331 }
332 let map = &mut symbols[single_sym_idx].reverse_map;
334 map.resize(ANS_TAB_SIZE as usize, 0);
335 for (i, v) in map.iter_mut().enumerate() {
336 *v = i as u16;
337 }
338 return Ok(());
339 }
340
341 let log_alpha_size = if alphabet_size <= 64 {
345 6 } else {
347 let min_bits = (alphabet_size - 1).ilog2() as usize + 1;
348 min_bits.min(ANS_LOG_TAB_SIZE as usize)
349 };
350 let table_size = 1usize << log_alpha_size;
351 let log_bucket_size = ANS_LOG_TAB_SIZE as usize - log_alpha_size;
352 let bucket_size = 1u16 << log_bucket_size;
353
354 #[derive(Clone)]
356 #[allow(dead_code)]
357 struct WorkingBucket {
358 dist: u16, alias_symbol: u16, alias_offset: u16, alias_cutoff: u16, }
363
364 let mut buckets: Vec<WorkingBucket> = (0..table_size)
365 .map(|i| {
366 let dist = if i < alphabet_size {
367 symbols[i].freq
368 } else {
369 0
370 };
371 WorkingBucket {
372 dist,
373 alias_symbol: if i < alphabet_size { i as u16 } else { 0 },
374 alias_offset: 0,
375 alias_cutoff: dist,
376 }
377 })
378 .collect();
379
380 let mut underfull: Vec<usize> = Vec::with_capacity(table_size);
382 let mut overfull: Vec<usize> = Vec::with_capacity(table_size);
383 for (i, bucket) in buckets.iter().enumerate() {
384 if bucket.alias_cutoff < bucket_size {
385 underfull.push(i);
386 } else if bucket.alias_cutoff > bucket_size {
387 overfull.push(i);
388 }
389 }
390
391 while let (Some(o), Some(u)) = (overfull.pop(), underfull.pop()) {
393 let by = bucket_size - buckets[u].alias_cutoff;
394 buckets[o].alias_cutoff -= by;
395 buckets[u].alias_symbol = o as u16;
396 buckets[u].alias_offset = buckets[o].alias_cutoff;
397
398 match buckets[o].alias_cutoff.cmp(&bucket_size) {
399 std::cmp::Ordering::Less => underfull.push(o),
400 std::cmp::Ordering::Greater => overfull.push(o),
401 std::cmp::Ordering::Equal => {}
402 }
403 }
404
405 for sym in symbols.iter_mut() {
407 sym.reverse_map.clear();
408 sym.reverse_map.resize(sym.freq as usize, 0);
409 }
410
411 for idx in 0..ANS_TAB_SIZE {
414 let bucket_idx = (idx >> log_bucket_size) as usize;
415 let pos = (idx as u16) & (bucket_size - 1);
416
417 let bucket = &buckets[bucket_idx.min(table_size - 1)];
418 let alias_cutoff = bucket.alias_cutoff;
419
420 let (symbol, offset) = if pos < alias_cutoff {
421 (bucket_idx, pos)
422 } else {
423 let alias_sym = bucket.alias_symbol as usize;
424 let offset = bucket.alias_offset - alias_cutoff + pos;
425 (alias_sym, offset)
426 };
427
428 if symbol < alphabet_size {
429 symbols[symbol].reverse_map[offset as usize] = idx as u16;
430 }
431 }
432
433 Ok(())
434 }
435
436 pub fn alphabet_size(&self) -> usize {
438 self.symbols.len()
439 }
440
441 pub fn get(&self, symbol: usize) -> Option<&AnsEncSymbolInfo> {
443 self.symbols.get(symbol)
444 }
445
446 pub fn write(&self, writer: &mut BitWriter) -> Result<()> {
448 let is_flat = self.is_flat();
450
451 writer.write(1, 0)?; writer.write(1, u64::from(is_flat))?;
453
454 if is_flat {
455 write_var_len_uint8(writer, (self.alphabet_size() - 1) as u8)?;
457 } else {
458 self.write_general(writer)?;
461 }
462
463 Ok(())
464 }
465
466 fn is_flat(&self) -> bool {
468 let first_freq = self.symbols.first().map(|s| s.freq).unwrap_or(0);
469 if first_freq == 0 {
470 return false;
471 }
472 self.symbols
473 .iter()
474 .all(|s| s.freq == first_freq || s.freq == first_freq - 1)
475 }
476
477 fn write_general(&self, writer: &mut BitWriter) -> Result<()> {
479 let method: u64 = 13; let upper_bound_log = 4; let log = floor_log2(method as u32);
483
484 writer.write(log as usize, (1u64 << log) - 1)?;
486 if log != upper_bound_log {
487 writer.write(1, 0)?;
488 }
489 writer.write(log as usize, ((1u64 << log) - 1) & method)?;
491
492 write_var_len_uint8(writer, (self.alphabet_size() - 3) as u8)?;
494
495 for sym in &self.symbols {
498 let freq = sym.freq;
500 if freq == 0 {
501 writer.write(1, 0)?;
502 } else {
503 writer.write(1, 1)?;
504 let bits = 16 - freq.leading_zeros();
505 writer.write(4, bits as u64)?;
506 if bits > 0 {
507 writer.write(bits as usize, freq as u64)?;
508 }
509 }
510 }
511
512 Ok(())
513 }
514}
515
516fn write_var_len_uint8(writer: &mut BitWriter, n: u8) -> Result<()> {
518 if n == 0 {
519 writer.write(1, 0)?;
520 } else {
521 writer.write(1, 1)?;
522 let nbits = 8 - n.leading_zeros();
523 writer.write(3, (nbits - 1) as u64)?;
524 writer.write((nbits - 1) as usize, (n as u64) - (1u64 << (nbits - 1)))?;
525 }
526 Ok(())
527}
528
529#[inline]
531pub fn floor_log2_ans(n: u32) -> u32 {
532 if n == 0 { 0 } else { 31 - n.leading_zeros() }
533}
534
535#[inline]
536fn floor_log2(n: u32) -> u32 {
537 floor_log2_ans(n)
538}
539
540pub fn get_population_count_precision(logcount: u32, shift: u32) -> u32 {
547 let logcount_i = logcount as i32;
548 let shift_i = shift as i32;
549 let r = logcount_i.min(shift_i - ((ANS_LOG_TAB_SIZE as i32 - logcount_i) >> 1));
550 r.max(0) as u32
551}
552
553#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
555pub enum ANSHistogramStrategy {
556 Fast,
558 Approximate,
560 #[default]
562 Precise,
563}
564
565#[derive(Clone, Debug)]
569pub struct ANSEncodingHistogram {
570 pub counts: Vec<i32>,
572 pub alphabet_size: usize,
574 pub cost: f32,
576 pub method: u32,
581 pub omit_pos: usize,
583 num_symbols: usize,
585 symbols: [usize; 2],
587}
588
589impl ANSEncodingHistogram {
590 pub fn new() -> Self {
592 Self {
593 counts: Vec::new(),
594 alphabet_size: 0,
595 cost: f32::MAX,
596 method: 0,
597 omit_pos: 0,
598 num_symbols: 0,
599 symbols: [0, 0],
600 }
601 }
602
603 pub fn from_histogram(
607 histo: &super::histogram::Histogram,
608 strategy: ANSHistogramStrategy,
609 ) -> Result<Self> {
610 if histo.total_count == 0 {
611 return Ok(Self {
613 counts: vec![0i32; histo.counts.len().max(1)],
614 alphabet_size: 1,
615 cost: 0.0,
616 method: 0, omit_pos: 0,
618 num_symbols: 0,
619 symbols: [0, 0],
620 });
621 }
622
623 let alphabet_size = histo.alphabet_size();
624
625 let mut num_symbols = 0;
627 let mut symbols = [0usize; 2];
628 for (i, &count) in histo.counts.iter().enumerate() {
629 if count > 0 {
630 if num_symbols < 2 {
631 symbols[num_symbols] = i;
632 }
633 num_symbols += 1;
634 }
635 }
636
637 if num_symbols <= 2 {
639 let mut counts = vec![0i32; alphabet_size];
640 if num_symbols == 1 {
641 counts[symbols[0]] = ANS_TAB_SIZE as i32;
642 } else {
643 let total = histo.total_count as f64;
645 let count0 = histo.counts[symbols[0]] as f64;
646 let norm0 = ((count0 / total) * ANS_TAB_SIZE as f64).round() as i32;
647 let norm0 = norm0.clamp(1, (ANS_TAB_SIZE - 1) as i32);
648 counts[symbols[0]] = norm0;
649 counts[symbols[1]] = ANS_TAB_SIZE as i32 - norm0;
650 }
651
652 let cost = if num_symbols <= 1 { 4.0 } else { 4.0 + 12.0 }; return Ok(Self {
656 counts,
657 alphabet_size,
658 cost,
659 method: 1, omit_pos: symbols[0],
661 num_symbols,
662 symbols,
663 });
664 }
665
666 let mut best = Self::new();
668
669 let shifts: Vec<u32> = match strategy {
670 ANSHistogramStrategy::Fast => vec![0, 6, 12],
671 ANSHistogramStrategy::Approximate => (0..=ANS_LOG_TAB_SIZE).step_by(2).collect(),
672 ANSHistogramStrategy::Precise => (0..ANS_LOG_TAB_SIZE).collect(),
673 };
674
675 for shift in shifts {
676 let mut candidate = Self {
677 counts: vec![0i32; alphabet_size],
678 alphabet_size,
679 cost: f32::MAX,
680 method: shift.min(ANS_LOG_TAB_SIZE - 1) + 1,
681 omit_pos: 0,
682 num_symbols,
683 symbols,
684 };
685
686 if candidate.rebalance_histogram(histo, shift) {
687 candidate.cost = candidate.estimate_cost(histo);
688 if candidate.cost < best.cost {
689 best = candidate;
690 }
691 }
692 }
693
694 if best.cost == f32::MAX {
695 return Err(Error::InvalidHistogram(
696 "Failed to rebalance histogram".to_string(),
697 ));
698 }
699
700 Ok(best)
701 }
702
703 fn rebalance_histogram(&mut self, histo: &super::histogram::Histogram, shift: u32) -> bool {
710 let total_count = histo.total_count;
711 if total_count == 0 {
712 return false;
713 }
714
715 let norm = ANS_TAB_SIZE as f64 / total_count as f64;
716
717 for (i, &count) in histo.counts.iter().enumerate().take(self.alphabet_size) {
719 if count == 0 {
720 self.counts[i] = 0;
721 continue;
722 }
723
724 let target = count as f64 * norm;
725 let mut normalized = target.round() as i32;
726 normalized = normalized.max(1);
727 normalized = normalized.min((ANS_TAB_SIZE - 1) as i32);
728 self.counts[i] = normalized;
729 }
730
731 let mut max_logcount = 0u32;
733 let mut omit_pos = 0;
734 for (i, &count) in self.counts.iter().enumerate().take(self.alphabet_size) {
735 if count > 0 {
736 let logcount = floor_log2(count as u32) + 1;
737 if logcount > max_logcount {
738 max_logcount = logcount;
739 omit_pos = i;
740 }
741 }
742 }
743 self.omit_pos = omit_pos;
744
745 let mut running_total = 0i32;
747 for i in 0..self.alphabet_size {
748 if i == omit_pos || self.counts[i] == 0 {
749 if i != omit_pos {
750 running_total += self.counts[i];
751 }
752 continue;
753 }
754
755 let mut normalized = self.counts[i];
756
757 if shift < ANS_LOG_TAB_SIZE && normalized > 1 {
759 let logcount = floor_log2(normalized as u32);
760 let precision = get_population_count_precision(logcount, shift);
761 let drop_bits = logcount.saturating_sub(precision);
762 let mask = (1i32 << drop_bits) - 1;
763 normalized &= !mask;
764 if normalized == 0 {
765 normalized = 1i32 << drop_bits;
766 }
767 }
768
769 self.counts[i] = normalized;
770 running_total += normalized;
771 }
772
773 let remainder = ANS_TAB_SIZE as i32 - running_total;
775 if remainder <= 0 || remainder > ANS_TAB_SIZE as i32 {
776 return false;
777 }
778 self.counts[omit_pos] = remainder;
779
780 let omit_logcount = floor_log2(self.counts[omit_pos] as u32) + 1;
785 for (i, &count) in self.counts.iter().enumerate().take(self.alphabet_size) {
786 if i == omit_pos {
787 continue;
788 }
789 if count > 0 {
790 let logcount = floor_log2(count as u32) + 1;
791 if logcount > omit_logcount {
792 return false;
794 }
795 if logcount == omit_logcount && i < omit_pos {
796 return false;
798 }
799 }
800 }
801
802 let sum: i32 = self.counts.iter().sum();
804 sum == ANS_TAB_SIZE as i32
805 }
806
807 fn estimate_cost(&self, histo: &super::histogram::Histogram) -> f32 {
809 let header_cost = self.estimate_header_cost();
811
812 let data_cost = self.estimate_data_cost(histo);
814
815 header_cost + data_cost
816 }
817
818 fn estimate_header_cost(&self) -> f32 {
820 if self.method == 0 {
821 2.0 + 8.0
823 } else if self.num_symbols <= 2 {
824 if self.num_symbols <= 1 {
826 3.0 + 8.0 } else {
828 3.0 + 16.0 + 12.0 }
830 } else {
831 let method_bits = 4.0; let alphabet_bits = 8.0;
834 let freq_bits = self.alphabet_size as f32 * 5.0; method_bits + alphabet_bits + freq_bits
836 }
837 }
838
839 fn estimate_data_cost(&self, histo: &super::histogram::Histogram) -> f32 {
841 let mut cost = 0.0f32;
842
843 for (i, &count) in histo.counts.iter().enumerate() {
844 if count > 0 {
845 let normalized = self.counts.get(i).copied().unwrap_or(1).max(1);
846 let prob = normalized as f32 / ANS_TAB_SIZE as f32;
847 cost -= count as f32 * prob.log2();
848 }
849 }
850
851 cost
852 }
853
854 pub fn write(&self, writer: &mut BitWriter) -> Result<()> {
856 if self.method == 0 {
857 writer.write(1, 0)?; writer.write(1, 1)?; write_var_len_uint8(writer, (self.alphabet_size - 1) as u8)?;
861 return Ok(());
862 }
863
864 if self.num_symbols <= 2 {
865 writer.write(1, 1)?; if self.num_symbols == 0 {
868 writer.write(1, 0)?;
869 write_var_len_uint8(writer, 0)?;
870 } else {
871 writer.write(1, (self.num_symbols - 1) as u64)?;
872 for i in 0..self.num_symbols {
873 write_var_len_uint8(writer, self.symbols[i] as u8)?;
874 }
875 if self.num_symbols == 2 {
876 writer.write(
877 ANS_LOG_TAB_SIZE as usize,
878 self.counts[self.symbols[0]] as u64,
879 )?;
880 }
881 }
882 return Ok(());
883 }
884
885 self.write_general(writer)
887 }
888
889 fn write_general(&self, writer: &mut BitWriter) -> Result<()> {
893 writer.write(1, 0)?; writer.write(1, 0)?; let shift = (self.method - 1) as i32;
899 let shift_val = (shift + 1) as u32; let mut len = 0u32;
903 while len < 3 && shift_val >= (1u32 << (len + 1)) {
904 len += 1;
905 }
906
907 for _ in 0..len {
909 writer.write(1, 1)?;
910 }
911 if len < 3 {
913 writer.write(1, 0)?;
914 }
915 if len > 0 {
917 let suffix = shift_val - (1u32 << len);
918 writer.write(len as usize, suffix as u64)?;
919 }
920
921 if self.alphabet_size < 3 {
923 return Err(Error::InvalidHistogram(
924 "General histogram needs at least 3 symbols".to_string(),
925 ));
926 }
927 write_var_len_uint8(writer, (self.alphabet_size - 3) as u8)?;
928
929 for i in 0..self.alphabet_size {
932 let count = self.counts[i];
933
934 let logcount = if count <= 0 {
936 0
937 } else {
938 floor_log2(count as u32) + 1
939 };
940
941 let (nbits, code) = LOGCOUNT_PREFIX_CODE[logcount as usize];
943 writer.write(nbits as usize, code as u64)?;
944 }
945
946 for i in 0..self.alphabet_size {
949 if i == self.omit_pos {
950 continue;
951 }
952
953 let count = self.counts[i];
954 if count <= 0 {
955 continue;
956 }
957
958 let logcount = floor_log2(count as u32) + 1;
959 if logcount <= 1 {
960 continue;
962 }
963
964 let zeros = (logcount - 1) as i32;
966 let bitcount = (shift - ((ANS_LOG_TAB_SIZE as i32 - zeros) >> 1)).clamp(0, zeros);
968
969 if bitcount > 0 {
970 let base = 1i32 << zeros;
973 let extra = ((count - base) >> (zeros - bitcount)) as u32;
974 writer.write(bitcount as usize, extra as u64)?;
975 }
976 }
977
978 Ok(())
979 }
980}
981
982impl Default for ANSEncodingHistogram {
983 fn default() -> Self {
984 Self::new()
985 }
986}
987
988pub fn encode_tokens_ans(
990 tokens: &[(u32, u32)], distributions: &[AnsDistribution],
992 context_map: &[usize],
993 writer: &mut BitWriter,
994) -> Result<()> {
995 let mut encoder = AnsEncoder::new();
996
997 for &(context, value) in tokens.iter().rev() {
999 let dist_idx = context_map.get(context as usize).copied().unwrap_or(0);
1000 if let Some(dist) = distributions.get(dist_idx)
1001 && let Some(info) = dist.get(value as usize)
1002 {
1003 encoder.put_symbol(info);
1004 }
1005 }
1006
1007 encoder.finalize(writer)
1008}
1009
1010#[cfg(test)]
1011mod tests {
1012 use super::*;
1013 use crate::entropy_coding::histogram::Histogram;
1014
1015 #[test]
1016 fn test_ans_encoding_histogram_single_symbol() {
1017 let h = Histogram::from_counts(&[100, 0, 0, 0]);
1018 let encoded = ANSEncodingHistogram::from_histogram(&h, ANSHistogramStrategy::Fast).unwrap();
1019
1020 assert_eq!(encoded.num_symbols, 1);
1021 assert_eq!(encoded.method, 1); assert_eq!(encoded.counts[0], ANS_TAB_SIZE as i32);
1023 assert!(encoded.cost < 100.0); }
1025
1026 #[test]
1027 fn test_ans_encoding_histogram_two_symbols() {
1028 let h = Histogram::from_counts(&[100, 100, 0, 0]);
1029 let encoded = ANSEncodingHistogram::from_histogram(&h, ANSHistogramStrategy::Fast).unwrap();
1030
1031 assert_eq!(encoded.num_symbols, 2);
1032 assert_eq!(encoded.method, 1); let sum: i32 = encoded.counts.iter().sum();
1035 assert_eq!(sum, ANS_TAB_SIZE as i32);
1036 assert!(encoded.counts[0] > 0);
1037 assert!(encoded.counts[1] > 0);
1038 }
1039
1040 #[test]
1041 fn test_ans_encoding_histogram_general() {
1042 let h = Histogram::from_counts(&[100, 50, 25, 10, 5, 3, 2, 1]);
1043 let encoded = ANSEncodingHistogram::from_histogram(&h, ANSHistogramStrategy::Fast).unwrap();
1044
1045 assert!(encoded.method >= 2 || encoded.method == 0);
1047
1048 let sum: i32 = encoded.counts.iter().sum();
1050 assert_eq!(sum, ANS_TAB_SIZE as i32);
1051
1052 for (i, &orig) in h.counts.iter().enumerate() {
1054 if orig > 0 {
1055 assert!(
1056 encoded.counts.get(i).copied().unwrap_or(0) > 0,
1057 "Symbol {} had count {} but normalized to 0",
1058 i,
1059 orig
1060 );
1061 }
1062 }
1063 }
1064
1065 #[test]
1066 fn test_ans_encoding_histogram_empty() {
1067 let h = Histogram::new();
1068 let encoded = ANSEncodingHistogram::from_histogram(&h, ANSHistogramStrategy::Fast).unwrap();
1069
1070 assert_eq!(encoded.cost, 0.0);
1071 assert_eq!(encoded.method, 0); }
1073
1074 #[test]
1075 fn test_get_population_count_precision() {
1076 assert_eq!(get_population_count_precision(0, 12), 0);
1078
1079 assert_eq!(get_population_count_precision(12, 12), 12);
1081
1082 assert_eq!(get_population_count_precision(6, 6), 3);
1084
1085 assert_eq!(get_population_count_precision(1, 0), 0);
1087 }
1088
1089 #[test]
1090 fn test_ans_encoding_histogram_write() {
1091 let h = Histogram::from_counts(&[100, 0, 0, 0]);
1092 let encoded = ANSEncodingHistogram::from_histogram(&h, ANSHistogramStrategy::Fast).unwrap();
1093
1094 let mut writer = BitWriter::new();
1095 encoded.write(&mut writer).unwrap();
1096
1097 let bytes = writer.finish_with_padding();
1098 assert!(!bytes.is_empty());
1099 }
1100
1101 #[test]
1102 fn test_flat_distribution() {
1103 let dist = AnsDistribution::flat(16).unwrap();
1104 assert_eq!(dist.alphabet_size(), 16);
1105
1106 for sym in &dist.symbols {
1108 assert_eq!(sym.freq, 256);
1109 }
1110 }
1111
1112 #[test]
1113 fn test_from_frequencies() {
1114 let freqs = vec![100, 200, 300, 400];
1115 let dist = AnsDistribution::from_frequencies(&freqs).unwrap();
1116 assert_eq!(dist.alphabet_size(), 4);
1117
1118 let total: u32 = dist.symbols.iter().map(|s| s.freq as u32).sum();
1120 assert_eq!(total, ANS_TAB_SIZE);
1121 }
1122
1123 #[test]
1124 fn test_ans_encoder_basic() {
1125 let dist = AnsDistribution::flat(4).unwrap();
1126 let mut encoder = AnsEncoder::new();
1127
1128 encoder.put_symbol(&dist.symbols[0]);
1130 encoder.put_symbol(&dist.symbols[1]);
1131 encoder.put_symbol(&dist.symbols[2]);
1132
1133 assert_ne!(encoder.state(), ANS_SIGNATURE << 16);
1135 }
1136
1137 #[test]
1138 fn test_reverse_map() {
1139 let dist = AnsDistribution::flat(4).unwrap();
1140
1141 for sym in &dist.symbols {
1143 assert_eq!(sym.reverse_map.len(), sym.freq as usize);
1144 }
1145
1146 let mut covered = vec![false; ANS_TAB_SIZE as usize];
1148 for sym in &dist.symbols {
1149 for &pos in &sym.reverse_map {
1150 assert!(!covered[pos as usize], "position {} covered twice", pos);
1151 covered[pos as usize] = true;
1152 }
1153 }
1154 assert!(covered.iter().all(|&c| c), "not all positions covered");
1155 }
1156
1157 #[test]
1158 fn test_write_distribution() {
1159 let dist = AnsDistribution::flat(16).unwrap();
1160 let mut writer = BitWriter::new();
1161 dist.write(&mut writer).unwrap();
1162
1163 let bytes = writer.finish_with_padding();
1164 assert!(!bytes.is_empty());
1166 }
1167
1168 #[test]
1169 fn test_ans_roundtrip_manual() {
1170 let dist = AnsDistribution::flat(2).unwrap();
1172
1173 println!("Distribution: {} symbols", dist.alphabet_size());
1174 for (i, sym) in dist.symbols.iter().enumerate() {
1175 println!(" Symbol {}: freq={}", i, sym.freq);
1176 }
1177
1178 let mut encoder = AnsEncoder::new();
1180 let initial_state = encoder.state();
1181 println!("\nInitial state: 0x{:08x}", initial_state);
1182 assert_eq!(initial_state, 0x130000, "Initial state should be 0x130000");
1183
1184 let info = &dist.symbols[0];
1185 encoder.put_symbol(info);
1186 let encoded_state = encoder.state();
1187 println!("After encoding symbol 0: state=0x{:08x}", encoded_state);
1188
1189 let idx = encoded_state & 0xFFF;
1191 println!("Decode: idx = {}", idx);
1192
1193 let decoded_symbol = if idx < 2048 { 0 } else { 1 };
1197 let offset_in_symbol = if idx < 2048 { idx } else { idx - 2048 };
1198 let freq = 2048u32;
1199
1200 println!("Decoded symbol: {}", decoded_symbol);
1201 println!("Offset in symbol: {}", offset_in_symbol);
1202
1203 let quotient = encoded_state >> 12;
1205 let next_state = quotient * freq + offset_in_symbol;
1206 println!(
1207 "next_state = {} * {} + {} = 0x{:08x}",
1208 quotient, freq, offset_in_symbol, next_state
1209 );
1210
1211 assert_eq!(next_state, 0x130000, "Decoded state should be 0x130000");
1213 assert_eq!(decoded_symbol, 0, "Decoded symbol should be 0");
1214 }
1215
1216 #[test]
1217 fn test_ans_roundtrip_multiple_symbols() {
1218 use crate::bit_writer::BitWriter;
1219 use crate::entropy_coding::ans_decode::{AnsHistogram, AnsReader, BitReader};
1220
1221 let counts = [1024i32, 1024, 1024, 1024];
1226 let dist = AnsDistribution::from_normalized_counts(&counts).unwrap();
1227
1228 let symbols_to_encode: Vec<usize> = vec![0, 1, 2, 3, 0, 1];
1229 println!(
1230 "Encoding {} symbols: {:?}",
1231 symbols_to_encode.len(),
1232 symbols_to_encode
1233 );
1234
1235 let mut encoder = AnsEncoder::new();
1237 for &sym in symbols_to_encode.iter().rev() {
1238 encoder.put_symbol(&dist.symbols[sym]);
1239 }
1240
1241 println!("Final state after encoding: 0x{:08x}", encoder.state());
1242
1243 let mut writer = BitWriter::new();
1245 encoder.finalize(&mut writer).unwrap();
1246 let encoded_bytes = writer.finish_with_padding();
1247 println!("Encoded bytes: {:02x?}", encoded_bytes);
1248
1249 let ans_histo = ANSEncodingHistogram::from_histogram(
1251 &Histogram::from_counts(&counts),
1252 ANSHistogramStrategy::Precise,
1253 )
1254 .unwrap();
1255 let mut hist_writer = BitWriter::new();
1256 ans_histo.write(&mut hist_writer).unwrap();
1257 let hist_bytes = hist_writer.finish_with_padding();
1258
1259 let mut hist_br = BitReader::new(&hist_bytes);
1260 let decoded_hist = AnsHistogram::decode(&mut hist_br, 6).unwrap();
1261
1262 println!(
1263 "Decoded histogram frequencies: {:?}",
1264 &decoded_hist.frequencies[..4]
1265 );
1266
1267 let mut br = BitReader::new(&encoded_bytes);
1269 let mut ans_reader = AnsReader::init(&mut br).unwrap();
1270
1271 println!("Decoding:");
1272 let mut decoded = Vec::new();
1273 for i in 0..symbols_to_encode.len() {
1274 let symbol = decoded_hist.read(&mut br, &mut ans_reader.0) as usize;
1275 println!(
1276 " step {}: symbol={}, state=0x{:08x}",
1277 i, symbol, ans_reader.0
1278 );
1279 decoded.push(symbol);
1280 }
1281
1282 println!("Original: {:?}", symbols_to_encode);
1283 println!("Decoded: {:?}", decoded);
1284 println!("Final state: 0x{:08x}", ans_reader.0);
1285
1286 assert_eq!(
1287 decoded, symbols_to_encode,
1288 "Decoded symbols should match original"
1289 );
1290 assert!(
1291 ans_reader.check_final_state().is_ok(),
1292 "Final state should be 0x130000, got 0x{:08x}",
1293 ans_reader.0
1294 );
1295 }
1296
1297 #[test]
1298 fn test_ans_histogram_write_decode_roundtrip() {
1299 use crate::bit_writer::BitWriter;
1300 use crate::entropy_coding::histogram::Histogram;
1301
1302 let histo = Histogram::from_counts(&[100, 50, 25, 10]);
1304
1305 let encoded =
1306 ANSEncodingHistogram::from_histogram(&histo, ANSHistogramStrategy::Precise).unwrap();
1307
1308 println!("Histogram: {:?}", histo.counts);
1309 println!("Encoded counts: {:?}", encoded.counts);
1310 println!(
1311 "Method: {}, alphabet_size: {}, omit_pos: {}",
1312 encoded.method, encoded.alphabet_size, encoded.omit_pos
1313 );
1314
1315 let sum: i32 = encoded.counts.iter().sum();
1317 assert_eq!(sum, ANS_TAB_SIZE as i32, "Sum should be 4096");
1318
1319 let mut writer = BitWriter::new();
1321 encoded.write(&mut writer).unwrap();
1322 let bytes = writer.finish_with_padding();
1323
1324 println!("Encoded histogram: {} bytes", bytes.len());
1325 println!("Bytes: {:02x?}", bytes);
1326 }
1327}