Skip to main content

oximedia_codec/motion/
search.rs

1//! Motion search algorithms for video encoding.
2//!
3//! This module provides various motion estimation search algorithms including:
4//! - Full search (exhaustive)
5//! - Diamond search (SDSP/LDSP)
6//! - Hexagon search
7//! - UMH search (Unsymmetrical-cross Multi-Hexagon)
8//!
9//! All algorithms implement the [`MotionSearch`] trait for consistent interface.
10
11#![forbid(unsafe_code)]
12#![allow(dead_code)]
13#![allow(clippy::too_many_arguments)]
14#![allow(clippy::cast_sign_loss)]
15#![allow(clippy::cast_possible_truncation)]
16#![allow(clippy::cast_possible_wrap)]
17#![allow(clippy::must_use_candidate)]
18#![allow(clippy::trivially_copy_pass_by_ref)]
19#![allow(clippy::unused_self)]
20#![allow(clippy::items_after_statements)]
21#![allow(unused_assignments)]
22
23use super::types::{BlockMatch, BlockSize, MotionVector, MvCost, MvPrecision, SearchRange};
24
25/// Early termination threshold for search.
26pub const EARLY_TERMINATION_SAD: u32 = 256;
27
28/// Minimum SAD improvement ratio for early termination.
29pub const EARLY_TERMINATION_RATIO: f32 = 0.9;
30
31/// Configuration for motion search algorithms.
32#[derive(Clone, Debug)]
33pub struct SearchConfig {
34    /// Search range in full pixels.
35    pub range: SearchRange,
36    /// Motion vector precision.
37    pub precision: MvPrecision,
38    /// Enable early termination.
39    pub early_termination: bool,
40    /// Early termination threshold.
41    pub early_threshold: u32,
42    /// MV cost calculator for RD optimization.
43    pub mv_cost: MvCost,
44    /// Maximum iterations for iterative algorithms.
45    pub max_iterations: u32,
46    /// Enable sub-pixel refinement.
47    pub subpel_refine: bool,
48}
49
50impl Default for SearchConfig {
51    fn default() -> Self {
52        Self {
53            range: SearchRange::default(),
54            precision: MvPrecision::QuarterPel,
55            early_termination: true,
56            early_threshold: EARLY_TERMINATION_SAD,
57            mv_cost: MvCost::default(),
58            max_iterations: 16,
59            subpel_refine: true,
60        }
61    }
62}
63
64impl SearchConfig {
65    /// Creates a new search config with the given range.
66    #[must_use]
67    pub fn with_range(range: SearchRange) -> Self {
68        Self {
69            range,
70            ..Default::default()
71        }
72    }
73
74    /// Sets the search range.
75    #[must_use]
76    pub const fn range(mut self, range: SearchRange) -> Self {
77        self.range = range;
78        self
79    }
80
81    /// Sets motion vector precision.
82    #[must_use]
83    pub const fn precision(mut self, precision: MvPrecision) -> Self {
84        self.precision = precision;
85        self
86    }
87
88    /// Enables or disables early termination.
89    #[must_use]
90    pub const fn early_termination(mut self, enable: bool) -> Self {
91        self.early_termination = enable;
92        self
93    }
94
95    /// Sets the reference motion vector for cost calculation.
96    #[must_use]
97    pub fn ref_mv(mut self, mv: MotionVector) -> Self {
98        self.mv_cost.set_ref_mv(mv);
99        self
100    }
101
102    /// Sets the lambda for rate-distortion optimization.
103    #[must_use]
104    pub fn lambda(mut self, lambda: f32) -> Self {
105        self.mv_cost.lambda = lambda;
106        self
107    }
108}
109
110/// Search context containing frame data and configuration.
111pub struct SearchContext<'a> {
112    /// Source block data.
113    pub src: &'a [u8],
114    /// Source stride.
115    pub src_stride: usize,
116    /// Reference frame data.
117    pub ref_frame: &'a [u8],
118    /// Reference stride.
119    pub ref_stride: usize,
120    /// Block size.
121    pub block_size: BlockSize,
122    /// Block position X in source.
123    pub block_x: usize,
124    /// Block position Y in source.
125    pub block_y: usize,
126    /// Reference frame width.
127    pub ref_width: usize,
128    /// Reference frame height.
129    pub ref_height: usize,
130}
131
132impl<'a> SearchContext<'a> {
133    /// Creates a new search context.
134    #[must_use]
135    #[allow(clippy::too_many_arguments)]
136    pub const fn new(
137        src: &'a [u8],
138        src_stride: usize,
139        ref_frame: &'a [u8],
140        ref_stride: usize,
141        block_size: BlockSize,
142        block_x: usize,
143        block_y: usize,
144        ref_width: usize,
145        ref_height: usize,
146    ) -> Self {
147        Self {
148            src,
149            src_stride,
150            ref_frame,
151            ref_stride,
152            block_size,
153            block_x,
154            block_y,
155            ref_width,
156            ref_height,
157        }
158    }
159
160    /// Returns the source block offset.
161    #[must_use]
162    pub const fn src_offset(&self) -> usize {
163        self.block_y * self.src_stride + self.block_x
164    }
165
166    /// Calculates SAD for a given motion vector.
167    #[must_use]
168    #[allow(clippy::cast_sign_loss)]
169    pub fn calculate_sad(&self, mv: &MotionVector) -> Option<u32> {
170        let ref_x = self.block_x as i32 + mv.full_pel_x();
171        let ref_y = self.block_y as i32 + mv.full_pel_y();
172
173        // Check bounds
174        if ref_x < 0 || ref_y < 0 {
175            return None;
176        }
177
178        let ref_x = ref_x as usize;
179        let ref_y = ref_y as usize;
180        let width = self.block_size.width();
181        let height = self.block_size.height();
182
183        if ref_x + width > self.ref_width || ref_y + height > self.ref_height {
184            return None;
185        }
186
187        let src_offset = self.src_offset();
188        let ref_offset = ref_y * self.ref_stride + ref_x;
189
190        // Check slice bounds
191        if src_offset + (height - 1) * self.src_stride + width > self.src.len() {
192            return None;
193        }
194        if ref_offset + (height - 1) * self.ref_stride + width > self.ref_frame.len() {
195            return None;
196        }
197
198        let mut sad = 0u32;
199        for row in 0..height {
200            let src_row_offset = src_offset + row * self.src_stride;
201            let ref_row_offset = ref_offset + row * self.ref_stride;
202
203            for col in 0..width {
204                let src_val = self.src[src_row_offset + col];
205                let ref_val = self.ref_frame[ref_row_offset + col];
206                let diff = i32::from(src_val) - i32::from(ref_val);
207                sad += diff.unsigned_abs();
208            }
209        }
210
211        Some(sad)
212    }
213
214    /// Checks if a motion vector is within valid bounds.
215    #[must_use]
216    pub fn is_valid_mv(&self, mv: &MotionVector, range: &SearchRange) -> bool {
217        let ref_x = self.block_x as i32 + mv.full_pel_x();
218        let ref_y = self.block_y as i32 + mv.full_pel_y();
219        let width = self.block_size.width() as i32;
220        let height = self.block_size.height() as i32;
221
222        ref_x >= 0
223            && ref_y >= 0
224            && ref_x + width <= self.ref_width as i32
225            && ref_y + height <= self.ref_height as i32
226            && range.contains(mv.full_pel_x(), mv.full_pel_y())
227    }
228}
229
230/// Trait for motion search algorithms.
231pub trait MotionSearch {
232    /// Performs motion search and returns the best match.
233    fn search(&self, ctx: &SearchContext, config: &SearchConfig) -> BlockMatch;
234
235    /// Performs motion search with a starting point prediction.
236    fn search_with_predictor(
237        &self,
238        ctx: &SearchContext,
239        config: &SearchConfig,
240        predictor: MotionVector,
241    ) -> BlockMatch;
242}
243
244/// Full exhaustive search algorithm.
245///
246/// Checks every position in the search range. Guaranteed to find the
247/// global optimum but computationally expensive.
248#[derive(Clone, Debug, Default)]
249pub struct FullSearch;
250
251impl FullSearch {
252    /// Creates a new full search instance.
253    #[must_use]
254    pub const fn new() -> Self {
255        Self
256    }
257}
258
259impl MotionSearch for FullSearch {
260    fn search(&self, ctx: &SearchContext, config: &SearchConfig) -> BlockMatch {
261        self.search_with_predictor(ctx, config, MotionVector::zero())
262    }
263
264    fn search_with_predictor(
265        &self,
266        ctx: &SearchContext,
267        config: &SearchConfig,
268        _predictor: MotionVector,
269    ) -> BlockMatch {
270        let mut best = BlockMatch::worst();
271        let range = &config.range;
272
273        for dy in -range.vertical..=range.vertical {
274            for dx in -range.horizontal..=range.horizontal {
275                let mv = MotionVector::from_full_pel(dx, dy);
276
277                if let Some(sad) = ctx.calculate_sad(&mv) {
278                    let cost = config.mv_cost.rd_cost(&mv, sad);
279                    let candidate = BlockMatch::new(mv, sad, cost);
280
281                    best.update_if_better(&candidate);
282
283                    // Early termination
284                    if config.early_termination && sad < config.early_threshold {
285                        return best;
286                    }
287                }
288            }
289        }
290
291        best
292    }
293}
294
295/// Diamond search algorithm.
296///
297/// Uses small diamond pattern (SDSP) for refinement and large diamond
298/// pattern (LDSP) for initial coarse search.
299#[derive(Clone, Debug, Default)]
300pub struct DiamondSearch {
301    /// Use large diamond for initial search.
302    use_large_diamond: bool,
303}
304
305impl DiamondSearch {
306    /// Small diamond pattern offsets (4 points).
307    const SDSP: [(i32, i32); 4] = [(0, -1), (0, 1), (-1, 0), (1, 0)];
308
309    /// Large diamond pattern offsets (8 points).
310    const LDSP: [(i32, i32); 8] = [
311        (0, -2),
312        (0, 2),
313        (-2, 0),
314        (2, 0),
315        (-1, -1),
316        (-1, 1),
317        (1, -1),
318        (1, 1),
319    ];
320
321    /// Creates a new diamond search instance.
322    #[must_use]
323    pub const fn new() -> Self {
324        Self {
325            use_large_diamond: true,
326        }
327    }
328
329    /// Sets whether to use large diamond pattern.
330    #[must_use]
331    pub const fn use_large_diamond(mut self, enable: bool) -> Self {
332        self.use_large_diamond = enable;
333        self
334    }
335
336    /// Performs diamond pattern search around a center point.
337    fn diamond_step(
338        &self,
339        ctx: &SearchContext,
340        config: &SearchConfig,
341        center: MotionVector,
342        pattern: &[(i32, i32)],
343    ) -> (MotionVector, u32) {
344        let mut best_mv = center;
345        let mut best_sad = ctx.calculate_sad(&center).unwrap_or(u32::MAX);
346
347        for &(dx, dy) in pattern {
348            let mv =
349                MotionVector::from_full_pel(center.full_pel_x() + dx, center.full_pel_y() + dy);
350
351            if !ctx.is_valid_mv(&mv, &config.range) {
352                continue;
353            }
354
355            if let Some(sad) = ctx.calculate_sad(&mv) {
356                if sad < best_sad {
357                    best_sad = sad;
358                    best_mv = mv;
359                }
360            }
361        }
362
363        (best_mv, best_sad)
364    }
365}
366
367impl MotionSearch for DiamondSearch {
368    fn search(&self, ctx: &SearchContext, config: &SearchConfig) -> BlockMatch {
369        self.search_with_predictor(ctx, config, MotionVector::zero())
370    }
371
372    fn search_with_predictor(
373        &self,
374        ctx: &SearchContext,
375        config: &SearchConfig,
376        predictor: MotionVector,
377    ) -> BlockMatch {
378        let mut center = predictor.to_precision(MvPrecision::FullPel);
379        let mut best_sad = ctx.calculate_sad(&center).unwrap_or(u32::MAX);
380
381        // Initial large diamond search
382        if self.use_large_diamond {
383            for _ in 0..config.max_iterations {
384                let (new_center, new_sad) = self.diamond_step(ctx, config, center, &Self::LDSP);
385
386                if new_center == center {
387                    break;
388                }
389
390                center = new_center;
391                best_sad = new_sad;
392
393                if config.early_termination && best_sad < config.early_threshold {
394                    let cost = config.mv_cost.rd_cost(&center, best_sad);
395                    return BlockMatch::new(center, best_sad, cost);
396                }
397            }
398        }
399
400        // Refinement with small diamond
401        loop {
402            let (new_center, new_sad) = self.diamond_step(ctx, config, center, &Self::SDSP);
403
404            if new_center == center {
405                break;
406            }
407
408            center = new_center;
409            best_sad = new_sad;
410        }
411
412        let cost = config.mv_cost.rd_cost(&center, best_sad);
413        BlockMatch::new(center, best_sad, cost)
414    }
415}
416
417/// Hexagon search algorithm.
418///
419/// Uses a hexagonal pattern which provides better coverage than diamond
420/// with fewer points to check.
421#[derive(Clone, Debug, Default)]
422pub struct HexagonSearch;
423
424impl HexagonSearch {
425    /// Hexagon pattern offsets (6 points).
426    const HEXAGON: [(i32, i32); 6] = [(-2, 0), (-1, -2), (1, -2), (2, 0), (1, 2), (-1, 2)];
427
428    /// Square pattern for final refinement (4 points).
429    const SQUARE: [(i32, i32); 4] = [(-1, -1), (-1, 1), (1, -1), (1, 1)];
430
431    /// Creates a new hexagon search instance.
432    #[must_use]
433    pub const fn new() -> Self {
434        Self
435    }
436
437    /// Performs hexagon pattern search around a center point.
438    fn hexagon_step(
439        &self,
440        ctx: &SearchContext,
441        config: &SearchConfig,
442        center: MotionVector,
443    ) -> (MotionVector, u32) {
444        let mut best_mv = center;
445        let mut best_sad = ctx.calculate_sad(&center).unwrap_or(u32::MAX);
446
447        for &(dx, dy) in &Self::HEXAGON {
448            let mv =
449                MotionVector::from_full_pel(center.full_pel_x() + dx, center.full_pel_y() + dy);
450
451            if !ctx.is_valid_mv(&mv, &config.range) {
452                continue;
453            }
454
455            if let Some(sad) = ctx.calculate_sad(&mv) {
456                if sad < best_sad {
457                    best_sad = sad;
458                    best_mv = mv;
459                }
460            }
461        }
462
463        (best_mv, best_sad)
464    }
465
466    /// Final square refinement.
467    fn square_refine(
468        &self,
469        ctx: &SearchContext,
470        config: &SearchConfig,
471        center: MotionVector,
472    ) -> (MotionVector, u32) {
473        let mut best_mv = center;
474        let mut best_sad = ctx.calculate_sad(&center).unwrap_or(u32::MAX);
475
476        for &(dx, dy) in &Self::SQUARE {
477            let mv =
478                MotionVector::from_full_pel(center.full_pel_x() + dx, center.full_pel_y() + dy);
479
480            if !ctx.is_valid_mv(&mv, &config.range) {
481                continue;
482            }
483
484            if let Some(sad) = ctx.calculate_sad(&mv) {
485                if sad < best_sad {
486                    best_sad = sad;
487                    best_mv = mv;
488                }
489            }
490        }
491
492        (best_mv, best_sad)
493    }
494}
495
496impl MotionSearch for HexagonSearch {
497    fn search(&self, ctx: &SearchContext, config: &SearchConfig) -> BlockMatch {
498        self.search_with_predictor(ctx, config, MotionVector::zero())
499    }
500
501    fn search_with_predictor(
502        &self,
503        ctx: &SearchContext,
504        config: &SearchConfig,
505        predictor: MotionVector,
506    ) -> BlockMatch {
507        let mut center = predictor.to_precision(MvPrecision::FullPel);
508        let mut best_sad = ctx.calculate_sad(&center).unwrap_or(u32::MAX);
509
510        // Hexagon search iterations
511        for _ in 0..config.max_iterations {
512            let (new_center, new_sad) = self.hexagon_step(ctx, config, center);
513
514            if new_center == center {
515                break;
516            }
517
518            center = new_center;
519            best_sad = new_sad;
520
521            if config.early_termination && best_sad < config.early_threshold {
522                let cost = config.mv_cost.rd_cost(&center, best_sad);
523                return BlockMatch::new(center, best_sad, cost);
524            }
525        }
526
527        // Final square refinement
528        let (final_center, final_sad) = self.square_refine(ctx, config, center);
529
530        let cost = config.mv_cost.rd_cost(&final_center, final_sad);
531        BlockMatch::new(final_center, final_sad, cost)
532    }
533}
534
535/// Unsymmetrical-cross Multi-Hexagon search algorithm.
536///
537/// Combines multiple patterns for efficient search:
538/// 1. Unsymmetrical cross
539/// 2. Multi-hexagon grid
540/// 3. Extended hexagon
541/// 4. Small diamond refinement
542#[derive(Clone, Debug, Default)]
543pub struct UmhSearch {
544    /// Cross search range multiplier.
545    cross_range: i32,
546}
547
548impl UmhSearch {
549    /// Creates a new UMH search instance.
550    #[must_use]
551    pub const fn new() -> Self {
552        Self { cross_range: 2 }
553    }
554
555    /// Sets the cross search range multiplier.
556    #[must_use]
557    pub const fn cross_range(mut self, range: i32) -> Self {
558        self.cross_range = range;
559        self
560    }
561
562    /// Performs unsymmetrical cross search.
563    fn cross_search(
564        &self,
565        ctx: &SearchContext,
566        config: &SearchConfig,
567        center: MotionVector,
568    ) -> (MotionVector, u32) {
569        let mut best_mv = center;
570        let mut best_sad = ctx.calculate_sad(&center).unwrap_or(u32::MAX);
571        let range = config.range.horizontal.min(config.range.vertical) * self.cross_range;
572
573        // Horizontal cross (more points)
574        for dx in (-range..=range).step_by(2) {
575            if dx == 0 {
576                continue;
577            }
578            let mv = MotionVector::from_full_pel(center.full_pel_x() + dx, center.full_pel_y());
579
580            if !ctx.is_valid_mv(&mv, &config.range) {
581                continue;
582            }
583
584            if let Some(sad) = ctx.calculate_sad(&mv) {
585                if sad < best_sad {
586                    best_sad = sad;
587                    best_mv = mv;
588                }
589            }
590        }
591
592        // Vertical cross (fewer points - unsymmetrical)
593        for dy in (-range / 2..=range / 2).step_by(2) {
594            if dy == 0 {
595                continue;
596            }
597            let mv = MotionVector::from_full_pel(center.full_pel_x(), center.full_pel_y() + dy);
598
599            if !ctx.is_valid_mv(&mv, &config.range) {
600                continue;
601            }
602
603            if let Some(sad) = ctx.calculate_sad(&mv) {
604                if sad < best_sad {
605                    best_sad = sad;
606                    best_mv = mv;
607                }
608            }
609        }
610
611        (best_mv, best_sad)
612    }
613
614    /// Multi-hexagon grid search.
615    fn multi_hexagon(
616        &self,
617        ctx: &SearchContext,
618        config: &SearchConfig,
619        center: MotionVector,
620    ) -> (MotionVector, u32) {
621        let mut best_mv = center;
622        let mut best_sad = ctx.calculate_sad(&center).unwrap_or(u32::MAX);
623
624        // 16 points in a multi-hexagon pattern
625        const MULTI_HEX: [(i32, i32); 16] = [
626            (-4, 0),
627            (-2, -2),
628            (0, -4),
629            (2, -2),
630            (4, 0),
631            (2, 2),
632            (0, 4),
633            (-2, 2),
634            (-2, -4),
635            (2, -4),
636            (4, -2),
637            (4, 2),
638            (2, 4),
639            (-2, 4),
640            (-4, 2),
641            (-4, -2),
642        ];
643
644        for &(dx, dy) in &MULTI_HEX {
645            let mv =
646                MotionVector::from_full_pel(center.full_pel_x() + dx, center.full_pel_y() + dy);
647
648            if !ctx.is_valid_mv(&mv, &config.range) {
649                continue;
650            }
651
652            if let Some(sad) = ctx.calculate_sad(&mv) {
653                if sad < best_sad {
654                    best_sad = sad;
655                    best_mv = mv;
656                }
657            }
658        }
659
660        (best_mv, best_sad)
661    }
662}
663
664impl MotionSearch for UmhSearch {
665    fn search(&self, ctx: &SearchContext, config: &SearchConfig) -> BlockMatch {
666        self.search_with_predictor(ctx, config, MotionVector::zero())
667    }
668
669    fn search_with_predictor(
670        &self,
671        ctx: &SearchContext,
672        config: &SearchConfig,
673        predictor: MotionVector,
674    ) -> BlockMatch {
675        let mut center = predictor.to_precision(MvPrecision::FullPel);
676        let mut best_sad = ctx.calculate_sad(&center).unwrap_or(u32::MAX);
677
678        // Step 1: Unsymmetrical cross search
679        let (cross_mv, cross_sad) = self.cross_search(ctx, config, center);
680        if cross_sad < best_sad {
681            center = cross_mv;
682            best_sad = cross_sad;
683        }
684
685        // Early termination check
686        if config.early_termination && best_sad < config.early_threshold {
687            let cost = config.mv_cost.rd_cost(&center, best_sad);
688            return BlockMatch::new(center, best_sad, cost);
689        }
690
691        // Step 2: Multi-hexagon grid
692        let (hex_mv, hex_sad) = self.multi_hexagon(ctx, config, center);
693        if hex_sad < best_sad {
694            center = hex_mv;
695            best_sad = hex_sad;
696        }
697
698        // Step 3: Extended hexagon search
699        let hex_search = HexagonSearch::new();
700        for _ in 0..config.max_iterations / 2 {
701            let (new_center, new_sad) = hex_search.hexagon_step(ctx, config, center);
702            if new_center == center {
703                break;
704            }
705            center = new_center;
706            best_sad = new_sad;
707        }
708
709        // Step 4: Small diamond refinement
710        let diamond = DiamondSearch::new().use_large_diamond(false);
711        let result = diamond.search_with_predictor(ctx, config, center);
712
713        if result.sad < best_sad {
714            result
715        } else {
716            let cost = config.mv_cost.rd_cost(&center, best_sad);
717            BlockMatch::new(center, best_sad, cost)
718        }
719    }
720}
721
722/// Three-step search algorithm.
723///
724/// Classic fast search algorithm that reduces search space logarithmically.
725#[derive(Clone, Debug, Default)]
726pub struct ThreeStepSearch;
727
728impl ThreeStepSearch {
729    /// Creates a new three-step search instance.
730    #[must_use]
731    pub const fn new() -> Self {
732        Self
733    }
734
735    /// Square pattern for each step.
736    const SQUARE_8: [(i32, i32); 8] = [
737        (-1, -1),
738        (0, -1),
739        (1, -1),
740        (-1, 0),
741        (1, 0),
742        (-1, 1),
743        (0, 1),
744        (1, 1),
745    ];
746}
747
748impl MotionSearch for ThreeStepSearch {
749    fn search(&self, ctx: &SearchContext, config: &SearchConfig) -> BlockMatch {
750        self.search_with_predictor(ctx, config, MotionVector::zero())
751    }
752
753    fn search_with_predictor(
754        &self,
755        ctx: &SearchContext,
756        config: &SearchConfig,
757        predictor: MotionVector,
758    ) -> BlockMatch {
759        let mut center = predictor.to_precision(MvPrecision::FullPel);
760        let mut best_sad = ctx.calculate_sad(&center).unwrap_or(u32::MAX);
761
762        // Initial step size (half of search range)
763        let initial_step = config.range.horizontal.min(config.range.vertical) / 2;
764        let mut step = initial_step.max(4);
765
766        // Three steps with decreasing step size
767        while step >= 1 {
768            let mut moved = false;
769
770            for &(dx, dy) in &Self::SQUARE_8 {
771                let mv = MotionVector::from_full_pel(
772                    center.full_pel_x() + dx * step,
773                    center.full_pel_y() + dy * step,
774                );
775
776                if !ctx.is_valid_mv(&mv, &config.range) {
777                    continue;
778                }
779
780                if let Some(sad) = ctx.calculate_sad(&mv) {
781                    if sad < best_sad {
782                        best_sad = sad;
783                        center = mv;
784                        moved = true;
785                    }
786                }
787            }
788
789            // If no movement, reduce step size
790            if !moved {
791                step /= 2;
792            }
793
794            // Early termination
795            if config.early_termination && best_sad < config.early_threshold {
796                break;
797            }
798        }
799
800        let cost = config.mv_cost.rd_cost(&center, best_sad);
801        BlockMatch::new(center, best_sad, cost)
802    }
803}
804
805/// Adaptive search that selects algorithm based on complexity.
806#[derive(Clone, Debug)]
807pub struct AdaptiveSearch {
808    /// Threshold for switching to simpler algorithm.
809    complexity_threshold: u32,
810}
811
812impl Default for AdaptiveSearch {
813    fn default() -> Self {
814        Self::new()
815    }
816}
817
818impl AdaptiveSearch {
819    /// Creates a new adaptive search instance.
820    #[must_use]
821    pub const fn new() -> Self {
822        Self {
823            complexity_threshold: 1000,
824        }
825    }
826
827    /// Sets the complexity threshold.
828    #[must_use]
829    pub const fn threshold(mut self, threshold: u32) -> Self {
830        self.complexity_threshold = threshold;
831        self
832    }
833
834    /// Estimates block complexity (variance-like measure).
835    fn estimate_complexity(&self, ctx: &SearchContext) -> u32 {
836        let src_offset = ctx.src_offset();
837        let width = ctx.block_size.width();
838        let height = ctx.block_size.height();
839
840        let mut sum = 0u32;
841        let mut sum_sq = 0u64;
842        let mut count = 0u32;
843
844        for row in 0..height {
845            let row_offset = src_offset + row * ctx.src_stride;
846            for col in 0..width {
847                if row_offset + col < ctx.src.len() {
848                    let val = u32::from(ctx.src[row_offset + col]);
849                    sum += val;
850                    sum_sq += u64::from(val * val);
851                    count += 1;
852                }
853            }
854        }
855
856        if count == 0 {
857            return 0;
858        }
859
860        let mean = sum / count;
861
862        (sum_sq / u64::from(count)) as u32 - mean * mean
863    }
864}
865
866impl MotionSearch for AdaptiveSearch {
867    fn search(&self, ctx: &SearchContext, config: &SearchConfig) -> BlockMatch {
868        self.search_with_predictor(ctx, config, MotionVector::zero())
869    }
870
871    fn search_with_predictor(
872        &self,
873        ctx: &SearchContext,
874        config: &SearchConfig,
875        predictor: MotionVector,
876    ) -> BlockMatch {
877        let complexity = self.estimate_complexity(ctx);
878
879        if complexity < self.complexity_threshold {
880            // Low complexity: use faster diamond search
881            DiamondSearch::new().search_with_predictor(ctx, config, predictor)
882        } else {
883            // High complexity: use more thorough UMH search
884            UmhSearch::new().search_with_predictor(ctx, config, predictor)
885        }
886    }
887}
888
889#[cfg(test)]
890mod tests {
891    use super::*;
892
893    fn create_test_context<'a>(
894        src: &'a [u8],
895        ref_frame: &'a [u8],
896        width: usize,
897        height: usize,
898    ) -> SearchContext<'a> {
899        SearchContext::new(
900            src,
901            width,
902            ref_frame,
903            width,
904            BlockSize::Block8x8,
905            0,
906            0,
907            width,
908            height,
909        )
910    }
911
912    #[test]
913    fn test_search_config_default() {
914        let config = SearchConfig::default();
915        assert_eq!(config.precision, MvPrecision::QuarterPel);
916        assert!(config.early_termination);
917    }
918
919    #[test]
920    fn test_search_config_builder() {
921        let config = SearchConfig::default()
922            .range(SearchRange::symmetric(32))
923            .precision(MvPrecision::HalfPel)
924            .early_termination(false);
925
926        assert_eq!(config.range.horizontal, 32);
927        assert_eq!(config.precision, MvPrecision::HalfPel);
928        assert!(!config.early_termination);
929    }
930
931    #[test]
932    fn test_search_context_sad_calculation() {
933        let src = vec![100u8; 64]; // 8x8 block
934        let ref_frame = vec![110u8; 64]; // 8x8 with offset
935
936        let ctx = create_test_context(&src, &ref_frame, 8, 8);
937        let mv = MotionVector::zero();
938
939        let sad = ctx.calculate_sad(&mv);
940        assert!(sad.is_some());
941        assert_eq!(sad.expect("should succeed"), 640); // 64 * 10
942    }
943
944    #[test]
945    fn test_search_context_identical_blocks() {
946        let data = vec![128u8; 64];
947        let ctx = create_test_context(&data, &data, 8, 8);
948        let mv = MotionVector::zero();
949
950        let sad = ctx.calculate_sad(&mv);
951        assert_eq!(sad, Some(0));
952    }
953
954    #[test]
955    fn test_full_search() {
956        let src = vec![100u8; 64];
957        let mut ref_frame = vec![50u8; 256]; // 16x16
958
959        // Place matching block at (4, 4)
960        for row in 0..8 {
961            for col in 0..8 {
962                ref_frame[(row + 4) * 16 + col + 4] = 100;
963            }
964        }
965
966        let ctx = SearchContext::new(&src, 8, &ref_frame, 16, BlockSize::Block8x8, 0, 0, 16, 16);
967        let config = SearchConfig::with_range(SearchRange::symmetric(8));
968
969        let searcher = FullSearch::new();
970        let result = searcher.search(&ctx, &config);
971
972        assert_eq!(result.mv.full_pel_x(), 4);
973        assert_eq!(result.mv.full_pel_y(), 4);
974        assert_eq!(result.sad, 0);
975    }
976
977    #[test]
978    fn test_diamond_search() {
979        let src = vec![100u8; 64];
980        let mut ref_frame = vec![50u8; 256];
981
982        // Place matching block at (4, 4)
983        for row in 0..8 {
984            for col in 0..8 {
985                ref_frame[(row + 4) * 16 + col + 4] = 100;
986            }
987        }
988
989        let ctx = SearchContext::new(&src, 8, &ref_frame, 16, BlockSize::Block8x8, 0, 0, 16, 16);
990        let config = SearchConfig::with_range(SearchRange::symmetric(8));
991
992        let searcher = DiamondSearch::new();
993        let result = searcher.search(&ctx, &config);
994
995        // Diamond search should find reasonably close match
996        assert!(result.sad < 1000);
997    }
998
999    #[test]
1000    fn test_hexagon_search() {
1001        let src = vec![100u8; 64];
1002        let mut ref_frame = vec![50u8; 256];
1003
1004        for row in 0..8 {
1005            for col in 0..8 {
1006                ref_frame[(row + 4) * 16 + col + 4] = 100;
1007            }
1008        }
1009
1010        let ctx = SearchContext::new(&src, 8, &ref_frame, 16, BlockSize::Block8x8, 0, 0, 16, 16);
1011        let config = SearchConfig::with_range(SearchRange::symmetric(8));
1012
1013        let searcher = HexagonSearch::new();
1014        let result = searcher.search(&ctx, &config);
1015
1016        assert!(result.sad < 1000);
1017    }
1018
1019    #[test]
1020    fn test_umh_search() {
1021        let src = vec![100u8; 64];
1022        let mut ref_frame = vec![50u8; 256];
1023
1024        for row in 0..8 {
1025            for col in 0..8 {
1026                ref_frame[(row + 4) * 16 + col + 4] = 100;
1027            }
1028        }
1029
1030        let ctx = SearchContext::new(&src, 8, &ref_frame, 16, BlockSize::Block8x8, 0, 0, 16, 16);
1031        let config = SearchConfig::with_range(SearchRange::symmetric(8));
1032
1033        let searcher = UmhSearch::new();
1034        let result = searcher.search(&ctx, &config);
1035
1036        assert!(result.sad < 1000);
1037    }
1038
1039    #[test]
1040    fn test_three_step_search() {
1041        let src = vec![100u8; 64];
1042        let mut ref_frame = vec![50u8; 256];
1043
1044        for row in 0..8 {
1045            for col in 0..8 {
1046                ref_frame[(row + 4) * 16 + col + 4] = 100;
1047            }
1048        }
1049
1050        let ctx = SearchContext::new(&src, 8, &ref_frame, 16, BlockSize::Block8x8, 0, 0, 16, 16);
1051        let config = SearchConfig::with_range(SearchRange::symmetric(8));
1052
1053        let searcher = ThreeStepSearch::new();
1054        let result = searcher.search(&ctx, &config);
1055
1056        assert!(result.sad < 1000);
1057    }
1058
1059    #[test]
1060    fn test_search_with_predictor() {
1061        let src = vec![100u8; 64];
1062        let mut ref_frame = vec![50u8; 256];
1063
1064        for row in 0..8 {
1065            for col in 0..8 {
1066                ref_frame[(row + 4) * 16 + col + 4] = 100;
1067            }
1068        }
1069
1070        let ctx = SearchContext::new(&src, 8, &ref_frame, 16, BlockSize::Block8x8, 0, 0, 16, 16);
1071        let config = SearchConfig::with_range(SearchRange::symmetric(8));
1072
1073        // Good predictor should help find the match faster
1074        let predictor = MotionVector::from_full_pel(3, 3);
1075        let searcher = DiamondSearch::new();
1076        let result = searcher.search_with_predictor(&ctx, &config, predictor);
1077
1078        assert!(result.sad < 500);
1079    }
1080
1081    #[test]
1082    fn test_early_termination() {
1083        let data = vec![128u8; 64];
1084        let mut ref_frame = vec![128u8; 256];
1085
1086        // Matching block at origin
1087        for row in 0..8 {
1088            for col in 0..8 {
1089                ref_frame[row * 16 + col] = 128;
1090            }
1091        }
1092
1093        let ctx = SearchContext::new(&data, 8, &ref_frame, 16, BlockSize::Block8x8, 0, 0, 16, 16);
1094        let config = SearchConfig::default().early_termination(true);
1095
1096        let searcher = FullSearch::new();
1097        let result = searcher.search(&ctx, &config);
1098
1099        // Should find perfect match at origin
1100        assert_eq!(result.sad, 0);
1101        assert_eq!(result.mv.full_pel_x(), 0);
1102        assert_eq!(result.mv.full_pel_y(), 0);
1103    }
1104
1105    #[test]
1106    fn test_adaptive_search() {
1107        let src = vec![100u8; 64];
1108        let ref_frame = vec![100u8; 256];
1109
1110        let ctx = SearchContext::new(&src, 8, &ref_frame, 16, BlockSize::Block8x8, 0, 0, 16, 16);
1111        let config = SearchConfig::with_range(SearchRange::symmetric(8));
1112
1113        let searcher = AdaptiveSearch::new();
1114        let result = searcher.search(&ctx, &config);
1115
1116        // Should work regardless of which algorithm is chosen
1117        assert!(result.cost < u32::MAX);
1118    }
1119
1120    #[test]
1121    fn test_out_of_bounds_mv() {
1122        let src = vec![100u8; 64];
1123        let ref_frame = vec![100u8; 64];
1124
1125        let ctx = SearchContext::new(&src, 8, &ref_frame, 8, BlockSize::Block8x8, 0, 0, 8, 8);
1126
1127        // MV that would go out of bounds
1128        let mv = MotionVector::from_full_pel(5, 5);
1129        let sad = ctx.calculate_sad(&mv);
1130
1131        assert!(sad.is_none());
1132    }
1133
1134    #[test]
1135    fn test_is_valid_mv() {
1136        let src = vec![100u8; 64];
1137        let ref_frame = vec![100u8; 256];
1138
1139        let ctx = SearchContext::new(&src, 8, &ref_frame, 16, BlockSize::Block8x8, 0, 0, 16, 16);
1140        let range = SearchRange::symmetric(8);
1141
1142        // Valid MV
1143        assert!(ctx.is_valid_mv(&MotionVector::from_full_pel(4, 4), &range));
1144
1145        // Invalid: out of range
1146        assert!(!ctx.is_valid_mv(&MotionVector::from_full_pel(10, 0), &range));
1147
1148        // Invalid: would go out of frame bounds
1149        assert!(!ctx.is_valid_mv(&MotionVector::from_full_pel(9, 9), &range));
1150    }
1151}