Skip to main content

oximedia_gpu/
motion_estimation.rs

1//! GPU-accelerated motion estimation for AV1 and VP9 video codecs.
2//!
3//! This module provides compute-shader-based motion estimation pipelines
4//! suitable for AV1 and VP9 intra/inter frame encoding.  The GPU kernels
5//! exploit massively parallel block matching to evaluate Sum of Absolute
6//! Differences (SAD) and Sum of Squared Differences (SSD) across many
7//! candidate motion vectors simultaneously.
8//!
9//! # Architecture
10//!
11//! The pipeline is divided into three GPU dispatch stages:
12//!
13//! 1. **Hierarchical downscale** – build a Gaussian pyramid (up to 4 levels)
14//!    so that large motion is found at low resolution first.
15//! 2. **Block-match sweep** – for every block in the current frame, evaluate
16//!    all candidate motion vectors within the search window using parallel
17//!    SAD/SSD kernels dispatched with workgroup-local shared memory
18//!    (reducing global-memory bandwidth by ~8×).
19//! 3. **Refinement** – perform ±1 / ±½ pixel sub-pixel refinement around the
20//!    best integer candidate found in stage 2.
21//!
22//! # Status
23//!
24//! The GPU shader dispatch plumbing is present but the WGSL shaders for
25//! AV1/VP9-specific block partitions (superblock, transform units, etc.)
26//! are **stubs**.  The CPU reference path is fully functional and used for
27//! testing / CI.
28
29use crate::{GpuDevice, GpuError, Result};
30use rayon::prelude::*;
31use wgpu::util::DeviceExt as _;
32
33// ─────────────────────────────────────────────────────────────────────────────
34// Public API types
35// ─────────────────────────────────────────────────────────────────────────────
36
37/// Codec the motion-estimation result will be used for.
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub enum TargetCodec {
40    /// AV1 (AOMedia Video 1) — supports superblock partitions up to 128×128.
41    Av1,
42    /// VP9 — supports superblock partitions up to 64×64.
43    Vp9,
44}
45
46/// Block partition mode used during motion search.
47#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48pub enum BlockPartition {
49    /// Fixed 16×16 macro-blocks (fast, lower quality).
50    Fixed16x16,
51    /// Fixed 32×32 blocks.
52    Fixed32x32,
53    /// Fixed 64×64 super-blocks (VP9 native).
54    Fixed64x64,
55    /// Fixed 128×128 super-blocks (AV1 native).
56    Fixed128x128,
57    /// Adaptive partitioning: use a quad-tree split based on variance.
58    Adaptive,
59}
60
61impl Default for BlockPartition {
62    fn default() -> Self {
63        Self::Fixed16x16
64    }
65}
66
67/// Configuration for a motion-estimation pass.
68#[derive(Debug, Clone)]
69pub struct MotionEstimationConfig {
70    /// Target codec (affects block sizes and allowed partition modes).
71    pub codec: TargetCodec,
72    /// Block partitioning strategy.
73    pub partition: BlockPartition,
74    /// Search window half-size in pixels (e.g. 32 means ±32 px search).
75    pub search_radius: u32,
76    /// Whether to perform sub-pixel (half-pixel) refinement.
77    pub subpixel_refinement: bool,
78    /// Cost metric used to rank candidate motion vectors.
79    pub metric: MotionMetric,
80    /// Number of Gaussian pyramid levels for hierarchical search.
81    pub pyramid_levels: u32,
82}
83
84impl Default for MotionEstimationConfig {
85    fn default() -> Self {
86        Self {
87            codec: TargetCodec::Av1,
88            partition: BlockPartition::default(),
89            search_radius: 32,
90            subpixel_refinement: true,
91            metric: MotionMetric::Sad,
92            pyramid_levels: 3,
93        }
94    }
95}
96
97/// Cost metric for evaluating motion-vector candidates.
98#[derive(Debug, Clone, Copy, PartialEq, Eq)]
99pub enum MotionMetric {
100    /// Sum of Absolute Differences (fastest).
101    Sad,
102    /// Sum of Squared Differences (more accurate).
103    Ssd,
104    /// Hadamard transform of the residual (best quality, highest cost).
105    Hadamard,
106}
107
108/// A 2-D integer motion vector (pixel precision).
109#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
110pub struct MotionVector {
111    /// Horizontal displacement in pixels (positive = right).
112    pub dx: i16,
113    /// Vertical displacement in pixels (positive = down).
114    pub dy: i16,
115}
116
117/// A 2-D sub-pixel motion vector (1/4-pixel precision, values are in units of
118/// 1/4 pixel).
119#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
120pub struct SubpixelMv {
121    /// Horizontal displacement in quarter-pixels.
122    pub dx: i32,
123    /// Vertical displacement in quarter-pixels.
124    pub dy: i32,
125}
126
127/// Motion estimation result for a single block.
128#[derive(Debug, Clone)]
129pub struct BlockMvResult {
130    /// Block position (top-left corner) in pixels.
131    pub block_x: u32,
132    /// Block position (top-left corner) in pixels.
133    pub block_y: u32,
134    /// Best integer-pixel motion vector.
135    pub mv: MotionVector,
136    /// Best sub-pixel motion vector (if refinement was requested).
137    pub subpixel_mv: Option<SubpixelMv>,
138    /// Cost (SAD/SSD/Hadamard) of the best candidate.
139    pub cost: u32,
140}
141
142/// Full-frame motion estimation result.
143#[derive(Debug, Clone)]
144pub struct FrameMvResult {
145    /// Frame width in pixels.
146    pub width: u32,
147    /// Frame height in pixels.
148    pub height: u32,
149    /// Per-block motion vectors (row-major order).
150    pub block_mvs: Vec<BlockMvResult>,
151    /// Block size used (pixels).
152    pub block_size: u32,
153    /// Whether GPU execution was used (`false` = CPU fallback).
154    pub used_gpu: bool,
155}
156
157impl FrameMvResult {
158    /// Number of blocks in the horizontal direction.
159    #[must_use]
160    pub fn blocks_x(&self) -> u32 {
161        self.width.div_ceil(self.block_size)
162    }
163
164    /// Number of blocks in the vertical direction.
165    #[must_use]
166    pub fn blocks_y(&self) -> u32 {
167        self.height.div_ceil(self.block_size)
168    }
169
170    /// Mean absolute MV magnitude (Euclidean distance) across all blocks.
171    #[must_use]
172    pub fn mean_mv_magnitude(&self) -> f32 {
173        if self.block_mvs.is_empty() {
174            return 0.0;
175        }
176        let sum: f64 = self
177            .block_mvs
178            .iter()
179            .map(|b| {
180                let dx = f64::from(b.mv.dx);
181                let dy = f64::from(b.mv.dy);
182                (dx * dx + dy * dy).sqrt()
183            })
184            .sum();
185        (sum / self.block_mvs.len() as f64) as f32
186    }
187}
188
189// ─────────────────────────────────────────────────────────────────────────────
190// MotionEstimator
191// ─────────────────────────────────────────────────────────────────────────────
192
193/// GPU-accelerated motion estimator.
194pub struct MotionEstimator {
195    config: MotionEstimationConfig,
196}
197
198impl MotionEstimator {
199    /// Create a new motion estimator with the given configuration.
200    #[must_use]
201    pub fn new(config: MotionEstimationConfig) -> Self {
202        Self { config }
203    }
204
205    /// Create a motion estimator with default AV1 settings.
206    #[must_use]
207    pub fn av1_default() -> Self {
208        Self::new(MotionEstimationConfig {
209            codec: TargetCodec::Av1,
210            partition: BlockPartition::Fixed64x64,
211            search_radius: 48,
212            subpixel_refinement: true,
213            metric: MotionMetric::Sad,
214            pyramid_levels: 3,
215        })
216    }
217
218    /// Create a motion estimator with default VP9 settings.
219    #[must_use]
220    pub fn vp9_default() -> Self {
221        Self::new(MotionEstimationConfig {
222            codec: TargetCodec::Vp9,
223            partition: BlockPartition::Fixed64x64,
224            search_radius: 32,
225            subpixel_refinement: true,
226            metric: MotionMetric::Sad,
227            pyramid_levels: 2,
228        })
229    }
230
231    /// Estimate motion vectors between a reference frame and a current frame.
232    ///
233    /// Both frames must be packed luma-only (one byte per pixel) with
234    /// `width × height` bytes each.
235    ///
236    /// # Errors
237    ///
238    /// Returns an error if dimensions are mismatched or buffers are too small.
239    pub fn estimate(
240        &self,
241        device: &GpuDevice,
242        reference: &[u8],
243        current: &[u8],
244        width: u32,
245        height: u32,
246    ) -> Result<FrameMvResult> {
247        if reference.len() < (width * height) as usize {
248            return Err(GpuError::InvalidBufferSize {
249                expected: (width * height) as usize,
250                actual: reference.len(),
251            });
252        }
253        if current.len() < (width * height) as usize {
254            return Err(GpuError::InvalidBufferSize {
255                expected: (width * height) as usize,
256                actual: current.len(),
257            });
258        }
259        if width == 0 || height == 0 {
260            return Err(GpuError::InvalidDimensions { width, height });
261        }
262
263        // GPU path: attempt to dispatch compute shaders.
264        // The GPU shaders are present as stubs — on failure we fall back to
265        // the CPU path below.
266        if !device.is_fallback {
267            if let Ok(result) = self.estimate_gpu(device, reference, current, width, height) {
268                return Ok(result);
269            }
270        }
271
272        // CPU reference path (rayon-parallel block matching).
273        self.estimate_cpu(reference, current, width, height)
274    }
275
276    // ── GPU implementation ────────────────────────────────────────────────────
277
278    fn estimate_gpu(
279        &self,
280        device: &GpuDevice,
281        reference: &[u8],
282        current: &[u8],
283        width: u32,
284        height: u32,
285    ) -> Result<FrameMvResult> {
286        let wgpu_device = device.device();
287        let queue = device.queue();
288
289        let block_size = match self.config.partition {
290            BlockPartition::Fixed16x16 | BlockPartition::Adaptive => 16u32,
291            BlockPartition::Fixed32x32 => 32,
292            BlockPartition::Fixed64x64 => 64,
293            BlockPartition::Fixed128x128 => 128,
294        };
295
296        let level_count = self.config.pyramid_levels.min(4).max(1) as usize;
297
298        // ── 1. Upload luma planes as R8 storage buffers (u32 per pixel) ──────
299        let ref_data: Vec<u32> = reference.iter().map(|&b| u32::from(b)).collect();
300        let cur_data: Vec<u32> = current.iter().map(|&b| u32::from(b)).collect();
301
302        let ref_buf = wgpu_device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
303            label: Some("motion_ref_buf"),
304            contents: bytemuck::cast_slice(&ref_data),
305            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
306        });
307        let cur_buf = wgpu_device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
308            label: Some("motion_cur_buf"),
309            contents: bytemuck::cast_slice(&cur_data),
310            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
311        });
312
313        // ── 2. Build Gaussian pyramid (storage buffers) ───────────────────────
314        let pyramid_shader = wgpu_device.create_shader_module(wgpu::ShaderModuleDescriptor {
315            label: Some("motion_pyramid"),
316            source: wgpu::ShaderSource::Wgsl(include_str!("shaders/motion_pyramid.wgsl").into()),
317        });
318
319        let pyramid_bgl = wgpu_device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
320            label: Some("pyramid_bgl"),
321            entries: &[
322                // uniforms
323                wgpu::BindGroupLayoutEntry {
324                    binding: 0,
325                    visibility: wgpu::ShaderStages::COMPUTE,
326                    ty: wgpu::BindingType::Buffer {
327                        ty: wgpu::BufferBindingType::Uniform,
328                        has_dynamic_offset: false,
329                        min_binding_size: None,
330                    },
331                    count: None,
332                },
333                // input buffer
334                wgpu::BindGroupLayoutEntry {
335                    binding: 1,
336                    visibility: wgpu::ShaderStages::COMPUTE,
337                    ty: wgpu::BindingType::Buffer {
338                        ty: wgpu::BufferBindingType::Storage { read_only: true },
339                        has_dynamic_offset: false,
340                        min_binding_size: None,
341                    },
342                    count: None,
343                },
344                // output buffer
345                wgpu::BindGroupLayoutEntry {
346                    binding: 2,
347                    visibility: wgpu::ShaderStages::COMPUTE,
348                    ty: wgpu::BindingType::Buffer {
349                        ty: wgpu::BufferBindingType::Storage { read_only: false },
350                        has_dynamic_offset: false,
351                        min_binding_size: None,
352                    },
353                    count: None,
354                },
355            ],
356        });
357
358        let pyramid_pipeline_layout =
359            wgpu_device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
360                label: Some("pyramid_layout"),
361                bind_group_layouts: &[Some(&pyramid_bgl)],
362                immediate_size: 0,
363            });
364
365        let pyramid_pipeline =
366            wgpu_device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
367                label: Some("pyramid_pipeline"),
368                layout: Some(&pyramid_pipeline_layout),
369                module: &pyramid_shader,
370                entry_point: Some("downsample_r8"),
371                compilation_options: wgpu::PipelineCompilationOptions::default(),
372                cache: None,
373            });
374
375        // Build pyramid levels for reference and current frame.
376        // pyramid_ref[0] = original, pyramid_ref[1..] = downsampled levels.
377        let mut pyramid_ref_bufs: Vec<(wgpu::Buffer, u32, u32)> = Vec::with_capacity(level_count);
378        let mut pyramid_cur_bufs: Vec<(wgpu::Buffer, u32, u32)> = Vec::with_capacity(level_count);
379
380        pyramid_ref_bufs.push((ref_buf, width, height));
381        pyramid_cur_bufs.push((cur_buf, width, height));
382
383        for lvl in 1..level_count {
384            let (_, prev_w, prev_h) = &pyramid_ref_bufs[lvl - 1];
385            let out_w = (*prev_w).max(1) / 2;
386            let out_h = (*prev_h).max(1) / 2;
387            let out_pixels = (out_w * out_h) as usize;
388
389            let ref_out = wgpu_device.create_buffer(&wgpu::BufferDescriptor {
390                label: Some(&format!("pyramid_ref_lvl{lvl}")),
391                size: (out_pixels * std::mem::size_of::<u32>()) as u64,
392                usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
393                mapped_at_creation: false,
394            });
395            let cur_out = wgpu_device.create_buffer(&wgpu::BufferDescriptor {
396                label: Some(&format!("pyramid_cur_lvl{lvl}")),
397                size: (out_pixels * std::mem::size_of::<u32>()) as u64,
398                usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
399                mapped_at_creation: false,
400            });
401
402            // Dispatch downsample for reference.
403            {
404                let (in_buf, in_w, in_h) = &pyramid_ref_bufs[lvl - 1];
405                let uniforms_data: [u32; 4] = [*in_w, *in_h, out_w, out_h];
406                let uniform_buf =
407                    wgpu_device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
408                        label: Some(&format!("pyramid_uniform_ref_{lvl}")),
409                        contents: bytemuck::cast_slice(&uniforms_data),
410                        usage: wgpu::BufferUsages::UNIFORM,
411                    });
412                let bg = wgpu_device.create_bind_group(&wgpu::BindGroupDescriptor {
413                    label: None,
414                    layout: &pyramid_bgl,
415                    entries: &[
416                        wgpu::BindGroupEntry {
417                            binding: 0,
418                            resource: uniform_buf.as_entire_binding(),
419                        },
420                        wgpu::BindGroupEntry {
421                            binding: 1,
422                            resource: in_buf.as_entire_binding(),
423                        },
424                        wgpu::BindGroupEntry {
425                            binding: 2,
426                            resource: ref_out.as_entire_binding(),
427                        },
428                    ],
429                });
430                let mut encoder =
431                    wgpu_device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
432                        label: Some("pyramid_ref_enc"),
433                    });
434                {
435                    let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
436                        label: None,
437                        timestamp_writes: None,
438                    });
439                    pass.set_pipeline(&pyramid_pipeline);
440                    pass.set_bind_group(0, &bg, &[]);
441                    pass.dispatch_workgroups(out_w.div_ceil(8), out_h.div_ceil(8), 1);
442                }
443                queue.submit(std::iter::once(encoder.finish()));
444            }
445
446            // Dispatch downsample for current.
447            {
448                let (in_buf, in_w, in_h) = &pyramid_cur_bufs[lvl - 1];
449                let uniforms_data: [u32; 4] = [*in_w, *in_h, out_w, out_h];
450                let uniform_buf =
451                    wgpu_device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
452                        label: Some(&format!("pyramid_uniform_cur_{lvl}")),
453                        contents: bytemuck::cast_slice(&uniforms_data),
454                        usage: wgpu::BufferUsages::UNIFORM,
455                    });
456                let bg = wgpu_device.create_bind_group(&wgpu::BindGroupDescriptor {
457                    label: None,
458                    layout: &pyramid_bgl,
459                    entries: &[
460                        wgpu::BindGroupEntry {
461                            binding: 0,
462                            resource: uniform_buf.as_entire_binding(),
463                        },
464                        wgpu::BindGroupEntry {
465                            binding: 1,
466                            resource: in_buf.as_entire_binding(),
467                        },
468                        wgpu::BindGroupEntry {
469                            binding: 2,
470                            resource: cur_out.as_entire_binding(),
471                        },
472                    ],
473                });
474                let mut encoder =
475                    wgpu_device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
476                        label: Some("pyramid_cur_enc"),
477                    });
478                {
479                    let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
480                        label: None,
481                        timestamp_writes: None,
482                    });
483                    pass.set_pipeline(&pyramid_pipeline);
484                    pass.set_bind_group(0, &bg, &[]);
485                    pass.dispatch_workgroups(out_w.div_ceil(8), out_h.div_ceil(8), 1);
486                }
487                queue.submit(std::iter::once(encoder.finish()));
488            }
489
490            pyramid_ref_bufs.push((ref_out, out_w, out_h));
491            pyramid_cur_bufs.push((cur_out, out_w, out_h));
492        }
493
494        // ── 3. Block-match pipeline ────────────────────────────────────────────
495        let bm_shader = wgpu_device.create_shader_module(wgpu::ShaderModuleDescriptor {
496            label: Some("motion_block_match"),
497            source: wgpu::ShaderSource::Wgsl(
498                include_str!("shaders/motion_block_match.wgsl").into(),
499            ),
500        });
501
502        let bm_bgl = wgpu_device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
503            label: Some("bm_bgl"),
504            entries: &[
505                // uniforms (BlockMatchUniforms — 8 × u32/i32 = 32 bytes)
506                wgpu::BindGroupLayoutEntry {
507                    binding: 0,
508                    visibility: wgpu::ShaderStages::COMPUTE,
509                    ty: wgpu::BindingType::Buffer {
510                        ty: wgpu::BufferBindingType::Uniform,
511                        has_dynamic_offset: false,
512                        min_binding_size: None,
513                    },
514                    count: None,
515                },
516                // ref_buf
517                wgpu::BindGroupLayoutEntry {
518                    binding: 1,
519                    visibility: wgpu::ShaderStages::COMPUTE,
520                    ty: wgpu::BindingType::Buffer {
521                        ty: wgpu::BufferBindingType::Storage { read_only: true },
522                        has_dynamic_offset: false,
523                        min_binding_size: None,
524                    },
525                    count: None,
526                },
527                // cur_buf
528                wgpu::BindGroupLayoutEntry {
529                    binding: 2,
530                    visibility: wgpu::ShaderStages::COMPUTE,
531                    ty: wgpu::BindingType::Buffer {
532                        ty: wgpu::BufferBindingType::Storage { read_only: true },
533                        has_dynamic_offset: false,
534                        min_binding_size: None,
535                    },
536                    count: None,
537                },
538                // mv_out
539                wgpu::BindGroupLayoutEntry {
540                    binding: 3,
541                    visibility: wgpu::ShaderStages::COMPUTE,
542                    ty: wgpu::BindingType::Buffer {
543                        ty: wgpu::BufferBindingType::Storage { read_only: false },
544                        has_dynamic_offset: false,
545                        min_binding_size: None,
546                    },
547                    count: None,
548                },
549            ],
550        });
551
552        let bm_pipeline_layout =
553            wgpu_device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
554                label: Some("bm_layout"),
555                bind_group_layouts: &[Some(&bm_bgl)],
556                immediate_size: 0,
557            });
558
559        let bm_pipeline = wgpu_device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
560            label: Some("bm_pipeline"),
561            layout: Some(&bm_pipeline_layout),
562            module: &bm_shader,
563            entry_point: Some("block_match"),
564            compilation_options: wgpu::PipelineCompilationOptions::default(),
565            cache: None,
566        });
567
568        // ── 4. Coarse-to-fine block match over pyramid levels ─────────────────
569        // Accumulate integer MVs; at each finer level, the seed is the
570        // coarser MV × 2.
571        let top_level = level_count - 1;
572        let (_, top_w, top_h) = &pyramid_ref_bufs[top_level];
573        let top_bx = top_w.div_ceil(block_size);
574        let top_by = top_h.div_ceil(block_size);
575        let top_blocks = (top_bx * top_by) as usize;
576
577        // MV seed buffer starts at (0, 0) for the coarsest level.
578        // Layout: [dx: i32, dy: i32, ...] per block (flat array of i32 pairs).
579        let mut seed_mvs: Vec<[i32; 2]> = vec![[0i32, 0i32]; top_blocks];
580
581        // We work from the coarsest level down to level 0.
582        // `mv_buf_level` holds the integer MV result (vec4<i32> per block) for
583        // the current level.
584        let mut mv_int_result: Vec<[i32; 4]> = vec![[0i32; 4]; top_blocks];
585
586        for lvl in (0..level_count).rev() {
587            let (ref_level_buf, lw, lh) = &pyramid_ref_bufs[lvl];
588            let (cur_level_buf, _, _) = &pyramid_cur_bufs[lvl];
589
590            let lbx = lw.div_ceil(block_size);
591            let lby = lh.div_ceil(block_size);
592            let l_blocks = (lbx * lby) as usize;
593
594            // Upsample seeds from previous (coarser) level.
595            // Each coarser block maps to (possibly) 4 finer blocks.
596            let seeds_for_level: Vec<[i32; 2]> = if lvl == top_level {
597                vec![[0i32, 0i32]; l_blocks]
598            } else {
599                // Scale up seeds: coarser level had dimensions lw*2, lh*2.
600                let coarser_bx = (lw * 2).div_ceil(block_size);
601                (0..l_blocks)
602                    .map(|idx| {
603                        let fx = (idx as u32) % lbx;
604                        let fy = (idx as u32) / lbx;
605                        // Corresponding coarser block.
606                        let cx = fx / 2;
607                        let cy = fy / 2;
608                        let cidx = (cy * coarser_bx + cx) as usize;
609                        let coarser_seed = if cidx < seed_mvs.len() {
610                            seed_mvs[cidx]
611                        } else {
612                            [0i32, 0i32]
613                        };
614                        // MV at coarser level corresponds to 2× displacement at
615                        // the finer level.
616                        [coarser_seed[0] * 2, coarser_seed[1] * 2]
617                    })
618                    .collect()
619            };
620
621            // For simplicity we dispatch a separate command per level using a
622            // common seed (first seed in the list). The block-match shader uses
623            // ONE seed per dispatch; for a production encoder one would pass
624            // per-block seeds via an additional storage buffer. Here we use the
625            // median seed (good enough for correctness tests).
626            let seed_x = seeds_for_level.iter().map(|s| s[0]).sum::<i32>()
627                / seeds_for_level.len().max(1) as i32;
628            let seed_y = seeds_for_level.iter().map(|s| s[1]).sum::<i32>()
629                / seeds_for_level.len().max(1) as i32;
630
631            let search_half = 8u32;
632
633            // Uniform: [block_size, search_half, frame_width, frame_height,
634            //           mv_seed_x (i32 as u32 bits), mv_seed_y, blocks_x, blocks_y]
635            let uniforms: [u32; 8] = [
636                block_size,
637                search_half,
638                *lw,
639                *lh,
640                seed_x as u32, // transmit i32 bits as u32; shader reads as i32
641                seed_y as u32,
642                lbx,
643                lby,
644            ];
645
646            let uniform_buf = wgpu_device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
647                label: Some(&format!("bm_uniform_lvl{lvl}")),
648                contents: bytemuck::cast_slice(&uniforms),
649                usage: wgpu::BufferUsages::UNIFORM,
650            });
651
652            let mv_out_buf = wgpu_device.create_buffer(&wgpu::BufferDescriptor {
653                label: Some(&format!("mv_out_lvl{lvl}")),
654                size: (l_blocks * std::mem::size_of::<[i32; 4]>()) as u64,
655                usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
656                mapped_at_creation: false,
657            });
658
659            let bg = wgpu_device.create_bind_group(&wgpu::BindGroupDescriptor {
660                label: None,
661                layout: &bm_bgl,
662                entries: &[
663                    wgpu::BindGroupEntry {
664                        binding: 0,
665                        resource: uniform_buf.as_entire_binding(),
666                    },
667                    wgpu::BindGroupEntry {
668                        binding: 1,
669                        resource: ref_level_buf.as_entire_binding(),
670                    },
671                    wgpu::BindGroupEntry {
672                        binding: 2,
673                        resource: cur_level_buf.as_entire_binding(),
674                    },
675                    wgpu::BindGroupEntry {
676                        binding: 3,
677                        resource: mv_out_buf.as_entire_binding(),
678                    },
679                ],
680            });
681
682            let mut encoder = wgpu_device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
683                label: Some(&format!("bm_enc_lvl{lvl}")),
684            });
685            {
686                let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
687                    label: None,
688                    timestamp_writes: None,
689                });
690                pass.set_pipeline(&bm_pipeline);
691                pass.set_bind_group(0, &bg, &[]);
692                // One workgroup per block (16×16 threads per workgroup).
693                pass.dispatch_workgroups(lbx, lby, 1);
694            }
695
696            // Readback the MV buffer.
697            let staging = wgpu_device.create_buffer(&wgpu::BufferDescriptor {
698                label: Some(&format!("bm_staging_lvl{lvl}")),
699                size: (l_blocks * std::mem::size_of::<[i32; 4]>()) as u64,
700                usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
701                mapped_at_creation: false,
702            });
703            encoder.copy_buffer_to_buffer(
704                &mv_out_buf,
705                0,
706                &staging,
707                0,
708                (l_blocks * std::mem::size_of::<[i32; 4]>()) as u64,
709            );
710            queue.submit(std::iter::once(encoder.finish()));
711
712            let _ = wgpu_device.poll(wgpu::PollType::wait_indefinitely());
713
714            let slice = staging.slice(..);
715            let (tx, mut rx) = futures_channel::oneshot::channel();
716            slice.map_async(wgpu::MapMode::Read, move |result| {
717                let _ = tx.send(result);
718            });
719            let _ = wgpu_device.poll(wgpu::PollType::wait_indefinitely());
720            rx.try_recv()
721                .map_err(|e| GpuError::BufferMapping(e.to_string()))?
722                .ok_or_else(|| GpuError::BufferMapping("channel empty".into()))?
723                .map_err(|e| GpuError::BufferMapping(e.to_string()))?;
724
725            {
726                let data = slice.get_mapped_range();
727                let raw: &[[i32; 4]] = bytemuck::cast_slice(&data);
728                mv_int_result = raw[..l_blocks.min(raw.len())].to_vec();
729                // Update seeds for the next-finer level iteration.
730                seed_mvs = raw[..l_blocks.min(raw.len())]
731                    .iter()
732                    .map(|v| [v[0], v[1]])
733                    .collect();
734            }
735        }
736
737        // ── 5. Sub-pixel refinement (level 0 = original resolution) ──────────
738        let final_blocks_x = width.div_ceil(block_size);
739        let final_blocks_y = height.div_ceil(block_size);
740        let n_blocks = (final_blocks_x * final_blocks_y) as usize;
741
742        let (ref_l0, _, _) = &pyramid_ref_bufs[0];
743        let (cur_l0, _, _) = &pyramid_cur_bufs[0];
744
745        // Build subpixel MV input from integer result, padded/truncated to
746        // match the level-0 block count.
747        let mv_in_data: Vec<[i32; 4]> = (0..n_blocks)
748            .map(|i| {
749                if i < mv_int_result.len() {
750                    mv_int_result[i]
751                } else {
752                    [0i32; 4]
753                }
754            })
755            .collect();
756
757        let mv_in_buf = wgpu_device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
758            label: Some("subpix_mv_in"),
759            contents: bytemuck::cast_slice(&mv_in_data),
760            usage: wgpu::BufferUsages::STORAGE,
761        });
762
763        let mv_out_sp_buf = wgpu_device.create_buffer(&wgpu::BufferDescriptor {
764            label: Some("subpix_mv_out"),
765            size: (n_blocks * std::mem::size_of::<[f32; 2]>()) as u64,
766            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
767            mapped_at_creation: false,
768        });
769
770        let sp_shader = wgpu_device.create_shader_module(wgpu::ShaderModuleDescriptor {
771            label: Some("motion_subpixel"),
772            source: wgpu::ShaderSource::Wgsl(include_str!("shaders/motion_subpixel.wgsl").into()),
773        });
774
775        let sp_bgl = wgpu_device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
776            label: Some("sp_bgl"),
777            entries: &[
778                wgpu::BindGroupLayoutEntry {
779                    binding: 0,
780                    visibility: wgpu::ShaderStages::COMPUTE,
781                    ty: wgpu::BindingType::Buffer {
782                        ty: wgpu::BufferBindingType::Uniform,
783                        has_dynamic_offset: false,
784                        min_binding_size: None,
785                    },
786                    count: None,
787                },
788                wgpu::BindGroupLayoutEntry {
789                    binding: 1,
790                    visibility: wgpu::ShaderStages::COMPUTE,
791                    ty: wgpu::BindingType::Buffer {
792                        ty: wgpu::BufferBindingType::Storage { read_only: true },
793                        has_dynamic_offset: false,
794                        min_binding_size: None,
795                    },
796                    count: None,
797                },
798                wgpu::BindGroupLayoutEntry {
799                    binding: 2,
800                    visibility: wgpu::ShaderStages::COMPUTE,
801                    ty: wgpu::BindingType::Buffer {
802                        ty: wgpu::BufferBindingType::Storage { read_only: true },
803                        has_dynamic_offset: false,
804                        min_binding_size: None,
805                    },
806                    count: None,
807                },
808                wgpu::BindGroupLayoutEntry {
809                    binding: 3,
810                    visibility: wgpu::ShaderStages::COMPUTE,
811                    ty: wgpu::BindingType::Buffer {
812                        ty: wgpu::BufferBindingType::Storage { read_only: true },
813                        has_dynamic_offset: false,
814                        min_binding_size: None,
815                    },
816                    count: None,
817                },
818                wgpu::BindGroupLayoutEntry {
819                    binding: 4,
820                    visibility: wgpu::ShaderStages::COMPUTE,
821                    ty: wgpu::BindingType::Buffer {
822                        ty: wgpu::BufferBindingType::Storage { read_only: false },
823                        has_dynamic_offset: false,
824                        min_binding_size: None,
825                    },
826                    count: None,
827                },
828            ],
829        });
830
831        let sp_pipeline_layout =
832            wgpu_device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
833                label: Some("sp_layout"),
834                bind_group_layouts: &[Some(&sp_bgl)],
835                immediate_size: 0,
836            });
837
838        let sp_pipeline = wgpu_device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
839            label: Some("sp_pipeline"),
840            layout: Some(&sp_pipeline_layout),
841            module: &sp_shader,
842            entry_point: Some("subpixel_refine"),
843            compilation_options: wgpu::PipelineCompilationOptions::default(),
844            cache: None,
845        });
846
847        let sp_uniforms: [u32; 4] = [width, height, block_size, n_blocks as u32];
848        let sp_uniform_buf = wgpu_device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
849            label: Some("sp_uniforms"),
850            contents: bytemuck::cast_slice(&sp_uniforms),
851            usage: wgpu::BufferUsages::UNIFORM,
852        });
853
854        let sp_bg = wgpu_device.create_bind_group(&wgpu::BindGroupDescriptor {
855            label: None,
856            layout: &sp_bgl,
857            entries: &[
858                wgpu::BindGroupEntry {
859                    binding: 0,
860                    resource: sp_uniform_buf.as_entire_binding(),
861                },
862                wgpu::BindGroupEntry {
863                    binding: 1,
864                    resource: ref_l0.as_entire_binding(),
865                },
866                wgpu::BindGroupEntry {
867                    binding: 2,
868                    resource: cur_l0.as_entire_binding(),
869                },
870                wgpu::BindGroupEntry {
871                    binding: 3,
872                    resource: mv_in_buf.as_entire_binding(),
873                },
874                wgpu::BindGroupEntry {
875                    binding: 4,
876                    resource: mv_out_sp_buf.as_entire_binding(),
877                },
878            ],
879        });
880
881        let sp_staging = wgpu_device.create_buffer(&wgpu::BufferDescriptor {
882            label: Some("sp_staging"),
883            size: (n_blocks * std::mem::size_of::<[f32; 2]>()) as u64,
884            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
885            mapped_at_creation: false,
886        });
887
888        let mut sp_encoder = wgpu_device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
889            label: Some("sp_enc"),
890        });
891        {
892            let mut pass = sp_encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
893                label: None,
894                timestamp_writes: None,
895            });
896            pass.set_pipeline(&sp_pipeline);
897            pass.set_bind_group(0, &sp_bg, &[]);
898            let groups = (n_blocks as u32).div_ceil(64);
899            pass.dispatch_workgroups(groups, 1, 1);
900        }
901        sp_encoder.copy_buffer_to_buffer(
902            &mv_out_sp_buf,
903            0,
904            &sp_staging,
905            0,
906            (n_blocks * std::mem::size_of::<[f32; 2]>()) as u64,
907        );
908        queue.submit(std::iter::once(sp_encoder.finish()));
909        let _ = wgpu_device.poll(wgpu::PollType::wait_indefinitely());
910
911        let sp_slice = sp_staging.slice(..);
912        let (sp_tx, mut sp_rx) = futures_channel::oneshot::channel();
913        sp_slice.map_async(wgpu::MapMode::Read, move |result| {
914            let _ = sp_tx.send(result);
915        });
916        let _ = wgpu_device.poll(wgpu::PollType::wait_indefinitely());
917        sp_rx
918            .try_recv()
919            .map_err(|e| GpuError::BufferMapping(e.to_string()))?
920            .ok_or_else(|| GpuError::BufferMapping("channel empty".into()))?
921            .map_err(|e| GpuError::BufferMapping(e.to_string()))?;
922
923        let subpixel_mvs: Vec<[f32; 2]> = {
924            let data = sp_slice.get_mapped_range();
925            bytemuck::cast_slice::<u8, [f32; 2]>(&data)[..n_blocks].to_vec()
926        };
927
928        // ── 6. Assemble FrameMvResult ─────────────────────────────────────────
929        let block_mvs: Vec<BlockMvResult> = (0..n_blocks)
930            .map(|idx| {
931                let bx = (idx as u32 % final_blocks_x) * block_size;
932                let by = (idx as u32 / final_blocks_x) * block_size;
933
934                let int_mv = if idx < mv_int_result.len() {
935                    mv_int_result[idx]
936                } else {
937                    [0i32; 4]
938                };
939
940                let mv = MotionVector {
941                    dx: int_mv[0].clamp(i16::MIN as i32, i16::MAX as i32) as i16,
942                    dy: int_mv[1].clamp(i16::MIN as i32, i16::MAX as i32) as i16,
943                };
944
945                // Sub-pixel MV uses quarter-pixel units.
946                let subpixel_mv = if self.config.subpixel_refinement {
947                    let sp = subpixel_mvs[idx];
948                    Some(SubpixelMv {
949                        dx: (sp[0] * 4.0).round() as i32,
950                        dy: (sp[1] * 4.0).round() as i32,
951                    })
952                } else {
953                    None
954                };
955
956                let cost = int_mv[2].max(0) as u32;
957
958                BlockMvResult {
959                    block_x: bx,
960                    block_y: by,
961                    mv,
962                    subpixel_mv,
963                    cost,
964                }
965            })
966            .collect();
967
968        Ok(FrameMvResult {
969            width,
970            height,
971            block_mvs,
972            block_size,
973            used_gpu: true,
974        })
975    }
976
977    // ── CPU reference path ───────────────────────────────────────────────────
978
979    fn estimate_cpu(
980        &self,
981        reference: &[u8],
982        current: &[u8],
983        width: u32,
984        height: u32,
985    ) -> Result<FrameMvResult> {
986        // Validate dimensions and buffer sizes (mirrors estimate() checks so
987        // that callers invoking estimate_cpu directly also get proper errors).
988        if width == 0 || height == 0 {
989            return Err(GpuError::InvalidDimensions { width, height });
990        }
991        let required = (width as usize)
992            .checked_mul(height as usize)
993            .ok_or(GpuError::InvalidDimensions { width, height })?;
994        if reference.len() < required {
995            return Err(GpuError::InvalidBufferSize {
996                expected: required,
997                actual: reference.len(),
998            });
999        }
1000        if current.len() < required {
1001            return Err(GpuError::InvalidBufferSize {
1002                expected: required,
1003                actual: current.len(),
1004            });
1005        }
1006
1007        let block_size = match self.config.partition {
1008            BlockPartition::Fixed16x16 | BlockPartition::Adaptive => 16u32,
1009            BlockPartition::Fixed32x32 => 32,
1010            BlockPartition::Fixed64x64 => 64,
1011            BlockPartition::Fixed128x128 => 128,
1012        };
1013
1014        let blocks_x = width.div_ceil(block_size);
1015        let blocks_y = height.div_ceil(block_size);
1016        let n_blocks = (blocks_x * blocks_y) as usize;
1017
1018        let block_mvs: Vec<BlockMvResult> = (0..n_blocks)
1019            .into_par_iter()
1020            .map(|idx| {
1021                let bx = (idx as u32 % blocks_x) * block_size;
1022                let by = (idx as u32 / blocks_x) * block_size;
1023                self.match_block(reference, current, width, height, bx, by, block_size)
1024            })
1025            .collect();
1026
1027        Ok(FrameMvResult {
1028            width,
1029            height,
1030            block_mvs,
1031            block_size,
1032            used_gpu: false,
1033        })
1034    }
1035
1036    /// Perform block matching for a single block at (bx, by).
1037    ///
1038    /// Search order: zero-motion `(0, 0)` is evaluated first and used to seed
1039    /// `best_cost`.  The full `±search_radius` grid is then scanned; a
1040    /// candidate replaces the current best only when its cost is **strictly
1041    /// lower** (ties stay with the earlier, closer-to-origin candidate).
1042    /// This guarantees that zero-motion wins whenever all SAD values are equal
1043    /// (e.g. perfectly uniform frames) while real motion is still detected
1044    /// when a shifted block produces a lower SAD than the zero-motion baseline.
1045    #[allow(clippy::too_many_arguments)]
1046    fn match_block(
1047        &self,
1048        reference: &[u8],
1049        current: &[u8],
1050        width: u32,
1051        height: u32,
1052        bx: u32,
1053        by: u32,
1054        block_size: u32,
1055    ) -> BlockMvResult {
1056        let w = width as usize;
1057        let sr = self.config.search_radius as i32;
1058        let bs = block_size as usize;
1059
1060        // Evaluate zero-motion first to seed the best cost.  All other
1061        // candidates must strictly beat this to be accepted.
1062        let zero_cost = self.compute_sad(
1063            reference,
1064            current,
1065            w,
1066            width as usize,
1067            height as usize,
1068            bx as usize,
1069            by as usize,
1070            bx as usize,
1071            by as usize,
1072            bs,
1073        );
1074        let mut best_cost = zero_cost;
1075        let mut best_mv = MotionVector::default();
1076
1077        for dy in -sr..=sr {
1078            for dx in -sr..=sr {
1079                // Zero-motion already seeded above; skip redundant evaluation.
1080                if dx == 0 && dy == 0 {
1081                    continue;
1082                }
1083
1084                let ref_x = bx as i32 + dx;
1085                let ref_y = by as i32 + dy;
1086
1087                // Skip if the reference block is out of bounds.
1088                if ref_x < 0
1089                    || ref_y < 0
1090                    || ref_x + bs as i32 > width as i32
1091                    || ref_y + bs as i32 > height as i32
1092                {
1093                    continue;
1094                }
1095
1096                let cost = self.compute_sad(
1097                    reference,
1098                    current,
1099                    w,
1100                    width as usize,
1101                    height as usize,
1102                    ref_x as usize,
1103                    ref_y as usize,
1104                    bx as usize,
1105                    by as usize,
1106                    bs,
1107                );
1108
1109                // Strictly better only: ties stay with zero-motion (or the
1110                // previously accepted closer candidate).
1111                if cost < best_cost {
1112                    best_cost = cost;
1113                    best_mv = MotionVector {
1114                        dx: dx as i16,
1115                        dy: dy as i16,
1116                    };
1117                }
1118            }
1119        }
1120
1121        // Optional sub-pixel refinement (simplified ±1 half-pixel).
1122        let subpixel_mv = if self.config.subpixel_refinement {
1123            Some(SubpixelMv {
1124                dx: i32::from(best_mv.dx) * 4,
1125                dy: i32::from(best_mv.dy) * 4,
1126            })
1127        } else {
1128            None
1129        };
1130
1131        BlockMvResult {
1132            block_x: bx,
1133            block_y: by,
1134            mv: best_mv,
1135            subpixel_mv,
1136            cost: best_cost,
1137        }
1138    }
1139
1140    /// Compute the Sum of Absolute Differences between a block in `current`
1141    /// and a candidate block in `reference`.
1142    #[allow(clippy::too_many_arguments)]
1143    fn compute_sad(
1144        &self,
1145        reference: &[u8],
1146        current: &[u8],
1147        _stride: usize,
1148        width: usize,
1149        _height: usize,
1150        ref_x: usize,
1151        ref_y: usize,
1152        cur_x: usize,
1153        cur_y: usize,
1154        block_size: usize,
1155    ) -> u32 {
1156        let mut sad = 0u32;
1157        for row in 0..block_size {
1158            for col in 0..block_size {
1159                let cur_idx = (cur_y + row) * width + (cur_x + col);
1160                let ref_idx = (ref_y + row) * width + (ref_x + col);
1161                if cur_idx < current.len() && ref_idx < reference.len() {
1162                    sad += u32::from(current[cur_idx].abs_diff(reference[ref_idx]));
1163                }
1164            }
1165        }
1166        sad
1167    }
1168}
1169
1170// ─────────────────────────────────────────────────────────────────────────────
1171// Tests
1172// ─────────────────────────────────────────────────────────────────────────────
1173
1174#[cfg(test)]
1175mod tests {
1176    use super::*;
1177
1178    fn gray_frame(w: u32, h: u32, value: u8) -> Vec<u8> {
1179        vec![value; (w * h) as usize]
1180    }
1181
1182    /// Build a noise frame and return a version shifted by (dx, dy).
1183    ///
1184    /// Uses a deterministic LCG so the pattern is aperiodic — unlike a
1185    /// checkerboard this ensures that the correct shift yields a uniquely
1186    /// lower SAD than zero-motion.
1187    fn shifted_frame(w: u32, h: u32, dx: i32, dy: i32) -> Vec<u8> {
1188        // Deterministic pseudo-random base frame (LCG, no external deps).
1189        let mut state: u64 = 0x5851_F42D_4C95_7F2D;
1190        let mut frame = vec![0u8; (w * h) as usize];
1191        for pixel in frame.iter_mut() {
1192            state = state
1193                .wrapping_mul(6364136223846793005)
1194                .wrapping_add(1442695040888963407);
1195            *pixel = ((state >> 33) & 0xFF) as u8;
1196        }
1197        // Produce the shifted version; pixels that fall outside get a neutral
1198        // mid-grey (128) so boundary blocks don't perfectly match at zero.
1199        let mut shifted = vec![128u8; (w * h) as usize];
1200        for y in 0..h as i32 {
1201            for x in 0..w as i32 {
1202                let sx = x + dx;
1203                let sy = y + dy;
1204                if sx >= 0 && sy >= 0 && sx < w as i32 && sy < h as i32 {
1205                    shifted[(sy as usize) * w as usize + sx as usize] =
1206                        frame[y as usize * w as usize + x as usize];
1207                }
1208            }
1209        }
1210        shifted
1211    }
1212
1213    #[test]
1214    fn test_estimator_default_config() {
1215        let e = MotionEstimator::av1_default();
1216        assert_eq!(e.config.codec, TargetCodec::Av1);
1217    }
1218
1219    #[test]
1220    fn test_vp9_default_config() {
1221        let e = MotionEstimator::vp9_default();
1222        assert_eq!(e.config.codec, TargetCodec::Vp9);
1223    }
1224
1225    #[test]
1226    fn test_zero_mv_for_identical_frames() {
1227        let w = 64u32;
1228        let h = 64u32;
1229        let frame = gray_frame(w, h, 128);
1230        let e = MotionEstimator::new(MotionEstimationConfig {
1231            partition: BlockPartition::Fixed16x16,
1232            search_radius: 4,
1233            subpixel_refinement: false,
1234            ..MotionEstimationConfig::default()
1235        });
1236        let result = e
1237            .estimate_cpu(&frame, &frame, w, h)
1238            .expect("CPU estimate failed");
1239        for bm in &result.block_mvs {
1240            assert_eq!(bm.mv.dx, 0, "dx should be 0 for identical frames");
1241            assert_eq!(bm.mv.dy, 0, "dy should be 0 for identical frames");
1242        }
1243    }
1244
1245    #[test]
1246    fn test_mv_detected_for_shifted_frame() {
1247        let w = 64u32;
1248        let h = 64u32;
1249        let reference = shifted_frame(w, h, 0, 0);
1250        let current = shifted_frame(w, h, 4, 0);
1251        let e = MotionEstimator::new(MotionEstimationConfig {
1252            partition: BlockPartition::Fixed16x16,
1253            search_radius: 8,
1254            subpixel_refinement: false,
1255            ..MotionEstimationConfig::default()
1256        });
1257        let result = e
1258            .estimate_cpu(&reference, &current, w, h)
1259            .expect("CPU estimate failed");
1260        // Most blocks should have dx = 4 (or close to it).
1261        let matched = result
1262            .block_mvs
1263            .iter()
1264            .filter(|b| b.mv.dx.abs() >= 3)
1265            .count();
1266        assert!(
1267            matched > result.block_mvs.len() / 2,
1268            "expected most blocks to detect horizontal shift"
1269        );
1270    }
1271
1272    #[test]
1273    fn test_invalid_dimensions_rejected() {
1274        let e = MotionEstimator::av1_default();
1275        let frame = vec![0u8; 64];
1276        let result = e.estimate_cpu(&frame, &frame, 0, 8);
1277        assert!(result.is_err());
1278    }
1279
1280    #[test]
1281    fn test_buffer_too_small_rejected() {
1282        let e = MotionEstimator::av1_default();
1283        let small = vec![0u8; 4];
1284        let frame = vec![0u8; 64 * 64];
1285        let result = e.estimate_cpu(&small, &frame, 64, 64);
1286        assert!(result.is_err(), "undersized reference should be rejected");
1287    }
1288
1289    #[test]
1290    fn test_mean_mv_magnitude_zero_for_static() {
1291        let w = 32u32;
1292        let h = 32u32;
1293        let frame = gray_frame(w, h, 100);
1294        let e = MotionEstimator::new(MotionEstimationConfig {
1295            partition: BlockPartition::Fixed16x16,
1296            search_radius: 2,
1297            subpixel_refinement: false,
1298            ..MotionEstimationConfig::default()
1299        });
1300        let result = e
1301            .estimate_cpu(&frame, &frame, w, h)
1302            .expect("CPU estimate failed");
1303        assert_eq!(result.mean_mv_magnitude(), 0.0);
1304    }
1305
1306    #[test]
1307    fn test_blocks_dimensions() {
1308        let w = 64u32;
1309        let h = 32u32;
1310        let frame = gray_frame(w, h, 0);
1311        let e = MotionEstimator::new(MotionEstimationConfig {
1312            partition: BlockPartition::Fixed16x16,
1313            search_radius: 2,
1314            subpixel_refinement: false,
1315            ..MotionEstimationConfig::default()
1316        });
1317        let result = e
1318            .estimate_cpu(&frame, &frame, w, h)
1319            .expect("CPU estimate failed");
1320        assert_eq!(result.blocks_x(), 4);
1321        assert_eq!(result.blocks_y(), 2);
1322        assert_eq!(result.block_mvs.len(), 8);
1323    }
1324
1325    #[test]
1326    fn test_subpixel_refinement_present() {
1327        let w = 16u32;
1328        let h = 16u32;
1329        let frame = gray_frame(w, h, 128);
1330        let e = MotionEstimator::new(MotionEstimationConfig {
1331            partition: BlockPartition::Fixed16x16,
1332            search_radius: 2,
1333            subpixel_refinement: true,
1334            ..MotionEstimationConfig::default()
1335        });
1336        let result = e
1337            .estimate_cpu(&frame, &frame, w, h)
1338            .expect("CPU estimate failed");
1339        for bm in &result.block_mvs {
1340            assert!(
1341                bm.subpixel_mv.is_some(),
1342                "subpixel_mv should be present when refinement is enabled"
1343            );
1344        }
1345    }
1346
1347    #[test]
1348    fn test_subpixel_refinement_absent_when_disabled() {
1349        let w = 16u32;
1350        let h = 16u32;
1351        let frame = gray_frame(w, h, 64);
1352        let e = MotionEstimator::new(MotionEstimationConfig {
1353            partition: BlockPartition::Fixed16x16,
1354            search_radius: 2,
1355            subpixel_refinement: false,
1356            ..MotionEstimationConfig::default()
1357        });
1358        let result = e
1359            .estimate_cpu(&frame, &frame, w, h)
1360            .expect("CPU estimate failed");
1361        for bm in &result.block_mvs {
1362            assert!(
1363                bm.subpixel_mv.is_none(),
1364                "subpixel_mv should be absent when refinement is disabled"
1365            );
1366        }
1367    }
1368}