1use crate::{GpuDevice, GpuError, Result};
30use rayon::prelude::*;
31use wgpu::util::DeviceExt as _;
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub enum TargetCodec {
40 Av1,
42 Vp9,
44}
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48pub enum BlockPartition {
49 Fixed16x16,
51 Fixed32x32,
53 Fixed64x64,
55 Fixed128x128,
57 Adaptive,
59}
60
61impl Default for BlockPartition {
62 fn default() -> Self {
63 Self::Fixed16x16
64 }
65}
66
67#[derive(Debug, Clone)]
69pub struct MotionEstimationConfig {
70 pub codec: TargetCodec,
72 pub partition: BlockPartition,
74 pub search_radius: u32,
76 pub subpixel_refinement: bool,
78 pub metric: MotionMetric,
80 pub pyramid_levels: u32,
82}
83
84impl Default for MotionEstimationConfig {
85 fn default() -> Self {
86 Self {
87 codec: TargetCodec::Av1,
88 partition: BlockPartition::default(),
89 search_radius: 32,
90 subpixel_refinement: true,
91 metric: MotionMetric::Sad,
92 pyramid_levels: 3,
93 }
94 }
95}
96
97#[derive(Debug, Clone, Copy, PartialEq, Eq)]
99pub enum MotionMetric {
100 Sad,
102 Ssd,
104 Hadamard,
106}
107
108#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
110pub struct MotionVector {
111 pub dx: i16,
113 pub dy: i16,
115}
116
117#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
120pub struct SubpixelMv {
121 pub dx: i32,
123 pub dy: i32,
125}
126
127#[derive(Debug, Clone)]
129pub struct BlockMvResult {
130 pub block_x: u32,
132 pub block_y: u32,
134 pub mv: MotionVector,
136 pub subpixel_mv: Option<SubpixelMv>,
138 pub cost: u32,
140}
141
142#[derive(Debug, Clone)]
144pub struct FrameMvResult {
145 pub width: u32,
147 pub height: u32,
149 pub block_mvs: Vec<BlockMvResult>,
151 pub block_size: u32,
153 pub used_gpu: bool,
155}
156
157impl FrameMvResult {
158 #[must_use]
160 pub fn blocks_x(&self) -> u32 {
161 self.width.div_ceil(self.block_size)
162 }
163
164 #[must_use]
166 pub fn blocks_y(&self) -> u32 {
167 self.height.div_ceil(self.block_size)
168 }
169
170 #[must_use]
172 pub fn mean_mv_magnitude(&self) -> f32 {
173 if self.block_mvs.is_empty() {
174 return 0.0;
175 }
176 let sum: f64 = self
177 .block_mvs
178 .iter()
179 .map(|b| {
180 let dx = f64::from(b.mv.dx);
181 let dy = f64::from(b.mv.dy);
182 (dx * dx + dy * dy).sqrt()
183 })
184 .sum();
185 (sum / self.block_mvs.len() as f64) as f32
186 }
187}
188
189pub struct MotionEstimator {
195 config: MotionEstimationConfig,
196}
197
198impl MotionEstimator {
199 #[must_use]
201 pub fn new(config: MotionEstimationConfig) -> Self {
202 Self { config }
203 }
204
205 #[must_use]
207 pub fn av1_default() -> Self {
208 Self::new(MotionEstimationConfig {
209 codec: TargetCodec::Av1,
210 partition: BlockPartition::Fixed64x64,
211 search_radius: 48,
212 subpixel_refinement: true,
213 metric: MotionMetric::Sad,
214 pyramid_levels: 3,
215 })
216 }
217
218 #[must_use]
220 pub fn vp9_default() -> Self {
221 Self::new(MotionEstimationConfig {
222 codec: TargetCodec::Vp9,
223 partition: BlockPartition::Fixed64x64,
224 search_radius: 32,
225 subpixel_refinement: true,
226 metric: MotionMetric::Sad,
227 pyramid_levels: 2,
228 })
229 }
230
231 pub fn estimate(
240 &self,
241 device: &GpuDevice,
242 reference: &[u8],
243 current: &[u8],
244 width: u32,
245 height: u32,
246 ) -> Result<FrameMvResult> {
247 if reference.len() < (width * height) as usize {
248 return Err(GpuError::InvalidBufferSize {
249 expected: (width * height) as usize,
250 actual: reference.len(),
251 });
252 }
253 if current.len() < (width * height) as usize {
254 return Err(GpuError::InvalidBufferSize {
255 expected: (width * height) as usize,
256 actual: current.len(),
257 });
258 }
259 if width == 0 || height == 0 {
260 return Err(GpuError::InvalidDimensions { width, height });
261 }
262
263 if !device.is_fallback {
267 if let Ok(result) = self.estimate_gpu(device, reference, current, width, height) {
268 return Ok(result);
269 }
270 }
271
272 self.estimate_cpu(reference, current, width, height)
274 }
275
276 fn estimate_gpu(
279 &self,
280 device: &GpuDevice,
281 reference: &[u8],
282 current: &[u8],
283 width: u32,
284 height: u32,
285 ) -> Result<FrameMvResult> {
286 let wgpu_device = device.device();
287 let queue = device.queue();
288
289 let block_size = match self.config.partition {
290 BlockPartition::Fixed16x16 | BlockPartition::Adaptive => 16u32,
291 BlockPartition::Fixed32x32 => 32,
292 BlockPartition::Fixed64x64 => 64,
293 BlockPartition::Fixed128x128 => 128,
294 };
295
296 let level_count = self.config.pyramid_levels.min(4).max(1) as usize;
297
298 let ref_data: Vec<u32> = reference.iter().map(|&b| u32::from(b)).collect();
300 let cur_data: Vec<u32> = current.iter().map(|&b| u32::from(b)).collect();
301
302 let ref_buf = wgpu_device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
303 label: Some("motion_ref_buf"),
304 contents: bytemuck::cast_slice(&ref_data),
305 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
306 });
307 let cur_buf = wgpu_device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
308 label: Some("motion_cur_buf"),
309 contents: bytemuck::cast_slice(&cur_data),
310 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
311 });
312
313 let pyramid_shader = wgpu_device.create_shader_module(wgpu::ShaderModuleDescriptor {
315 label: Some("motion_pyramid"),
316 source: wgpu::ShaderSource::Wgsl(include_str!("shaders/motion_pyramid.wgsl").into()),
317 });
318
319 let pyramid_bgl = wgpu_device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
320 label: Some("pyramid_bgl"),
321 entries: &[
322 wgpu::BindGroupLayoutEntry {
324 binding: 0,
325 visibility: wgpu::ShaderStages::COMPUTE,
326 ty: wgpu::BindingType::Buffer {
327 ty: wgpu::BufferBindingType::Uniform,
328 has_dynamic_offset: false,
329 min_binding_size: None,
330 },
331 count: None,
332 },
333 wgpu::BindGroupLayoutEntry {
335 binding: 1,
336 visibility: wgpu::ShaderStages::COMPUTE,
337 ty: wgpu::BindingType::Buffer {
338 ty: wgpu::BufferBindingType::Storage { read_only: true },
339 has_dynamic_offset: false,
340 min_binding_size: None,
341 },
342 count: None,
343 },
344 wgpu::BindGroupLayoutEntry {
346 binding: 2,
347 visibility: wgpu::ShaderStages::COMPUTE,
348 ty: wgpu::BindingType::Buffer {
349 ty: wgpu::BufferBindingType::Storage { read_only: false },
350 has_dynamic_offset: false,
351 min_binding_size: None,
352 },
353 count: None,
354 },
355 ],
356 });
357
358 let pyramid_pipeline_layout =
359 wgpu_device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
360 label: Some("pyramid_layout"),
361 bind_group_layouts: &[Some(&pyramid_bgl)],
362 immediate_size: 0,
363 });
364
365 let pyramid_pipeline =
366 wgpu_device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
367 label: Some("pyramid_pipeline"),
368 layout: Some(&pyramid_pipeline_layout),
369 module: &pyramid_shader,
370 entry_point: Some("downsample_r8"),
371 compilation_options: wgpu::PipelineCompilationOptions::default(),
372 cache: None,
373 });
374
375 let mut pyramid_ref_bufs: Vec<(wgpu::Buffer, u32, u32)> = Vec::with_capacity(level_count);
378 let mut pyramid_cur_bufs: Vec<(wgpu::Buffer, u32, u32)> = Vec::with_capacity(level_count);
379
380 pyramid_ref_bufs.push((ref_buf, width, height));
381 pyramid_cur_bufs.push((cur_buf, width, height));
382
383 for lvl in 1..level_count {
384 let (_, prev_w, prev_h) = &pyramid_ref_bufs[lvl - 1];
385 let out_w = (*prev_w).max(1) / 2;
386 let out_h = (*prev_h).max(1) / 2;
387 let out_pixels = (out_w * out_h) as usize;
388
389 let ref_out = wgpu_device.create_buffer(&wgpu::BufferDescriptor {
390 label: Some(&format!("pyramid_ref_lvl{lvl}")),
391 size: (out_pixels * std::mem::size_of::<u32>()) as u64,
392 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
393 mapped_at_creation: false,
394 });
395 let cur_out = wgpu_device.create_buffer(&wgpu::BufferDescriptor {
396 label: Some(&format!("pyramid_cur_lvl{lvl}")),
397 size: (out_pixels * std::mem::size_of::<u32>()) as u64,
398 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
399 mapped_at_creation: false,
400 });
401
402 {
404 let (in_buf, in_w, in_h) = &pyramid_ref_bufs[lvl - 1];
405 let uniforms_data: [u32; 4] = [*in_w, *in_h, out_w, out_h];
406 let uniform_buf =
407 wgpu_device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
408 label: Some(&format!("pyramid_uniform_ref_{lvl}")),
409 contents: bytemuck::cast_slice(&uniforms_data),
410 usage: wgpu::BufferUsages::UNIFORM,
411 });
412 let bg = wgpu_device.create_bind_group(&wgpu::BindGroupDescriptor {
413 label: None,
414 layout: &pyramid_bgl,
415 entries: &[
416 wgpu::BindGroupEntry {
417 binding: 0,
418 resource: uniform_buf.as_entire_binding(),
419 },
420 wgpu::BindGroupEntry {
421 binding: 1,
422 resource: in_buf.as_entire_binding(),
423 },
424 wgpu::BindGroupEntry {
425 binding: 2,
426 resource: ref_out.as_entire_binding(),
427 },
428 ],
429 });
430 let mut encoder =
431 wgpu_device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
432 label: Some("pyramid_ref_enc"),
433 });
434 {
435 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
436 label: None,
437 timestamp_writes: None,
438 });
439 pass.set_pipeline(&pyramid_pipeline);
440 pass.set_bind_group(0, &bg, &[]);
441 pass.dispatch_workgroups(out_w.div_ceil(8), out_h.div_ceil(8), 1);
442 }
443 queue.submit(std::iter::once(encoder.finish()));
444 }
445
446 {
448 let (in_buf, in_w, in_h) = &pyramid_cur_bufs[lvl - 1];
449 let uniforms_data: [u32; 4] = [*in_w, *in_h, out_w, out_h];
450 let uniform_buf =
451 wgpu_device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
452 label: Some(&format!("pyramid_uniform_cur_{lvl}")),
453 contents: bytemuck::cast_slice(&uniforms_data),
454 usage: wgpu::BufferUsages::UNIFORM,
455 });
456 let bg = wgpu_device.create_bind_group(&wgpu::BindGroupDescriptor {
457 label: None,
458 layout: &pyramid_bgl,
459 entries: &[
460 wgpu::BindGroupEntry {
461 binding: 0,
462 resource: uniform_buf.as_entire_binding(),
463 },
464 wgpu::BindGroupEntry {
465 binding: 1,
466 resource: in_buf.as_entire_binding(),
467 },
468 wgpu::BindGroupEntry {
469 binding: 2,
470 resource: cur_out.as_entire_binding(),
471 },
472 ],
473 });
474 let mut encoder =
475 wgpu_device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
476 label: Some("pyramid_cur_enc"),
477 });
478 {
479 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
480 label: None,
481 timestamp_writes: None,
482 });
483 pass.set_pipeline(&pyramid_pipeline);
484 pass.set_bind_group(0, &bg, &[]);
485 pass.dispatch_workgroups(out_w.div_ceil(8), out_h.div_ceil(8), 1);
486 }
487 queue.submit(std::iter::once(encoder.finish()));
488 }
489
490 pyramid_ref_bufs.push((ref_out, out_w, out_h));
491 pyramid_cur_bufs.push((cur_out, out_w, out_h));
492 }
493
494 let bm_shader = wgpu_device.create_shader_module(wgpu::ShaderModuleDescriptor {
496 label: Some("motion_block_match"),
497 source: wgpu::ShaderSource::Wgsl(
498 include_str!("shaders/motion_block_match.wgsl").into(),
499 ),
500 });
501
502 let bm_bgl = wgpu_device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
503 label: Some("bm_bgl"),
504 entries: &[
505 wgpu::BindGroupLayoutEntry {
507 binding: 0,
508 visibility: wgpu::ShaderStages::COMPUTE,
509 ty: wgpu::BindingType::Buffer {
510 ty: wgpu::BufferBindingType::Uniform,
511 has_dynamic_offset: false,
512 min_binding_size: None,
513 },
514 count: None,
515 },
516 wgpu::BindGroupLayoutEntry {
518 binding: 1,
519 visibility: wgpu::ShaderStages::COMPUTE,
520 ty: wgpu::BindingType::Buffer {
521 ty: wgpu::BufferBindingType::Storage { read_only: true },
522 has_dynamic_offset: false,
523 min_binding_size: None,
524 },
525 count: None,
526 },
527 wgpu::BindGroupLayoutEntry {
529 binding: 2,
530 visibility: wgpu::ShaderStages::COMPUTE,
531 ty: wgpu::BindingType::Buffer {
532 ty: wgpu::BufferBindingType::Storage { read_only: true },
533 has_dynamic_offset: false,
534 min_binding_size: None,
535 },
536 count: None,
537 },
538 wgpu::BindGroupLayoutEntry {
540 binding: 3,
541 visibility: wgpu::ShaderStages::COMPUTE,
542 ty: wgpu::BindingType::Buffer {
543 ty: wgpu::BufferBindingType::Storage { read_only: false },
544 has_dynamic_offset: false,
545 min_binding_size: None,
546 },
547 count: None,
548 },
549 ],
550 });
551
552 let bm_pipeline_layout =
553 wgpu_device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
554 label: Some("bm_layout"),
555 bind_group_layouts: &[Some(&bm_bgl)],
556 immediate_size: 0,
557 });
558
559 let bm_pipeline = wgpu_device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
560 label: Some("bm_pipeline"),
561 layout: Some(&bm_pipeline_layout),
562 module: &bm_shader,
563 entry_point: Some("block_match"),
564 compilation_options: wgpu::PipelineCompilationOptions::default(),
565 cache: None,
566 });
567
568 let top_level = level_count - 1;
572 let (_, top_w, top_h) = &pyramid_ref_bufs[top_level];
573 let top_bx = top_w.div_ceil(block_size);
574 let top_by = top_h.div_ceil(block_size);
575 let top_blocks = (top_bx * top_by) as usize;
576
577 let mut seed_mvs: Vec<[i32; 2]> = vec![[0i32, 0i32]; top_blocks];
580
581 let mut mv_int_result: Vec<[i32; 4]> = vec![[0i32; 4]; top_blocks];
585
586 for lvl in (0..level_count).rev() {
587 let (ref_level_buf, lw, lh) = &pyramid_ref_bufs[lvl];
588 let (cur_level_buf, _, _) = &pyramid_cur_bufs[lvl];
589
590 let lbx = lw.div_ceil(block_size);
591 let lby = lh.div_ceil(block_size);
592 let l_blocks = (lbx * lby) as usize;
593
594 let seeds_for_level: Vec<[i32; 2]> = if lvl == top_level {
597 vec![[0i32, 0i32]; l_blocks]
598 } else {
599 let coarser_bx = (lw * 2).div_ceil(block_size);
601 (0..l_blocks)
602 .map(|idx| {
603 let fx = (idx as u32) % lbx;
604 let fy = (idx as u32) / lbx;
605 let cx = fx / 2;
607 let cy = fy / 2;
608 let cidx = (cy * coarser_bx + cx) as usize;
609 let coarser_seed = if cidx < seed_mvs.len() {
610 seed_mvs[cidx]
611 } else {
612 [0i32, 0i32]
613 };
614 [coarser_seed[0] * 2, coarser_seed[1] * 2]
617 })
618 .collect()
619 };
620
621 let seed_x = seeds_for_level.iter().map(|s| s[0]).sum::<i32>()
627 / seeds_for_level.len().max(1) as i32;
628 let seed_y = seeds_for_level.iter().map(|s| s[1]).sum::<i32>()
629 / seeds_for_level.len().max(1) as i32;
630
631 let search_half = 8u32;
632
633 let uniforms: [u32; 8] = [
636 block_size,
637 search_half,
638 *lw,
639 *lh,
640 seed_x as u32, seed_y as u32,
642 lbx,
643 lby,
644 ];
645
646 let uniform_buf = wgpu_device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
647 label: Some(&format!("bm_uniform_lvl{lvl}")),
648 contents: bytemuck::cast_slice(&uniforms),
649 usage: wgpu::BufferUsages::UNIFORM,
650 });
651
652 let mv_out_buf = wgpu_device.create_buffer(&wgpu::BufferDescriptor {
653 label: Some(&format!("mv_out_lvl{lvl}")),
654 size: (l_blocks * std::mem::size_of::<[i32; 4]>()) as u64,
655 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
656 mapped_at_creation: false,
657 });
658
659 let bg = wgpu_device.create_bind_group(&wgpu::BindGroupDescriptor {
660 label: None,
661 layout: &bm_bgl,
662 entries: &[
663 wgpu::BindGroupEntry {
664 binding: 0,
665 resource: uniform_buf.as_entire_binding(),
666 },
667 wgpu::BindGroupEntry {
668 binding: 1,
669 resource: ref_level_buf.as_entire_binding(),
670 },
671 wgpu::BindGroupEntry {
672 binding: 2,
673 resource: cur_level_buf.as_entire_binding(),
674 },
675 wgpu::BindGroupEntry {
676 binding: 3,
677 resource: mv_out_buf.as_entire_binding(),
678 },
679 ],
680 });
681
682 let mut encoder = wgpu_device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
683 label: Some(&format!("bm_enc_lvl{lvl}")),
684 });
685 {
686 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
687 label: None,
688 timestamp_writes: None,
689 });
690 pass.set_pipeline(&bm_pipeline);
691 pass.set_bind_group(0, &bg, &[]);
692 pass.dispatch_workgroups(lbx, lby, 1);
694 }
695
696 let staging = wgpu_device.create_buffer(&wgpu::BufferDescriptor {
698 label: Some(&format!("bm_staging_lvl{lvl}")),
699 size: (l_blocks * std::mem::size_of::<[i32; 4]>()) as u64,
700 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
701 mapped_at_creation: false,
702 });
703 encoder.copy_buffer_to_buffer(
704 &mv_out_buf,
705 0,
706 &staging,
707 0,
708 (l_blocks * std::mem::size_of::<[i32; 4]>()) as u64,
709 );
710 queue.submit(std::iter::once(encoder.finish()));
711
712 let _ = wgpu_device.poll(wgpu::PollType::wait_indefinitely());
713
714 let slice = staging.slice(..);
715 let (tx, mut rx) = futures_channel::oneshot::channel();
716 slice.map_async(wgpu::MapMode::Read, move |result| {
717 let _ = tx.send(result);
718 });
719 let _ = wgpu_device.poll(wgpu::PollType::wait_indefinitely());
720 rx.try_recv()
721 .map_err(|e| GpuError::BufferMapping(e.to_string()))?
722 .ok_or_else(|| GpuError::BufferMapping("channel empty".into()))?
723 .map_err(|e| GpuError::BufferMapping(e.to_string()))?;
724
725 {
726 let data = slice.get_mapped_range();
727 let raw: &[[i32; 4]] = bytemuck::cast_slice(&data);
728 mv_int_result = raw[..l_blocks.min(raw.len())].to_vec();
729 seed_mvs = raw[..l_blocks.min(raw.len())]
731 .iter()
732 .map(|v| [v[0], v[1]])
733 .collect();
734 }
735 }
736
737 let final_blocks_x = width.div_ceil(block_size);
739 let final_blocks_y = height.div_ceil(block_size);
740 let n_blocks = (final_blocks_x * final_blocks_y) as usize;
741
742 let (ref_l0, _, _) = &pyramid_ref_bufs[0];
743 let (cur_l0, _, _) = &pyramid_cur_bufs[0];
744
745 let mv_in_data: Vec<[i32; 4]> = (0..n_blocks)
748 .map(|i| {
749 if i < mv_int_result.len() {
750 mv_int_result[i]
751 } else {
752 [0i32; 4]
753 }
754 })
755 .collect();
756
757 let mv_in_buf = wgpu_device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
758 label: Some("subpix_mv_in"),
759 contents: bytemuck::cast_slice(&mv_in_data),
760 usage: wgpu::BufferUsages::STORAGE,
761 });
762
763 let mv_out_sp_buf = wgpu_device.create_buffer(&wgpu::BufferDescriptor {
764 label: Some("subpix_mv_out"),
765 size: (n_blocks * std::mem::size_of::<[f32; 2]>()) as u64,
766 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
767 mapped_at_creation: false,
768 });
769
770 let sp_shader = wgpu_device.create_shader_module(wgpu::ShaderModuleDescriptor {
771 label: Some("motion_subpixel"),
772 source: wgpu::ShaderSource::Wgsl(include_str!("shaders/motion_subpixel.wgsl").into()),
773 });
774
775 let sp_bgl = wgpu_device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
776 label: Some("sp_bgl"),
777 entries: &[
778 wgpu::BindGroupLayoutEntry {
779 binding: 0,
780 visibility: wgpu::ShaderStages::COMPUTE,
781 ty: wgpu::BindingType::Buffer {
782 ty: wgpu::BufferBindingType::Uniform,
783 has_dynamic_offset: false,
784 min_binding_size: None,
785 },
786 count: None,
787 },
788 wgpu::BindGroupLayoutEntry {
789 binding: 1,
790 visibility: wgpu::ShaderStages::COMPUTE,
791 ty: wgpu::BindingType::Buffer {
792 ty: wgpu::BufferBindingType::Storage { read_only: true },
793 has_dynamic_offset: false,
794 min_binding_size: None,
795 },
796 count: None,
797 },
798 wgpu::BindGroupLayoutEntry {
799 binding: 2,
800 visibility: wgpu::ShaderStages::COMPUTE,
801 ty: wgpu::BindingType::Buffer {
802 ty: wgpu::BufferBindingType::Storage { read_only: true },
803 has_dynamic_offset: false,
804 min_binding_size: None,
805 },
806 count: None,
807 },
808 wgpu::BindGroupLayoutEntry {
809 binding: 3,
810 visibility: wgpu::ShaderStages::COMPUTE,
811 ty: wgpu::BindingType::Buffer {
812 ty: wgpu::BufferBindingType::Storage { read_only: true },
813 has_dynamic_offset: false,
814 min_binding_size: None,
815 },
816 count: None,
817 },
818 wgpu::BindGroupLayoutEntry {
819 binding: 4,
820 visibility: wgpu::ShaderStages::COMPUTE,
821 ty: wgpu::BindingType::Buffer {
822 ty: wgpu::BufferBindingType::Storage { read_only: false },
823 has_dynamic_offset: false,
824 min_binding_size: None,
825 },
826 count: None,
827 },
828 ],
829 });
830
831 let sp_pipeline_layout =
832 wgpu_device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
833 label: Some("sp_layout"),
834 bind_group_layouts: &[Some(&sp_bgl)],
835 immediate_size: 0,
836 });
837
838 let sp_pipeline = wgpu_device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
839 label: Some("sp_pipeline"),
840 layout: Some(&sp_pipeline_layout),
841 module: &sp_shader,
842 entry_point: Some("subpixel_refine"),
843 compilation_options: wgpu::PipelineCompilationOptions::default(),
844 cache: None,
845 });
846
847 let sp_uniforms: [u32; 4] = [width, height, block_size, n_blocks as u32];
848 let sp_uniform_buf = wgpu_device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
849 label: Some("sp_uniforms"),
850 contents: bytemuck::cast_slice(&sp_uniforms),
851 usage: wgpu::BufferUsages::UNIFORM,
852 });
853
854 let sp_bg = wgpu_device.create_bind_group(&wgpu::BindGroupDescriptor {
855 label: None,
856 layout: &sp_bgl,
857 entries: &[
858 wgpu::BindGroupEntry {
859 binding: 0,
860 resource: sp_uniform_buf.as_entire_binding(),
861 },
862 wgpu::BindGroupEntry {
863 binding: 1,
864 resource: ref_l0.as_entire_binding(),
865 },
866 wgpu::BindGroupEntry {
867 binding: 2,
868 resource: cur_l0.as_entire_binding(),
869 },
870 wgpu::BindGroupEntry {
871 binding: 3,
872 resource: mv_in_buf.as_entire_binding(),
873 },
874 wgpu::BindGroupEntry {
875 binding: 4,
876 resource: mv_out_sp_buf.as_entire_binding(),
877 },
878 ],
879 });
880
881 let sp_staging = wgpu_device.create_buffer(&wgpu::BufferDescriptor {
882 label: Some("sp_staging"),
883 size: (n_blocks * std::mem::size_of::<[f32; 2]>()) as u64,
884 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
885 mapped_at_creation: false,
886 });
887
888 let mut sp_encoder = wgpu_device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
889 label: Some("sp_enc"),
890 });
891 {
892 let mut pass = sp_encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
893 label: None,
894 timestamp_writes: None,
895 });
896 pass.set_pipeline(&sp_pipeline);
897 pass.set_bind_group(0, &sp_bg, &[]);
898 let groups = (n_blocks as u32).div_ceil(64);
899 pass.dispatch_workgroups(groups, 1, 1);
900 }
901 sp_encoder.copy_buffer_to_buffer(
902 &mv_out_sp_buf,
903 0,
904 &sp_staging,
905 0,
906 (n_blocks * std::mem::size_of::<[f32; 2]>()) as u64,
907 );
908 queue.submit(std::iter::once(sp_encoder.finish()));
909 let _ = wgpu_device.poll(wgpu::PollType::wait_indefinitely());
910
911 let sp_slice = sp_staging.slice(..);
912 let (sp_tx, mut sp_rx) = futures_channel::oneshot::channel();
913 sp_slice.map_async(wgpu::MapMode::Read, move |result| {
914 let _ = sp_tx.send(result);
915 });
916 let _ = wgpu_device.poll(wgpu::PollType::wait_indefinitely());
917 sp_rx
918 .try_recv()
919 .map_err(|e| GpuError::BufferMapping(e.to_string()))?
920 .ok_or_else(|| GpuError::BufferMapping("channel empty".into()))?
921 .map_err(|e| GpuError::BufferMapping(e.to_string()))?;
922
923 let subpixel_mvs: Vec<[f32; 2]> = {
924 let data = sp_slice.get_mapped_range();
925 bytemuck::cast_slice::<u8, [f32; 2]>(&data)[..n_blocks].to_vec()
926 };
927
928 let block_mvs: Vec<BlockMvResult> = (0..n_blocks)
930 .map(|idx| {
931 let bx = (idx as u32 % final_blocks_x) * block_size;
932 let by = (idx as u32 / final_blocks_x) * block_size;
933
934 let int_mv = if idx < mv_int_result.len() {
935 mv_int_result[idx]
936 } else {
937 [0i32; 4]
938 };
939
940 let mv = MotionVector {
941 dx: int_mv[0].clamp(i16::MIN as i32, i16::MAX as i32) as i16,
942 dy: int_mv[1].clamp(i16::MIN as i32, i16::MAX as i32) as i16,
943 };
944
945 let subpixel_mv = if self.config.subpixel_refinement {
947 let sp = subpixel_mvs[idx];
948 Some(SubpixelMv {
949 dx: (sp[0] * 4.0).round() as i32,
950 dy: (sp[1] * 4.0).round() as i32,
951 })
952 } else {
953 None
954 };
955
956 let cost = int_mv[2].max(0) as u32;
957
958 BlockMvResult {
959 block_x: bx,
960 block_y: by,
961 mv,
962 subpixel_mv,
963 cost,
964 }
965 })
966 .collect();
967
968 Ok(FrameMvResult {
969 width,
970 height,
971 block_mvs,
972 block_size,
973 used_gpu: true,
974 })
975 }
976
977 fn estimate_cpu(
980 &self,
981 reference: &[u8],
982 current: &[u8],
983 width: u32,
984 height: u32,
985 ) -> Result<FrameMvResult> {
986 if width == 0 || height == 0 {
989 return Err(GpuError::InvalidDimensions { width, height });
990 }
991 let required = (width as usize)
992 .checked_mul(height as usize)
993 .ok_or(GpuError::InvalidDimensions { width, height })?;
994 if reference.len() < required {
995 return Err(GpuError::InvalidBufferSize {
996 expected: required,
997 actual: reference.len(),
998 });
999 }
1000 if current.len() < required {
1001 return Err(GpuError::InvalidBufferSize {
1002 expected: required,
1003 actual: current.len(),
1004 });
1005 }
1006
1007 let block_size = match self.config.partition {
1008 BlockPartition::Fixed16x16 | BlockPartition::Adaptive => 16u32,
1009 BlockPartition::Fixed32x32 => 32,
1010 BlockPartition::Fixed64x64 => 64,
1011 BlockPartition::Fixed128x128 => 128,
1012 };
1013
1014 let blocks_x = width.div_ceil(block_size);
1015 let blocks_y = height.div_ceil(block_size);
1016 let n_blocks = (blocks_x * blocks_y) as usize;
1017
1018 let block_mvs: Vec<BlockMvResult> = (0..n_blocks)
1019 .into_par_iter()
1020 .map(|idx| {
1021 let bx = (idx as u32 % blocks_x) * block_size;
1022 let by = (idx as u32 / blocks_x) * block_size;
1023 self.match_block(reference, current, width, height, bx, by, block_size)
1024 })
1025 .collect();
1026
1027 Ok(FrameMvResult {
1028 width,
1029 height,
1030 block_mvs,
1031 block_size,
1032 used_gpu: false,
1033 })
1034 }
1035
1036 #[allow(clippy::too_many_arguments)]
1046 fn match_block(
1047 &self,
1048 reference: &[u8],
1049 current: &[u8],
1050 width: u32,
1051 height: u32,
1052 bx: u32,
1053 by: u32,
1054 block_size: u32,
1055 ) -> BlockMvResult {
1056 let w = width as usize;
1057 let sr = self.config.search_radius as i32;
1058 let bs = block_size as usize;
1059
1060 let zero_cost = self.compute_sad(
1063 reference,
1064 current,
1065 w,
1066 width as usize,
1067 height as usize,
1068 bx as usize,
1069 by as usize,
1070 bx as usize,
1071 by as usize,
1072 bs,
1073 );
1074 let mut best_cost = zero_cost;
1075 let mut best_mv = MotionVector::default();
1076
1077 for dy in -sr..=sr {
1078 for dx in -sr..=sr {
1079 if dx == 0 && dy == 0 {
1081 continue;
1082 }
1083
1084 let ref_x = bx as i32 + dx;
1085 let ref_y = by as i32 + dy;
1086
1087 if ref_x < 0
1089 || ref_y < 0
1090 || ref_x + bs as i32 > width as i32
1091 || ref_y + bs as i32 > height as i32
1092 {
1093 continue;
1094 }
1095
1096 let cost = self.compute_sad(
1097 reference,
1098 current,
1099 w,
1100 width as usize,
1101 height as usize,
1102 ref_x as usize,
1103 ref_y as usize,
1104 bx as usize,
1105 by as usize,
1106 bs,
1107 );
1108
1109 if cost < best_cost {
1112 best_cost = cost;
1113 best_mv = MotionVector {
1114 dx: dx as i16,
1115 dy: dy as i16,
1116 };
1117 }
1118 }
1119 }
1120
1121 let subpixel_mv = if self.config.subpixel_refinement {
1123 Some(SubpixelMv {
1124 dx: i32::from(best_mv.dx) * 4,
1125 dy: i32::from(best_mv.dy) * 4,
1126 })
1127 } else {
1128 None
1129 };
1130
1131 BlockMvResult {
1132 block_x: bx,
1133 block_y: by,
1134 mv: best_mv,
1135 subpixel_mv,
1136 cost: best_cost,
1137 }
1138 }
1139
1140 #[allow(clippy::too_many_arguments)]
1143 fn compute_sad(
1144 &self,
1145 reference: &[u8],
1146 current: &[u8],
1147 _stride: usize,
1148 width: usize,
1149 _height: usize,
1150 ref_x: usize,
1151 ref_y: usize,
1152 cur_x: usize,
1153 cur_y: usize,
1154 block_size: usize,
1155 ) -> u32 {
1156 let mut sad = 0u32;
1157 for row in 0..block_size {
1158 for col in 0..block_size {
1159 let cur_idx = (cur_y + row) * width + (cur_x + col);
1160 let ref_idx = (ref_y + row) * width + (ref_x + col);
1161 if cur_idx < current.len() && ref_idx < reference.len() {
1162 sad += u32::from(current[cur_idx].abs_diff(reference[ref_idx]));
1163 }
1164 }
1165 }
1166 sad
1167 }
1168}
1169
1170#[cfg(test)]
1175mod tests {
1176 use super::*;
1177
1178 fn gray_frame(w: u32, h: u32, value: u8) -> Vec<u8> {
1179 vec![value; (w * h) as usize]
1180 }
1181
1182 fn shifted_frame(w: u32, h: u32, dx: i32, dy: i32) -> Vec<u8> {
1188 let mut state: u64 = 0x5851_F42D_4C95_7F2D;
1190 let mut frame = vec![0u8; (w * h) as usize];
1191 for pixel in frame.iter_mut() {
1192 state = state
1193 .wrapping_mul(6364136223846793005)
1194 .wrapping_add(1442695040888963407);
1195 *pixel = ((state >> 33) & 0xFF) as u8;
1196 }
1197 let mut shifted = vec![128u8; (w * h) as usize];
1200 for y in 0..h as i32 {
1201 for x in 0..w as i32 {
1202 let sx = x + dx;
1203 let sy = y + dy;
1204 if sx >= 0 && sy >= 0 && sx < w as i32 && sy < h as i32 {
1205 shifted[(sy as usize) * w as usize + sx as usize] =
1206 frame[y as usize * w as usize + x as usize];
1207 }
1208 }
1209 }
1210 shifted
1211 }
1212
1213 #[test]
1214 fn test_estimator_default_config() {
1215 let e = MotionEstimator::av1_default();
1216 assert_eq!(e.config.codec, TargetCodec::Av1);
1217 }
1218
1219 #[test]
1220 fn test_vp9_default_config() {
1221 let e = MotionEstimator::vp9_default();
1222 assert_eq!(e.config.codec, TargetCodec::Vp9);
1223 }
1224
1225 #[test]
1226 fn test_zero_mv_for_identical_frames() {
1227 let w = 64u32;
1228 let h = 64u32;
1229 let frame = gray_frame(w, h, 128);
1230 let e = MotionEstimator::new(MotionEstimationConfig {
1231 partition: BlockPartition::Fixed16x16,
1232 search_radius: 4,
1233 subpixel_refinement: false,
1234 ..MotionEstimationConfig::default()
1235 });
1236 let result = e
1237 .estimate_cpu(&frame, &frame, w, h)
1238 .expect("CPU estimate failed");
1239 for bm in &result.block_mvs {
1240 assert_eq!(bm.mv.dx, 0, "dx should be 0 for identical frames");
1241 assert_eq!(bm.mv.dy, 0, "dy should be 0 for identical frames");
1242 }
1243 }
1244
1245 #[test]
1246 fn test_mv_detected_for_shifted_frame() {
1247 let w = 64u32;
1248 let h = 64u32;
1249 let reference = shifted_frame(w, h, 0, 0);
1250 let current = shifted_frame(w, h, 4, 0);
1251 let e = MotionEstimator::new(MotionEstimationConfig {
1252 partition: BlockPartition::Fixed16x16,
1253 search_radius: 8,
1254 subpixel_refinement: false,
1255 ..MotionEstimationConfig::default()
1256 });
1257 let result = e
1258 .estimate_cpu(&reference, ¤t, w, h)
1259 .expect("CPU estimate failed");
1260 let matched = result
1262 .block_mvs
1263 .iter()
1264 .filter(|b| b.mv.dx.abs() >= 3)
1265 .count();
1266 assert!(
1267 matched > result.block_mvs.len() / 2,
1268 "expected most blocks to detect horizontal shift"
1269 );
1270 }
1271
1272 #[test]
1273 fn test_invalid_dimensions_rejected() {
1274 let e = MotionEstimator::av1_default();
1275 let frame = vec![0u8; 64];
1276 let result = e.estimate_cpu(&frame, &frame, 0, 8);
1277 assert!(result.is_err());
1278 }
1279
1280 #[test]
1281 fn test_buffer_too_small_rejected() {
1282 let e = MotionEstimator::av1_default();
1283 let small = vec![0u8; 4];
1284 let frame = vec![0u8; 64 * 64];
1285 let result = e.estimate_cpu(&small, &frame, 64, 64);
1286 assert!(result.is_err(), "undersized reference should be rejected");
1287 }
1288
1289 #[test]
1290 fn test_mean_mv_magnitude_zero_for_static() {
1291 let w = 32u32;
1292 let h = 32u32;
1293 let frame = gray_frame(w, h, 100);
1294 let e = MotionEstimator::new(MotionEstimationConfig {
1295 partition: BlockPartition::Fixed16x16,
1296 search_radius: 2,
1297 subpixel_refinement: false,
1298 ..MotionEstimationConfig::default()
1299 });
1300 let result = e
1301 .estimate_cpu(&frame, &frame, w, h)
1302 .expect("CPU estimate failed");
1303 assert_eq!(result.mean_mv_magnitude(), 0.0);
1304 }
1305
1306 #[test]
1307 fn test_blocks_dimensions() {
1308 let w = 64u32;
1309 let h = 32u32;
1310 let frame = gray_frame(w, h, 0);
1311 let e = MotionEstimator::new(MotionEstimationConfig {
1312 partition: BlockPartition::Fixed16x16,
1313 search_radius: 2,
1314 subpixel_refinement: false,
1315 ..MotionEstimationConfig::default()
1316 });
1317 let result = e
1318 .estimate_cpu(&frame, &frame, w, h)
1319 .expect("CPU estimate failed");
1320 assert_eq!(result.blocks_x(), 4);
1321 assert_eq!(result.blocks_y(), 2);
1322 assert_eq!(result.block_mvs.len(), 8);
1323 }
1324
1325 #[test]
1326 fn test_subpixel_refinement_present() {
1327 let w = 16u32;
1328 let h = 16u32;
1329 let frame = gray_frame(w, h, 128);
1330 let e = MotionEstimator::new(MotionEstimationConfig {
1331 partition: BlockPartition::Fixed16x16,
1332 search_radius: 2,
1333 subpixel_refinement: true,
1334 ..MotionEstimationConfig::default()
1335 });
1336 let result = e
1337 .estimate_cpu(&frame, &frame, w, h)
1338 .expect("CPU estimate failed");
1339 for bm in &result.block_mvs {
1340 assert!(
1341 bm.subpixel_mv.is_some(),
1342 "subpixel_mv should be present when refinement is enabled"
1343 );
1344 }
1345 }
1346
1347 #[test]
1348 fn test_subpixel_refinement_absent_when_disabled() {
1349 let w = 16u32;
1350 let h = 16u32;
1351 let frame = gray_frame(w, h, 64);
1352 let e = MotionEstimator::new(MotionEstimationConfig {
1353 partition: BlockPartition::Fixed16x16,
1354 search_radius: 2,
1355 subpixel_refinement: false,
1356 ..MotionEstimationConfig::default()
1357 });
1358 let result = e
1359 .estimate_cpu(&frame, &frame, w, h)
1360 .expect("CPU estimate failed");
1361 for bm in &result.block_mvs {
1362 assert!(
1363 bm.subpixel_mv.is_none(),
1364 "subpixel_mv should be absent when refinement is disabled"
1365 );
1366 }
1367 }
1368}