1#![allow(dead_code)]
33#![allow(clippy::cast_possible_truncation)]
34#![allow(clippy::cast_sign_loss)]
35#![allow(clippy::bool_to_int_with_if)]
36#![allow(clippy::needless_bool_assign)]
37#![allow(clippy::if_not_else)]
38#![allow(clippy::cast_possible_wrap)]
39#![allow(clippy::match_same_arms)]
40#![allow(clippy::doc_markdown)]
41#![allow(clippy::explicit_iter_loop)]
42#![allow(clippy::cast_precision_loss)]
43#![allow(clippy::comparison_chain)]
44#![allow(clippy::cast_lossless)]
45
46use super::transform::{TxClass, TxSize, TxType};
47use crate::error::{CodecError, CodecResult};
48
49pub const MAX_EOB: usize = 4096;
55
56pub const EOB_COEF_CONTEXTS: usize = 9;
58
59pub const COEFF_BASE_CONTEXTS: usize = 42;
61
62pub const COEFF_BASE_EOB_CONTEXTS: usize = 3;
64
65pub const DC_SIGN_CONTEXTS: usize = 3;
67
68pub const COEFF_BR_CONTEXTS: usize = 21;
70
71pub const COEFF_BASE_RANGE_MAX: u32 = 3;
73
74pub const COEFF_BR_RICE_PARAM: u8 = 1;
76
77pub const BASE_LEVEL_CUTOFFS: [u32; 5] = [0, 1, 2, 3, 4];
79
80pub const TX_CLASSES: usize = 3;
82
83pub const COEFF_CONTEXT_MASK: usize = 63;
85
86pub const MAX_NEIGHBORS: usize = 2;
88
89pub const EOB_OFFSET: [u16; 19] = [
95 0, 16, 80, 336, 1360, 16, 16, 80, 80, 336, 336, 1360, 1360, 48, 48, 176, 176, 592, 592, ];
115
116pub const EOB_EXTRA_BITS: [u8; 19] = [
118 0, 1, 2, 3, 4, 1, 1, 2, 2, 3, 3, 4, 4, 2, 2, 3, 3, 4, 4, ];
138
139pub const EOB_GROUP_START: [u16; 12] = [0, 1, 2, 3, 5, 9, 17, 33, 65, 129, 257, 513];
141
142pub const EOB_TO_POS: [u16; 16] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15];
144
145#[derive(Clone, Debug, Default)]
151pub struct LevelContext {
152 pub mag: u32,
154 pub count: u8,
156 pub pos_ctx: u8,
158}
159
160impl LevelContext {
161 #[must_use]
163 pub const fn new() -> Self {
164 Self {
165 mag: 0,
166 count: 0,
167 pos_ctx: 0,
168 }
169 }
170
171 #[must_use]
173 pub fn mag_context(&self) -> u8 {
174 let mag = self.mag;
175 if mag > 512 {
176 4
177 } else if mag > 256 {
178 3
179 } else if mag > 128 {
180 2
181 } else if mag > 64 {
182 1
183 } else {
184 0
185 }
186 }
187
188 #[must_use]
190 pub fn context(&self) -> u8 {
191 self.mag_context() * 3 + self.count.min(2)
192 }
193}
194
195#[derive(Clone, Debug)]
201pub struct CoeffContext {
202 pub tx_size: TxSize,
204 pub tx_type: TxType,
206 pub plane: u8,
208 pub scan_pos: u16,
210 pub eob: u16,
212 pub levels: Vec<i32>,
214 pub signs: Vec<bool>,
216 pub left_ctx: Vec<u8>,
218 pub above_ctx: Vec<u8>,
220 pub block_width: u8,
222 pub block_height: u8,
224}
225
226impl CoeffContext {
227 #[must_use]
229 pub fn new(tx_size: TxSize, tx_type: TxType, plane: u8) -> Self {
230 let area = tx_size.area() as usize;
231 let width = (tx_size.width() / 4) as u8;
232 let height = (tx_size.height() / 4) as u8;
233
234 Self {
235 tx_size,
236 tx_type,
237 plane,
238 scan_pos: 0,
239 eob: 0,
240 levels: vec![0; area],
241 signs: vec![false; area],
242 left_ctx: vec![0; height as usize * 4],
243 above_ctx: vec![0; width as usize * 4],
244 block_width: width,
245 block_height: height,
246 }
247 }
248
249 pub fn reset(&mut self) {
251 self.scan_pos = 0;
252 self.eob = 0;
253 self.levels.fill(0);
254 self.signs.fill(false);
255 self.left_ctx.fill(0);
256 self.above_ctx.fill(0);
257 }
258
259 #[must_use]
261 pub fn tx_class(&self) -> TxClass {
262 self.tx_type.tx_class()
263 }
264
265 #[must_use]
267 pub fn get_scan_position(&self, idx: usize) -> (u32, u32) {
268 let width = self.tx_size.width();
269 let row = (idx as u32) / width;
270 let col = (idx as u32) % width;
271 (row, col)
272 }
273
274 #[must_use]
276 pub fn get_coeff_index(&self, row: u32, col: u32) -> usize {
277 (row * self.tx_size.width() + col) as usize
278 }
279
280 #[must_use]
282 pub fn compute_level_context(&self, pos: usize) -> LevelContext {
283 let width = self.tx_size.width() as usize;
284 let _height = self.tx_size.height() as usize;
285 let row = pos / width;
286 let col = pos % width;
287
288 let mut ctx = LevelContext::new();
289
290 if col > 0 {
292 let left = self.levels[row * width + col - 1].unsigned_abs();
293 ctx.mag += left;
294 if left > 0 {
295 ctx.count += 1;
296 }
297 }
298
299 if row > 0 {
300 let above = self.levels[(row - 1) * width + col].unsigned_abs();
301 ctx.mag += above;
302 if above > 0 {
303 ctx.count += 1;
304 }
305 }
306
307 if row > 0 && col > 0 {
309 let diag = self.levels[(row - 1) * width + col - 1].unsigned_abs();
310 ctx.mag += diag;
311 }
312
313 ctx.pos_ctx = if row + col == 0 {
315 0
316 } else if row + col < 2 {
317 1
318 } else if row + col < 4 {
319 2
320 } else {
321 3
322 };
323
324 ctx
325 }
326
327 #[must_use]
329 pub fn dc_sign_context(&self) -> u8 {
330 let left_sign = if !self.left_ctx.is_empty() {
331 (self.left_ctx[0] as i8 - 1).signum()
332 } else {
333 0
334 };
335
336 let above_sign = if !self.above_ctx.is_empty() {
337 (self.above_ctx[0] as i8 - 1).signum()
338 } else {
339 0
340 };
341
342 let sign_sum = left_sign + above_sign;
343
344 if sign_sum < 0 {
345 0
346 } else if sign_sum > 0 {
347 2
348 } else {
349 1
350 }
351 }
352
353 pub fn set_coeff(&mut self, pos: usize, level: i32, sign: bool) {
355 if pos < self.levels.len() {
356 self.levels[pos] = if sign { -level } else { level };
357 self.signs[pos] = sign;
358 }
359 }
360
361 #[must_use]
363 pub fn get_coeff(&self, pos: usize) -> i32 {
364 self.levels.get(pos).copied().unwrap_or(0)
365 }
366
367 #[must_use]
369 pub fn has_nonzero(&self) -> bool {
370 self.eob > 0
371 }
372
373 #[must_use]
375 pub fn count_nonzero(&self) -> u16 {
376 self.levels.iter().filter(|&&l| l != 0).count() as u16
377 }
378}
379
380impl Default for CoeffContext {
381 fn default() -> Self {
382 Self::new(TxSize::Tx4x4, TxType::DctDct, 0)
383 }
384}
385
386#[must_use]
392pub fn generate_diagonal_scan(width: usize, height: usize) -> Vec<u16> {
393 let mut scan = Vec::with_capacity(width * height);
394
395 for diag in 0..(width + height - 1) {
397 let col_start = if diag < width { 0 } else { diag - width + 1 };
399 let col_end = diag.min(height - 1);
400
401 for offset in 0..=(col_end - col_start) {
402 let row = col_start + offset;
403 let col = diag - row;
404
405 if col < width && row < height {
406 scan.push((row * width + col) as u16);
407 }
408 }
409 }
410
411 scan
412}
413
414#[must_use]
416pub fn generate_horizontal_scan(width: usize, height: usize) -> Vec<u16> {
417 let mut scan = Vec::with_capacity(width * height);
418
419 for row in 0..height {
420 for col in 0..width {
421 scan.push((row * width + col) as u16);
422 }
423 }
424
425 scan
426}
427
428#[must_use]
430pub fn generate_vertical_scan(width: usize, height: usize) -> Vec<u16> {
431 let mut scan = Vec::with_capacity(width * height);
432
433 for col in 0..width {
434 for row in 0..height {
435 scan.push((row * width + col) as u16);
436 }
437 }
438
439 scan
440}
441
442#[must_use]
444pub fn get_scan_order(tx_size: TxSize, tx_class: TxClass) -> Vec<u16> {
445 let width = tx_size.width() as usize;
446 let height = tx_size.height() as usize;
447
448 match tx_class {
449 TxClass::Class2D => generate_diagonal_scan(width, height),
450 TxClass::ClassHoriz => generate_horizontal_scan(width, height),
451 TxClass::ClassVert => generate_vertical_scan(width, height),
452 }
453}
454
455#[derive(Clone, Debug)]
457pub struct ScanOrderCache {
458 cache: Vec<Vec<Vec<u16>>>,
460}
461
462impl ScanOrderCache {
463 #[must_use]
465 pub fn new() -> Self {
466 let mut cache = Vec::with_capacity(19);
467
468 for tx_size_idx in 0..19 {
469 let tx_size = TxSize::from_u8(tx_size_idx as u8).unwrap_or_default();
470 let mut class_scans = Vec::with_capacity(3);
471
472 for tx_class_idx in 0..3 {
473 let tx_class = TxClass::from_u8(tx_class_idx as u8).unwrap_or_default();
474 class_scans.push(get_scan_order(tx_size, tx_class));
475 }
476
477 cache.push(class_scans);
478 }
479
480 Self { cache }
481 }
482
483 #[must_use]
485 pub fn get(&self, tx_size: TxSize, tx_class: TxClass) -> &[u16] {
486 let size_idx = tx_size as usize;
487 let class_idx = tx_class as usize;
488
489 if size_idx < self.cache.len() && class_idx < self.cache[size_idx].len() {
490 &self.cache[size_idx][class_idx]
491 } else {
492 &[]
493 }
494 }
495}
496
497impl Default for ScanOrderCache {
498 fn default() -> Self {
499 Self::new()
500 }
501}
502
503#[derive(Clone, Debug, Default)]
509pub struct EobContext {
510 pub eob_multi: u8,
512 pub eob_extra: u8,
514 pub base_ctx: u8,
516}
517
518impl EobContext {
519 #[must_use]
521 pub fn new(tx_size: TxSize) -> Self {
522 let size_idx = tx_size as usize;
523 let extra_bits = if size_idx < EOB_EXTRA_BITS.len() {
524 EOB_EXTRA_BITS[size_idx]
525 } else {
526 0
527 };
528
529 Self {
530 eob_multi: 0,
531 eob_extra: extra_bits,
532 base_ctx: 0,
533 }
534 }
535
536 #[must_use]
538 pub fn get_eob_context(eob: u16) -> u8 {
539 if eob <= 1 {
540 0
541 } else if eob <= 2 {
542 1
543 } else if eob <= 4 {
544 2
545 } else if eob <= 8 {
546 3
547 } else if eob <= 16 {
548 4
549 } else if eob <= 32 {
550 5
551 } else if eob <= 64 {
552 6
553 } else if eob <= 128 {
554 7
555 } else {
556 8
557 }
558 }
559
560 #[must_use]
562 pub fn compute_eob(eob_multi: u8, eob_extra: u16) -> u16 {
563 let group_idx = eob_multi as usize;
564 if group_idx >= EOB_GROUP_START.len() {
565 return 0;
566 }
567
568 let base = EOB_GROUP_START[group_idx];
569 base + eob_extra
570 }
571}
572
573#[derive(Clone, Copy, Debug, PartialEq, Eq)]
575pub enum EobPt {
576 EobPt0 = 0,
578 EobPt1 = 1,
580 EobPt2 = 2,
582 EobPt3To4 = 3,
584 EobPt5To8 = 4,
586 EobPt9To16 = 5,
588 EobPt17To32 = 6,
590 EobPt33To64 = 7,
592 EobPt65To128 = 8,
594 EobPt129To256 = 9,
596 EobPt257To512 = 10,
598 EobPt513To1024 = 11,
600}
601
602impl EobPt {
603 #[must_use]
617 pub fn from_eob(eob: u16) -> Self {
618 match eob {
619 0 => Self::EobPt0,
620 1 => Self::EobPt1,
621 2 => Self::EobPt2,
622 3..=4 => Self::EobPt3To4,
623 5..=8 => Self::EobPt5To8,
624 9..=16 => Self::EobPt9To16,
625 17..=32 => Self::EobPt17To32,
626 33..=64 => Self::EobPt33To64,
627 65..=128 => Self::EobPt65To128,
628 129..=256 => Self::EobPt129To256,
629 257..=512 => Self::EobPt257To512,
630 _ => Self::EobPt513To1024,
631 }
632 }
633
634 pub fn from_symbol(symbol: usize) -> CodecResult<Self> {
652 match symbol {
653 0 => Ok(Self::EobPt0),
654 1 => Ok(Self::EobPt1),
655 2 => Ok(Self::EobPt2),
656 3 => Ok(Self::EobPt3To4),
657 4 => Ok(Self::EobPt5To8),
658 5 => Ok(Self::EobPt9To16),
659 6 => Ok(Self::EobPt17To32),
660 7 => Ok(Self::EobPt33To64),
661 8 => Ok(Self::EobPt65To128),
662 9 => Ok(Self::EobPt129To256),
663 10 => Ok(Self::EobPt257To512),
664 11 => Ok(Self::EobPt513To1024),
665 other => Err(CodecError::InvalidBitstream(format!(
666 "EOB-multi symbol {other} out of range 0..=11"
667 ))),
668 }
669 }
670
671 #[must_use]
673 pub const fn base_eob(self) -> u16 {
674 match self {
675 Self::EobPt0 => 0,
676 Self::EobPt1 => 1,
677 Self::EobPt2 => 2,
678 Self::EobPt3To4 => 3,
679 Self::EobPt5To8 => 5,
680 Self::EobPt9To16 => 9,
681 Self::EobPt17To32 => 17,
682 Self::EobPt33To64 => 33,
683 Self::EobPt65To128 => 65,
684 Self::EobPt129To256 => 129,
685 Self::EobPt257To512 => 257,
686 Self::EobPt513To1024 => 513,
687 }
688 }
689
690 #[must_use]
692 pub const fn extra_bits(self) -> u8 {
693 match self {
694 Self::EobPt0 | Self::EobPt1 | Self::EobPt2 => 0,
695 Self::EobPt3To4 => 1,
696 Self::EobPt5To8 => 2,
697 Self::EobPt9To16 => 3,
698 Self::EobPt17To32 => 4,
699 Self::EobPt33To64 => 5,
700 Self::EobPt65To128 => 6,
701 Self::EobPt129To256 => 7,
702 Self::EobPt257To512 => 8,
703 Self::EobPt513To1024 => 9,
704 }
705 }
706}
707
708#[derive(Clone, Copy, Debug, Default)]
714pub struct CoeffBaseRange {
715 pub base_level: u8,
717 pub range_ctx: u8,
719}
720
721impl CoeffBaseRange {
722 #[must_use]
724 pub fn get_br_context(level_ctx: &LevelContext, pos: usize, width: usize) -> u8 {
725 let row = pos / width;
726 let col = pos % width;
727
728 let pos_ctx = if row + col == 0 {
730 0
731 } else if row + col < 2 {
732 7
733 } else {
734 14
735 };
736
737 pos_ctx + level_ctx.mag_context().min(6)
739 }
740
741 #[must_use]
743 pub fn compute_level(base: u8, range: u16) -> u32 {
744 u32::from(base) + u32::from(range)
745 }
746}
747
748#[must_use]
754pub fn dequantize_coeff(level: i32, dequant: i16, shift: u8) -> i32 {
755 let abs_level = level.abs();
756 let dq_level = (abs_level * i32::from(dequant)) >> shift;
757
758 if level < 0 {
759 -dq_level
760 } else {
761 dq_level
762 }
763}
764
765pub fn dequantize_block(coeffs: &mut [i32], dc_dequant: i16, ac_dequant: i16, shift: u8) {
767 if coeffs.is_empty() {
768 return;
769 }
770
771 coeffs[0] = dequantize_coeff(coeffs[0], dc_dequant, shift);
773
774 for coeff in coeffs.iter_mut().skip(1) {
776 *coeff = dequantize_coeff(*coeff, ac_dequant, shift);
777 }
778}
779
780#[must_use]
782pub const fn get_dequant_shift(bit_depth: u8) -> u8 {
783 match bit_depth {
784 8 => 0,
785 10 => 2,
786 12 => 4,
787 _ => 0,
788 }
789}
790
791#[derive(Clone, Debug)]
797pub struct CoeffBuffer {
798 coeffs: Vec<i32>,
800 width: usize,
802 height: usize,
804}
805
806impl CoeffBuffer {
807 #[must_use]
809 pub fn new(width: usize, height: usize) -> Self {
810 Self {
811 coeffs: vec![0; width * height],
812 width,
813 height,
814 }
815 }
816
817 #[must_use]
819 pub fn from_tx_size(tx_size: TxSize) -> Self {
820 Self::new(tx_size.width() as usize, tx_size.height() as usize)
821 }
822
823 #[must_use]
825 pub fn get(&self, row: usize, col: usize) -> i32 {
826 if row < self.height && col < self.width {
827 self.coeffs[row * self.width + col]
828 } else {
829 0
830 }
831 }
832
833 pub fn set(&mut self, row: usize, col: usize, value: i32) {
835 if row < self.height && col < self.width {
836 self.coeffs[row * self.width + col] = value;
837 }
838 }
839
840 pub fn clear(&mut self) {
842 self.coeffs.fill(0);
843 }
844
845 #[must_use]
847 pub const fn width(&self) -> usize {
848 self.width
849 }
850
851 #[must_use]
853 pub const fn height(&self) -> usize {
854 self.height
855 }
856
857 pub fn as_mut_slice(&mut self) -> &mut [i32] {
859 &mut self.coeffs
860 }
861
862 #[must_use]
864 pub fn as_slice(&self) -> &[i32] {
865 &self.coeffs
866 }
867
868 pub fn copy_from_scan(&mut self, src: &[i32], scan: &[u16]) {
870 for (i, &pos) in scan.iter().enumerate() {
871 if i < src.len() && (pos as usize) < self.coeffs.len() {
872 self.coeffs[pos as usize] = src[i];
873 }
874 }
875 }
876
877 pub fn copy_to_scan(&self, dst: &mut [i32], scan: &[u16]) {
879 for (i, &pos) in scan.iter().enumerate() {
880 if i < dst.len() && (pos as usize) < self.coeffs.len() {
881 dst[i] = self.coeffs[pos as usize];
882 }
883 }
884 }
885}
886
887impl Default for CoeffBuffer {
888 fn default() -> Self {
889 Self::new(4, 4)
890 }
891}
892
893#[must_use]
899pub fn get_neighbor_positions(pos: usize, width: usize, _height: usize) -> [(usize, bool); 5] {
900 let row = pos / width;
901 let col = pos % width;
902
903 let mut neighbors = [(0usize, false); 5];
904
905 if col > 0 {
907 neighbors[0] = (row * width + col - 1, true);
908 }
909
910 if row > 0 {
912 neighbors[1] = ((row - 1) * width + col, true);
913 }
914
915 if row > 0 && col > 0 {
917 neighbors[2] = ((row - 1) * width + col - 1, true);
918 }
919
920 if row > 0 && col + 1 < width {
922 neighbors[3] = ((row - 1) * width + col + 1, true);
923 }
924
925 if col > 1 {
927 neighbors[4] = (row * width + col - 2, true);
928 }
929
930 neighbors
931}
932
933#[must_use]
935pub fn compute_context_from_neighbors(levels: &[i32], neighbors: &[(usize, bool); 5]) -> u8 {
936 let mut mag = 0u32;
937 let mut count = 0u8;
938
939 for &(pos, valid) in neighbors.iter() {
940 if valid && pos < levels.len() {
941 let level = levels[pos].unsigned_abs();
942 mag += level;
943 if level > 0 {
944 count += 1;
945 }
946 }
947 }
948
949 let mag_ctx = if mag > 512 {
951 4
952 } else if mag > 256 {
953 3
954 } else if mag > 128 {
955 2
956 } else if mag > 64 {
957 1
958 } else {
959 0
960 };
961
962 mag_ctx * 3 + count.min(2)
963}
964
965#[must_use]
971pub fn compute_dc_sign_context(left_dc: i32, above_dc: i32) -> u8 {
972 let left_sign = left_dc.signum();
973 let above_sign = above_dc.signum();
974
975 let sum = left_sign + above_sign;
976
977 if sum < 0 {
978 0
979 } else if sum > 0 {
980 2
981 } else {
982 1
983 }
984}
985
986pub fn update_level_context(
988 left_ctx: &mut [u8],
989 above_ctx: &mut [u8],
990 level: i32,
991 row: usize,
992 col: usize,
993) {
994 let level_ctx = (level.unsigned_abs().min(63) as u8) + 1;
995
996 if row < left_ctx.len() {
997 left_ctx[row] = level_ctx;
998 }
999
1000 if col < above_ctx.len() {
1001 above_ctx[col] = level_ctx;
1002 }
1003}
1004
1005#[derive(Clone, Debug, Default)]
1011pub struct CoeffStats {
1012 pub zero_count: u32,
1014 pub level1_count: u32,
1016 pub level2_count: u32,
1018 pub high_level_count: u32,
1020 pub level_sum: u64,
1022 pub max_level: u32,
1024}
1025
1026impl CoeffStats {
1027 #[must_use]
1029 pub fn from_coeffs(coeffs: &[i32]) -> Self {
1030 let mut stats = Self::default();
1031
1032 for &coeff in coeffs {
1033 let level = coeff.unsigned_abs();
1034
1035 match level {
1036 0 => stats.zero_count += 1,
1037 1 => stats.level1_count += 1,
1038 2 => stats.level2_count += 1,
1039 _ => stats.high_level_count += 1,
1040 }
1041
1042 stats.level_sum += u64::from(level);
1043 stats.max_level = stats.max_level.max(level);
1044 }
1045
1046 stats
1047 }
1048
1049 #[must_use]
1051 pub fn nonzero_count(&self) -> u32 {
1052 self.level1_count + self.level2_count + self.high_level_count
1053 }
1054
1055 #[must_use]
1057 pub fn average_level(&self) -> f64 {
1058 let count = self.nonzero_count();
1059 if count > 0 {
1060 self.level_sum as f64 / count as f64
1061 } else {
1062 0.0
1063 }
1064 }
1065}
1066
1067#[cfg(test)]
1072mod tests {
1073 use super::*;
1074
1075 #[test]
1076 fn test_level_context() {
1077 let mut ctx = LevelContext::new();
1078 assert_eq!(ctx.mag, 0);
1079 assert_eq!(ctx.count, 0);
1080
1081 ctx.mag = 100;
1082 ctx.count = 2;
1083 assert_eq!(ctx.mag_context(), 1);
1084 assert_eq!(ctx.context(), 1 * 3 + 2);
1085 }
1086
1087 #[test]
1088 fn test_coeff_context_new() {
1089 let ctx = CoeffContext::new(TxSize::Tx8x8, TxType::DctDct, 0);
1090 assert_eq!(ctx.levels.len(), 64);
1091 assert_eq!(ctx.tx_class(), TxClass::Class2D);
1092 }
1093
1094 #[test]
1095 fn test_coeff_context_set_get() {
1096 let mut ctx = CoeffContext::new(TxSize::Tx4x4, TxType::DctDct, 0);
1097 ctx.set_coeff(5, 100, false);
1098 assert_eq!(ctx.get_coeff(5), 100);
1099
1100 ctx.set_coeff(10, 50, true);
1101 assert_eq!(ctx.get_coeff(10), -50);
1102 }
1103
1104 #[test]
1105 fn test_diagonal_scan_4x4() {
1106 let scan = generate_diagonal_scan(4, 4);
1107 assert_eq!(scan.len(), 16);
1108 assert_eq!(scan[0], 0); assert_eq!(scan[1], 1); }
1112
1113 #[test]
1114 fn test_horizontal_scan() {
1115 let scan = generate_horizontal_scan(4, 4);
1116 assert_eq!(scan.len(), 16);
1117 for i in 0..16 {
1118 assert_eq!(scan[i], i as u16);
1119 }
1120 }
1121
1122 #[test]
1123 fn test_vertical_scan() {
1124 let scan = generate_vertical_scan(4, 4);
1125 assert_eq!(scan.len(), 16);
1126 assert_eq!(scan[0], 0);
1127 assert_eq!(scan[1], 4);
1128 assert_eq!(scan[2], 8);
1129 assert_eq!(scan[3], 12);
1130 }
1131
1132 #[test]
1133 fn test_scan_order_cache() {
1134 let cache = ScanOrderCache::new();
1135 let scan = cache.get(TxSize::Tx4x4, TxClass::Class2D);
1136 assert_eq!(scan.len(), 16);
1137 }
1138
1139 #[test]
1140 fn test_eob_context() {
1141 let ctx = EobContext::new(TxSize::Tx8x8);
1142 assert!(ctx.eob_extra > 0);
1143 }
1144
1145 #[test]
1146 fn test_eob_pt() {
1147 assert_eq!(EobPt::from_eob(0), EobPt::EobPt0);
1148 assert_eq!(EobPt::from_eob(1), EobPt::EobPt1);
1149 assert_eq!(EobPt::from_eob(5), EobPt::EobPt5To8);
1150 assert_eq!(EobPt::from_eob(100), EobPt::EobPt65To128);
1151
1152 assert_eq!(EobPt::EobPt5To8.extra_bits(), 2);
1153 assert_eq!(EobPt::EobPt5To8.base_eob(), 5);
1154 }
1155
1156 #[test]
1157 fn test_eob_pt_from_symbol_basic() {
1158 let expected = [
1162 EobPt::EobPt0,
1163 EobPt::EobPt1,
1164 EobPt::EobPt2,
1165 EobPt::EobPt3To4,
1166 EobPt::EobPt5To8,
1167 EobPt::EobPt9To16,
1168 EobPt::EobPt17To32,
1169 EobPt::EobPt33To64,
1170 EobPt::EobPt65To128,
1171 EobPt::EobPt129To256,
1172 EobPt::EobPt257To512,
1173 EobPt::EobPt513To1024,
1174 ];
1175 for (symbol, &want) in expected.iter().enumerate() {
1176 let got = EobPt::from_symbol(symbol).expect("symbol in 0..=11 must succeed");
1177 assert_eq!(got, want, "symbol {symbol} should map to {want:?}");
1178 assert_eq!(
1179 got as usize, symbol,
1180 "discriminant of {got:?} should equal symbol {symbol}"
1181 );
1182 }
1183 }
1184
1185 #[test]
1186 fn test_eob_pt_from_symbol_rejects_out_of_range() {
1187 for bad in 12..=15 {
1191 let err = EobPt::from_symbol(bad).expect_err("symbols 12..=15 must be rejected");
1192 let msg = format!("{err}");
1193 assert!(
1194 msg.contains("EOB-multi"),
1195 "error message should mention EOB-multi: {msg}"
1196 );
1197 }
1198 assert!(EobPt::from_symbol(usize::MAX).is_err());
1199 }
1200
1201 #[test]
1202 fn test_eob_pt_from_symbol_differs_from_from_eob() {
1203 for symbol in 0..=3usize {
1212 let from_sym = EobPt::from_symbol(symbol).expect("low symbol always succeeds");
1213 let from_pos = EobPt::from_eob(symbol as u16);
1214 assert_eq!(from_sym, from_pos, "symbols 0..=3 coincidentally match");
1215 }
1216 for symbol in 4..=11usize {
1217 let from_sym = EobPt::from_symbol(symbol).expect("symbols 4..=11 must succeed");
1218 let from_pos = EobPt::from_eob(symbol as u16);
1219 assert_ne!(
1220 from_sym, from_pos,
1221 "symbol {symbol} must NOT coincide with from_eob (the bug)",
1222 );
1223 assert_ne!(
1226 from_sym.extra_bits(),
1227 from_pos.extra_bits(),
1228 "symbol {symbol} extra_bits mismatch is the desync source",
1229 );
1230 }
1231 }
1232
1233 #[test]
1234 fn test_eob_pt_symbol_chain_covers_group_start() {
1235 for symbol in 0..=11usize {
1240 let pt = EobPt::from_symbol(symbol).expect("legal symbol");
1241 assert_eq!(pt.base_eob(), EOB_GROUP_START[symbol]);
1242
1243 let extra = pt.extra_bits();
1244 let range_size: u16 = 1u16 << extra;
1246 if symbol <= 2 {
1247 assert_eq!(range_size, 1);
1248 } else {
1249 let max_offset = range_size - 1;
1251 let max_pos = EobPt::from_eob(pt.base_eob() + max_offset);
1252 assert_eq!(
1253 max_pos, pt,
1254 "EOB position at top of group must classify as same EobPt",
1255 );
1256 }
1257 }
1258 }
1259
1260 #[test]
1261 fn test_dequantize_coeff() {
1262 let level = 10;
1263 let dequant = 16;
1264 let result = dequantize_coeff(level, dequant, 0);
1265 assert_eq!(result, 160);
1266
1267 let neg_result = dequantize_coeff(-level, dequant, 0);
1268 assert_eq!(neg_result, -160);
1269 }
1270
1271 #[test]
1272 fn test_dequantize_block() {
1273 let mut coeffs = vec![10, 5, 5, 5, 5, 5, 5, 5];
1274 dequantize_block(&mut coeffs, 20, 10, 0);
1275
1276 assert_eq!(coeffs[0], 200); assert_eq!(coeffs[1], 50); }
1279
1280 #[test]
1281 fn test_get_dequant_shift() {
1282 assert_eq!(get_dequant_shift(8), 0);
1283 assert_eq!(get_dequant_shift(10), 2);
1284 assert_eq!(get_dequant_shift(12), 4);
1285 }
1286
1287 #[test]
1288 fn test_coeff_buffer() {
1289 let mut buf = CoeffBuffer::new(4, 4);
1290 buf.set(1, 2, 100);
1291 assert_eq!(buf.get(1, 2), 100);
1292 assert_eq!(buf.get(0, 0), 0);
1293
1294 buf.clear();
1295 assert_eq!(buf.get(1, 2), 0);
1296 }
1297
1298 #[test]
1299 fn test_coeff_buffer_from_tx_size() {
1300 let buf = CoeffBuffer::from_tx_size(TxSize::Tx8x8);
1301 assert_eq!(buf.as_slice().len(), 64);
1302 }
1303
1304 #[test]
1305 fn test_neighbor_positions() {
1306 let neighbors = get_neighbor_positions(5, 4, 4);
1307
1308 assert!(neighbors[0].1);
1311 assert_eq!(neighbors[0].0, 4);
1312
1313 assert!(neighbors[1].1);
1315 assert_eq!(neighbors[1].0, 1);
1316 }
1317
1318 #[test]
1319 fn test_compute_dc_sign_context() {
1320 assert_eq!(compute_dc_sign_context(-5, -3), 0); assert_eq!(compute_dc_sign_context(5, 3), 2); assert_eq!(compute_dc_sign_context(-5, 3), 1); assert_eq!(compute_dc_sign_context(0, 0), 1); }
1325
1326 #[test]
1327 fn test_coeff_stats() {
1328 let coeffs = vec![0, 1, 2, 3, 0, 1, 5, 0];
1329 let stats = CoeffStats::from_coeffs(&coeffs);
1330
1331 assert_eq!(stats.zero_count, 3);
1332 assert_eq!(stats.level1_count, 2);
1333 assert_eq!(stats.level2_count, 1);
1334 assert_eq!(stats.high_level_count, 2);
1335 assert_eq!(stats.max_level, 5);
1336 assert_eq!(stats.nonzero_count(), 5);
1337 }
1338
1339 #[test]
1340 fn test_coeff_context_dc_sign() {
1341 let ctx = CoeffContext::new(TxSize::Tx4x4, TxType::DctDct, 0);
1342 let dc_ctx = ctx.dc_sign_context();
1345 assert!(dc_ctx <= 2);
1347 }
1348
1349 #[test]
1350 fn test_coeff_context_level_context() {
1351 let mut ctx = CoeffContext::new(TxSize::Tx4x4, TxType::DctDct, 0);
1352 ctx.levels[0] = 5;
1353 ctx.levels[1] = 3;
1354 ctx.levels[4] = 2;
1355
1356 let level_ctx = ctx.compute_level_context(5);
1357 assert!(level_ctx.mag > 0);
1359 }
1360
1361 #[test]
1362 fn test_eob_compute() {
1363 assert_eq!(EobContext::compute_eob(0, 0), 0);
1365 assert_eq!(EobContext::compute_eob(1, 0), 1);
1366 assert_eq!(EobContext::compute_eob(2, 0), 2);
1367 }
1368
1369 #[test]
1370 fn test_coeff_base_range() {
1371 let level_ctx = LevelContext {
1372 mag: 100,
1373 count: 2,
1374 pos_ctx: 1,
1375 };
1376
1377 let br_ctx = CoeffBaseRange::get_br_context(&level_ctx, 5, 4);
1378 assert!(br_ctx > 0);
1379
1380 let level = CoeffBaseRange::compute_level(2, 5);
1381 assert_eq!(level, 7);
1382 }
1383
1384 #[test]
1385 fn test_constants() {
1386 assert_eq!(MAX_EOB, 4096);
1387 assert_eq!(EOB_COEF_CONTEXTS, 9);
1388 assert_eq!(TX_CLASSES, 3);
1389 }
1390
1391 #[test]
1392 fn test_coeff_context_reset() {
1393 let mut ctx = CoeffContext::new(TxSize::Tx4x4, TxType::DctDct, 0);
1394 ctx.eob = 10;
1395 ctx.levels[5] = 100;
1396
1397 ctx.reset();
1398 assert_eq!(ctx.eob, 0);
1399 assert_eq!(ctx.levels[5], 0);
1400 }
1401
1402 #[test]
1403 fn test_coeff_context_count_nonzero() {
1404 let mut ctx = CoeffContext::new(TxSize::Tx4x4, TxType::DctDct, 0);
1405 ctx.levels[0] = 5;
1406 ctx.levels[5] = 3;
1407 ctx.levels[10] = -2;
1408
1409 assert_eq!(ctx.count_nonzero(), 3);
1410 }
1411
1412 #[test]
1413 fn test_scan_order_all_sizes() {
1414 for size_idx in 0..19 {
1416 if let Some(tx_size) = TxSize::from_u8(size_idx) {
1417 for class_idx in 0..3 {
1418 if let Some(tx_class) = TxClass::from_u8(class_idx) {
1419 let scan = get_scan_order(tx_size, tx_class);
1420 assert_eq!(scan.len(), tx_size.area() as usize);
1421 }
1422 }
1423 }
1424 }
1425 }
1426}