1use crate::{AlignError, AlignResult};
8
9#[derive(Debug, Clone, Copy, PartialEq)]
11#[allow(dead_code)]
12pub struct FlowVector {
13 pub dx: f32,
15 pub dy: f32,
17 pub confidence: f32,
19}
20
21impl FlowVector {
22 #[must_use]
24 pub fn new(dx: f32, dy: f32, confidence: f32) -> Self {
25 Self { dx, dy, confidence }
26 }
27
28 #[must_use]
30 pub fn magnitude(&self) -> f32 {
31 (self.dx * self.dx + self.dy * self.dy).sqrt()
32 }
33
34 #[must_use]
36 pub fn angle_radians(&self) -> f32 {
37 self.dy.atan2(self.dx)
38 }
39}
40
41#[derive(Debug, Clone)]
43#[allow(dead_code)]
44pub struct FlowField {
45 pub vectors: Vec<FlowVector>,
47 pub width: u32,
49 pub height: u32,
51 pub block_size: u32,
53}
54
55impl FlowField {
56 #[must_use]
58 pub fn new(width: u32, height: u32, block_size: u32) -> Self {
59 let cols = cols(width, block_size);
60 let rows = rows(height, block_size);
61 Self {
62 vectors: vec![FlowVector::new(0.0, 0.0, 0.0); (cols * rows) as usize],
63 width,
64 height,
65 block_size,
66 }
67 }
68
69 pub fn set(&mut self, x: u32, y: u32, flow: FlowVector) {
74 let c = cols(self.width, self.block_size);
75 let r = rows(self.height, self.block_size);
76 if x < c && y < r {
77 self.vectors[(y * c + x) as usize] = flow;
78 }
79 }
80
81 #[must_use]
85 pub fn get(&self, x: u32, y: u32) -> Option<&FlowVector> {
86 let c = cols(self.width, self.block_size);
87 let r = rows(self.height, self.block_size);
88 if x < c && y < r {
89 Some(&self.vectors[(y * c + x) as usize])
90 } else {
91 None
92 }
93 }
94
95 #[must_use]
97 pub fn block_cols(&self) -> u32 {
98 cols(self.width, self.block_size)
99 }
100
101 #[must_use]
103 pub fn block_rows(&self) -> u32 {
104 rows(self.height, self.block_size)
105 }
106
107 #[must_use]
111 pub fn avg_magnitude(&self) -> f32 {
112 let mut weighted_sum = 0.0_f32;
113 let mut weight_total = 0.0_f32;
114
115 for v in &self.vectors {
116 weighted_sum += v.magnitude() * v.confidence;
117 weight_total += v.confidence;
118 }
119
120 if weight_total == 0.0 {
121 return 0.0;
122 }
123 weighted_sum / weight_total
124 }
125
126 #[must_use]
130 pub fn dominant_direction(&self) -> (f32, f32) {
131 let mut sum_dx = 0.0_f32;
132 let mut sum_dy = 0.0_f32;
133 let mut weight_total = 0.0_f32;
134
135 for v in &self.vectors {
136 sum_dx += v.dx * v.confidence;
137 sum_dy += v.dy * v.confidence;
138 weight_total += v.confidence;
139 }
140
141 if weight_total == 0.0 {
142 return (0.0, 0.0);
143 }
144 (sum_dx / weight_total, sum_dy / weight_total)
145 }
146}
147
148fn cols(width: u32, block_size: u32) -> u32 {
151 if block_size == 0 {
152 return 0;
153 }
154 width.div_ceil(block_size)
155}
156
157fn rows(height: u32, block_size: u32) -> u32 {
158 if block_size == 0 {
159 return 0;
160 }
161 height.div_ceil(block_size)
162}
163
164#[must_use]
166pub fn sum_squared_diff(a: &[u8], b: &[u8]) -> u64 {
167 a.iter()
168 .zip(b.iter())
169 .map(|(&x, &y)| {
170 let d = i32::from(x) - i32::from(y);
171 (d * d) as u64
172 })
173 .sum()
174}
175
176#[allow(clippy::too_many_arguments)]
190#[must_use]
191pub fn block_match_flow(
192 prev: &[u8],
193 curr: &[u8],
194 width: u32,
195 height: u32,
196 block_size: u32,
197 search_range: i32,
198) -> FlowField {
199 let mut field = FlowField::new(width, height, block_size);
200
201 let bsize = block_size as i32;
202 let w = width as i32;
203 let h = height as i32;
204
205 for by in 0..field.block_rows() {
206 for bx in 0..field.block_cols() {
207 let px0 = (bx * block_size) as i32;
208 let py0 = (by * block_size) as i32;
209
210 let mut best_ssd = u64::MAX;
211 let mut best_dx = 0_i32;
212 let mut best_dy = 0_i32;
213
214 for dy in -search_range..=search_range {
216 for dx in -search_range..=search_range {
217 let cx0 = px0 + dx;
218 let cy0 = py0 + dy;
219
220 if cx0 < 0 || cy0 < 0 || cx0 + bsize > w || cy0 + bsize > h {
222 continue;
223 }
224
225 let mut ssd = 0_u64;
227 for row in 0..bsize {
228 let prev_row_start = (py0 + row) * w + px0;
229 let curr_row_start = (cy0 + row) * w + cx0;
230
231 let p_row =
232 &prev[prev_row_start as usize..(prev_row_start + bsize) as usize];
233 let c_row =
234 &curr[curr_row_start as usize..(curr_row_start + bsize) as usize];
235
236 ssd += sum_squared_diff(p_row, c_row);
237 }
238
239 if ssd < best_ssd
240 || (ssd == best_ssd
241 && (dx.unsigned_abs() + dy.unsigned_abs())
242 < (best_dx.unsigned_abs() + best_dy.unsigned_abs()))
243 {
244 best_ssd = ssd;
245 best_dx = dx;
246 best_dy = dy;
247 }
248 }
249 }
250
251 let max_ssd = 255_u64 * 255 * (bsize * bsize) as u64;
254 let confidence = if max_ssd == 0 {
255 0.0
256 } else {
257 1.0 - (best_ssd as f32 / max_ssd as f32).min(1.0)
258 };
259
260 field.set(
261 bx,
262 by,
263 FlowVector::new(best_dx as f32, best_dy as f32, confidence),
264 );
265 }
266 }
267
268 field
269}
270
271pub fn compute_dense_flow(
288 prev: &[f32],
289 curr: &[f32],
290 width: u32,
291 height: u32,
292) -> AlignResult<Vec<(f32, f32)>> {
293 let w = width as usize;
294 let h = height as usize;
295
296 if prev.len() != w * h || curr.len() != w * h {
297 return Err(AlignError::InvalidConfig(
298 "dense flow: slice length does not match width * height".to_string(),
299 ));
300 }
301
302 let min_val = prev
305 .iter()
306 .chain(curr.iter())
307 .copied()
308 .fold(f32::INFINITY, f32::min);
309 let max_val = prev
310 .iter()
311 .chain(curr.iter())
312 .copied()
313 .fold(f32::NEG_INFINITY, f32::max);
314 let range = max_val - min_val;
315
316 let to_u8 = |v: f32| -> u8 {
317 if range < 1e-8 {
318 128u8
319 } else {
320 (((v - min_val) / range) * 255.0).round().clamp(0.0, 255.0) as u8
321 }
322 };
323
324 let prev_u8: Vec<u8> = prev.iter().copied().map(to_u8).collect();
325 let curr_u8: Vec<u8> = curr.iter().copied().map(to_u8).collect();
326
327 let config = crate::farneback_flow::FarnebackConfig::default();
328 let field = crate::farneback_flow::compute_farneback_flow(&prev_u8, &curr_u8, w, h, &config)?;
329
330 Ok(field.dx.into_iter().zip(field.dy).collect())
331}
332
333#[cfg(test)]
334mod tests {
335 use super::*;
336
337 #[test]
340 fn test_flow_vector_magnitude_zero() {
341 let v = FlowVector::new(0.0, 0.0, 1.0);
342 assert_eq!(v.magnitude(), 0.0);
343 }
344
345 #[test]
346 fn test_flow_vector_magnitude_pythagorean() {
347 let v = FlowVector::new(3.0, 4.0, 1.0);
348 assert!((v.magnitude() - 5.0).abs() < 1e-5);
349 }
350
351 #[test]
352 fn test_flow_vector_angle() {
353 let v = FlowVector::new(1.0, 0.0, 1.0);
354 assert!(v.angle_radians().abs() < 1e-5); let v2 = FlowVector::new(0.0, 1.0, 1.0);
357 assert!((v2.angle_radians() - std::f32::consts::FRAC_PI_2).abs() < 1e-5);
358 }
359
360 #[test]
363 fn test_flow_field_dimensions() {
364 let f = FlowField::new(64, 48, 8);
365 assert_eq!(f.block_cols(), 8);
366 assert_eq!(f.block_rows(), 6);
367 assert_eq!(f.vectors.len(), 48);
368 }
369
370 #[test]
371 fn test_flow_field_set_get() {
372 let mut f = FlowField::new(32, 32, 8);
373 f.set(1, 2, FlowVector::new(3.0, -1.0, 0.8));
374 let v = f.get(1, 2).expect("valid position");
375 assert_eq!(v.dx, 3.0);
376 assert_eq!(v.dy, -1.0);
377 assert_eq!(v.confidence, 0.8);
378 }
379
380 #[test]
381 fn test_flow_field_get_out_of_bounds() {
382 let f = FlowField::new(32, 32, 8);
383 assert!(f.get(100, 100).is_none());
384 }
385
386 #[test]
387 fn test_flow_field_avg_magnitude_all_zero() {
388 let f = FlowField::new(32, 32, 8);
389 assert_eq!(f.avg_magnitude(), 0.0);
390 }
391
392 #[test]
393 fn test_flow_field_avg_magnitude_single_vector() {
394 let mut f = FlowField::new(8, 8, 8);
395 f.set(0, 0, FlowVector::new(3.0, 4.0, 1.0)); assert!((f.avg_magnitude() - 5.0).abs() < 1e-4);
397 }
398
399 #[test]
400 fn test_flow_field_dominant_direction_zero_confidence() {
401 let f = FlowField::new(16, 16, 8);
402 assert_eq!(f.dominant_direction(), (0.0, 0.0));
403 }
404
405 #[test]
406 fn test_flow_field_dominant_direction() {
407 let mut f = FlowField::new(16, 8, 8);
408 f.set(0, 0, FlowVector::new(2.0, 1.0, 1.0));
410 f.set(1, 0, FlowVector::new(2.0, -1.0, 1.0));
411 let (ddx, ddy) = f.dominant_direction();
412 assert!((ddx - 2.0).abs() < 1e-4);
413 assert!(ddy.abs() < 1e-4); }
415
416 #[test]
419 fn test_ssd_identical() {
420 let a = [10_u8, 20, 30];
421 assert_eq!(sum_squared_diff(&a, &a), 0);
422 }
423
424 #[test]
425 fn test_ssd_known_value() {
426 let a = [0_u8, 0, 0];
427 let b = [3_u8, 4, 0];
428 assert_eq!(sum_squared_diff(&a, &b), 25);
430 }
431
432 #[test]
435 fn test_block_match_identical_frames() {
436 let frame = vec![100_u8; 64 * 48];
437 let field = block_match_flow(&frame, &frame, 64, 48, 8, 4);
438 for v in &field.vectors {
440 assert_eq!(v.dx, 0.0);
441 assert_eq!(v.dy, 0.0);
442 }
443 }
444
445 #[test]
446 fn test_block_match_returns_correct_field_size() {
447 let prev = vec![0_u8; 32 * 32];
448 let curr = vec![0_u8; 32 * 32];
449 let field = block_match_flow(&prev, &curr, 32, 32, 8, 2);
450 assert_eq!(field.block_cols(), 4);
451 assert_eq!(field.block_rows(), 4);
452 }
453
454 #[test]
457 fn test_compute_dense_flow_identical_frames() {
458 let w = 32u32;
459 let h = 32u32;
460 let img: Vec<f32> = (0..w * h).map(|i| (i % 50) as f32).collect();
461
462 let flow = compute_dense_flow(&img, &img, w, h).expect("should succeed");
463 assert_eq!(flow.len(), (w * h) as usize);
464
465 let avg_mag: f32 = flow
468 .iter()
469 .map(|(dx, dy)| (dx * dx + dy * dy).sqrt())
470 .sum::<f32>()
471 / flow.len() as f32;
472
473 assert!(
474 avg_mag < 1.0,
475 "avg magnitude on identical frames: {avg_mag}"
476 );
477 }
478
479 #[test]
480 fn test_compute_dense_flow_length_mismatch() {
481 let a = vec![0.0_f32; 64 * 64];
482 let b = vec![0.0_f32; 32 * 32];
483 let result = compute_dense_flow(&a, &b, 64, 64);
484 assert!(result.is_err(), "mismatched lengths should error");
485 }
486
487 #[test]
488 fn test_compute_dense_flow_too_small() {
489 let a = vec![0.0_f32; 4 * 4];
490 let b = vec![0.0_f32; 4 * 4];
491 let result = compute_dense_flow(&a, &b, 4, 4);
492 assert!(result.is_err(), "images < 8x8 should error");
493 }
494
495 #[test]
496 fn test_compute_dense_flow_non_zero_motion() {
497 let w = 64u32;
499 let h = 64u32;
500 let n = (w * h) as usize;
501 let mut prev = vec![0.0_f32; n];
502 let mut curr = vec![0.0_f32; n];
503
504 for y in 0..h as usize {
505 for x in 0..w as usize {
506 prev[y * w as usize + x] = if (x / 8) % 2 == 0 { 200.0 } else { 50.0 };
507 let sx = (x + 2).min(w as usize - 1);
508 curr[y * w as usize + sx] = if (x / 8) % 2 == 0 { 200.0 } else { 50.0 };
509 }
510 }
511
512 let flow = compute_dense_flow(&prev, &curr, w, h).expect("should succeed");
513 let max_mag = flow
514 .iter()
515 .map(|(dx, dy)| (dx * dx + dy * dy).sqrt())
516 .fold(0.0_f32, f32::max);
517
518 assert!(
519 max_mag > 0.0,
520 "shifted stripe pattern should produce non-zero flow"
521 );
522 }
523
524 #[test]
525 fn test_compute_dense_flow_constant_frame() {
526 let w = 16u32;
528 let h = 16u32;
529 let img = vec![1.0_f32; (w * h) as usize];
530 let flow = compute_dense_flow(&img, &img, w, h).expect("should succeed");
531 for (dx, dy) in &flow {
532 assert!(dx.abs() < 1.0 && dy.abs() < 1.0, "dx={dx}, dy={dy}");
533 }
534 }
535}