Skip to main content

oximedia_codec/motion/
hierarchical.rs

1//! Hierarchical motion estimation using image pyramids.
2//!
3//! This module provides multi-resolution motion estimation that:
4//! 1. Builds a pyramid of downsampled images
5//! 2. Searches at coarse resolution first
6//! 3. Refines motion vectors at finer resolutions
7//!
8//! This approach significantly reduces computational cost while
9//! maintaining good accuracy for large motion vectors.
10
11#![forbid(unsafe_code)]
12#![allow(dead_code)]
13#![allow(clippy::too_many_arguments)]
14#![allow(clippy::cast_possible_truncation)]
15#![allow(clippy::cast_sign_loss)]
16#![allow(clippy::cast_possible_wrap)]
17#![allow(clippy::missing_panics_doc)]
18#![allow(clippy::must_use_candidate)]
19#![allow(clippy::let_and_return)]
20#![allow(clippy::manual_let_else)]
21
22use super::diamond::AdaptiveDiamond;
23use super::search::{MotionSearch, SearchConfig, SearchContext};
24use super::types::{BlockMatch, BlockSize, MotionVector, SearchRange};
25
26/// Maximum number of pyramid levels.
27pub const MAX_PYRAMID_LEVELS: usize = 4;
28
29/// Minimum dimension for pyramid level.
30pub const MIN_PYRAMID_DIMENSION: usize = 16;
31
32/// Configuration for hierarchical search.
33#[derive(Clone, Debug)]
34pub struct HierarchicalConfig {
35    /// Number of pyramid levels.
36    pub levels: usize,
37    /// Search range at coarsest level.
38    pub coarse_range: SearchRange,
39    /// Refinement search range at each finer level.
40    pub refine_range: SearchRange,
41    /// Enable adaptive level selection.
42    pub adaptive_levels: bool,
43}
44
45impl Default for HierarchicalConfig {
46    fn default() -> Self {
47        Self {
48            levels: 3,
49            coarse_range: SearchRange::symmetric(16),
50            refine_range: SearchRange::symmetric(4),
51            adaptive_levels: true,
52        }
53    }
54}
55
56impl HierarchicalConfig {
57    /// Creates a new hierarchical config.
58    #[must_use]
59    pub const fn new(levels: usize) -> Self {
60        Self {
61            levels,
62            coarse_range: SearchRange::symmetric(16),
63            refine_range: SearchRange::symmetric(4),
64            adaptive_levels: true,
65        }
66    }
67
68    /// Sets the number of levels.
69    #[must_use]
70    pub const fn levels(mut self, levels: usize) -> Self {
71        self.levels = levels;
72        self
73    }
74
75    /// Sets the coarse level search range.
76    #[must_use]
77    pub const fn coarse_range(mut self, range: SearchRange) -> Self {
78        self.coarse_range = range;
79        self
80    }
81
82    /// Sets the refinement range.
83    #[must_use]
84    pub const fn refine_range(mut self, range: SearchRange) -> Self {
85        self.refine_range = range;
86        self
87    }
88}
89
90/// A single level of the image pyramid.
91#[derive(Clone, Debug)]
92pub struct PyramidLevel {
93    /// Pixel data.
94    pub data: Vec<u8>,
95    /// Width in pixels.
96    pub width: usize,
97    /// Height in pixels.
98    pub height: usize,
99    /// Stride (bytes per row).
100    pub stride: usize,
101    /// Scale factor from original (1, 2, 4, ...).
102    pub scale: usize,
103}
104
105impl PyramidLevel {
106    /// Creates a new pyramid level.
107    #[must_use]
108    pub fn new(width: usize, height: usize, scale: usize) -> Self {
109        let stride = width;
110        Self {
111            data: vec![0u8; stride * height],
112            width,
113            height,
114            stride,
115            scale,
116        }
117    }
118
119    /// Creates a pyramid level from existing data.
120    #[must_use]
121    pub fn from_data(data: Vec<u8>, width: usize, height: usize, scale: usize) -> Self {
122        let stride = width;
123        Self {
124            data,
125            width,
126            height,
127            stride,
128            scale,
129        }
130    }
131
132    /// Gets the pixel value at (x, y).
133    #[must_use]
134    pub fn get_pixel(&self, x: usize, y: usize) -> u8 {
135        if x < self.width && y < self.height {
136            self.data[y * self.stride + x]
137        } else {
138            0
139        }
140    }
141
142    /// Sets the pixel value at (x, y).
143    pub fn set_pixel(&mut self, x: usize, y: usize, value: u8) {
144        if x < self.width && y < self.height {
145            self.data[y * self.stride + x] = value;
146        }
147    }
148
149    /// Downsamples from another level (2:1).
150    pub fn downsample_from(&mut self, src: &PyramidLevel) {
151        for y in 0..self.height {
152            for x in 0..self.width {
153                let src_x = x * 2;
154                let src_y = y * 2;
155
156                // 2x2 box filter
157                let p00 = u32::from(src.get_pixel(src_x, src_y));
158                let p01 = u32::from(src.get_pixel(src_x + 1, src_y));
159                let p10 = u32::from(src.get_pixel(src_x, src_y + 1));
160                let p11 = u32::from(src.get_pixel(src_x + 1, src_y + 1));
161
162                let avg = ((p00 + p01 + p10 + p11 + 2) / 4) as u8;
163                self.set_pixel(x, y, avg);
164            }
165        }
166    }
167
168    /// Returns a slice of data for a block.
169    #[must_use]
170    pub fn block_data(&self, x: usize, y: usize) -> &[u8] {
171        let offset = y * self.stride + x;
172        &self.data[offset..]
173    }
174}
175
176/// Image pyramid for multi-resolution search.
177#[derive(Clone, Debug)]
178pub struct ImagePyramid {
179    /// Pyramid levels (index 0 = original resolution).
180    levels: Vec<PyramidLevel>,
181}
182
183impl ImagePyramid {
184    /// Creates a new empty pyramid.
185    #[must_use]
186    pub const fn new() -> Self {
187        Self { levels: Vec::new() }
188    }
189
190    /// Builds the pyramid from source image data.
191    pub fn build(&mut self, src: &[u8], width: usize, height: usize, num_levels: usize) {
192        self.levels.clear();
193
194        // Level 0: original resolution (copy)
195        let level0 = PyramidLevel::from_data(src.to_vec(), width, height, 1);
196        self.levels.push(level0);
197
198        // Build downsampled levels
199        let mut cur_width = width;
200        let mut cur_height = height;
201        let mut cur_scale = 1;
202
203        for _ in 1..num_levels {
204            cur_width /= 2;
205            cur_height /= 2;
206            cur_scale *= 2;
207
208            if cur_width < MIN_PYRAMID_DIMENSION || cur_height < MIN_PYRAMID_DIMENSION {
209                break;
210            }
211
212            if let Some(prev) = self.levels.last() {
213                let mut level = PyramidLevel::new(cur_width, cur_height, cur_scale);
214                level.downsample_from(prev);
215                self.levels.push(level);
216            }
217        }
218    }
219
220    /// Returns the number of levels.
221    #[must_use]
222    pub fn num_levels(&self) -> usize {
223        self.levels.len()
224    }
225
226    /// Gets a pyramid level.
227    #[must_use]
228    pub fn get_level(&self, index: usize) -> Option<&PyramidLevel> {
229        self.levels.get(index)
230    }
231
232    /// Gets the coarsest level.
233    #[must_use]
234    pub fn coarsest(&self) -> Option<&PyramidLevel> {
235        self.levels.last()
236    }
237
238    /// Gets the finest level (original).
239    #[must_use]
240    pub fn finest(&self) -> Option<&PyramidLevel> {
241        self.levels.first()
242    }
243}
244
245impl Default for ImagePyramid {
246    fn default() -> Self {
247        Self::new()
248    }
249}
250
251/// Hierarchical motion search using image pyramids.
252#[derive(Clone, Debug)]
253pub struct HierarchicalSearch {
254    /// Source image pyramid.
255    src_pyramid: ImagePyramid,
256    /// Reference image pyramid.
257    ref_pyramid: ImagePyramid,
258    /// Search configuration.
259    config: HierarchicalConfig,
260    /// Underlying search algorithm.
261    searcher: AdaptiveDiamond,
262}
263
264impl Default for HierarchicalSearch {
265    fn default() -> Self {
266        Self::new()
267    }
268}
269
270impl HierarchicalSearch {
271    /// Creates a new hierarchical search.
272    #[must_use]
273    pub fn new() -> Self {
274        Self {
275            src_pyramid: ImagePyramid::new(),
276            ref_pyramid: ImagePyramid::new(),
277            config: HierarchicalConfig::default(),
278            searcher: AdaptiveDiamond::new(),
279        }
280    }
281
282    /// Sets the configuration.
283    #[must_use]
284    pub fn with_config(mut self, config: HierarchicalConfig) -> Self {
285        self.config = config;
286        self
287    }
288
289    /// Builds pyramids from source and reference frames.
290    pub fn build_pyramids(
291        &mut self,
292        src: &[u8],
293        src_width: usize,
294        src_height: usize,
295        reference: &[u8],
296        ref_width: usize,
297        ref_height: usize,
298    ) {
299        let levels = self.config.levels.min(MAX_PYRAMID_LEVELS);
300        self.src_pyramid.build(src, src_width, src_height, levels);
301        self.ref_pyramid
302            .build(reference, ref_width, ref_height, levels);
303    }
304
305    /// Performs hierarchical search.
306    ///
307    /// Searches from coarsest to finest level, using the result from
308    /// each level as the starting point for the next.
309    pub fn search_hierarchical(
310        &self,
311        block_x: usize,
312        block_y: usize,
313        block_size: BlockSize,
314        search_config: &SearchConfig,
315    ) -> BlockMatch {
316        let num_levels = self
317            .src_pyramid
318            .num_levels()
319            .min(self.ref_pyramid.num_levels());
320
321        if num_levels == 0 {
322            return BlockMatch::worst();
323        }
324
325        // Start at coarsest level
326        let mut current_mv = MotionVector::zero();
327
328        // Search from coarsest to finest
329        for level_idx in (0..num_levels).rev() {
330            let src_level = match self.src_pyramid.get_level(level_idx) {
331                Some(l) => l,
332                None => continue,
333            };
334            let ref_level = match self.ref_pyramid.get_level(level_idx) {
335                Some(l) => l,
336                None => continue,
337            };
338
339            // Scale block position for this level
340            let scale = src_level.scale;
341            let scaled_x = block_x / scale;
342            let scaled_y = block_y / scale;
343            let scaled_width = block_size.width() / scale;
344            let scaled_height = block_size.height() / scale;
345
346            // Skip if block too small
347            if scaled_width < 4 || scaled_height < 4 {
348                continue;
349            }
350
351            // Determine search range for this level
352            let level_range = if level_idx == num_levels - 1 {
353                self.config.coarse_range
354            } else {
355                self.config.refine_range
356            };
357
358            // Create search config for this level
359            let level_config = SearchConfig {
360                range: level_range,
361                ..search_config.clone()
362            };
363
364            // Scale MV from previous level
365            let scaled_mv = MotionVector::from_full_pel(
366                current_mv.full_pel_x() / scale as i32,
367                current_mv.full_pel_y() / scale as i32,
368            );
369
370            // Create context for this level
371            let src_offset = scaled_y * src_level.stride + scaled_x;
372            if src_offset >= src_level.data.len() {
373                continue;
374            }
375
376            let ctx = SearchContext::new(
377                &src_level.data[src_offset..],
378                src_level.stride,
379                &ref_level.data,
380                ref_level.stride,
381                BlockSize::Block8x8, // Use fixed size for pyramid levels
382                scaled_x,
383                scaled_y,
384                ref_level.width,
385                ref_level.height,
386            );
387
388            // Search at this level
389            let result = self
390                .searcher
391                .search_with_predictor(&ctx, &level_config, scaled_mv);
392
393            // Scale MV back up for next level
394            current_mv = MotionVector::from_full_pel(
395                result.mv.full_pel_x() * scale as i32,
396                result.mv.full_pel_y() * scale as i32,
397            );
398        }
399
400        // Final search at full resolution
401        if let (Some(src_level), Some(ref_level)) =
402            (self.src_pyramid.finest(), self.ref_pyramid.finest())
403        {
404            let src_offset = block_y * src_level.stride + block_x;
405            if src_offset < src_level.data.len() {
406                let ctx = SearchContext::new(
407                    &src_level.data[src_offset..],
408                    src_level.stride,
409                    &ref_level.data,
410                    ref_level.stride,
411                    block_size,
412                    block_x,
413                    block_y,
414                    ref_level.width,
415                    ref_level.height,
416                );
417
418                let final_config = SearchConfig {
419                    range: self.config.refine_range,
420                    ..search_config.clone()
421                };
422
423                return self
424                    .searcher
425                    .search_with_predictor(&ctx, &final_config, current_mv);
426            }
427        }
428
429        let cost = search_config.mv_cost.rd_cost(&current_mv, u32::MAX);
430        BlockMatch::new(current_mv, u32::MAX, cost)
431    }
432
433    /// Calculates the optimal number of pyramid levels.
434    #[must_use]
435    pub fn calculate_levels(width: usize, height: usize, max_levels: usize) -> usize {
436        let min_dim = width.min(height);
437        let mut levels = 1;
438
439        let mut size = min_dim;
440        while size >= MIN_PYRAMID_DIMENSION * 2 && levels < max_levels {
441            size /= 2;
442            levels += 1;
443        }
444
445        levels
446    }
447}
448
449/// Coarse-to-fine refinement helper.
450#[derive(Clone, Debug, Default)]
451pub struct CoarseToFineRefiner {
452    /// Refinement steps at each scale.
453    steps: Vec<RefinementStep>,
454}
455
456/// A single refinement step.
457#[derive(Clone, Debug)]
458pub struct RefinementStep {
459    /// Scale factor (1 = full resolution).
460    pub scale: usize,
461    /// Search range for this step.
462    pub range: SearchRange,
463    /// Number of iterations.
464    pub iterations: u32,
465}
466
467impl Default for RefinementStep {
468    fn default() -> Self {
469        Self {
470            scale: 1,
471            range: SearchRange::symmetric(2),
472            iterations: 4,
473        }
474    }
475}
476
477impl CoarseToFineRefiner {
478    /// Creates a new refiner.
479    #[must_use]
480    pub fn new() -> Self {
481        Self { steps: Vec::new() }
482    }
483
484    /// Adds a refinement step.
485    #[must_use]
486    pub fn add_step(mut self, scale: usize, range: i32, iterations: u32) -> Self {
487        self.steps.push(RefinementStep {
488            scale,
489            range: SearchRange::symmetric(range),
490            iterations,
491        });
492        self
493    }
494
495    /// Creates default coarse-to-fine steps.
496    #[must_use]
497    pub fn default_steps() -> Self {
498        Self::new()
499            .add_step(4, 8, 8) // 1/4 resolution, wide search
500            .add_step(2, 4, 6) // 1/2 resolution, medium search
501            .add_step(1, 2, 4) // Full resolution, fine search
502    }
503
504    /// Returns the refinement steps.
505    #[must_use]
506    pub fn steps(&self) -> &[RefinementStep] {
507        &self.steps
508    }
509
510    /// Scales a motion vector between levels.
511    #[must_use]
512    pub const fn scale_mv(mv: MotionVector, from_scale: usize, to_scale: usize) -> MotionVector {
513        if from_scale == to_scale {
514            return mv;
515        }
516
517        if from_scale > to_scale {
518            // Upscaling (coarse to fine)
519            let factor = (from_scale / to_scale) as i32;
520            MotionVector::new(mv.dx * factor, mv.dy * factor)
521        } else {
522            // Downscaling (fine to coarse)
523            let factor = (to_scale / from_scale) as i32;
524            MotionVector::new(mv.dx / factor, mv.dy / factor)
525        }
526    }
527}
528
529/// Resolution scaling utilities.
530pub struct ResolutionScaler;
531
532impl ResolutionScaler {
533    /// Downsamples image by factor of 2.
534    pub fn downsample_2x(src: &[u8], width: usize, height: usize) -> Vec<u8> {
535        let new_width = width / 2;
536        let new_height = height / 2;
537        let mut dst = vec![0u8; new_width * new_height];
538
539        for y in 0..new_height {
540            for x in 0..new_width {
541                let src_x = x * 2;
542                let src_y = y * 2;
543
544                let p00 = u32::from(src[src_y * width + src_x]);
545                let p01 = u32::from(src[src_y * width + src_x + 1]);
546                let p10 = u32::from(src[(src_y + 1) * width + src_x]);
547                let p11 = u32::from(src[(src_y + 1) * width + src_x + 1]);
548
549                dst[y * new_width + x] = ((p00 + p01 + p10 + p11 + 2) / 4) as u8;
550            }
551        }
552
553        dst
554    }
555
556    /// Downsamples image by factor of 4.
557    pub fn downsample_4x(src: &[u8], width: usize, height: usize) -> Vec<u8> {
558        let half = Self::downsample_2x(src, width, height);
559        Self::downsample_2x(&half, width / 2, height / 2)
560    }
561
562    /// Upsamples motion vector coordinates.
563    #[must_use]
564    pub const fn upsample_mv(mv: MotionVector, factor: i32) -> MotionVector {
565        MotionVector::new(mv.dx * factor, mv.dy * factor)
566    }
567
568    /// Downsamples motion vector coordinates.
569    #[must_use]
570    pub const fn downsample_mv(mv: MotionVector, factor: i32) -> MotionVector {
571        MotionVector::new(mv.dx / factor, mv.dy / factor)
572    }
573
574    /// Scales block position.
575    #[must_use]
576    pub const fn scale_position(x: usize, y: usize, scale: usize) -> (usize, usize) {
577        (x / scale, y / scale)
578    }
579}
580
581#[cfg(test)]
582mod tests {
583    use super::*;
584
585    #[test]
586    fn test_pyramid_level_creation() {
587        let level = PyramidLevel::new(64, 64, 1);
588        assert_eq!(level.width, 64);
589        assert_eq!(level.height, 64);
590        assert_eq!(level.scale, 1);
591        assert_eq!(level.data.len(), 64 * 64);
592    }
593
594    #[test]
595    fn test_pyramid_level_pixel_access() {
596        let mut level = PyramidLevel::new(8, 8, 1);
597        level.set_pixel(3, 4, 128);
598        assert_eq!(level.get_pixel(3, 4), 128);
599        assert_eq!(level.get_pixel(0, 0), 0);
600    }
601
602    #[test]
603    fn test_pyramid_level_downsample() {
604        let mut src_level = PyramidLevel::new(8, 8, 1);
605        // Fill with gradient
606        for y in 0..8 {
607            for x in 0..8 {
608                src_level.set_pixel(x, y, ((x + y) * 16) as u8);
609            }
610        }
611
612        let mut dst_level = PyramidLevel::new(4, 4, 2);
613        dst_level.downsample_from(&src_level);
614
615        // Check that values are averaged
616        assert!(dst_level.get_pixel(0, 0) > 0);
617        assert!(dst_level.get_pixel(1, 1) > dst_level.get_pixel(0, 0));
618    }
619
620    #[test]
621    fn test_image_pyramid_build() {
622        let src = vec![128u8; 64 * 64];
623        let mut pyramid = ImagePyramid::new();
624        pyramid.build(&src, 64, 64, 3);
625
626        assert_eq!(pyramid.num_levels(), 3);
627
628        // Check level dimensions
629        assert_eq!(pyramid.get_level(0).map(|l| l.width), Some(64));
630        assert_eq!(pyramid.get_level(1).map(|l| l.width), Some(32));
631        assert_eq!(pyramid.get_level(2).map(|l| l.width), Some(16));
632    }
633
634    #[test]
635    fn test_image_pyramid_min_size() {
636        let src = vec![128u8; 32 * 32];
637        let mut pyramid = ImagePyramid::new();
638        pyramid.build(&src, 32, 32, 5);
639
640        // Should stop at MIN_PYRAMID_DIMENSION
641        assert!(pyramid.num_levels() <= 2);
642    }
643
644    #[test]
645    fn test_hierarchical_config() {
646        let config = HierarchicalConfig::new(4)
647            .coarse_range(SearchRange::symmetric(32))
648            .refine_range(SearchRange::symmetric(8));
649
650        assert_eq!(config.levels, 4);
651        assert_eq!(config.coarse_range.horizontal, 32);
652        assert_eq!(config.refine_range.horizontal, 8);
653    }
654
655    #[test]
656    fn test_hierarchical_search_creation() {
657        let search = HierarchicalSearch::new().with_config(HierarchicalConfig::new(3));
658
659        assert_eq!(search.config.levels, 3);
660    }
661
662    #[test]
663    fn test_calculate_pyramid_levels() {
664        assert_eq!(HierarchicalSearch::calculate_levels(128, 128, 4), 4);
665        assert_eq!(HierarchicalSearch::calculate_levels(64, 64, 4), 3);
666        assert_eq!(HierarchicalSearch::calculate_levels(32, 32, 4), 2);
667    }
668
669    #[test]
670    fn test_coarse_to_fine_refiner() {
671        let refiner = CoarseToFineRefiner::default_steps();
672        assert_eq!(refiner.steps().len(), 3);
673    }
674
675    #[test]
676    fn test_scale_mv() {
677        let mv = MotionVector::new(16, 32);
678
679        // Coarse to fine (upscale)
680        let scaled_up = CoarseToFineRefiner::scale_mv(mv, 2, 1);
681        assert_eq!(scaled_up.dx, 32);
682        assert_eq!(scaled_up.dy, 64);
683
684        // Fine to coarse (downscale)
685        let scaled_down = CoarseToFineRefiner::scale_mv(mv, 1, 2);
686        assert_eq!(scaled_down.dx, 8);
687        assert_eq!(scaled_down.dy, 16);
688    }
689
690    #[test]
691    fn test_resolution_scaler_downsample() {
692        // Create 4x4 image with known values
693        let src = vec![
694            100, 100, 200, 200, 100, 100, 200, 200, 50, 50, 150, 150, 50, 50, 150, 150,
695        ];
696
697        let dst = ResolutionScaler::downsample_2x(&src, 4, 4);
698        assert_eq!(dst.len(), 4);
699
700        // Check averaged values
701        assert_eq!(dst[0], 100); // (100+100+100+100)/4
702        assert_eq!(dst[1], 200); // (200+200+200+200)/4
703        assert_eq!(dst[2], 50); // (50+50+50+50)/4
704        assert_eq!(dst[3], 150); // (150+150+150+150)/4
705    }
706
707    #[test]
708    fn test_resolution_scaler_mv() {
709        let mv = MotionVector::new(8, 16);
710
711        let up = ResolutionScaler::upsample_mv(mv, 2);
712        assert_eq!(up.dx, 16);
713        assert_eq!(up.dy, 32);
714
715        let down = ResolutionScaler::downsample_mv(mv, 2);
716        assert_eq!(down.dx, 4);
717        assert_eq!(down.dy, 8);
718    }
719
720    #[test]
721    fn test_hierarchical_search_integration() {
722        let src = vec![100u8; 64 * 64];
723        let reference = vec![100u8; 64 * 64];
724
725        let mut search = HierarchicalSearch::new().with_config(HierarchicalConfig::new(3));
726
727        search.build_pyramids(&src, 64, 64, &reference, 64, 64);
728
729        let config = SearchConfig::default();
730        let result = search.search_hierarchical(0, 0, BlockSize::Block8x8, &config);
731
732        // Perfect match at origin
733        assert_eq!(result.mv.full_pel_x(), 0);
734        assert_eq!(result.mv.full_pel_y(), 0);
735    }
736
737    #[test]
738    fn test_refinement_step() {
739        let step = RefinementStep::default();
740        assert_eq!(step.scale, 1);
741        assert_eq!(step.iterations, 4);
742    }
743}