1#![forbid(unsafe_code)]
10#![allow(dead_code)]
11#![allow(clippy::too_many_arguments)]
12#![allow(clippy::must_use_candidate)]
13
14use super::search::{MotionSearch, SearchConfig, SearchContext};
15use super::types::{BlockMatch, MotionVector, MvPrecision};
16
17#[derive(Clone, Copy, Debug)]
26pub struct SmallDiamond {
27 pub points: [(i32, i32); 4],
29}
30
31impl Default for SmallDiamond {
32 fn default() -> Self {
33 Self::new()
34 }
35}
36
37impl SmallDiamond {
38 pub const PATTERN: [(i32, i32); 4] = [(0, -1), (-1, 0), (1, 0), (0, 1)];
40
41 #[must_use]
43 pub const fn new() -> Self {
44 Self {
45 points: Self::PATTERN,
46 }
47 }
48
49 #[must_use]
51 pub const fn size(&self) -> usize {
52 4
53 }
54
55 #[must_use]
57 pub const fn get(&self, index: usize) -> Option<(i32, i32)> {
58 if index < 4 {
59 Some(self.points[index])
60 } else {
61 None
62 }
63 }
64
65 pub fn search(
67 &self,
68 ctx: &SearchContext,
69 config: &SearchConfig,
70 center: MotionVector,
71 ) -> (MotionVector, u32, usize) {
72 let mut best_mv = center;
73 let mut best_sad = ctx.calculate_sad(¢er).unwrap_or(u32::MAX);
74 let mut best_idx = 4; for (idx, &(dx, dy)) in self.points.iter().enumerate() {
77 let mv =
78 MotionVector::from_full_pel(center.full_pel_x() + dx, center.full_pel_y() + dy);
79
80 if !ctx.is_valid_mv(&mv, &config.range) {
81 continue;
82 }
83
84 if let Some(sad) = ctx.calculate_sad(&mv) {
85 if sad < best_sad {
86 best_sad = sad;
87 best_mv = mv;
88 best_idx = idx;
89 }
90 }
91 }
92
93 (best_mv, best_sad, best_idx)
94 }
95}
96
97#[derive(Clone, Copy, Debug)]
108pub struct LargeDiamond {
109 pub points: [(i32, i32); 8],
111}
112
113impl Default for LargeDiamond {
114 fn default() -> Self {
115 Self::new()
116 }
117}
118
119impl LargeDiamond {
120 pub const PATTERN: [(i32, i32); 8] = [
122 (0, -2), (-1, -1), (1, -1), (-2, 0), (2, 0), (-1, 1), (1, 1), (0, 2), ];
131
132 #[must_use]
134 pub const fn new() -> Self {
135 Self {
136 points: Self::PATTERN,
137 }
138 }
139
140 #[must_use]
142 pub const fn size(&self) -> usize {
143 8
144 }
145
146 #[must_use]
148 pub const fn get(&self, index: usize) -> Option<(i32, i32)> {
149 if index < 8 {
150 Some(self.points[index])
151 } else {
152 None
153 }
154 }
155
156 pub fn search(
158 &self,
159 ctx: &SearchContext,
160 config: &SearchConfig,
161 center: MotionVector,
162 ) -> (MotionVector, u32, usize) {
163 let mut best_mv = center;
164 let mut best_sad = ctx.calculate_sad(¢er).unwrap_or(u32::MAX);
165 let mut best_idx = 8; for (idx, &(dx, dy)) in self.points.iter().enumerate() {
168 let mv =
169 MotionVector::from_full_pel(center.full_pel_x() + dx, center.full_pel_y() + dy);
170
171 if !ctx.is_valid_mv(&mv, &config.range) {
172 continue;
173 }
174
175 if let Some(sad) = ctx.calculate_sad(&mv) {
176 if sad < best_sad {
177 best_sad = sad;
178 best_mv = mv;
179 best_idx = idx;
180 }
181 }
182 }
183
184 (best_mv, best_sad, best_idx)
185 }
186}
187
188#[derive(Clone, Copy, Debug)]
201pub struct ExtendedDiamond {
202 pub inner: [(i32, i32); 4],
204 pub middle: [(i32, i32); 8],
206 pub outer: [(i32, i32); 4],
208}
209
210impl Default for ExtendedDiamond {
211 fn default() -> Self {
212 Self::new()
213 }
214}
215
216impl ExtendedDiamond {
217 #[must_use]
219 pub const fn new() -> Self {
220 Self {
221 inner: SmallDiamond::PATTERN,
222 middle: LargeDiamond::PATTERN,
223 outer: [(0, -3), (-3, 0), (3, 0), (0, 3)],
224 }
225 }
226
227 pub fn search(
229 &self,
230 ctx: &SearchContext,
231 config: &SearchConfig,
232 center: MotionVector,
233 ) -> (MotionVector, u32) {
234 let mut best_mv = center;
235 let mut best_sad = ctx.calculate_sad(¢er).unwrap_or(u32::MAX);
236
237 for &(dx, dy) in self.outer.iter().chain(&self.middle).chain(&self.inner) {
239 let mv =
240 MotionVector::from_full_pel(center.full_pel_x() + dx, center.full_pel_y() + dy);
241
242 if !ctx.is_valid_mv(&mv, &config.range) {
243 continue;
244 }
245
246 if let Some(sad) = ctx.calculate_sad(&mv) {
247 if sad < best_sad {
248 best_sad = sad;
249 best_mv = mv;
250 }
251 }
252 }
253
254 (best_mv, best_sad)
255 }
256}
257
258#[derive(Clone, Debug)]
264pub struct AdaptiveDiamond {
265 sdsp: SmallDiamond,
267 ldsp: LargeDiamond,
269 max_ldsp_iterations: u32,
271 switch_threshold: u32,
273}
274
275impl Default for AdaptiveDiamond {
276 fn default() -> Self {
277 Self::new()
278 }
279}
280
281impl AdaptiveDiamond {
282 #[must_use]
284 pub const fn new() -> Self {
285 Self {
286 sdsp: SmallDiamond::new(),
287 ldsp: LargeDiamond::new(),
288 max_ldsp_iterations: 8,
289 switch_threshold: 512,
290 }
291 }
292
293 #[must_use]
295 pub const fn max_iterations(mut self, max: u32) -> Self {
296 self.max_ldsp_iterations = max;
297 self
298 }
299
300 #[must_use]
302 pub const fn switch_threshold(mut self, threshold: u32) -> Self {
303 self.switch_threshold = threshold;
304 self
305 }
306}
307
308impl MotionSearch for AdaptiveDiamond {
309 fn search(&self, ctx: &SearchContext, config: &SearchConfig) -> BlockMatch {
310 self.search_with_predictor(ctx, config, MotionVector::zero())
311 }
312
313 fn search_with_predictor(
314 &self,
315 ctx: &SearchContext,
316 config: &SearchConfig,
317 predictor: MotionVector,
318 ) -> BlockMatch {
319 let mut center = predictor.to_precision(MvPrecision::FullPel);
320 let mut best_sad = ctx.calculate_sad(¢er).unwrap_or(u32::MAX);
321
322 for iteration in 0..self.max_ldsp_iterations {
324 let (new_center, new_sad, best_idx) = self.ldsp.search(ctx, config, center);
325
326 if best_idx >= self.ldsp.size() {
328 break;
329 }
330
331 if new_sad < self.switch_threshold {
333 center = new_center;
334 best_sad = new_sad;
335 break;
336 }
337
338 center = new_center;
339 best_sad = new_sad;
340
341 if config.early_termination && best_sad < config.early_threshold {
343 let cost = config.mv_cost.rd_cost(¢er, best_sad);
344 return BlockMatch::new(center, best_sad, cost);
345 }
346
347 if iteration >= self.max_ldsp_iterations - 1 {
349 break;
350 }
351 }
352
353 loop {
355 let (new_center, new_sad, best_idx) = self.sdsp.search(ctx, config, center);
356
357 if best_idx >= self.sdsp.size() {
359 break;
360 }
361
362 center = new_center;
363 best_sad = new_sad;
364 }
365
366 let cost = config.mv_cost.rd_cost(¢er, best_sad);
367 BlockMatch::new(center, best_sad, cost)
368 }
369}
370
371#[derive(Clone, Debug)]
376pub struct PredictorDiamond {
377 diamond: AdaptiveDiamond,
379 max_predictors: usize,
381}
382
383impl Default for PredictorDiamond {
384 fn default() -> Self {
385 Self::new()
386 }
387}
388
389impl PredictorDiamond {
390 #[must_use]
392 pub const fn new() -> Self {
393 Self {
394 diamond: AdaptiveDiamond::new(),
395 max_predictors: 5,
396 }
397 }
398
399 #[must_use]
401 pub const fn max_predictors(mut self, max: usize) -> Self {
402 self.max_predictors = max;
403 self
404 }
405
406 pub fn search_multi(
408 &self,
409 ctx: &SearchContext,
410 config: &SearchConfig,
411 predictors: &[MotionVector],
412 ) -> BlockMatch {
413 let mut best = BlockMatch::worst();
414
415 if let Some(sad) = ctx.calculate_sad(&MotionVector::zero()) {
417 let cost = config.mv_cost.rd_cost(&MotionVector::zero(), sad);
418 let candidate = BlockMatch::new(MotionVector::zero(), sad, cost);
419 best.update_if_better(&candidate);
420
421 if sad == 0 {
423 return best;
424 }
425 }
426
427 for (i, &pred) in predictors.iter().take(self.max_predictors).enumerate() {
429 if i > 0 && pred.is_zero() {
430 continue; }
432
433 let pred_fp = pred.to_precision(MvPrecision::FullPel);
435 if let Some(sad) = ctx.calculate_sad(&pred_fp) {
436 if sad < best.sad {
437 let result = self.diamond.search_with_predictor(ctx, config, pred);
439 best.update_if_better(&result);
440 }
441 }
442 }
443
444 if best.sad > config.early_threshold {
446 let result = self.diamond.search_with_predictor(ctx, config, best.mv);
447 best.update_if_better(&result);
448 }
449
450 best
451 }
452}
453
454impl MotionSearch for PredictorDiamond {
455 fn search(&self, ctx: &SearchContext, config: &SearchConfig) -> BlockMatch {
456 self.diamond.search(ctx, config)
457 }
458
459 fn search_with_predictor(
460 &self,
461 ctx: &SearchContext,
462 config: &SearchConfig,
463 predictor: MotionVector,
464 ) -> BlockMatch {
465 self.diamond.search_with_predictor(ctx, config, predictor)
466 }
467}
468
469#[derive(Clone, Debug)]
474pub struct CrossDiamond {
475 cross_range: i32,
477 diamond: AdaptiveDiamond,
479}
480
481impl Default for CrossDiamond {
482 fn default() -> Self {
483 Self::new()
484 }
485}
486
487impl CrossDiamond {
488 #[must_use]
490 pub const fn new() -> Self {
491 Self {
492 cross_range: 4,
493 diamond: AdaptiveDiamond::new(),
494 }
495 }
496
497 #[must_use]
499 pub const fn cross_range(mut self, range: i32) -> Self {
500 self.cross_range = range;
501 self
502 }
503
504 fn cross_search(
506 &self,
507 ctx: &SearchContext,
508 config: &SearchConfig,
509 center: MotionVector,
510 ) -> (MotionVector, u32) {
511 let mut best_mv = center;
512 let mut best_sad = ctx.calculate_sad(¢er).unwrap_or(u32::MAX);
513
514 for dx in -self.cross_range..=self.cross_range {
516 if dx == 0 {
517 continue;
518 }
519 let mv = MotionVector::from_full_pel(center.full_pel_x() + dx, center.full_pel_y());
520
521 if !ctx.is_valid_mv(&mv, &config.range) {
522 continue;
523 }
524
525 if let Some(sad) = ctx.calculate_sad(&mv) {
526 if sad < best_sad {
527 best_sad = sad;
528 best_mv = mv;
529 }
530 }
531 }
532
533 for dy in -self.cross_range..=self.cross_range {
535 if dy == 0 {
536 continue;
537 }
538 let mv = MotionVector::from_full_pel(center.full_pel_x(), center.full_pel_y() + dy);
539
540 if !ctx.is_valid_mv(&mv, &config.range) {
541 continue;
542 }
543
544 if let Some(sad) = ctx.calculate_sad(&mv) {
545 if sad < best_sad {
546 best_sad = sad;
547 best_mv = mv;
548 }
549 }
550 }
551
552 (best_mv, best_sad)
553 }
554}
555
556impl MotionSearch for CrossDiamond {
557 fn search(&self, ctx: &SearchContext, config: &SearchConfig) -> BlockMatch {
558 self.search_with_predictor(ctx, config, MotionVector::zero())
559 }
560
561 fn search_with_predictor(
562 &self,
563 ctx: &SearchContext,
564 config: &SearchConfig,
565 predictor: MotionVector,
566 ) -> BlockMatch {
567 let center = predictor.to_precision(MvPrecision::FullPel);
568
569 let (cross_best, _) = self.cross_search(ctx, config, center);
571
572 self.diamond.search_with_predictor(ctx, config, cross_best)
574 }
575}
576
577#[derive(Clone, Copy, Debug)]
593pub struct HexagonalSearch {
594 pub inner: [(i32, i32); 6],
596 pub outer: [(i32, i32); 6],
598 pub max_iterations: u32,
600}
601
602impl Default for HexagonalSearch {
603 fn default() -> Self {
604 Self::new()
605 }
606}
607
608impl HexagonalSearch {
609 pub const INNER_PATTERN: [(i32, i32); 6] =
611 [(-2, 0), (-1, -2), (1, -2), (2, 0), (1, 2), (-1, 2)];
612
613 pub const OUTER_PATTERN: [(i32, i32); 6] =
615 [(-4, 0), (-2, -4), (2, -4), (4, 0), (2, 4), (-2, 4)];
616
617 #[must_use]
619 pub const fn new() -> Self {
620 Self {
621 inner: Self::INNER_PATTERN,
622 outer: Self::OUTER_PATTERN,
623 max_iterations: 8,
624 }
625 }
626
627 #[must_use]
629 pub const fn max_iterations(mut self, max: u32) -> Self {
630 self.max_iterations = max;
631 self
632 }
633
634 fn search_hex(
636 &self,
637 ctx: &SearchContext,
638 config: &SearchConfig,
639 center: MotionVector,
640 pattern: &[(i32, i32)],
641 ) -> (MotionVector, u32) {
642 let mut best_mv = center;
643 let mut best_sad = ctx.calculate_sad(¢er).unwrap_or(u32::MAX);
644
645 for &(dx, dy) in pattern {
646 let mv =
647 MotionVector::from_full_pel(center.full_pel_x() + dx, center.full_pel_y() + dy);
648 if !ctx.is_valid_mv(&mv, &config.range) {
649 continue;
650 }
651 if let Some(sad) = ctx.calculate_sad(&mv) {
652 if sad < best_sad {
653 best_sad = sad;
654 best_mv = mv;
655 }
656 }
657 }
658
659 (best_mv, best_sad)
660 }
661}
662
663impl MotionSearch for HexagonalSearch {
664 fn search(&self, ctx: &SearchContext, config: &SearchConfig) -> BlockMatch {
665 self.search_with_predictor(ctx, config, MotionVector::zero())
666 }
667
668 fn search_with_predictor(
669 &self,
670 ctx: &SearchContext,
671 config: &SearchConfig,
672 predictor: MotionVector,
673 ) -> BlockMatch {
674 let center = predictor.to_precision(MvPrecision::FullPel);
675 let mut current = center;
676 let mut current_sad = ctx.calculate_sad(¤t).unwrap_or(u32::MAX);
677
678 for _ in 0..self.max_iterations {
680 let (best_mv, best_sad) = self.search_hex(ctx, config, current, &self.outer);
681 if best_sad >= current_sad {
682 break;
683 }
684 current = best_mv;
685 current_sad = best_sad;
686
687 if current_sad == 0 {
688 break;
689 }
690 }
691
692 for _ in 0..self.max_iterations {
694 let (best_mv, best_sad) = self.search_hex(ctx, config, current, &self.inner);
695 if best_sad >= current_sad {
696 break;
697 }
698 current = best_mv;
699 current_sad = best_sad;
700
701 if current_sad == 0 {
702 break;
703 }
704 }
705
706 let sdsp = SmallDiamond::new();
708 let (refined_mv, refined_sad, _) = sdsp.search(ctx, config, current);
709 if refined_sad < current_sad {
710 current = refined_mv;
711 current_sad = refined_sad;
712 }
713
714 let cost = config.mv_cost.rd_cost(¤t, current_sad);
715 BlockMatch::new(current, current_sad, cost)
716 }
717}
718
719#[derive(Clone, Debug)]
739pub struct UMHexSearch {
740 max_hex_steps: u32,
742 cross_range: i32,
744 early_exit_threshold: u32,
746}
747
748impl Default for UMHexSearch {
749 fn default() -> Self {
750 Self::new()
751 }
752}
753
754impl UMHexSearch {
755 const HEX1: [(i32, i32); 6] = [(-1, -2), (1, -2), (2, 0), (1, 2), (-1, 2), (-2, 0)];
757 const HEX2: [(i32, i32); 12] = [
759 (-1, -2),
760 (1, -2), (2, -1),
762 (2, 1), (1, 2),
764 (-1, 2), (-2, 1),
766 (-2, -1), (0, -4),
768 (4, 0), (0, 4),
770 (-4, 0), ];
772
773 #[must_use]
775 pub const fn new() -> Self {
776 Self {
777 max_hex_steps: 16,
778 cross_range: 8,
779 early_exit_threshold: 4,
780 }
781 }
782
783 #[must_use]
785 pub const fn max_hex_steps(mut self, steps: u32) -> Self {
786 self.max_hex_steps = steps;
787 self
788 }
789
790 #[must_use]
792 pub const fn cross_range(mut self, range: i32) -> Self {
793 self.cross_range = range;
794 self
795 }
796
797 #[must_use]
799 pub const fn early_exit_threshold(mut self, threshold: u32) -> Self {
800 self.early_exit_threshold = threshold;
801 self
802 }
803
804 fn cross_scan(
806 &self,
807 ctx: &SearchContext,
808 config: &SearchConfig,
809 center: MotionVector,
810 ) -> (MotionVector, u32) {
811 let mut best_mv = center;
812 let mut best_sad = ctx.calculate_sad(¢er).unwrap_or(u32::MAX);
813
814 let h_offsets: &[i32] = &[-8, -4, -2, -1, 1, 2, 4, 8];
816 for &dx in h_offsets {
817 if dx.abs() > self.cross_range {
818 continue;
819 }
820 let mv = MotionVector::from_full_pel(center.full_pel_x() + dx, center.full_pel_y());
821 if ctx.is_valid_mv(&mv, &config.range) {
822 if let Some(sad) = ctx.calculate_sad(&mv) {
823 if sad < best_sad {
824 best_sad = sad;
825 best_mv = mv;
826 }
827 }
828 }
829 }
830
831 let v_offsets: &[i32] = &[-8, -4, -2, -1, 1, 2, 4, 8];
833 for &dy in v_offsets {
834 if dy.abs() > self.cross_range {
835 continue;
836 }
837 let mv = MotionVector::from_full_pel(center.full_pel_x(), center.full_pel_y() + dy);
838 if ctx.is_valid_mv(&mv, &config.range) {
839 if let Some(sad) = ctx.calculate_sad(&mv) {
840 if sad < best_sad {
841 best_sad = sad;
842 best_mv = mv;
843 }
844 }
845 }
846 }
847
848 (best_mv, best_sad)
849 }
850
851 fn hex_step(
853 ctx: &SearchContext,
854 config: &SearchConfig,
855 center: MotionVector,
856 pattern: &[(i32, i32)],
857 ) -> (MotionVector, u32, bool) {
858 let mut best_mv = center;
859 let mut best_sad = ctx.calculate_sad(¢er).unwrap_or(u32::MAX);
860 let mut improved = false;
861
862 for &(dx, dy) in pattern {
863 let mv =
864 MotionVector::from_full_pel(center.full_pel_x() + dx, center.full_pel_y() + dy);
865 if !ctx.is_valid_mv(&mv, &config.range) {
866 continue;
867 }
868 if let Some(sad) = ctx.calculate_sad(&mv) {
869 if sad < best_sad {
870 best_sad = sad;
871 best_mv = mv;
872 improved = true;
873 }
874 }
875 }
876
877 (best_mv, best_sad, improved)
878 }
879}
880
881impl MotionSearch for UMHexSearch {
882 fn search(&self, ctx: &SearchContext, config: &SearchConfig) -> BlockMatch {
883 self.search_with_predictor(ctx, config, MotionVector::zero())
884 }
885
886 fn search_with_predictor(
887 &self,
888 ctx: &SearchContext,
889 config: &SearchConfig,
890 predictor: MotionVector,
891 ) -> BlockMatch {
892 let pred_fp = predictor.to_precision(MvPrecision::FullPel);
894 let pred_sad = ctx.calculate_sad(&pred_fp).unwrap_or(u32::MAX);
895
896 let zero_sad = ctx.calculate_sad(&MotionVector::zero()).unwrap_or(u32::MAX);
897
898 let (mut current, mut current_sad) = if pred_sad <= zero_sad {
899 (pred_fp, pred_sad)
900 } else {
901 (MotionVector::zero(), zero_sad)
902 };
903
904 if current_sad <= self.early_exit_threshold {
906 let cost = config.mv_cost.rd_cost(¤t, current_sad);
907 return BlockMatch::new(current, current_sad, cost);
908 }
909
910 let (cross_mv, cross_sad) = self.cross_scan(ctx, config, current);
912 if cross_sad < current_sad {
913 current = cross_mv;
914 current_sad = cross_sad;
915 }
916
917 if current_sad <= self.early_exit_threshold {
918 let cost = config.mv_cost.rd_cost(¤t, current_sad);
919 return BlockMatch::new(current, current_sad, cost);
920 }
921
922 for _ in 0..self.max_hex_steps {
924 let (mv, sad, improved) = Self::hex_step(ctx, config, current, &Self::HEX2);
925 if !improved {
926 break;
927 }
928 current = mv;
929 current_sad = sad;
930
931 if current_sad <= self.early_exit_threshold {
932 break;
933 }
934 }
935
936 for _ in 0..4 {
938 let (mv, sad, improved) = Self::hex_step(ctx, config, current, &Self::HEX1);
939 if !improved {
940 break;
941 }
942 current = mv;
943 current_sad = sad;
944 }
945
946 let sdsp = SmallDiamond::new();
948 let (refined, rsad, _) = sdsp.search(ctx, config, current);
949 if rsad < current_sad {
950 current = refined;
951 current_sad = rsad;
952 }
953
954 let cost = config.mv_cost.rd_cost(¤t, current_sad);
955 BlockMatch::new(current, current_sad, cost)
956 }
957}
958
959#[cfg(test)]
960mod tests {
961 use super::*;
962 use crate::motion::types::{BlockSize, SearchRange};
963
964 fn create_test_context<'a>(
965 src: &'a [u8],
966 ref_frame: &'a [u8],
967 width: usize,
968 height: usize,
969 ) -> SearchContext<'a> {
970 SearchContext::new(
971 src,
972 width,
973 ref_frame,
974 width,
975 BlockSize::Block8x8,
976 0,
977 0,
978 width,
979 height,
980 )
981 }
982
983 #[test]
984 fn test_small_diamond_pattern() {
985 let sdsp = SmallDiamond::new();
986 assert_eq!(sdsp.size(), 4);
987 assert_eq!(sdsp.get(0), Some((0, -1)));
988 assert_eq!(sdsp.get(4), None);
989 }
990
991 #[test]
992 fn test_large_diamond_pattern() {
993 let ldsp = LargeDiamond::new();
994 assert_eq!(ldsp.size(), 8);
995 assert_eq!(ldsp.get(0), Some((0, -2)));
996 assert_eq!(ldsp.get(8), None);
997 }
998
999 #[test]
1000 fn test_small_diamond_search() {
1001 let src = vec![100u8; 64];
1002 let mut ref_frame = vec![50u8; 144]; for row in 0..8 {
1006 for col in 0..8 {
1007 ref_frame[row * 12 + col + 1] = 100;
1008 }
1009 }
1010
1011 let ctx = SearchContext::new(&src, 8, &ref_frame, 12, BlockSize::Block8x8, 0, 0, 12, 12);
1012 let config = SearchConfig::with_range(SearchRange::symmetric(4));
1013
1014 let sdsp = SmallDiamond::new();
1015 let (mv, sad, _) = sdsp.search(&ctx, &config, MotionVector::zero());
1016
1017 assert_eq!(mv.full_pel_x(), 1);
1019 assert_eq!(mv.full_pel_y(), 0);
1020 assert_eq!(sad, 0);
1021 }
1022
1023 #[test]
1024 fn test_large_diamond_search() {
1025 let src = vec![100u8; 64];
1026 let mut ref_frame = vec![50u8; 256]; for row in 0..8 {
1030 for col in 0..8 {
1031 ref_frame[row * 16 + col + 2] = 100;
1032 }
1033 }
1034
1035 let ctx = SearchContext::new(&src, 8, &ref_frame, 16, BlockSize::Block8x8, 0, 0, 16, 16);
1036 let config = SearchConfig::with_range(SearchRange::symmetric(4));
1037
1038 let ldsp = LargeDiamond::new();
1039 let (mv, sad, _) = ldsp.search(&ctx, &config, MotionVector::zero());
1040
1041 assert_eq!(mv.full_pel_x(), 2);
1043 assert_eq!(mv.full_pel_y(), 0);
1044 assert_eq!(sad, 0);
1045 }
1046
1047 #[test]
1048 fn test_extended_diamond_search() {
1049 let src = vec![100u8; 64];
1050 let mut ref_frame = vec![50u8; 256];
1051
1052 for row in 0..8 {
1054 for col in 0..8 {
1055 ref_frame[row * 16 + col + 3] = 100;
1056 }
1057 }
1058
1059 let ctx = SearchContext::new(&src, 8, &ref_frame, 16, BlockSize::Block8x8, 0, 0, 16, 16);
1060 let config = SearchConfig::with_range(SearchRange::symmetric(4));
1061
1062 let extended = ExtendedDiamond::new();
1063 let (mv, sad) = extended.search(&ctx, &config, MotionVector::zero());
1064
1065 assert_eq!(mv.full_pel_x(), 3);
1067 assert_eq!(mv.full_pel_y(), 0);
1068 assert_eq!(sad, 0);
1069 }
1070
1071 #[test]
1072 fn test_adaptive_diamond_convergence() {
1073 let src = vec![100u8; 64];
1074 let mut ref_frame = vec![50u8; 256];
1075
1076 for row in 0..8 {
1078 for col in 0..8 {
1079 ref_frame[(row + 4) * 16 + col + 4] = 100;
1080 }
1081 }
1082
1083 let ctx = SearchContext::new(&src, 8, &ref_frame, 16, BlockSize::Block8x8, 0, 0, 16, 16);
1084 let config = SearchConfig::with_range(SearchRange::symmetric(8));
1085
1086 let adaptive = AdaptiveDiamond::new();
1087 let result = adaptive.search(&ctx, &config);
1088
1089 assert!(result.sad < 500);
1091 }
1092
1093 #[test]
1094 fn test_predictor_diamond() {
1095 let src = vec![100u8; 64];
1096 let mut ref_frame = vec![50u8; 256];
1097
1098 for row in 0..8 {
1100 for col in 0..8 {
1101 ref_frame[(row + 4) * 16 + col + 4] = 100;
1102 }
1103 }
1104
1105 let ctx = SearchContext::new(&src, 8, &ref_frame, 16, BlockSize::Block8x8, 0, 0, 16, 16);
1106 let config = SearchConfig::with_range(SearchRange::symmetric(8));
1107
1108 let predictor = PredictorDiamond::new();
1109 let predictors = [
1110 MotionVector::from_full_pel(3, 3), MotionVector::from_full_pel(0, 0),
1112 ];
1113
1114 let result = predictor.search_multi(&ctx, &config, &predictors);
1115
1116 assert!(result.sad < 200);
1118 }
1119
1120 #[test]
1121 fn test_cross_diamond() {
1122 let src = vec![100u8; 64];
1123 let mut ref_frame = vec![50u8; 256];
1124
1125 for row in 0..8 {
1127 for col in 0..8 {
1128 ref_frame[row * 16 + col + 4] = 100;
1129 }
1130 }
1131
1132 let ctx = SearchContext::new(&src, 8, &ref_frame, 16, BlockSize::Block8x8, 0, 0, 16, 16);
1133 let config = SearchConfig::with_range(SearchRange::symmetric(8));
1134
1135 let cross = CrossDiamond::new();
1136 let result = cross.search(&ctx, &config);
1137
1138 assert!(result.sad < 300);
1140 }
1141
1142 #[test]
1143 fn test_adaptive_diamond_early_switch() {
1144 let src = vec![100u8; 64];
1145 let ref_frame = vec![100u8; 256];
1146
1147 let ctx = SearchContext::new(&src, 8, &ref_frame, 16, BlockSize::Block8x8, 0, 0, 16, 16);
1149 let config = SearchConfig::with_range(SearchRange::symmetric(8));
1150
1151 let adaptive = AdaptiveDiamond::new().switch_threshold(100);
1152 let result = adaptive.search(&ctx, &config);
1153
1154 assert_eq!(result.sad, 0);
1156 }
1157
1158 #[test]
1159 fn test_diamond_builder_pattern() {
1160 let adaptive = AdaptiveDiamond::new()
1161 .max_iterations(16)
1162 .switch_threshold(1000);
1163
1164 assert_eq!(adaptive.max_ldsp_iterations, 16);
1165 assert_eq!(adaptive.switch_threshold, 1000);
1166 }
1167
1168 #[test]
1169 fn test_hexagonal_search_pattern_constants() {
1170 assert_eq!(HexagonalSearch::INNER_PATTERN.len(), 6);
1171 assert_eq!(HexagonalSearch::OUTER_PATTERN.len(), 6);
1172 }
1173
1174 #[test]
1175 fn test_hexagonal_search_finds_match() {
1176 let src = vec![100u8; 64];
1177 let mut ref_frame = vec![50u8; 400]; for row in 0..8usize {
1181 for col in 0..8usize {
1182 ref_frame[row * 20 + col + 2] = 100;
1183 }
1184 }
1185
1186 let ctx = SearchContext::new(&src, 8, &ref_frame, 20, BlockSize::Block8x8, 0, 0, 20, 20);
1187 let config = SearchConfig::with_range(SearchRange::symmetric(8));
1188
1189 let hex = HexagonalSearch::new();
1190 let result = hex.search(&ctx, &config);
1191 assert!(result.sad < 500, "Hex search should find a good match");
1192 }
1193
1194 #[test]
1195 fn test_hexagonal_search_zero_match() {
1196 let src = vec![128u8; 64];
1197 let ref_frame = vec![128u8; 400];
1198
1199 let ctx = SearchContext::new(&src, 8, &ref_frame, 20, BlockSize::Block8x8, 0, 0, 20, 20);
1200 let config = SearchConfig::with_range(SearchRange::symmetric(8));
1201
1202 let hex = HexagonalSearch::new();
1203 let result = hex.search(&ctx, &config);
1204 assert_eq!(result.sad, 0, "Perfect match should have SAD=0");
1205 }
1206
1207 #[test]
1208 fn test_hexagonal_search_max_iterations_builder() {
1209 let hex = HexagonalSearch::new().max_iterations(16);
1210 assert_eq!(hex.max_iterations, 16);
1211 }
1212
1213 #[test]
1214 fn test_umhex_search_finds_match() {
1215 let src = vec![200u8; 64];
1216 let mut ref_frame = vec![50u8; 400]; for row in 0..8usize {
1220 for col in 0..8usize {
1221 ref_frame[(row + 2) * 20 + col + 4] = 200;
1222 }
1223 }
1224
1225 let ctx = SearchContext::new(&src, 8, &ref_frame, 20, BlockSize::Block8x8, 0, 0, 20, 20);
1226 let config = SearchConfig::with_range(SearchRange::symmetric(8));
1227
1228 let umhex = UMHexSearch::new();
1229 let result = umhex.search(&ctx, &config);
1230 assert!(result.sad < 1000, "UMHex should find a reasonable match");
1231 }
1232
1233 #[test]
1234 fn test_umhex_zero_sad_early_exit() {
1235 let src = vec![77u8; 64];
1236 let ref_frame = vec![77u8; 400];
1237
1238 let ctx = SearchContext::new(&src, 8, &ref_frame, 20, BlockSize::Block8x8, 0, 0, 20, 20);
1239 let config = SearchConfig::with_range(SearchRange::symmetric(8));
1240
1241 let umhex = UMHexSearch::new();
1242 let result = umhex.search(&ctx, &config);
1243 assert_eq!(result.sad, 0);
1244 }
1245
1246 #[test]
1247 fn test_umhex_predictor_helps() {
1248 let src = vec![150u8; 64];
1249 let mut ref_frame = vec![0u8; 400];
1250
1251 for row in 0..8usize {
1253 for col in 0..8usize {
1254 ref_frame[(row + 6) * 20 + col + 6] = 150;
1255 }
1256 }
1257
1258 let ctx = SearchContext::new(&src, 8, &ref_frame, 20, BlockSize::Block8x8, 0, 0, 20, 20);
1259 let config = SearchConfig::with_range(SearchRange::symmetric(10));
1260
1261 let umhex = UMHexSearch::new();
1262 let predictor = MotionVector::from_full_pel(5, 5);
1263 let result = umhex.search_with_predictor(&ctx, &config, predictor);
1264 assert!(result.sad < 2000);
1266 }
1267
1268 #[test]
1269 fn test_umhex_builder_pattern() {
1270 let umhex = UMHexSearch::new()
1271 .max_hex_steps(32)
1272 .cross_range(16)
1273 .early_exit_threshold(8);
1274
1275 assert_eq!(umhex.max_hex_steps, 32);
1276 assert_eq!(umhex.cross_range, 16);
1277 assert_eq!(umhex.early_exit_threshold, 8);
1278 }
1279}