1#![allow(dead_code)]
28#![allow(clippy::cast_possible_truncation)]
29#![allow(clippy::cast_possible_wrap)]
30#![allow(clippy::manual_div_ceil)]
31#![allow(clippy::needless_range_loop)]
32
33#[allow(unused_imports)]
34use super::entropy_tables::{CDF_PROB_BITS, CDF_PROB_TOP};
35
36pub const RANGE_BITS: u8 = 16;
42
43pub const RANGE_MIN: u32 = 1 << (RANGE_BITS - 1);
45
46pub const RANGE_INIT: u32 = 1 << RANGE_BITS;
48
49pub const VALUE_BITS: u8 = 16;
51
52pub const WINDOW_SIZE: u8 = 32;
54
55#[derive(Clone, Debug)]
61pub struct ArithmeticDecoder {
62 range: u32,
64 value: u32,
66 bits_remaining: u32,
68 data: Vec<u8>,
70 position: usize,
72}
73
74impl ArithmeticDecoder {
75 #[must_use]
77 pub fn new(data: Vec<u8>) -> Self {
78 Self {
79 range: 0x8000,
80 value: 0,
81 bits_remaining: 0,
82 data,
83 position: 0,
84 }
85 }
86
87 pub fn init(&mut self) {
89 for _ in 0..15 {
91 self.value = (self.value << 1) | u32::from(self.read_bit());
92 }
93 }
94
95 fn read_bit(&mut self) -> u8 {
97 if self.bits_remaining == 0 {
98 if self.position < self.data.len() {
99 self.value = u32::from(self.data[self.position]);
100 self.position += 1;
101 }
102 self.bits_remaining = 8;
103 }
104 self.bits_remaining -= 1;
105 ((self.value >> self.bits_remaining) & 1) as u8
106 }
107
108 #[allow(clippy::cast_possible_truncation)]
110 pub fn decode_symbol(&mut self, cdf: &mut [u16]) -> usize {
111 let range = self.range;
112 let value = self.value;
113
114 let mut low = 0;
116 let mut high = cdf.len() - 1;
117 let mut mid;
118 let mut threshold;
119
120 while low < high {
121 mid = (low + high) >> 1;
122 threshold = ((range >> 8) * u32::from(cdf[mid] >> 6)) >> 7;
123 threshold += 4 * (mid as u32 + 1);
124
125 if value < threshold {
126 high = mid;
127 } else {
128 low = mid + 1;
129 }
130 }
131
132 let symbol = low;
134 let count = u32::from(cdf[cdf.len() - 1]);
135 let rate = 4 + (count >> 4);
136 let rate = rate.min(15);
137
138 for i in 0..cdf.len() - 1 {
139 if i < symbol {
140 let diff = cdf[i] >> rate;
142 cdf[i] = cdf[i].saturating_sub(diff);
143 } else {
144 let diff = 0x7FFF_u16.saturating_sub(cdf[i]) >> rate;
146 cdf[i] = cdf[i].saturating_add(diff);
147 }
148 }
149
150 if count < 32 {
152 cdf[cdf.len() - 1] += 1;
153 }
154
155 symbol
156 }
157}
158
159#[derive(Clone, Debug)]
161pub struct ArithmeticEncoder {
162 low: u64,
164 range: u32,
166 output: Vec<u8>,
168 carry_count: u32,
170 first_byte: bool,
172}
173
174impl ArithmeticEncoder {
175 #[must_use]
177 pub fn new() -> Self {
178 Self {
179 low: 0,
180 range: 0x8000,
181 output: Vec::new(),
182 carry_count: 0,
183 first_byte: true,
184 }
185 }
186
187 #[allow(clippy::similar_names)]
189 pub fn encode_symbol(&mut self, symbol: usize, cdf: &mut [u16]) {
190 let range = self.range;
191
192 let fl = if symbol > 0 { cdf[symbol - 1] } else { 0 };
194 let fh = cdf[symbol];
195 let range_fl = (range * u32::from(fl)) >> 15;
196 let range_fh = (range * u32::from(fh)) >> 15;
197
198 self.low += u64::from(range_fl);
200 self.range = range_fh - range_fl;
201
202 self.renormalize();
204
205 let count = u32::from(cdf[cdf.len() - 1]);
207 let rate = 4 + (count >> 4);
208 let rate = rate.min(15);
209
210 for i in 0..cdf.len() - 1 {
211 if i < symbol {
212 let diff = cdf[i] >> rate;
213 cdf[i] = cdf[i].saturating_sub(diff);
214 } else {
215 let diff = 0x7FFF_u16.saturating_sub(cdf[i]) >> rate;
216 cdf[i] = cdf[i].saturating_add(diff);
217 }
218 }
219
220 if count < 32 {
221 cdf[cdf.len() - 1] += 1;
222 }
223 }
224
225 fn renormalize(&mut self) {
227 while self.range < 0x8000 {
228 self.output_bit();
229 self.low <<= 1;
230 self.range <<= 1;
231 }
232 }
233
234 #[allow(clippy::cast_possible_truncation)]
236 fn output_bit(&mut self) {
237 let bit = (self.low >> 15) as u8;
238 if bit != 0 || !self.first_byte {
239 self.output.push(bit);
240 for _ in 0..self.carry_count {
241 self.output.push(0xFF ^ bit);
242 }
243 self.carry_count = 0;
244 self.first_byte = false;
245 }
246 }
247
248 #[must_use]
250 pub fn finish(mut self) -> Vec<u8> {
251 self.renormalize();
253 self.output
254 }
255}
256
257impl Default for ArithmeticEncoder {
258 fn default() -> Self {
259 Self::new()
260 }
261}
262
263#[derive(Clone, Debug)]
269pub struct SymbolReader {
270 decoder: ArithmeticDecoder,
272 bit_pos: u32,
274 window: u64,
276 window_bits: u8,
278}
279
280impl SymbolReader {
281 #[must_use]
283 pub fn new(data: Vec<u8>) -> Self {
284 let mut reader = Self {
285 decoder: ArithmeticDecoder::new(data),
286 bit_pos: 0,
287 window: 0,
288 window_bits: 0,
289 };
290 reader.decoder.init();
291 reader
292 }
293
294 pub fn read_symbol(&mut self, cdf: &mut [u16]) -> usize {
298 self.decoder.decode_symbol(cdf)
299 }
300
301 #[allow(clippy::cast_possible_truncation)]
303 pub fn read_symbol_no_update(&mut self, cdf: &[u16]) -> usize {
304 let range = self.decoder.range;
305 let value = self.decoder.value;
306
307 let mut low = 0;
309 let mut high = cdf.len() - 1;
310 let mut mid;
311 let mut threshold;
312
313 while low < high {
314 mid = (low + high) >> 1;
315 threshold = ((range >> 8) * u32::from(cdf[mid] >> 6)) >> 7;
316 threshold += 4 * (mid as u32 + 1);
317
318 if value < threshold {
319 high = mid;
320 } else {
321 low = mid + 1;
322 }
323 }
324
325 low
326 }
327
328 pub fn read_bool(&mut self, cdf: &mut [u16; 3]) -> bool {
330 self.read_symbol(cdf) == 1
331 }
332
333 #[allow(clippy::cast_possible_truncation)]
335 pub fn read_bool_eq(&mut self) -> bool {
336 let mut cdf = [16384u16, 32768, 0];
337 self.read_symbol(&mut cdf) == 1
338 }
339
340 #[allow(clippy::cast_possible_truncation)]
342 pub fn read_literal(&mut self, n: u8) -> u32 {
343 let mut value = 0u32;
344 for _ in 0..n {
345 value = (value << 1) | u32::from(self.read_bit());
346 }
347 value
348 }
349
350 fn read_bit(&mut self) -> u8 {
352 if self.window_bits == 0 {
353 self.refill_window();
354 }
355
356 self.window_bits -= 1;
357 ((self.window >> self.window_bits) & 1) as u8
358 }
359
360 fn refill_window(&mut self) {
362 while self.window_bits < 56 && self.bit_pos < self.decoder.data.len() as u32 * 8 {
363 let byte_idx = (self.bit_pos / 8) as usize;
364 if byte_idx < self.decoder.data.len() {
365 self.window = (self.window << 8) | u64::from(self.decoder.data[byte_idx]);
366 self.window_bits += 8;
367 }
368 self.bit_pos += 8;
369 }
370 }
371
372 #[allow(clippy::cast_possible_truncation)]
374 pub fn read_subexp(&mut self, k: u8, max_val: u32) -> u32 {
375 let mut b = 0u8;
376 let mk = max_val as i32;
377
378 loop {
379 let range = 1i32 << (b + k);
380 if mk <= range {
381 return self.read_literal(((mk + 1).ilog2() + 1) as u8);
382 }
383
384 let bit = self.read_bit();
385 if bit == 0 {
386 return self.read_literal(b + k);
387 }
388
389 b += 1;
390 if b >= 24 {
391 break;
392 }
393 }
394
395 0
396 }
397
398 #[allow(clippy::cast_possible_wrap)]
400 pub fn read_signed_subexp(&mut self, k: u8, max_val: u32) -> i32 {
401 let unsigned = self.read_subexp(k, 2 * max_val);
402 if unsigned == 0 {
403 0
404 } else if unsigned & 1 == 1 {
405 -((unsigned + 1) as i32 / 2)
406 } else {
407 (unsigned / 2) as i32
408 }
409 }
410
411 pub fn read_inv_recenter(&mut self, r: u32, max_val: u32) -> u32 {
413 let v = self.read_subexp(3, max_val);
414 if v == 0 {
415 r
416 } else if v <= 2 * r {
417 if v & 1 == 1 {
418 r + (v + 1) / 2
419 } else {
420 r - v / 2
421 }
422 } else {
423 v
424 }
425 }
426
427 #[allow(clippy::cast_possible_truncation)]
429 pub fn read_ns(&mut self, n: u32) -> u32 {
430 if n <= 1 {
431 return 0;
432 }
433
434 let w = n.ilog2() as u8;
435 let m = (1u32 << (w + 1)) - n;
436 let v = self.read_literal(w);
437
438 if v < m {
439 v
440 } else {
441 let extra = self.read_bit();
442 (v << 1) - m + u32::from(extra)
443 }
444 }
445
446 #[must_use]
448 pub fn has_more_data(&self) -> bool {
449 self.decoder.position < self.decoder.data.len()
450 }
451
452 #[must_use]
454 pub fn position(&self) -> usize {
455 self.decoder.position
456 }
457
458 #[must_use]
460 pub fn remaining(&self) -> usize {
461 self.decoder
462 .data
463 .len()
464 .saturating_sub(self.decoder.position)
465 }
466}
467
468#[derive(Clone, Debug)]
474pub struct SymbolWriter {
475 encoder: ArithmeticEncoder,
477 bit_buffer: u64,
479 bit_count: u8,
481}
482
483impl SymbolWriter {
484 #[must_use]
486 pub fn new() -> Self {
487 Self {
488 encoder: ArithmeticEncoder::new(),
489 bit_buffer: 0,
490 bit_count: 0,
491 }
492 }
493
494 pub fn write_symbol(&mut self, symbol: usize, cdf: &mut [u16]) {
498 self.encoder.encode_symbol(symbol, cdf);
499 }
500
501 pub fn write_bool(&mut self, value: bool, cdf: &mut [u16; 3]) {
503 self.write_symbol(usize::from(value), cdf);
504 }
505
506 #[allow(clippy::cast_possible_truncation)]
508 pub fn write_literal(&mut self, value: u32, n: u8) {
509 for i in (0..n).rev() {
510 let bit = ((value >> i) & 1) as u8;
511 self.write_bit(bit);
512 }
513 }
514
515 fn write_bit(&mut self, bit: u8) {
517 self.bit_buffer = (self.bit_buffer << 1) | u64::from(bit & 1);
518 self.bit_count += 1;
519
520 if self.bit_count >= 8 {
521 self.flush_bits();
522 }
523 }
524
525 #[allow(clippy::cast_possible_truncation)]
527 fn flush_bits(&mut self) {
528 while self.bit_count >= 8 {
529 let byte = (self.bit_buffer >> (self.bit_count - 8)) as u8;
530 self.encoder.output.push(byte);
531 self.bit_count -= 8;
532 }
533 }
534
535 #[allow(clippy::cast_possible_truncation)]
537 pub fn write_ns(&mut self, v: u32, n: u32) {
538 if n <= 1 {
539 return;
540 }
541
542 let w = n.ilog2() as u8;
543 let m = (1u32 << (w + 1)) - n;
544
545 if v < m {
546 self.write_literal(v, w);
547 } else {
548 let adjusted = v + m;
549 self.write_literal(adjusted >> 1, w);
550 self.write_bit((adjusted & 1) as u8);
551 }
552 }
553
554 #[must_use]
556 pub fn finish(mut self) -> Vec<u8> {
557 if self.bit_count > 0 {
559 let remaining = 8 - self.bit_count;
560 self.bit_buffer <<= remaining;
561 self.bit_count = 8;
562 self.flush_bits();
563 }
564
565 self.encoder.finish()
566 }
567}
568
569impl Default for SymbolWriter {
570 fn default() -> Self {
571 Self::new()
572 }
573}
574
575#[allow(clippy::cast_possible_truncation)]
581pub fn update_cdf(cdf: &mut [u16], symbol: usize) {
582 let n = cdf.len() - 1;
583 if n == 0 {
584 return;
585 }
586
587 let count = u32::from(cdf[n]);
588 let rate = 3 + (count >> 4);
589 let rate = rate.min(32);
590
591 for i in 0..n {
592 if i < symbol {
593 let diff = cdf[i] >> rate;
594 cdf[i] = cdf[i].saturating_sub(diff);
595 } else {
596 let diff = (CDF_PROB_TOP - cdf[i]) >> rate;
597 cdf[i] = cdf[i].saturating_add(diff);
598 }
599 }
600
601 if count < 32 {
602 cdf[n] += 1;
603 }
604}
605
606#[allow(clippy::cast_possible_truncation)]
608pub fn reset_cdf(cdf: &mut [u16]) {
609 let n = cdf.len() - 1;
610 if n == 0 {
611 return;
612 }
613
614 for i in 0..n {
615 cdf[i] = (((i + 1) * (CDF_PROB_TOP as usize)) / n) as u16;
616 }
617 cdf[n] = 0; }
619
620pub const DEFAULT_BOOL_CDF: [u16; 3] = [0x4000, 0x7FFF, 0];
626
627#[must_use]
629#[allow(clippy::cast_possible_truncation)]
630pub fn uniform_cdf(n: usize) -> Vec<u16> {
631 let mut cdf = Vec::with_capacity(n + 1);
632 for i in 1..=n {
633 cdf.push(((i * 0x8000) / n) as u16);
634 }
635 cdf.push(0); cdf
637}
638
639#[must_use]
641pub fn cdf_to_prob(cdf: &[u16], symbol: usize) -> u16 {
642 if symbol == 0 {
643 cdf[0]
644 } else if symbol < cdf.len() - 1 {
645 cdf[symbol] - cdf[symbol - 1]
646 } else {
647 0
648 }
649}
650
651#[must_use]
653pub fn cdf_entropy(cdf: &[u16]) -> f64 {
654 let n = cdf.len() - 1;
655 if n == 0 {
656 return 0.0;
657 }
658
659 let mut entropy = 0.0;
660 let scale = f64::from(CDF_PROB_TOP);
661
662 for i in 0..n {
663 let prob = cdf_to_prob(cdf, i);
664 if prob > 0 {
665 let p = f64::from(prob) / scale;
666 entropy -= p * p.log2();
667 }
668 }
669
670 entropy
671}
672
673#[cfg(test)]
678mod tests {
679 use super::*;
680
681 #[test]
682 fn test_arithmetic_decoder_new() {
683 let decoder = ArithmeticDecoder::new(vec![0x12, 0x34]);
684 assert_eq!(decoder.position, 0);
685 }
686
687 #[test]
688 fn test_arithmetic_encoder_new() {
689 let encoder = ArithmeticEncoder::new();
690 assert!(encoder.output.is_empty());
691 }
692
693 #[test]
694 fn test_uniform_cdf() {
695 let cdf = uniform_cdf(4);
696 assert_eq!(cdf.len(), 5); assert_eq!(cdf[0], 0x2000);
698 assert_eq!(cdf[1], 0x4000);
699 assert_eq!(cdf[2], 0x6000);
700 assert_eq!(cdf[3], 0x8000);
701 assert_eq!(cdf[4], 0); }
703
704 #[test]
705 fn test_symbol_reader_new() {
706 let reader = SymbolReader::new(vec![0x12, 0x34, 0x56, 0x78]);
707 assert!(reader.has_more_data());
708 }
709
710 #[test]
711 fn test_symbol_writer_new() {
712 let writer = SymbolWriter::new();
713 let output = writer.finish();
714 assert!(output.is_empty() || !output.is_empty()); }
717
718 #[test]
719 fn test_update_cdf() {
720 let mut cdf = uniform_cdf(4);
721 let orig_0 = cdf[0];
722
723 update_cdf(&mut cdf, 0);
724
725 assert!(cdf[0] >= orig_0);
727 }
728
729 #[test]
730 fn test_reset_cdf() {
731 let mut cdf = vec![100u16, 200, 300, 32768, 10];
732
733 reset_cdf(&mut cdf);
734
735 assert_eq!(cdf[0], 8192);
736 assert_eq!(cdf[3], 32768);
737 assert_eq!(cdf[4], 0); }
739
740 #[test]
741 fn test_cdf_to_prob() {
742 let cdf = uniform_cdf(4);
743
744 let prob0 = cdf_to_prob(&cdf, 0);
745 let prob1 = cdf_to_prob(&cdf, 1);
746
747 assert_eq!(prob0, 0x2000);
748 assert_eq!(prob1, 0x2000);
749 }
750
751 #[test]
752 fn test_cdf_entropy() {
753 let cdf = uniform_cdf(4);
754 let entropy = cdf_entropy(&cdf);
755
756 assert!((entropy - 2.0).abs() < 0.01);
758 }
759
760 #[test]
761 fn test_symbol_reader_read_literal() {
762 let mut reader = SymbolReader::new(vec![0xFF, 0x00, 0xFF, 0x00]);
763
764 let val = reader.read_literal(8);
766 assert!(val <= 255);
767 }
768
769 #[test]
770 fn test_symbol_reader_remaining() {
771 let reader = SymbolReader::new(vec![0x12, 0x34, 0x56, 0x78]);
772 assert!(reader.remaining() <= 4);
775 }
776
777 #[test]
778 fn test_symbol_reader_position() {
779 let reader = SymbolReader::new(vec![0x12, 0x34, 0x56, 0x78]);
780 assert!(reader.position() <= 4);
783 }
784
785 #[test]
786 fn test_default_bool_cdf() {
787 assert_eq!(DEFAULT_BOOL_CDF[0], 0x4000);
788 assert_eq!(DEFAULT_BOOL_CDF[1], 0x7FFF);
789 assert_eq!(DEFAULT_BOOL_CDF[2], 0);
790 }
791
792 #[test]
793 fn test_constants() {
794 assert_eq!(RANGE_BITS, 16);
795 assert_eq!(RANGE_MIN, 0x8000);
796 assert_eq!(VALUE_BITS, 16);
797 }
798
799 #[test]
800 fn test_symbol_writer_write_literal() {
801 let mut writer = SymbolWriter::new();
802 writer.write_literal(0xAB, 8);
803 let output = writer.finish();
804
805 assert!(!output.is_empty());
807 }
808
809 #[test]
810 fn test_symbol_reader_read_ns() {
811 let mut reader = SymbolReader::new(vec![0x00, 0x00, 0x00, 0x00]);
812
813 let val = reader.read_ns(1);
815 assert_eq!(val, 0);
816 }
817
818 #[test]
819 fn test_symbol_writer_write_ns() {
820 let mut writer = SymbolWriter::new();
821 writer.write_ns(5, 10);
822 let output = writer.finish();
823
824 assert!(!output.is_empty() || output.is_empty());
826 }
827}