1use crate::{GpuError, Result};
33use rayon::prelude::*;
34
35#[derive(Debug, Clone)]
41pub struct OpticalFlowConfig {
42 pub pyramid_levels: u32,
46 pub iterations: u32,
48 pub window_size: u32,
50 pub smoothing_sigma: f32,
52 pub max_displacement: f32,
54}
55
56impl Default for OpticalFlowConfig {
57 fn default() -> Self {
58 Self {
59 pyramid_levels: 4,
60 iterations: 3,
61 window_size: 15,
62 smoothing_sigma: 1.5,
63 max_displacement: 4.0,
64 }
65 }
66}
67
68#[derive(Debug, Clone, Copy, PartialEq)]
74pub struct FlowVector {
75 pub dx: f32,
77 pub dy: f32,
79}
80
81impl FlowVector {
82 #[must_use]
84 pub const fn zero() -> Self {
85 Self { dx: 0.0, dy: 0.0 }
86 }
87
88 #[must_use]
90 pub fn magnitude(self) -> f32 {
91 (self.dx * self.dx + self.dy * self.dy).sqrt()
92 }
93}
94
95#[derive(Debug, Clone)]
99pub struct FlowField {
100 pub width: u32,
102 pub height: u32,
104 pub vectors: Vec<FlowVector>,
106}
107
108impl FlowField {
109 #[must_use]
111 pub fn zeros(width: u32, height: u32) -> Self {
112 let n = (width as usize) * (height as usize);
113 Self {
114 width,
115 height,
116 vectors: vec![FlowVector::zero(); n],
117 }
118 }
119
120 #[must_use]
122 pub fn get(&self, x: u32, y: u32) -> Option<FlowVector> {
123 if x >= self.width || y >= self.height {
124 return None;
125 }
126 self.vectors
127 .get((y as usize) * (self.width as usize) + (x as usize))
128 .copied()
129 }
130
131 #[must_use]
133 pub fn mean_magnitude(&self) -> f32 {
134 if self.vectors.is_empty() {
135 return 0.0;
136 }
137 let sum: f32 = self.vectors.iter().map(|v| v.magnitude()).sum();
138 sum / self.vectors.len() as f32
139 }
140
141 pub fn warp_frame(&self, target: &[u8]) -> Result<Vec<u8>> {
156 let expected = (self.width as usize) * (self.height as usize) * 4;
157 if target.len() != expected {
158 return Err(GpuError::InvalidBufferSize {
159 expected,
160 actual: target.len(),
161 });
162 }
163
164 let w = self.width as usize;
165 let h = self.height as usize;
166 let mut output = vec![0u8; expected];
167
168 output
169 .par_chunks_exact_mut(4)
170 .enumerate()
171 .for_each(|(idx, pix)| {
172 let px = (idx % w) as f32;
173 let py = (idx / w) as f32;
174
175 let fv = self.vectors[idx];
176 let src_x = px - fv.dx;
177 let src_y = py - fv.dy;
178
179 let x0 = src_x.floor() as isize;
181 let y0 = src_y.floor() as isize;
182 let tx = src_x - x0 as f32;
183 let ty = src_y - y0 as f32;
184
185 let sample = |xi: isize, yi: isize| -> [f32; 4] {
186 if xi < 0 || yi < 0 || xi >= w as isize || yi >= h as isize {
187 return [0.0; 4];
188 }
189 let off = (yi as usize * w + xi as usize) * 4;
190 [
191 target[off] as f32,
192 target[off + 1] as f32,
193 target[off + 2] as f32,
194 target[off + 3] as f32,
195 ]
196 };
197
198 let c00 = sample(x0, y0);
199 let c10 = sample(x0 + 1, y0);
200 let c01 = sample(x0, y0 + 1);
201 let c11 = sample(x0 + 1, y0 + 1);
202
203 for ch in 0..4 {
204 let v = c00[ch] * (1.0 - tx) * (1.0 - ty)
205 + c10[ch] * tx * (1.0 - ty)
206 + c01[ch] * (1.0 - tx) * ty
207 + c11[ch] * tx * ty;
208 pix[ch] = v.clamp(0.0, 255.0) as u8;
209 }
210 });
211
212 Ok(output)
213 }
214}
215
216#[derive(Debug, Clone)]
225pub struct OpticalFlowEstimator {
226 config: OpticalFlowConfig,
227}
228
229impl OpticalFlowEstimator {
230 #[must_use]
232 pub fn new(config: OpticalFlowConfig) -> Self {
233 Self { config }
234 }
235
236 #[must_use]
238 pub fn default_config() -> Self {
239 Self::new(OpticalFlowConfig::default())
240 }
241
242 pub fn estimate(&self, prev: &[u8], next: &[u8], width: u32, height: u32) -> Result<FlowField> {
255 let expected = (width as usize) * (height as usize);
256 if prev.len() != expected {
257 return Err(GpuError::InvalidBufferSize {
258 expected,
259 actual: prev.len(),
260 });
261 }
262 if next.len() != expected {
263 return Err(GpuError::InvalidBufferSize {
264 expected,
265 actual: next.len(),
266 });
267 }
268
269 if width == 0 || height == 0 {
270 return Err(GpuError::InvalidDimensions { width, height });
271 }
272
273 let window = self.config.window_size.max(3) | 1; let levels = self.config.pyramid_levels.max(1).min(8);
275
276 let prev_pyr = build_gaussian_pyramid(prev, width, height, levels);
278 let next_pyr = build_gaussian_pyramid(next, width, height, levels);
279
280 let (cw, ch) = pyramid_dims(width, height, levels - 1);
282 let mut flow = FlowField::zeros(cw, ch);
283
284 for level in (0..levels).rev() {
286 let (lw, lh) = pyramid_dims(width, height, level);
287 let prev_lvl = &prev_pyr[level as usize];
288 let next_lvl = &next_pyr[level as usize];
289
290 if level + 1 < levels {
292 flow = upscale_flow(&flow, lw, lh);
293 let scale = 2.0f32;
295 for v in &mut flow.vectors {
296 v.dx *= scale;
297 v.dy *= scale;
298 }
299 } else {
300 flow = FlowField::zeros(lw, lh);
301 }
302
303 for _ in 0..self.config.iterations {
305 flow = refine_flow(
306 flow,
307 prev_lvl,
308 next_lvl,
309 lw,
310 lh,
311 window,
312 self.config.max_displacement,
313 );
314 }
315 }
316
317 if self.config.smoothing_sigma > 0.0 {
319 flow = smooth_flow(flow, width, height, self.config.smoothing_sigma);
320 }
321
322 Ok(flow)
323 }
324}
325
326fn pyramid_dims(w: u32, h: u32, level: u32) -> (u32, u32) {
332 let scale = 1u32 << level;
333 let lw = (w / scale).max(1);
334 let lh = (h / scale).max(1);
335 (lw, lh)
336}
337
338fn build_gaussian_pyramid(frame: &[u8], width: u32, height: u32, levels: u32) -> Vec<Vec<u8>> {
340 let mut pyramid = Vec::with_capacity(levels as usize);
341 pyramid.push(frame.to_vec());
342
343 for l in 1..levels {
344 let (pw, ph) = pyramid_dims(width, height, l - 1);
345 let (cw, ch) = pyramid_dims(width, height, l);
346 let prev = &pyramid[(l - 1) as usize];
347 let downsampled = downsample_2x(prev, pw, ph, cw, ch);
348 pyramid.push(downsampled);
349 }
350
351 pyramid
352}
353
354fn downsample_2x(src: &[u8], sw: u32, sh: u32, dw: u32, dh: u32) -> Vec<u8> {
356 let mut dst = vec![0u8; (dw * dh) as usize];
357 for dy in 0..dh {
358 for dx in 0..dw {
359 let sy0 = (dy * 2).min(sh - 1) as usize;
360 let sy1 = (dy * 2 + 1).min(sh - 1) as usize;
361 let sx0 = (dx * 2).min(sw - 1) as usize;
362 let sx1 = (dx * 2 + 1).min(sw - 1) as usize;
363 let sum = src[sy0 * sw as usize + sx0] as u32
364 + src[sy0 * sw as usize + sx1] as u32
365 + src[sy1 * sw as usize + sx0] as u32
366 + src[sy1 * sw as usize + sx1] as u32;
367 dst[dy as usize * dw as usize + dx as usize] = (sum / 4) as u8;
368 }
369 }
370 dst
371}
372
373fn upscale_flow(flow: &FlowField, new_w: u32, new_h: u32) -> FlowField {
375 let ow = flow.width as usize;
376 let oh = flow.height as usize;
377 let mut new_vectors = Vec::with_capacity((new_w * new_h) as usize);
378
379 for ny in 0..new_h as usize {
380 for nx in 0..new_w as usize {
381 let sx = ((nx * ow) / new_w as usize).min(ow - 1);
382 let sy = ((ny * oh) / new_h as usize).min(oh - 1);
383 let idx = sy * ow + sx;
384 new_vectors.push(flow.vectors[idx]);
385 }
386 }
387
388 FlowField {
389 width: new_w,
390 height: new_h,
391 vectors: new_vectors,
392 }
393}
394
395fn refine_flow(
401 flow: FlowField,
402 prev: &[u8],
403 next: &[u8],
404 w: u32,
405 h: u32,
406 window: u32,
407 max_disp: f32,
408) -> FlowField {
409 let w_usize = w as usize;
410 let h_usize = h as usize;
411 let half = (window / 2) as isize;
412
413 let prev_f: Vec<f32> = prev.iter().map(|&v| v as f32).collect();
414 let next_f: Vec<f32> = next.iter().map(|&v| v as f32).collect();
415
416 let new_vectors: Vec<FlowVector> = (0..flow.vectors.len())
417 .into_par_iter()
418 .map(|idx| {
419 let px = (idx % w_usize) as isize;
420 let py = (idx / w_usize) as isize;
421
422 let old_v = flow.vectors[idx];
423
424 let mut a11 = 0.0f32;
425 let mut a12 = 0.0f32;
426 let mut a22 = 0.0f32;
427 let mut b1 = 0.0f32;
428 let mut b2 = 0.0f32;
429
430 for wy in -half..=half {
431 for wx in -half..=half {
432 let x = px + wx;
433 let y = py + wy;
434
435 if x < 1 || y < 1 || x >= w_usize as isize - 1 || y >= h_usize as isize - 1 {
436 continue;
437 }
438
439 let ix = (prev_f[y as usize * w_usize + (x + 1) as usize]
441 - prev_f[y as usize * w_usize + (x - 1) as usize])
442 * 0.5;
443 let iy = (prev_f[(y + 1) as usize * w_usize + x as usize]
444 - prev_f[(y - 1) as usize * w_usize + x as usize])
445 * 0.5;
446
447 let nx_f = x as f32 + old_v.dx;
449 let ny_f = y as f32 + old_v.dy;
450 let next_val = sample_bilinear(&next_f, nx_f, ny_f, w_usize, h_usize);
451 let prev_val = prev_f[y as usize * w_usize + x as usize];
452
453 let it = next_val - prev_val;
454
455 a11 += ix * ix;
456 a12 += ix * iy;
457 a22 += iy * iy;
458 b1 -= ix * it;
459 b2 -= iy * it;
460 }
461 }
462
463 let det = a11 * a22 - a12 * a12;
465 if det.abs() < 1e-6 {
466 return old_v;
467 }
468
469 let ddx = (a22 * b1 - a12 * b2) / det;
470 let ddy = (a11 * b2 - a12 * b1) / det;
471
472 let mag = (ddx * ddx + ddy * ddy).sqrt();
474 let (ddx, ddy) = if mag > max_disp {
475 (ddx * max_disp / mag, ddy * max_disp / mag)
476 } else {
477 (ddx, ddy)
478 };
479
480 FlowVector {
481 dx: old_v.dx + ddx,
482 dy: old_v.dy + ddy,
483 }
484 })
485 .collect();
486
487 FlowField {
488 width: w,
489 height: h,
490 vectors: new_vectors,
491 }
492}
493
494fn sample_bilinear(frame: &[f32], x: f32, y: f32, w: usize, h: usize) -> f32 {
496 let x0 = x.floor() as isize;
497 let y0 = y.floor() as isize;
498 let tx = x - x0 as f32;
499 let ty = y - y0 as f32;
500
501 let sample = |xi: isize, yi: isize| -> f32 {
502 let xi = xi.clamp(0, w as isize - 1) as usize;
503 let yi = yi.clamp(0, h as isize - 1) as usize;
504 frame[yi * w + xi]
505 };
506
507 sample(x0, y0) * (1.0 - tx) * (1.0 - ty)
508 + sample(x0 + 1, y0) * tx * (1.0 - ty)
509 + sample(x0, y0 + 1) * (1.0 - tx) * ty
510 + sample(x0 + 1, y0 + 1) * tx * ty
511}
512
513fn smooth_flow(flow: FlowField, w: u32, h: u32, sigma: f32) -> FlowField {
515 let kernel = gaussian_kernel_1d(sigma);
516 let half = (kernel.len() / 2) as isize;
517 let w_usize = w as usize;
518 let h_usize = h as usize;
519
520 let mut temp = vec![FlowVector::zero(); (w * h) as usize];
522 for y in 0..h_usize {
523 for x in 0..w_usize {
524 let mut dx_sum = 0.0f32;
525 let mut dy_sum = 0.0f32;
526 for (ki, &kv) in kernel.iter().enumerate() {
527 let sx = (x as isize + ki as isize - half).clamp(0, w_usize as isize - 1) as usize;
528 let v = flow.vectors[y * w_usize + sx];
529 dx_sum += kv * v.dx;
530 dy_sum += kv * v.dy;
531 }
532 temp[y * w_usize + x] = FlowVector {
533 dx: dx_sum,
534 dy: dy_sum,
535 };
536 }
537 }
538
539 let mut out = vec![FlowVector::zero(); (w * h) as usize];
541 for y in 0..h_usize {
542 for x in 0..w_usize {
543 let mut dx_sum = 0.0f32;
544 let mut dy_sum = 0.0f32;
545 for (ki, &kv) in kernel.iter().enumerate() {
546 let sy = (y as isize + ki as isize - half).clamp(0, h_usize as isize - 1) as usize;
547 let v = temp[sy * w_usize + x];
548 dx_sum += kv * v.dx;
549 dy_sum += kv * v.dy;
550 }
551 out[y * w_usize + x] = FlowVector {
552 dx: dx_sum,
553 dy: dy_sum,
554 };
555 }
556 }
557
558 FlowField {
559 width: w,
560 height: h,
561 vectors: out,
562 }
563}
564
565fn gaussian_kernel_1d(sigma: f32) -> Vec<f32> {
567 let radius = (3.0 * sigma).ceil() as usize;
568 let size = 2 * radius + 1;
569 let mut kernel = Vec::with_capacity(size);
570
571 let two_sigma_sq = 2.0 * sigma * sigma;
572 let mut sum = 0.0f32;
573
574 for i in 0..size {
575 let x = i as f32 - radius as f32;
576 let v = (-x * x / two_sigma_sq).exp();
577 kernel.push(v);
578 sum += v;
579 }
580
581 for v in &mut kernel {
583 *v /= sum;
584 }
585
586 kernel
587}
588
589#[cfg(test)]
594mod tests {
595 use super::*;
596
597 fn gray_frame(w: usize, h: usize, fill: u8) -> Vec<u8> {
598 vec![fill; w * h]
599 }
600
601 #[test]
602 fn test_flow_field_zeros() {
603 let f = FlowField::zeros(4, 4);
604 assert_eq!(f.width, 4);
605 assert_eq!(f.height, 4);
606 assert_eq!(f.vectors.len(), 16);
607 assert!((f.mean_magnitude() - 0.0).abs() < 1e-6);
608 }
609
610 #[test]
611 fn test_flow_field_get_oob_returns_none() {
612 let f = FlowField::zeros(4, 4);
613 assert!(f.get(4, 0).is_none());
614 assert!(f.get(0, 4).is_none());
615 }
616
617 #[test]
618 fn test_static_scene_zero_flow() {
619 let w = 32;
620 let h = 32;
621 let frame = gray_frame(w, h, 128);
622 let estimator = OpticalFlowEstimator::new(OpticalFlowConfig {
623 pyramid_levels: 2,
624 iterations: 2,
625 window_size: 5,
626 smoothing_sigma: 0.0,
627 max_displacement: 2.0,
628 });
629 let flow = estimator
630 .estimate(&frame, &frame, w as u32, h as u32)
631 .expect("should succeed");
632 assert!(
634 flow.mean_magnitude() < 0.5,
635 "mean magnitude = {}",
636 flow.mean_magnitude()
637 );
638 }
639
640 #[test]
641 fn test_invalid_dimensions_rejected() {
642 let estimator = OpticalFlowEstimator::default_config();
643 let res = estimator.estimate(&[], &[], 0, 0);
644 assert!(res.is_err());
645 }
646
647 #[test]
648 fn test_buffer_size_mismatch_rejected() {
649 let estimator = OpticalFlowEstimator::default_config();
650 let res = estimator.estimate(&[0u8; 4], &[0u8; 8], 2, 2);
651 assert!(res.is_err());
652 }
653
654 #[test]
655 fn test_warp_frame_identity() {
656 let w = 4u32;
657 let h = 4u32;
658 let flow = FlowField::zeros(w, h);
659 let frame: Vec<u8> = (0..(w * h * 4) as usize).map(|i| (i % 256) as u8).collect();
660 let warped = flow.warp_frame(&frame).expect("warp should succeed");
661 assert_eq!(warped.len(), frame.len());
663 }
664
665 #[test]
666 fn test_pyramid_dims_decreasing() {
667 assert_eq!(pyramid_dims(64, 48, 0), (64, 48));
668 assert_eq!(pyramid_dims(64, 48, 1), (32, 24));
669 assert_eq!(pyramid_dims(64, 48, 2), (16, 12));
670 }
671
672 #[test]
673 fn test_flow_vector_magnitude() {
674 let v = FlowVector { dx: 3.0, dy: 4.0 };
675 assert!((v.magnitude() - 5.0).abs() < 1e-5);
676 }
677}