1use crate::error::{Error, Result};
8
9const LOG_SUM_PROBS: usize = 12;
10const SUM_PROBS: u16 = 1 << LOG_SUM_PROBS;
11const RLE_MARKER_SYM: u16 = LOG_SUM_PROBS as u16 + 1;
12
13pub struct BitReader<'a> {
15 data: &'a [u8],
16 bit_pos: usize,
17}
18
19impl<'a> BitReader<'a> {
20 pub fn new(data: &'a [u8]) -> Self {
21 Self { data, bit_pos: 0 }
22 }
23
24 pub fn read(&mut self, n: usize) -> Result<u64> {
25 let mut val = 0u64;
26 for i in 0..n {
27 let byte_idx = self.bit_pos / 8;
28 let bit_idx = self.bit_pos % 8;
29 if byte_idx >= self.data.len() {
30 return Err(Error::Bitstream("unexpected end of data".to_string()));
31 }
32 let bit = ((self.data[byte_idx] >> bit_idx) & 1) as u64;
33 val |= bit << i;
34 self.bit_pos += 1;
35 }
36 Ok(val)
37 }
38
39 pub fn peek(&mut self, n: usize) -> u64 {
40 let old_pos = self.bit_pos;
41 let val = self.read(n).unwrap_or(0);
42 self.bit_pos = old_pos;
43 val
44 }
45
46 pub fn consume(&mut self, n: usize) -> Result<()> {
47 self.bit_pos += n;
48 if self.bit_pos > self.data.len() * 8 {
49 return Err(Error::Bitstream("unexpected end of data".to_string()));
50 }
51 Ok(())
52 }
53
54 pub fn bits_read(&self) -> usize {
55 self.bit_pos
56 }
57}
58
59#[derive(Debug)]
61pub struct AnsHistogram {
62 pub buckets: Vec<Bucket>,
63 pub log_bucket_size: usize,
64 pub bucket_mask: u32,
65 pub single_symbol: Option<u32>,
66 pub frequencies: Vec<u16>,
67}
68
69#[derive(Debug, Copy, Clone)]
70pub struct Bucket {
71 pub alias_symbol: u8,
72 pub alias_cutoff: u8,
73 pub dist: u16,
74 pub alias_offset: u16,
75 pub alias_dist_xor: u16,
76}
77
78impl AnsHistogram {
79 pub fn decode(br: &mut BitReader, log_alpha_size: usize) -> Result<Self> {
80 debug_assert!((5..=8).contains(&log_alpha_size));
81 let table_size = 1usize << log_alpha_size;
82 let log_bucket_size = LOG_SUM_PROBS - log_alpha_size;
83 let bucket_size = 1u16 << log_bucket_size;
84 let bucket_mask = bucket_size as u32 - 1;
85
86 let mut dist = vec![0u16; table_size];
87 let alphabet_size = if br.read(1)? != 0 {
88 if br.read(1)? != 0 {
89 Self::decode_dist_two_symbols(br, &mut dist)?
90 } else {
91 Self::decode_dist_single_symbol(br, &mut dist)?
92 }
93 } else if br.read(1)? != 0 {
94 Self::decode_dist_evenly_distributed(br, &mut dist)?
95 } else {
96 Self::decode_dist_complex(br, &mut dist)?
97 };
98
99 let frequencies = dist.clone();
100
101 if let Some(single_sym_idx) = dist.iter().position(|&d| d == SUM_PROBS) {
102 let buckets = dist
103 .into_iter()
104 .enumerate()
105 .map(|(i, dist)| Bucket {
106 dist,
107 alias_symbol: single_sym_idx as u8,
108 alias_offset: bucket_size * i as u16,
109 alias_cutoff: 0,
110 alias_dist_xor: dist ^ SUM_PROBS,
111 })
112 .collect();
113 return Ok(Self {
114 buckets,
115 log_bucket_size,
116 bucket_mask,
117 single_symbol: Some(single_sym_idx as u32),
118 frequencies,
119 });
120 }
121
122 Ok(Self {
123 buckets: Self::build_alias_map(alphabet_size, log_bucket_size, &dist),
124 log_bucket_size,
125 bucket_mask,
126 single_symbol: None,
127 frequencies,
128 })
129 }
130
131 fn decode_dist_two_symbols(br: &mut BitReader, dist: &mut [u16]) -> Result<usize> {
132 let table_size = dist.len();
133
134 let v0 = Self::read_u8(br)? as usize;
135 let v1 = Self::read_u8(br)? as usize;
136 if v0 == v1 {
137 return Err(Error::InvalidHistogram(
138 "two symbols are the same".to_string(),
139 ));
140 }
141
142 let alphabet_size = v0.max(v1) + 1;
143 if alphabet_size > table_size {
144 return Err(Error::InvalidHistogram("alphabet too large".to_string()));
145 }
146
147 let prob = br.read(LOG_SUM_PROBS)? as u16;
148 dist[v0] = prob;
149 dist[v1] = SUM_PROBS - prob;
150
151 Ok(alphabet_size)
152 }
153
154 fn decode_dist_single_symbol(br: &mut BitReader, dist: &mut [u16]) -> Result<usize> {
155 let table_size = dist.len();
156
157 let val = Self::read_u8(br)? as usize;
158 let alphabet_size = val + 1;
159 if alphabet_size > table_size {
160 return Err(Error::InvalidHistogram("alphabet too large".to_string()));
161 }
162
163 dist[val] = SUM_PROBS;
164
165 Ok(alphabet_size)
166 }
167
168 fn decode_dist_evenly_distributed(br: &mut BitReader, dist: &mut [u16]) -> Result<usize> {
169 let table_size = dist.len();
170
171 let alphabet_size = Self::read_u8(br)? as usize + 1;
172 if alphabet_size > table_size {
173 return Err(Error::InvalidHistogram("alphabet too large".to_string()));
174 }
175
176 let base = SUM_PROBS as usize / alphabet_size;
177 let remainder = SUM_PROBS as usize % alphabet_size;
178 dist[0..remainder].fill(base as u16 + 1);
179 dist[remainder..alphabet_size].fill(base as u16);
180
181 Ok(alphabet_size)
182 }
183
184 fn decode_dist_complex(br: &mut BitReader, dist: &mut [u16]) -> Result<usize> {
185 let table_size = dist.len();
186
187 let mut len = 0usize;
188 while len < 3 {
189 if br.read(1)? != 0 {
190 len += 1;
191 } else {
192 break;
193 }
194 }
195
196 let shift = (br.read(len)? + (1 << len) - 1) as i16;
197 if shift > 13 {
198 return Err(Error::InvalidHistogram("shift too large".to_string()));
199 }
200
201 let alphabet_size = Self::read_u8(br)? as usize + 3;
202 if alphabet_size > table_size {
203 return Err(Error::InvalidHistogram("alphabet too large".to_string()));
204 }
205
206 let mut repeat_ranges = Vec::new();
207 let mut omit_data: Option<(u16, usize)> = None;
208 let mut idx = 0;
209 while idx < alphabet_size {
210 dist[idx] = Self::read_prefix(br)?;
211 if dist[idx] == RLE_MARKER_SYM {
212 let repeat_count = Self::read_u8(br)? as usize + 4;
213 if idx + repeat_count > alphabet_size {
214 return Err(Error::InvalidHistogram("RLE overflow".to_string()));
215 }
216 repeat_ranges.push(idx..(idx + repeat_count));
217 idx += repeat_count;
218 continue;
219 }
220 match &mut omit_data {
221 Some((log, pos)) => {
222 if dist[idx] > *log {
223 *log = dist[idx];
224 *pos = idx;
225 }
226 }
227 data => {
228 *data = Some((dist[idx], idx));
229 }
230 }
231 idx += 1;
232 }
233 let Some((_, omit_pos)) = omit_data else {
234 return Err(Error::InvalidHistogram("no omit position".to_string()));
235 };
236 if dist.get(omit_pos + 1) == Some(&RLE_MARKER_SYM) {
237 return Err(Error::InvalidHistogram("RLE after omit".to_string()));
238 }
239
240 let mut repeat_range_idx = 0usize;
241 let mut acc = 0;
242 let mut prev_dist = 0u16;
243 for (idx, code) in dist.iter_mut().enumerate() {
244 if repeat_range_idx < repeat_ranges.len()
245 && repeat_ranges[repeat_range_idx].start <= idx
246 {
247 if repeat_ranges[repeat_range_idx].end == idx {
248 repeat_range_idx += 1;
249 } else {
250 *code = prev_dist;
251 acc += *code;
252 if acc >= SUM_PROBS {
253 return Err(Error::InvalidHistogram("sum overflow".to_string()));
254 }
255 continue;
256 }
257 }
258
259 if *code == 0 {
260 prev_dist = 0;
261 continue;
262 }
263 if idx == omit_pos {
264 prev_dist = 0;
265 continue;
266 }
267 if *code > 1 {
268 let zeros = (*code - 1) as i16;
269 let bitcount = (shift - ((LOG_SUM_PROBS as i16 - zeros) >> 1)).clamp(0, zeros);
270 *code = (1 << zeros) + ((br.read(bitcount as usize)? as u16) << (zeros - bitcount));
271 }
272
273 prev_dist = *code;
274 acc += *code;
275 if acc >= SUM_PROBS {
276 return Err(Error::InvalidHistogram("sum overflow".to_string()));
277 }
278 }
279 dist[omit_pos] = SUM_PROBS - acc;
280
281 Ok(alphabet_size)
282 }
283
284 pub fn build_alias_map_from_freqs(
286 alphabet_size: usize,
287 log_bucket_size: usize,
288 dist: &[u16],
289 ) -> Vec<Bucket> {
290 Self::build_alias_map(alphabet_size, log_bucket_size, dist)
291 }
292
293 fn build_alias_map(alphabet_size: usize, log_bucket_size: usize, dist: &[u16]) -> Vec<Bucket> {
294 struct WorkingBucket {
295 dist: u16,
296 alias_symbol: u16,
297 alias_offset: u16,
298 alias_cutoff: u16,
299 }
300
301 let bucket_size = 1u16 << log_bucket_size;
302 let mut buckets: Vec<_> = dist
303 .iter()
304 .enumerate()
305 .map(|(i, &dist)| WorkingBucket {
306 dist,
307 alias_symbol: if i < alphabet_size { i as u16 } else { 0 },
308 alias_offset: 0,
309 alias_cutoff: dist,
310 })
311 .collect();
312
313 let mut underfull = Vec::new();
314 let mut overfull = Vec::new();
315 for (idx, bucket) in buckets.iter().enumerate() {
316 match bucket.dist.cmp(&bucket_size) {
317 std::cmp::Ordering::Less => underfull.push(idx),
318 std::cmp::Ordering::Equal => {}
319 std::cmp::Ordering::Greater => overfull.push(idx),
320 }
321 }
322 while let (Some(o), Some(u)) = (overfull.pop(), underfull.pop()) {
323 let by = bucket_size - buckets[u].alias_cutoff;
324 buckets[o].alias_cutoff -= by;
325 buckets[u].alias_symbol = o as u16;
326 buckets[u].alias_offset = buckets[o].alias_cutoff;
327 match buckets[o].alias_cutoff.cmp(&bucket_size) {
328 std::cmp::Ordering::Less => underfull.push(o),
329 std::cmp::Ordering::Equal => {}
330 std::cmp::Ordering::Greater => overfull.push(o),
331 }
332 }
333
334 buckets
335 .iter()
336 .enumerate()
337 .map(|(idx, bucket)| {
338 if bucket.alias_cutoff == bucket_size {
339 Bucket {
340 dist: bucket.dist,
341 alias_symbol: idx as u8,
342 alias_offset: 0,
343 alias_cutoff: 0,
344 alias_dist_xor: 0,
345 }
346 } else {
347 Bucket {
348 dist: bucket.dist,
349 alias_symbol: bucket.alias_symbol as u8,
350 alias_offset: bucket.alias_offset - bucket.alias_cutoff,
351 alias_cutoff: bucket.alias_cutoff as u8,
352 alias_dist_xor: bucket.dist ^ buckets[bucket.alias_symbol as usize].dist,
353 }
354 }
355 })
356 .collect()
357 }
358
359 fn read_u8(br: &mut BitReader) -> Result<u8> {
360 Ok(if br.read(1)? != 0 {
361 let n = br.read(3)?;
362 ((1 << n) + br.read(n as usize)?) as u8
363 } else {
364 0
365 })
366 }
367
368 fn read_prefix(br: &mut BitReader) -> Result<u16> {
369 #[rustfmt::skip]
370 const TABLE: [(u8, u8); 128] = [
371 (10, 3), (12, 7), (7, 3), (3, 4), (6, 3), (8, 3), (9, 3), (5, 4),
372 (10, 3), ( 4, 4), (7, 3), (1, 4), (6, 3), (8, 3), (9, 3), (2, 4),
373 (10, 3), ( 0, 5), (7, 3), (3, 4), (6, 3), (8, 3), (9, 3), (5, 4),
374 (10, 3), ( 4, 4), (7, 3), (1, 4), (6, 3), (8, 3), (9, 3), (2, 4),
375 (10, 3), (11, 6), (7, 3), (3, 4), (6, 3), (8, 3), (9, 3), (5, 4),
376 (10, 3), ( 4, 4), (7, 3), (1, 4), (6, 3), (8, 3), (9, 3), (2, 4),
377 (10, 3), ( 0, 5), (7, 3), (3, 4), (6, 3), (8, 3), (9, 3), (5, 4),
378 (10, 3), ( 4, 4), (7, 3), (1, 4), (6, 3), (8, 3), (9, 3), (2, 4),
379 (10, 3), (13, 7), (7, 3), (3, 4), (6, 3), (8, 3), (9, 3), (5, 4),
380 (10, 3), ( 4, 4), (7, 3), (1, 4), (6, 3), (8, 3), (9, 3), (2, 4),
381 (10, 3), ( 0, 5), (7, 3), (3, 4), (6, 3), (8, 3), (9, 3), (5, 4),
382 (10, 3), ( 4, 4), (7, 3), (1, 4), (6, 3), (8, 3), (9, 3), (2, 4),
383 (10, 3), (11, 6), (7, 3), (3, 4), (6, 3), (8, 3), (9, 3), (5, 4),
384 (10, 3), ( 4, 4), (7, 3), (1, 4), (6, 3), (8, 3), (9, 3), (2, 4),
385 (10, 3), ( 0, 5), (7, 3), (3, 4), (6, 3), (8, 3), (9, 3), (5, 4),
386 (10, 3), ( 4, 4), (7, 3), (1, 4), (6, 3), (8, 3), (9, 3), (2, 4),
387 ];
388
389 let index = br.peek(7) as usize;
390 let (sym, bits) = TABLE[index];
391 br.consume(bits as usize)?;
392 Ok(sym as u16)
393 }
394
395 pub fn read(&self, br: &mut BitReader, state: &mut u32) -> u32 {
397 let idx = *state & 0xfff;
398 let i = (idx >> self.log_bucket_size) as usize;
399 let pos = idx & self.bucket_mask;
400
401 let bucket = &self.buckets[i & (self.buckets.len() - 1)];
402 let alias_symbol = bucket.alias_symbol as u32;
403 let alias_cutoff = bucket.alias_cutoff as u32;
404 let dist = bucket.dist as u32;
405
406 let map_to_alias = (pos >= alias_cutoff) as u32;
407 let offset = (bucket.alias_offset as u32) * map_to_alias;
408 let dist_xor = (bucket.alias_dist_xor as u32) * map_to_alias;
409
410 let dist = dist ^ dist_xor;
411 let symbol = (alias_symbol * map_to_alias) | (i as u32 * (1 - map_to_alias));
412 let offset = offset + pos;
413
414 let next_state = (*state >> LOG_SUM_PROBS) * dist + offset;
415 let select_appended = (next_state < (1 << 16)) as u32;
416 let appended_bits = br.peek(16) as u32;
417 let appended_state = (next_state << 16) | appended_bits;
418 *state = (appended_state * select_appended) | (next_state * (1 - select_appended));
419 if select_appended != 0 {
420 br.consume(16).ok();
421 }
422 symbol
423 }
424}
425
426pub struct AnsReader(pub u32);
428
429impl AnsReader {
430 pub const CHECKSUM: u32 = 0x130000;
431
432 pub fn init(br: &mut BitReader) -> Result<Self> {
433 let initial_state = br.read(32)? as u32;
434 Ok(Self(initial_state))
435 }
436
437 pub fn check_final_state(&self) -> Result<()> {
438 if self.0 == Self::CHECKSUM {
439 Ok(())
440 } else {
441 Err(Error::Bitstream(format!(
442 "ANS checksum mismatch: got 0x{:08x}, expected 0x{:08x}",
443 self.0,
444 Self::CHECKSUM
445 )))
446 }
447 }
448
449 pub fn state(&self) -> u32 {
450 self.0
451 }
452}
453
454#[cfg(test)]
455mod tests {
456 use super::*;
457 use crate::bit_writer::BitWriter;
458 use crate::entropy_coding::ans::{ANSEncodingHistogram, ANSHistogramStrategy};
459 use crate::entropy_coding::histogram::Histogram;
460
461 #[test]
462 fn test_decode_single_symbol() {
463 let histo = Histogram::from_counts(&[100, 0, 0, 0]);
465 let ans_histo =
466 ANSEncodingHistogram::from_histogram(&histo, ANSHistogramStrategy::Precise).unwrap();
467
468 let mut writer = BitWriter::new();
469 ans_histo.write(&mut writer).unwrap();
470 let bytes = writer.finish_with_padding();
471
472 println!("Single symbol histogram bytes: {:02x?}", bytes);
473
474 let mut br = BitReader::new(&bytes);
476 let decoded = AnsHistogram::decode(&mut br, 6).unwrap();
477
478 println!("Decoded frequencies: {:?}", &decoded.frequencies[..4]);
479 println!("Single symbol: {:?}", decoded.single_symbol);
480
481 assert_eq!(decoded.single_symbol, Some(0));
483 assert_eq!(decoded.frequencies[0], 4096);
484 }
485
486 #[test]
487 fn test_decode_two_symbols() {
488 let histo = Histogram::from_counts(&[100, 100, 0, 0]);
490 let ans_histo =
491 ANSEncodingHistogram::from_histogram(&histo, ANSHistogramStrategy::Precise).unwrap();
492
493 println!("Two symbol histogram: {:?}", ans_histo.counts);
494
495 let mut writer = BitWriter::new();
496 ans_histo.write(&mut writer).unwrap();
497 let bytes = writer.finish_with_padding();
498
499 println!("Two symbol histogram bytes: {:02x?}", bytes);
500
501 let mut br = BitReader::new(&bytes);
503 let decoded = AnsHistogram::decode(&mut br, 6).unwrap();
504
505 println!("Decoded frequencies: {:?}", &decoded.frequencies[..4]);
506
507 let sum: u16 = decoded.frequencies.iter().sum();
509 assert_eq!(sum, 4096, "Sum should be 4096");
510
511 assert_eq!(decoded.frequencies[0], ans_histo.counts[0] as u16);
513 assert_eq!(decoded.frequencies[1], ans_histo.counts[1] as u16);
514 }
515
516 #[test]
517 fn test_decode_general_histogram() {
518 let histo = Histogram::from_counts(&[100, 50, 25, 10]);
520 let ans_histo =
521 ANSEncodingHistogram::from_histogram(&histo, ANSHistogramStrategy::Precise).unwrap();
522
523 println!("General histogram:");
524 println!(" counts: {:?}", ans_histo.counts);
525 println!(
526 " method: {}, alphabet_size: {}, omit_pos: {}",
527 ans_histo.method, ans_histo.alphabet_size, ans_histo.omit_pos
528 );
529
530 let mut writer = BitWriter::new();
531 ans_histo.write(&mut writer).unwrap();
532 let bytes = writer.finish_with_padding();
533
534 println!(" bytes ({} bytes): {:02x?}", bytes.len(), bytes);
535
536 let mut br = BitReader::new(&bytes);
538 let decoded = AnsHistogram::decode(&mut br, 6).unwrap();
539
540 println!(
541 "Decoded frequencies: {:?}",
542 &decoded.frequencies[..ans_histo.alphabet_size]
543 );
544
545 let sum: u16 = decoded.frequencies.iter().sum();
547 assert_eq!(sum, 4096, "Sum should be 4096");
548
549 for i in 0..ans_histo.alphabet_size {
551 assert_eq!(
552 decoded.frequencies[i], ans_histo.counts[i] as u16,
553 "Frequency mismatch at symbol {}",
554 i
555 );
556 }
557 }
558
559 #[test]
560 fn test_decode_sparse_histogram_roundtrip() {
561 let mut raw_counts = vec![0i32; 40]; raw_counts[1] = 196000; raw_counts[31] = 100; raw_counts[35] = 100; let histo = Histogram::from_counts(&raw_counts);
569
570 let ans_histo =
571 ANSEncodingHistogram::from_histogram(&histo, ANSHistogramStrategy::Precise).unwrap();
572
573 println!("Sparse histogram:");
574 println!(
575 " method={}, alphabet_size={}, omit_pos={}",
576 ans_histo.method, ans_histo.alphabet_size, ans_histo.omit_pos
577 );
578 println!(" non-zero counts:");
579 for (i, &c) in ans_histo.counts.iter().enumerate() {
580 if c != 0 {
581 println!(" [{}] = {}", i, c);
582 }
583 }
584
585 let mut writer = BitWriter::new();
586 ans_histo.write(&mut writer).unwrap();
587 writer.write(8, 0).unwrap();
589 writer.zero_pad_to_byte();
590 let bytes = writer.finish();
591
592 println!(
593 " encoded bytes ({} bytes): {:02x?}",
594 bytes.len(),
595 &bytes[..bytes.len().min(32)]
596 );
597
598 let mut br = BitReader::new(&bytes);
600 let decoded = AnsHistogram::decode(&mut br, 6).unwrap();
601
602 println!(" decoded frequencies:");
603 for (i, &f) in decoded.frequencies.iter().enumerate() {
604 if f != 0 {
605 println!(" [{}] = {}", i, f);
606 }
607 }
608
609 let sum: u16 = decoded.frequencies.iter().sum();
611 assert_eq!(sum, 4096, "Sum should be 4096 but got {}", sum);
612
613 for i in 0..ans_histo.alphabet_size {
614 assert_eq!(
615 decoded.frequencies[i], ans_histo.counts[i] as u16,
616 "Frequency mismatch at symbol {}: encoder wrote {}, decoder read {}",
617 i, ans_histo.counts[i], decoded.frequencies[i]
618 );
619 }
620 }
621}