Skip to main content

oximedia_codec/motion/
predictor.rs

1//! Motion vector prediction for video encoding.
2//!
3//! This module provides:
4//! - Spatial predictors (left, top, top-right, top-left)
5//! - Temporal predictors (co-located block from reference)
6//! - MVP (Motion Vector Predictor) selection
7//! - MV cost calculation using lambda-based RD optimization
8//!
9//! Good MV prediction reduces the bits needed to encode motion vectors
10//! and provides better starting points for motion search.
11
12#![forbid(unsafe_code)]
13#![allow(dead_code)]
14#![allow(clippy::too_many_arguments)]
15#![allow(clippy::must_use_candidate)]
16#![allow(clippy::cast_sign_loss)]
17#![allow(clippy::cast_precision_loss)]
18#![allow(clippy::cast_possible_truncation)]
19#![allow(clippy::cast_possible_wrap)]
20#![allow(clippy::needless_range_loop)]
21#![allow(clippy::unused_self)]
22#![allow(clippy::redundant_closure_for_method_calls)]
23#![allow(clippy::trivially_copy_pass_by_ref)]
24
25use super::types::{BlockSize, MotionVector, MvCost};
26
27/// Maximum number of MV predictors.
28pub const MAX_PREDICTORS: usize = 8;
29
30/// Weight for spatial predictors.
31pub const SPATIAL_WEIGHT: u32 = 2;
32
33/// Weight for temporal predictors.
34pub const TEMPORAL_WEIGHT: u32 = 1;
35
36/// Position of a neighboring block.
37#[derive(Clone, Copy, Debug, PartialEq, Eq)]
38pub enum NeighborPosition {
39    /// Left neighbor.
40    Left,
41    /// Top neighbor.
42    Top,
43    /// Top-right neighbor.
44    TopRight,
45    /// Top-left neighbor.
46    TopLeft,
47    /// Co-located block in reference frame.
48    CoLocated,
49    /// Below-left neighbor.
50    BelowLeft,
51    /// Median of neighbors.
52    Median,
53}
54
55/// Information about a neighboring block.
56#[derive(Clone, Copy, Debug, Default)]
57pub struct NeighborInfo {
58    /// Motion vector.
59    pub mv: MotionVector,
60    /// Reference frame index.
61    pub ref_idx: i8,
62    /// Is this neighbor available?
63    pub available: bool,
64    /// Is this an inter-predicted block?
65    pub is_inter: bool,
66}
67
68impl NeighborInfo {
69    /// Creates unavailable neighbor info.
70    #[must_use]
71    pub const fn unavailable() -> Self {
72        Self {
73            mv: MotionVector::zero(),
74            ref_idx: -1,
75            available: false,
76            is_inter: false,
77        }
78    }
79
80    /// Creates neighbor info with MV.
81    #[must_use]
82    pub const fn with_mv(mv: MotionVector, ref_idx: i8) -> Self {
83        Self {
84            mv,
85            ref_idx,
86            available: true,
87            is_inter: true,
88        }
89    }
90
91    /// Creates intra neighbor (no MV).
92    #[must_use]
93    pub const fn intra() -> Self {
94        Self {
95            mv: MotionVector::zero(),
96            ref_idx: -1,
97            available: true,
98            is_inter: false,
99        }
100    }
101}
102
103/// Context for motion vector prediction.
104#[derive(Clone, Debug, Default)]
105pub struct MvPredContext {
106    /// Left neighbor.
107    pub left: NeighborInfo,
108    /// Top neighbor.
109    pub top: NeighborInfo,
110    /// Top-right neighbor.
111    pub top_right: NeighborInfo,
112    /// Top-left neighbor.
113    pub top_left: NeighborInfo,
114    /// Co-located block in reference.
115    pub co_located: NeighborInfo,
116    /// Current reference frame index.
117    pub ref_idx: i8,
118    /// Block position (x in 4x4 units).
119    pub mi_col: usize,
120    /// Block position (y in 4x4 units).
121    pub mi_row: usize,
122    /// Block size.
123    pub block_size: BlockSize,
124}
125
126impl MvPredContext {
127    /// Creates a new prediction context.
128    #[must_use]
129    pub const fn new() -> Self {
130        Self {
131            left: NeighborInfo::unavailable(),
132            top: NeighborInfo::unavailable(),
133            top_right: NeighborInfo::unavailable(),
134            top_left: NeighborInfo::unavailable(),
135            co_located: NeighborInfo::unavailable(),
136            ref_idx: 0,
137            mi_col: 0,
138            mi_row: 0,
139            block_size: BlockSize::Block8x8,
140        }
141    }
142
143    /// Sets the block position.
144    #[must_use]
145    pub const fn at_position(mut self, mi_row: usize, mi_col: usize) -> Self {
146        self.mi_row = mi_row;
147        self.mi_col = mi_col;
148        self
149    }
150
151    /// Sets the block size.
152    #[must_use]
153    pub const fn with_size(mut self, size: BlockSize) -> Self {
154        self.block_size = size;
155        self
156    }
157
158    /// Sets the reference frame index.
159    #[must_use]
160    pub const fn with_ref(mut self, ref_idx: i8) -> Self {
161        self.ref_idx = ref_idx;
162        self
163    }
164
165    /// Sets the left neighbor.
166    #[must_use]
167    pub const fn with_left(mut self, info: NeighborInfo) -> Self {
168        self.left = info;
169        self
170    }
171
172    /// Sets the top neighbor.
173    #[must_use]
174    pub const fn with_top(mut self, info: NeighborInfo) -> Self {
175        self.top = info;
176        self
177    }
178
179    /// Sets the top-right neighbor.
180    #[must_use]
181    pub const fn with_top_right(mut self, info: NeighborInfo) -> Self {
182        self.top_right = info;
183        self
184    }
185
186    /// Sets the top-left neighbor.
187    #[must_use]
188    pub const fn with_top_left(mut self, info: NeighborInfo) -> Self {
189        self.top_left = info;
190        self
191    }
192
193    /// Sets the co-located block.
194    #[must_use]
195    pub const fn with_co_located(mut self, info: NeighborInfo) -> Self {
196        self.co_located = info;
197        self
198    }
199}
200
201/// MV predictor candidate.
202#[derive(Clone, Copy, Debug)]
203pub struct MvCandidate {
204    /// Predicted motion vector.
205    pub mv: MotionVector,
206    /// Weight/priority of this candidate.
207    pub weight: u32,
208    /// Source position.
209    pub source: NeighborPosition,
210}
211
212impl MvCandidate {
213    /// Creates a new MV candidate.
214    #[must_use]
215    pub const fn new(mv: MotionVector, weight: u32, source: NeighborPosition) -> Self {
216        Self { mv, weight, source }
217    }
218
219    /// Creates a zero MV candidate.
220    #[must_use]
221    pub const fn zero() -> Self {
222        Self {
223            mv: MotionVector::zero(),
224            weight: 0,
225            source: NeighborPosition::Median,
226        }
227    }
228}
229
230/// Motion vector predictor list.
231#[derive(Clone, Debug)]
232pub struct MvPredictorList {
233    /// Candidate predictors.
234    candidates: [MvCandidate; MAX_PREDICTORS],
235    /// Number of valid candidates.
236    count: usize,
237}
238
239impl Default for MvPredictorList {
240    fn default() -> Self {
241        Self::new()
242    }
243}
244
245impl MvPredictorList {
246    /// Creates a new empty predictor list.
247    #[must_use]
248    pub fn new() -> Self {
249        Self {
250            candidates: [MvCandidate::zero(); MAX_PREDICTORS],
251            count: 0,
252        }
253    }
254
255    /// Adds a candidate to the list.
256    pub fn add(&mut self, candidate: MvCandidate) {
257        if self.count < MAX_PREDICTORS {
258            // Check for duplicates
259            for i in 0..self.count {
260                if self.candidates[i].mv == candidate.mv {
261                    // Update weight if higher
262                    if candidate.weight > self.candidates[i].weight {
263                        self.candidates[i].weight = candidate.weight;
264                    }
265                    return;
266                }
267            }
268            self.candidates[self.count] = candidate;
269            self.count += 1;
270        }
271    }
272
273    /// Adds a candidate from neighbor info.
274    pub fn add_from_neighbor(&mut self, info: &NeighborInfo, source: NeighborPosition) {
275        if info.available && info.is_inter {
276            let weight = match source {
277                NeighborPosition::Left | NeighborPosition::Top | NeighborPosition::TopRight => {
278                    SPATIAL_WEIGHT
279                }
280                NeighborPosition::CoLocated => TEMPORAL_WEIGHT,
281                _ => 1,
282            };
283            self.add(MvCandidate::new(info.mv, weight, source));
284        }
285    }
286
287    /// Sorts candidates by weight (descending).
288    pub fn sort_by_weight(&mut self) {
289        // Simple insertion sort for small array
290        for i in 1..self.count {
291            let key = self.candidates[i];
292            let mut j = i;
293            while j > 0 && self.candidates[j - 1].weight < key.weight {
294                self.candidates[j] = self.candidates[j - 1];
295                j -= 1;
296            }
297            self.candidates[j] = key;
298        }
299    }
300
301    /// Returns the number of candidates.
302    #[must_use]
303    pub const fn len(&self) -> usize {
304        self.count
305    }
306
307    /// Returns true if empty.
308    #[must_use]
309    pub const fn is_empty(&self) -> bool {
310        self.count == 0
311    }
312
313    /// Gets a candidate by index.
314    #[must_use]
315    pub const fn get(&self, index: usize) -> Option<&MvCandidate> {
316        if index < self.count {
317            Some(&self.candidates[index])
318        } else {
319            None
320        }
321    }
322
323    /// Returns the best (first) predictor.
324    #[must_use]
325    pub fn best(&self) -> MotionVector {
326        if self.count > 0 {
327            self.candidates[0].mv
328        } else {
329            MotionVector::zero()
330        }
331    }
332
333    /// Returns all predictors as a slice.
334    #[must_use]
335    pub fn as_slice(&self) -> &[MvCandidate] {
336        &self.candidates[..self.count]
337    }
338
339    /// Extracts motion vectors only.
340    pub fn motion_vectors(&self) -> Vec<MotionVector> {
341        self.candidates[..self.count].iter().map(|c| c.mv).collect()
342    }
343}
344
345/// Spatial MV predictor calculator.
346#[derive(Clone, Copy, Debug, Default)]
347pub struct SpatialPredictor;
348
349impl SpatialPredictor {
350    /// Creates a new spatial predictor.
351    #[must_use]
352    pub const fn new() -> Self {
353        Self
354    }
355
356    /// Gets the left neighbor MV.
357    #[must_use]
358    pub fn get_left(ctx: &MvPredContext) -> Option<MotionVector> {
359        if ctx.left.available && ctx.left.is_inter {
360            Some(ctx.left.mv)
361        } else {
362            None
363        }
364    }
365
366    /// Gets the top neighbor MV.
367    #[must_use]
368    pub fn get_top(ctx: &MvPredContext) -> Option<MotionVector> {
369        if ctx.top.available && ctx.top.is_inter {
370            Some(ctx.top.mv)
371        } else {
372            None
373        }
374    }
375
376    /// Gets the top-right neighbor MV.
377    #[must_use]
378    pub fn get_top_right(ctx: &MvPredContext) -> Option<MotionVector> {
379        if ctx.top_right.available && ctx.top_right.is_inter {
380            Some(ctx.top_right.mv)
381        } else {
382            None
383        }
384    }
385
386    /// Gets the top-left neighbor MV.
387    #[must_use]
388    pub fn get_top_left(ctx: &MvPredContext) -> Option<MotionVector> {
389        if ctx.top_left.available && ctx.top_left.is_inter {
390            Some(ctx.top_left.mv)
391        } else {
392            None
393        }
394    }
395
396    /// Calculates the median predictor from three MVs.
397    #[must_use]
398    pub fn median(a: MotionVector, b: MotionVector, c: MotionVector) -> MotionVector {
399        // Component-wise median
400        let dx = Self::median_of_3(a.dx, b.dx, c.dx);
401        let dy = Self::median_of_3(a.dy, b.dy, c.dy);
402        MotionVector::new(dx, dy)
403    }
404
405    /// Median of three values.
406    #[must_use]
407    fn median_of_3(a: i32, b: i32, c: i32) -> i32 {
408        a.max(b.min(c)).min(b.max(c))
409    }
410
411    /// Calculates the median MVP from spatial neighbors.
412    #[must_use]
413    pub fn calculate_median(ctx: &MvPredContext) -> MotionVector {
414        let left = Self::get_left(ctx).unwrap_or_else(MotionVector::zero);
415        let top = Self::get_top(ctx).unwrap_or_else(MotionVector::zero);
416        let top_right = Self::get_top_right(ctx)
417            .or_else(|| Self::get_top_left(ctx))
418            .unwrap_or_else(MotionVector::zero);
419
420        Self::median(left, top, top_right)
421    }
422
423    /// Builds spatial predictor list.
424    pub fn build_predictors(ctx: &MvPredContext, list: &mut MvPredictorList) {
425        list.add_from_neighbor(&ctx.left, NeighborPosition::Left);
426        list.add_from_neighbor(&ctx.top, NeighborPosition::Top);
427        list.add_from_neighbor(&ctx.top_right, NeighborPosition::TopRight);
428        list.add_from_neighbor(&ctx.top_left, NeighborPosition::TopLeft);
429
430        // Add median if we have multiple neighbors
431        let mut neighbor_count = 0;
432        if ctx.left.available && ctx.left.is_inter {
433            neighbor_count += 1;
434        }
435        if ctx.top.available && ctx.top.is_inter {
436            neighbor_count += 1;
437        }
438        if ctx.top_right.available && ctx.top_right.is_inter {
439            neighbor_count += 1;
440        }
441
442        if neighbor_count >= 2 {
443            let median = Self::calculate_median(ctx);
444            list.add(MvCandidate::new(
445                median,
446                SPATIAL_WEIGHT + 1,
447                NeighborPosition::Median,
448            ));
449        }
450    }
451}
452
453/// Temporal MV predictor calculator.
454#[derive(Clone, Copy, Debug, Default)]
455pub struct TemporalPredictor;
456
457impl TemporalPredictor {
458    /// Creates a new temporal predictor.
459    #[must_use]
460    pub const fn new() -> Self {
461        Self
462    }
463
464    /// Gets the co-located MV from reference frame.
465    #[must_use]
466    pub fn get_co_located(ctx: &MvPredContext) -> Option<MotionVector> {
467        if ctx.co_located.available && ctx.co_located.is_inter {
468            Some(ctx.co_located.mv)
469        } else {
470            None
471        }
472    }
473
474    /// Scales MV for different temporal distances.
475    ///
476    /// If the co-located block references a frame at distance `src_dist`,
477    /// and we want to predict for target at distance `dst_dist`,
478    /// scale the MV proportionally.
479    #[must_use]
480    #[allow(clippy::cast_possible_truncation)]
481    pub fn scale_mv(mv: MotionVector, src_dist: i32, dst_dist: i32) -> MotionVector {
482        if src_dist == 0 || src_dist == dst_dist {
483            return mv;
484        }
485
486        let scale_x = (i64::from(mv.dx) * i64::from(dst_dist)) / i64::from(src_dist);
487        let scale_y = (i64::from(mv.dy) * i64::from(dst_dist)) / i64::from(src_dist);
488
489        MotionVector::new(scale_x as i32, scale_y as i32)
490    }
491
492    /// Builds temporal predictor.
493    pub fn build_predictors(ctx: &MvPredContext, list: &mut MvPredictorList) {
494        list.add_from_neighbor(&ctx.co_located, NeighborPosition::CoLocated);
495    }
496}
497
498/// Combined MV predictor that uses both spatial and temporal information.
499#[derive(Clone, Debug, Default)]
500pub struct MvPredictor {
501    /// Spatial predictor.
502    spatial: SpatialPredictor,
503    /// Temporal predictor.
504    temporal: TemporalPredictor,
505    /// Predictor list.
506    predictors: MvPredictorList,
507}
508
509impl MvPredictor {
510    /// Creates a new MV predictor.
511    #[must_use]
512    pub fn new() -> Self {
513        Self {
514            spatial: SpatialPredictor::new(),
515            temporal: TemporalPredictor::new(),
516            predictors: MvPredictorList::new(),
517        }
518    }
519
520    /// Builds all predictors from context.
521    pub fn build(&mut self, ctx: &MvPredContext) {
522        self.predictors = MvPredictorList::new();
523
524        // Always add zero MV as fallback
525        self.predictors.add(MvCandidate::new(
526            MotionVector::zero(),
527            1,
528            NeighborPosition::Median,
529        ));
530
531        // Add spatial predictors
532        SpatialPredictor::build_predictors(ctx, &mut self.predictors);
533
534        // Add temporal predictors
535        TemporalPredictor::build_predictors(ctx, &mut self.predictors);
536
537        // Sort by weight
538        self.predictors.sort_by_weight();
539    }
540
541    /// Returns the best MVP.
542    #[must_use]
543    pub fn best_mvp(&self) -> MotionVector {
544        self.predictors.best()
545    }
546
547    /// Returns all predictors.
548    #[must_use]
549    pub fn all_predictors(&self) -> &MvPredictorList {
550        &self.predictors
551    }
552
553    /// Returns predictors as motion vectors.
554    pub fn motion_vectors(&self) -> Vec<MotionVector> {
555        self.predictors.motion_vectors()
556    }
557}
558
559/// MV cost calculator for rate-distortion optimization.
560#[derive(Clone, Debug)]
561pub struct MvCostCalculator {
562    /// Lambda for RD tradeoff.
563    lambda: f32,
564    /// MV cost tables (optional precomputed).
565    cost_table: Option<Vec<u32>>,
566}
567
568impl Default for MvCostCalculator {
569    fn default() -> Self {
570        Self::new(1.0)
571    }
572}
573
574impl MvCostCalculator {
575    /// Creates a new cost calculator.
576    #[must_use]
577    pub const fn new(lambda: f32) -> Self {
578        Self {
579            lambda,
580            cost_table: None,
581        }
582    }
583
584    /// Builds cost table for fast lookup.
585    pub fn build_cost_table(&mut self, max_mv: i32) {
586        let size = (2 * max_mv + 1) as usize;
587        let mut table = vec![0u32; size];
588
589        for i in 0..size {
590            let mv_component = i as i32 - max_mv;
591            table[i] = self.component_cost(mv_component);
592        }
593
594        self.cost_table = Some(table);
595    }
596
597    /// Calculates cost for a single MV component.
598    #[must_use]
599    #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
600    fn component_cost(&self, value: i32) -> u32 {
601        if value == 0 {
602            return (self.lambda * 1.0) as u32;
603        }
604
605        let abs_val = value.unsigned_abs();
606        // Approximate bit cost: log2(abs) * 2 + constant
607        let log2 = 32 - abs_val.leading_zeros();
608        let bits = log2 * 2 + 2;
609        ((bits as f32) * self.lambda) as u32
610    }
611
612    /// Calculates the bit cost of an MV differential.
613    #[must_use]
614    pub fn cost(&self, mv: &MotionVector, mvp: &MotionVector) -> u32 {
615        let diff = *mv - *mvp;
616        self.component_cost(diff.dx) + self.component_cost(diff.dy)
617    }
618
619    /// Calculates full RD cost (distortion + rate).
620    #[must_use]
621    pub fn rd_cost(&self, mv: &MotionVector, mvp: &MotionVector, distortion: u32) -> u32 {
622        distortion.saturating_add(self.cost(mv, mvp))
623    }
624
625    /// Creates an MvCost instance from this calculator.
626    #[must_use]
627    pub fn to_mv_cost(&self, mvp: MotionVector) -> MvCost {
628        MvCost::with_ref_mv(self.lambda, mvp)
629    }
630}
631
632/// MVP selection modes.
633#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
634pub enum MvpMode {
635    /// Use spatial median.
636    #[default]
637    Median,
638    /// Use left neighbor.
639    Left,
640    /// Use top neighbor.
641    Top,
642    /// Use co-located temporal.
643    Temporal,
644    /// Use zero MV.
645    Zero,
646}
647
648impl MvpMode {
649    /// Gets the MVP for this mode.
650    #[must_use]
651    pub fn get_mvp(&self, ctx: &MvPredContext) -> MotionVector {
652        match self {
653            Self::Median => SpatialPredictor::calculate_median(ctx),
654            Self::Left => SpatialPredictor::get_left(ctx).unwrap_or_else(MotionVector::zero),
655            Self::Top => SpatialPredictor::get_top(ctx).unwrap_or_else(MotionVector::zero),
656            Self::Temporal => {
657                TemporalPredictor::get_co_located(ctx).unwrap_or_else(MotionVector::zero)
658            }
659            Self::Zero => MotionVector::zero(),
660        }
661    }
662}
663
664#[cfg(test)]
665mod tests {
666    use super::*;
667
668    #[test]
669    fn test_neighbor_info_unavailable() {
670        let info = NeighborInfo::unavailable();
671        assert!(!info.available);
672        assert!(!info.is_inter);
673    }
674
675    #[test]
676    fn test_neighbor_info_with_mv() {
677        let mv = MotionVector::new(10, 20);
678        let info = NeighborInfo::with_mv(mv, 0);
679        assert!(info.available);
680        assert!(info.is_inter);
681        assert_eq!(info.mv.dx, 10);
682        assert_eq!(info.mv.dy, 20);
683    }
684
685    #[test]
686    fn test_mv_pred_context_builder() {
687        let left = NeighborInfo::with_mv(MotionVector::new(5, 5), 0);
688        let ctx = MvPredContext::new()
689            .at_position(10, 20)
690            .with_size(BlockSize::Block16x16)
691            .with_ref(0)
692            .with_left(left);
693
694        assert_eq!(ctx.mi_row, 10);
695        assert_eq!(ctx.mi_col, 20);
696        assert_eq!(ctx.block_size, BlockSize::Block16x16);
697        assert!(ctx.left.available);
698    }
699
700    #[test]
701    fn test_mv_candidate() {
702        let mv = MotionVector::new(10, 20);
703        let candidate = MvCandidate::new(mv, 5, NeighborPosition::Left);
704
705        assert_eq!(candidate.mv.dx, 10);
706        assert_eq!(candidate.weight, 5);
707        assert_eq!(candidate.source, NeighborPosition::Left);
708    }
709
710    #[test]
711    fn test_mv_predictor_list_add() {
712        let mut list = MvPredictorList::new();
713
714        list.add(MvCandidate::new(
715            MotionVector::new(10, 20),
716            2,
717            NeighborPosition::Left,
718        ));
719        list.add(MvCandidate::new(
720            MotionVector::new(30, 40),
721            3,
722            NeighborPosition::Top,
723        ));
724
725        assert_eq!(list.len(), 2);
726        assert_eq!(list.get(0).expect("get should return value").mv.dx, 10);
727        assert_eq!(list.get(1).expect("get should return value").mv.dx, 30);
728    }
729
730    #[test]
731    fn test_mv_predictor_list_dedup() {
732        let mut list = MvPredictorList::new();
733
734        list.add(MvCandidate::new(
735            MotionVector::new(10, 20),
736            2,
737            NeighborPosition::Left,
738        ));
739        list.add(MvCandidate::new(
740            MotionVector::new(10, 20),
741            3,
742            NeighborPosition::Top,
743        ));
744
745        // Should merge duplicates
746        assert_eq!(list.len(), 1);
747        assert_eq!(list.get(0).expect("get should return value").weight, 3); // Higher weight kept
748    }
749
750    #[test]
751    fn test_mv_predictor_list_sort() {
752        let mut list = MvPredictorList::new();
753
754        list.add(MvCandidate::new(
755            MotionVector::new(10, 20),
756            1,
757            NeighborPosition::Left,
758        ));
759        list.add(MvCandidate::new(
760            MotionVector::new(30, 40),
761            3,
762            NeighborPosition::Top,
763        ));
764        list.add(MvCandidate::new(
765            MotionVector::new(50, 60),
766            2,
767            NeighborPosition::TopRight,
768        ));
769
770        list.sort_by_weight();
771
772        assert_eq!(list.get(0).expect("get should return value").weight, 3);
773        assert_eq!(list.get(1).expect("get should return value").weight, 2);
774        assert_eq!(list.get(2).expect("get should return value").weight, 1);
775    }
776
777    #[test]
778    fn test_spatial_predictor_median() {
779        let a = MotionVector::new(10, 30);
780        let b = MotionVector::new(20, 20);
781        let c = MotionVector::new(30, 10);
782
783        let median = SpatialPredictor::median(a, b, c);
784
785        assert_eq!(median.dx, 20); // Median of 10, 20, 30
786        assert_eq!(median.dy, 20); // Median of 30, 20, 10
787    }
788
789    #[test]
790    fn test_spatial_predictor_build() {
791        let ctx = MvPredContext::new()
792            .with_left(NeighborInfo::with_mv(MotionVector::new(10, 10), 0))
793            .with_top(NeighborInfo::with_mv(MotionVector::new(20, 20), 0));
794
795        let mut list = MvPredictorList::new();
796        SpatialPredictor::build_predictors(&ctx, &mut list);
797
798        assert!(list.len() >= 2);
799    }
800
801    #[test]
802    fn test_temporal_predictor_scale() {
803        let mv = MotionVector::new(100, 200);
804
805        // Same distance - no change
806        let same = TemporalPredictor::scale_mv(mv, 1, 1);
807        assert_eq!(same.dx, 100);
808        assert_eq!(same.dy, 200);
809
810        // Double distance
811        let doubled = TemporalPredictor::scale_mv(mv, 1, 2);
812        assert_eq!(doubled.dx, 200);
813        assert_eq!(doubled.dy, 400);
814
815        // Half distance
816        let halved = TemporalPredictor::scale_mv(mv, 2, 1);
817        assert_eq!(halved.dx, 50);
818        assert_eq!(halved.dy, 100);
819    }
820
821    #[test]
822    fn test_mv_predictor_build() {
823        let ctx = MvPredContext::new()
824            .with_left(NeighborInfo::with_mv(MotionVector::new(10, 10), 0))
825            .with_top(NeighborInfo::with_mv(MotionVector::new(20, 20), 0))
826            .with_co_located(NeighborInfo::with_mv(MotionVector::new(15, 15), 0));
827
828        let mut predictor = MvPredictor::new();
829        predictor.build(&ctx);
830
831        // Should have multiple predictors
832        assert!(predictor.all_predictors().len() >= 3);
833
834        // Best MVP should be calculated
835        let mvp = predictor.best_mvp();
836        assert!(mvp.dx != 0 || mvp.dy != 0 || predictor.all_predictors().len() == 1);
837    }
838
839    #[test]
840    fn test_mv_cost_calculator() {
841        let calc = MvCostCalculator::new(1.0);
842        let mv = MotionVector::new(16, 16);
843        let mvp = MotionVector::zero();
844
845        let cost = calc.cost(&mv, &mvp);
846        assert!(cost > 0);
847
848        // Same MV as predictor should have lower cost
849        let same_cost = calc.cost(&mvp, &mvp);
850        assert!(same_cost < cost);
851    }
852
853    #[test]
854    fn test_mv_cost_rd() {
855        let calc = MvCostCalculator::new(1.0);
856        let mv = MotionVector::new(16, 16);
857        let mvp = MotionVector::zero();
858
859        let rd = calc.rd_cost(&mv, &mvp, 100);
860        assert!(rd > 100); // Should include MV cost
861    }
862
863    #[test]
864    fn test_mvp_mode() {
865        let ctx =
866            MvPredContext::new().with_left(NeighborInfo::with_mv(MotionVector::new(10, 10), 0));
867
868        assert_eq!(MvpMode::Left.get_mvp(&ctx).dx, 10);
869        assert_eq!(MvpMode::Zero.get_mvp(&ctx).dx, 0);
870    }
871
872    #[test]
873    fn test_motion_vectors_extraction() {
874        let mut list = MvPredictorList::new();
875        list.add(MvCandidate::new(
876            MotionVector::new(10, 20),
877            2,
878            NeighborPosition::Left,
879        ));
880        list.add(MvCandidate::new(
881            MotionVector::new(30, 40),
882            1,
883            NeighborPosition::Top,
884        ));
885
886        let mvs = list.motion_vectors();
887        assert_eq!(mvs.len(), 2);
888        assert_eq!(mvs[0].dx, 10);
889        assert_eq!(mvs[1].dx, 30);
890    }
891}