1#![forbid(unsafe_code)]
13#![allow(dead_code)]
14#![allow(clippy::too_many_arguments)]
15#![allow(clippy::must_use_candidate)]
16#![allow(clippy::cast_sign_loss)]
17#![allow(clippy::cast_precision_loss)]
18#![allow(clippy::cast_possible_truncation)]
19#![allow(clippy::cast_possible_wrap)]
20#![allow(clippy::needless_range_loop)]
21#![allow(clippy::unused_self)]
22#![allow(clippy::redundant_closure_for_method_calls)]
23#![allow(clippy::trivially_copy_pass_by_ref)]
24
25use super::types::{BlockSize, MotionVector, MvCost};
26
27pub const MAX_PREDICTORS: usize = 8;
29
30pub const SPATIAL_WEIGHT: u32 = 2;
32
33pub const TEMPORAL_WEIGHT: u32 = 1;
35
36#[derive(Clone, Copy, Debug, PartialEq, Eq)]
38pub enum NeighborPosition {
39 Left,
41 Top,
43 TopRight,
45 TopLeft,
47 CoLocated,
49 BelowLeft,
51 Median,
53}
54
55#[derive(Clone, Copy, Debug, Default)]
57pub struct NeighborInfo {
58 pub mv: MotionVector,
60 pub ref_idx: i8,
62 pub available: bool,
64 pub is_inter: bool,
66}
67
68impl NeighborInfo {
69 #[must_use]
71 pub const fn unavailable() -> Self {
72 Self {
73 mv: MotionVector::zero(),
74 ref_idx: -1,
75 available: false,
76 is_inter: false,
77 }
78 }
79
80 #[must_use]
82 pub const fn with_mv(mv: MotionVector, ref_idx: i8) -> Self {
83 Self {
84 mv,
85 ref_idx,
86 available: true,
87 is_inter: true,
88 }
89 }
90
91 #[must_use]
93 pub const fn intra() -> Self {
94 Self {
95 mv: MotionVector::zero(),
96 ref_idx: -1,
97 available: true,
98 is_inter: false,
99 }
100 }
101}
102
103#[derive(Clone, Debug, Default)]
105pub struct MvPredContext {
106 pub left: NeighborInfo,
108 pub top: NeighborInfo,
110 pub top_right: NeighborInfo,
112 pub top_left: NeighborInfo,
114 pub co_located: NeighborInfo,
116 pub ref_idx: i8,
118 pub mi_col: usize,
120 pub mi_row: usize,
122 pub block_size: BlockSize,
124}
125
126impl MvPredContext {
127 #[must_use]
129 pub const fn new() -> Self {
130 Self {
131 left: NeighborInfo::unavailable(),
132 top: NeighborInfo::unavailable(),
133 top_right: NeighborInfo::unavailable(),
134 top_left: NeighborInfo::unavailable(),
135 co_located: NeighborInfo::unavailable(),
136 ref_idx: 0,
137 mi_col: 0,
138 mi_row: 0,
139 block_size: BlockSize::Block8x8,
140 }
141 }
142
143 #[must_use]
145 pub const fn at_position(mut self, mi_row: usize, mi_col: usize) -> Self {
146 self.mi_row = mi_row;
147 self.mi_col = mi_col;
148 self
149 }
150
151 #[must_use]
153 pub const fn with_size(mut self, size: BlockSize) -> Self {
154 self.block_size = size;
155 self
156 }
157
158 #[must_use]
160 pub const fn with_ref(mut self, ref_idx: i8) -> Self {
161 self.ref_idx = ref_idx;
162 self
163 }
164
165 #[must_use]
167 pub const fn with_left(mut self, info: NeighborInfo) -> Self {
168 self.left = info;
169 self
170 }
171
172 #[must_use]
174 pub const fn with_top(mut self, info: NeighborInfo) -> Self {
175 self.top = info;
176 self
177 }
178
179 #[must_use]
181 pub const fn with_top_right(mut self, info: NeighborInfo) -> Self {
182 self.top_right = info;
183 self
184 }
185
186 #[must_use]
188 pub const fn with_top_left(mut self, info: NeighborInfo) -> Self {
189 self.top_left = info;
190 self
191 }
192
193 #[must_use]
195 pub const fn with_co_located(mut self, info: NeighborInfo) -> Self {
196 self.co_located = info;
197 self
198 }
199}
200
201#[derive(Clone, Copy, Debug)]
203pub struct MvCandidate {
204 pub mv: MotionVector,
206 pub weight: u32,
208 pub source: NeighborPosition,
210}
211
212impl MvCandidate {
213 #[must_use]
215 pub const fn new(mv: MotionVector, weight: u32, source: NeighborPosition) -> Self {
216 Self { mv, weight, source }
217 }
218
219 #[must_use]
221 pub const fn zero() -> Self {
222 Self {
223 mv: MotionVector::zero(),
224 weight: 0,
225 source: NeighborPosition::Median,
226 }
227 }
228}
229
230#[derive(Clone, Debug)]
232pub struct MvPredictorList {
233 candidates: [MvCandidate; MAX_PREDICTORS],
235 count: usize,
237}
238
239impl Default for MvPredictorList {
240 fn default() -> Self {
241 Self::new()
242 }
243}
244
245impl MvPredictorList {
246 #[must_use]
248 pub fn new() -> Self {
249 Self {
250 candidates: [MvCandidate::zero(); MAX_PREDICTORS],
251 count: 0,
252 }
253 }
254
255 pub fn add(&mut self, candidate: MvCandidate) {
257 if self.count < MAX_PREDICTORS {
258 for i in 0..self.count {
260 if self.candidates[i].mv == candidate.mv {
261 if candidate.weight > self.candidates[i].weight {
263 self.candidates[i].weight = candidate.weight;
264 }
265 return;
266 }
267 }
268 self.candidates[self.count] = candidate;
269 self.count += 1;
270 }
271 }
272
273 pub fn add_from_neighbor(&mut self, info: &NeighborInfo, source: NeighborPosition) {
275 if info.available && info.is_inter {
276 let weight = match source {
277 NeighborPosition::Left | NeighborPosition::Top | NeighborPosition::TopRight => {
278 SPATIAL_WEIGHT
279 }
280 NeighborPosition::CoLocated => TEMPORAL_WEIGHT,
281 _ => 1,
282 };
283 self.add(MvCandidate::new(info.mv, weight, source));
284 }
285 }
286
287 pub fn sort_by_weight(&mut self) {
289 for i in 1..self.count {
291 let key = self.candidates[i];
292 let mut j = i;
293 while j > 0 && self.candidates[j - 1].weight < key.weight {
294 self.candidates[j] = self.candidates[j - 1];
295 j -= 1;
296 }
297 self.candidates[j] = key;
298 }
299 }
300
301 #[must_use]
303 pub const fn len(&self) -> usize {
304 self.count
305 }
306
307 #[must_use]
309 pub const fn is_empty(&self) -> bool {
310 self.count == 0
311 }
312
313 #[must_use]
315 pub const fn get(&self, index: usize) -> Option<&MvCandidate> {
316 if index < self.count {
317 Some(&self.candidates[index])
318 } else {
319 None
320 }
321 }
322
323 #[must_use]
325 pub fn best(&self) -> MotionVector {
326 if self.count > 0 {
327 self.candidates[0].mv
328 } else {
329 MotionVector::zero()
330 }
331 }
332
333 #[must_use]
335 pub fn as_slice(&self) -> &[MvCandidate] {
336 &self.candidates[..self.count]
337 }
338
339 pub fn motion_vectors(&self) -> Vec<MotionVector> {
341 self.candidates[..self.count].iter().map(|c| c.mv).collect()
342 }
343}
344
345#[derive(Clone, Copy, Debug, Default)]
347pub struct SpatialPredictor;
348
349impl SpatialPredictor {
350 #[must_use]
352 pub const fn new() -> Self {
353 Self
354 }
355
356 #[must_use]
358 pub fn get_left(ctx: &MvPredContext) -> Option<MotionVector> {
359 if ctx.left.available && ctx.left.is_inter {
360 Some(ctx.left.mv)
361 } else {
362 None
363 }
364 }
365
366 #[must_use]
368 pub fn get_top(ctx: &MvPredContext) -> Option<MotionVector> {
369 if ctx.top.available && ctx.top.is_inter {
370 Some(ctx.top.mv)
371 } else {
372 None
373 }
374 }
375
376 #[must_use]
378 pub fn get_top_right(ctx: &MvPredContext) -> Option<MotionVector> {
379 if ctx.top_right.available && ctx.top_right.is_inter {
380 Some(ctx.top_right.mv)
381 } else {
382 None
383 }
384 }
385
386 #[must_use]
388 pub fn get_top_left(ctx: &MvPredContext) -> Option<MotionVector> {
389 if ctx.top_left.available && ctx.top_left.is_inter {
390 Some(ctx.top_left.mv)
391 } else {
392 None
393 }
394 }
395
396 #[must_use]
398 pub fn median(a: MotionVector, b: MotionVector, c: MotionVector) -> MotionVector {
399 let dx = Self::median_of_3(a.dx, b.dx, c.dx);
401 let dy = Self::median_of_3(a.dy, b.dy, c.dy);
402 MotionVector::new(dx, dy)
403 }
404
405 #[must_use]
407 fn median_of_3(a: i32, b: i32, c: i32) -> i32 {
408 a.max(b.min(c)).min(b.max(c))
409 }
410
411 #[must_use]
413 pub fn calculate_median(ctx: &MvPredContext) -> MotionVector {
414 let left = Self::get_left(ctx).unwrap_or_else(MotionVector::zero);
415 let top = Self::get_top(ctx).unwrap_or_else(MotionVector::zero);
416 let top_right = Self::get_top_right(ctx)
417 .or_else(|| Self::get_top_left(ctx))
418 .unwrap_or_else(MotionVector::zero);
419
420 Self::median(left, top, top_right)
421 }
422
423 pub fn build_predictors(ctx: &MvPredContext, list: &mut MvPredictorList) {
425 list.add_from_neighbor(&ctx.left, NeighborPosition::Left);
426 list.add_from_neighbor(&ctx.top, NeighborPosition::Top);
427 list.add_from_neighbor(&ctx.top_right, NeighborPosition::TopRight);
428 list.add_from_neighbor(&ctx.top_left, NeighborPosition::TopLeft);
429
430 let mut neighbor_count = 0;
432 if ctx.left.available && ctx.left.is_inter {
433 neighbor_count += 1;
434 }
435 if ctx.top.available && ctx.top.is_inter {
436 neighbor_count += 1;
437 }
438 if ctx.top_right.available && ctx.top_right.is_inter {
439 neighbor_count += 1;
440 }
441
442 if neighbor_count >= 2 {
443 let median = Self::calculate_median(ctx);
444 list.add(MvCandidate::new(
445 median,
446 SPATIAL_WEIGHT + 1,
447 NeighborPosition::Median,
448 ));
449 }
450 }
451}
452
453#[derive(Clone, Copy, Debug, Default)]
455pub struct TemporalPredictor;
456
457impl TemporalPredictor {
458 #[must_use]
460 pub const fn new() -> Self {
461 Self
462 }
463
464 #[must_use]
466 pub fn get_co_located(ctx: &MvPredContext) -> Option<MotionVector> {
467 if ctx.co_located.available && ctx.co_located.is_inter {
468 Some(ctx.co_located.mv)
469 } else {
470 None
471 }
472 }
473
474 #[must_use]
480 #[allow(clippy::cast_possible_truncation)]
481 pub fn scale_mv(mv: MotionVector, src_dist: i32, dst_dist: i32) -> MotionVector {
482 if src_dist == 0 || src_dist == dst_dist {
483 return mv;
484 }
485
486 let scale_x = (i64::from(mv.dx) * i64::from(dst_dist)) / i64::from(src_dist);
487 let scale_y = (i64::from(mv.dy) * i64::from(dst_dist)) / i64::from(src_dist);
488
489 MotionVector::new(scale_x as i32, scale_y as i32)
490 }
491
492 pub fn build_predictors(ctx: &MvPredContext, list: &mut MvPredictorList) {
494 list.add_from_neighbor(&ctx.co_located, NeighborPosition::CoLocated);
495 }
496}
497
498#[derive(Clone, Debug, Default)]
500pub struct MvPredictor {
501 spatial: SpatialPredictor,
503 temporal: TemporalPredictor,
505 predictors: MvPredictorList,
507}
508
509impl MvPredictor {
510 #[must_use]
512 pub fn new() -> Self {
513 Self {
514 spatial: SpatialPredictor::new(),
515 temporal: TemporalPredictor::new(),
516 predictors: MvPredictorList::new(),
517 }
518 }
519
520 pub fn build(&mut self, ctx: &MvPredContext) {
522 self.predictors = MvPredictorList::new();
523
524 self.predictors.add(MvCandidate::new(
526 MotionVector::zero(),
527 1,
528 NeighborPosition::Median,
529 ));
530
531 SpatialPredictor::build_predictors(ctx, &mut self.predictors);
533
534 TemporalPredictor::build_predictors(ctx, &mut self.predictors);
536
537 self.predictors.sort_by_weight();
539 }
540
541 #[must_use]
543 pub fn best_mvp(&self) -> MotionVector {
544 self.predictors.best()
545 }
546
547 #[must_use]
549 pub fn all_predictors(&self) -> &MvPredictorList {
550 &self.predictors
551 }
552
553 pub fn motion_vectors(&self) -> Vec<MotionVector> {
555 self.predictors.motion_vectors()
556 }
557}
558
559#[derive(Clone, Debug)]
561pub struct MvCostCalculator {
562 lambda: f32,
564 cost_table: Option<Vec<u32>>,
566}
567
568impl Default for MvCostCalculator {
569 fn default() -> Self {
570 Self::new(1.0)
571 }
572}
573
574impl MvCostCalculator {
575 #[must_use]
577 pub const fn new(lambda: f32) -> Self {
578 Self {
579 lambda,
580 cost_table: None,
581 }
582 }
583
584 pub fn build_cost_table(&mut self, max_mv: i32) {
586 let size = (2 * max_mv + 1) as usize;
587 let mut table = vec![0u32; size];
588
589 for i in 0..size {
590 let mv_component = i as i32 - max_mv;
591 table[i] = self.component_cost(mv_component);
592 }
593
594 self.cost_table = Some(table);
595 }
596
597 #[must_use]
599 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
600 fn component_cost(&self, value: i32) -> u32 {
601 if value == 0 {
602 return (self.lambda * 1.0) as u32;
603 }
604
605 let abs_val = value.unsigned_abs();
606 let log2 = 32 - abs_val.leading_zeros();
608 let bits = log2 * 2 + 2;
609 ((bits as f32) * self.lambda) as u32
610 }
611
612 #[must_use]
614 pub fn cost(&self, mv: &MotionVector, mvp: &MotionVector) -> u32 {
615 let diff = *mv - *mvp;
616 self.component_cost(diff.dx) + self.component_cost(diff.dy)
617 }
618
619 #[must_use]
621 pub fn rd_cost(&self, mv: &MotionVector, mvp: &MotionVector, distortion: u32) -> u32 {
622 distortion.saturating_add(self.cost(mv, mvp))
623 }
624
625 #[must_use]
627 pub fn to_mv_cost(&self, mvp: MotionVector) -> MvCost {
628 MvCost::with_ref_mv(self.lambda, mvp)
629 }
630}
631
632#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
634pub enum MvpMode {
635 #[default]
637 Median,
638 Left,
640 Top,
642 Temporal,
644 Zero,
646}
647
648impl MvpMode {
649 #[must_use]
651 pub fn get_mvp(&self, ctx: &MvPredContext) -> MotionVector {
652 match self {
653 Self::Median => SpatialPredictor::calculate_median(ctx),
654 Self::Left => SpatialPredictor::get_left(ctx).unwrap_or_else(MotionVector::zero),
655 Self::Top => SpatialPredictor::get_top(ctx).unwrap_or_else(MotionVector::zero),
656 Self::Temporal => {
657 TemporalPredictor::get_co_located(ctx).unwrap_or_else(MotionVector::zero)
658 }
659 Self::Zero => MotionVector::zero(),
660 }
661 }
662}
663
664#[cfg(test)]
665mod tests {
666 use super::*;
667
668 #[test]
669 fn test_neighbor_info_unavailable() {
670 let info = NeighborInfo::unavailable();
671 assert!(!info.available);
672 assert!(!info.is_inter);
673 }
674
675 #[test]
676 fn test_neighbor_info_with_mv() {
677 let mv = MotionVector::new(10, 20);
678 let info = NeighborInfo::with_mv(mv, 0);
679 assert!(info.available);
680 assert!(info.is_inter);
681 assert_eq!(info.mv.dx, 10);
682 assert_eq!(info.mv.dy, 20);
683 }
684
685 #[test]
686 fn test_mv_pred_context_builder() {
687 let left = NeighborInfo::with_mv(MotionVector::new(5, 5), 0);
688 let ctx = MvPredContext::new()
689 .at_position(10, 20)
690 .with_size(BlockSize::Block16x16)
691 .with_ref(0)
692 .with_left(left);
693
694 assert_eq!(ctx.mi_row, 10);
695 assert_eq!(ctx.mi_col, 20);
696 assert_eq!(ctx.block_size, BlockSize::Block16x16);
697 assert!(ctx.left.available);
698 }
699
700 #[test]
701 fn test_mv_candidate() {
702 let mv = MotionVector::new(10, 20);
703 let candidate = MvCandidate::new(mv, 5, NeighborPosition::Left);
704
705 assert_eq!(candidate.mv.dx, 10);
706 assert_eq!(candidate.weight, 5);
707 assert_eq!(candidate.source, NeighborPosition::Left);
708 }
709
710 #[test]
711 fn test_mv_predictor_list_add() {
712 let mut list = MvPredictorList::new();
713
714 list.add(MvCandidate::new(
715 MotionVector::new(10, 20),
716 2,
717 NeighborPosition::Left,
718 ));
719 list.add(MvCandidate::new(
720 MotionVector::new(30, 40),
721 3,
722 NeighborPosition::Top,
723 ));
724
725 assert_eq!(list.len(), 2);
726 assert_eq!(list.get(0).expect("get should return value").mv.dx, 10);
727 assert_eq!(list.get(1).expect("get should return value").mv.dx, 30);
728 }
729
730 #[test]
731 fn test_mv_predictor_list_dedup() {
732 let mut list = MvPredictorList::new();
733
734 list.add(MvCandidate::new(
735 MotionVector::new(10, 20),
736 2,
737 NeighborPosition::Left,
738 ));
739 list.add(MvCandidate::new(
740 MotionVector::new(10, 20),
741 3,
742 NeighborPosition::Top,
743 ));
744
745 assert_eq!(list.len(), 1);
747 assert_eq!(list.get(0).expect("get should return value").weight, 3); }
749
750 #[test]
751 fn test_mv_predictor_list_sort() {
752 let mut list = MvPredictorList::new();
753
754 list.add(MvCandidate::new(
755 MotionVector::new(10, 20),
756 1,
757 NeighborPosition::Left,
758 ));
759 list.add(MvCandidate::new(
760 MotionVector::new(30, 40),
761 3,
762 NeighborPosition::Top,
763 ));
764 list.add(MvCandidate::new(
765 MotionVector::new(50, 60),
766 2,
767 NeighborPosition::TopRight,
768 ));
769
770 list.sort_by_weight();
771
772 assert_eq!(list.get(0).expect("get should return value").weight, 3);
773 assert_eq!(list.get(1).expect("get should return value").weight, 2);
774 assert_eq!(list.get(2).expect("get should return value").weight, 1);
775 }
776
777 #[test]
778 fn test_spatial_predictor_median() {
779 let a = MotionVector::new(10, 30);
780 let b = MotionVector::new(20, 20);
781 let c = MotionVector::new(30, 10);
782
783 let median = SpatialPredictor::median(a, b, c);
784
785 assert_eq!(median.dx, 20); assert_eq!(median.dy, 20); }
788
789 #[test]
790 fn test_spatial_predictor_build() {
791 let ctx = MvPredContext::new()
792 .with_left(NeighborInfo::with_mv(MotionVector::new(10, 10), 0))
793 .with_top(NeighborInfo::with_mv(MotionVector::new(20, 20), 0));
794
795 let mut list = MvPredictorList::new();
796 SpatialPredictor::build_predictors(&ctx, &mut list);
797
798 assert!(list.len() >= 2);
799 }
800
801 #[test]
802 fn test_temporal_predictor_scale() {
803 let mv = MotionVector::new(100, 200);
804
805 let same = TemporalPredictor::scale_mv(mv, 1, 1);
807 assert_eq!(same.dx, 100);
808 assert_eq!(same.dy, 200);
809
810 let doubled = TemporalPredictor::scale_mv(mv, 1, 2);
812 assert_eq!(doubled.dx, 200);
813 assert_eq!(doubled.dy, 400);
814
815 let halved = TemporalPredictor::scale_mv(mv, 2, 1);
817 assert_eq!(halved.dx, 50);
818 assert_eq!(halved.dy, 100);
819 }
820
821 #[test]
822 fn test_mv_predictor_build() {
823 let ctx = MvPredContext::new()
824 .with_left(NeighborInfo::with_mv(MotionVector::new(10, 10), 0))
825 .with_top(NeighborInfo::with_mv(MotionVector::new(20, 20), 0))
826 .with_co_located(NeighborInfo::with_mv(MotionVector::new(15, 15), 0));
827
828 let mut predictor = MvPredictor::new();
829 predictor.build(&ctx);
830
831 assert!(predictor.all_predictors().len() >= 3);
833
834 let mvp = predictor.best_mvp();
836 assert!(mvp.dx != 0 || mvp.dy != 0 || predictor.all_predictors().len() == 1);
837 }
838
839 #[test]
840 fn test_mv_cost_calculator() {
841 let calc = MvCostCalculator::new(1.0);
842 let mv = MotionVector::new(16, 16);
843 let mvp = MotionVector::zero();
844
845 let cost = calc.cost(&mv, &mvp);
846 assert!(cost > 0);
847
848 let same_cost = calc.cost(&mvp, &mvp);
850 assert!(same_cost < cost);
851 }
852
853 #[test]
854 fn test_mv_cost_rd() {
855 let calc = MvCostCalculator::new(1.0);
856 let mv = MotionVector::new(16, 16);
857 let mvp = MotionVector::zero();
858
859 let rd = calc.rd_cost(&mv, &mvp, 100);
860 assert!(rd > 100); }
862
863 #[test]
864 fn test_mvp_mode() {
865 let ctx =
866 MvPredContext::new().with_left(NeighborInfo::with_mv(MotionVector::new(10, 10), 0));
867
868 assert_eq!(MvpMode::Left.get_mvp(&ctx).dx, 10);
869 assert_eq!(MvpMode::Zero.get_mvp(&ctx).dx, 0);
870 }
871
872 #[test]
873 fn test_motion_vectors_extraction() {
874 let mut list = MvPredictorList::new();
875 list.add(MvCandidate::new(
876 MotionVector::new(10, 20),
877 2,
878 NeighborPosition::Left,
879 ));
880 list.add(MvCandidate::new(
881 MotionVector::new(30, 40),
882 1,
883 NeighborPosition::Top,
884 ));
885
886 let mvs = list.motion_vectors();
887 assert_eq!(mvs.len(), 2);
888 assert_eq!(mvs[0].dx, 10);
889 assert_eq!(mvs[1].dx, 30);
890 }
891}