1#![forbid(unsafe_code)]
28#![allow(dead_code)]
29#![allow(clippy::doc_markdown)]
30#![allow(clippy::unused_self)]
31#![allow(clippy::missing_errors_doc)]
32#![allow(clippy::similar_names)]
33#![allow(clippy::cast_possible_truncation)]
34#![allow(clippy::cast_sign_loss)]
35#![allow(clippy::too_many_arguments)]
36#![allow(clippy::struct_excessive_bools)]
37#![allow(clippy::module_name_repetitions)]
38
39use super::block::{BlockModeInfo, BlockSize, InterMode, IntraMode, PartitionType};
40use super::entropy::{uniform_cdf, SymbolReader};
41use super::entropy_tables::CdfContext;
42use super::transform::{TxSize, TxType};
43use crate::error::{CodecError, CodecResult};
44
45pub const PARTITION_CONTEXTS: usize = 4;
51
52pub const SKIP_CONTEXTS: usize = 3;
54
55pub const INTRA_MODE_CONTEXTS: usize = 5;
57
58pub const INTER_MODE_CONTEXTS: usize = 7;
60
61pub const REF_CONTEXTS: usize = 3;
63
64pub const MV_CONTEXTS: usize = 2;
66
67pub const TX_SIZE_CONTEXTS: usize = 4;
69
70pub const TX_TYPE_CONTEXTS: usize = 4;
72
73pub const NUM_REF_FRAMES: usize = 7;
75
76pub const MAX_MV_COMPONENT: i16 = 1023;
78
79#[derive(Debug)]
85pub struct SymbolDecoder {
86 reader: SymbolReader,
88 cdf_context: CdfContext,
90 frame_is_intra: bool,
92 allow_intrabc: bool,
94 segment_id: u8,
96}
97
98impl SymbolDecoder {
99 pub fn new(data: Vec<u8>, frame_is_intra: bool) -> Self {
101 Self {
102 reader: SymbolReader::new(data),
103 cdf_context: CdfContext::new(),
104 frame_is_intra,
105 allow_intrabc: false,
106 segment_id: 0,
107 }
108 }
109
110 pub fn read_partition(&mut self, bsize: BlockSize, _ctx: u8) -> CodecResult<PartitionType> {
112 if bsize == BlockSize::Block4x4 {
113 return Ok(PartitionType::None);
115 }
116
117 let mut cdf = uniform_cdf(10);
118 let symbol = self.reader.read_symbol(&mut cdf);
119 PartitionType::from_u8(symbol as u8)
120 .ok_or_else(|| CodecError::InvalidBitstream("Invalid partition type".to_string()))
121 }
122
123 pub fn read_skip(&mut self, _ctx: u8) -> bool {
125 let mut cdf = [16384u16, 32768, 0];
127 self.reader.read_bool(&mut cdf)
128 }
129
130 pub fn read_skip_mode(&mut self, _ctx: u8) -> bool {
132 if self.frame_is_intra {
133 return false;
134 }
135
136 let mut cdf = [16384u16, 32768, 0];
138 self.reader.read_bool(&mut cdf)
139 }
140
141 pub fn read_segment_id(&mut self, _ctx: u8, max_segments: u8) -> u8 {
143 if max_segments == 1 {
144 return 0;
145 }
146
147 let mut cdf = uniform_cdf(max_segments as usize);
148 let segment_id = self.reader.read_symbol(&mut cdf) as u8;
149 self.segment_id = segment_id;
150 segment_id
151 }
152
153 pub fn read_is_inter(&mut self, _ctx: u8) -> bool {
155 if self.frame_is_intra {
156 return false;
157 }
158
159 let mut cdf = [16384u16, 32768, 0];
160 self.reader.read_bool(&mut cdf)
161 }
162
163 pub fn read_intra_mode(&mut self, _ctx: u8, _bsize: BlockSize) -> CodecResult<IntraMode> {
165 let mut cdf = uniform_cdf(13);
166 let symbol = self.reader.read_symbol(&mut cdf);
167 IntraMode::from_u8(symbol as u8)
168 .ok_or_else(|| CodecError::InvalidBitstream("Invalid intra mode".to_string()))
169 }
170
171 pub fn read_uv_mode(&mut self, _y_mode: IntraMode, _ctx: u8) -> CodecResult<IntraMode> {
173 let mut cdf = uniform_cdf(13);
174 let symbol = self.reader.read_symbol(&mut cdf);
175 IntraMode::from_u8(symbol as u8)
176 .ok_or_else(|| CodecError::InvalidBitstream("Invalid UV mode".to_string()))
177 }
178
179 pub fn read_angle_delta(&mut self, mode: IntraMode) -> i8 {
181 if !mode.is_directional() {
182 return 0;
183 }
184
185 let mut cdf = uniform_cdf(7);
186 let symbol = self.reader.read_symbol(&mut cdf);
187
188 (symbol as i8) - 3
190 }
191
192 pub fn read_use_palette(&mut self, bsize: BlockSize, _ctx: u8) -> bool {
194 if bsize == BlockSize::Block4x4
195 || bsize == BlockSize::Block4x8
196 || bsize == BlockSize::Block8x4
197 {
198 return false;
199 }
200
201 let mut cdf = [16384u16, 32768, 0];
202 self.reader.read_bool(&mut cdf)
203 }
204
205 pub fn read_filter_intra_mode(&mut self) -> u8 {
207 let mut cdf = uniform_cdf(5);
208 self.reader.read_symbol(&mut cdf) as u8
209 }
210
211 pub fn read_inter_mode(&mut self, _ctx: u8) -> CodecResult<InterMode> {
213 let mut cdf = uniform_cdf(4);
214 let symbol = self.reader.read_symbol(&mut cdf);
215 InterMode::from_u8(symbol as u8)
216 .ok_or_else(|| CodecError::InvalidBitstream("Invalid inter mode".to_string()))
217 }
218
219 pub fn read_ref_frames(&mut self, _ctx: u8) -> [i8; 2] {
221 if self.frame_is_intra {
222 return [-1, -1];
223 }
224
225 let mut compound_cdf = [16384u16, 32768, 0];
227 let is_compound = self.reader.read_bool(&mut compound_cdf);
228
229 if is_compound {
230 let ref0 = self.read_single_ref_frame(0);
232 let ref1 = self.read_single_ref_frame(1);
233 [ref0, ref1]
234 } else {
235 let ref0 = self.read_single_ref_frame(0);
237 [ref0, -1]
238 }
239 }
240
241 fn read_single_ref_frame(&mut self, _idx: usize) -> i8 {
243 let mut cdf = uniform_cdf(7);
244 self.reader.read_symbol(&mut cdf) as i8
245 }
246
247 pub fn read_mv(&mut self, ctx: u8) -> [i16; 2] {
249 let row = self.read_mv_component(ctx, true);
250 let col = self.read_mv_component(ctx, false);
251 [row, col]
252 }
253
254 fn read_mv_component(&mut self, _ctx: u8, _is_row: bool) -> i16 {
256 let mut sign_cdf = [16384u16, 32768, 0];
258 let sign = self.reader.read_bool(&mut sign_cdf);
259
260 let mut class_cdf = uniform_cdf(11);
262 let class = self.reader.read_symbol(&mut class_cdf) as u8;
263
264 let mag = self.read_mv_magnitude(class);
266
267 if sign {
268 -(mag as i16)
269 } else {
270 mag as i16
271 }
272 }
273
274 fn read_mv_magnitude(&mut self, class: u8) -> u16 {
276 match class {
277 0 => 0, 1 => 1, _ => {
280 let offset_bits = class - 2;
282 let mut mag = 1u16 << (offset_bits + 1);
283
284 for _ in 0..offset_bits {
285 let mut bit_cdf = [16384u16, 32768, 0];
286 let bit = self.reader.read_bool(&mut bit_cdf);
287 mag |= u16::from(bit);
288 mag <<= 1;
289 }
290
291 mag >> 1
292 }
293 }
294 }
295
296 pub fn read_tx_size(&mut self, bsize: BlockSize, _ctx: u8) -> TxSize {
298 let max_tx_size = bsize.max_tx_size();
299
300 let mut cdf = uniform_cdf(5);
302 let symbol = self.reader.read_symbol(&mut cdf);
303
304 self.map_tx_size_symbol(symbol, max_tx_size)
306 }
307
308 fn map_tx_size_symbol(&self, symbol: usize, max_tx_size: TxSize) -> TxSize {
310 match symbol {
311 0 => TxSize::Tx4x4,
312 1 => TxSize::Tx8x8.min(max_tx_size),
313 2 => TxSize::Tx16x16.min(max_tx_size),
314 3 => TxSize::Tx32x32.min(max_tx_size),
315 _ => max_tx_size,
316 }
317 }
318
319 pub fn read_tx_type(&mut self, _tx_size: TxSize, _is_inter: bool, _ctx: u8) -> TxType {
321 let mut cdf = uniform_cdf(16);
323 let symbol = self.reader.read_symbol(&mut cdf);
324 TxType::from_u8(symbol as u8).unwrap_or(TxType::DctDct)
325 }
326
327 pub fn read_compound_type(&mut self, _ctx: u8) -> u8 {
329 let mut cdf = uniform_cdf(3);
330 self.reader.read_symbol(&mut cdf) as u8
331 }
332
333 pub fn read_interp_filter(&mut self, _ctx: u8) -> u8 {
335 let mut cdf = uniform_cdf(4);
336 self.reader.read_symbol(&mut cdf) as u8
337 }
338
339 pub fn read_motion_mode(&mut self, bsize: BlockSize, _ctx: u8) -> u8 {
341 if bsize == BlockSize::Block4x4
342 || bsize == BlockSize::Block4x8
343 || bsize == BlockSize::Block8x4
344 {
345 return 0; }
347
348 let mut cdf = uniform_cdf(3);
349 self.reader.read_symbol(&mut cdf) as u8
350 }
351
352 pub fn decode_block_mode(
354 &mut self,
355 bsize: BlockSize,
356 ctx_skip: u8,
357 ctx_mode: u8,
358 ) -> CodecResult<BlockModeInfo> {
359 let mut mode_info = BlockModeInfo::new();
360 mode_info.block_size = bsize;
361
362 mode_info.skip = self.read_skip(ctx_skip);
364
365 mode_info.segment_id = self.segment_id;
367
368 mode_info.is_inter = self.read_is_inter(ctx_mode);
370
371 if mode_info.is_inter {
372 mode_info.inter_mode = self.read_inter_mode(ctx_mode)?;
374 mode_info.ref_frames = self.read_ref_frames(ctx_mode);
375
376 if mode_info.inter_mode.has_newmv() {
378 let mv = self.read_mv(ctx_mode);
379 mode_info.mv[0] = mv;
380 }
381
382 mode_info.interp_filter = [
384 self.read_interp_filter(ctx_mode),
385 self.read_interp_filter(ctx_mode),
386 ];
387
388 mode_info.motion_mode = self.read_motion_mode(bsize, ctx_mode);
390
391 if mode_info.is_compound() {
393 mode_info.compound_type = self.read_compound_type(ctx_mode);
394 }
395 } else {
396 mode_info.intra_mode = self.read_intra_mode(ctx_mode, bsize)?;
398 mode_info.uv_mode = self.read_uv_mode(mode_info.intra_mode, ctx_mode)?;
399 mode_info.angle_delta = [
400 self.read_angle_delta(mode_info.intra_mode),
401 self.read_angle_delta(mode_info.uv_mode),
402 ];
403
404 mode_info.use_palette = self.read_use_palette(bsize, ctx_mode);
406
407 if bsize.width() <= 32 && bsize.height() <= 32 {
409 mode_info.filter_intra_mode = self.read_filter_intra_mode();
410 }
411 }
412
413 mode_info.tx_size = self.read_tx_size(bsize, ctx_mode);
415 Ok(mode_info)
418 }
419
420 pub fn has_more_data(&self) -> bool {
422 self.reader.has_more_data()
423 }
424
425 pub fn position(&self) -> usize {
427 self.reader.position()
428 }
429
430 pub fn remaining(&self) -> usize {
432 self.reader.remaining()
433 }
434}
435
436impl TxSize {
441 #[must_use]
443 pub const fn min(self, other: Self) -> Self {
444 let self_area = self.area();
445 let other_area = other.area();
446 if self_area <= other_area {
447 self
448 } else {
449 other
450 }
451 }
452}
453
454#[must_use]
460pub fn compute_partition_context(above: u8, left: u8, bsize: BlockSize) -> u8 {
461 let bs = bsize.width_log2();
462 let above_split = above < bs;
463 let left_split = left < bs;
464
465 match (above_split, left_split) {
466 (false, false) => 0,
467 (true, false) => 1,
468 (false, true) => 2,
469 (true, true) => 3,
470 }
471}
472
473#[must_use]
475pub fn compute_skip_context(above_skip: bool, left_skip: bool) -> u8 {
476 match (above_skip, left_skip) {
477 (false, false) => 0,
478 (false, true) | (true, false) => 1,
479 (true, true) => 2,
480 }
481}
482
483#[must_use]
485pub fn compute_is_inter_context(above_inter: bool, left_inter: bool) -> u8 {
486 match (above_inter, left_inter) {
487 (false, false) => 0,
488 (false, true) | (true, false) => 1,
489 (true, true) => 2,
490 }
491}
492
493#[must_use]
495pub fn compute_tx_size_context(above_tx: TxSize, left_tx: TxSize, max_tx: TxSize) -> u8 {
496 let above_cat = tx_size_category(above_tx, max_tx);
497 let left_cat = tx_size_category(left_tx, max_tx);
498
499 (above_cat + left_cat).min(3)
500}
501
502fn tx_size_category(tx: TxSize, max_tx: TxSize) -> u8 {
504 if tx == max_tx {
505 0
506 } else if tx.width() * 2 >= max_tx.width() && tx.height() * 2 >= max_tx.height() {
507 1
508 } else {
509 2
510 }
511}
512
513#[derive(Clone, Copy, Debug, Default)]
519pub struct MvPredictor {
520 pub candidates: [[i16; 2]; 3],
522 pub count: usize,
524}
525
526impl MvPredictor {
527 #[must_use]
529 pub const fn new() -> Self {
530 Self {
531 candidates: [[0, 0]; 3],
532 count: 0,
533 }
534 }
535
536 pub fn add_candidate(&mut self, mv: [i16; 2]) {
538 if self.count < 3 {
539 self.candidates[self.count] = mv;
540 self.count += 1;
541 }
542 }
543
544 #[must_use]
546 pub fn nearest(&self) -> [i16; 2] {
547 if self.count > 0 {
548 self.candidates[0]
549 } else {
550 [0, 0]
551 }
552 }
553
554 #[must_use]
556 pub fn near(&self) -> [i16; 2] {
557 if self.count > 1 {
558 self.candidates[1]
559 } else {
560 self.nearest()
561 }
562 }
563
564 #[must_use]
566 pub fn compute_context(&self) -> u8 {
567 if self.count == 0 {
568 return 0;
569 }
570
571 let mv0_mag = self.mv_magnitude(self.candidates[0]);
572 let mv1_mag = if self.count > 1 {
573 self.mv_magnitude(self.candidates[1])
574 } else {
575 0
576 };
577
578 if mv0_mag < 16 && mv1_mag < 16 {
579 0
580 } else {
581 1
582 }
583 }
584
585 fn mv_magnitude(&self, mv: [i16; 2]) -> u16 {
587 (mv[0].abs() + mv[1].abs()) as u16
588 }
589}
590
591#[derive(Debug)]
597pub struct SymbolEncoder {
598 writer: super::entropy::SymbolWriter,
600 cdf_context: CdfContext,
602}
603
604impl SymbolEncoder {
605 #[must_use]
607 pub fn new() -> Self {
608 Self {
609 writer: super::entropy::SymbolWriter::new(),
610 cdf_context: CdfContext::new(),
611 }
612 }
613
614 pub fn write_partition(&mut self, partition: PartitionType, _ctx: u8) {
616 let mut cdf = uniform_cdf(10);
617 self.writer.write_symbol(partition as usize, &mut cdf);
618 }
619
620 pub fn write_skip(&mut self, skip: bool, _ctx: u8) {
622 let mut cdf = [16384u16, 32768, 0];
624 self.writer.write_bool(skip, &mut cdf);
625 }
626
627 #[must_use]
629 pub fn finish(self) -> Vec<u8> {
630 self.writer.finish()
631 }
632}
633
634impl Default for SymbolEncoder {
635 fn default() -> Self {
636 Self::new()
637 }
638}
639
640#[cfg(test)]
645mod tests {
646 use super::*;
647
648 #[test]
649 fn test_symbol_decoder_creation() {
650 let data = vec![0u8; 128];
651 let decoder = SymbolDecoder::new(data, false);
652 assert!(decoder.has_more_data());
653 }
654
655 #[test]
656 fn test_partition_context() {
657 assert_eq!(compute_partition_context(0, 0, BlockSize::Block16x16), 3);
658 assert_eq!(compute_partition_context(4, 4, BlockSize::Block16x16), 0);
659 assert_eq!(compute_partition_context(4, 3, BlockSize::Block16x16), 2);
660 }
661
662 #[test]
663 fn test_skip_context() {
664 assert_eq!(compute_skip_context(false, false), 0);
665 assert_eq!(compute_skip_context(true, false), 1);
666 assert_eq!(compute_skip_context(false, true), 1);
667 assert_eq!(compute_skip_context(true, true), 2);
668 }
669
670 #[test]
671 fn test_is_inter_context() {
672 assert_eq!(compute_is_inter_context(false, false), 0);
673 assert_eq!(compute_is_inter_context(true, false), 1);
674 assert_eq!(compute_is_inter_context(false, true), 1);
675 assert_eq!(compute_is_inter_context(true, true), 2);
676 }
677
678 #[test]
679 fn test_tx_size_context() {
680 let max_tx = TxSize::Tx16x16;
681 let ctx = compute_tx_size_context(TxSize::Tx8x8, TxSize::Tx8x8, max_tx);
682 assert!(ctx <= 3);
683 }
684
685 #[test]
686 fn test_mv_predictor() {
687 let mut pred = MvPredictor::new();
688 assert_eq!(pred.count, 0);
689
690 pred.add_candidate([10, 20]);
691 assert_eq!(pred.count, 1);
692 assert_eq!(pred.nearest(), [10, 20]);
693
694 pred.add_candidate([5, 15]);
695 assert_eq!(pred.count, 2);
696 assert_eq!(pred.near(), [5, 15]);
697 }
698
699 #[test]
700 fn test_mv_predictor_context() {
701 let mut pred = MvPredictor::new();
702 pred.add_candidate([5, 10]);
703 pred.add_candidate([8, 12]);
704
705 let ctx = pred.compute_context();
706 assert!(ctx <= 1);
707 }
708
709 #[test]
710 fn test_tx_size_min() {
711 assert_eq!(TxSize::Tx8x8.min(TxSize::Tx16x16), TxSize::Tx8x8);
712 assert_eq!(TxSize::Tx16x16.min(TxSize::Tx8x8), TxSize::Tx8x8);
713 assert_eq!(TxSize::Tx4x4.min(TxSize::Tx4x4), TxSize::Tx4x4);
714 }
715
716 #[test]
717 fn test_symbol_encoder() {
718 let mut encoder = SymbolEncoder::new();
719 encoder.write_skip(true, 0);
720 encoder.write_partition(PartitionType::None, 0);
721 let output = encoder.finish();
722 assert!(!output.is_empty());
723 }
724
725 #[test]
726 fn test_mv_magnitude() {
727 let pred = MvPredictor::new();
728 assert_eq!(pred.mv_magnitude([0, 0]), 0);
729 assert_eq!(pred.mv_magnitude([10, 20]), 30);
730 assert_eq!(pred.mv_magnitude([-10, 20]), 30);
731 }
732
733 #[test]
734 fn test_constants() {
735 assert_eq!(PARTITION_CONTEXTS, 4);
736 assert_eq!(SKIP_CONTEXTS, 3);
737 assert_eq!(INTRA_MODE_CONTEXTS, 5);
738 assert_eq!(INTER_MODE_CONTEXTS, 7);
739 assert_eq!(NUM_REF_FRAMES, 7);
740 }
741
742 #[test]
743 fn test_symbol_decoder_position() {
744 let data = vec![0u8; 128];
745 let decoder = SymbolDecoder::new(data, false);
746 assert_eq!(decoder.remaining(), 126);
748 }
749
750 #[test]
751 fn test_tx_size_category() {
752 let max_tx = TxSize::Tx32x32;
753 assert_eq!(tx_size_category(TxSize::Tx32x32, max_tx), 0);
754 assert_eq!(tx_size_category(TxSize::Tx16x16, max_tx), 1);
755 assert_eq!(tx_size_category(TxSize::Tx4x4, max_tx), 2);
756 }
757}