1#![allow(clippy::cast_precision_loss)]
21
22use crate::{AlignError, AlignResult, Point2D};
23
24#[derive(Debug, Clone)]
26pub struct KltConfig {
27 pub window_half_size: usize,
29 pub pyramid_levels: usize,
31 pub max_iterations: usize,
33 pub epsilon: f64,
35 pub min_eigenvalue: f64,
38}
39
40impl Default for KltConfig {
41 fn default() -> Self {
42 Self {
43 window_half_size: 7,
44 pyramid_levels: 3,
45 max_iterations: 20,
46 epsilon: 0.03,
47 min_eigenvalue: 1e-4,
48 }
49 }
50}
51
52#[derive(Debug, Clone)]
54pub struct TrackResult {
55 pub origin: Point2D,
57 pub tracked: Option<Point2D>,
59 pub error: f64,
61 pub success: bool,
63}
64
65pub struct KltTracker {
67 pub config: KltConfig,
69}
70
71impl Default for KltTracker {
72 fn default() -> Self {
73 Self {
74 config: KltConfig::default(),
75 }
76 }
77}
78
79impl KltTracker {
80 #[must_use]
82 pub fn new(config: KltConfig) -> Self {
83 Self { config }
84 }
85
86 pub fn track_features(
102 &self,
103 prev_frame: &[u8],
104 curr_frame: &[u8],
105 width: u32,
106 height: u32,
107 points: &[(f32, f32)],
108 ) -> AlignResult<Vec<Option<(f32, f32)>>> {
109 let w = width as usize;
110 let h = height as usize;
111
112 let pts: Vec<Point2D> = points
114 .iter()
115 .map(|&(x, y)| Point2D::new(f64::from(x), f64::from(y)))
116 .collect();
117
118 let results = self.track(prev_frame, curr_frame, w, h, &pts)?;
119
120 Ok(results
121 .into_iter()
122 .map(|r| r.tracked.map(|p| (p.x as f32, p.y as f32)))
123 .collect())
124 }
125
126 pub fn track(
135 &self,
136 prev_image: &[u8],
137 curr_image: &[u8],
138 width: usize,
139 height: usize,
140 points: &[Point2D],
141 ) -> AlignResult<Vec<TrackResult>> {
142 if prev_image.len() != width * height || curr_image.len() != width * height {
143 return Err(AlignError::InvalidConfig(
144 "Image size does not match width*height".to_string(),
145 ));
146 }
147 if width < 8 || height < 8 {
148 return Err(AlignError::InvalidConfig(
149 "Image must be at least 8x8".to_string(),
150 ));
151 }
152
153 let prev_pyr = build_pyramid(prev_image, width, height, self.config.pyramid_levels);
155 let curr_pyr = build_pyramid(curr_image, width, height, self.config.pyramid_levels);
156
157 let results: Vec<TrackResult> = points
158 .iter()
159 .map(|pt| self.track_point(pt, &prev_pyr, &curr_pyr))
160 .collect();
161
162 Ok(results)
163 }
164
165 fn track_point(
167 &self,
168 point: &Point2D,
169 prev_pyr: &[PyramidLevel],
170 curr_pyr: &[PyramidLevel],
171 ) -> TrackResult {
172 let num_levels = prev_pyr.len();
173 let win = self.config.window_half_size as f64;
174
175 let mut gx = 0.0_f64;
177 let mut gy = 0.0_f64;
178
179 let mut last_error = f64::MAX;
180 let mut success = true;
181
182 for level in (0..num_levels).rev() {
184 let scale = 1.0 / (1 << level) as f64;
185 let px = point.x * scale;
186 let py = point.y * scale;
187
188 let prev_level = &prev_pyr[level];
189 let curr_level = &curr_pyr[level];
190
191 let w = prev_level.width;
192 let h = prev_level.height;
193
194 let (grad_x, grad_y) = compute_gradients(&prev_level.data, w, h);
196
197 let wi = self.config.window_half_size as isize;
199
200 let mut vx = gx;
202 let mut vy = gy;
203
204 for _iter in 0..self.config.max_iterations {
205 let mut g_xx = 0.0_f64;
206 let mut g_yy = 0.0_f64;
207 let mut g_xy = 0.0_f64;
208 let mut b_x = 0.0_f64;
209 let mut b_y = 0.0_f64;
210
211 for dy in -wi..=wi {
212 for dx in -wi..=wi {
213 let sx = px + dx as f64;
214 let sy = py + dy as f64;
215 let tx = px + dx as f64 + vx;
216 let ty = py + dy as f64 + vy;
217
218 if sx < 0.0
219 || sy < 0.0
220 || sx >= (w - 1) as f64
221 || sy >= (h - 1) as f64
222 || tx < 0.0
223 || ty < 0.0
224 || tx >= (w - 1) as f64
225 || ty >= (h - 1) as f64
226 {
227 continue;
228 }
229
230 let ix = bilinear_sample_f64(&grad_x, w, sx, sy);
231 let iy = bilinear_sample_f64(&grad_y, w, sx, sy);
232 let prev_val = bilinear_sample(&prev_level.data, w, sx, sy);
233 let curr_val = bilinear_sample(&curr_level.data, w, tx, ty);
234
235 let dt = prev_val - curr_val;
236
237 g_xx += ix * ix;
238 g_yy += iy * iy;
239 g_xy += ix * iy;
240 b_x += dt * ix;
241 b_y += dt * iy;
242 }
243 }
244
245 let trace = g_xx + g_yy;
247 let det = g_xx * g_yy - g_xy * g_xy;
248 let discriminant = trace * trace - 4.0 * det;
249 let min_eig = if discriminant >= 0.0 {
250 (trace - discriminant.sqrt()) / 2.0
251 } else {
252 0.0
253 };
254
255 if min_eig < self.config.min_eigenvalue {
256 success = false;
257 break;
258 }
259
260 if det.abs() < 1e-12 {
262 success = false;
263 break;
264 }
265
266 let dvx = (g_yy * b_x - g_xy * b_y) / det;
267 let dvy = (-g_xy * b_x + g_xx * b_y) / det;
268
269 vx += dvx;
270 vy += dvy;
271
272 if dvx * dvx + dvy * dvy < self.config.epsilon * self.config.epsilon {
273 break;
274 }
275 }
276
277 if level > 0 {
279 gx = vx * 2.0;
280 gy = vy * 2.0;
281 } else {
282 gx = vx;
283 gy = vy;
284 }
285
286 if level == 0 {
288 last_error = self.compute_tracking_error(
289 &prev_pyr[0].data,
290 &curr_pyr[0].data,
291 prev_pyr[0].width,
292 prev_pyr[0].height,
293 point.x,
294 point.y,
295 gx,
296 gy,
297 win as isize,
298 );
299 }
300 }
301
302 let tracked_x = point.x + gx;
304 let tracked_y = point.y + gy;
305 let orig_w = prev_pyr[0].width as f64;
306 let orig_h = prev_pyr[0].height as f64;
307
308 if !success
309 || tracked_x < 0.0
310 || tracked_y < 0.0
311 || tracked_x >= orig_w
312 || tracked_y >= orig_h
313 {
314 return TrackResult {
315 origin: *point,
316 tracked: None,
317 error: last_error,
318 success: false,
319 };
320 }
321
322 TrackResult {
323 origin: *point,
324 tracked: Some(Point2D::new(tracked_x, tracked_y)),
325 error: last_error,
326 success: true,
327 }
328 }
329
330 #[allow(clippy::too_many_arguments, clippy::manual_checked_ops)]
332 fn compute_tracking_error(
333 &self,
334 prev: &[u8],
335 curr: &[u8],
336 w: usize,
337 h: usize,
338 px: f64,
339 py: f64,
340 vx: f64,
341 vy: f64,
342 half_win: isize,
343 ) -> f64 {
344 let mut ssd = 0.0_f64;
345 let mut count = 0u32;
346
347 for dy in -half_win..=half_win {
348 for dx in -half_win..=half_win {
349 let sx = px + dx as f64;
350 let sy = py + dy as f64;
351 let tx = sx + vx;
352 let ty = sy + vy;
353
354 if sx >= 0.0
355 && sy >= 0.0
356 && sx < (w - 1) as f64
357 && sy < (h - 1) as f64
358 && tx >= 0.0
359 && ty >= 0.0
360 && tx < (w - 1) as f64
361 && ty < (h - 1) as f64
362 {
363 let a = bilinear_sample(prev, w, sx, sy);
364 let b = bilinear_sample(curr, w, tx, ty);
365 let d = a - b;
366 ssd += d * d;
367 count += 1;
368 }
369 }
370 }
371
372 if count == 0 {
373 return f64::MAX;
374 }
375 ssd / f64::from(count)
376 }
377}
378
379#[derive(Debug, Clone)]
383struct PyramidLevel {
384 data: Vec<u8>,
385 width: usize,
386 height: usize,
387}
388
389fn build_pyramid(image: &[u8], width: usize, height: usize, levels: usize) -> Vec<PyramidLevel> {
391 let mut pyramid = Vec::with_capacity(levels);
392 pyramid.push(PyramidLevel {
393 data: image.to_vec(),
394 width,
395 height,
396 });
397
398 let mut cur = image.to_vec();
399 let mut cw = width;
400 let mut ch = height;
401
402 for _ in 1..levels {
403 let nw = cw / 2;
404 let nh = ch / 2;
405 if nw < 4 || nh < 4 {
406 break;
407 }
408 let down = downsample_2x(&cur, cw, ch, nw, nh);
409 pyramid.push(PyramidLevel {
410 data: down.clone(),
411 width: nw,
412 height: nh,
413 });
414 cur = down;
415 cw = nw;
416 ch = nh;
417 }
418
419 pyramid
420}
421
422fn downsample_2x(src: &[u8], sw: usize, sh: usize, dw: usize, dh: usize) -> Vec<u8> {
424 let mut dst = vec![0u8; dw * dh];
425 for dy in 0..dh {
426 for dx in 0..dw {
427 let sx = dx * 2;
428 let sy = dy * 2;
429 let mut sum = 0u16;
430 let mut count = 0u16;
431 for oy in 0..2 {
432 for ox in 0..2 {
433 let rx = sx + ox;
434 let ry = sy + oy;
435 if rx < sw && ry < sh {
436 sum += u16::from(src[ry * sw + rx]);
437 count += 1;
438 }
439 }
440 }
441 dst[dy * dw + dx] = sum.checked_div(count).unwrap_or(0) as u8;
442 }
443 }
444 dst
445}
446
447fn bilinear_sample(image: &[u8], width: usize, x: f64, y: f64) -> f64 {
451 let x0 = x.floor() as usize;
452 let y0 = y.floor() as usize;
453 let x1 = x0 + 1;
454 let y1 = y0 + 1;
455 let fx = x - x0 as f64;
456 let fy = y - y0 as f64;
457
458 let v00 = f64::from(image[y0 * width + x0]);
459 let v10 = f64::from(image[y0 * width + x1]);
460 let v01 = f64::from(image[y1 * width + x0]);
461 let v11 = f64::from(image[y1 * width + x1]);
462
463 v00 * (1.0 - fx) * (1.0 - fy) + v10 * fx * (1.0 - fy) + v01 * (1.0 - fx) * fy + v11 * fx * fy
464}
465
466fn bilinear_sample_f64(buf: &[f64], width: usize, x: f64, y: f64) -> f64 {
468 let x0 = x.floor() as usize;
469 let y0 = y.floor() as usize;
470 let x1 = x0 + 1;
471 let y1 = y0 + 1;
472 let fx = x - x0 as f64;
473 let fy = y - y0 as f64;
474
475 let v00 = buf[y0 * width + x0];
476 let v10 = buf[y0 * width + x1];
477 let v01 = buf[y1 * width + x0];
478 let v11 = buf[y1 * width + x1];
479
480 v00 * (1.0 - fx) * (1.0 - fy) + v10 * fx * (1.0 - fy) + v01 * (1.0 - fx) * fy + v11 * fx * fy
481}
482
483fn compute_gradients(image: &[u8], width: usize, height: usize) -> (Vec<f64>, Vec<f64>) {
485 let n = width * height;
486 let mut gx = vec![0.0_f64; n];
487 let mut gy = vec![0.0_f64; n];
488
489 for y in 1..height.saturating_sub(1) {
490 for x in 1..width.saturating_sub(1) {
491 let idx = y * width + x;
492
493 let i_tl = f64::from(image[(y - 1) * width + (x - 1)]);
494 let i_t = f64::from(image[(y - 1) * width + x]);
495 let i_tr = f64::from(image[(y - 1) * width + (x + 1)]);
496 let i_l = f64::from(image[y * width + (x - 1)]);
497 let i_r = f64::from(image[y * width + (x + 1)]);
498 let i_bl = f64::from(image[(y + 1) * width + (x - 1)]);
499 let i_b = f64::from(image[(y + 1) * width + x]);
500 let i_br = f64::from(image[(y + 1) * width + (x + 1)]);
501
502 gx[idx] = (-i_tl + i_tr - 2.0 * i_l + 2.0 * i_r - i_bl + i_br) / 8.0;
503 gy[idx] = (-i_tl - 2.0 * i_t - i_tr + i_bl + 2.0 * i_b + i_br) / 8.0;
504 }
505 }
506
507 (gx, gy)
508}
509
510#[cfg(test)]
511mod tests {
512 use super::*;
513
514 fn make_square_image(w: usize, h: usize, cx: usize, cy: usize, half: usize) -> Vec<u8> {
518 let mut img = vec![30u8; w * h];
519 for y in cy.saturating_sub(half)..=(cy + half).min(h - 1) {
520 for x in cx.saturating_sub(half)..=(cx + half).min(w - 1) {
521 img[y * w + x] = 200;
522 }
523 }
524 img
525 }
526
527 #[test]
530 fn test_build_pyramid_levels() {
531 let img = vec![128u8; 64 * 64];
532 let pyr = build_pyramid(&img, 64, 64, 3);
533 assert_eq!(pyr.len(), 3);
534 assert_eq!(pyr[0].width, 64);
535 assert_eq!(pyr[1].width, 32);
536 assert_eq!(pyr[2].width, 16);
537 }
538
539 #[test]
540 fn test_build_pyramid_single_level() {
541 let img = vec![128u8; 32 * 32];
542 let pyr = build_pyramid(&img, 32, 32, 1);
543 assert_eq!(pyr.len(), 1);
544 }
545
546 #[test]
547 fn test_downsample_preserves_constant() {
548 let img = vec![100u8; 64 * 64];
549 let down = downsample_2x(&img, 64, 64, 32, 32);
550 for &v in &down {
551 assert_eq!(v, 100);
552 }
553 }
554
555 #[test]
558 fn test_bilinear_integer_coords() {
559 let img: Vec<u8> = vec![10, 20, 30, 40];
560 let val = bilinear_sample(&img, 2, 0.0, 0.0);
561 assert!((val - 10.0).abs() < 1e-6);
562 }
563
564 #[test]
565 fn test_bilinear_midpoint() {
566 let img: Vec<u8> = vec![0, 100, 0, 100];
568 let val = bilinear_sample(&img, 2, 0.5, 0.0);
569 assert!((val - 50.0).abs() < 1e-6);
570 }
571
572 #[test]
575 fn test_klt_stationary_point() {
576 let w = 64usize;
577 let h = 64usize;
578 let img = make_square_image(w, h, 32, 32, 5);
579
580 let config = KltConfig {
581 window_half_size: 5,
582 pyramid_levels: 2,
583 max_iterations: 20,
584 epsilon: 0.01,
585 min_eigenvalue: 1e-6,
586 };
587 let tracker = KltTracker::new(config);
588 let pts = vec![Point2D::new(32.0, 32.0)];
589
590 let results = tracker
591 .track(&img, &img, w, h, &pts)
592 .expect("track should succeed");
593 assert_eq!(results.len(), 1);
594 assert!(
595 results[0].success,
596 "tracking a stationary point should succeed"
597 );
598 let tracked = results[0].tracked.expect("should have a tracked point");
599 assert!(
600 (tracked.x - 32.0).abs() < 1.0 && (tracked.y - 32.0).abs() < 1.0,
601 "stationary point should not move: got ({:.2}, {:.2})",
602 tracked.x,
603 tracked.y,
604 );
605 }
606
607 #[test]
608 fn test_klt_translated_square() {
609 let w = 128usize;
610 let h = 128usize;
611 let shift = 4;
612
613 let prev = make_square_image(w, h, 60, 60, 10);
614 let curr = make_square_image(w, h, 60 + shift, 60, 10);
615
616 let config = KltConfig {
617 window_half_size: 10,
618 pyramid_levels: 3,
619 max_iterations: 30,
620 epsilon: 0.01,
621 min_eigenvalue: 1e-6,
622 };
623 let tracker = KltTracker::new(config);
624 let pts = vec![Point2D::new(60.0, 60.0)];
625
626 let results = tracker
627 .track(&prev, &curr, w, h, &pts)
628 .expect("track should succeed");
629 assert!(results[0].success, "should successfully track the square");
630 if let Some(tracked) = &results[0].tracked {
631 let dx = tracked.x - 60.0;
632 assert!(
634 (dx - shift as f64).abs() < 2.0,
635 "expected ~{shift} px shift, got dx={dx:.2}"
636 );
637 }
638 }
639
640 #[test]
641 fn test_klt_image_size_mismatch() {
642 let tracker = KltTracker::default();
643 let pts = vec![Point2D::new(5.0, 5.0)];
644 let result = tracker.track(&[0u8; 100], &[0u8; 200], 10, 10, &pts);
645 assert!(result.is_err());
646 }
647
648 #[test]
649 fn test_klt_too_small_image() {
650 let tracker = KltTracker::default();
651 let pts = vec![Point2D::new(1.0, 1.0)];
652 let result = tracker.track(&[0u8; 4], &[0u8; 4], 2, 2, &pts);
653 assert!(result.is_err());
654 }
655
656 #[test]
657 fn test_klt_multiple_points() {
658 let w = 128usize;
659 let h = 128usize;
660 let img = make_square_image(w, h, 64, 64, 15);
661
662 let tracker = KltTracker::default();
663 let pts = vec![
664 Point2D::new(50.0, 50.0),
665 Point2D::new(70.0, 64.0),
666 Point2D::new(64.0, 70.0),
667 ];
668
669 let results = tracker
670 .track(&img, &img, w, h, &pts)
671 .expect("should succeed");
672 assert_eq!(results.len(), 3);
673 }
674
675 #[test]
676 fn test_klt_point_out_of_bounds_does_not_crash() {
677 let w = 64usize;
678 let h = 64usize;
679 let img = vec![128u8; w * h];
680
681 let tracker = KltTracker::default();
682 let pts = vec![Point2D::new(0.0, 0.0)];
684 let results = tracker
685 .track(&img, &img, w, h, &pts)
686 .expect("should not crash");
687 assert_eq!(results.len(), 1);
688 }
690
691 #[test]
692 fn test_klt_default_config() {
693 let config = KltConfig::default();
694 assert_eq!(config.window_half_size, 7);
695 assert_eq!(config.pyramid_levels, 3);
696 assert_eq!(config.max_iterations, 20);
697 }
698
699 #[test]
700 fn test_track_result_fields() {
701 let tr = TrackResult {
702 origin: Point2D::new(10.0, 20.0),
703 tracked: Some(Point2D::new(12.0, 22.0)),
704 error: 0.5,
705 success: true,
706 };
707 assert!(tr.success);
708 assert!((tr.error - 0.5).abs() < f64::EPSILON);
709 }
710
711 #[test]
714 fn test_gradients_constant_image() {
715 let img = vec![100u8; 32 * 32];
716 let (gx, gy) = compute_gradients(&img, 32, 32);
717 for y in 2..30 {
719 for x in 2..30 {
720 assert!(gx[y * 32 + x].abs() < 1e-10);
721 assert!(gy[y * 32 + x].abs() < 1e-10);
722 }
723 }
724 }
725
726 #[test]
727 fn test_gradients_horizontal_ramp() {
728 let w = 32usize;
730 let h = 32usize;
731 let mut img = vec![0u8; w * h];
732 for y in 0..h {
733 for x in 0..w {
734 img[y * w + x] = (x * 8).min(255) as u8;
735 }
736 }
737 let (gx, _gy) = compute_gradients(&img, w, h);
738 let mid = 16 * w + 16;
740 assert!(gx[mid] > 0.0, "horizontal ramp should produce positive gx");
741 }
742
743 fn make_patch_image(w: usize, h: usize, cx: usize, cy: usize, half: usize) -> Vec<u8> {
747 let mut img = vec![30u8; w * h];
748 for y in cy.saturating_sub(half)..=(cy + half).min(h - 1) {
749 for x in cx.saturating_sub(half)..=(cx + half).min(w - 1) {
750 img[y * w + x] = 210;
751 }
752 }
753 img
754 }
755
756 #[test]
757 fn test_track_features_stationary() {
758 let w = 64u32;
759 let h = 64u32;
760 let img = make_patch_image(64, 64, 32, 32, 6);
761 let tracker = KltTracker::default();
762 let pts: Vec<(f32, f32)> = vec![(32.0, 32.0)];
763
764 let results = tracker
765 .track_features(&img, &img, w, h, &pts)
766 .expect("track_features should not error");
767
768 assert_eq!(results.len(), 1);
769 if let Some((tx, ty)) = results[0] {
770 assert!((tx - 32.0).abs() < 1.5, "tx={tx}");
771 assert!((ty - 32.0).abs() < 1.5, "ty={ty}");
772 }
773 }
776
777 #[test]
778 fn test_track_features_translation() {
779 let w = 128u32;
780 let h = 128u32;
781 let shift = 5usize;
782
783 let prev = make_patch_image(128, 128, 60, 60, 12);
784 let curr = make_patch_image(128, 128, 60 + shift, 60, 12);
785
786 let config = KltConfig {
787 window_half_size: 10,
788 pyramid_levels: 3,
789 max_iterations: 30,
790 epsilon: 0.01,
791 min_eigenvalue: 1e-6,
792 };
793 let tracker = KltTracker::new(config);
794 let pts: Vec<(f32, f32)> = vec![(60.0, 60.0)];
795
796 let results = tracker
797 .track_features(&prev, &curr, w, h, &pts)
798 .expect("track_features should succeed");
799
800 assert_eq!(results.len(), 1);
801 if let Some((tx, _ty)) = results[0] {
802 let dx = tx - 60.0;
803 assert!(
804 (dx - shift as f32).abs() < 3.0,
805 "expected ~{shift} px shift, got dx={dx:.2}"
806 );
807 }
808 }
809
810 #[test]
811 fn test_track_features_returns_none_for_flat_region() {
812 let w = 64u32;
815 let h = 64u32;
816 let img = vec![128u8; 64 * 64];
817 let tracker = KltTracker::default();
818 let pts: Vec<(f32, f32)> = vec![(32.0, 32.0)];
819
820 let results = tracker
821 .track_features(&img, &img, w, h, &pts)
822 .expect("should not error");
823
824 assert_eq!(results.len(), 1);
825 assert!(
827 results[0].is_none(),
828 "flat region should return None, got {:?}",
829 results[0]
830 );
831 }
832
833 #[test]
834 fn test_track_features_invalid_size() {
835 let tracker = KltTracker::default();
836 let pts = vec![(5.0_f32, 5.0_f32)];
837 let err = tracker.track_features(&[0u8; 100], &[0u8; 200], 10, 10, &pts);
839 assert!(err.is_err());
840 }
841
842 #[test]
843 fn test_track_features_multiple_points() {
844 let w = 64u32;
845 let h = 64u32;
846 let img = make_patch_image(64, 64, 32, 32, 8);
847 let tracker = KltTracker::default();
848 let pts: Vec<(f32, f32)> = vec![(28.0, 28.0), (32.0, 32.0), (36.0, 36.0)];
849
850 let results = tracker
851 .track_features(&img, &img, w, h, &pts)
852 .expect("should succeed");
853
854 assert_eq!(results.len(), 3);
856 }
857}