1use crate::error::{CodecError, CodecResult};
15
16const DEFAULT_LOG_TABLE_SIZE: u8 = 10;
18
19const RENORM_WORD_BITS: u32 = 16;
21
22#[derive(Clone, Debug)]
27pub struct AnsDistribution {
28 pub symbols: Vec<u16>,
30 pub frequencies: Vec<u32>,
32 pub cumulative: Vec<u32>,
34 pub log_table_size: u8,
36}
37
38impl AnsDistribution {
39 pub fn new(symbols: Vec<u16>, frequencies: Vec<u32>, log_table_size: u8) -> CodecResult<Self> {
43 if symbols.len() != frequencies.len() {
44 return Err(CodecError::InvalidParameter(
45 "Symbol and frequency vectors must have the same length".into(),
46 ));
47 }
48 if symbols.is_empty() {
49 return Err(CodecError::InvalidParameter(
50 "Distribution must have at least one symbol".into(),
51 ));
52 }
53
54 let total: u32 = frequencies.iter().sum();
55 if total == 0 {
56 return Err(CodecError::InvalidParameter(
57 "Total frequency must be non-zero".into(),
58 ));
59 }
60
61 let table_size = 1u32 << log_table_size;
62
63 let mut normalized: Vec<u32> = frequencies
65 .iter()
66 .map(|&f| {
67 if f == 0 {
68 0
69 } else {
70 let n = (f as u64 * table_size as u64 / total as u64) as u32;
71 if n == 0 {
72 1
73 } else {
74 n
75 }
76 }
77 })
78 .collect();
79
80 let current_sum: u32 = normalized.iter().sum();
82 if current_sum != table_size {
83 let diff = table_size as i64 - current_sum as i64;
84 if let Some(max_idx) = normalized
85 .iter()
86 .enumerate()
87 .filter(|(_, &f)| f > 0)
88 .max_by_key(|(_, &f)| f)
89 .map(|(i, _)| i)
90 {
91 let adjusted = normalized[max_idx] as i64 + diff;
92 if adjusted > 0 {
93 normalized[max_idx] = adjusted as u32;
94 }
95 }
96 }
97
98 let mut cumulative = Vec::with_capacity(normalized.len() + 1);
100 cumulative.push(0);
101 let mut sum = 0u32;
102 for &f in &normalized {
103 sum += f;
104 cumulative.push(sum);
105 }
106
107 Ok(Self {
108 symbols,
109 frequencies: normalized,
110 cumulative,
111 log_table_size,
112 })
113 }
114
115 pub fn table_size(&self) -> u32 {
117 1u32 << self.log_table_size
118 }
119
120 pub fn num_symbols(&self) -> usize {
122 self.symbols.len()
123 }
124
125 pub fn lookup(&self, value: u32) -> CodecResult<(usize, u32, u32)> {
127 let mut lo = 0usize;
128 let mut hi = self.symbols.len();
129 while lo < hi {
130 let mid = lo + (hi - lo) / 2;
131 if self.cumulative[mid + 1] <= value {
132 lo = mid + 1;
133 } else {
134 hi = mid;
135 }
136 }
137 if lo >= self.symbols.len() {
138 return Err(CodecError::InvalidBitstream(format!(
139 "ANS lookup failed: value {value} out of range"
140 )));
141 }
142 Ok((lo, self.cumulative[lo], self.frequencies[lo]))
143 }
144
145 fn find_symbol(&self, symbol: u16) -> CodecResult<usize> {
147 self.symbols
148 .iter()
149 .position(|&s| s == symbol)
150 .ok_or_else(|| {
151 CodecError::InvalidParameter(format!("Symbol {symbol} not found in distribution"))
152 })
153 }
154}
155
156pub fn uniform_distribution(n: u16) -> CodecResult<AnsDistribution> {
158 if n == 0 {
159 return Err(CodecError::InvalidParameter(
160 "Cannot create uniform distribution with 0 symbols".into(),
161 ));
162 }
163 let symbols: Vec<u16> = (0..n).collect();
164 let freq = vec![1u32; n as usize];
165 AnsDistribution::new(symbols, freq, DEFAULT_LOG_TABLE_SIZE)
166}
167
168pub fn distribution_from_counts(
170 counts: &[u32],
171 log_table_size: u8,
172) -> CodecResult<AnsDistribution> {
173 let mut symbols = Vec::new();
174 let mut frequencies = Vec::new();
175
176 for (i, &count) in counts.iter().enumerate() {
177 if count > 0 {
178 symbols.push(i as u16);
179 frequencies.push(count);
180 }
181 }
182
183 if symbols.is_empty() {
184 symbols.push(0);
185 frequencies.push(1);
186 }
187
188 AnsDistribution::new(symbols, frequencies, log_table_size)
189}
190
191pub struct AnsDecoder<'a> {
198 state: u32,
199 data: &'a [u8],
200 word_pos: usize,
202}
203
204impl<'a> AnsDecoder<'a> {
205 pub fn new(data: &'a [u8]) -> CodecResult<Self> {
207 if data.len() < 8 {
208 return Err(CodecError::InvalidBitstream("ANS data too short".into()));
209 }
210 let state = u32::from_le_bytes([data[0], data[1], data[2], data[3]]);
211 Ok(Self {
213 state,
214 data,
215 word_pos: 8,
216 })
217 }
218
219 fn read_word(&mut self) -> u16 {
221 if self.word_pos + 1 < self.data.len() {
222 let w = u16::from_le_bytes([self.data[self.word_pos], self.data[self.word_pos + 1]]);
223 self.word_pos += 2;
224 w
225 } else {
226 0
227 }
228 }
229
230 pub fn decode_symbol(&mut self, dist: &AnsDistribution) -> CodecResult<u16> {
232 let table_size = dist.table_size();
233 let mask = table_size - 1;
234
235 let slot = self.state & mask;
236 let (idx, start, freq) = dist.lookup(slot)?;
237 let symbol = dist.symbols[idx];
238
239 self.state = freq * (self.state >> dist.log_table_size) + slot - start;
241
242 if self.state < table_size {
244 let word = self.read_word() as u32;
245 self.state = (self.state << RENORM_WORD_BITS) | word;
246 }
247
248 Ok(symbol)
249 }
250}
251
252pub struct AnsEncoder {
261 state: u32,
262 words: Vec<u16>,
264 log_table_size: u8,
265}
266
267impl AnsEncoder {
268 pub fn new() -> Self {
270 let log_table_size = DEFAULT_LOG_TABLE_SIZE;
271 let table_size = 1u32 << log_table_size;
272 Self {
273 state: table_size, words: Vec::new(),
275 log_table_size,
276 }
277 }
278
279 pub fn encode_symbol(&mut self, symbol: u16, dist: &AnsDistribution) -> CodecResult<()> {
281 let idx = dist.find_symbol(symbol)?;
282 let start = dist.cumulative[idx];
283 let freq = dist.frequencies[idx];
284
285 if freq == 0 {
286 return Err(CodecError::InvalidParameter(format!(
287 "Symbol {symbol} has zero frequency"
288 )));
289 }
290
291 let table_size = dist.table_size();
292
293 let upper_bound = freq << RENORM_WORD_BITS;
297 while self.state >= upper_bound {
298 self.words.push(self.state as u16);
299 self.state >>= RENORM_WORD_BITS;
300 }
301
302 self.state = table_size * (self.state / freq) + (self.state % freq) + start;
304
305 Ok(())
306 }
307
308 pub fn finish(self) -> Vec<u8> {
310 let word_count = self.words.len() as u32;
311 let mut output = Vec::with_capacity(8 + self.words.len() * 2);
312
313 output.extend_from_slice(&self.state.to_le_bytes());
315 output.extend_from_slice(&word_count.to_le_bytes());
317 for &word in self.words.iter().rev() {
319 output.extend_from_slice(&word.to_le_bytes());
320 }
321
322 output
323 }
324}
325
326impl Default for AnsEncoder {
327 fn default() -> Self {
328 Self::new()
329 }
330}
331
332#[cfg(test)]
333mod tests {
334 use super::*;
335
336 #[test]
337 #[ignore]
338 fn test_uniform_distribution() {
339 let dist = uniform_distribution(4).expect("ok");
340 assert_eq!(dist.num_symbols(), 4);
341 assert_eq!(dist.table_size(), 1 << DEFAULT_LOG_TABLE_SIZE);
342 let expected = dist.table_size() / 4;
343 for &f in &dist.frequencies {
344 assert!((f as i64 - expected as i64).unsigned_abs() <= 1);
345 }
346 }
347
348 #[test]
349 #[ignore]
350 fn test_distribution_from_counts() {
351 let counts = [10u32, 20, 30, 0, 40];
352 let dist = distribution_from_counts(&counts, 10).expect("ok");
353 assert_eq!(dist.num_symbols(), 4);
354 assert_eq!(dist.symbols, vec![0, 1, 2, 4]);
355 }
356
357 #[test]
358 #[ignore]
359 fn test_distribution_cumulative() {
360 let symbols = vec![0, 1, 2];
361 let freqs = vec![256, 512, 256];
362 let dist = AnsDistribution::new(symbols, freqs, 10).expect("ok");
363 assert_eq!(dist.cumulative[0], 0);
364 assert_eq!(
365 *dist.cumulative.last().expect("has last"),
366 dist.table_size()
367 );
368 }
369
370 #[test]
371 #[ignore]
372 fn test_distribution_lookup() {
373 let symbols = vec![0, 1];
374 let freqs = vec![512, 512];
375 let dist = AnsDistribution::new(symbols, freqs, 10).expect("ok");
376
377 let (idx, start, freq) = dist.lookup(0).expect("ok");
378 assert_eq!(idx, 0);
379 assert_eq!(start, 0);
380 assert!(freq > 0);
381
382 let (idx, _start, _freq) = dist.lookup(dist.table_size() - 1).expect("ok");
383 assert_eq!(idx, 1);
384 }
385
386 #[test]
387 #[ignore]
388 fn test_ans_roundtrip_single_symbol() {
389 let dist = uniform_distribution(4).expect("ok");
390
391 let mut encoder = AnsEncoder::new();
392 encoder.encode_symbol(2, &dist).expect("ok");
393 let encoded = encoder.finish();
394
395 let mut decoder = AnsDecoder::new(&encoded).expect("ok");
396 let decoded = decoder.decode_symbol(&dist).expect("ok");
397 assert_eq!(decoded, 2);
398 }
399
400 #[test]
401 #[ignore]
402 fn test_ans_roundtrip_sequence() {
403 let dist = uniform_distribution(8).expect("ok");
404 let symbols_to_encode: Vec<u16> = vec![0, 3, 7, 1, 5, 2, 6, 4];
405
406 let mut encoder = AnsEncoder::new();
408 for &sym in symbols_to_encode.iter().rev() {
409 encoder.encode_symbol(sym, &dist).expect("ok");
410 }
411 let encoded = encoder.finish();
412
413 let mut decoder = AnsDecoder::new(&encoded).expect("ok");
415 for &expected in &symbols_to_encode {
416 let decoded = decoder.decode_symbol(&dist).expect("ok");
417 assert_eq!(decoded, expected, "ANS roundtrip mismatch");
418 }
419 }
420
421 #[test]
422 #[ignore]
423 fn test_ans_roundtrip_skewed_distribution() {
424 let symbols = vec![0, 1, 2, 3];
425 let freqs = vec![700, 200, 80, 20];
426 let dist = AnsDistribution::new(symbols, freqs, 10).expect("ok");
427
428 let test_seq: Vec<u16> = vec![0, 0, 0, 1, 0, 2, 0, 0, 3, 0, 1];
429
430 let mut encoder = AnsEncoder::new();
431 for &sym in test_seq.iter().rev() {
432 encoder.encode_symbol(sym, &dist).expect("ok");
433 }
434 let encoded = encoder.finish();
435
436 let mut decoder = AnsDecoder::new(&encoded).expect("ok");
437 for &expected in &test_seq {
438 let decoded = decoder.decode_symbol(&dist).expect("ok");
439 assert_eq!(decoded, expected);
440 }
441 }
442
443 #[test]
444 #[ignore]
445 fn test_ans_roundtrip_repeated_symbol() {
446 let dist = uniform_distribution(4).expect("ok");
447 let symbols: Vec<u16> = vec![1, 1, 1, 1, 1];
448
449 let mut encoder = AnsEncoder::new();
450 for &sym in symbols.iter().rev() {
451 encoder.encode_symbol(sym, &dist).expect("ok");
452 }
453 let encoded = encoder.finish();
454
455 let mut decoder = AnsDecoder::new(&encoded).expect("ok");
456 for &expected in &symbols {
457 let decoded = decoder.decode_symbol(&dist).expect("ok");
458 assert_eq!(decoded, expected);
459 }
460 }
461
462 #[test]
463 #[ignore]
464 fn test_ans_roundtrip_long_sequence() {
465 let dist = uniform_distribution(16).expect("ok");
466 let symbols: Vec<u16> = (0..100).map(|i| (i % 16) as u16).collect();
467
468 let mut encoder = AnsEncoder::new();
469 for &sym in symbols.iter().rev() {
470 encoder.encode_symbol(sym, &dist).expect("ok");
471 }
472 let encoded = encoder.finish();
473
474 let mut decoder = AnsDecoder::new(&encoded).expect("ok");
475 for (i, &expected) in symbols.iter().enumerate() {
476 let decoded = decoder.decode_symbol(&dist).expect("ok");
477 assert_eq!(decoded, expected, "Mismatch at position {i}");
478 }
479 }
480
481 #[test]
482 #[ignore]
483 fn test_empty_distribution_error() {
484 assert!(AnsDistribution::new(vec![], vec![], 10).is_err());
485 }
486
487 #[test]
488 #[ignore]
489 fn test_zero_symbol_uniform_error() {
490 assert!(uniform_distribution(0).is_err());
491 }
492}