Skip to main content

oximedia_codec/motion/
diamond.rs

1//! Diamond search patterns for motion estimation.
2//!
3//! This module provides implementations of the Small Diamond Search Pattern (SDSP)
4//! and Large Diamond Search Pattern (LDSP) for efficient motion estimation.
5//!
6//! The diamond search is one of the most widely used fast motion estimation
7//! algorithms due to its good balance between quality and speed.
8
9#![forbid(unsafe_code)]
10#![allow(dead_code)]
11#![allow(clippy::too_many_arguments)]
12#![allow(clippy::must_use_candidate)]
13
14use super::search::{MotionSearch, SearchConfig, SearchContext};
15use super::types::{BlockMatch, MotionVector, MvPrecision};
16
17/// Small Diamond Search Pattern (SDSP).
18///
19/// A 4-point pattern for fine refinement:
20/// ```text
21///       *
22///     * O *
23///       *
24/// ```
25#[derive(Clone, Copy, Debug)]
26pub struct SmallDiamond {
27    /// Pattern offsets (dx, dy) for each point.
28    pub points: [(i32, i32); 4],
29}
30
31impl Default for SmallDiamond {
32    fn default() -> Self {
33        Self::new()
34    }
35}
36
37impl SmallDiamond {
38    /// Standard SDSP offsets.
39    pub const PATTERN: [(i32, i32); 4] = [(0, -1), (-1, 0), (1, 0), (0, 1)];
40
41    /// Creates a new small diamond pattern.
42    #[must_use]
43    pub const fn new() -> Self {
44        Self {
45            points: Self::PATTERN,
46        }
47    }
48
49    /// Returns the number of points in the pattern.
50    #[must_use]
51    pub const fn size(&self) -> usize {
52        4
53    }
54
55    /// Gets the offset at a given index.
56    #[must_use]
57    pub const fn get(&self, index: usize) -> Option<(i32, i32)> {
58        if index < 4 {
59            Some(self.points[index])
60        } else {
61            None
62        }
63    }
64
65    /// Searches using the small diamond pattern.
66    pub fn search(
67        &self,
68        ctx: &SearchContext,
69        config: &SearchConfig,
70        center: MotionVector,
71    ) -> (MotionVector, u32, usize) {
72        let mut best_mv = center;
73        let mut best_sad = ctx.calculate_sad(&center).unwrap_or(u32::MAX);
74        let mut best_idx = 4; // Center
75
76        for (idx, &(dx, dy)) in self.points.iter().enumerate() {
77            let mv =
78                MotionVector::from_full_pel(center.full_pel_x() + dx, center.full_pel_y() + dy);
79
80            if !ctx.is_valid_mv(&mv, &config.range) {
81                continue;
82            }
83
84            if let Some(sad) = ctx.calculate_sad(&mv) {
85                if sad < best_sad {
86                    best_sad = sad;
87                    best_mv = mv;
88                    best_idx = idx;
89                }
90            }
91        }
92
93        (best_mv, best_sad, best_idx)
94    }
95}
96
97/// Large Diamond Search Pattern (LDSP).
98///
99/// An 8-point pattern for coarse search:
100/// ```text
101///       *
102///     * * *
103///   * * O * *
104///     * * *
105///       *
106/// ```
107#[derive(Clone, Copy, Debug)]
108pub struct LargeDiamond {
109    /// Pattern offsets (dx, dy) for each point.
110    pub points: [(i32, i32); 8],
111}
112
113impl Default for LargeDiamond {
114    fn default() -> Self {
115        Self::new()
116    }
117}
118
119impl LargeDiamond {
120    /// Standard LDSP offsets.
121    pub const PATTERN: [(i32, i32); 8] = [
122        (0, -2),  // Top
123        (-1, -1), // Top-left
124        (1, -1),  // Top-right
125        (-2, 0),  // Left
126        (2, 0),   // Right
127        (-1, 1),  // Bottom-left
128        (1, 1),   // Bottom-right
129        (0, 2),   // Bottom
130    ];
131
132    /// Creates a new large diamond pattern.
133    #[must_use]
134    pub const fn new() -> Self {
135        Self {
136            points: Self::PATTERN,
137        }
138    }
139
140    /// Returns the number of points in the pattern.
141    #[must_use]
142    pub const fn size(&self) -> usize {
143        8
144    }
145
146    /// Gets the offset at a given index.
147    #[must_use]
148    pub const fn get(&self, index: usize) -> Option<(i32, i32)> {
149        if index < 8 {
150            Some(self.points[index])
151        } else {
152            None
153        }
154    }
155
156    /// Searches using the large diamond pattern.
157    pub fn search(
158        &self,
159        ctx: &SearchContext,
160        config: &SearchConfig,
161        center: MotionVector,
162    ) -> (MotionVector, u32, usize) {
163        let mut best_mv = center;
164        let mut best_sad = ctx.calculate_sad(&center).unwrap_or(u32::MAX);
165        let mut best_idx = 8; // Center
166
167        for (idx, &(dx, dy)) in self.points.iter().enumerate() {
168            let mv =
169                MotionVector::from_full_pel(center.full_pel_x() + dx, center.full_pel_y() + dy);
170
171            if !ctx.is_valid_mv(&mv, &config.range) {
172                continue;
173            }
174
175            if let Some(sad) = ctx.calculate_sad(&mv) {
176                if sad < best_sad {
177                    best_sad = sad;
178                    best_mv = mv;
179                    best_idx = idx;
180                }
181            }
182        }
183
184        (best_mv, best_sad, best_idx)
185    }
186}
187
188/// Extended Diamond Search Pattern.
189///
190/// A 16-point pattern for larger steps:
191/// ```text
192///           *
193///         * * *
194///       * * * * *
195///     * * * O * * *
196///       * * * * *
197///         * * *
198///           *
199/// ```
200#[derive(Clone, Copy, Debug)]
201pub struct ExtendedDiamond {
202    /// Inner ring (4 points, distance 1).
203    pub inner: [(i32, i32); 4],
204    /// Middle ring (8 points, distance 2).
205    pub middle: [(i32, i32); 8],
206    /// Outer ring (4 points, distance 3).
207    pub outer: [(i32, i32); 4],
208}
209
210impl Default for ExtendedDiamond {
211    fn default() -> Self {
212        Self::new()
213    }
214}
215
216impl ExtendedDiamond {
217    /// Creates a new extended diamond pattern.
218    #[must_use]
219    pub const fn new() -> Self {
220        Self {
221            inner: SmallDiamond::PATTERN,
222            middle: LargeDiamond::PATTERN,
223            outer: [(0, -3), (-3, 0), (3, 0), (0, 3)],
224        }
225    }
226
227    /// Searches using all rings of the extended diamond.
228    pub fn search(
229        &self,
230        ctx: &SearchContext,
231        config: &SearchConfig,
232        center: MotionVector,
233    ) -> (MotionVector, u32) {
234        let mut best_mv = center;
235        let mut best_sad = ctx.calculate_sad(&center).unwrap_or(u32::MAX);
236
237        // Search all rings
238        for &(dx, dy) in self.outer.iter().chain(&self.middle).chain(&self.inner) {
239            let mv =
240                MotionVector::from_full_pel(center.full_pel_x() + dx, center.full_pel_y() + dy);
241
242            if !ctx.is_valid_mv(&mv, &config.range) {
243                continue;
244            }
245
246            if let Some(sad) = ctx.calculate_sad(&mv) {
247                if sad < best_sad {
248                    best_sad = sad;
249                    best_mv = mv;
250                }
251            }
252        }
253
254        (best_mv, best_sad)
255    }
256}
257
258/// Adaptive diamond search that switches between SDSP and LDSP.
259///
260/// This implementation uses LDSP initially and switches to SDSP when:
261/// 1. The best point is at the center (convergence)
262/// 2. A threshold number of iterations has passed
263#[derive(Clone, Debug)]
264pub struct AdaptiveDiamond {
265    /// Small diamond pattern.
266    sdsp: SmallDiamond,
267    /// Large diamond pattern.
268    ldsp: LargeDiamond,
269    /// Maximum LDSP iterations before switching to SDSP.
270    max_ldsp_iterations: u32,
271    /// SAD threshold for early switch to SDSP.
272    switch_threshold: u32,
273}
274
275impl Default for AdaptiveDiamond {
276    fn default() -> Self {
277        Self::new()
278    }
279}
280
281impl AdaptiveDiamond {
282    /// Creates a new adaptive diamond search.
283    #[must_use]
284    pub const fn new() -> Self {
285        Self {
286            sdsp: SmallDiamond::new(),
287            ldsp: LargeDiamond::new(),
288            max_ldsp_iterations: 8,
289            switch_threshold: 512,
290        }
291    }
292
293    /// Sets the maximum LDSP iterations.
294    #[must_use]
295    pub const fn max_iterations(mut self, max: u32) -> Self {
296        self.max_ldsp_iterations = max;
297        self
298    }
299
300    /// Sets the SAD threshold for early switch.
301    #[must_use]
302    pub const fn switch_threshold(mut self, threshold: u32) -> Self {
303        self.switch_threshold = threshold;
304        self
305    }
306}
307
308impl MotionSearch for AdaptiveDiamond {
309    fn search(&self, ctx: &SearchContext, config: &SearchConfig) -> BlockMatch {
310        self.search_with_predictor(ctx, config, MotionVector::zero())
311    }
312
313    fn search_with_predictor(
314        &self,
315        ctx: &SearchContext,
316        config: &SearchConfig,
317        predictor: MotionVector,
318    ) -> BlockMatch {
319        let mut center = predictor.to_precision(MvPrecision::FullPel);
320        let mut best_sad = ctx.calculate_sad(&center).unwrap_or(u32::MAX);
321
322        // Phase 1: Large diamond search
323        for iteration in 0..self.max_ldsp_iterations {
324            let (new_center, new_sad, best_idx) = self.ldsp.search(ctx, config, center);
325
326            // Check for convergence (center is best)
327            if best_idx >= self.ldsp.size() {
328                break;
329            }
330
331            // Check for early switch
332            if new_sad < self.switch_threshold {
333                center = new_center;
334                best_sad = new_sad;
335                break;
336            }
337
338            center = new_center;
339            best_sad = new_sad;
340
341            // Early termination
342            if config.early_termination && best_sad < config.early_threshold {
343                let cost = config.mv_cost.rd_cost(&center, best_sad);
344                return BlockMatch::new(center, best_sad, cost);
345            }
346
347            // Maximum iterations check (avoid infinite loop)
348            if iteration >= self.max_ldsp_iterations - 1 {
349                break;
350            }
351        }
352
353        // Phase 2: Small diamond refinement
354        loop {
355            let (new_center, new_sad, best_idx) = self.sdsp.search(ctx, config, center);
356
357            // Check for convergence
358            if best_idx >= self.sdsp.size() {
359                break;
360            }
361
362            center = new_center;
363            best_sad = new_sad;
364        }
365
366        let cost = config.mv_cost.rd_cost(&center, best_sad);
367        BlockMatch::new(center, best_sad, cost)
368    }
369}
370
371/// Predictor-based diamond search.
372///
373/// Uses multiple predictors (spatial/temporal) to initialize search
374/// from the most promising starting point.
375#[derive(Clone, Debug)]
376pub struct PredictorDiamond {
377    /// Underlying diamond search.
378    diamond: AdaptiveDiamond,
379    /// Maximum number of predictors to try.
380    max_predictors: usize,
381}
382
383impl Default for PredictorDiamond {
384    fn default() -> Self {
385        Self::new()
386    }
387}
388
389impl PredictorDiamond {
390    /// Creates a new predictor-based diamond search.
391    #[must_use]
392    pub const fn new() -> Self {
393        Self {
394            diamond: AdaptiveDiamond::new(),
395            max_predictors: 5,
396        }
397    }
398
399    /// Sets the maximum number of predictors.
400    #[must_use]
401    pub const fn max_predictors(mut self, max: usize) -> Self {
402        self.max_predictors = max;
403        self
404    }
405
406    /// Searches with multiple predictors.
407    pub fn search_multi(
408        &self,
409        ctx: &SearchContext,
410        config: &SearchConfig,
411        predictors: &[MotionVector],
412    ) -> BlockMatch {
413        let mut best = BlockMatch::worst();
414
415        // Try zero MV first
416        if let Some(sad) = ctx.calculate_sad(&MotionVector::zero()) {
417            let cost = config.mv_cost.rd_cost(&MotionVector::zero(), sad);
418            let candidate = BlockMatch::new(MotionVector::zero(), sad, cost);
419            best.update_if_better(&candidate);
420
421            // Early termination for perfect match
422            if sad == 0 {
423                return best;
424            }
425        }
426
427        // Evaluate each predictor
428        for (i, &pred) in predictors.iter().take(self.max_predictors).enumerate() {
429            if i > 0 && pred.is_zero() {
430                continue; // Skip duplicate zero MV
431            }
432
433            // Quick evaluation of predictor
434            let pred_fp = pred.to_precision(MvPrecision::FullPel);
435            if let Some(sad) = ctx.calculate_sad(&pred_fp) {
436                if sad < best.sad {
437                    // Full search from this predictor
438                    let result = self.diamond.search_with_predictor(ctx, config, pred);
439                    best.update_if_better(&result);
440                }
441            }
442        }
443
444        // If no predictor worked well, search from best point so far
445        if best.sad > config.early_threshold {
446            let result = self.diamond.search_with_predictor(ctx, config, best.mv);
447            best.update_if_better(&result);
448        }
449
450        best
451    }
452}
453
454impl MotionSearch for PredictorDiamond {
455    fn search(&self, ctx: &SearchContext, config: &SearchConfig) -> BlockMatch {
456        self.diamond.search(ctx, config)
457    }
458
459    fn search_with_predictor(
460        &self,
461        ctx: &SearchContext,
462        config: &SearchConfig,
463        predictor: MotionVector,
464    ) -> BlockMatch {
465        self.diamond.search_with_predictor(ctx, config, predictor)
466    }
467}
468
469/// Cross diamond search pattern.
470///
471/// Combines cross pattern with diamond for better coverage of
472/// horizontal/vertical motion.
473#[derive(Clone, Debug)]
474pub struct CrossDiamond {
475    /// Cross pattern range.
476    cross_range: i32,
477    /// Diamond search for refinement.
478    diamond: AdaptiveDiamond,
479}
480
481impl Default for CrossDiamond {
482    fn default() -> Self {
483        Self::new()
484    }
485}
486
487impl CrossDiamond {
488    /// Creates a new cross diamond search.
489    #[must_use]
490    pub const fn new() -> Self {
491        Self {
492            cross_range: 4,
493            diamond: AdaptiveDiamond::new(),
494        }
495    }
496
497    /// Sets the cross pattern range.
498    #[must_use]
499    pub const fn cross_range(mut self, range: i32) -> Self {
500        self.cross_range = range;
501        self
502    }
503
504    /// Performs cross pattern search.
505    fn cross_search(
506        &self,
507        ctx: &SearchContext,
508        config: &SearchConfig,
509        center: MotionVector,
510    ) -> (MotionVector, u32) {
511        let mut best_mv = center;
512        let mut best_sad = ctx.calculate_sad(&center).unwrap_or(u32::MAX);
513
514        // Horizontal cross
515        for dx in -self.cross_range..=self.cross_range {
516            if dx == 0 {
517                continue;
518            }
519            let mv = MotionVector::from_full_pel(center.full_pel_x() + dx, center.full_pel_y());
520
521            if !ctx.is_valid_mv(&mv, &config.range) {
522                continue;
523            }
524
525            if let Some(sad) = ctx.calculate_sad(&mv) {
526                if sad < best_sad {
527                    best_sad = sad;
528                    best_mv = mv;
529                }
530            }
531        }
532
533        // Vertical cross
534        for dy in -self.cross_range..=self.cross_range {
535            if dy == 0 {
536                continue;
537            }
538            let mv = MotionVector::from_full_pel(center.full_pel_x(), center.full_pel_y() + dy);
539
540            if !ctx.is_valid_mv(&mv, &config.range) {
541                continue;
542            }
543
544            if let Some(sad) = ctx.calculate_sad(&mv) {
545                if sad < best_sad {
546                    best_sad = sad;
547                    best_mv = mv;
548                }
549            }
550        }
551
552        (best_mv, best_sad)
553    }
554}
555
556impl MotionSearch for CrossDiamond {
557    fn search(&self, ctx: &SearchContext, config: &SearchConfig) -> BlockMatch {
558        self.search_with_predictor(ctx, config, MotionVector::zero())
559    }
560
561    fn search_with_predictor(
562        &self,
563        ctx: &SearchContext,
564        config: &SearchConfig,
565        predictor: MotionVector,
566    ) -> BlockMatch {
567        let center = predictor.to_precision(MvPrecision::FullPel);
568
569        // Phase 1: Cross search
570        let (cross_best, _) = self.cross_search(ctx, config, center);
571
572        // Phase 2: Diamond refinement
573        self.diamond.search_with_predictor(ctx, config, cross_best)
574    }
575}
576
577// =============================================================================
578// Hexagonal Search Pattern
579// =============================================================================
580
581/// Hexagonal search pattern (HEX) for motion estimation.
582///
583/// A 6-point pattern inspired by the H.264 HEX search:
584/// ```text
585///     *   *
586///   *   O   *
587///     *   *
588/// ```
589///
590/// Hexagonal patterns often outperform diamond patterns for complex motion
591/// because they cover 6 equidistant directions simultaneously.
592#[derive(Clone, Copy, Debug)]
593pub struct HexagonalSearch {
594    /// Inner hex (6 points, radius ≈ 2).
595    pub inner: [(i32, i32); 6],
596    /// Outer hex (6 points, radius ≈ 4).
597    pub outer: [(i32, i32); 6],
598    /// Maximum iterations before refinement.
599    pub max_iterations: u32,
600}
601
602impl Default for HexagonalSearch {
603    fn default() -> Self {
604        Self::new()
605    }
606}
607
608impl HexagonalSearch {
609    /// Standard inner hexagon offsets (radius ≈ 2).
610    pub const INNER_PATTERN: [(i32, i32); 6] =
611        [(-2, 0), (-1, -2), (1, -2), (2, 0), (1, 2), (-1, 2)];
612
613    /// Standard outer hexagon offsets (radius ≈ 4).
614    pub const OUTER_PATTERN: [(i32, i32); 6] =
615        [(-4, 0), (-2, -4), (2, -4), (4, 0), (2, 4), (-2, 4)];
616
617    /// Create a new hexagonal search.
618    #[must_use]
619    pub const fn new() -> Self {
620        Self {
621            inner: Self::INNER_PATTERN,
622            outer: Self::OUTER_PATTERN,
623            max_iterations: 8,
624        }
625    }
626
627    /// Set maximum iterations.
628    #[must_use]
629    pub const fn max_iterations(mut self, max: u32) -> Self {
630        self.max_iterations = max;
631        self
632    }
633
634    /// Search one hexagon ring centered on `center`.
635    fn search_hex(
636        &self,
637        ctx: &SearchContext,
638        config: &SearchConfig,
639        center: MotionVector,
640        pattern: &[(i32, i32)],
641    ) -> (MotionVector, u32) {
642        let mut best_mv = center;
643        let mut best_sad = ctx.calculate_sad(&center).unwrap_or(u32::MAX);
644
645        for &(dx, dy) in pattern {
646            let mv =
647                MotionVector::from_full_pel(center.full_pel_x() + dx, center.full_pel_y() + dy);
648            if !ctx.is_valid_mv(&mv, &config.range) {
649                continue;
650            }
651            if let Some(sad) = ctx.calculate_sad(&mv) {
652                if sad < best_sad {
653                    best_sad = sad;
654                    best_mv = mv;
655                }
656            }
657        }
658
659        (best_mv, best_sad)
660    }
661}
662
663impl MotionSearch for HexagonalSearch {
664    fn search(&self, ctx: &SearchContext, config: &SearchConfig) -> BlockMatch {
665        self.search_with_predictor(ctx, config, MotionVector::zero())
666    }
667
668    fn search_with_predictor(
669        &self,
670        ctx: &SearchContext,
671        config: &SearchConfig,
672        predictor: MotionVector,
673    ) -> BlockMatch {
674        let center = predictor.to_precision(MvPrecision::FullPel);
675        let mut current = center;
676        let mut current_sad = ctx.calculate_sad(&current).unwrap_or(u32::MAX);
677
678        // Phase 1: coarse outer hex search
679        for _ in 0..self.max_iterations {
680            let (best_mv, best_sad) = self.search_hex(ctx, config, current, &self.outer);
681            if best_sad >= current_sad {
682                break;
683            }
684            current = best_mv;
685            current_sad = best_sad;
686
687            if current_sad == 0 {
688                break;
689            }
690        }
691
692        // Phase 2: fine inner hex refinement
693        for _ in 0..self.max_iterations {
694            let (best_mv, best_sad) = self.search_hex(ctx, config, current, &self.inner);
695            if best_sad >= current_sad {
696                break;
697            }
698            current = best_mv;
699            current_sad = best_sad;
700
701            if current_sad == 0 {
702                break;
703            }
704        }
705
706        // Phase 3: small diamond final refinement
707        let sdsp = SmallDiamond::new();
708        let (refined_mv, refined_sad, _) = sdsp.search(ctx, config, current);
709        if refined_sad < current_sad {
710            current = refined_mv;
711            current_sad = refined_sad;
712        }
713
714        let cost = config.mv_cost.rd_cost(&current, current_sad);
715        BlockMatch::new(current, current_sad, cost)
716    }
717}
718
719// =============================================================================
720// UMHex (Unsymmetric Multi-Hexagon) Search
721// =============================================================================
722
723/// UMHex — Unsymmetric Multi-Hexagon grid search.
724///
725/// A state-of-the-art fast ME algorithm (Zhu & Ma, 2000) used in JM H.264
726/// reference software and ported here for AV1/VP9 quality targets.
727///
728/// # Algorithm
729///
730/// 1. **Predictor check** — evaluate MV predictors (zero, spatial, temporal).
731/// 2. **Unsymmetric-cross** — rapid scan along horizontal and vertical axes.
732/// 3. **Hexagon expansion** — grow the hex grid until no improvement.
733/// 4. **Small diamond refinement** — SDSP for sub-pixel accuracy.
734///
735/// # Performance
736///
737/// Typically 4-8× faster than full search at ≤ 1 dB quality loss.
738#[derive(Clone, Debug)]
739pub struct UMHexSearch {
740    /// Maximum hexagon expansion steps.
741    max_hex_steps: u32,
742    /// Unsymmetric-cross search range.
743    cross_range: i32,
744    /// SAD threshold for early termination.
745    early_exit_threshold: u32,
746}
747
748impl Default for UMHexSearch {
749    fn default() -> Self {
750        Self::new()
751    }
752}
753
754impl UMHexSearch {
755    /// Inner hexagon (radius ≈ 1 pel).
756    const HEX1: [(i32, i32); 6] = [(-1, -2), (1, -2), (2, 0), (1, 2), (-1, 2), (-2, 0)];
757    /// Outer hexagon (radius ≈ 2 pel).
758    const HEX2: [(i32, i32); 12] = [
759        (-1, -2),
760        (1, -2), // top
761        (2, -1),
762        (2, 1), // right
763        (1, 2),
764        (-1, 2), // bottom
765        (-2, 1),
766        (-2, -1), // left
767        (0, -4),
768        (4, 0), // extended top/right
769        (0, 4),
770        (-4, 0), // extended bottom/left
771    ];
772
773    /// Create a new UMHex search with default parameters.
774    #[must_use]
775    pub const fn new() -> Self {
776        Self {
777            max_hex_steps: 16,
778            cross_range: 8,
779            early_exit_threshold: 4,
780        }
781    }
782
783    /// Set maximum hexagon expansion steps.
784    #[must_use]
785    pub const fn max_hex_steps(mut self, steps: u32) -> Self {
786        self.max_hex_steps = steps;
787        self
788    }
789
790    /// Set unsymmetric-cross search range.
791    #[must_use]
792    pub const fn cross_range(mut self, range: i32) -> Self {
793        self.cross_range = range;
794        self
795    }
796
797    /// Set SAD early-exit threshold.
798    #[must_use]
799    pub const fn early_exit_threshold(mut self, threshold: u32) -> Self {
800        self.early_exit_threshold = threshold;
801        self
802    }
803
804    /// Unsymmetric-cross scan: rapid horizontal then vertical scan.
805    fn cross_scan(
806        &self,
807        ctx: &SearchContext,
808        config: &SearchConfig,
809        center: MotionVector,
810    ) -> (MotionVector, u32) {
811        let mut best_mv = center;
812        let mut best_sad = ctx.calculate_sad(&center).unwrap_or(u32::MAX);
813
814        // Horizontal pass (non-uniform spacing: small near center, larger far)
815        let h_offsets: &[i32] = &[-8, -4, -2, -1, 1, 2, 4, 8];
816        for &dx in h_offsets {
817            if dx.abs() > self.cross_range {
818                continue;
819            }
820            let mv = MotionVector::from_full_pel(center.full_pel_x() + dx, center.full_pel_y());
821            if ctx.is_valid_mv(&mv, &config.range) {
822                if let Some(sad) = ctx.calculate_sad(&mv) {
823                    if sad < best_sad {
824                        best_sad = sad;
825                        best_mv = mv;
826                    }
827                }
828            }
829        }
830
831        // Vertical pass
832        let v_offsets: &[i32] = &[-8, -4, -2, -1, 1, 2, 4, 8];
833        for &dy in v_offsets {
834            if dy.abs() > self.cross_range {
835                continue;
836            }
837            let mv = MotionVector::from_full_pel(center.full_pel_x(), center.full_pel_y() + dy);
838            if ctx.is_valid_mv(&mv, &config.range) {
839                if let Some(sad) = ctx.calculate_sad(&mv) {
840                    if sad < best_sad {
841                        best_sad = sad;
842                        best_mv = mv;
843                    }
844                }
845            }
846        }
847
848        (best_mv, best_sad)
849    }
850
851    /// Hexagon grid expansion step.
852    fn hex_step(
853        ctx: &SearchContext,
854        config: &SearchConfig,
855        center: MotionVector,
856        pattern: &[(i32, i32)],
857    ) -> (MotionVector, u32, bool) {
858        let mut best_mv = center;
859        let mut best_sad = ctx.calculate_sad(&center).unwrap_or(u32::MAX);
860        let mut improved = false;
861
862        for &(dx, dy) in pattern {
863            let mv =
864                MotionVector::from_full_pel(center.full_pel_x() + dx, center.full_pel_y() + dy);
865            if !ctx.is_valid_mv(&mv, &config.range) {
866                continue;
867            }
868            if let Some(sad) = ctx.calculate_sad(&mv) {
869                if sad < best_sad {
870                    best_sad = sad;
871                    best_mv = mv;
872                    improved = true;
873                }
874            }
875        }
876
877        (best_mv, best_sad, improved)
878    }
879}
880
881impl MotionSearch for UMHexSearch {
882    fn search(&self, ctx: &SearchContext, config: &SearchConfig) -> BlockMatch {
883        self.search_with_predictor(ctx, config, MotionVector::zero())
884    }
885
886    fn search_with_predictor(
887        &self,
888        ctx: &SearchContext,
889        config: &SearchConfig,
890        predictor: MotionVector,
891    ) -> BlockMatch {
892        // Step 1: initialise from predictor
893        let pred_fp = predictor.to_precision(MvPrecision::FullPel);
894        let pred_sad = ctx.calculate_sad(&pred_fp).unwrap_or(u32::MAX);
895
896        let zero_sad = ctx.calculate_sad(&MotionVector::zero()).unwrap_or(u32::MAX);
897
898        let (mut current, mut current_sad) = if pred_sad <= zero_sad {
899            (pred_fp, pred_sad)
900        } else {
901            (MotionVector::zero(), zero_sad)
902        };
903
904        // Early exit for trivial match
905        if current_sad <= self.early_exit_threshold {
906            let cost = config.mv_cost.rd_cost(&current, current_sad);
907            return BlockMatch::new(current, current_sad, cost);
908        }
909
910        // Step 2: unsymmetric-cross scan
911        let (cross_mv, cross_sad) = self.cross_scan(ctx, config, current);
912        if cross_sad < current_sad {
913            current = cross_mv;
914            current_sad = cross_sad;
915        }
916
917        if current_sad <= self.early_exit_threshold {
918            let cost = config.mv_cost.rd_cost(&current, current_sad);
919            return BlockMatch::new(current, current_sad, cost);
920        }
921
922        // Step 3: HEX2 expansion until convergence
923        for _ in 0..self.max_hex_steps {
924            let (mv, sad, improved) = Self::hex_step(ctx, config, current, &Self::HEX2);
925            if !improved {
926                break;
927            }
928            current = mv;
929            current_sad = sad;
930
931            if current_sad <= self.early_exit_threshold {
932                break;
933            }
934        }
935
936        // Step 4: HEX1 refinement
937        for _ in 0..4 {
938            let (mv, sad, improved) = Self::hex_step(ctx, config, current, &Self::HEX1);
939            if !improved {
940                break;
941            }
942            current = mv;
943            current_sad = sad;
944        }
945
946        // Step 5: small diamond final refinement
947        let sdsp = SmallDiamond::new();
948        let (refined, rsad, _) = sdsp.search(ctx, config, current);
949        if rsad < current_sad {
950            current = refined;
951            current_sad = rsad;
952        }
953
954        let cost = config.mv_cost.rd_cost(&current, current_sad);
955        BlockMatch::new(current, current_sad, cost)
956    }
957}
958
959#[cfg(test)]
960mod tests {
961    use super::*;
962    use crate::motion::types::{BlockSize, SearchRange};
963
964    fn create_test_context<'a>(
965        src: &'a [u8],
966        ref_frame: &'a [u8],
967        width: usize,
968        height: usize,
969    ) -> SearchContext<'a> {
970        SearchContext::new(
971            src,
972            width,
973            ref_frame,
974            width,
975            BlockSize::Block8x8,
976            0,
977            0,
978            width,
979            height,
980        )
981    }
982
983    #[test]
984    fn test_small_diamond_pattern() {
985        let sdsp = SmallDiamond::new();
986        assert_eq!(sdsp.size(), 4);
987        assert_eq!(sdsp.get(0), Some((0, -1)));
988        assert_eq!(sdsp.get(4), None);
989    }
990
991    #[test]
992    fn test_large_diamond_pattern() {
993        let ldsp = LargeDiamond::new();
994        assert_eq!(ldsp.size(), 8);
995        assert_eq!(ldsp.get(0), Some((0, -2)));
996        assert_eq!(ldsp.get(8), None);
997    }
998
999    #[test]
1000    fn test_small_diamond_search() {
1001        let src = vec![100u8; 64];
1002        let mut ref_frame = vec![50u8; 144]; // 12x12
1003
1004        // Place match at offset (1, 0)
1005        for row in 0..8 {
1006            for col in 0..8 {
1007                ref_frame[row * 12 + col + 1] = 100;
1008            }
1009        }
1010
1011        let ctx = SearchContext::new(&src, 8, &ref_frame, 12, BlockSize::Block8x8, 0, 0, 12, 12);
1012        let config = SearchConfig::with_range(SearchRange::symmetric(4));
1013
1014        let sdsp = SmallDiamond::new();
1015        let (mv, sad, _) = sdsp.search(&ctx, &config, MotionVector::zero());
1016
1017        // Should find match at (1, 0)
1018        assert_eq!(mv.full_pel_x(), 1);
1019        assert_eq!(mv.full_pel_y(), 0);
1020        assert_eq!(sad, 0);
1021    }
1022
1023    #[test]
1024    fn test_large_diamond_search() {
1025        let src = vec![100u8; 64];
1026        let mut ref_frame = vec![50u8; 256]; // 16x16
1027
1028        // Place match at offset (2, 0)
1029        for row in 0..8 {
1030            for col in 0..8 {
1031                ref_frame[row * 16 + col + 2] = 100;
1032            }
1033        }
1034
1035        let ctx = SearchContext::new(&src, 8, &ref_frame, 16, BlockSize::Block8x8, 0, 0, 16, 16);
1036        let config = SearchConfig::with_range(SearchRange::symmetric(4));
1037
1038        let ldsp = LargeDiamond::new();
1039        let (mv, sad, _) = ldsp.search(&ctx, &config, MotionVector::zero());
1040
1041        // Should find match at (2, 0)
1042        assert_eq!(mv.full_pel_x(), 2);
1043        assert_eq!(mv.full_pel_y(), 0);
1044        assert_eq!(sad, 0);
1045    }
1046
1047    #[test]
1048    fn test_extended_diamond_search() {
1049        let src = vec![100u8; 64];
1050        let mut ref_frame = vec![50u8; 256];
1051
1052        // Place match at offset (3, 0)
1053        for row in 0..8 {
1054            for col in 0..8 {
1055                ref_frame[row * 16 + col + 3] = 100;
1056            }
1057        }
1058
1059        let ctx = SearchContext::new(&src, 8, &ref_frame, 16, BlockSize::Block8x8, 0, 0, 16, 16);
1060        let config = SearchConfig::with_range(SearchRange::symmetric(4));
1061
1062        let extended = ExtendedDiamond::new();
1063        let (mv, sad) = extended.search(&ctx, &config, MotionVector::zero());
1064
1065        // Should find match at (3, 0)
1066        assert_eq!(mv.full_pel_x(), 3);
1067        assert_eq!(mv.full_pel_y(), 0);
1068        assert_eq!(sad, 0);
1069    }
1070
1071    #[test]
1072    fn test_adaptive_diamond_convergence() {
1073        let src = vec![100u8; 64];
1074        let mut ref_frame = vec![50u8; 256];
1075
1076        // Place match at (4, 4)
1077        for row in 0..8 {
1078            for col in 0..8 {
1079                ref_frame[(row + 4) * 16 + col + 4] = 100;
1080            }
1081        }
1082
1083        let ctx = SearchContext::new(&src, 8, &ref_frame, 16, BlockSize::Block8x8, 0, 0, 16, 16);
1084        let config = SearchConfig::with_range(SearchRange::symmetric(8));
1085
1086        let adaptive = AdaptiveDiamond::new();
1087        let result = adaptive.search(&ctx, &config);
1088
1089        // Should find close to the optimal
1090        assert!(result.sad < 500);
1091    }
1092
1093    #[test]
1094    fn test_predictor_diamond() {
1095        let src = vec![100u8; 64];
1096        let mut ref_frame = vec![50u8; 256];
1097
1098        // Place match at (4, 4)
1099        for row in 0..8 {
1100            for col in 0..8 {
1101                ref_frame[(row + 4) * 16 + col + 4] = 100;
1102            }
1103        }
1104
1105        let ctx = SearchContext::new(&src, 8, &ref_frame, 16, BlockSize::Block8x8, 0, 0, 16, 16);
1106        let config = SearchConfig::with_range(SearchRange::symmetric(8));
1107
1108        let predictor = PredictorDiamond::new();
1109        let predictors = [
1110            MotionVector::from_full_pel(3, 3), // Close to optimal
1111            MotionVector::from_full_pel(0, 0),
1112        ];
1113
1114        let result = predictor.search_multi(&ctx, &config, &predictors);
1115
1116        // Good predictor should help find optimal
1117        assert!(result.sad < 200);
1118    }
1119
1120    #[test]
1121    fn test_cross_diamond() {
1122        let src = vec![100u8; 64];
1123        let mut ref_frame = vec![50u8; 256];
1124
1125        // Place match at (4, 0) - horizontal motion
1126        for row in 0..8 {
1127            for col in 0..8 {
1128                ref_frame[row * 16 + col + 4] = 100;
1129            }
1130        }
1131
1132        let ctx = SearchContext::new(&src, 8, &ref_frame, 16, BlockSize::Block8x8, 0, 0, 16, 16);
1133        let config = SearchConfig::with_range(SearchRange::symmetric(8));
1134
1135        let cross = CrossDiamond::new();
1136        let result = cross.search(&ctx, &config);
1137
1138        // Cross pattern should handle horizontal motion well
1139        assert!(result.sad < 300);
1140    }
1141
1142    #[test]
1143    fn test_adaptive_diamond_early_switch() {
1144        let src = vec![100u8; 64];
1145        let ref_frame = vec![100u8; 256];
1146
1147        // Near-perfect match at origin
1148        let ctx = SearchContext::new(&src, 8, &ref_frame, 16, BlockSize::Block8x8, 0, 0, 16, 16);
1149        let config = SearchConfig::with_range(SearchRange::symmetric(8));
1150
1151        let adaptive = AdaptiveDiamond::new().switch_threshold(100);
1152        let result = adaptive.search(&ctx, &config);
1153
1154        // Should find match quickly
1155        assert_eq!(result.sad, 0);
1156    }
1157
1158    #[test]
1159    fn test_diamond_builder_pattern() {
1160        let adaptive = AdaptiveDiamond::new()
1161            .max_iterations(16)
1162            .switch_threshold(1000);
1163
1164        assert_eq!(adaptive.max_ldsp_iterations, 16);
1165        assert_eq!(adaptive.switch_threshold, 1000);
1166    }
1167
1168    #[test]
1169    fn test_hexagonal_search_pattern_constants() {
1170        assert_eq!(HexagonalSearch::INNER_PATTERN.len(), 6);
1171        assert_eq!(HexagonalSearch::OUTER_PATTERN.len(), 6);
1172    }
1173
1174    #[test]
1175    fn test_hexagonal_search_finds_match() {
1176        let src = vec![100u8; 64];
1177        let mut ref_frame = vec![50u8; 400]; // 20x20
1178
1179        // Place match at (2, 0) — within inner hex
1180        for row in 0..8usize {
1181            for col in 0..8usize {
1182                ref_frame[row * 20 + col + 2] = 100;
1183            }
1184        }
1185
1186        let ctx = SearchContext::new(&src, 8, &ref_frame, 20, BlockSize::Block8x8, 0, 0, 20, 20);
1187        let config = SearchConfig::with_range(SearchRange::symmetric(8));
1188
1189        let hex = HexagonalSearch::new();
1190        let result = hex.search(&ctx, &config);
1191        assert!(result.sad < 500, "Hex search should find a good match");
1192    }
1193
1194    #[test]
1195    fn test_hexagonal_search_zero_match() {
1196        let src = vec![128u8; 64];
1197        let ref_frame = vec![128u8; 400];
1198
1199        let ctx = SearchContext::new(&src, 8, &ref_frame, 20, BlockSize::Block8x8, 0, 0, 20, 20);
1200        let config = SearchConfig::with_range(SearchRange::symmetric(8));
1201
1202        let hex = HexagonalSearch::new();
1203        let result = hex.search(&ctx, &config);
1204        assert_eq!(result.sad, 0, "Perfect match should have SAD=0");
1205    }
1206
1207    #[test]
1208    fn test_hexagonal_search_max_iterations_builder() {
1209        let hex = HexagonalSearch::new().max_iterations(16);
1210        assert_eq!(hex.max_iterations, 16);
1211    }
1212
1213    #[test]
1214    fn test_umhex_search_finds_match() {
1215        let src = vec![200u8; 64];
1216        let mut ref_frame = vec![50u8; 400]; // 20x20
1217
1218        // Place match at (4, 2)
1219        for row in 0..8usize {
1220            for col in 0..8usize {
1221                ref_frame[(row + 2) * 20 + col + 4] = 200;
1222            }
1223        }
1224
1225        let ctx = SearchContext::new(&src, 8, &ref_frame, 20, BlockSize::Block8x8, 0, 0, 20, 20);
1226        let config = SearchConfig::with_range(SearchRange::symmetric(8));
1227
1228        let umhex = UMHexSearch::new();
1229        let result = umhex.search(&ctx, &config);
1230        assert!(result.sad < 1000, "UMHex should find a reasonable match");
1231    }
1232
1233    #[test]
1234    fn test_umhex_zero_sad_early_exit() {
1235        let src = vec![77u8; 64];
1236        let ref_frame = vec![77u8; 400];
1237
1238        let ctx = SearchContext::new(&src, 8, &ref_frame, 20, BlockSize::Block8x8, 0, 0, 20, 20);
1239        let config = SearchConfig::with_range(SearchRange::symmetric(8));
1240
1241        let umhex = UMHexSearch::new();
1242        let result = umhex.search(&ctx, &config);
1243        assert_eq!(result.sad, 0);
1244    }
1245
1246    #[test]
1247    fn test_umhex_predictor_helps() {
1248        let src = vec![150u8; 64];
1249        let mut ref_frame = vec![0u8; 400];
1250
1251        // Place match far from origin at (6, 6)
1252        for row in 0..8usize {
1253            for col in 0..8usize {
1254                ref_frame[(row + 6) * 20 + col + 6] = 150;
1255            }
1256        }
1257
1258        let ctx = SearchContext::new(&src, 8, &ref_frame, 20, BlockSize::Block8x8, 0, 0, 20, 20);
1259        let config = SearchConfig::with_range(SearchRange::symmetric(10));
1260
1261        let umhex = UMHexSearch::new();
1262        let predictor = MotionVector::from_full_pel(5, 5);
1263        let result = umhex.search_with_predictor(&ctx, &config, predictor);
1264        // With a good predictor, SAD should be low
1265        assert!(result.sad < 2000);
1266    }
1267
1268    #[test]
1269    fn test_umhex_builder_pattern() {
1270        let umhex = UMHexSearch::new()
1271            .max_hex_steps(32)
1272            .cross_range(16)
1273            .early_exit_threshold(8);
1274
1275        assert_eq!(umhex.max_hex_steps, 32);
1276        assert_eq!(umhex.cross_range, 16);
1277        assert_eq!(umhex.early_exit_threshold, 8);
1278    }
1279}