Skip to main content

oximedia_align/
rolling_shutter.rs

1//! Rolling shutter correction.
2//!
3//! This module provides tools for correcting rolling shutter artifacts:
4//!
5//! - Motion estimation per scanline
6//! - Wobble correction
7//! - Skew removal
8//! - Global shutter simulation
9
10use crate::{AlignError, AlignResult};
11// Vector2 removed - unused
12
13/// Rolling shutter parameters
14#[derive(Debug, Clone)]
15pub struct RollingShutterParams {
16    /// Readout time in seconds (time to read entire frame)
17    pub readout_time: f64,
18    /// Frame rate
19    pub frame_rate: f64,
20    /// Readout direction
21    pub direction: ReadoutDirection,
22}
23
24/// Readout direction for rolling shutter
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum ReadoutDirection {
27    /// Top to bottom
28    TopToBottom,
29    /// Bottom to top
30    BottomToTop,
31    /// Left to right
32    LeftToRight,
33    /// Right to left
34    RightToLeft,
35}
36
37impl RollingShutterParams {
38    /// Create new rolling shutter parameters
39    #[must_use]
40    pub fn new(readout_time: f64, frame_rate: f64, direction: ReadoutDirection) -> Self {
41        Self {
42            readout_time,
43            frame_rate,
44            direction,
45        }
46    }
47
48    /// Compute time offset for a given scanline
49    #[must_use]
50    pub fn compute_scanline_time(&self, scanline: usize, total_lines: usize) -> f64 {
51        let progress = match self.direction {
52            ReadoutDirection::TopToBottom => scanline as f64 / total_lines as f64,
53            ReadoutDirection::BottomToTop => 1.0 - (scanline as f64 / total_lines as f64),
54            ReadoutDirection::LeftToRight => scanline as f64 / total_lines as f64,
55            ReadoutDirection::RightToLeft => 1.0 - (scanline as f64 / total_lines as f64),
56        };
57
58        progress * self.readout_time
59    }
60}
61
62/// Motion vector for a scanline
63#[derive(Debug, Clone, Copy)]
64pub struct MotionVector {
65    /// Horizontal displacement
66    pub dx: f32,
67    /// Vertical displacement
68    pub dy: f32,
69    /// Confidence (0.0 to 1.0)
70    pub confidence: f32,
71}
72
73impl MotionVector {
74    /// Create a new motion vector
75    #[must_use]
76    pub fn new(dx: f32, dy: f32, confidence: f32) -> Self {
77        Self { dx, dy, confidence }
78    }
79
80    /// Create zero motion vector
81    #[must_use]
82    pub fn zero() -> Self {
83        Self {
84            dx: 0.0,
85            dy: 0.0,
86            confidence: 1.0,
87        }
88    }
89
90    /// Magnitude of motion
91    #[must_use]
92    pub fn magnitude(&self) -> f32 {
93        (self.dx * self.dx + self.dy * self.dy).sqrt()
94    }
95}
96
97/// Rolling shutter motion estimator
98pub struct RollingShutterEstimator {
99    /// Block size for motion estimation
100    pub block_size: usize,
101    /// Search range for motion
102    pub search_range: isize,
103}
104
105impl Default for RollingShutterEstimator {
106    fn default() -> Self {
107        Self {
108            block_size: 16,
109            search_range: 16,
110        }
111    }
112}
113
114impl RollingShutterEstimator {
115    /// Create a new motion estimator
116    #[must_use]
117    pub fn new(block_size: usize, search_range: isize) -> Self {
118        Self {
119            block_size,
120            search_range,
121        }
122    }
123
124    /// Estimate motion vectors for each scanline
125    ///
126    /// # Errors
127    /// Returns error if frames are invalid
128    pub fn estimate_motion(
129        &self,
130        frame1: &[u8],
131        frame2: &[u8],
132        width: usize,
133        height: usize,
134    ) -> AlignResult<Vec<MotionVector>> {
135        if frame1.len() != width * height * 3 || frame2.len() != width * height * 3 {
136            return Err(AlignError::InvalidConfig("Frame size mismatch".to_string()));
137        }
138
139        let mut motion_vectors = Vec::new();
140
141        // Estimate motion for each row
142        for y in (0..height).step_by(self.block_size) {
143            let mv = self.estimate_row_motion(frame1, frame2, width, height, y);
144            motion_vectors.push(mv);
145        }
146
147        Ok(motion_vectors)
148    }
149
150    /// Estimate motion for a single row
151    fn estimate_row_motion(
152        &self,
153        frame1: &[u8],
154        frame2: &[u8],
155        width: usize,
156        height: usize,
157        y: usize,
158    ) -> MotionVector {
159        let mut best_dx = 0;
160        let mut best_dy = 0;
161        let mut best_sad = u32::MAX;
162
163        // Search in a window around the current position
164        for dy in -self.search_range..=self.search_range {
165            for dx in -self.search_range..=self.search_range {
166                let sad = self.compute_sad(frame1, frame2, width, height, 0, y, dx, dy);
167
168                if sad < best_sad {
169                    best_sad = sad;
170                    best_dx = dx;
171                    best_dy = dy;
172                }
173            }
174        }
175
176        // Compute confidence based on SAD
177        let confidence = if best_sad == 0 {
178            1.0
179        } else {
180            1.0 / (1.0 + (best_sad as f32 / 1000.0))
181        };
182
183        MotionVector::new(best_dx as f32, best_dy as f32, confidence)
184    }
185
186    /// Compute sum of absolute differences
187    #[allow(clippy::too_many_arguments)]
188    fn compute_sad(
189        &self,
190        frame1: &[u8],
191        frame2: &[u8],
192        width: usize,
193        height: usize,
194        x: usize,
195        y: usize,
196        dx: isize,
197        dy: isize,
198    ) -> u32 {
199        let mut sad = 0u32;
200        let block_height = self.block_size.min(height - y);
201
202        for by in 0..block_height {
203            for bx in 0..self.block_size.min(width) {
204                let x1 = x + bx;
205                let y1 = y + by;
206
207                let x2 = (x1 as isize + dx).max(0).min((width - 1) as isize) as usize;
208                let y2 = (y1 as isize + dy).max(0).min((height - 1) as isize) as usize;
209
210                let idx1 = (y1 * width + x1) * 3;
211                let idx2 = (y2 * width + x2) * 3;
212
213                if idx1 + 2 < frame1.len() && idx2 + 2 < frame2.len() {
214                    for c in 0..3 {
215                        sad += u32::from(
216                            (i16::from(frame1[idx1 + c]) - i16::from(frame2[idx2 + c]))
217                                .unsigned_abs(),
218                        );
219                    }
220                }
221            }
222        }
223
224        sad
225    }
226}
227
228/// Rolling shutter corrector
229pub struct RollingShutterCorrector {
230    /// Camera parameters
231    pub params: RollingShutterParams,
232    /// Motion estimator
233    estimator: RollingShutterEstimator,
234}
235
236impl RollingShutterCorrector {
237    /// Create a new rolling shutter corrector
238    #[must_use]
239    pub fn new(params: RollingShutterParams) -> Self {
240        Self {
241            params,
242            estimator: RollingShutterEstimator::default(),
243        }
244    }
245
246    /// Correct rolling shutter in a frame
247    ///
248    /// # Errors
249    /// Returns error if correction fails
250    pub fn correct(
251        &self,
252        frame: &[u8],
253        motion_vectors: &[MotionVector],
254        width: usize,
255        height: usize,
256    ) -> AlignResult<Vec<u8>> {
257        if frame.len() != width * height * 3 {
258            return Err(AlignError::InvalidConfig("Frame size mismatch".to_string()));
259        }
260
261        let mut corrected = vec![0u8; width * height * 3];
262
263        // Apply motion compensation per scanline
264        for (block_idx, mv) in motion_vectors.iter().enumerate() {
265            let y_start = block_idx * self.estimator.block_size;
266            let y_end = (y_start + self.estimator.block_size).min(height);
267
268            for y in y_start..y_end {
269                self.correct_scanline(frame, &mut corrected, width, y, mv);
270            }
271        }
272
273        Ok(corrected)
274    }
275
276    /// Correct a single scanline
277    fn correct_scanline(
278        &self,
279        input: &[u8],
280        output: &mut [u8],
281        width: usize,
282        y: usize,
283        mv: &MotionVector,
284    ) {
285        for x in 0..width {
286            let src_x = (x as f32 - mv.dx).round() as isize;
287            let src_y = (y as f32 - mv.dy).round() as isize;
288
289            if src_x >= 0 && src_x < width as isize && src_y >= 0 {
290                let src_idx = (src_y as usize * width + src_x as usize) * 3;
291                let dst_idx = (y * width + x) * 3;
292
293                if src_idx + 2 < input.len() && dst_idx + 2 < output.len() {
294                    output[dst_idx..dst_idx + 3].copy_from_slice(&input[src_idx..src_idx + 3]);
295                }
296            }
297        }
298    }
299
300    /// Estimate and correct rolling shutter in one step
301    ///
302    /// # Errors
303    /// Returns error if correction fails
304    pub fn estimate_and_correct(
305        &self,
306        frame1: &[u8],
307        frame2: &[u8],
308        width: usize,
309        height: usize,
310    ) -> AlignResult<Vec<u8>> {
311        let motion_vectors = self
312            .estimator
313            .estimate_motion(frame1, frame2, width, height)?;
314        self.correct(frame2, &motion_vectors, width, height)
315    }
316}
317
318/// Wobble detector for rolling shutter artifacts
319pub struct WobbleDetector {
320    /// Threshold for wobble detection
321    pub threshold: f32,
322}
323
324impl Default for WobbleDetector {
325    fn default() -> Self {
326        Self { threshold: 5.0 }
327    }
328}
329
330impl WobbleDetector {
331    /// Create a new wobble detector
332    #[must_use]
333    pub fn new(threshold: f32) -> Self {
334        Self { threshold }
335    }
336
337    /// Detect wobble in motion vectors
338    #[must_use]
339    pub fn detect_wobble(&self, motion_vectors: &[MotionVector]) -> bool {
340        if motion_vectors.len() < 3 {
341            return false;
342        }
343
344        // Check for oscillating motion
345        let mut sign_changes = 0;
346
347        for i in 2..motion_vectors.len() {
348            let d1 = motion_vectors[i - 1].dx - motion_vectors[i - 2].dx;
349            let d2 = motion_vectors[i].dx - motion_vectors[i - 1].dx;
350
351            if d1 * d2 < 0.0 && d1.abs() > self.threshold {
352                sign_changes += 1;
353            }
354        }
355
356        // If motion changes direction frequently, it's wobble
357        sign_changes > motion_vectors.len() / 4
358    }
359
360    /// Compute wobble metric (0.0 = no wobble, 1.0 = severe wobble)
361    #[must_use]
362    pub fn compute_wobble_metric(&self, motion_vectors: &[MotionVector]) -> f32 {
363        if motion_vectors.len() < 2 {
364            return 0.0;
365        }
366
367        let mut total_variation = 0.0f32;
368
369        for i in 1..motion_vectors.len() {
370            let ddx = motion_vectors[i].dx - motion_vectors[i - 1].dx;
371            let ddy = motion_vectors[i].dy - motion_vectors[i - 1].dy;
372            total_variation += (ddx * ddx + ddy * ddy).sqrt();
373        }
374
375        let avg_variation = total_variation / (motion_vectors.len() - 1) as f32;
376
377        // Normalize to 0-1 range (assuming max variation of 20 pixels)
378        (avg_variation / 20.0).min(1.0)
379    }
380}
381
382/// Skew corrector for rolling shutter-induced distortion
383pub struct SkewCorrector {
384    /// Angular velocity (radians per second)
385    pub angular_velocity: f64,
386}
387
388impl SkewCorrector {
389    /// Create a new skew corrector
390    #[must_use]
391    pub fn new(angular_velocity: f64) -> Self {
392        Self { angular_velocity }
393    }
394
395    /// Correct skew in image
396    ///
397    /// # Errors
398    /// Returns error if correction fails
399    pub fn correct(
400        &self,
401        frame: &[u8],
402        width: usize,
403        height: usize,
404        params: &RollingShutterParams,
405    ) -> AlignResult<Vec<u8>> {
406        if frame.len() != width * height * 3 {
407            return Err(AlignError::InvalidConfig("Frame size mismatch".to_string()));
408        }
409
410        let mut corrected = vec![0u8; width * height * 3];
411
412        for y in 0..height {
413            let time = params.compute_scanline_time(y, height);
414            let angle = self.angular_velocity * time;
415
416            // Compute horizontal offset due to rotation
417            let offset = (angle * (height as f64 / 2.0)) as isize;
418
419            self.shift_scanline(frame, &mut corrected, width, y, offset);
420        }
421
422        Ok(corrected)
423    }
424
425    /// Shift a scanline horizontally
426    fn shift_scanline(
427        &self,
428        input: &[u8],
429        output: &mut [u8],
430        width: usize,
431        y: usize,
432        offset: isize,
433    ) {
434        for x in 0..width {
435            let src_x = (x as isize - offset).max(0).min((width - 1) as isize) as usize;
436
437            let src_idx = (y * width + src_x) * 3;
438            let dst_idx = (y * width + x) * 3;
439
440            if src_idx + 2 < input.len() && dst_idx + 2 < output.len() {
441                output[dst_idx..dst_idx + 3].copy_from_slice(&input[src_idx..src_idx + 3]);
442            }
443        }
444    }
445}
446
447/// Temporal interpolator for global shutter simulation
448pub struct GlobalShutterSimulator {
449    /// Number of virtual sub-frames
450    pub sub_frames: usize,
451}
452
453impl Default for GlobalShutterSimulator {
454    fn default() -> Self {
455        Self { sub_frames: 10 }
456    }
457}
458
459impl GlobalShutterSimulator {
460    /// Create a new global shutter simulator
461    #[must_use]
462    pub fn new(sub_frames: usize) -> Self {
463        Self { sub_frames }
464    }
465
466    /// Simulate global shutter by averaging virtual sub-frames
467    ///
468    /// # Errors
469    /// Returns error if simulation fails
470    pub fn simulate(
471        &self,
472        frames: &[&[u8]],
473        width: usize,
474        height: usize,
475        params: &RollingShutterParams,
476    ) -> AlignResult<Vec<u8>> {
477        if frames.is_empty() {
478            return Err(AlignError::InsufficientData(
479                "Need at least one frame".to_string(),
480            ));
481        }
482
483        let mut output = vec![0u32; width * height * 3];
484
485        // For each scanline, average contributions from multiple frames
486        for y in 0..height {
487            let _time = params.compute_scanline_time(y, height);
488
489            for frame in frames {
490                if frame.len() != width * height * 3 {
491                    continue;
492                }
493
494                for x in 0..width {
495                    let idx = (y * width + x) * 3;
496                    if idx + 2 < frame.len() {
497                        output[idx] += u32::from(frame[idx]);
498                        output[idx + 1] += u32::from(frame[idx + 1]);
499                        output[idx + 2] += u32::from(frame[idx + 2]);
500                    }
501                }
502            }
503        }
504
505        // Average
506        let n = frames.len() as u32;
507        let result = output.iter().map(|&v| (v / n) as u8).collect();
508
509        Ok(result)
510    }
511}
512
513/// Temporal smoother for rolling shutter motion vectors.
514///
515/// Applies an exponentially weighted moving average (EWMA) across consecutive
516/// frames to suppress frame-to-frame jitter in the per-scanline motion
517/// estimates. This is critical for preventing flickering artifacts that occur
518/// when raw per-frame motion vectors vary erratically.
519///
520/// # Algorithm
521///
522/// For each block index `i`, the smoother maintains a running estimate:
523///
524/// ```text
525/// mv_smoothed[i] = alpha * mv_new[i] + (1 - alpha) * mv_prev[i]
526/// ```
527///
528/// A lower `alpha` produces more temporal smoothing (more lag), while a higher
529/// `alpha` responds faster to genuine motion changes.
530pub struct TemporalSmoother {
531    /// Smoothing factor in (0, 1].  Lower = smoother.
532    alpha: f64,
533    /// Previous smoothed motion vectors (one per block).
534    state: Vec<MotionVector>,
535}
536
537impl TemporalSmoother {
538    /// Create a new temporal smoother.
539    ///
540    /// `alpha` is clamped to `[0.01, 1.0]`.
541    #[must_use]
542    pub fn new(alpha: f64) -> Self {
543        Self {
544            alpha: alpha.clamp(0.01, 1.0),
545            state: Vec::new(),
546        }
547    }
548
549    /// Smooth a new frame's motion vectors against the running average.
550    ///
551    /// On the first call the input is returned as-is (there is no history).
552    /// Subsequent calls blend the new vectors with the accumulated state.
553    ///
554    /// If the number of blocks changes between calls (e.g. resolution change)
555    /// the state is reset.
556    pub fn smooth(&mut self, motion_vectors: &[MotionVector]) -> Vec<MotionVector> {
557        if self.state.len() != motion_vectors.len() {
558            // First frame or resolution change: initialise state
559            self.state = motion_vectors.to_vec();
560            return motion_vectors.to_vec();
561        }
562
563        let alpha = self.alpha as f32;
564        let one_minus = 1.0 - alpha;
565
566        let mut result = Vec::with_capacity(motion_vectors.len());
567        for (prev, new) in self.state.iter_mut().zip(motion_vectors.iter()) {
568            let dx = alpha * new.dx + one_minus * prev.dx;
569            let dy = alpha * new.dy + one_minus * prev.dy;
570            let conf = alpha * new.confidence + one_minus * prev.confidence;
571
572            prev.dx = dx;
573            prev.dy = dy;
574            prev.confidence = conf;
575
576            result.push(MotionVector::new(dx, dy, conf));
577        }
578
579        result
580    }
581
582    /// Reset the internal state so the next call starts fresh.
583    pub fn reset(&mut self) {
584        self.state.clear();
585    }
586
587    /// Current smoothing factor.
588    #[must_use]
589    pub fn alpha(&self) -> f64 {
590        self.alpha
591    }
592
593    /// Number of blocks tracked.
594    #[must_use]
595    pub fn num_blocks(&self) -> usize {
596        self.state.len()
597    }
598}
599
600/// Gaussian temporal smoother that keeps a window of past frames and applies
601/// a weighted average across time.
602///
603/// This is more expensive than EWMA but produces less phase lag because the
604/// kernel is symmetric (it uses future context when available via lookahead).
605pub struct GaussianTemporalSmoother {
606    /// Kernel half-size (total window = 2 * radius + 1).
607    radius: usize,
608    /// Precomputed 1-D Gaussian kernel weights.
609    kernel: Vec<f64>,
610    /// Ring buffer of recent motion vector frames.
611    history: Vec<Vec<MotionVector>>,
612    /// Maximum number of frames to store (= 2 * radius + 1).
613    capacity: usize,
614}
615
616impl GaussianTemporalSmoother {
617    /// Create a new Gaussian temporal smoother.
618    ///
619    /// * `radius` -- half-size of the Gaussian kernel.
620    /// * `sigma` -- standard deviation (in frames).
621    #[must_use]
622    pub fn new(radius: usize, sigma: f64) -> Self {
623        let sigma = sigma.max(0.1);
624        let cap = 2 * radius + 1;
625        let mut kernel = Vec::with_capacity(cap);
626        for i in 0..cap {
627            let x = i as f64 - radius as f64;
628            kernel.push((-0.5 * x * x / (sigma * sigma)).exp());
629        }
630        // Normalise
631        let sum: f64 = kernel.iter().sum();
632        if sum > 1e-15 {
633            for v in &mut kernel {
634                *v /= sum;
635            }
636        }
637
638        Self {
639            radius,
640            kernel,
641            history: Vec::with_capacity(cap),
642            capacity: cap,
643        }
644    }
645
646    /// Push a new frame of motion vectors and return the smoothed result for
647    /// the centre frame (i.e. with `radius` frames of look-ahead/look-behind
648    /// when available).
649    ///
650    /// Until the buffer is full the result uses whatever history is available.
651    pub fn push(&mut self, motion_vectors: &[MotionVector]) -> Vec<MotionVector> {
652        self.history.push(motion_vectors.to_vec());
653        if self.history.len() > self.capacity {
654            self.history.remove(0);
655        }
656
657        let num_blocks = motion_vectors.len();
658        let num_frames = self.history.len();
659
660        // The "centre" index in the available history
661        let centre = if num_frames > self.radius {
662            num_frames - 1 - self.radius.min(num_frames - 1)
663        } else {
664            0
665        };
666
667        let mut result = Vec::with_capacity(num_blocks);
668        for block_idx in 0..num_blocks {
669            let mut sum_dx = 0.0_f64;
670            let mut sum_dy = 0.0_f64;
671            let mut sum_conf = 0.0_f64;
672            let mut weight_total = 0.0_f64;
673
674            for (frame_offset, frame) in self.history.iter().enumerate() {
675                if block_idx >= frame.len() {
676                    continue;
677                }
678                // Map frame_offset to kernel index relative to centre
679                let ki = frame_offset as isize - centre as isize + self.radius as isize;
680                if ki < 0 || ki >= self.kernel.len() as isize {
681                    continue;
682                }
683                let w = self.kernel[ki as usize];
684                let mv = &frame[block_idx];
685                sum_dx += f64::from(mv.dx) * w;
686                sum_dy += f64::from(mv.dy) * w;
687                sum_conf += f64::from(mv.confidence) * w;
688                weight_total += w;
689            }
690
691            if weight_total > 1e-15 {
692                result.push(MotionVector::new(
693                    (sum_dx / weight_total) as f32,
694                    (sum_dy / weight_total) as f32,
695                    (sum_conf / weight_total) as f32,
696                ));
697            } else {
698                result.push(
699                    motion_vectors
700                        .get(block_idx)
701                        .copied()
702                        .unwrap_or(MotionVector::zero()),
703                );
704            }
705        }
706
707        result
708    }
709
710    /// Reset the history buffer.
711    pub fn reset(&mut self) {
712        self.history.clear();
713    }
714
715    /// Number of frames currently in the history buffer.
716    #[must_use]
717    pub fn history_len(&self) -> usize {
718        self.history.len()
719    }
720}
721
722#[cfg(test)]
723mod tests {
724    use super::*;
725
726    #[test]
727    fn test_rolling_shutter_params() {
728        let params = RollingShutterParams::new(0.033, 30.0, ReadoutDirection::TopToBottom);
729        assert_eq!(params.readout_time, 0.033);
730        assert_eq!(params.frame_rate, 30.0);
731    }
732
733    #[test]
734    fn test_scanline_time() {
735        let params = RollingShutterParams::new(0.01, 100.0, ReadoutDirection::TopToBottom);
736        let time = params.compute_scanline_time(500, 1000);
737        assert!((time - 0.005).abs() < 1e-10);
738    }
739
740    #[test]
741    fn test_motion_vector() {
742        let mv = MotionVector::new(10.0, 20.0, 0.9);
743        assert_eq!(mv.dx, 10.0);
744        assert_eq!(mv.dy, 20.0);
745        assert_eq!(mv.confidence, 0.9);
746
747        let mag = mv.magnitude();
748        assert!((mag - (10.0f32 * 10.0 + 20.0 * 20.0).sqrt()).abs() < 1e-6);
749    }
750
751    #[test]
752    fn test_zero_motion_vector() {
753        let mv = MotionVector::zero();
754        assert_eq!(mv.dx, 0.0);
755        assert_eq!(mv.dy, 0.0);
756        assert_eq!(mv.magnitude(), 0.0);
757    }
758
759    #[test]
760    fn test_wobble_detector() {
761        let detector = WobbleDetector::new(5.0);
762        assert_eq!(detector.threshold, 5.0);
763    }
764
765    #[test]
766    fn test_wobble_metric() {
767        let detector = WobbleDetector::default();
768        let vectors = vec![
769            MotionVector::new(0.0, 0.0, 1.0),
770            MotionVector::new(10.0, 0.0, 1.0),
771            MotionVector::new(0.0, 0.0, 1.0),
772            MotionVector::new(10.0, 0.0, 1.0),
773        ];
774
775        let metric = detector.compute_wobble_metric(&vectors);
776        assert!(metric > 0.0);
777    }
778
779    #[test]
780    fn test_skew_corrector() {
781        let corrector = SkewCorrector::new(1.0);
782        assert_eq!(corrector.angular_velocity, 1.0);
783    }
784
785    #[test]
786    fn test_global_shutter_simulator() {
787        let simulator = GlobalShutterSimulator::new(10);
788        assert_eq!(simulator.sub_frames, 10);
789    }
790
791    #[test]
792    fn test_readout_direction() {
793        assert_eq!(ReadoutDirection::TopToBottom, ReadoutDirection::TopToBottom);
794        assert_ne!(ReadoutDirection::TopToBottom, ReadoutDirection::BottomToTop);
795    }
796
797    // ── TemporalSmoother (EWMA) ─────────────────────────────────────────────
798
799    #[test]
800    fn test_temporal_smoother_first_frame_passthrough() {
801        let mut smoother = TemporalSmoother::new(0.5);
802        let mvs = vec![
803            MotionVector::new(10.0, 5.0, 0.9),
804            MotionVector::new(-3.0, 2.0, 0.8),
805        ];
806        let result = smoother.smooth(&mvs);
807        assert_eq!(result.len(), 2);
808        assert!((result[0].dx - 10.0).abs() < 1e-5);
809        assert!((result[1].dy - 2.0).abs() < 1e-5);
810    }
811
812    #[test]
813    fn test_temporal_smoother_convergence() {
814        let mut smoother = TemporalSmoother::new(0.3);
815        // Feed constant motion: smoother should converge to it
816        let mvs = vec![MotionVector::new(4.0, -2.0, 1.0)];
817        for _ in 0..50 {
818            let _ = smoother.smooth(&mvs);
819        }
820        let result = smoother.smooth(&mvs);
821        assert!(
822            (result[0].dx - 4.0).abs() < 0.01,
823            "should converge to 4.0, got {}",
824            result[0].dx
825        );
826        assert!(
827            (result[0].dy + 2.0).abs() < 0.01,
828            "should converge to -2.0, got {}",
829            result[0].dy
830        );
831    }
832
833    #[test]
834    fn test_temporal_smoother_dampens_jitter() {
835        let mut smoother = TemporalSmoother::new(0.2);
836        // Alternate between +10 and -10 (high-frequency jitter)
837        let _ = smoother.smooth(&[MotionVector::new(10.0, 0.0, 1.0)]);
838        for _ in 0..20 {
839            let _ = smoother.smooth(&[MotionVector::new(-10.0, 0.0, 1.0)]);
840            let _ = smoother.smooth(&[MotionVector::new(10.0, 0.0, 1.0)]);
841        }
842        let result = smoother.smooth(&[MotionVector::new(-10.0, 0.0, 1.0)]);
843        // After many oscillations, the smoothed result should be near zero
844        assert!(
845            result[0].dx.abs() < 5.0,
846            "jitter should be dampened, got {}",
847            result[0].dx
848        );
849    }
850
851    #[test]
852    fn test_temporal_smoother_alpha_clamping() {
853        let s1 = TemporalSmoother::new(0.0);
854        assert!((s1.alpha() - 0.01).abs() < 1e-10);
855
856        let s2 = TemporalSmoother::new(2.0);
857        assert!((s2.alpha() - 1.0).abs() < 1e-10);
858    }
859
860    #[test]
861    fn test_temporal_smoother_reset() {
862        let mut smoother = TemporalSmoother::new(0.5);
863        let _ = smoother.smooth(&[MotionVector::new(5.0, 5.0, 1.0)]);
864        assert_eq!(smoother.num_blocks(), 1);
865        smoother.reset();
866        assert_eq!(smoother.num_blocks(), 0);
867    }
868
869    // ── GaussianTemporalSmoother ────────────────────────────────────────────
870
871    #[test]
872    fn test_gaussian_smoother_constant_input() {
873        let mut smoother = GaussianTemporalSmoother::new(2, 1.0);
874        let mvs = vec![MotionVector::new(3.0, -1.0, 0.9)];
875        for _ in 0..10 {
876            let result = smoother.push(&mvs);
877            assert_eq!(result.len(), 1);
878            // With constant input, output should converge to input
879            assert!((result[0].dx - 3.0).abs() < 0.5, "dx={}", result[0].dx);
880        }
881    }
882
883    #[test]
884    fn test_gaussian_smoother_dampens_spike() {
885        let mut smoother = GaussianTemporalSmoother::new(2, 1.0);
886        let normal = vec![MotionVector::new(0.0, 0.0, 1.0)];
887        let spike = vec![MotionVector::new(100.0, 0.0, 1.0)];
888
889        let _ = smoother.push(&normal);
890        let _ = smoother.push(&normal);
891        let result = smoother.push(&spike); // spike at the most recent frame
892                                            // The spike should be attenuated because it's averaged with past zeros
893        assert!(
894            result[0].dx < 100.0,
895            "spike should be dampened: dx={}",
896            result[0].dx
897        );
898    }
899
900    #[test]
901    fn test_gaussian_smoother_history_len() {
902        let mut smoother = GaussianTemporalSmoother::new(1, 0.5);
903        assert_eq!(smoother.history_len(), 0);
904        let mvs = vec![MotionVector::zero()];
905        let _ = smoother.push(&mvs);
906        assert_eq!(smoother.history_len(), 1);
907        let _ = smoother.push(&mvs);
908        let _ = smoother.push(&mvs);
909        // Capacity is 2*1+1 = 3
910        assert_eq!(smoother.history_len(), 3);
911        let _ = smoother.push(&mvs);
912        // Should evict oldest
913        assert_eq!(smoother.history_len(), 3);
914    }
915
916    #[test]
917    fn test_gaussian_smoother_reset() {
918        let mut smoother = GaussianTemporalSmoother::new(2, 1.0);
919        let _ = smoother.push(&[MotionVector::zero()]);
920        assert_eq!(smoother.history_len(), 1);
921        smoother.reset();
922        assert_eq!(smoother.history_len(), 0);
923    }
924}