1use crate::bit_writer::BitWriter;
11use crate::error::{Error, Result};
12
13pub const ANS_LOG_TAB_SIZE: u32 = 12;
15pub const ANS_TAB_SIZE: u32 = 1 << ANS_LOG_TAB_SIZE;
16pub const ANS_TAB_MASK: u32 = ANS_TAB_SIZE - 1;
17
18pub const ANS_MAX_ALPHABET_SIZE: usize = 256;
20
21pub const ANS_SIGNATURE: u32 = 0x13;
23
24const RLE_MARKER_SYM: u8 = 13;
26
27const LOGCOUNT_PREFIX_CODE: [(u8, u8); 14] = [
31 (5, 0b10001), (4, 0b1011), (4, 0b1111), (4, 0b0011), (4, 0b1001), (4, 0b0111), (3, 0b100), (3, 0b010), (3, 0b101), (3, 0b110), (3, 0b000), (6, 0b100001), (7, 0b0000001), (7, 0b1000001), ];
46
47fn build_allowed_counts(shift: u32) -> Vec<i32> {
51 let mut counts = Vec::with_capacity(256);
52 counts.push(1i32);
54 for bits in 1..ANS_LOG_TAB_SIZE {
55 let precision = get_population_count_precision(bits, shift);
56 let drop_bits = bits.saturating_sub(precision);
57 let num_mantissa = 1u32 << precision;
58 for mantissa in 0..num_mantissa {
59 let count = (1i32 << bits) | ((mantissa as i32) << drop_bits);
60 if count > 0 && count < ANS_TAB_SIZE as i32 {
61 counts.push(count);
62 }
63 }
64 }
65 counts.sort_unstable();
66 counts.dedup();
67 counts.reverse(); counts
69}
70
71pub struct AllowedCountsCache {
75 tables: [Vec<i32>; ANS_LOG_TAB_SIZE as usize + 1],
77}
78
79impl Default for AllowedCountsCache {
80 fn default() -> Self {
81 Self::new()
82 }
83}
84
85impl AllowedCountsCache {
86 pub fn new() -> Self {
88 Self {
89 tables: core::array::from_fn(|shift| build_allowed_counts(shift as u32)),
90 }
91 }
92
93 #[inline]
95 pub fn get(&self, shift: u32) -> &[i32] {
96 &self.tables[shift as usize]
97 }
98}
99
100fn find_allowed_leq(allowed: &[i32], target: i32) -> usize {
104 let mut lo = 0usize;
106 let mut hi = allowed.len();
107 while lo < hi {
108 let mid = lo + (hi - lo) / 2;
109 if allowed[mid] > target {
110 lo = mid + 1;
111 } else {
112 hi = mid;
113 }
114 }
115 if lo >= allowed.len() {
117 allowed.len() - 1 } else {
119 lo
120 }
121}
122
123fn estimate_data_bits_normalized(
127 histo_counts: &[i32],
128 norm_counts: &[i32],
129 total_count: usize,
130 alphabet_size: usize,
131) -> f64 {
132 let mut sum = 0.0f64;
133 for (actual, norm) in histo_counts
134 .iter()
135 .zip(norm_counts.iter())
136 .take(alphabet_size)
137 {
138 if *actual > 0 && *norm > 0 {
139 sum += *actual as f64 * jxl_simd::fast_log2f(*norm as f32) as f64;
140 }
141 }
142 total_count as f64 * ANS_LOG_TAB_SIZE as f64 - sum
143}
144
145const RECIPROCAL_PRECISION: u32 = 44;
147
148#[derive(Debug, Clone)]
150pub struct AnsEncSymbolInfo {
151 pub freq: u16,
153 pub ifreq: u64,
155 pub reverse_map: Vec<u16>,
157}
158
159impl AnsEncSymbolInfo {
160 pub fn new(freq: u16) -> Self {
162 let ifreq = if freq > 0 {
163 (1u64 << RECIPROCAL_PRECISION).div_ceil(freq as u64)
164 } else {
165 0
166 };
167
168 Self {
169 freq,
170 ifreq,
171 reverse_map: Vec::new(), }
173 }
174}
175
176pub struct AnsEncoder {
178 state: u32,
180 bits: Vec<(u32, u8)>, }
183
184impl AnsEncoder {
185 pub fn new() -> Self {
187 Self {
188 state: ANS_SIGNATURE << 16,
189 bits: Vec::new(),
190 }
191 }
192
193 pub fn with_capacity(num_tokens: usize) -> Self {
195 Self {
196 state: ANS_SIGNATURE << 16,
197 bits: Vec::with_capacity(num_tokens * 2), }
199 }
200
201 #[inline]
205 pub fn put_symbol(&mut self, info: &AnsEncSymbolInfo) {
206 let freq = info.freq as u32;
207
208 if (self.state >> (32 - ANS_LOG_TAB_SIZE)) >= freq {
210 self.bits.push((self.state & 0xFFFF, 16));
211 self.state >>= 16;
212 }
213
214 let v = ((self.state as u64 * info.ifreq) >> RECIPROCAL_PRECISION) as u32;
217 let remainder = self.state - v * freq;
218
219 let offset = info.reverse_map[remainder as usize] as u32;
221
222 self.state = (v << ANS_LOG_TAB_SIZE) + offset;
224 }
225
226 #[inline]
232 pub fn push_bits(&mut self, bits: u32, nbits: u8) {
233 if nbits > 0 {
234 self.bits.push((bits, nbits));
235 }
236 }
237
238 pub fn finalize(self, writer: &mut BitWriter) -> Result<()> {
242 #[cfg(feature = "debug-tokens")]
244 eprintln!(
245 "ANS finalize: state=0x{:08x}, {} bit chunks",
246 self.state,
247 self.bits.len()
248 );
249
250 writer.write(32, self.state as u64)?;
252
253 for &(bits, nbits) in self.bits.iter().rev() {
255 writer.write(nbits as usize, bits as u64)?;
256 }
257
258 Ok(())
259 }
260
261 pub fn state(&self) -> u32 {
263 self.state
264 }
265}
266
267impl Default for AnsEncoder {
268 fn default() -> Self {
269 Self::new()
270 }
271}
272
273#[derive(Debug, Clone)]
275pub struct AnsDistribution {
276 pub symbols: Vec<AnsEncSymbolInfo>,
278 pub log_alpha_size: u32,
280 pub total: u32,
282}
283
284impl AnsDistribution {
285 pub fn from_frequencies(freqs: &[u32]) -> Result<Self> {
289 if freqs.is_empty() {
290 return Err(Error::InvalidHistogram("empty distribution".to_string()));
291 }
292
293 let total_count: u64 = freqs.iter().map(|&f| f as u64).sum();
294 if total_count == 0 {
295 return Err(Error::InvalidHistogram("all zero frequencies".to_string()));
296 }
297
298 let mut normalized: Vec<u16> = Vec::with_capacity(freqs.len());
300 let mut running_total: u32 = 0;
301
302 for &freq in freqs.iter() {
303 let normalized_freq = if freq == 0 {
304 0
305 } else {
306 ((freq as u64 * ANS_TAB_SIZE as u64) / total_count).max(1) as u16
308 };
309 normalized.push(normalized_freq);
310 running_total += normalized_freq as u32;
311 }
312
313 let diff = running_total as i32 - ANS_TAB_SIZE as i32;
315 if diff != 0 {
316 if let Some((max_idx, _)) = normalized
318 .iter()
319 .enumerate()
320 .filter(|&(_, &f)| f > 0)
321 .max_by_key(|&(_, &f)| f)
322 {
323 let new_val = (normalized[max_idx] as i32 - diff).max(1) as u16;
324 normalized[max_idx] = new_val;
325 }
326 }
327
328 let mut symbols: Vec<AnsEncSymbolInfo> = normalized
330 .iter()
331 .map(|&f| AnsEncSymbolInfo::new(f))
332 .collect();
333
334 let log_alpha_size = Self::default_log_alpha_size(symbols.len());
336 Self::build_reverse_maps(&mut symbols, log_alpha_size)?;
337
338 Ok(Self {
339 symbols,
340 log_alpha_size: ANS_LOG_TAB_SIZE,
341 total: ANS_TAB_SIZE,
342 })
343 }
344
345 pub fn flat(alphabet_size: usize) -> Result<Self> {
347 if alphabet_size == 0 || alphabet_size > ANS_TAB_SIZE as usize {
348 return Err(Error::InvalidHistogram(format!(
349 "invalid alphabet size: {}",
350 alphabet_size
351 )));
352 }
353
354 let base_freq = ANS_TAB_SIZE as usize / alphabet_size;
355 let remainder = ANS_TAB_SIZE as usize % alphabet_size;
356
357 let mut freqs = vec![base_freq as u32; alphabet_size];
358 for freq in freqs.iter_mut().take(remainder) {
359 *freq += 1;
360 }
361
362 Self::from_frequencies(&freqs)
363 }
364
365 pub fn from_normalized_counts(counts: &[i32]) -> Result<Self> {
371 let log_alpha_size = Self::default_log_alpha_size(counts.len());
372 Self::from_normalized_counts_with_log_alpha(counts, log_alpha_size)
373 }
374
375 pub fn from_normalized_counts_with_log_alpha(
382 counts: &[i32],
383 log_alpha_size: usize,
384 ) -> Result<Self> {
385 if counts.is_empty() {
386 return Err(Error::InvalidHistogram("empty distribution".to_string()));
387 }
388
389 let total: i32 = counts.iter().sum();
391 if total != ANS_TAB_SIZE as i32 {
392 return Err(Error::InvalidHistogram(format!(
393 "normalized counts sum to {} instead of {}",
394 total, ANS_TAB_SIZE
395 )));
396 }
397
398 let mut symbols: Vec<AnsEncSymbolInfo> = counts
400 .iter()
401 .map(|&c| AnsEncSymbolInfo::new(c.max(0) as u16))
402 .collect();
403
404 Self::build_reverse_maps(&mut symbols, log_alpha_size)?;
406
407 Ok(Self {
408 symbols,
409 log_alpha_size: ANS_LOG_TAB_SIZE,
410 total: ANS_TAB_SIZE,
411 })
412 }
413
414 fn default_log_alpha_size(alphabet_size: usize) -> usize {
420 use super::encode_ans::ANS_LOG_ALPHA_SIZE;
421 if alphabet_size <= (1 << ANS_LOG_ALPHA_SIZE) {
422 ANS_LOG_ALPHA_SIZE
423 } else {
424 let min_bits = if alphabet_size <= 1 {
425 5
426 } else {
427 (alphabet_size - 1).ilog2() as usize + 1
428 };
429 min_bits.clamp(5, 8)
430 }
431 }
432
433 fn build_reverse_maps(symbols: &mut [AnsEncSymbolInfo], log_alpha_size: usize) -> Result<()> {
446 let alphabet_size = symbols.len();
447 if alphabet_size == 0 {
448 return Ok(());
449 }
450
451 let total: u32 = symbols.iter().map(|s| s.freq as u32).sum();
453 if total != ANS_TAB_SIZE {
454 return Err(Error::InvalidHistogram(format!(
455 "frequencies sum to {} instead of {}",
456 total, ANS_TAB_SIZE
457 )));
458 }
459
460 if let Some(single_sym_idx) = symbols.iter().position(|s| s.freq == ANS_TAB_SIZE as u16) {
464 for sym in symbols.iter_mut() {
466 sym.reverse_map.clear();
467 }
468 let map = &mut symbols[single_sym_idx].reverse_map;
470 map.resize(ANS_TAB_SIZE as usize, 0);
471 for (i, v) in map.iter_mut().enumerate() {
472 *v = i as u16;
473 }
474 return Ok(());
475 }
476
477 let table_size = 1usize << log_alpha_size;
478 let log_bucket_size = ANS_LOG_TAB_SIZE as usize - log_alpha_size;
479 let bucket_size = 1u16 << log_bucket_size;
480
481 #[derive(Clone)]
483 #[allow(dead_code)]
484 struct WorkingBucket {
485 dist: u16, alias_symbol: u16, alias_offset: u16, alias_cutoff: u16, }
490
491 let mut buckets: Vec<WorkingBucket> = (0..table_size)
492 .map(|i| {
493 let dist = if i < alphabet_size {
494 symbols[i].freq
495 } else {
496 0
497 };
498 WorkingBucket {
499 dist,
500 alias_symbol: if i < alphabet_size { i as u16 } else { 0 },
501 alias_offset: 0,
502 alias_cutoff: dist,
503 }
504 })
505 .collect();
506
507 let mut underfull: Vec<usize> = Vec::with_capacity(table_size);
509 let mut overfull: Vec<usize> = Vec::with_capacity(table_size);
510 for (i, bucket) in buckets.iter().enumerate() {
511 if bucket.alias_cutoff < bucket_size {
512 underfull.push(i);
513 } else if bucket.alias_cutoff > bucket_size {
514 overfull.push(i);
515 }
516 }
517
518 while let (Some(o), Some(u)) = (overfull.pop(), underfull.pop()) {
520 let by = bucket_size - buckets[u].alias_cutoff;
521 buckets[o].alias_cutoff -= by;
522 buckets[u].alias_symbol = o as u16;
523 buckets[u].alias_offset = buckets[o].alias_cutoff;
524
525 match buckets[o].alias_cutoff.cmp(&bucket_size) {
526 std::cmp::Ordering::Less => underfull.push(o),
527 std::cmp::Ordering::Greater => overfull.push(o),
528 std::cmp::Ordering::Equal => {}
529 }
530 }
531
532 for sym in symbols.iter_mut() {
534 sym.reverse_map.clear();
535 sym.reverse_map.resize(sym.freq as usize, 0);
536 }
537
538 for idx in 0..ANS_TAB_SIZE {
541 let bucket_idx = (idx >> log_bucket_size) as usize;
542 let pos = (idx as u16) & (bucket_size - 1);
543
544 let bucket = &buckets[bucket_idx.min(table_size - 1)];
545 let alias_cutoff = bucket.alias_cutoff;
546
547 let (symbol, offset) = if pos < alias_cutoff {
548 (bucket_idx, pos)
549 } else {
550 let alias_sym = bucket.alias_symbol as usize;
551 let offset = bucket.alias_offset - alias_cutoff + pos;
552 (alias_sym, offset)
553 };
554
555 if symbol < alphabet_size {
556 symbols[symbol].reverse_map[offset as usize] = idx as u16;
557 }
558 }
559
560 Ok(())
561 }
562
563 pub fn alphabet_size(&self) -> usize {
565 self.symbols.len()
566 }
567
568 pub fn get(&self, symbol: usize) -> Option<&AnsEncSymbolInfo> {
570 self.symbols.get(symbol)
571 }
572
573 pub fn write(&self, writer: &mut BitWriter) -> Result<()> {
575 let is_flat = self.is_flat();
577
578 writer.write(1, 0)?; writer.write(1, u64::from(is_flat))?;
580
581 if is_flat {
582 write_var_len_uint8(writer, (self.alphabet_size() - 1) as u8)?;
584 } else {
585 self.write_general(writer)?;
588 }
589
590 Ok(())
591 }
592
593 fn is_flat(&self) -> bool {
595 let first_freq = self.symbols.first().map(|s| s.freq).unwrap_or(0);
596 if first_freq == 0 {
597 return false;
598 }
599 self.symbols
600 .iter()
601 .all(|s| s.freq == first_freq || s.freq == first_freq - 1)
602 }
603
604 fn write_general(&self, writer: &mut BitWriter) -> Result<()> {
606 let method: u64 = 13; let upper_bound_log = 4; let log = floor_log2(method as u32);
610
611 writer.write(log as usize, (1u64 << log) - 1)?;
613 if log != upper_bound_log {
614 writer.write(1, 0)?;
615 }
616 writer.write(log as usize, ((1u64 << log) - 1) & method)?;
618
619 write_var_len_uint8(writer, (self.alphabet_size() - 3) as u8)?;
621
622 for sym in &self.symbols {
625 let freq = sym.freq;
627 if freq == 0 {
628 writer.write(1, 0)?;
629 } else {
630 writer.write(1, 1)?;
631 let bits = 16 - freq.leading_zeros();
632 writer.write(4, bits as u64)?;
633 if bits > 0 {
634 writer.write(bits as usize, freq as u64)?;
635 }
636 }
637 }
638
639 Ok(())
640 }
641}
642
643fn write_var_len_uint8(writer: &mut BitWriter, n: u8) -> Result<()> {
645 if n == 0 {
646 writer.write(1, 0)?;
647 } else {
648 writer.write(1, 1)?;
649 let nbits = 8 - n.leading_zeros();
650 writer.write(3, (nbits - 1) as u64)?;
651 writer.write((nbits - 1) as usize, (n as u64) - (1u64 << (nbits - 1)))?;
652 }
653 Ok(())
654}
655
656#[inline]
658pub fn floor_log2_ans(n: u32) -> u32 {
659 if n == 0 { 0 } else { 31 - n.leading_zeros() }
660}
661
662#[inline]
663fn floor_log2(n: u32) -> u32 {
664 floor_log2_ans(n)
665}
666
667pub fn get_population_count_precision(logcount: u32, shift: u32) -> u32 {
674 let logcount_i = logcount as i32;
675 let shift_i = shift as i32;
676 let r = logcount_i.min(shift_i - ((ANS_LOG_TAB_SIZE as i32 - logcount_i) >> 1));
677 r.max(0) as u32
678}
679
680#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
682pub enum ANSHistogramStrategy {
683 Fast,
685 Approximate,
687 #[default]
689 Precise,
690}
691
692#[derive(Clone, Debug)]
696pub struct ANSEncodingHistogram {
697 pub counts: Vec<i32>,
699 pub alphabet_size: usize,
701 pub cost: f32,
703 pub method: u32,
708 pub omit_pos: usize,
710 num_symbols: usize,
712 symbols: [usize; 2],
714}
715
716impl ANSEncodingHistogram {
717 pub fn new() -> Self {
719 Self {
720 counts: Vec::new(),
721 alphabet_size: 0,
722 cost: f32::MAX,
723 method: 0,
724 omit_pos: 0,
725 num_symbols: 0,
726 symbols: [0, 0],
727 }
728 }
729
730 pub fn from_histogram(
736 histo: &super::histogram::Histogram,
737 strategy: ANSHistogramStrategy,
738 ) -> Result<Self> {
739 let cache = AllowedCountsCache::new();
740 Self::from_histogram_cached(histo, strategy, &cache)
741 }
742
743 pub fn from_histogram_cached(
749 histo: &super::histogram::Histogram,
750 strategy: ANSHistogramStrategy,
751 cache: &AllowedCountsCache,
752 ) -> Result<Self> {
753 if histo.total_count == 0 {
754 return Ok(Self {
756 counts: vec![0i32; histo.counts.len().max(1)],
757 alphabet_size: 1,
758 cost: 0.0,
759 method: 0, omit_pos: 0,
761 num_symbols: 0,
762 symbols: [0, 0],
763 });
764 }
765
766 let alphabet_size = histo.alphabet_size();
767
768 let mut num_symbols = 0;
770 let mut symbols = [0usize; 2];
771 for (i, &count) in histo.counts.iter().enumerate() {
772 if count > 0 {
773 if num_symbols < 2 {
774 symbols[num_symbols] = i;
775 }
776 num_symbols += 1;
777 }
778 }
779
780 if num_symbols <= 2 {
782 let mut counts = vec![0i32; alphabet_size];
783 if num_symbols == 1 {
784 counts[symbols[0]] = ANS_TAB_SIZE as i32;
785 } else {
786 let total = histo.total_count as f64;
788 let count0 = histo.counts[symbols[0]] as f64;
789 let norm0 = ((count0 / total) * ANS_TAB_SIZE as f64).round() as i32;
790 let norm0 = norm0.clamp(1, (ANS_TAB_SIZE - 1) as i32);
791 counts[symbols[0]] = norm0;
792 counts[symbols[1]] = ANS_TAB_SIZE as i32 - norm0;
793 }
794
795 let cost = if num_symbols <= 1 { 4.0 } else { 4.0 + 12.0 }; return Ok(Self {
799 counts,
800 alphabet_size,
801 cost,
802 method: 1, omit_pos: symbols[0],
804 num_symbols,
805 symbols,
806 });
807 }
808
809 let flat_data_cost = {
813 let log2_alpha = jxl_simd::fast_log2f(alphabet_size as f32);
814 histo.total_count as f32 * log2_alpha
815 };
816 let flat_header_cost = 2.0 + 8.0; let mut best = Self {
818 counts: {
819 let alpha = alphabet_size as u32;
820 let per = ANS_TAB_SIZE / alpha;
821 let remainder = (ANS_TAB_SIZE % alpha) as usize;
822 let mut c = vec![per as i32; alphabet_size];
823 for c in c.iter_mut().take(remainder) {
825 *c += 1;
826 }
827 c
828 },
829 alphabet_size,
830 cost: flat_header_cost + flat_data_cost,
831 method: 0, omit_pos: 0,
833 num_symbols,
834 symbols,
835 };
836
837 let mut candidate_counts = vec![0i32; alphabet_size];
840
841 let shift_iter: &[u32] = match strategy {
843 ANSHistogramStrategy::Fast => &[0, 6, 12],
844 ANSHistogramStrategy::Approximate => &[0, 2, 4, 6, 8, 10, 12],
845 ANSHistogramStrategy::Precise => &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
846 };
847
848 for &shift in shift_iter {
849 candidate_counts.fill(0);
851
852 let mut candidate = Self {
853 counts: Vec::new(), alphabet_size,
855 cost: f32::MAX,
856 method: shift.min(ANS_LOG_TAB_SIZE - 1) + 1,
857 omit_pos: 0,
858 num_symbols,
859 symbols,
860 };
861
862 core::mem::swap(&mut candidate.counts, &mut candidate_counts);
864
865 if candidate.rebalance_histogram_cached(histo, shift, cache.get(shift)) {
866 candidate.cost = candidate.estimate_cost(histo);
867 if candidate.cost < best.cost {
868 core::mem::swap(&mut candidate_counts, &mut best.counts);
871 best = candidate;
872 candidate_counts.resize(alphabet_size, 0);
875 } else {
876 core::mem::swap(&mut candidate.counts, &mut candidate_counts);
878 }
879 } else {
880 core::mem::swap(&mut candidate.counts, &mut candidate_counts);
882 }
883 }
884
885 if best.cost == f32::MAX {
886 eprintln!(
888 "ANS rebalance FAILED: alphabet_size={}, num_symbols={}, total_count={}",
889 alphabet_size, num_symbols, histo.total_count
890 );
891 for (i, &c) in histo.counts.iter().enumerate() {
892 if c > 0 {
893 eprintln!(" symbol {}: count={}", i, c);
894 }
895 }
896 return Err(Error::InvalidHistogram(
897 "Failed to rebalance histogram".to_string(),
898 ));
899 }
900
901 Ok(best)
902 }
903
904 fn rebalance_histogram_cached(
907 &mut self,
908 histo: &super::histogram::Histogram,
909 _shift: u32,
910 allowed: &[i32],
911 ) -> bool {
912 let total_count = histo.total_count;
913 if total_count == 0 {
914 return false;
915 }
916
917 let norm = ANS_TAB_SIZE as f64 / total_count as f64;
918
919 let mut remainder_pos = 0;
922 let mut max_freq = 0i32;
923
924 let mut bins: Vec<(i32, usize, usize)> = Vec::with_capacity(self.alphabet_size);
926 let mut rest = ANS_TAB_SIZE as i32;
927
928 for (n, &freq) in histo.counts.iter().enumerate().take(self.alphabet_size) {
929 if freq > max_freq {
930 remainder_pos = n;
931 max_freq = freq;
932 }
933
934 if freq == 0 {
935 self.counts[n] = 0;
936 continue;
937 }
938
939 let target = freq as f64 * norm;
940 let rounded = target.round().max(1.0).min((ANS_TAB_SIZE - 1) as f64) as i32;
942 let ai = find_allowed_leq(allowed, rounded);
943 let count = allowed[ai];
944
945 self.counts[n] = count;
946 rest -= count;
947
948 if target > 1.0 {
950 bins.push((freq, ai, n));
951 }
952 }
953
954 if let Some(pos) = bins.iter().position(|b| b.2 == remainder_pos) {
956 bins.remove(pos);
957 }
958
959 rest += self.counts[remainder_pos];
962
963 if !bins.is_empty() {
967 let max_freq_f = max_freq as f64;
968 let lg2 = |v: i32| -> f64 {
971 if v <= 0 {
972 0.0
973 } else {
974 jxl_simd::fast_log2f(v as f32) as f64
975 }
976 };
977
978 loop {
979 let mut best_inc_net = 0.0f64; let mut best_inc_bi = None;
982
983 let mut best_dec_net = 0.0f64; let mut best_dec_bi = None;
986
987 for (bi, &(freq, ai, _bin)) in bins.iter().enumerate() {
988 let count = allowed[ai];
989 let freq_f = freq as f64;
990 let lg2_count = lg2(count);
991
992 if ai > 0 {
994 let new_count = allowed[ai - 1];
995 let step = new_count - count;
996 let new_rest = rest - step;
997 if new_rest > 0 || rest >= ANS_TAB_SIZE as i32 {
998 let gain = freq_f * (lg2(new_count) - lg2_count);
999 let cost = if rest >= ANS_TAB_SIZE as i32 {
1000 0.0 } else if rest > 0 && new_rest > 0 {
1002 max_freq_f * (lg2(rest) - lg2(new_rest))
1003 } else {
1004 f64::MAX
1005 };
1006 let net = gain - cost;
1007 let step_log = floor_log2(step as u32);
1009 let norm_net = if step_log > 0 {
1010 net / (1u32 << step_log) as f64
1011 } else {
1012 net
1013 };
1014 if norm_net > best_inc_net {
1015 best_inc_net = norm_net;
1016 best_inc_bi = Some(bi);
1017 }
1018 }
1019 }
1020
1021 if ai + 1 < allowed.len() && allowed[ai + 1] > 0 {
1023 let new_count = allowed[ai + 1];
1024 let step = count - new_count;
1025 let new_rest = rest + step;
1026 if new_rest < ANS_TAB_SIZE as i32 || rest <= 1 {
1027 let loss = freq_f * (lg2_count - lg2(new_count));
1028 let gain = if rest <= 1 {
1029 f64::MAX } else if rest > 0 && new_rest < ANS_TAB_SIZE as i32 {
1031 max_freq_f * (lg2(new_rest) - lg2(rest))
1032 } else {
1033 0.0
1034 };
1035 let net = gain - loss;
1036 let step_log = floor_log2(step as u32);
1037 let norm_net = if step_log > 0 {
1038 net / (1u32 << step_log) as f64
1039 } else {
1040 net
1041 };
1042 if norm_net > best_dec_net {
1043 best_dec_net = norm_net;
1044 best_dec_bi = Some(bi);
1045 }
1046 }
1047 }
1048 }
1049
1050 if best_inc_net > 0.0 {
1052 if let Some(bi) = best_inc_bi {
1053 let step = allowed[bins[bi].1 - 1] - allowed[bins[bi].1];
1054 bins[bi].1 -= 1; rest -= step;
1056 }
1057 } else if best_dec_net > 0.0 {
1058 if let Some(bi) = best_dec_bi {
1059 let step = allowed[bins[bi].1] - allowed[bins[bi].1 + 1];
1060 bins[bi].1 += 1; rest += step;
1062 }
1063 } else {
1064 break; }
1066 }
1067
1068 for &(_freq, ai, bin) in &bins {
1070 self.counts[bin] = allowed[ai];
1071 }
1072
1073 for n in 0..remainder_pos {
1077 if self.counts[n] >= 2048 {
1078 self.counts[remainder_pos] = self.counts[n];
1079 remainder_pos = n;
1080 break;
1081 }
1082 }
1083 }
1084
1085 self.counts[remainder_pos] = rest;
1087 self.omit_pos = remainder_pos;
1088
1089 if rest <= 0 {
1090 return false;
1091 }
1092
1093 for _ in 0..10 {
1098 let omit_logcount = floor_log2(self.counts[remainder_pos] as u32) + 1;
1099 let mut adjusted = false;
1100 for i in 0..self.alphabet_size {
1101 if i == remainder_pos || self.counts[i] <= 0 {
1102 continue;
1103 }
1104 let logcount = floor_log2(self.counts[i] as u32) + 1;
1105 let needs_fix =
1106 logcount > omit_logcount || (logcount == omit_logcount && i < remainder_pos);
1107 if needs_fix {
1108 let target_logcount = if i < remainder_pos {
1112 omit_logcount.saturating_sub(1)
1113 } else {
1114 omit_logcount
1115 };
1116 let max_value = (1i32 << target_logcount) - 1;
1117 let new_ai = find_allowed_leq(allowed, max_value);
1118 let new_count = allowed[new_ai].max(1);
1119 let reduction = self.counts[i] - new_count;
1120 if reduction > 0 {
1121 self.counts[i] = new_count;
1122 self.counts[remainder_pos] += reduction;
1123 adjusted = true;
1124 }
1125 }
1126 }
1127 if !adjusted {
1128 break;
1129 }
1130 }
1131
1132 let omit_logcount = floor_log2(self.counts[remainder_pos] as u32) + 1;
1134 for (i, &count) in self.counts.iter().enumerate().take(self.alphabet_size) {
1135 if i == remainder_pos || count <= 0 {
1136 continue;
1137 }
1138 let logcount = floor_log2(count as u32) + 1;
1139 if logcount > omit_logcount || (logcount == omit_logcount && i < remainder_pos) {
1140 return false;
1141 }
1142 }
1143
1144 let sum: i32 = self.counts.iter().sum();
1146 sum == ANS_TAB_SIZE as i32
1147 }
1148
1149 fn estimate_cost(&self, histo: &super::histogram::Histogram) -> f32 {
1152 let header_cost = self.estimate_header_cost();
1153 let data_cost = estimate_data_bits_normalized(
1154 &histo.counts,
1155 &self.counts,
1156 histo.total_count,
1157 self.alphabet_size,
1158 ) as f32;
1159 header_cost + data_cost
1160 }
1161
1162 fn estimate_header_cost(&self) -> f32 {
1164 if self.method == 0 {
1165 2.0 + 8.0
1167 } else if self.num_symbols <= 2 {
1168 if self.num_symbols <= 1 {
1170 3.0 + 8.0 } else {
1172 3.0 + 16.0 + 12.0 }
1174 } else {
1175 let method_bits = 4.0; let alphabet_bits = 8.0;
1178 let freq_bits = self.alphabet_size as f32 * 5.0; method_bits + alphabet_bits + freq_bits
1180 }
1181 }
1182
1183 pub fn write(&self, writer: &mut BitWriter) -> Result<()> {
1185 if self.method == 0 {
1186 writer.write(1, 0)?; writer.write(1, 1)?; write_var_len_uint8(writer, (self.alphabet_size - 1) as u8)?;
1190 return Ok(());
1191 }
1192
1193 if self.num_symbols <= 2 {
1194 writer.write(1, 1)?; if self.num_symbols == 0 {
1197 writer.write(1, 0)?;
1198 write_var_len_uint8(writer, 0)?;
1199 } else {
1200 writer.write(1, (self.num_symbols - 1) as u64)?;
1201 for i in 0..self.num_symbols {
1202 write_var_len_uint8(writer, self.symbols[i] as u8)?;
1203 }
1204 if self.num_symbols == 2 {
1205 writer.write(
1206 ANS_LOG_TAB_SIZE as usize,
1207 self.counts[self.symbols[0]] as u64,
1208 )?;
1209 }
1210 }
1211 return Ok(());
1212 }
1213
1214 self.write_general(writer)
1216 }
1217
1218 fn write_general(&self, writer: &mut BitWriter) -> Result<()> {
1222 writer.write(1, 0)?; writer.write(1, 0)?; let shift = (self.method - 1) as i32;
1228 let shift_val = (shift + 1) as u32; let mut len = 0u32;
1232 while len < 3 && shift_val >= (1u32 << (len + 1)) {
1233 len += 1;
1234 }
1235
1236 for _ in 0..len {
1238 writer.write(1, 1)?;
1239 }
1240 if len < 3 {
1242 writer.write(1, 0)?;
1243 }
1244 if len > 0 {
1246 let suffix = shift_val - (1u32 << len);
1247 writer.write(len as usize, suffix as u64)?;
1248 }
1249
1250 if self.alphabet_size < 3 {
1252 return Err(Error::InvalidHistogram(
1253 "General histogram needs at least 3 symbols".to_string(),
1254 ));
1255 }
1256 write_var_len_uint8(writer, (self.alphabet_size - 3) as u8)?;
1257
1258 let logcounts: Vec<u32> = (0..self.alphabet_size)
1260 .map(|i| {
1261 let count = self.counts[i];
1262 if count <= 0 {
1263 0
1264 } else {
1265 floor_log2(count as u32) + 1
1266 }
1267 })
1268 .collect();
1269
1270 let mut same = vec![0usize; self.alphabet_size];
1278 #[allow(clippy::needless_range_loop)]
1279 for i in 0..self.alphabet_size {
1280 if i == self.omit_pos {
1281 continue;
1282 }
1283 let mut run = 0;
1284 let mut j = i + 1;
1285 while j < self.alphabet_size && self.counts[j] == self.counts[i] {
1286 if j == self.omit_pos {
1287 break; }
1289 run += 1;
1290 j += 1;
1291 }
1292 same[i] = run;
1293 }
1294
1295 const MIN_REPS: usize = 4; let mut i = 0;
1299 while i < self.alphabet_size {
1300 let (nbits, code) = LOGCOUNT_PREFIX_CODE[logcounts[i] as usize];
1302 writer.write(nbits as usize, code as u64)?;
1303
1304 if same[i] >= MIN_REPS && i + 1 != self.omit_pos + 1 {
1307 let (rle_nbits, rle_code) = LOGCOUNT_PREFIX_CODE[RLE_MARKER_SYM as usize];
1308 writer.write(rle_nbits as usize, rle_code as u64)?;
1309 write_var_len_uint8(writer, (same[i] - MIN_REPS) as u8)?;
1310 i += same[i]; }
1312 i += 1;
1313 }
1314
1315 let mut rle_covered = vec![false; self.alphabet_size];
1318 {
1319 let mut i = 0;
1320 while i < self.alphabet_size {
1321 if same[i] >= MIN_REPS && i + 1 != self.omit_pos + 1 {
1322 for item in rle_covered.iter_mut().take(i + same[i] + 1).skip(i + 1) {
1324 *item = true;
1325 }
1326 i += same[i];
1327 }
1328 i += 1;
1329 }
1330 }
1331
1332 for i in 0..self.alphabet_size {
1335 if i == self.omit_pos || rle_covered[i] {
1336 continue;
1337 }
1338
1339 let count = self.counts[i];
1340 if count <= 0 {
1341 continue;
1342 }
1343
1344 let logcount = logcounts[i];
1345 if logcount <= 1 {
1346 continue;
1348 }
1349
1350 let zeros = (logcount - 1) as i32;
1352 let bitcount = (shift - ((ANS_LOG_TAB_SIZE as i32 - zeros) >> 1)).clamp(0, zeros);
1354
1355 if bitcount > 0 {
1356 let base = 1i32 << zeros;
1359 let extra = ((count - base) >> (zeros - bitcount)) as u32;
1360 writer.write(bitcount as usize, extra as u64)?;
1361 }
1362 }
1363
1364 Ok(())
1365 }
1366}
1367
1368impl Default for ANSEncodingHistogram {
1369 fn default() -> Self {
1370 Self::new()
1371 }
1372}
1373
1374pub fn encode_tokens_ans(
1376 tokens: &[(u32, u32)], distributions: &[AnsDistribution],
1378 context_map: &[usize],
1379 writer: &mut BitWriter,
1380) -> Result<()> {
1381 let mut encoder = AnsEncoder::new();
1382
1383 for &(context, value) in tokens.iter().rev() {
1385 let dist_idx = context_map.get(context as usize).copied().unwrap_or(0);
1386 if let Some(dist) = distributions.get(dist_idx)
1387 && let Some(info) = dist.get(value as usize)
1388 {
1389 encoder.put_symbol(info);
1390 }
1391 }
1392
1393 encoder.finalize(writer)
1394}
1395
1396#[cfg(test)]
1397mod tests {
1398 use super::*;
1399 use crate::entropy_coding::histogram::Histogram;
1400
1401 #[test]
1402 fn test_ans_encoding_histogram_single_symbol() {
1403 let h = Histogram::from_counts(&[100, 0, 0, 0]);
1404 let encoded = ANSEncodingHistogram::from_histogram(&h, ANSHistogramStrategy::Fast).unwrap();
1405
1406 assert_eq!(encoded.num_symbols, 1);
1407 assert_eq!(encoded.method, 1); assert_eq!(encoded.counts[0], ANS_TAB_SIZE as i32);
1409 assert!(encoded.cost < 100.0); }
1411
1412 #[test]
1413 fn test_ans_encoding_histogram_two_symbols() {
1414 let h = Histogram::from_counts(&[100, 100, 0, 0]);
1415 let encoded = ANSEncodingHistogram::from_histogram(&h, ANSHistogramStrategy::Fast).unwrap();
1416
1417 assert_eq!(encoded.num_symbols, 2);
1418 assert_eq!(encoded.method, 1); let sum: i32 = encoded.counts.iter().sum();
1421 assert_eq!(sum, ANS_TAB_SIZE as i32);
1422 assert!(encoded.counts[0] > 0);
1423 assert!(encoded.counts[1] > 0);
1424 }
1425
1426 #[test]
1427 fn test_ans_encoding_histogram_general() {
1428 let h = Histogram::from_counts(&[100, 50, 25, 10, 5, 3, 2, 1]);
1429 let encoded = ANSEncodingHistogram::from_histogram(&h, ANSHistogramStrategy::Fast).unwrap();
1430
1431 assert!(encoded.method >= 2 || encoded.method == 0);
1433
1434 let sum: i32 = encoded.counts.iter().sum();
1436 assert_eq!(sum, ANS_TAB_SIZE as i32);
1437
1438 for (i, &orig) in h.counts.iter().enumerate() {
1440 if orig > 0 {
1441 assert!(
1442 encoded.counts.get(i).copied().unwrap_or(0) > 0,
1443 "Symbol {} had count {} but normalized to 0",
1444 i,
1445 orig
1446 );
1447 }
1448 }
1449 }
1450
1451 #[test]
1452 fn test_ans_encoding_histogram_empty() {
1453 let h = Histogram::new();
1454 let encoded = ANSEncodingHistogram::from_histogram(&h, ANSHistogramStrategy::Fast).unwrap();
1455
1456 assert_eq!(encoded.cost, 0.0);
1457 assert_eq!(encoded.method, 0); }
1459
1460 #[test]
1461 fn test_get_population_count_precision() {
1462 assert_eq!(get_population_count_precision(0, 12), 0);
1464
1465 assert_eq!(get_population_count_precision(12, 12), 12);
1467
1468 assert_eq!(get_population_count_precision(6, 6), 3);
1470
1471 assert_eq!(get_population_count_precision(1, 0), 0);
1473 }
1474
1475 #[test]
1476 fn test_ans_encoding_histogram_write() {
1477 let h = Histogram::from_counts(&[100, 0, 0, 0]);
1478 let encoded = ANSEncodingHistogram::from_histogram(&h, ANSHistogramStrategy::Fast).unwrap();
1479
1480 let mut writer = BitWriter::new();
1481 encoded.write(&mut writer).unwrap();
1482
1483 let bytes = writer.finish_with_padding();
1484 assert!(!bytes.is_empty());
1485 }
1486
1487 #[test]
1488 fn test_flat_distribution() {
1489 let dist = AnsDistribution::flat(16).unwrap();
1490 assert_eq!(dist.alphabet_size(), 16);
1491
1492 for sym in &dist.symbols {
1494 assert_eq!(sym.freq, 256);
1495 }
1496 }
1497
1498 #[test]
1499 fn test_from_frequencies() {
1500 let freqs = vec![100, 200, 300, 400];
1501 let dist = AnsDistribution::from_frequencies(&freqs).unwrap();
1502 assert_eq!(dist.alphabet_size(), 4);
1503
1504 let total: u32 = dist.symbols.iter().map(|s| s.freq as u32).sum();
1506 assert_eq!(total, ANS_TAB_SIZE);
1507 }
1508
1509 #[test]
1510 fn test_ans_encoder_basic() {
1511 let dist = AnsDistribution::flat(4).unwrap();
1512 let mut encoder = AnsEncoder::new();
1513
1514 encoder.put_symbol(&dist.symbols[0]);
1516 encoder.put_symbol(&dist.symbols[1]);
1517 encoder.put_symbol(&dist.symbols[2]);
1518
1519 assert_ne!(encoder.state(), ANS_SIGNATURE << 16);
1521 }
1522
1523 #[test]
1524 fn test_reverse_map() {
1525 let dist = AnsDistribution::flat(4).unwrap();
1526
1527 for sym in &dist.symbols {
1529 assert_eq!(sym.reverse_map.len(), sym.freq as usize);
1530 }
1531
1532 let mut covered = vec![false; ANS_TAB_SIZE as usize];
1534 for sym in &dist.symbols {
1535 for &pos in &sym.reverse_map {
1536 assert!(!covered[pos as usize], "position {} covered twice", pos);
1537 covered[pos as usize] = true;
1538 }
1539 }
1540 assert!(covered.iter().all(|&c| c), "not all positions covered");
1541 }
1542
1543 #[test]
1544 fn test_write_distribution() {
1545 let dist = AnsDistribution::flat(16).unwrap();
1546 let mut writer = BitWriter::new();
1547 dist.write(&mut writer).unwrap();
1548
1549 let bytes = writer.finish_with_padding();
1550 assert!(!bytes.is_empty());
1552 }
1553
1554 #[test]
1555 fn test_ans_roundtrip_manual() {
1556 let dist = AnsDistribution::flat(2).unwrap();
1558
1559 println!("Distribution: {} symbols", dist.alphabet_size());
1560 for (i, sym) in dist.symbols.iter().enumerate() {
1561 println!(" Symbol {}: freq={}", i, sym.freq);
1562 }
1563
1564 let mut encoder = AnsEncoder::new();
1566 let initial_state = encoder.state();
1567 println!("\nInitial state: 0x{:08x}", initial_state);
1568 assert_eq!(initial_state, 0x130000, "Initial state should be 0x130000");
1569
1570 let info = &dist.symbols[0];
1571 encoder.put_symbol(info);
1572 let encoded_state = encoder.state();
1573 println!("After encoding symbol 0: state=0x{:08x}", encoded_state);
1574
1575 let idx = encoded_state & 0xFFF;
1577 println!("Decode: idx = {}", idx);
1578
1579 let decoded_symbol = if idx < 2048 { 0 } else { 1 };
1583 let offset_in_symbol = if idx < 2048 { idx } else { idx - 2048 };
1584 let freq = 2048u32;
1585
1586 println!("Decoded symbol: {}", decoded_symbol);
1587 println!("Offset in symbol: {}", offset_in_symbol);
1588
1589 let quotient = encoded_state >> 12;
1591 let next_state = quotient * freq + offset_in_symbol;
1592 println!(
1593 "next_state = {} * {} + {} = 0x{:08x}",
1594 quotient, freq, offset_in_symbol, next_state
1595 );
1596
1597 assert_eq!(next_state, 0x130000, "Decoded state should be 0x130000");
1599 assert_eq!(decoded_symbol, 0, "Decoded symbol should be 0");
1600 }
1601
1602 #[test]
1603 fn test_ans_roundtrip_multiple_symbols() {
1604 use crate::bit_writer::BitWriter;
1605 use crate::entropy_coding::ans_decode::{AnsHistogram, AnsReader, BitReader};
1606
1607 let counts = [1024i32, 1024, 1024, 1024];
1612 let dist = AnsDistribution::from_normalized_counts(&counts).unwrap();
1613
1614 let symbols_to_encode: Vec<usize> = vec![0, 1, 2, 3, 0, 1];
1615 println!(
1616 "Encoding {} symbols: {:?}",
1617 symbols_to_encode.len(),
1618 symbols_to_encode
1619 );
1620
1621 let mut encoder = AnsEncoder::new();
1623 for &sym in symbols_to_encode.iter().rev() {
1624 encoder.put_symbol(&dist.symbols[sym]);
1625 }
1626
1627 println!("Final state after encoding: 0x{:08x}", encoder.state());
1628
1629 let mut writer = BitWriter::new();
1631 encoder.finalize(&mut writer).unwrap();
1632 let encoded_bytes = writer.finish_with_padding();
1633 println!("Encoded bytes: {:02x?}", encoded_bytes);
1634
1635 let ans_histo = ANSEncodingHistogram::from_histogram(
1637 &Histogram::from_counts(&counts),
1638 ANSHistogramStrategy::Precise,
1639 )
1640 .unwrap();
1641 let mut hist_writer = BitWriter::new();
1642 ans_histo.write(&mut hist_writer).unwrap();
1643 let hist_bytes = hist_writer.finish_with_padding();
1644
1645 let mut hist_br = BitReader::new(&hist_bytes);
1646 let decoded_hist = AnsHistogram::decode(&mut hist_br, 6).unwrap();
1647
1648 println!(
1649 "Decoded histogram frequencies: {:?}",
1650 &decoded_hist.frequencies[..4]
1651 );
1652
1653 let mut br = BitReader::new(&encoded_bytes);
1655 let mut ans_reader = AnsReader::init(&mut br).unwrap();
1656
1657 println!("Decoding:");
1658 let mut decoded = Vec::new();
1659 for i in 0..symbols_to_encode.len() {
1660 let symbol = decoded_hist.read(&mut br, &mut ans_reader.0) as usize;
1661 println!(
1662 " step {}: symbol={}, state=0x{:08x}",
1663 i, symbol, ans_reader.0
1664 );
1665 decoded.push(symbol);
1666 }
1667
1668 println!("Original: {:?}", symbols_to_encode);
1669 println!("Decoded: {:?}", decoded);
1670 println!("Final state: 0x{:08x}", ans_reader.0);
1671
1672 assert_eq!(
1673 decoded, symbols_to_encode,
1674 "Decoded symbols should match original"
1675 );
1676 assert!(
1677 ans_reader.check_final_state().is_ok(),
1678 "Final state should be 0x130000, got 0x{:08x}",
1679 ans_reader.0
1680 );
1681 }
1682
1683 #[test]
1684 fn test_ans_histogram_write_decode_roundtrip() {
1685 use crate::bit_writer::BitWriter;
1686 use crate::entropy_coding::histogram::Histogram;
1687
1688 let histo = Histogram::from_counts(&[100, 50, 25, 10]);
1690
1691 let encoded =
1692 ANSEncodingHistogram::from_histogram(&histo, ANSHistogramStrategy::Precise).unwrap();
1693
1694 println!("Histogram: {:?}", histo.counts);
1695 println!("Encoded counts: {:?}", encoded.counts);
1696 println!(
1697 "Method: {}, alphabet_size: {}, omit_pos: {}",
1698 encoded.method, encoded.alphabet_size, encoded.omit_pos
1699 );
1700
1701 let sum: i32 = encoded.counts.iter().sum();
1703 assert_eq!(sum, ANS_TAB_SIZE as i32, "Sum should be 4096");
1704
1705 let mut writer = BitWriter::new();
1707 encoded.write(&mut writer).unwrap();
1708 let bytes = writer.finish_with_padding();
1709
1710 println!("Encoded histogram: {} bytes", bytes.len());
1711 println!("Bytes: {:02x?}", bytes);
1712 }
1713}