Skip to main content

oximedia_codec/motion/
hierarchical.rs

1//! Hierarchical motion estimation using image pyramids.
2//!
3//! This module provides multi-resolution motion estimation that:
4//! 1. Builds a pyramid of downsampled images
5//! 2. Searches at coarse resolution first
6//! 3. Refines motion vectors at finer resolutions
7//!
8//! This approach significantly reduces computational cost while
9//! maintaining good accuracy for large motion vectors.
10
11#![forbid(unsafe_code)]
12#![allow(dead_code)]
13#![allow(clippy::too_many_arguments)]
14#![allow(clippy::cast_possible_truncation)]
15#![allow(clippy::cast_sign_loss)]
16#![allow(clippy::cast_possible_wrap)]
17#![allow(clippy::missing_panics_doc)]
18#![allow(clippy::must_use_candidate)]
19#![allow(clippy::let_and_return)]
20#![allow(clippy::manual_let_else)]
21
22use super::diamond::AdaptiveDiamond;
23use super::search::{MotionSearch, SearchConfig, SearchContext};
24use super::types::{BlockMatch, BlockSize, MotionVector, SearchRange};
25
26/// Maximum number of pyramid levels.
27pub const MAX_PYRAMID_LEVELS: usize = 4;
28
29/// Minimum dimension for pyramid level.
30pub const MIN_PYRAMID_DIMENSION: usize = 16;
31
32/// Configuration for hierarchical search.
33#[derive(Clone, Debug)]
34pub struct HierarchicalConfig {
35    /// Number of pyramid levels.
36    pub levels: usize,
37    /// Search range at coarsest level.
38    pub coarse_range: SearchRange,
39    /// Refinement search range at each finer level.
40    pub refine_range: SearchRange,
41    /// Enable adaptive level selection.
42    pub adaptive_levels: bool,
43}
44
45impl Default for HierarchicalConfig {
46    fn default() -> Self {
47        Self {
48            levels: 3,
49            coarse_range: SearchRange::symmetric(16),
50            refine_range: SearchRange::symmetric(4),
51            adaptive_levels: true,
52        }
53    }
54}
55
56impl HierarchicalConfig {
57    /// Creates a new hierarchical config.
58    #[must_use]
59    pub const fn new(levels: usize) -> Self {
60        Self {
61            levels,
62            coarse_range: SearchRange::symmetric(16),
63            refine_range: SearchRange::symmetric(4),
64            adaptive_levels: true,
65        }
66    }
67
68    /// Sets the number of levels.
69    #[must_use]
70    pub const fn levels(mut self, levels: usize) -> Self {
71        self.levels = levels;
72        self
73    }
74
75    /// Sets the coarse level search range.
76    #[must_use]
77    pub const fn coarse_range(mut self, range: SearchRange) -> Self {
78        self.coarse_range = range;
79        self
80    }
81
82    /// Sets the refinement range.
83    #[must_use]
84    pub const fn refine_range(mut self, range: SearchRange) -> Self {
85        self.refine_range = range;
86        self
87    }
88}
89
90/// A single level of the image pyramid.
91#[derive(Clone, Debug)]
92pub struct PyramidLevel {
93    /// Pixel data.
94    pub data: Vec<u8>,
95    /// Width in pixels.
96    pub width: usize,
97    /// Height in pixels.
98    pub height: usize,
99    /// Stride (bytes per row).
100    pub stride: usize,
101    /// Scale factor from original (1, 2, 4, ...).
102    pub scale: usize,
103}
104
105impl PyramidLevel {
106    /// Creates a new pyramid level.
107    #[must_use]
108    pub fn new(width: usize, height: usize, scale: usize) -> Self {
109        let stride = width;
110        Self {
111            data: vec![0u8; stride * height],
112            width,
113            height,
114            stride,
115            scale,
116        }
117    }
118
119    /// Creates a pyramid level from existing data.
120    #[must_use]
121    pub fn from_data(data: Vec<u8>, width: usize, height: usize, scale: usize) -> Self {
122        let stride = width;
123        Self {
124            data,
125            width,
126            height,
127            stride,
128            scale,
129        }
130    }
131
132    /// Gets the pixel value at (x, y).
133    #[must_use]
134    pub fn get_pixel(&self, x: usize, y: usize) -> u8 {
135        if x < self.width && y < self.height {
136            self.data[y * self.stride + x]
137        } else {
138            0
139        }
140    }
141
142    /// Sets the pixel value at (x, y).
143    pub fn set_pixel(&mut self, x: usize, y: usize, value: u8) {
144        if x < self.width && y < self.height {
145            self.data[y * self.stride + x] = value;
146        }
147    }
148
149    /// Downsamples from another level (2:1).
150    pub fn downsample_from(&mut self, src: &PyramidLevel) {
151        for y in 0..self.height {
152            for x in 0..self.width {
153                let src_x = x * 2;
154                let src_y = y * 2;
155
156                // 2x2 box filter
157                let p00 = u32::from(src.get_pixel(src_x, src_y));
158                let p01 = u32::from(src.get_pixel(src_x + 1, src_y));
159                let p10 = u32::from(src.get_pixel(src_x, src_y + 1));
160                let p11 = u32::from(src.get_pixel(src_x + 1, src_y + 1));
161
162                let avg = ((p00 + p01 + p10 + p11 + 2) / 4) as u8;
163                self.set_pixel(x, y, avg);
164            }
165        }
166    }
167
168    /// Returns a slice of data for a block.
169    #[must_use]
170    pub fn block_data(&self, x: usize, y: usize) -> &[u8] {
171        let offset = y * self.stride + x;
172        &self.data[offset..]
173    }
174}
175
176/// Image pyramid for multi-resolution search.
177#[derive(Clone, Debug)]
178pub struct ImagePyramid {
179    /// Pyramid levels (index 0 = original resolution).
180    levels: Vec<PyramidLevel>,
181}
182
183impl ImagePyramid {
184    /// Creates a new empty pyramid.
185    #[must_use]
186    pub const fn new() -> Self {
187        Self { levels: Vec::new() }
188    }
189
190    /// Builds the pyramid from source image data.
191    pub fn build(&mut self, src: &[u8], width: usize, height: usize, num_levels: usize) {
192        self.levels.clear();
193
194        // Level 0: original resolution (copy)
195        let level0 = PyramidLevel::from_data(src.to_vec(), width, height, 1);
196        self.levels.push(level0);
197
198        // Build downsampled levels
199        let mut cur_width = width;
200        let mut cur_height = height;
201        let mut cur_scale = 1;
202
203        for _ in 1..num_levels {
204            cur_width /= 2;
205            cur_height /= 2;
206            cur_scale *= 2;
207
208            if cur_width < MIN_PYRAMID_DIMENSION || cur_height < MIN_PYRAMID_DIMENSION {
209                break;
210            }
211
212            let mut level = PyramidLevel::new(cur_width, cur_height, cur_scale);
213            level.downsample_from(
214                self.levels
215                    .last()
216                    .expect("levels is non-empty by construction"),
217            );
218            self.levels.push(level);
219        }
220    }
221
222    /// Returns the number of levels.
223    #[must_use]
224    pub fn num_levels(&self) -> usize {
225        self.levels.len()
226    }
227
228    /// Gets a pyramid level.
229    #[must_use]
230    pub fn get_level(&self, index: usize) -> Option<&PyramidLevel> {
231        self.levels.get(index)
232    }
233
234    /// Gets the coarsest level.
235    #[must_use]
236    pub fn coarsest(&self) -> Option<&PyramidLevel> {
237        self.levels.last()
238    }
239
240    /// Gets the finest level (original).
241    #[must_use]
242    pub fn finest(&self) -> Option<&PyramidLevel> {
243        self.levels.first()
244    }
245}
246
247impl Default for ImagePyramid {
248    fn default() -> Self {
249        Self::new()
250    }
251}
252
253/// Hierarchical motion search using image pyramids.
254#[derive(Clone, Debug)]
255pub struct HierarchicalSearch {
256    /// Source image pyramid.
257    src_pyramid: ImagePyramid,
258    /// Reference image pyramid.
259    ref_pyramid: ImagePyramid,
260    /// Search configuration.
261    config: HierarchicalConfig,
262    /// Underlying search algorithm.
263    searcher: AdaptiveDiamond,
264}
265
266impl Default for HierarchicalSearch {
267    fn default() -> Self {
268        Self::new()
269    }
270}
271
272impl HierarchicalSearch {
273    /// Creates a new hierarchical search.
274    #[must_use]
275    pub fn new() -> Self {
276        Self {
277            src_pyramid: ImagePyramid::new(),
278            ref_pyramid: ImagePyramid::new(),
279            config: HierarchicalConfig::default(),
280            searcher: AdaptiveDiamond::new(),
281        }
282    }
283
284    /// Sets the configuration.
285    #[must_use]
286    pub fn with_config(mut self, config: HierarchicalConfig) -> Self {
287        self.config = config;
288        self
289    }
290
291    /// Builds pyramids from source and reference frames.
292    pub fn build_pyramids(
293        &mut self,
294        src: &[u8],
295        src_width: usize,
296        src_height: usize,
297        reference: &[u8],
298        ref_width: usize,
299        ref_height: usize,
300    ) {
301        let levels = self.config.levels.min(MAX_PYRAMID_LEVELS);
302        self.src_pyramid.build(src, src_width, src_height, levels);
303        self.ref_pyramid
304            .build(reference, ref_width, ref_height, levels);
305    }
306
307    /// Performs hierarchical search.
308    ///
309    /// Searches from coarsest to finest level, using the result from
310    /// each level as the starting point for the next.
311    pub fn search_hierarchical(
312        &self,
313        block_x: usize,
314        block_y: usize,
315        block_size: BlockSize,
316        search_config: &SearchConfig,
317    ) -> BlockMatch {
318        let num_levels = self
319            .src_pyramid
320            .num_levels()
321            .min(self.ref_pyramid.num_levels());
322
323        if num_levels == 0 {
324            return BlockMatch::worst();
325        }
326
327        // Start at coarsest level
328        let mut current_mv = MotionVector::zero();
329
330        // Search from coarsest to finest
331        for level_idx in (0..num_levels).rev() {
332            let src_level = match self.src_pyramid.get_level(level_idx) {
333                Some(l) => l,
334                None => continue,
335            };
336            let ref_level = match self.ref_pyramid.get_level(level_idx) {
337                Some(l) => l,
338                None => continue,
339            };
340
341            // Scale block position for this level
342            let scale = src_level.scale;
343            let scaled_x = block_x / scale;
344            let scaled_y = block_y / scale;
345            let scaled_width = block_size.width() / scale;
346            let scaled_height = block_size.height() / scale;
347
348            // Skip if block too small
349            if scaled_width < 4 || scaled_height < 4 {
350                continue;
351            }
352
353            // Determine search range for this level
354            let level_range = if level_idx == num_levels - 1 {
355                self.config.coarse_range
356            } else {
357                self.config.refine_range
358            };
359
360            // Create search config for this level
361            let level_config = SearchConfig {
362                range: level_range,
363                ..search_config.clone()
364            };
365
366            // Scale MV from previous level
367            let scaled_mv = MotionVector::from_full_pel(
368                current_mv.full_pel_x() / scale as i32,
369                current_mv.full_pel_y() / scale as i32,
370            );
371
372            // Create context for this level
373            let src_offset = scaled_y * src_level.stride + scaled_x;
374            if src_offset >= src_level.data.len() {
375                continue;
376            }
377
378            let ctx = SearchContext::new(
379                &src_level.data[src_offset..],
380                src_level.stride,
381                &ref_level.data,
382                ref_level.stride,
383                BlockSize::Block8x8, // Use fixed size for pyramid levels
384                scaled_x,
385                scaled_y,
386                ref_level.width,
387                ref_level.height,
388            );
389
390            // Search at this level
391            let result = self
392                .searcher
393                .search_with_predictor(&ctx, &level_config, scaled_mv);
394
395            // Scale MV back up for next level
396            current_mv = MotionVector::from_full_pel(
397                result.mv.full_pel_x() * scale as i32,
398                result.mv.full_pel_y() * scale as i32,
399            );
400        }
401
402        // Final search at full resolution
403        if let (Some(src_level), Some(ref_level)) =
404            (self.src_pyramid.finest(), self.ref_pyramid.finest())
405        {
406            let src_offset = block_y * src_level.stride + block_x;
407            if src_offset < src_level.data.len() {
408                let ctx = SearchContext::new(
409                    &src_level.data[src_offset..],
410                    src_level.stride,
411                    &ref_level.data,
412                    ref_level.stride,
413                    block_size,
414                    block_x,
415                    block_y,
416                    ref_level.width,
417                    ref_level.height,
418                );
419
420                let final_config = SearchConfig {
421                    range: self.config.refine_range,
422                    ..search_config.clone()
423                };
424
425                return self
426                    .searcher
427                    .search_with_predictor(&ctx, &final_config, current_mv);
428            }
429        }
430
431        let cost = search_config.mv_cost.rd_cost(&current_mv, u32::MAX);
432        BlockMatch::new(current_mv, u32::MAX, cost)
433    }
434
435    /// Calculates the optimal number of pyramid levels.
436    #[must_use]
437    pub fn calculate_levels(width: usize, height: usize, max_levels: usize) -> usize {
438        let min_dim = width.min(height);
439        let mut levels = 1;
440
441        let mut size = min_dim;
442        while size >= MIN_PYRAMID_DIMENSION * 2 && levels < max_levels {
443            size /= 2;
444            levels += 1;
445        }
446
447        levels
448    }
449}
450
451/// Coarse-to-fine refinement helper.
452#[derive(Clone, Debug, Default)]
453pub struct CoarseToFineRefiner {
454    /// Refinement steps at each scale.
455    steps: Vec<RefinementStep>,
456}
457
458/// A single refinement step.
459#[derive(Clone, Debug)]
460pub struct RefinementStep {
461    /// Scale factor (1 = full resolution).
462    pub scale: usize,
463    /// Search range for this step.
464    pub range: SearchRange,
465    /// Number of iterations.
466    pub iterations: u32,
467}
468
469impl Default for RefinementStep {
470    fn default() -> Self {
471        Self {
472            scale: 1,
473            range: SearchRange::symmetric(2),
474            iterations: 4,
475        }
476    }
477}
478
479impl CoarseToFineRefiner {
480    /// Creates a new refiner.
481    #[must_use]
482    pub fn new() -> Self {
483        Self { steps: Vec::new() }
484    }
485
486    /// Adds a refinement step.
487    #[must_use]
488    pub fn add_step(mut self, scale: usize, range: i32, iterations: u32) -> Self {
489        self.steps.push(RefinementStep {
490            scale,
491            range: SearchRange::symmetric(range),
492            iterations,
493        });
494        self
495    }
496
497    /// Creates default coarse-to-fine steps.
498    #[must_use]
499    pub fn default_steps() -> Self {
500        Self::new()
501            .add_step(4, 8, 8) // 1/4 resolution, wide search
502            .add_step(2, 4, 6) // 1/2 resolution, medium search
503            .add_step(1, 2, 4) // Full resolution, fine search
504    }
505
506    /// Returns the refinement steps.
507    #[must_use]
508    pub fn steps(&self) -> &[RefinementStep] {
509        &self.steps
510    }
511
512    /// Scales a motion vector between levels.
513    #[must_use]
514    pub const fn scale_mv(mv: MotionVector, from_scale: usize, to_scale: usize) -> MotionVector {
515        if from_scale == to_scale {
516            return mv;
517        }
518
519        if from_scale > to_scale {
520            // Upscaling (coarse to fine)
521            let factor = (from_scale / to_scale) as i32;
522            MotionVector::new(mv.dx * factor, mv.dy * factor)
523        } else {
524            // Downscaling (fine to coarse)
525            let factor = (to_scale / from_scale) as i32;
526            MotionVector::new(mv.dx / factor, mv.dy / factor)
527        }
528    }
529}
530
531/// Resolution scaling utilities.
532pub struct ResolutionScaler;
533
534impl ResolutionScaler {
535    /// Downsamples image by factor of 2.
536    pub fn downsample_2x(src: &[u8], width: usize, height: usize) -> Vec<u8> {
537        let new_width = width / 2;
538        let new_height = height / 2;
539        let mut dst = vec![0u8; new_width * new_height];
540
541        for y in 0..new_height {
542            for x in 0..new_width {
543                let src_x = x * 2;
544                let src_y = y * 2;
545
546                let p00 = u32::from(src[src_y * width + src_x]);
547                let p01 = u32::from(src[src_y * width + src_x + 1]);
548                let p10 = u32::from(src[(src_y + 1) * width + src_x]);
549                let p11 = u32::from(src[(src_y + 1) * width + src_x + 1]);
550
551                dst[y * new_width + x] = ((p00 + p01 + p10 + p11 + 2) / 4) as u8;
552            }
553        }
554
555        dst
556    }
557
558    /// Downsamples image by factor of 4.
559    pub fn downsample_4x(src: &[u8], width: usize, height: usize) -> Vec<u8> {
560        let half = Self::downsample_2x(src, width, height);
561        Self::downsample_2x(&half, width / 2, height / 2)
562    }
563
564    /// Upsamples motion vector coordinates.
565    #[must_use]
566    pub const fn upsample_mv(mv: MotionVector, factor: i32) -> MotionVector {
567        MotionVector::new(mv.dx * factor, mv.dy * factor)
568    }
569
570    /// Downsamples motion vector coordinates.
571    #[must_use]
572    pub const fn downsample_mv(mv: MotionVector, factor: i32) -> MotionVector {
573        MotionVector::new(mv.dx / factor, mv.dy / factor)
574    }
575
576    /// Scales block position.
577    #[must_use]
578    pub const fn scale_position(x: usize, y: usize, scale: usize) -> (usize, usize) {
579        (x / scale, y / scale)
580    }
581}
582
583#[cfg(test)]
584mod tests {
585    use super::*;
586
587    #[test]
588    fn test_pyramid_level_creation() {
589        let level = PyramidLevel::new(64, 64, 1);
590        assert_eq!(level.width, 64);
591        assert_eq!(level.height, 64);
592        assert_eq!(level.scale, 1);
593        assert_eq!(level.data.len(), 64 * 64);
594    }
595
596    #[test]
597    fn test_pyramid_level_pixel_access() {
598        let mut level = PyramidLevel::new(8, 8, 1);
599        level.set_pixel(3, 4, 128);
600        assert_eq!(level.get_pixel(3, 4), 128);
601        assert_eq!(level.get_pixel(0, 0), 0);
602    }
603
604    #[test]
605    fn test_pyramid_level_downsample() {
606        let mut src_level = PyramidLevel::new(8, 8, 1);
607        // Fill with gradient
608        for y in 0..8 {
609            for x in 0..8 {
610                src_level.set_pixel(x, y, ((x + y) * 16) as u8);
611            }
612        }
613
614        let mut dst_level = PyramidLevel::new(4, 4, 2);
615        dst_level.downsample_from(&src_level);
616
617        // Check that values are averaged
618        assert!(dst_level.get_pixel(0, 0) > 0);
619        assert!(dst_level.get_pixel(1, 1) > dst_level.get_pixel(0, 0));
620    }
621
622    #[test]
623    fn test_image_pyramid_build() {
624        let src = vec![128u8; 64 * 64];
625        let mut pyramid = ImagePyramid::new();
626        pyramid.build(&src, 64, 64, 3);
627
628        assert_eq!(pyramid.num_levels(), 3);
629
630        // Check level dimensions
631        assert_eq!(pyramid.get_level(0).map(|l| l.width), Some(64));
632        assert_eq!(pyramid.get_level(1).map(|l| l.width), Some(32));
633        assert_eq!(pyramid.get_level(2).map(|l| l.width), Some(16));
634    }
635
636    #[test]
637    fn test_image_pyramid_min_size() {
638        let src = vec![128u8; 32 * 32];
639        let mut pyramid = ImagePyramid::new();
640        pyramid.build(&src, 32, 32, 5);
641
642        // Should stop at MIN_PYRAMID_DIMENSION
643        assert!(pyramid.num_levels() <= 2);
644    }
645
646    #[test]
647    fn test_hierarchical_config() {
648        let config = HierarchicalConfig::new(4)
649            .coarse_range(SearchRange::symmetric(32))
650            .refine_range(SearchRange::symmetric(8));
651
652        assert_eq!(config.levels, 4);
653        assert_eq!(config.coarse_range.horizontal, 32);
654        assert_eq!(config.refine_range.horizontal, 8);
655    }
656
657    #[test]
658    fn test_hierarchical_search_creation() {
659        let search = HierarchicalSearch::new().with_config(HierarchicalConfig::new(3));
660
661        assert_eq!(search.config.levels, 3);
662    }
663
664    #[test]
665    fn test_calculate_pyramid_levels() {
666        assert_eq!(HierarchicalSearch::calculate_levels(128, 128, 4), 4);
667        assert_eq!(HierarchicalSearch::calculate_levels(64, 64, 4), 3);
668        assert_eq!(HierarchicalSearch::calculate_levels(32, 32, 4), 2);
669    }
670
671    #[test]
672    fn test_coarse_to_fine_refiner() {
673        let refiner = CoarseToFineRefiner::default_steps();
674        assert_eq!(refiner.steps().len(), 3);
675    }
676
677    #[test]
678    fn test_scale_mv() {
679        let mv = MotionVector::new(16, 32);
680
681        // Coarse to fine (upscale)
682        let scaled_up = CoarseToFineRefiner::scale_mv(mv, 2, 1);
683        assert_eq!(scaled_up.dx, 32);
684        assert_eq!(scaled_up.dy, 64);
685
686        // Fine to coarse (downscale)
687        let scaled_down = CoarseToFineRefiner::scale_mv(mv, 1, 2);
688        assert_eq!(scaled_down.dx, 8);
689        assert_eq!(scaled_down.dy, 16);
690    }
691
692    #[test]
693    fn test_resolution_scaler_downsample() {
694        // Create 4x4 image with known values
695        let src = vec![
696            100, 100, 200, 200, 100, 100, 200, 200, 50, 50, 150, 150, 50, 50, 150, 150,
697        ];
698
699        let dst = ResolutionScaler::downsample_2x(&src, 4, 4);
700        assert_eq!(dst.len(), 4);
701
702        // Check averaged values
703        assert_eq!(dst[0], 100); // (100+100+100+100)/4
704        assert_eq!(dst[1], 200); // (200+200+200+200)/4
705        assert_eq!(dst[2], 50); // (50+50+50+50)/4
706        assert_eq!(dst[3], 150); // (150+150+150+150)/4
707    }
708
709    #[test]
710    fn test_resolution_scaler_mv() {
711        let mv = MotionVector::new(8, 16);
712
713        let up = ResolutionScaler::upsample_mv(mv, 2);
714        assert_eq!(up.dx, 16);
715        assert_eq!(up.dy, 32);
716
717        let down = ResolutionScaler::downsample_mv(mv, 2);
718        assert_eq!(down.dx, 4);
719        assert_eq!(down.dy, 8);
720    }
721
722    #[test]
723    fn test_hierarchical_search_integration() {
724        let src = vec![100u8; 64 * 64];
725        let reference = vec![100u8; 64 * 64];
726
727        let mut search = HierarchicalSearch::new().with_config(HierarchicalConfig::new(3));
728
729        search.build_pyramids(&src, 64, 64, &reference, 64, 64);
730
731        let config = SearchConfig::default();
732        let result = search.search_hierarchical(0, 0, BlockSize::Block8x8, &config);
733
734        // Perfect match at origin
735        assert_eq!(result.mv.full_pel_x(), 0);
736        assert_eq!(result.mv.full_pel_y(), 0);
737    }
738
739    #[test]
740    fn test_refinement_step() {
741        let step = RefinementStep::default();
742        assert_eq!(step.scale, 1);
743        assert_eq!(step.iterations, 4);
744    }
745}