Skip to main content

oximedia_align/
motion_compensate.rs

1#![allow(dead_code)]
2//! Motion compensation for temporal video alignment.
3//!
4//! This module provides frame-level motion compensation used to correct for camera movement
5//! and subject motion when aligning video streams temporally.
6//!
7//! # Features
8//!
9//! - **Block-based motion estimation** using full search and diamond search
10//! - **Motion vector field** representation and interpolation
11//! - **Frame warping** to compensate detected motion
12//! - **Motion statistics** for alignment quality assessment
13
14use crate::{AlignError, AlignResult, Point2D};
15
16/// A motion vector representing displacement of a block.
17#[derive(Debug, Clone, Copy, PartialEq)]
18pub struct MotionVector {
19    /// Horizontal displacement in pixels.
20    pub dx: f64,
21    /// Vertical displacement in pixels.
22    pub dy: f64,
23    /// Match quality (lower SAD means better match).
24    pub cost: f64,
25}
26
27impl MotionVector {
28    /// Create a new motion vector.
29    #[must_use]
30    pub fn new(dx: f64, dy: f64, cost: f64) -> Self {
31        Self { dx, dy, cost }
32    }
33
34    /// Create a zero motion vector.
35    #[must_use]
36    pub fn zero() -> Self {
37        Self {
38            dx: 0.0,
39            dy: 0.0,
40            cost: 0.0,
41        }
42    }
43
44    /// Compute the magnitude of this motion vector.
45    #[must_use]
46    pub fn magnitude(&self) -> f64 {
47        (self.dx * self.dx + self.dy * self.dy).sqrt()
48    }
49
50    /// Compute the direction angle in radians.
51    #[must_use]
52    pub fn direction(&self) -> f64 {
53        self.dy.atan2(self.dx)
54    }
55
56    /// Add two motion vectors.
57    #[must_use]
58    pub fn add(&self, other: &Self) -> Self {
59        Self {
60            dx: self.dx + other.dx,
61            dy: self.dy + other.dy,
62            cost: (self.cost + other.cost) / 2.0,
63        }
64    }
65
66    /// Scale this motion vector by a factor.
67    #[must_use]
68    pub fn scale(&self, factor: f64) -> Self {
69        Self {
70            dx: self.dx * factor,
71            dy: self.dy * factor,
72            cost: self.cost,
73        }
74    }
75}
76
77/// Search strategy for block matching.
78#[derive(Debug, Clone, Copy, PartialEq, Eq)]
79pub enum SearchStrategy {
80    /// Full (exhaustive) search within the search range.
81    FullSearch,
82    /// Diamond search pattern for faster estimation.
83    DiamondSearch,
84    /// Three-step search for moderate speed/quality.
85    ThreeStepSearch,
86    /// Hexagonal search pattern.
87    HexagonalSearch,
88}
89
90/// Configuration for motion estimation.
91#[derive(Debug, Clone)]
92pub struct MotionEstimationConfig {
93    /// Block size in pixels (width and height).
94    pub block_size: u32,
95    /// Search range in pixels.
96    pub search_range: u32,
97    /// Search strategy to use.
98    pub search_strategy: SearchStrategy,
99    /// Enable sub-pixel refinement.
100    pub sub_pixel: bool,
101    /// Frame width in pixels.
102    pub frame_width: u32,
103    /// Frame height in pixels.
104    pub frame_height: u32,
105}
106
107impl Default for MotionEstimationConfig {
108    fn default() -> Self {
109        Self {
110            block_size: 16,
111            search_range: 32,
112            search_strategy: SearchStrategy::DiamondSearch,
113            sub_pixel: true,
114            frame_width: 1920,
115            frame_height: 1080,
116        }
117    }
118}
119
120/// A field of motion vectors covering an entire frame.
121#[derive(Debug, Clone)]
122pub struct MotionField {
123    /// Motion vectors in row-major order.
124    pub vectors: Vec<MotionVector>,
125    /// Number of blocks horizontally.
126    pub cols: u32,
127    /// Number of blocks vertically.
128    pub rows: u32,
129    /// Block size used for estimation.
130    pub block_size: u32,
131}
132
133impl MotionField {
134    /// Create a new motion field with zero vectors.
135    #[must_use]
136    #[allow(clippy::cast_precision_loss)]
137    pub fn new(frame_width: u32, frame_height: u32, block_size: u32) -> Self {
138        let cols = frame_width.div_ceil(block_size);
139        let rows = frame_height.div_ceil(block_size);
140        let count = (cols * rows) as usize;
141        Self {
142            vectors: vec![MotionVector::zero(); count],
143            cols,
144            rows,
145            block_size,
146        }
147    }
148
149    /// Get the motion vector at block position (bx, by).
150    #[must_use]
151    pub fn get(&self, bx: u32, by: u32) -> Option<&MotionVector> {
152        if bx < self.cols && by < self.rows {
153            Some(&self.vectors[(by * self.cols + bx) as usize])
154        } else {
155            None
156        }
157    }
158
159    /// Set the motion vector at block position (bx, by).
160    pub fn set(&mut self, bx: u32, by: u32, mv: MotionVector) {
161        if bx < self.cols && by < self.rows {
162            self.vectors[(by * self.cols + bx) as usize] = mv;
163        }
164    }
165
166    /// Interpolate a motion vector at a continuous pixel position.
167    #[must_use]
168    #[allow(clippy::cast_precision_loss)]
169    pub fn interpolate(&self, x: f64, y: f64) -> MotionVector {
170        let bs = f64::from(self.block_size);
171        let bx = (x / bs).floor();
172        let by = (y / bs).floor();
173
174        let bxi = bx as u32;
175        let byi = by as u32;
176
177        // Bilinear interpolation
178        let fx = x / bs - bx;
179        let fy = y / bs - by;
180
181        let get_mv = |cx: u32, cy: u32| -> MotionVector {
182            self.get(
183                cx.min(self.cols.saturating_sub(1)),
184                cy.min(self.rows.saturating_sub(1)),
185            )
186            .copied()
187            .unwrap_or_else(MotionVector::zero)
188        };
189
190        let tl = get_mv(bxi, byi);
191        let tr = get_mv(bxi + 1, byi);
192        let bl = get_mv(bxi, byi + 1);
193        let br = get_mv(bxi + 1, byi + 1);
194
195        let dx = tl.dx * (1.0 - fx) * (1.0 - fy)
196            + tr.dx * fx * (1.0 - fy)
197            + bl.dx * (1.0 - fx) * fy
198            + br.dx * fx * fy;
199
200        let dy = tl.dy * (1.0 - fx) * (1.0 - fy)
201            + tr.dy * fx * (1.0 - fy)
202            + bl.dy * (1.0 - fx) * fy
203            + br.dy * fx * fy;
204
205        let cost = tl.cost * (1.0 - fx) * (1.0 - fy)
206            + tr.cost * fx * (1.0 - fy)
207            + bl.cost * (1.0 - fx) * fy
208            + br.cost * fx * fy;
209
210        MotionVector::new(dx, dy, cost)
211    }
212
213    /// Compute global motion from the motion field (median of all vectors).
214    #[must_use]
215    pub fn global_motion(&self) -> MotionVector {
216        if self.vectors.is_empty() {
217            return MotionVector::zero();
218        }
219
220        let mut dxs: Vec<f64> = self.vectors.iter().map(|v| v.dx).collect();
221        let mut dys: Vec<f64> = self.vectors.iter().map(|v| v.dy).collect();
222
223        dxs.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
224        dys.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
225
226        let mid = dxs.len() / 2;
227        MotionVector::new(dxs[mid], dys[mid], 0.0)
228    }
229
230    /// Compute the average magnitude of all motion vectors.
231    #[must_use]
232    #[allow(clippy::cast_precision_loss)]
233    pub fn average_magnitude(&self) -> f64 {
234        if self.vectors.is_empty() {
235            return 0.0;
236        }
237        let total: f64 = self.vectors.iter().map(MotionVector::magnitude).sum();
238        total / self.vectors.len() as f64
239    }
240
241    /// Count the number of vectors exceeding a magnitude threshold.
242    #[must_use]
243    pub fn count_above_threshold(&self, threshold: f64) -> usize {
244        self.vectors
245            .iter()
246            .filter(|v| v.magnitude() > threshold)
247            .count()
248    }
249}
250
251/// Motion statistics for a pair of frames.
252#[derive(Debug, Clone)]
253pub struct MotionStats {
254    /// Average motion magnitude in pixels.
255    pub avg_magnitude: f64,
256    /// Maximum motion magnitude in pixels.
257    pub max_magnitude: f64,
258    /// Standard deviation of motion magnitude.
259    pub std_magnitude: f64,
260    /// Global horizontal motion (median dx).
261    pub global_dx: f64,
262    /// Global vertical motion (median dy).
263    pub global_dy: f64,
264    /// Fraction of blocks with significant motion (above 1 pixel).
265    pub motion_fraction: f64,
266}
267
268/// Motion compensator that estimates and applies motion compensation.
269#[derive(Debug, Clone)]
270pub struct MotionCompensator {
271    /// Configuration for motion estimation.
272    config: MotionEstimationConfig,
273}
274
275impl MotionCompensator {
276    /// Create a new motion compensator with the given configuration.
277    #[must_use]
278    pub fn new(config: MotionEstimationConfig) -> Self {
279        Self { config }
280    }
281
282    /// Create with default configuration.
283    #[must_use]
284    pub fn with_defaults() -> Self {
285        Self {
286            config: MotionEstimationConfig::default(),
287        }
288    }
289
290    /// Estimate motion field between a reference frame and a target frame.
291    ///
292    /// Both frames are provided as grayscale pixel data (one byte per pixel, row-major).
293    #[allow(clippy::cast_precision_loss)]
294    pub fn estimate(&self, reference: &[u8], target: &[u8]) -> AlignResult<MotionField> {
295        let expected_size = (self.config.frame_width * self.config.frame_height) as usize;
296        if reference.len() != expected_size || target.len() != expected_size {
297            return Err(AlignError::InsufficientData(format!(
298                "Expected frame size {}, got ref={} target={}",
299                expected_size,
300                reference.len(),
301                target.len()
302            )));
303        }
304
305        let mut field = MotionField::new(
306            self.config.frame_width,
307            self.config.frame_height,
308            self.config.block_size,
309        );
310
311        let bs = self.config.block_size;
312        let sr = self.config.search_range as i32;
313        let w = self.config.frame_width;
314        let h = self.config.frame_height;
315
316        for by in 0..field.rows {
317            for bx in 0..field.cols {
318                let orig_x = (bx * bs) as i32;
319                let orig_y = (by * bs) as i32;
320
321                let mv = match self.config.search_strategy {
322                    SearchStrategy::FullSearch => {
323                        self.full_search(reference, target, orig_x, orig_y, bs, sr, w, h)
324                    }
325                    _ => {
326                        // Use diamond search as default fast path
327                        self.diamond_search(reference, target, orig_x, orig_y, bs, sr, w, h)
328                    }
329                };
330
331                field.set(bx, by, mv);
332            }
333        }
334
335        Ok(field)
336    }
337
338    /// Compute motion statistics from a motion field.
339    #[must_use]
340    #[allow(clippy::cast_precision_loss)]
341    pub fn compute_stats(field: &MotionField) -> MotionStats {
342        if field.vectors.is_empty() {
343            return MotionStats {
344                avg_magnitude: 0.0,
345                max_magnitude: 0.0,
346                std_magnitude: 0.0,
347                global_dx: 0.0,
348                global_dy: 0.0,
349                motion_fraction: 0.0,
350            };
351        }
352
353        let magnitudes: Vec<f64> = field.vectors.iter().map(MotionVector::magnitude).collect();
354        let n = magnitudes.len() as f64;
355        let avg = magnitudes.iter().sum::<f64>() / n;
356        let max = magnitudes.iter().copied().fold(0.0_f64, f64::max);
357        let variance = magnitudes.iter().map(|m| (m - avg).powi(2)).sum::<f64>() / n;
358        let std_dev = variance.sqrt();
359
360        let global = field.global_motion();
361        let motion_count = field.count_above_threshold(1.0);
362
363        MotionStats {
364            avg_magnitude: avg,
365            max_magnitude: max,
366            std_magnitude: std_dev,
367            global_dx: global.dx,
368            global_dy: global.dy,
369            motion_fraction: motion_count as f64 / n,
370        }
371    }
372
373    /// Apply motion compensation to warp a set of points.
374    #[must_use]
375    pub fn compensate_points(field: &MotionField, points: &[Point2D]) -> Vec<Point2D> {
376        points
377            .iter()
378            .map(|p| {
379                let mv = field.interpolate(p.x, p.y);
380                Point2D::new(p.x + mv.dx, p.y + mv.dy)
381            })
382            .collect()
383    }
384
385    /// Full search block matching (exhaustive).
386    #[allow(clippy::too_many_arguments)]
387    #[allow(clippy::cast_precision_loss)]
388    fn full_search(
389        &self,
390        reference: &[u8],
391        target: &[u8],
392        bx: i32,
393        by: i32,
394        bs: u32,
395        sr: i32,
396        w: u32,
397        h: u32,
398    ) -> MotionVector {
399        let mut best_dx = 0i32;
400        let mut best_dy = 0i32;
401        let mut best_cost = f64::MAX;
402
403        for dy in -sr..=sr {
404            for dx in -sr..=sr {
405                let cost = self.compute_sad(reference, target, bx, by, bx + dx, by + dy, bs, w, h);
406                if cost < best_cost
407                    || (cost == best_cost
408                        && (dx.unsigned_abs() + dy.unsigned_abs())
409                            < (best_dx.unsigned_abs() + best_dy.unsigned_abs()))
410                {
411                    best_cost = cost;
412                    best_dx = dx;
413                    best_dy = dy;
414                }
415            }
416        }
417
418        MotionVector::new(f64::from(best_dx), f64::from(best_dy), best_cost)
419    }
420
421    /// Diamond search pattern block matching.
422    #[allow(clippy::too_many_arguments)]
423    #[allow(clippy::cast_precision_loss)]
424    fn diamond_search(
425        &self,
426        reference: &[u8],
427        target: &[u8],
428        bx: i32,
429        by: i32,
430        bs: u32,
431        sr: i32,
432        w: u32,
433        h: u32,
434    ) -> MotionVector {
435        let large_diamond: [(i32, i32); 9] = [
436            (0, 0),
437            (0, -2),
438            (1, -1),
439            (2, 0),
440            (1, 1),
441            (0, 2),
442            (-1, 1),
443            (-2, 0),
444            (-1, -1),
445        ];
446
447        let mut cx = 0i32;
448        let mut cy = 0i32;
449        let mut best_cost = f64::MAX;
450
451        for _ in 0..sr {
452            let mut found_better = false;
453            let mut new_cx = cx;
454            let mut new_cy = cy;
455
456            for &(ddx, ddy) in &large_diamond {
457                let tx = cx + ddx;
458                let ty = cy + ddy;
459                if tx.abs() > sr || ty.abs() > sr {
460                    continue;
461                }
462                let cost = self.compute_sad(reference, target, bx, by, bx + tx, by + ty, bs, w, h);
463                if cost < best_cost {
464                    best_cost = cost;
465                    new_cx = tx;
466                    new_cy = ty;
467                    found_better = true;
468                }
469            }
470
471            if !found_better || (new_cx == cx && new_cy == cy) {
472                break;
473            }
474            cx = new_cx;
475            cy = new_cy;
476        }
477
478        MotionVector::new(f64::from(cx), f64::from(cy), best_cost)
479    }
480
481    /// Compute Sum of Absolute Differences (SAD) for a block.
482    #[allow(clippy::too_many_arguments)]
483    #[allow(clippy::cast_precision_loss)]
484    fn compute_sad(
485        &self,
486        reference: &[u8],
487        target: &[u8],
488        rx: i32,
489        ry: i32,
490        tx: i32,
491        ty: i32,
492        bs: u32,
493        w: u32,
494        h: u32,
495    ) -> f64 {
496        let mut sad = 0u64;
497        let bs_i = bs as i32;
498        let w_i = w as i32;
499        let h_i = h as i32;
500
501        for row in 0..bs_i {
502            for col in 0..bs_i {
503                let ref_x = rx + col;
504                let ref_y = ry + row;
505                let tgt_x = tx + col;
506                let tgt_y = ty + row;
507
508                if ref_x < 0 || ref_x >= w_i || ref_y < 0 || ref_y >= h_i {
509                    sad += 128;
510                    continue;
511                }
512                if tgt_x < 0 || tgt_x >= w_i || tgt_y < 0 || tgt_y >= h_i {
513                    sad += 128;
514                    continue;
515                }
516
517                let ref_idx = (ref_y as u32 * w + ref_x as u32) as usize;
518                let tgt_idx = (tgt_y as u32 * w + tgt_x as u32) as usize;
519
520                let diff = i32::from(reference[ref_idx]) - i32::from(target[tgt_idx]);
521                sad += u64::from(diff.unsigned_abs());
522            }
523        }
524
525        sad as f64
526    }
527
528    /// Get the current configuration.
529    #[must_use]
530    pub fn config(&self) -> &MotionEstimationConfig {
531        &self.config
532    }
533}
534
535#[cfg(test)]
536mod tests {
537    use super::*;
538
539    #[test]
540    fn test_motion_vector_creation() {
541        let mv = MotionVector::new(3.0, 4.0, 100.0);
542        assert!((mv.dx - 3.0).abs() < f64::EPSILON);
543        assert!((mv.dy - 4.0).abs() < f64::EPSILON);
544        assert!((mv.cost - 100.0).abs() < f64::EPSILON);
545    }
546
547    #[test]
548    fn test_motion_vector_magnitude() {
549        let mv = MotionVector::new(3.0, 4.0, 0.0);
550        assert!((mv.magnitude() - 5.0).abs() < 1e-10);
551    }
552
553    #[test]
554    fn test_motion_vector_direction() {
555        let mv = MotionVector::new(1.0, 0.0, 0.0);
556        assert!((mv.direction()).abs() < 1e-10);
557
558        let mv_up = MotionVector::new(0.0, 1.0, 0.0);
559        assert!((mv_up.direction() - std::f64::consts::FRAC_PI_2).abs() < 1e-10);
560    }
561
562    #[test]
563    fn test_motion_vector_zero() {
564        let mv = MotionVector::zero();
565        assert!((mv.magnitude()).abs() < f64::EPSILON);
566    }
567
568    #[test]
569    fn test_motion_vector_add() {
570        let a = MotionVector::new(1.0, 2.0, 10.0);
571        let b = MotionVector::new(3.0, 4.0, 20.0);
572        let c = a.add(&b);
573        assert!((c.dx - 4.0).abs() < f64::EPSILON);
574        assert!((c.dy - 6.0).abs() < f64::EPSILON);
575        assert!((c.cost - 15.0).abs() < f64::EPSILON);
576    }
577
578    #[test]
579    fn test_motion_vector_scale() {
580        let mv = MotionVector::new(2.0, 3.0, 10.0);
581        let scaled = mv.scale(0.5);
582        assert!((scaled.dx - 1.0).abs() < f64::EPSILON);
583        assert!((scaled.dy - 1.5).abs() < f64::EPSILON);
584    }
585
586    #[test]
587    fn test_motion_field_creation() {
588        let field = MotionField::new(320, 240, 16);
589        assert_eq!(field.cols, 20);
590        assert_eq!(field.rows, 15);
591        assert_eq!(field.vectors.len(), 300);
592    }
593
594    #[test]
595    fn test_motion_field_get_set() {
596        let mut field = MotionField::new(64, 64, 16);
597        let mv = MotionVector::new(5.0, -3.0, 50.0);
598        field.set(1, 2, mv);
599        let retrieved = field.get(1, 2).expect("retrieved should be valid");
600        assert!((retrieved.dx - 5.0).abs() < f64::EPSILON);
601        assert!((retrieved.dy - (-3.0)).abs() < f64::EPSILON);
602    }
603
604    #[test]
605    fn test_motion_field_global_motion() {
606        let mut field = MotionField::new(64, 64, 16);
607        // Set all vectors to roughly (2, 1) with some noise
608        for by in 0..field.rows {
609            for bx in 0..field.cols {
610                field.set(bx, by, MotionVector::new(2.0, 1.0, 0.0));
611            }
612        }
613        let global = field.global_motion();
614        assert!((global.dx - 2.0).abs() < f64::EPSILON);
615        assert!((global.dy - 1.0).abs() < f64::EPSILON);
616    }
617
618    #[test]
619    fn test_motion_field_average_magnitude() {
620        let mut field = MotionField::new(32, 32, 16);
621        field.set(0, 0, MotionVector::new(3.0, 4.0, 0.0));
622        field.set(1, 0, MotionVector::new(0.0, 0.0, 0.0));
623        field.set(0, 1, MotionVector::new(0.0, 0.0, 0.0));
624        field.set(1, 1, MotionVector::new(0.0, 0.0, 0.0));
625        let avg = field.average_magnitude();
626        // One vector has mag 5.0, three have 0.0 => avg = 5.0/4 = 1.25
627        assert!((avg - 1.25).abs() < 1e-10);
628    }
629
630    #[test]
631    fn test_estimate_static_frames() {
632        let config = MotionEstimationConfig {
633            block_size: 8,
634            search_range: 4,
635            search_strategy: SearchStrategy::FullSearch,
636            sub_pixel: false,
637            frame_width: 32,
638            frame_height: 32,
639        };
640        let comp = MotionCompensator::new(config);
641
642        // Both frames identical => all zero motion
643        let frame = vec![128u8; 32 * 32];
644        let field = comp
645            .estimate(&frame, &frame)
646            .expect("field should be valid");
647
648        for mv in &field.vectors {
649            assert!((mv.dx).abs() < f64::EPSILON);
650            assert!((mv.dy).abs() < f64::EPSILON);
651        }
652    }
653
654    #[test]
655    fn test_estimate_wrong_size() {
656        let comp = MotionCompensator::new(MotionEstimationConfig {
657            frame_width: 64,
658            frame_height: 64,
659            ..MotionEstimationConfig::default()
660        });
661        let small_frame = vec![0u8; 10];
662        let result = comp.estimate(&small_frame, &small_frame);
663        assert!(result.is_err());
664    }
665
666    #[test]
667    fn test_compensate_points() {
668        let mut field = MotionField::new(64, 64, 64);
669        field.set(0, 0, MotionVector::new(10.0, -5.0, 0.0));
670
671        let points = vec![Point2D::new(10.0, 20.0)];
672        let compensated = MotionCompensator::compensate_points(&field, &points);
673        assert!((compensated[0].x - 20.0).abs() < 1e-10);
674        assert!((compensated[0].y - 15.0).abs() < 1e-10);
675    }
676
677    #[test]
678    fn test_motion_stats_static() {
679        let field = MotionField::new(64, 64, 16);
680        let stats = MotionCompensator::compute_stats(&field);
681        assert!((stats.avg_magnitude).abs() < f64::EPSILON);
682        assert!((stats.max_magnitude).abs() < f64::EPSILON);
683        assert!((stats.motion_fraction).abs() < f64::EPSILON);
684    }
685}