1#![allow(dead_code)]
2use crate::{AlignError, AlignResult, Point2D};
15
16#[derive(Debug, Clone, Copy, PartialEq)]
18pub struct MotionVector {
19 pub dx: f64,
21 pub dy: f64,
23 pub cost: f64,
25}
26
27impl MotionVector {
28 #[must_use]
30 pub fn new(dx: f64, dy: f64, cost: f64) -> Self {
31 Self { dx, dy, cost }
32 }
33
34 #[must_use]
36 pub fn zero() -> Self {
37 Self {
38 dx: 0.0,
39 dy: 0.0,
40 cost: 0.0,
41 }
42 }
43
44 #[must_use]
46 pub fn magnitude(&self) -> f64 {
47 (self.dx * self.dx + self.dy * self.dy).sqrt()
48 }
49
50 #[must_use]
52 pub fn direction(&self) -> f64 {
53 self.dy.atan2(self.dx)
54 }
55
56 #[must_use]
58 pub fn add(&self, other: &Self) -> Self {
59 Self {
60 dx: self.dx + other.dx,
61 dy: self.dy + other.dy,
62 cost: (self.cost + other.cost) / 2.0,
63 }
64 }
65
66 #[must_use]
68 pub fn scale(&self, factor: f64) -> Self {
69 Self {
70 dx: self.dx * factor,
71 dy: self.dy * factor,
72 cost: self.cost,
73 }
74 }
75}
76
77#[derive(Debug, Clone, Copy, PartialEq, Eq)]
79pub enum SearchStrategy {
80 FullSearch,
82 DiamondSearch,
84 ThreeStepSearch,
86 HexagonalSearch,
88}
89
90#[derive(Debug, Clone)]
92pub struct MotionEstimationConfig {
93 pub block_size: u32,
95 pub search_range: u32,
97 pub search_strategy: SearchStrategy,
99 pub sub_pixel: bool,
101 pub frame_width: u32,
103 pub frame_height: u32,
105}
106
107impl Default for MotionEstimationConfig {
108 fn default() -> Self {
109 Self {
110 block_size: 16,
111 search_range: 32,
112 search_strategy: SearchStrategy::DiamondSearch,
113 sub_pixel: true,
114 frame_width: 1920,
115 frame_height: 1080,
116 }
117 }
118}
119
120#[derive(Debug, Clone)]
122pub struct MotionField {
123 pub vectors: Vec<MotionVector>,
125 pub cols: u32,
127 pub rows: u32,
129 pub block_size: u32,
131}
132
133impl MotionField {
134 #[must_use]
136 #[allow(clippy::cast_precision_loss)]
137 pub fn new(frame_width: u32, frame_height: u32, block_size: u32) -> Self {
138 let cols = frame_width.div_ceil(block_size);
139 let rows = frame_height.div_ceil(block_size);
140 let count = (cols * rows) as usize;
141 Self {
142 vectors: vec![MotionVector::zero(); count],
143 cols,
144 rows,
145 block_size,
146 }
147 }
148
149 #[must_use]
151 pub fn get(&self, bx: u32, by: u32) -> Option<&MotionVector> {
152 if bx < self.cols && by < self.rows {
153 Some(&self.vectors[(by * self.cols + bx) as usize])
154 } else {
155 None
156 }
157 }
158
159 pub fn set(&mut self, bx: u32, by: u32, mv: MotionVector) {
161 if bx < self.cols && by < self.rows {
162 self.vectors[(by * self.cols + bx) as usize] = mv;
163 }
164 }
165
166 #[must_use]
168 #[allow(clippy::cast_precision_loss)]
169 pub fn interpolate(&self, x: f64, y: f64) -> MotionVector {
170 let bs = f64::from(self.block_size);
171 let bx = (x / bs).floor();
172 let by = (y / bs).floor();
173
174 let bxi = bx as u32;
175 let byi = by as u32;
176
177 let fx = x / bs - bx;
179 let fy = y / bs - by;
180
181 let get_mv = |cx: u32, cy: u32| -> MotionVector {
182 self.get(
183 cx.min(self.cols.saturating_sub(1)),
184 cy.min(self.rows.saturating_sub(1)),
185 )
186 .copied()
187 .unwrap_or_else(MotionVector::zero)
188 };
189
190 let tl = get_mv(bxi, byi);
191 let tr = get_mv(bxi + 1, byi);
192 let bl = get_mv(bxi, byi + 1);
193 let br = get_mv(bxi + 1, byi + 1);
194
195 let dx = tl.dx * (1.0 - fx) * (1.0 - fy)
196 + tr.dx * fx * (1.0 - fy)
197 + bl.dx * (1.0 - fx) * fy
198 + br.dx * fx * fy;
199
200 let dy = tl.dy * (1.0 - fx) * (1.0 - fy)
201 + tr.dy * fx * (1.0 - fy)
202 + bl.dy * (1.0 - fx) * fy
203 + br.dy * fx * fy;
204
205 let cost = tl.cost * (1.0 - fx) * (1.0 - fy)
206 + tr.cost * fx * (1.0 - fy)
207 + bl.cost * (1.0 - fx) * fy
208 + br.cost * fx * fy;
209
210 MotionVector::new(dx, dy, cost)
211 }
212
213 #[must_use]
215 pub fn global_motion(&self) -> MotionVector {
216 if self.vectors.is_empty() {
217 return MotionVector::zero();
218 }
219
220 let mut dxs: Vec<f64> = self.vectors.iter().map(|v| v.dx).collect();
221 let mut dys: Vec<f64> = self.vectors.iter().map(|v| v.dy).collect();
222
223 dxs.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
224 dys.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
225
226 let mid = dxs.len() / 2;
227 MotionVector::new(dxs[mid], dys[mid], 0.0)
228 }
229
230 #[must_use]
232 #[allow(clippy::cast_precision_loss)]
233 pub fn average_magnitude(&self) -> f64 {
234 if self.vectors.is_empty() {
235 return 0.0;
236 }
237 let total: f64 = self.vectors.iter().map(MotionVector::magnitude).sum();
238 total / self.vectors.len() as f64
239 }
240
241 #[must_use]
243 pub fn count_above_threshold(&self, threshold: f64) -> usize {
244 self.vectors
245 .iter()
246 .filter(|v| v.magnitude() > threshold)
247 .count()
248 }
249}
250
251#[derive(Debug, Clone)]
253pub struct MotionStats {
254 pub avg_magnitude: f64,
256 pub max_magnitude: f64,
258 pub std_magnitude: f64,
260 pub global_dx: f64,
262 pub global_dy: f64,
264 pub motion_fraction: f64,
266}
267
268#[derive(Debug, Clone)]
270pub struct MotionCompensator {
271 config: MotionEstimationConfig,
273}
274
275impl MotionCompensator {
276 #[must_use]
278 pub fn new(config: MotionEstimationConfig) -> Self {
279 Self { config }
280 }
281
282 #[must_use]
284 pub fn with_defaults() -> Self {
285 Self {
286 config: MotionEstimationConfig::default(),
287 }
288 }
289
290 #[allow(clippy::cast_precision_loss)]
294 pub fn estimate(&self, reference: &[u8], target: &[u8]) -> AlignResult<MotionField> {
295 let expected_size = (self.config.frame_width * self.config.frame_height) as usize;
296 if reference.len() != expected_size || target.len() != expected_size {
297 return Err(AlignError::InsufficientData(format!(
298 "Expected frame size {}, got ref={} target={}",
299 expected_size,
300 reference.len(),
301 target.len()
302 )));
303 }
304
305 let mut field = MotionField::new(
306 self.config.frame_width,
307 self.config.frame_height,
308 self.config.block_size,
309 );
310
311 let bs = self.config.block_size;
312 let sr = self.config.search_range as i32;
313 let w = self.config.frame_width;
314 let h = self.config.frame_height;
315
316 for by in 0..field.rows {
317 for bx in 0..field.cols {
318 let orig_x = (bx * bs) as i32;
319 let orig_y = (by * bs) as i32;
320
321 let mv = match self.config.search_strategy {
322 SearchStrategy::FullSearch => {
323 self.full_search(reference, target, orig_x, orig_y, bs, sr, w, h)
324 }
325 _ => {
326 self.diamond_search(reference, target, orig_x, orig_y, bs, sr, w, h)
328 }
329 };
330
331 field.set(bx, by, mv);
332 }
333 }
334
335 Ok(field)
336 }
337
338 #[must_use]
340 #[allow(clippy::cast_precision_loss)]
341 pub fn compute_stats(field: &MotionField) -> MotionStats {
342 if field.vectors.is_empty() {
343 return MotionStats {
344 avg_magnitude: 0.0,
345 max_magnitude: 0.0,
346 std_magnitude: 0.0,
347 global_dx: 0.0,
348 global_dy: 0.0,
349 motion_fraction: 0.0,
350 };
351 }
352
353 let magnitudes: Vec<f64> = field.vectors.iter().map(MotionVector::magnitude).collect();
354 let n = magnitudes.len() as f64;
355 let avg = magnitudes.iter().sum::<f64>() / n;
356 let max = magnitudes.iter().copied().fold(0.0_f64, f64::max);
357 let variance = magnitudes.iter().map(|m| (m - avg).powi(2)).sum::<f64>() / n;
358 let std_dev = variance.sqrt();
359
360 let global = field.global_motion();
361 let motion_count = field.count_above_threshold(1.0);
362
363 MotionStats {
364 avg_magnitude: avg,
365 max_magnitude: max,
366 std_magnitude: std_dev,
367 global_dx: global.dx,
368 global_dy: global.dy,
369 motion_fraction: motion_count as f64 / n,
370 }
371 }
372
373 #[must_use]
375 pub fn compensate_points(field: &MotionField, points: &[Point2D]) -> Vec<Point2D> {
376 points
377 .iter()
378 .map(|p| {
379 let mv = field.interpolate(p.x, p.y);
380 Point2D::new(p.x + mv.dx, p.y + mv.dy)
381 })
382 .collect()
383 }
384
385 #[allow(clippy::too_many_arguments)]
387 #[allow(clippy::cast_precision_loss)]
388 fn full_search(
389 &self,
390 reference: &[u8],
391 target: &[u8],
392 bx: i32,
393 by: i32,
394 bs: u32,
395 sr: i32,
396 w: u32,
397 h: u32,
398 ) -> MotionVector {
399 let mut best_dx = 0i32;
400 let mut best_dy = 0i32;
401 let mut best_cost = f64::MAX;
402
403 for dy in -sr..=sr {
404 for dx in -sr..=sr {
405 let cost = self.compute_sad(reference, target, bx, by, bx + dx, by + dy, bs, w, h);
406 if cost < best_cost
407 || (cost == best_cost
408 && (dx.unsigned_abs() + dy.unsigned_abs())
409 < (best_dx.unsigned_abs() + best_dy.unsigned_abs()))
410 {
411 best_cost = cost;
412 best_dx = dx;
413 best_dy = dy;
414 }
415 }
416 }
417
418 MotionVector::new(f64::from(best_dx), f64::from(best_dy), best_cost)
419 }
420
421 #[allow(clippy::too_many_arguments)]
423 #[allow(clippy::cast_precision_loss)]
424 fn diamond_search(
425 &self,
426 reference: &[u8],
427 target: &[u8],
428 bx: i32,
429 by: i32,
430 bs: u32,
431 sr: i32,
432 w: u32,
433 h: u32,
434 ) -> MotionVector {
435 let large_diamond: [(i32, i32); 9] = [
436 (0, 0),
437 (0, -2),
438 (1, -1),
439 (2, 0),
440 (1, 1),
441 (0, 2),
442 (-1, 1),
443 (-2, 0),
444 (-1, -1),
445 ];
446
447 let mut cx = 0i32;
448 let mut cy = 0i32;
449 let mut best_cost = f64::MAX;
450
451 for _ in 0..sr {
452 let mut found_better = false;
453 let mut new_cx = cx;
454 let mut new_cy = cy;
455
456 for &(ddx, ddy) in &large_diamond {
457 let tx = cx + ddx;
458 let ty = cy + ddy;
459 if tx.abs() > sr || ty.abs() > sr {
460 continue;
461 }
462 let cost = self.compute_sad(reference, target, bx, by, bx + tx, by + ty, bs, w, h);
463 if cost < best_cost {
464 best_cost = cost;
465 new_cx = tx;
466 new_cy = ty;
467 found_better = true;
468 }
469 }
470
471 if !found_better || (new_cx == cx && new_cy == cy) {
472 break;
473 }
474 cx = new_cx;
475 cy = new_cy;
476 }
477
478 MotionVector::new(f64::from(cx), f64::from(cy), best_cost)
479 }
480
481 #[allow(clippy::too_many_arguments)]
483 #[allow(clippy::cast_precision_loss)]
484 fn compute_sad(
485 &self,
486 reference: &[u8],
487 target: &[u8],
488 rx: i32,
489 ry: i32,
490 tx: i32,
491 ty: i32,
492 bs: u32,
493 w: u32,
494 h: u32,
495 ) -> f64 {
496 let mut sad = 0u64;
497 let bs_i = bs as i32;
498 let w_i = w as i32;
499 let h_i = h as i32;
500
501 for row in 0..bs_i {
502 for col in 0..bs_i {
503 let ref_x = rx + col;
504 let ref_y = ry + row;
505 let tgt_x = tx + col;
506 let tgt_y = ty + row;
507
508 if ref_x < 0 || ref_x >= w_i || ref_y < 0 || ref_y >= h_i {
509 sad += 128;
510 continue;
511 }
512 if tgt_x < 0 || tgt_x >= w_i || tgt_y < 0 || tgt_y >= h_i {
513 sad += 128;
514 continue;
515 }
516
517 let ref_idx = (ref_y as u32 * w + ref_x as u32) as usize;
518 let tgt_idx = (tgt_y as u32 * w + tgt_x as u32) as usize;
519
520 let diff = i32::from(reference[ref_idx]) - i32::from(target[tgt_idx]);
521 sad += u64::from(diff.unsigned_abs());
522 }
523 }
524
525 sad as f64
526 }
527
528 #[must_use]
530 pub fn config(&self) -> &MotionEstimationConfig {
531 &self.config
532 }
533}
534
535#[cfg(test)]
536mod tests {
537 use super::*;
538
539 #[test]
540 fn test_motion_vector_creation() {
541 let mv = MotionVector::new(3.0, 4.0, 100.0);
542 assert!((mv.dx - 3.0).abs() < f64::EPSILON);
543 assert!((mv.dy - 4.0).abs() < f64::EPSILON);
544 assert!((mv.cost - 100.0).abs() < f64::EPSILON);
545 }
546
547 #[test]
548 fn test_motion_vector_magnitude() {
549 let mv = MotionVector::new(3.0, 4.0, 0.0);
550 assert!((mv.magnitude() - 5.0).abs() < 1e-10);
551 }
552
553 #[test]
554 fn test_motion_vector_direction() {
555 let mv = MotionVector::new(1.0, 0.0, 0.0);
556 assert!((mv.direction()).abs() < 1e-10);
557
558 let mv_up = MotionVector::new(0.0, 1.0, 0.0);
559 assert!((mv_up.direction() - std::f64::consts::FRAC_PI_2).abs() < 1e-10);
560 }
561
562 #[test]
563 fn test_motion_vector_zero() {
564 let mv = MotionVector::zero();
565 assert!((mv.magnitude()).abs() < f64::EPSILON);
566 }
567
568 #[test]
569 fn test_motion_vector_add() {
570 let a = MotionVector::new(1.0, 2.0, 10.0);
571 let b = MotionVector::new(3.0, 4.0, 20.0);
572 let c = a.add(&b);
573 assert!((c.dx - 4.0).abs() < f64::EPSILON);
574 assert!((c.dy - 6.0).abs() < f64::EPSILON);
575 assert!((c.cost - 15.0).abs() < f64::EPSILON);
576 }
577
578 #[test]
579 fn test_motion_vector_scale() {
580 let mv = MotionVector::new(2.0, 3.0, 10.0);
581 let scaled = mv.scale(0.5);
582 assert!((scaled.dx - 1.0).abs() < f64::EPSILON);
583 assert!((scaled.dy - 1.5).abs() < f64::EPSILON);
584 }
585
586 #[test]
587 fn test_motion_field_creation() {
588 let field = MotionField::new(320, 240, 16);
589 assert_eq!(field.cols, 20);
590 assert_eq!(field.rows, 15);
591 assert_eq!(field.vectors.len(), 300);
592 }
593
594 #[test]
595 fn test_motion_field_get_set() {
596 let mut field = MotionField::new(64, 64, 16);
597 let mv = MotionVector::new(5.0, -3.0, 50.0);
598 field.set(1, 2, mv);
599 let retrieved = field.get(1, 2).expect("retrieved should be valid");
600 assert!((retrieved.dx - 5.0).abs() < f64::EPSILON);
601 assert!((retrieved.dy - (-3.0)).abs() < f64::EPSILON);
602 }
603
604 #[test]
605 fn test_motion_field_global_motion() {
606 let mut field = MotionField::new(64, 64, 16);
607 for by in 0..field.rows {
609 for bx in 0..field.cols {
610 field.set(bx, by, MotionVector::new(2.0, 1.0, 0.0));
611 }
612 }
613 let global = field.global_motion();
614 assert!((global.dx - 2.0).abs() < f64::EPSILON);
615 assert!((global.dy - 1.0).abs() < f64::EPSILON);
616 }
617
618 #[test]
619 fn test_motion_field_average_magnitude() {
620 let mut field = MotionField::new(32, 32, 16);
621 field.set(0, 0, MotionVector::new(3.0, 4.0, 0.0));
622 field.set(1, 0, MotionVector::new(0.0, 0.0, 0.0));
623 field.set(0, 1, MotionVector::new(0.0, 0.0, 0.0));
624 field.set(1, 1, MotionVector::new(0.0, 0.0, 0.0));
625 let avg = field.average_magnitude();
626 assert!((avg - 1.25).abs() < 1e-10);
628 }
629
630 #[test]
631 fn test_estimate_static_frames() {
632 let config = MotionEstimationConfig {
633 block_size: 8,
634 search_range: 4,
635 search_strategy: SearchStrategy::FullSearch,
636 sub_pixel: false,
637 frame_width: 32,
638 frame_height: 32,
639 };
640 let comp = MotionCompensator::new(config);
641
642 let frame = vec![128u8; 32 * 32];
644 let field = comp
645 .estimate(&frame, &frame)
646 .expect("field should be valid");
647
648 for mv in &field.vectors {
649 assert!((mv.dx).abs() < f64::EPSILON);
650 assert!((mv.dy).abs() < f64::EPSILON);
651 }
652 }
653
654 #[test]
655 fn test_estimate_wrong_size() {
656 let comp = MotionCompensator::new(MotionEstimationConfig {
657 frame_width: 64,
658 frame_height: 64,
659 ..MotionEstimationConfig::default()
660 });
661 let small_frame = vec![0u8; 10];
662 let result = comp.estimate(&small_frame, &small_frame);
663 assert!(result.is_err());
664 }
665
666 #[test]
667 fn test_compensate_points() {
668 let mut field = MotionField::new(64, 64, 64);
669 field.set(0, 0, MotionVector::new(10.0, -5.0, 0.0));
670
671 let points = vec![Point2D::new(10.0, 20.0)];
672 let compensated = MotionCompensator::compensate_points(&field, &points);
673 assert!((compensated[0].x - 20.0).abs() < 1e-10);
674 assert!((compensated[0].y - 15.0).abs() < 1e-10);
675 }
676
677 #[test]
678 fn test_motion_stats_static() {
679 let field = MotionField::new(64, 64, 16);
680 let stats = MotionCompensator::compute_stats(&field);
681 assert!((stats.avg_magnitude).abs() < f64::EPSILON);
682 assert!((stats.max_magnitude).abs() < f64::EPSILON);
683 assert!((stats.motion_fraction).abs() < f64::EPSILON);
684 }
685}