Skip to main content

oximedia_align/
klt_tracker.rs

1//! Kanade-Lucas-Tomasi (KLT) sparse optical flow tracker.
2//!
3//! Implements the classic pyramidal Lucas-Kanade tracker for tracking sparse
4//! feature points across consecutive frames. This is the workhorse of most
5//! video stabilization and motion estimation pipelines.
6//!
7//! # Algorithm
8//!
9//! 1. Build Gaussian image pyramids for both frames.
10//! 2. At the coarsest level, initialize estimated displacement to zero.
11//! 3. At each level, refine the displacement using the Lucas-Kanade optical flow
12//!    equation (solve the 2x2 structure tensor system).
13//! 4. Propagate the refined displacement to the next finer level (multiply by 2).
14//!
15//! # References
16//!
17//! - Bouguet, J-Y. "Pyramidal Implementation of the Lucas Kanade Feature Tracker"
18//!   Intel Corporation, 2001.
19
20#![allow(clippy::cast_precision_loss)]
21
22use crate::{AlignError, AlignResult, Point2D};
23
24/// Configuration for the KLT tracker.
25#[derive(Debug, Clone)]
26pub struct KltConfig {
27    /// Half-size of the integration window (full window is `2*win + 1`).
28    pub window_half_size: usize,
29    /// Number of pyramid levels (1 = no pyramid, just original resolution).
30    pub pyramid_levels: usize,
31    /// Maximum number of Newton-Raphson iterations per pyramid level.
32    pub max_iterations: usize,
33    /// Convergence threshold for the displacement update (pixels).
34    pub epsilon: f64,
35    /// Minimum eigenvalue of the structure tensor. Points with eigenvalues
36    /// below this are considered un-trackable.
37    pub min_eigenvalue: f64,
38}
39
40impl Default for KltConfig {
41    fn default() -> Self {
42        Self {
43            window_half_size: 7,
44            pyramid_levels: 3,
45            max_iterations: 20,
46            epsilon: 0.03,
47            min_eigenvalue: 1e-4,
48        }
49    }
50}
51
52/// Result of tracking a single point.
53#[derive(Debug, Clone)]
54pub struct TrackResult {
55    /// Original point in the first frame.
56    pub origin: Point2D,
57    /// Tracked position in the second frame (may be `None` if tracking failed).
58    pub tracked: Option<Point2D>,
59    /// Tracking error (SSD over the window), lower is better.
60    pub error: f64,
61    /// Whether the point was successfully tracked.
62    pub success: bool,
63}
64
65/// Pyramidal KLT tracker.
66pub struct KltTracker {
67    /// Tracker configuration.
68    pub config: KltConfig,
69}
70
71impl Default for KltTracker {
72    fn default() -> Self {
73        Self {
74            config: KltConfig::default(),
75        }
76    }
77}
78
79impl KltTracker {
80    /// Create a new KLT tracker with the given configuration.
81    #[must_use]
82    pub fn new(config: KltConfig) -> Self {
83        Self { config }
84    }
85
86    /// Track sparse feature points from `prev_frame` to `curr_frame` using
87    /// pyramidal Lucas-Kanade optical flow.
88    ///
89    /// This is the primary high-level API for KLT tracking. It accepts f32
90    /// pixel coordinates (common in feature-detection pipelines) and returns
91    /// `Some((x, y))` for successfully tracked points or `None` when tracking
92    /// failed (low texture, out-of-bounds, etc.).
93    ///
94    /// Both frames must be single-channel (grayscale), row-major byte images
95    /// of size `width * height`.
96    ///
97    /// # Errors
98    ///
99    /// Returns an error if image dimensions are inconsistent or the frame is
100    /// smaller than 8×8 pixels.
101    pub fn track_features(
102        &self,
103        prev_frame: &[u8],
104        curr_frame: &[u8],
105        width: u32,
106        height: u32,
107        points: &[(f32, f32)],
108    ) -> AlignResult<Vec<Option<(f32, f32)>>> {
109        let w = width as usize;
110        let h = height as usize;
111
112        // Convert (f32, f32) points to Point2D for the internal API.
113        let pts: Vec<Point2D> = points
114            .iter()
115            .map(|&(x, y)| Point2D::new(f64::from(x), f64::from(y)))
116            .collect();
117
118        let results = self.track(prev_frame, curr_frame, w, h, &pts)?;
119
120        Ok(results
121            .into_iter()
122            .map(|r| r.tracked.map(|p| (p.x as f32, p.y as f32)))
123            .collect())
124    }
125
126    /// Track a set of points from `prev_image` to `curr_image`.
127    ///
128    /// Both images must be single-channel (grayscale), row-major, with the
129    /// same `width` and `height`.
130    ///
131    /// # Errors
132    ///
133    /// Returns an error if the image dimensions are inconsistent.
134    pub fn track(
135        &self,
136        prev_image: &[u8],
137        curr_image: &[u8],
138        width: usize,
139        height: usize,
140        points: &[Point2D],
141    ) -> AlignResult<Vec<TrackResult>> {
142        if prev_image.len() != width * height || curr_image.len() != width * height {
143            return Err(AlignError::InvalidConfig(
144                "Image size does not match width*height".to_string(),
145            ));
146        }
147        if width < 8 || height < 8 {
148            return Err(AlignError::InvalidConfig(
149                "Image must be at least 8x8".to_string(),
150            ));
151        }
152
153        // Build pyramids
154        let prev_pyr = build_pyramid(prev_image, width, height, self.config.pyramid_levels);
155        let curr_pyr = build_pyramid(curr_image, width, height, self.config.pyramid_levels);
156
157        let results: Vec<TrackResult> = points
158            .iter()
159            .map(|pt| self.track_point(pt, &prev_pyr, &curr_pyr))
160            .collect();
161
162        Ok(results)
163    }
164
165    /// Track a single point through the pyramid.
166    fn track_point(
167        &self,
168        point: &Point2D,
169        prev_pyr: &[PyramidLevel],
170        curr_pyr: &[PyramidLevel],
171    ) -> TrackResult {
172        let num_levels = prev_pyr.len();
173        let win = self.config.window_half_size as f64;
174
175        // Initial guess at coarsest level: zero displacement
176        let mut gx = 0.0_f64;
177        let mut gy = 0.0_f64;
178
179        let mut last_error = f64::MAX;
180        let mut success = true;
181
182        // Iterate from coarsest to finest
183        for level in (0..num_levels).rev() {
184            let scale = 1.0 / (1 << level) as f64;
185            let px = point.x * scale;
186            let py = point.y * scale;
187
188            let prev_level = &prev_pyr[level];
189            let curr_level = &curr_pyr[level];
190
191            let w = prev_level.width;
192            let h = prev_level.height;
193
194            // Compute gradients of the previous image at this level
195            let (grad_x, grad_y) = compute_gradients(&prev_level.data, w, h);
196
197            // Build the structure tensor G and mismatch vector b
198            let wi = self.config.window_half_size as isize;
199
200            // Iterative Lucas-Kanade
201            let mut vx = gx;
202            let mut vy = gy;
203
204            for _iter in 0..self.config.max_iterations {
205                let mut g_xx = 0.0_f64;
206                let mut g_yy = 0.0_f64;
207                let mut g_xy = 0.0_f64;
208                let mut b_x = 0.0_f64;
209                let mut b_y = 0.0_f64;
210
211                for dy in -wi..=wi {
212                    for dx in -wi..=wi {
213                        let sx = px + dx as f64;
214                        let sy = py + dy as f64;
215                        let tx = px + dx as f64 + vx;
216                        let ty = py + dy as f64 + vy;
217
218                        if sx < 0.0
219                            || sy < 0.0
220                            || sx >= (w - 1) as f64
221                            || sy >= (h - 1) as f64
222                            || tx < 0.0
223                            || ty < 0.0
224                            || tx >= (w - 1) as f64
225                            || ty >= (h - 1) as f64
226                        {
227                            continue;
228                        }
229
230                        let ix = bilinear_sample_f64(&grad_x, w, sx, sy);
231                        let iy = bilinear_sample_f64(&grad_y, w, sx, sy);
232                        let prev_val = bilinear_sample(&prev_level.data, w, sx, sy);
233                        let curr_val = bilinear_sample(&curr_level.data, w, tx, ty);
234
235                        let dt = prev_val - curr_val;
236
237                        g_xx += ix * ix;
238                        g_yy += iy * iy;
239                        g_xy += ix * iy;
240                        b_x += dt * ix;
241                        b_y += dt * iy;
242                    }
243                }
244
245                // Check minimum eigenvalue
246                let trace = g_xx + g_yy;
247                let det = g_xx * g_yy - g_xy * g_xy;
248                let discriminant = trace * trace - 4.0 * det;
249                let min_eig = if discriminant >= 0.0 {
250                    (trace - discriminant.sqrt()) / 2.0
251                } else {
252                    0.0
253                };
254
255                if min_eig < self.config.min_eigenvalue {
256                    success = false;
257                    break;
258                }
259
260                // Solve 2x2 system: G * [dvx; dvy] = [bx; by]
261                if det.abs() < 1e-12 {
262                    success = false;
263                    break;
264                }
265
266                let dvx = (g_yy * b_x - g_xy * b_y) / det;
267                let dvy = (-g_xy * b_x + g_xx * b_y) / det;
268
269                vx += dvx;
270                vy += dvy;
271
272                if dvx * dvx + dvy * dvy < self.config.epsilon * self.config.epsilon {
273                    break;
274                }
275            }
276
277            // Propagate to next finer level (scale by 2)
278            if level > 0 {
279                gx = vx * 2.0;
280                gy = vy * 2.0;
281            } else {
282                gx = vx;
283                gy = vy;
284            }
285
286            // Compute error at finest level
287            if level == 0 {
288                last_error = self.compute_tracking_error(
289                    &prev_pyr[0].data,
290                    &curr_pyr[0].data,
291                    prev_pyr[0].width,
292                    prev_pyr[0].height,
293                    point.x,
294                    point.y,
295                    gx,
296                    gy,
297                    win as isize,
298                );
299            }
300        }
301
302        // Check if tracked point is within image bounds (at original resolution)
303        let tracked_x = point.x + gx;
304        let tracked_y = point.y + gy;
305        let orig_w = prev_pyr[0].width as f64;
306        let orig_h = prev_pyr[0].height as f64;
307
308        if !success
309            || tracked_x < 0.0
310            || tracked_y < 0.0
311            || tracked_x >= orig_w
312            || tracked_y >= orig_h
313        {
314            return TrackResult {
315                origin: *point,
316                tracked: None,
317                error: last_error,
318                success: false,
319            };
320        }
321
322        TrackResult {
323            origin: *point,
324            tracked: Some(Point2D::new(tracked_x, tracked_y)),
325            error: last_error,
326            success: true,
327        }
328    }
329
330    /// Compute SSD tracking error over the window.
331    #[allow(clippy::too_many_arguments, clippy::manual_checked_ops)]
332    fn compute_tracking_error(
333        &self,
334        prev: &[u8],
335        curr: &[u8],
336        w: usize,
337        h: usize,
338        px: f64,
339        py: f64,
340        vx: f64,
341        vy: f64,
342        half_win: isize,
343    ) -> f64 {
344        let mut ssd = 0.0_f64;
345        let mut count = 0u32;
346
347        for dy in -half_win..=half_win {
348            for dx in -half_win..=half_win {
349                let sx = px + dx as f64;
350                let sy = py + dy as f64;
351                let tx = sx + vx;
352                let ty = sy + vy;
353
354                if sx >= 0.0
355                    && sy >= 0.0
356                    && sx < (w - 1) as f64
357                    && sy < (h - 1) as f64
358                    && tx >= 0.0
359                    && ty >= 0.0
360                    && tx < (w - 1) as f64
361                    && ty < (h - 1) as f64
362                {
363                    let a = bilinear_sample(prev, w, sx, sy);
364                    let b = bilinear_sample(curr, w, tx, ty);
365                    let d = a - b;
366                    ssd += d * d;
367                    count += 1;
368                }
369            }
370        }
371
372        if count == 0 {
373            return f64::MAX;
374        }
375        ssd / f64::from(count)
376    }
377}
378
379// -- Pyramid ------------------------------------------------------------------
380
381/// A single level in the image pyramid.
382#[derive(Debug, Clone)]
383struct PyramidLevel {
384    data: Vec<u8>,
385    width: usize,
386    height: usize,
387}
388
389/// Build a Gaussian image pyramid (factor-of-2 downsample at each level).
390fn build_pyramid(image: &[u8], width: usize, height: usize, levels: usize) -> Vec<PyramidLevel> {
391    let mut pyramid = Vec::with_capacity(levels);
392    pyramid.push(PyramidLevel {
393        data: image.to_vec(),
394        width,
395        height,
396    });
397
398    let mut cur = image.to_vec();
399    let mut cw = width;
400    let mut ch = height;
401
402    for _ in 1..levels {
403        let nw = cw / 2;
404        let nh = ch / 2;
405        if nw < 4 || nh < 4 {
406            break;
407        }
408        let down = downsample_2x(&cur, cw, ch, nw, nh);
409        pyramid.push(PyramidLevel {
410            data: down.clone(),
411            width: nw,
412            height: nh,
413        });
414        cur = down;
415        cw = nw;
416        ch = nh;
417    }
418
419    pyramid
420}
421
422/// 2x downsample with box-filter averaging.
423fn downsample_2x(src: &[u8], sw: usize, sh: usize, dw: usize, dh: usize) -> Vec<u8> {
424    let mut dst = vec![0u8; dw * dh];
425    for dy in 0..dh {
426        for dx in 0..dw {
427            let sx = dx * 2;
428            let sy = dy * 2;
429            let mut sum = 0u16;
430            let mut count = 0u16;
431            for oy in 0..2 {
432                for ox in 0..2 {
433                    let rx = sx + ox;
434                    let ry = sy + oy;
435                    if rx < sw && ry < sh {
436                        sum += u16::from(src[ry * sw + rx]);
437                        count += 1;
438                    }
439                }
440            }
441            dst[dy * dw + dx] = sum.checked_div(count).unwrap_or(0) as u8;
442        }
443    }
444    dst
445}
446
447// -- Bilinear interpolation ---------------------------------------------------
448
449/// Bilinear interpolation on a u8 image, returning f64.
450fn bilinear_sample(image: &[u8], width: usize, x: f64, y: f64) -> f64 {
451    let x0 = x.floor() as usize;
452    let y0 = y.floor() as usize;
453    let x1 = x0 + 1;
454    let y1 = y0 + 1;
455    let fx = x - x0 as f64;
456    let fy = y - y0 as f64;
457
458    let v00 = f64::from(image[y0 * width + x0]);
459    let v10 = f64::from(image[y0 * width + x1]);
460    let v01 = f64::from(image[y1 * width + x0]);
461    let v11 = f64::from(image[y1 * width + x1]);
462
463    v00 * (1.0 - fx) * (1.0 - fy) + v10 * fx * (1.0 - fy) + v01 * (1.0 - fx) * fy + v11 * fx * fy
464}
465
466/// Bilinear interpolation on an f64 buffer.
467fn bilinear_sample_f64(buf: &[f64], width: usize, x: f64, y: f64) -> f64 {
468    let x0 = x.floor() as usize;
469    let y0 = y.floor() as usize;
470    let x1 = x0 + 1;
471    let y1 = y0 + 1;
472    let fx = x - x0 as f64;
473    let fy = y - y0 as f64;
474
475    let v00 = buf[y0 * width + x0];
476    let v10 = buf[y0 * width + x1];
477    let v01 = buf[y1 * width + x0];
478    let v11 = buf[y1 * width + x1];
479
480    v00 * (1.0 - fx) * (1.0 - fy) + v10 * fx * (1.0 - fy) + v01 * (1.0 - fx) * fy + v11 * fx * fy
481}
482
483/// Compute Sobel gradients (f64 output).
484fn compute_gradients(image: &[u8], width: usize, height: usize) -> (Vec<f64>, Vec<f64>) {
485    let n = width * height;
486    let mut gx = vec![0.0_f64; n];
487    let mut gy = vec![0.0_f64; n];
488
489    for y in 1..height.saturating_sub(1) {
490        for x in 1..width.saturating_sub(1) {
491            let idx = y * width + x;
492
493            let i_tl = f64::from(image[(y - 1) * width + (x - 1)]);
494            let i_t = f64::from(image[(y - 1) * width + x]);
495            let i_tr = f64::from(image[(y - 1) * width + (x + 1)]);
496            let i_l = f64::from(image[y * width + (x - 1)]);
497            let i_r = f64::from(image[y * width + (x + 1)]);
498            let i_bl = f64::from(image[(y + 1) * width + (x - 1)]);
499            let i_b = f64::from(image[(y + 1) * width + x]);
500            let i_br = f64::from(image[(y + 1) * width + (x + 1)]);
501
502            gx[idx] = (-i_tl + i_tr - 2.0 * i_l + 2.0 * i_r - i_bl + i_br) / 8.0;
503            gy[idx] = (-i_tl - 2.0 * i_t - i_tr + i_bl + 2.0 * i_b + i_br) / 8.0;
504        }
505    }
506
507    (gx, gy)
508}
509
510#[cfg(test)]
511mod tests {
512    use super::*;
513
514    // -- Helpers ---------------------------------------------------------------
515
516    /// Create a test image with a single bright square at a given offset.
517    fn make_square_image(w: usize, h: usize, cx: usize, cy: usize, half: usize) -> Vec<u8> {
518        let mut img = vec![30u8; w * h];
519        for y in cy.saturating_sub(half)..=(cy + half).min(h - 1) {
520            for x in cx.saturating_sub(half)..=(cx + half).min(w - 1) {
521                img[y * w + x] = 200;
522            }
523        }
524        img
525    }
526
527    // -- PyramidLevel ---------------------------------------------------------
528
529    #[test]
530    fn test_build_pyramid_levels() {
531        let img = vec![128u8; 64 * 64];
532        let pyr = build_pyramid(&img, 64, 64, 3);
533        assert_eq!(pyr.len(), 3);
534        assert_eq!(pyr[0].width, 64);
535        assert_eq!(pyr[1].width, 32);
536        assert_eq!(pyr[2].width, 16);
537    }
538
539    #[test]
540    fn test_build_pyramid_single_level() {
541        let img = vec![128u8; 32 * 32];
542        let pyr = build_pyramid(&img, 32, 32, 1);
543        assert_eq!(pyr.len(), 1);
544    }
545
546    #[test]
547    fn test_downsample_preserves_constant() {
548        let img = vec![100u8; 64 * 64];
549        let down = downsample_2x(&img, 64, 64, 32, 32);
550        for &v in &down {
551            assert_eq!(v, 100);
552        }
553    }
554
555    // -- Bilinear interpolation -----------------------------------------------
556
557    #[test]
558    fn test_bilinear_integer_coords() {
559        let img: Vec<u8> = vec![10, 20, 30, 40];
560        let val = bilinear_sample(&img, 2, 0.0, 0.0);
561        assert!((val - 10.0).abs() < 1e-6);
562    }
563
564    #[test]
565    fn test_bilinear_midpoint() {
566        // 2x2 image: [0, 100; 0, 100]
567        let img: Vec<u8> = vec![0, 100, 0, 100];
568        let val = bilinear_sample(&img, 2, 0.5, 0.0);
569        assert!((val - 50.0).abs() < 1e-6);
570    }
571
572    // -- KLT tracker ----------------------------------------------------------
573
574    #[test]
575    fn test_klt_stationary_point() {
576        let w = 64usize;
577        let h = 64usize;
578        let img = make_square_image(w, h, 32, 32, 5);
579
580        let config = KltConfig {
581            window_half_size: 5,
582            pyramid_levels: 2,
583            max_iterations: 20,
584            epsilon: 0.01,
585            min_eigenvalue: 1e-6,
586        };
587        let tracker = KltTracker::new(config);
588        let pts = vec![Point2D::new(32.0, 32.0)];
589
590        let results = tracker
591            .track(&img, &img, w, h, &pts)
592            .expect("track should succeed");
593        assert_eq!(results.len(), 1);
594        assert!(
595            results[0].success,
596            "tracking a stationary point should succeed"
597        );
598        let tracked = results[0].tracked.expect("should have a tracked point");
599        assert!(
600            (tracked.x - 32.0).abs() < 1.0 && (tracked.y - 32.0).abs() < 1.0,
601            "stationary point should not move: got ({:.2}, {:.2})",
602            tracked.x,
603            tracked.y,
604        );
605    }
606
607    #[test]
608    fn test_klt_translated_square() {
609        let w = 128usize;
610        let h = 128usize;
611        let shift = 4;
612
613        let prev = make_square_image(w, h, 60, 60, 10);
614        let curr = make_square_image(w, h, 60 + shift, 60, 10);
615
616        let config = KltConfig {
617            window_half_size: 10,
618            pyramid_levels: 3,
619            max_iterations: 30,
620            epsilon: 0.01,
621            min_eigenvalue: 1e-6,
622        };
623        let tracker = KltTracker::new(config);
624        let pts = vec![Point2D::new(60.0, 60.0)];
625
626        let results = tracker
627            .track(&prev, &curr, w, h, &pts)
628            .expect("track should succeed");
629        assert!(results[0].success, "should successfully track the square");
630        if let Some(tracked) = &results[0].tracked {
631            let dx = tracked.x - 60.0;
632            // Should detect roughly +4 pixels horizontal shift
633            assert!(
634                (dx - shift as f64).abs() < 2.0,
635                "expected ~{shift} px shift, got dx={dx:.2}"
636            );
637        }
638    }
639
640    #[test]
641    fn test_klt_image_size_mismatch() {
642        let tracker = KltTracker::default();
643        let pts = vec![Point2D::new(5.0, 5.0)];
644        let result = tracker.track(&[0u8; 100], &[0u8; 200], 10, 10, &pts);
645        assert!(result.is_err());
646    }
647
648    #[test]
649    fn test_klt_too_small_image() {
650        let tracker = KltTracker::default();
651        let pts = vec![Point2D::new(1.0, 1.0)];
652        let result = tracker.track(&[0u8; 4], &[0u8; 4], 2, 2, &pts);
653        assert!(result.is_err());
654    }
655
656    #[test]
657    fn test_klt_multiple_points() {
658        let w = 128usize;
659        let h = 128usize;
660        let img = make_square_image(w, h, 64, 64, 15);
661
662        let tracker = KltTracker::default();
663        let pts = vec![
664            Point2D::new(50.0, 50.0),
665            Point2D::new(70.0, 64.0),
666            Point2D::new(64.0, 70.0),
667        ];
668
669        let results = tracker
670            .track(&img, &img, w, h, &pts)
671            .expect("should succeed");
672        assert_eq!(results.len(), 3);
673    }
674
675    #[test]
676    fn test_klt_point_out_of_bounds_does_not_crash() {
677        let w = 64usize;
678        let h = 64usize;
679        let img = vec![128u8; w * h];
680
681        let tracker = KltTracker::default();
682        // Point near the border
683        let pts = vec![Point2D::new(0.0, 0.0)];
684        let results = tracker
685            .track(&img, &img, w, h, &pts)
686            .expect("should not crash");
687        assert_eq!(results.len(), 1);
688        // May or may not succeed depending on window size, but should not panic
689    }
690
691    #[test]
692    fn test_klt_default_config() {
693        let config = KltConfig::default();
694        assert_eq!(config.window_half_size, 7);
695        assert_eq!(config.pyramid_levels, 3);
696        assert_eq!(config.max_iterations, 20);
697    }
698
699    #[test]
700    fn test_track_result_fields() {
701        let tr = TrackResult {
702            origin: Point2D::new(10.0, 20.0),
703            tracked: Some(Point2D::new(12.0, 22.0)),
704            error: 0.5,
705            success: true,
706        };
707        assert!(tr.success);
708        assert!((tr.error - 0.5).abs() < f64::EPSILON);
709    }
710
711    // -- Gradient computation -------------------------------------------------
712
713    #[test]
714    fn test_gradients_constant_image() {
715        let img = vec![100u8; 32 * 32];
716        let (gx, gy) = compute_gradients(&img, 32, 32);
717        // Interior pixels should have zero gradient on a constant image.
718        for y in 2..30 {
719            for x in 2..30 {
720                assert!(gx[y * 32 + x].abs() < 1e-10);
721                assert!(gy[y * 32 + x].abs() < 1e-10);
722            }
723        }
724    }
725
726    #[test]
727    fn test_gradients_horizontal_ramp() {
728        // Image where each column has a constant value equal to x.
729        let w = 32usize;
730        let h = 32usize;
731        let mut img = vec![0u8; w * h];
732        for y in 0..h {
733            for x in 0..w {
734                img[y * w + x] = (x * 8).min(255) as u8;
735            }
736        }
737        let (gx, _gy) = compute_gradients(&img, w, h);
738        // Interior horizontal gradient should be positive.
739        let mid = 16 * w + 16;
740        assert!(gx[mid] > 0.0, "horizontal ramp should produce positive gx");
741    }
742
743    // -- track_features (f32 API) ---------------------------------------------
744
745    /// Helper: create a synthetic image with a bright square patch.
746    fn make_patch_image(w: usize, h: usize, cx: usize, cy: usize, half: usize) -> Vec<u8> {
747        let mut img = vec![30u8; w * h];
748        for y in cy.saturating_sub(half)..=(cy + half).min(h - 1) {
749            for x in cx.saturating_sub(half)..=(cx + half).min(w - 1) {
750                img[y * w + x] = 210;
751            }
752        }
753        img
754    }
755
756    #[test]
757    fn test_track_features_stationary() {
758        let w = 64u32;
759        let h = 64u32;
760        let img = make_patch_image(64, 64, 32, 32, 6);
761        let tracker = KltTracker::default();
762        let pts: Vec<(f32, f32)> = vec![(32.0, 32.0)];
763
764        let results = tracker
765            .track_features(&img, &img, w, h, &pts)
766            .expect("track_features should not error");
767
768        assert_eq!(results.len(), 1);
769        if let Some((tx, ty)) = results[0] {
770            assert!((tx - 32.0).abs() < 1.5, "tx={tx}");
771            assert!((ty - 32.0).abs() < 1.5, "ty={ty}");
772        }
773        // The point may also be None if it falls on a flat region — that is
774        // acceptable; we only care that the call did not panic or error.
775    }
776
777    #[test]
778    fn test_track_features_translation() {
779        let w = 128u32;
780        let h = 128u32;
781        let shift = 5usize;
782
783        let prev = make_patch_image(128, 128, 60, 60, 12);
784        let curr = make_patch_image(128, 128, 60 + shift, 60, 12);
785
786        let config = KltConfig {
787            window_half_size: 10,
788            pyramid_levels: 3,
789            max_iterations: 30,
790            epsilon: 0.01,
791            min_eigenvalue: 1e-6,
792        };
793        let tracker = KltTracker::new(config);
794        let pts: Vec<(f32, f32)> = vec![(60.0, 60.0)];
795
796        let results = tracker
797            .track_features(&prev, &curr, w, h, &pts)
798            .expect("track_features should succeed");
799
800        assert_eq!(results.len(), 1);
801        if let Some((tx, _ty)) = results[0] {
802            let dx = tx - 60.0;
803            assert!(
804                (dx - shift as f32).abs() < 3.0,
805                "expected ~{shift} px shift, got dx={dx:.2}"
806            );
807        }
808    }
809
810    #[test]
811    fn test_track_features_returns_none_for_flat_region() {
812        // A completely flat (constant-value) image has no texture, so any
813        // interior point should fail the eigenvalue test and return None.
814        let w = 64u32;
815        let h = 64u32;
816        let img = vec![128u8; 64 * 64];
817        let tracker = KltTracker::default();
818        let pts: Vec<(f32, f32)> = vec![(32.0, 32.0)];
819
820        let results = tracker
821            .track_features(&img, &img, w, h, &pts)
822            .expect("should not error");
823
824        assert_eq!(results.len(), 1);
825        // Flat image → no trackable texture → None expected.
826        assert!(
827            results[0].is_none(),
828            "flat region should return None, got {:?}",
829            results[0]
830        );
831    }
832
833    #[test]
834    fn test_track_features_invalid_size() {
835        let tracker = KltTracker::default();
836        let pts = vec![(5.0_f32, 5.0_f32)];
837        // Mismatched slice length
838        let err = tracker.track_features(&[0u8; 100], &[0u8; 200], 10, 10, &pts);
839        assert!(err.is_err());
840    }
841
842    #[test]
843    fn test_track_features_multiple_points() {
844        let w = 64u32;
845        let h = 64u32;
846        let img = make_patch_image(64, 64, 32, 32, 8);
847        let tracker = KltTracker::default();
848        let pts: Vec<(f32, f32)> = vec![(28.0, 28.0), (32.0, 32.0), (36.0, 36.0)];
849
850        let results = tracker
851            .track_features(&img, &img, w, h, &pts)
852            .expect("should succeed");
853
854        // Length must match input regardless of tracking outcome.
855        assert_eq!(results.len(), 3);
856    }
857}