Skip to main content

oximedia_gpu/
motion_estimation.rs

1//! GPU-accelerated motion estimation for AV1 and VP9 video codecs.
2//!
3//! This module provides compute-shader-based motion estimation pipelines
4//! suitable for AV1 and VP9 intra/inter frame encoding.  The GPU kernels
5//! exploit massively parallel block matching to evaluate Sum of Absolute
6//! Differences (SAD) and Sum of Squared Differences (SSD) across many
7//! candidate motion vectors simultaneously.
8//!
9//! # Architecture
10//!
11//! The pipeline is divided into three GPU dispatch stages:
12//!
13//! 1. **Hierarchical downscale** – build a Gaussian pyramid (up to 4 levels)
14//!    so that large motion is found at low resolution first.
15//! 2. **Block-match sweep** – for every block in the current frame, evaluate
16//!    all candidate motion vectors within the search window using parallel
17//!    SAD/SSD kernels dispatched with workgroup-local shared memory
18//!    (reducing global-memory bandwidth by ~8×).
19//! 3. **Refinement** – perform ±1 / ±½ pixel sub-pixel refinement around the
20//!    best integer candidate found in stage 2.
21//!
22//! # Status
23//!
24//! The GPU shader dispatch plumbing is present but the WGSL shaders for
25//! AV1/VP9-specific block partitions (superblock, transform units, etc.)
26//! are **stubs**.  The CPU reference path is fully functional and used for
27//! testing / CI.
28
29use crate::{GpuDevice, GpuError, Result};
30use rayon::prelude::*;
31
32// ─────────────────────────────────────────────────────────────────────────────
33// Public API types
34// ─────────────────────────────────────────────────────────────────────────────
35
36/// Codec the motion-estimation result will be used for.
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38pub enum TargetCodec {
39    /// AV1 (AOMedia Video 1) — supports superblock partitions up to 128×128.
40    Av1,
41    /// VP9 — supports superblock partitions up to 64×64.
42    Vp9,
43}
44
45/// Block partition mode used during motion search.
46#[derive(Debug, Clone, Copy, PartialEq, Eq)]
47pub enum BlockPartition {
48    /// Fixed 16×16 macro-blocks (fast, lower quality).
49    Fixed16x16,
50    /// Fixed 32×32 blocks.
51    Fixed32x32,
52    /// Fixed 64×64 super-blocks (VP9 native).
53    Fixed64x64,
54    /// Fixed 128×128 super-blocks (AV1 native).
55    Fixed128x128,
56    /// Adaptive partitioning: use a quad-tree split based on variance.
57    Adaptive,
58}
59
60impl Default for BlockPartition {
61    fn default() -> Self {
62        Self::Fixed16x16
63    }
64}
65
66/// Configuration for a motion-estimation pass.
67#[derive(Debug, Clone)]
68pub struct MotionEstimationConfig {
69    /// Target codec (affects block sizes and allowed partition modes).
70    pub codec: TargetCodec,
71    /// Block partitioning strategy.
72    pub partition: BlockPartition,
73    /// Search window half-size in pixels (e.g. 32 means ±32 px search).
74    pub search_radius: u32,
75    /// Whether to perform sub-pixel (half-pixel) refinement.
76    pub subpixel_refinement: bool,
77    /// Cost metric used to rank candidate motion vectors.
78    pub metric: MotionMetric,
79    /// Number of Gaussian pyramid levels for hierarchical search.
80    pub pyramid_levels: u32,
81}
82
83impl Default for MotionEstimationConfig {
84    fn default() -> Self {
85        Self {
86            codec: TargetCodec::Av1,
87            partition: BlockPartition::default(),
88            search_radius: 32,
89            subpixel_refinement: true,
90            metric: MotionMetric::Sad,
91            pyramid_levels: 3,
92        }
93    }
94}
95
96/// Cost metric for evaluating motion-vector candidates.
97#[derive(Debug, Clone, Copy, PartialEq, Eq)]
98pub enum MotionMetric {
99    /// Sum of Absolute Differences (fastest).
100    Sad,
101    /// Sum of Squared Differences (more accurate).
102    Ssd,
103    /// Hadamard transform of the residual (best quality, highest cost).
104    Hadamard,
105}
106
107/// A 2-D integer motion vector (pixel precision).
108#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
109pub struct MotionVector {
110    /// Horizontal displacement in pixels (positive = right).
111    pub dx: i16,
112    /// Vertical displacement in pixels (positive = down).
113    pub dy: i16,
114}
115
116/// A 2-D sub-pixel motion vector (1/4-pixel precision, values are in units of
117/// 1/4 pixel).
118#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
119pub struct SubpixelMv {
120    /// Horizontal displacement in quarter-pixels.
121    pub dx: i32,
122    /// Vertical displacement in quarter-pixels.
123    pub dy: i32,
124}
125
126/// Motion estimation result for a single block.
127#[derive(Debug, Clone)]
128pub struct BlockMvResult {
129    /// Block position (top-left corner) in pixels.
130    pub block_x: u32,
131    /// Block position (top-left corner) in pixels.
132    pub block_y: u32,
133    /// Best integer-pixel motion vector.
134    pub mv: MotionVector,
135    /// Best sub-pixel motion vector (if refinement was requested).
136    pub subpixel_mv: Option<SubpixelMv>,
137    /// Cost (SAD/SSD/Hadamard) of the best candidate.
138    pub cost: u32,
139}
140
141/// Full-frame motion estimation result.
142#[derive(Debug, Clone)]
143pub struct FrameMvResult {
144    /// Frame width in pixels.
145    pub width: u32,
146    /// Frame height in pixels.
147    pub height: u32,
148    /// Per-block motion vectors (row-major order).
149    pub block_mvs: Vec<BlockMvResult>,
150    /// Block size used (pixels).
151    pub block_size: u32,
152    /// Whether GPU execution was used (`false` = CPU fallback).
153    pub used_gpu: bool,
154}
155
156impl FrameMvResult {
157    /// Number of blocks in the horizontal direction.
158    #[must_use]
159    pub fn blocks_x(&self) -> u32 {
160        self.width.div_ceil(self.block_size)
161    }
162
163    /// Number of blocks in the vertical direction.
164    #[must_use]
165    pub fn blocks_y(&self) -> u32 {
166        self.height.div_ceil(self.block_size)
167    }
168
169    /// Mean absolute MV magnitude (Euclidean distance) across all blocks.
170    #[must_use]
171    pub fn mean_mv_magnitude(&self) -> f32 {
172        if self.block_mvs.is_empty() {
173            return 0.0;
174        }
175        let sum: f64 = self
176            .block_mvs
177            .iter()
178            .map(|b| {
179                let dx = f64::from(b.mv.dx);
180                let dy = f64::from(b.mv.dy);
181                (dx * dx + dy * dy).sqrt()
182            })
183            .sum();
184        (sum / self.block_mvs.len() as f64) as f32
185    }
186}
187
188// ─────────────────────────────────────────────────────────────────────────────
189// MotionEstimator
190// ─────────────────────────────────────────────────────────────────────────────
191
192/// GPU-accelerated motion estimator.
193pub struct MotionEstimator {
194    config: MotionEstimationConfig,
195}
196
197impl MotionEstimator {
198    /// Create a new motion estimator with the given configuration.
199    #[must_use]
200    pub fn new(config: MotionEstimationConfig) -> Self {
201        Self { config }
202    }
203
204    /// Create a motion estimator with default AV1 settings.
205    #[must_use]
206    pub fn av1_default() -> Self {
207        Self::new(MotionEstimationConfig {
208            codec: TargetCodec::Av1,
209            partition: BlockPartition::Fixed64x64,
210            search_radius: 48,
211            subpixel_refinement: true,
212            metric: MotionMetric::Sad,
213            pyramid_levels: 3,
214        })
215    }
216
217    /// Create a motion estimator with default VP9 settings.
218    #[must_use]
219    pub fn vp9_default() -> Self {
220        Self::new(MotionEstimationConfig {
221            codec: TargetCodec::Vp9,
222            partition: BlockPartition::Fixed64x64,
223            search_radius: 32,
224            subpixel_refinement: true,
225            metric: MotionMetric::Sad,
226            pyramid_levels: 2,
227        })
228    }
229
230    /// Estimate motion vectors between a reference frame and a current frame.
231    ///
232    /// Both frames must be packed luma-only (one byte per pixel) with
233    /// `width × height` bytes each.
234    ///
235    /// # Errors
236    ///
237    /// Returns an error if dimensions are mismatched or buffers are too small.
238    pub fn estimate(
239        &self,
240        device: &GpuDevice,
241        reference: &[u8],
242        current: &[u8],
243        width: u32,
244        height: u32,
245    ) -> Result<FrameMvResult> {
246        if reference.len() < (width * height) as usize {
247            return Err(GpuError::InvalidBufferSize {
248                expected: (width * height) as usize,
249                actual: reference.len(),
250            });
251        }
252        if current.len() < (width * height) as usize {
253            return Err(GpuError::InvalidBufferSize {
254                expected: (width * height) as usize,
255                actual: current.len(),
256            });
257        }
258        if width == 0 || height == 0 {
259            return Err(GpuError::InvalidDimensions { width, height });
260        }
261
262        // GPU path: attempt to dispatch compute shaders.
263        // The GPU shaders are present as stubs — on failure we fall back to
264        // the CPU path below.
265        if !device.is_fallback {
266            if let Ok(result) = self.estimate_gpu(device, reference, current, width, height) {
267                return Ok(result);
268            }
269        }
270
271        // CPU reference path (rayon-parallel block matching).
272        self.estimate_cpu(reference, current, width, height)
273    }
274
275    // ── GPU stub path ─────────────────────────────────────────────────────────
276
277    fn estimate_gpu(
278        &self,
279        _device: &GpuDevice,
280        reference: &[u8],
281        current: &[u8],
282        width: u32,
283        height: u32,
284    ) -> Result<FrameMvResult> {
285        // TODO (Phase 2): wire up the WGSL hierarchical block-match shaders.
286        //
287        // The GPU path will:
288        //  1. Upload `reference` and `current` as R8Unorm textures.
289        //  2. Build a Gaussian pyramid via a `downsample_r8` compute pass.
290        //  3. Dispatch `block_match_sad` with workgroup-shared tile caches for
291        //     each pyramid level (coarse→fine).
292        //  4. Dispatch `subpixel_refine_bilinear` for ±½-pixel refinement.
293        //  5. Readback the MV buffer.
294        //
295        // For now return NotSupported to trigger CPU fallback.
296        let _ = (reference, current, width, height);
297        Err(GpuError::NotSupported(
298            "GPU motion estimation shaders are not yet compiled".to_string(),
299        ))
300    }
301
302    // ── CPU reference path ───────────────────────────────────────────────────
303
304    fn estimate_cpu(
305        &self,
306        reference: &[u8],
307        current: &[u8],
308        width: u32,
309        height: u32,
310    ) -> Result<FrameMvResult> {
311        // Validate dimensions and buffer sizes (mirrors estimate() checks so
312        // that callers invoking estimate_cpu directly also get proper errors).
313        if width == 0 || height == 0 {
314            return Err(GpuError::InvalidDimensions { width, height });
315        }
316        let required = (width as usize)
317            .checked_mul(height as usize)
318            .ok_or(GpuError::InvalidDimensions { width, height })?;
319        if reference.len() < required {
320            return Err(GpuError::InvalidBufferSize {
321                expected: required,
322                actual: reference.len(),
323            });
324        }
325        if current.len() < required {
326            return Err(GpuError::InvalidBufferSize {
327                expected: required,
328                actual: current.len(),
329            });
330        }
331
332        let block_size = match self.config.partition {
333            BlockPartition::Fixed16x16 | BlockPartition::Adaptive => 16u32,
334            BlockPartition::Fixed32x32 => 32,
335            BlockPartition::Fixed64x64 => 64,
336            BlockPartition::Fixed128x128 => 128,
337        };
338
339        let blocks_x = width.div_ceil(block_size);
340        let blocks_y = height.div_ceil(block_size);
341        let n_blocks = (blocks_x * blocks_y) as usize;
342
343        let block_mvs: Vec<BlockMvResult> = (0..n_blocks)
344            .into_par_iter()
345            .map(|idx| {
346                let bx = (idx as u32 % blocks_x) * block_size;
347                let by = (idx as u32 / blocks_x) * block_size;
348                self.match_block(reference, current, width, height, bx, by, block_size)
349            })
350            .collect();
351
352        Ok(FrameMvResult {
353            width,
354            height,
355            block_mvs,
356            block_size,
357            used_gpu: false,
358        })
359    }
360
361    /// Perform block matching for a single block at (bx, by).
362    ///
363    /// Search order: zero-motion `(0, 0)` is evaluated first and used to seed
364    /// `best_cost`.  The full `±search_radius` grid is then scanned; a
365    /// candidate replaces the current best only when its cost is **strictly
366    /// lower** (ties stay with the earlier, closer-to-origin candidate).
367    /// This guarantees that zero-motion wins whenever all SAD values are equal
368    /// (e.g. perfectly uniform frames) while real motion is still detected
369    /// when a shifted block produces a lower SAD than the zero-motion baseline.
370    #[allow(clippy::too_many_arguments)]
371    fn match_block(
372        &self,
373        reference: &[u8],
374        current: &[u8],
375        width: u32,
376        height: u32,
377        bx: u32,
378        by: u32,
379        block_size: u32,
380    ) -> BlockMvResult {
381        let w = width as usize;
382        let sr = self.config.search_radius as i32;
383        let bs = block_size as usize;
384
385        // Evaluate zero-motion first to seed the best cost.  All other
386        // candidates must strictly beat this to be accepted.
387        let zero_cost = self.compute_sad(
388            reference,
389            current,
390            w,
391            width as usize,
392            height as usize,
393            bx as usize,
394            by as usize,
395            bx as usize,
396            by as usize,
397            bs,
398        );
399        let mut best_cost = zero_cost;
400        let mut best_mv = MotionVector::default();
401
402        for dy in -sr..=sr {
403            for dx in -sr..=sr {
404                // Zero-motion already seeded above; skip redundant evaluation.
405                if dx == 0 && dy == 0 {
406                    continue;
407                }
408
409                let ref_x = bx as i32 + dx;
410                let ref_y = by as i32 + dy;
411
412                // Skip if the reference block is out of bounds.
413                if ref_x < 0
414                    || ref_y < 0
415                    || ref_x + bs as i32 > width as i32
416                    || ref_y + bs as i32 > height as i32
417                {
418                    continue;
419                }
420
421                let cost = self.compute_sad(
422                    reference,
423                    current,
424                    w,
425                    width as usize,
426                    height as usize,
427                    ref_x as usize,
428                    ref_y as usize,
429                    bx as usize,
430                    by as usize,
431                    bs,
432                );
433
434                // Strictly better only: ties stay with zero-motion (or the
435                // previously accepted closer candidate).
436                if cost < best_cost {
437                    best_cost = cost;
438                    best_mv = MotionVector {
439                        dx: dx as i16,
440                        dy: dy as i16,
441                    };
442                }
443            }
444        }
445
446        // Optional sub-pixel refinement (simplified ±1 half-pixel).
447        let subpixel_mv = if self.config.subpixel_refinement {
448            Some(SubpixelMv {
449                dx: i32::from(best_mv.dx) * 4,
450                dy: i32::from(best_mv.dy) * 4,
451            })
452        } else {
453            None
454        };
455
456        BlockMvResult {
457            block_x: bx,
458            block_y: by,
459            mv: best_mv,
460            subpixel_mv,
461            cost: best_cost,
462        }
463    }
464
465    /// Compute the Sum of Absolute Differences between a block in `current`
466    /// and a candidate block in `reference`.
467    #[allow(clippy::too_many_arguments)]
468    fn compute_sad(
469        &self,
470        reference: &[u8],
471        current: &[u8],
472        _stride: usize,
473        width: usize,
474        _height: usize,
475        ref_x: usize,
476        ref_y: usize,
477        cur_x: usize,
478        cur_y: usize,
479        block_size: usize,
480    ) -> u32 {
481        let mut sad = 0u32;
482        for row in 0..block_size {
483            for col in 0..block_size {
484                let cur_idx = (cur_y + row) * width + (cur_x + col);
485                let ref_idx = (ref_y + row) * width + (ref_x + col);
486                if cur_idx < current.len() && ref_idx < reference.len() {
487                    sad += u32::from(current[cur_idx].abs_diff(reference[ref_idx]));
488                }
489            }
490        }
491        sad
492    }
493}
494
495// ─────────────────────────────────────────────────────────────────────────────
496// Tests
497// ─────────────────────────────────────────────────────────────────────────────
498
499#[cfg(test)]
500mod tests {
501    use super::*;
502
503    fn gray_frame(w: u32, h: u32, value: u8) -> Vec<u8> {
504        vec![value; (w * h) as usize]
505    }
506
507    /// Build a noise frame and return a version shifted by (dx, dy).
508    ///
509    /// Uses a deterministic LCG so the pattern is aperiodic — unlike a
510    /// checkerboard this ensures that the correct shift yields a uniquely
511    /// lower SAD than zero-motion.
512    fn shifted_frame(w: u32, h: u32, dx: i32, dy: i32) -> Vec<u8> {
513        // Deterministic pseudo-random base frame (LCG, no external deps).
514        let mut state: u64 = 0x5851_F42D_4C95_7F2D;
515        let mut frame = vec![0u8; (w * h) as usize];
516        for pixel in frame.iter_mut() {
517            state = state
518                .wrapping_mul(6364136223846793005)
519                .wrapping_add(1442695040888963407);
520            *pixel = ((state >> 33) & 0xFF) as u8;
521        }
522        // Produce the shifted version; pixels that fall outside get a neutral
523        // mid-grey (128) so boundary blocks don't perfectly match at zero.
524        let mut shifted = vec![128u8; (w * h) as usize];
525        for y in 0..h as i32 {
526            for x in 0..w as i32 {
527                let sx = x + dx;
528                let sy = y + dy;
529                if sx >= 0 && sy >= 0 && sx < w as i32 && sy < h as i32 {
530                    shifted[(sy as usize) * w as usize + sx as usize] =
531                        frame[y as usize * w as usize + x as usize];
532                }
533            }
534        }
535        shifted
536    }
537
538    #[test]
539    fn test_estimator_default_config() {
540        let e = MotionEstimator::av1_default();
541        assert_eq!(e.config.codec, TargetCodec::Av1);
542    }
543
544    #[test]
545    fn test_vp9_default_config() {
546        let e = MotionEstimator::vp9_default();
547        assert_eq!(e.config.codec, TargetCodec::Vp9);
548    }
549
550    #[test]
551    fn test_zero_mv_for_identical_frames() {
552        let w = 64u32;
553        let h = 64u32;
554        let frame = gray_frame(w, h, 128);
555        let e = MotionEstimator::new(MotionEstimationConfig {
556            partition: BlockPartition::Fixed16x16,
557            search_radius: 4,
558            subpixel_refinement: false,
559            ..MotionEstimationConfig::default()
560        });
561        let result = e
562            .estimate_cpu(&frame, &frame, w, h)
563            .expect("CPU estimate failed");
564        for bm in &result.block_mvs {
565            assert_eq!(bm.mv.dx, 0, "dx should be 0 for identical frames");
566            assert_eq!(bm.mv.dy, 0, "dy should be 0 for identical frames");
567        }
568    }
569
570    #[test]
571    fn test_mv_detected_for_shifted_frame() {
572        let w = 64u32;
573        let h = 64u32;
574        let reference = shifted_frame(w, h, 0, 0);
575        let current = shifted_frame(w, h, 4, 0);
576        let e = MotionEstimator::new(MotionEstimationConfig {
577            partition: BlockPartition::Fixed16x16,
578            search_radius: 8,
579            subpixel_refinement: false,
580            ..MotionEstimationConfig::default()
581        });
582        let result = e
583            .estimate_cpu(&reference, &current, w, h)
584            .expect("CPU estimate failed");
585        // Most blocks should have dx = 4 (or close to it).
586        let matched = result
587            .block_mvs
588            .iter()
589            .filter(|b| b.mv.dx.abs() >= 3)
590            .count();
591        assert!(
592            matched > result.block_mvs.len() / 2,
593            "expected most blocks to detect horizontal shift"
594        );
595    }
596
597    #[test]
598    fn test_invalid_dimensions_rejected() {
599        let e = MotionEstimator::av1_default();
600        let frame = vec![0u8; 64];
601        let result = e.estimate_cpu(&frame, &frame, 0, 8);
602        assert!(result.is_err());
603    }
604
605    #[test]
606    fn test_buffer_too_small_rejected() {
607        let e = MotionEstimator::av1_default();
608        let small = vec![0u8; 4];
609        let frame = vec![0u8; 64 * 64];
610        let result = e.estimate_cpu(&small, &frame, 64, 64);
611        assert!(result.is_err(), "undersized reference should be rejected");
612    }
613
614    #[test]
615    fn test_mean_mv_magnitude_zero_for_static() {
616        let w = 32u32;
617        let h = 32u32;
618        let frame = gray_frame(w, h, 100);
619        let e = MotionEstimator::new(MotionEstimationConfig {
620            partition: BlockPartition::Fixed16x16,
621            search_radius: 2,
622            subpixel_refinement: false,
623            ..MotionEstimationConfig::default()
624        });
625        let result = e
626            .estimate_cpu(&frame, &frame, w, h)
627            .expect("CPU estimate failed");
628        assert_eq!(result.mean_mv_magnitude(), 0.0);
629    }
630
631    #[test]
632    fn test_blocks_dimensions() {
633        let w = 64u32;
634        let h = 32u32;
635        let frame = gray_frame(w, h, 0);
636        let e = MotionEstimator::new(MotionEstimationConfig {
637            partition: BlockPartition::Fixed16x16,
638            search_radius: 2,
639            subpixel_refinement: false,
640            ..MotionEstimationConfig::default()
641        });
642        let result = e
643            .estimate_cpu(&frame, &frame, w, h)
644            .expect("CPU estimate failed");
645        assert_eq!(result.blocks_x(), 4);
646        assert_eq!(result.blocks_y(), 2);
647        assert_eq!(result.block_mvs.len(), 8);
648    }
649
650    #[test]
651    fn test_subpixel_refinement_present() {
652        let w = 16u32;
653        let h = 16u32;
654        let frame = gray_frame(w, h, 128);
655        let e = MotionEstimator::new(MotionEstimationConfig {
656            partition: BlockPartition::Fixed16x16,
657            search_radius: 2,
658            subpixel_refinement: true,
659            ..MotionEstimationConfig::default()
660        });
661        let result = e
662            .estimate_cpu(&frame, &frame, w, h)
663            .expect("CPU estimate failed");
664        for bm in &result.block_mvs {
665            assert!(
666                bm.subpixel_mv.is_some(),
667                "subpixel_mv should be present when refinement is enabled"
668            );
669        }
670    }
671
672    #[test]
673    fn test_subpixel_refinement_absent_when_disabled() {
674        let w = 16u32;
675        let h = 16u32;
676        let frame = gray_frame(w, h, 64);
677        let e = MotionEstimator::new(MotionEstimationConfig {
678            partition: BlockPartition::Fixed16x16,
679            search_radius: 2,
680            subpixel_refinement: false,
681            ..MotionEstimationConfig::default()
682        });
683        let result = e
684            .estimate_cpu(&frame, &frame, w, h)
685            .expect("CPU estimate failed");
686        for bm in &result.block_mvs {
687            assert!(
688                bm.subpixel_mv.is_none(),
689                "subpixel_mv should be absent when refinement is disabled"
690            );
691        }
692    }
693}