1#![forbid(unsafe_code)]
22#![allow(dead_code)]
23#![allow(clippy::cast_possible_truncation)]
24#![allow(clippy::cast_precision_loss)]
25#![allow(clippy::cast_sign_loss)]
26#![allow(clippy::similar_names)]
27#![allow(clippy::too_many_arguments)]
28
29use super::entropy_tables::{CDF_PROB_BITS, CDF_PROB_TOP};
30
31const EC_PROB_SHIFT: u32 = 6;
37
38const EC_WINDOW_SIZE: u32 = 1 << 16;
40
41const EC_MIN_RANGE: u32 = 128;
43
44const MAX_SYMBOL_VALUE: u16 = 15;
46
47const CDF_UPDATE_RATE: u16 = 5;
49
50#[derive(Clone, Debug)]
56pub struct ArithmeticEncoder {
57 range: u32,
59 low: u32,
61 cnt: i32,
63 buffer: Vec<u8>,
65}
66
67impl Default for ArithmeticEncoder {
68 fn default() -> Self {
69 Self::new()
70 }
71}
72
73impl ArithmeticEncoder {
74 #[must_use]
76 pub fn new() -> Self {
77 Self {
78 range: EC_WINDOW_SIZE,
79 low: 0,
80 cnt: -9,
81 buffer: Vec::with_capacity(4096),
82 }
83 }
84
85 pub fn encode_symbol(&mut self, symbol: u16, cdf: &[u16]) {
92 assert!(symbol < cdf.len() as u16 - 1, "Symbol out of range");
93
94 let fl = u32::from(if symbol == 0 {
95 0
96 } else {
97 cdf[symbol as usize - 1]
98 });
99 let fh = u32::from(cdf[symbol as usize]);
100 let _ft = u32::from(cdf[cdf.len() - 1]);
101
102 let u = self.range;
104 let v = ((u >> 8) * (fh - fl)) >> (CDF_PROB_BITS - 8);
105 let r_new = if v < EC_MIN_RANGE { EC_MIN_RANGE } else { v };
106
107 self.low += ((u >> 8) * fl) >> (CDF_PROB_BITS - 8);
109 self.range = r_new;
110
111 self.renormalize();
113 }
114
115 pub fn encode_bool(&mut self, symbol: bool, prob: u16) {
117 let cdf = [CDF_PROB_TOP - prob, CDF_PROB_TOP, CDF_PROB_TOP];
120 let symbol_val = u16::from(symbol);
121 self.encode_symbol(symbol_val, &cdf);
122 }
123
124 pub fn encode_literal(&mut self, value: u32, num_bits: u8) {
126 for i in (0..num_bits).rev() {
127 let bit = (value >> i) & 1;
128 self.encode_bool(bit != 0, CDF_PROB_TOP / 2);
129 }
130 }
131
132 fn renormalize(&mut self) {
134 while self.range < EC_MIN_RANGE {
135 let c = (self.low >> 23) as u8;
136 self.buffer.push(c);
137
138 self.low = (self.low << 8) & 0x7F_FF_FF;
139 self.range <<= 8;
140 self.cnt += 8;
141 }
142 }
143
144 pub fn flush(&mut self) -> Vec<u8> {
146 while self.cnt >= 0 {
148 let c = (self.low >> 23) as u8;
149 self.buffer.push(c);
150 self.low = (self.low << 8) & 0x7F_FF_FF;
151 self.cnt -= 8;
152 }
153
154 let c = (self.low >> 23) as u8;
156 self.buffer.push(c);
157
158 while self.buffer.len() % 4 != 0 {
160 self.buffer.push(0);
161 }
162
163 std::mem::take(&mut self.buffer)
164 }
165
166 #[must_use]
168 pub fn buffer(&self) -> &[u8] {
169 &self.buffer
170 }
171
172 pub fn reset(&mut self) {
174 self.range = EC_WINDOW_SIZE;
175 self.low = 0;
176 self.cnt = -9;
177 self.buffer.clear();
178 }
179}
180
181#[derive(Clone, Debug)]
187pub struct CdfContext {
188 cdf: Vec<u16>,
190 nsymb: usize,
192}
193
194impl CdfContext {
195 #[must_use]
197 pub fn new(nsymb: usize) -> Self {
198 let mut cdf = Vec::with_capacity(nsymb + 1);
199 let step = CDF_PROB_TOP / nsymb as u16;
200
201 for i in 0..nsymb {
202 cdf.push(step * (i as u16 + 1));
203 }
204 cdf[nsymb - 1] = CDF_PROB_TOP;
205
206 Self { cdf, nsymb }
207 }
208
209 #[must_use]
211 pub fn cdf(&self) -> &[u16] {
212 &self.cdf
213 }
214
215 pub fn update(&mut self, symbol: u16) {
217 if symbol >= self.nsymb as u16 {
218 return;
219 }
220
221 for i in symbol as usize..self.nsymb {
223 let delta = CDF_PROB_TOP.saturating_sub(self.cdf[i]) >> CDF_UPDATE_RATE;
224 self.cdf[i] = self.cdf[i].saturating_add(delta);
225 }
226
227 self.cdf[self.nsymb - 1] = CDF_PROB_TOP;
229 }
230
231 pub fn reset(&mut self) {
233 let step = CDF_PROB_TOP / self.nsymb as u16;
234 for i in 0..self.nsymb {
235 self.cdf[i] = step * (i as u16 + 1);
236 }
237 self.cdf[self.nsymb - 1] = CDF_PROB_TOP;
238 }
239}
240
241#[derive(Clone, Debug)]
247pub struct SymbolEncoder {
248 encoder: ArithmeticEncoder,
250 contexts: Vec<CdfContext>,
252}
253
254impl Default for SymbolEncoder {
255 fn default() -> Self {
256 Self::new()
257 }
258}
259
260impl SymbolEncoder {
261 #[must_use]
263 pub fn new() -> Self {
264 Self {
265 encoder: ArithmeticEncoder::new(),
266 contexts: Vec::new(),
267 }
268 }
269
270 pub fn init_contexts(&mut self, num_contexts: usize, nsymb: usize) {
272 self.contexts.clear();
273 for _ in 0..num_contexts {
274 self.contexts.push(CdfContext::new(nsymb));
275 }
276 }
277
278 pub fn encode(&mut self, symbol: u16, context_id: usize) {
280 if context_id >= self.contexts.len() {
281 let cdf = CdfContext::new(MAX_SYMBOL_VALUE as usize + 1);
283 self.encoder.encode_symbol(symbol, cdf.cdf());
284 return;
285 }
286
287 let cdf = self.contexts[context_id].cdf().to_vec();
288 self.encoder.encode_symbol(symbol, &cdf);
289
290 self.contexts[context_id].update(symbol);
292 }
293
294 pub fn encode_bool(&mut self, value: bool) {
296 self.encoder.encode_bool(value, CDF_PROB_TOP / 2);
297 }
298
299 pub fn encode_literal(&mut self, value: u32, num_bits: u8) {
301 self.encoder.encode_literal(value, num_bits);
302 }
303
304 pub fn finish(&mut self) -> Vec<u8> {
306 self.encoder.flush()
307 }
308
309 #[must_use]
311 pub fn buffer(&self) -> &[u8] {
312 self.encoder.buffer()
313 }
314
315 pub fn reset(&mut self) {
317 self.encoder.reset();
318 for ctx in &mut self.contexts {
319 ctx.reset();
320 }
321 }
322}
323
324#[derive(Clone, Debug, Default)]
330pub struct BitstreamWriter {
331 buffer: Vec<u8>,
333 current_byte: u8,
335 bit_pos: u8,
337}
338
339impl BitstreamWriter {
340 #[must_use]
342 pub fn new() -> Self {
343 Self {
344 buffer: Vec::new(),
345 current_byte: 0,
346 bit_pos: 0,
347 }
348 }
349
350 pub fn write_bit(&mut self, bit: bool) {
352 if bit {
353 self.current_byte |= 1 << (7 - self.bit_pos);
354 }
355
356 self.bit_pos += 1;
357 if self.bit_pos == 8 {
358 self.buffer.push(self.current_byte);
359 self.current_byte = 0;
360 self.bit_pos = 0;
361 }
362 }
363
364 pub fn write_bits(&mut self, value: u32, num_bits: u8) {
366 for i in (0..num_bits).rev() {
367 let bit = (value >> i) & 1;
368 self.write_bit(bit != 0);
369 }
370 }
371
372 pub fn write_byte(&mut self, byte: u8) {
374 self.align();
375 self.buffer.push(byte);
376 }
377
378 pub fn align(&mut self) {
380 if self.bit_pos != 0 {
381 self.buffer.push(self.current_byte);
382 self.current_byte = 0;
383 self.bit_pos = 0;
384 }
385 }
386
387 pub fn write_bytes(&mut self, bytes: &[u8]) {
389 self.align();
390 self.buffer.extend_from_slice(bytes);
391 }
392
393 #[must_use]
395 pub fn buffer(&self) -> &[u8] {
396 &self.buffer
397 }
398
399 #[must_use]
401 pub fn finish(mut self) -> Vec<u8> {
402 self.align();
403 self.buffer
404 }
405
406 #[must_use]
408 pub fn len(&self) -> usize {
409 self.buffer.len() + usize::from(self.bit_pos > 0)
410 }
411
412 #[must_use]
414 pub fn is_empty(&self) -> bool {
415 self.buffer.is_empty() && self.bit_pos == 0
416 }
417
418 pub fn reset(&mut self) {
420 self.buffer.clear();
421 self.current_byte = 0;
422 self.bit_pos = 0;
423 }
424}
425
426#[derive(Clone, Debug)]
432pub struct ObuWriter {
433 writer: BitstreamWriter,
435}
436
437impl Default for ObuWriter {
438 fn default() -> Self {
439 Self::new()
440 }
441}
442
443impl ObuWriter {
444 #[must_use]
446 pub fn new() -> Self {
447 Self {
448 writer: BitstreamWriter::new(),
449 }
450 }
451
452 pub fn write_obu_header(&mut self, obu_type: u8, has_size: bool) {
454 self.writer.write_bit(false);
456
457 self.writer.write_bits(u32::from(obu_type), 4);
459
460 self.writer.write_bit(false);
462
463 self.writer.write_bit(has_size);
465
466 self.writer.write_bit(false);
468 }
469
470 pub fn write_leb128(&mut self, mut value: u64) {
472 loop {
473 let mut byte = (value & 0x7F) as u8;
474 value >>= 7;
475
476 if value != 0 {
477 byte |= 0x80;
478 }
479
480 self.writer.write_byte(byte);
481
482 if value == 0 {
483 break;
484 }
485 }
486 }
487
488 pub fn write_obu(&mut self, obu_type: u8, payload: &[u8]) {
490 self.write_obu_header(obu_type, true);
491 self.write_leb128(payload.len() as u64);
492 self.writer.write_bytes(payload);
493 }
494
495 #[must_use]
497 pub fn buffer(&self) -> &[u8] {
498 self.writer.buffer()
499 }
500
501 #[must_use]
503 pub fn finish(self) -> Vec<u8> {
504 self.writer.finish()
505 }
506}
507
508#[must_use]
514pub fn pmf_to_cdf(pmf: &[u16]) -> Vec<u16> {
515 let mut cdf = Vec::with_capacity(pmf.len());
516 let mut cumsum = 0u16;
517
518 for &p in pmf {
519 cumsum = cumsum.saturating_add(p);
520 cdf.push(cumsum);
521 }
522
523 if let Some(&last) = cdf.last() {
525 if last > 0 && last != CDF_PROB_TOP {
526 for val in &mut cdf {
527 *val = (u32::from(*val) * u32::from(CDF_PROB_TOP) / u32::from(last)) as u16;
528 }
529 }
530 }
531
532 cdf
533}
534
535#[must_use]
537pub fn estimate_symbol_rate(symbol: u16, cdf: &[u16]) -> f32 {
538 if symbol >= cdf.len() as u16 {
539 return 8.0; }
541
542 let fl = if symbol == 0 {
543 0
544 } else {
545 cdf[symbol as usize - 1]
546 };
547 let fh = cdf[symbol as usize];
548 let prob = fh.saturating_sub(fl);
549
550 if prob == 0 {
551 16.0 } else {
553 -(f32::from(prob) / f32::from(CDF_PROB_TOP)).log2()
554 }
555}
556
557#[cfg(test)]
562mod tests {
563 use super::*;
564
565 #[test]
566 fn test_arithmetic_encoder_creation() {
567 let encoder = ArithmeticEncoder::new();
568 assert_eq!(encoder.range, EC_WINDOW_SIZE);
569 assert_eq!(encoder.low, 0);
570 assert!(encoder.buffer.is_empty());
571 }
572
573 #[test]
574 fn test_arithmetic_encoder_bool() {
575 let mut encoder = ArithmeticEncoder::new();
576 encoder.encode_bool(true, CDF_PROB_TOP / 2);
577 encoder.encode_bool(false, CDF_PROB_TOP / 2);
578
579 let output = encoder.flush();
580 assert!(!output.is_empty());
581 }
582
583 #[test]
584 fn test_arithmetic_encoder_literal() {
585 let mut encoder = ArithmeticEncoder::new();
586 encoder.encode_literal(0xFF, 8);
587
588 let output = encoder.flush();
589 assert!(!output.is_empty());
590 }
591
592 #[test]
593 fn test_cdf_context_creation() {
594 let cdf = CdfContext::new(4);
595 assert_eq!(cdf.nsymb, 4);
596 assert_eq!(cdf.cdf().len(), 4);
597 assert_eq!(
598 *cdf.cdf().last().expect("should have last element"),
599 CDF_PROB_TOP
600 );
601 }
602
603 #[test]
604 fn test_cdf_context_update() {
605 let mut cdf = CdfContext::new(4);
606 let initial_cdf = cdf.cdf().to_vec();
607
608 cdf.update(1);
609 let updated_cdf = cdf.cdf();
610
611 assert_ne!(initial_cdf, updated_cdf);
613 assert_eq!(
614 *updated_cdf.last().expect("should have last element"),
615 CDF_PROB_TOP
616 );
617 }
618
619 #[test]
620 fn test_cdf_context_reset() {
621 let mut cdf = CdfContext::new(4);
622 let initial_cdf = cdf.cdf().to_vec();
623
624 cdf.update(1);
625 cdf.update(2);
626 cdf.reset();
627
628 assert_eq!(cdf.cdf(), &initial_cdf[..]);
629 }
630
631 #[test]
632 fn test_symbol_encoder() {
633 let mut encoder = SymbolEncoder::new();
634 encoder.init_contexts(4, 8);
635
636 encoder.encode(0, 0);
637 encoder.encode(1, 0);
638 encoder.encode(2, 1);
639
640 let output = encoder.finish();
641 assert!(!output.is_empty());
642 }
643
644 #[test]
645 fn test_symbol_encoder_bool() {
646 let mut encoder = SymbolEncoder::new();
647 encoder.encode_bool(true);
648 encoder.encode_bool(false);
649 encoder.encode_bool(true);
650
651 let output = encoder.finish();
652 assert!(!output.is_empty());
653 }
654
655 #[test]
656 fn test_bitstream_writer_bit() {
657 let mut writer = BitstreamWriter::new();
658 writer.write_bit(true);
659 writer.write_bit(false);
660 writer.write_bit(true);
661 writer.write_bit(true);
662 writer.write_bit(false);
663 writer.write_bit(false);
664 writer.write_bit(false);
665 writer.write_bit(true);
666
667 let output = writer.finish();
668 assert_eq!(output.len(), 1);
669 assert_eq!(output[0], 0b1011_0001);
670 }
671
672 #[test]
673 fn test_bitstream_writer_bits() {
674 let mut writer = BitstreamWriter::new();
675 writer.write_bits(0xFF, 8);
676
677 let output = writer.finish();
678 assert_eq!(output.len(), 1);
679 assert_eq!(output[0], 0xFF);
680 }
681
682 #[test]
683 fn test_bitstream_writer_align() {
684 let mut writer = BitstreamWriter::new();
685 writer.write_bit(true);
686 writer.write_bit(false);
687 writer.align();
688
689 let output = writer.finish();
690 assert_eq!(output.len(), 1);
691 }
692
693 #[test]
694 fn test_bitstream_writer_bytes() {
695 let mut writer = BitstreamWriter::new();
696 writer.write_bytes(&[0xAB, 0xCD, 0xEF]);
697
698 let output = writer.finish();
699 assert_eq!(output, &[0xAB, 0xCD, 0xEF]);
700 }
701
702 #[test]
703 fn test_obu_writer_header() {
704 let mut writer = ObuWriter::new();
705 writer.write_obu_header(1, true);
706
707 let output = writer.buffer();
708 assert!(!output.is_empty());
709 }
710
711 #[test]
712 fn test_obu_writer_leb128() {
713 let mut writer = ObuWriter::new();
714 writer.write_leb128(127);
715
716 let output = writer.buffer();
717 assert_eq!(output.len(), 1);
718 assert_eq!(output[0], 127);
719
720 let mut writer2 = ObuWriter::new();
721 writer2.write_leb128(128);
722
723 let output2 = writer2.buffer();
724 assert_eq!(output2.len(), 2);
725 }
726
727 #[test]
728 fn test_obu_writer_complete() {
729 let mut writer = ObuWriter::new();
730 let payload = vec![1, 2, 3, 4];
731 writer.write_obu(1, &payload);
732
733 let output = writer.finish();
734 assert!(output.len() > payload.len());
735 }
736
737 #[test]
738 fn test_pmf_to_cdf() {
739 let pmf = vec![100, 200, 300, 400];
740 let cdf = pmf_to_cdf(&pmf);
741
742 assert_eq!(cdf.len(), 4);
743 assert!(*cdf.last().expect("should have last element") > 0);
744 for i in 1..cdf.len() {
746 assert!(cdf[i] >= cdf[i - 1]);
747 }
748 }
749
750 #[test]
751 fn test_estimate_symbol_rate() {
752 let cdf = vec![100, 300, 600, CDF_PROB_TOP];
753
754 let rate0 = estimate_symbol_rate(0, &cdf);
755 let rate1 = estimate_symbol_rate(1, &cdf);
756
757 assert!(rate0 > 0.0);
758 assert!(rate1 > 0.0);
759 assert!(rate0 < rate1 * 2.0);
761 }
762
763 #[test]
764 fn test_bitstream_writer_len() {
765 let mut writer = BitstreamWriter::new();
766 assert_eq!(writer.len(), 0);
767 assert!(writer.is_empty());
768
769 writer.write_byte(0xFF);
770 assert_eq!(writer.len(), 1);
771 assert!(!writer.is_empty());
772 }
773
774 #[test]
775 fn test_symbol_encoder_reset() {
776 let mut encoder = SymbolEncoder::new();
777 encoder.init_contexts(2, 4);
778 encoder.encode(1, 0);
779
780 encoder.reset();
781 assert!(encoder.buffer().is_empty());
782 }
783
784 #[test]
785 fn test_arithmetic_encoder_reset() {
786 let mut encoder = ArithmeticEncoder::new();
787 encoder.encode_bool(true, CDF_PROB_TOP / 2);
788
789 encoder.reset();
790 assert_eq!(encoder.range, EC_WINDOW_SIZE);
791 assert!(encoder.buffer.is_empty());
792 }
793}