Skip to main content

oximedia_align/
drift_correction.rs

1//! Long-form timing drift correction.
2//!
3//! Provides tools to model and correct gradual clock drift between recording
4//! devices over extended recording sessions:
5//!
6//! - [`DriftMeasurement`] – an observed drift sample at a given time.
7//! - [`LinearDriftEstimator`] – least-squares linear drift model.
8//! - [`DriftModel`] – choice of correction model.
9//! - [`DriftCorrector`] – applies the fitted model to compute per-timestamp
10//!   corrections.
11//! - [`DriftQuality`] – evaluates the quality of a fitted model.
12
13#![allow(dead_code)]
14
15/// A single drift observation: the measured offset between two clocks at
16/// `time_ms` milliseconds into the recording.
17#[derive(Debug, Clone, Copy, PartialEq)]
18pub struct DriftMeasurement {
19    /// Elapsed time in milliseconds since the start of the recording.
20    pub time_ms: u64,
21    /// Measured drift (in milliseconds) at `time_ms`.  Positive means the
22    /// secondary device is ahead of the reference.
23    pub drift_ms: f64,
24}
25
26impl DriftMeasurement {
27    /// Create a new drift measurement.
28    #[must_use]
29    pub fn new(time_ms: u64, drift_ms: f64) -> Self {
30        Self { time_ms, drift_ms }
31    }
32}
33
34/// Selects which mathematical model to use for drift correction.
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub enum DriftModel {
37    /// First-order (linear) drift: `drift = slope * t + intercept`.
38    Linear,
39    /// Second-order polynomial: `drift = a*t² + b*t + c`.
40    Polynomial,
41    /// Piecewise-linear: drift is linearly interpolated between measurement
42    /// points.
43    PiecewiseLinear,
44}
45
46/// Fits a linear drift model via ordinary least squares.
47pub struct LinearDriftEstimator;
48
49impl LinearDriftEstimator {
50    /// Fit a linear model to the provided measurements.
51    ///
52    /// Returns `(slope_ms_per_sec, intercept_ms)` where slope is the rate of
53    /// drift in milliseconds per second, and intercept is the drift at `t=0`.
54    ///
55    /// Falls back to `(0.0, 0.0)` when fewer than 2 measurements are provided.
56    #[must_use]
57    pub fn fit(measurements: &[DriftMeasurement]) -> (f64, f64) {
58        let n = measurements.len();
59        if n < 2 {
60            return (0.0, 0.0);
61        }
62
63        // Convert time to seconds for numerically stable regression.
64        let xs: Vec<f64> = measurements
65            .iter()
66            .map(|m| m.time_ms as f64 / 1000.0)
67            .collect();
68        let ys: Vec<f64> = measurements.iter().map(|m| m.drift_ms).collect();
69
70        let n_f = n as f64;
71        let sum_x: f64 = xs.iter().sum();
72        let sum_y: f64 = ys.iter().sum();
73        let sum_xy: f64 = xs.iter().zip(ys.iter()).map(|(x, y)| x * y).sum();
74        let sum_xx: f64 = xs.iter().map(|x| x * x).sum();
75
76        let denom = n_f * sum_xx - sum_x * sum_x;
77        if denom.abs() < 1e-12 {
78            // All time values are identical → only an intercept can be estimated.
79            let intercept = sum_y / n_f;
80            return (0.0, intercept);
81        }
82
83        let slope = (n_f * sum_xy - sum_x * sum_y) / denom;
84        let intercept = (sum_y - slope * sum_x) / n_f;
85
86        (slope, intercept)
87    }
88}
89
90/// Applies a fitted drift model to compute per-timestamp corrections.
91#[derive(Debug, Clone)]
92pub struct DriftCorrector {
93    /// The model variant in use.
94    pub model: DriftModel,
95    /// Model coefficients (interpretation depends on `model`):
96    /// - `Linear`: `[slope_ms_per_sec, intercept_ms]`
97    /// - `Polynomial`: `[a, b, c]` for `a*t²+b*t+c` (t in seconds)
98    /// - `PiecewiseLinear`: interleaved `[t0_s, d0, t1_s, d1, …]`
99    pub coefficients: Vec<f64>,
100}
101
102impl DriftCorrector {
103    /// Create a new corrector with explicit model and coefficients.
104    #[must_use]
105    pub fn new(model: DriftModel, coefficients: Vec<f64>) -> Self {
106        Self {
107            model,
108            coefficients,
109        }
110    }
111
112    /// Fit a corrector from observations using the specified model.
113    #[must_use]
114    pub fn from_measurements(measurements: &[DriftMeasurement], model: DriftModel) -> Self {
115        match model {
116            DriftModel::Linear => {
117                let (slope, intercept) = LinearDriftEstimator::fit(measurements);
118                Self::new(model, vec![slope, intercept])
119            }
120            DriftModel::Polynomial => {
121                // Simple quadratic fit via normal equations (3×3 system).
122                let coeffs = fit_quadratic(measurements);
123                Self::new(model, coeffs)
124            }
125            DriftModel::PiecewiseLinear => {
126                // Store measurement pairs directly.
127                let mut coeffs = Vec::with_capacity(measurements.len() * 2);
128                for m in measurements {
129                    coeffs.push(m.time_ms as f64 / 1000.0);
130                    coeffs.push(m.drift_ms);
131                }
132                Self::new(model, coeffs)
133            }
134        }
135    }
136
137    /// Compute the drift correction (in milliseconds, rounded to integer) at
138    /// `time_ms` milliseconds.
139    ///
140    /// The correction to *apply* to the secondary device's timestamp is the
141    /// negative of the predicted drift: `corrected = original - correction`.
142    #[must_use]
143    pub fn correct(&self, time_ms: u64) -> i64 {
144        let t_s = time_ms as f64 / 1000.0;
145        let drift = match self.model {
146            DriftModel::Linear => {
147                let slope = self.coefficients.first().copied().unwrap_or(0.0);
148                let intercept = self.coefficients.get(1).copied().unwrap_or(0.0);
149                slope * t_s + intercept
150            }
151            DriftModel::Polynomial => {
152                let a = self.coefficients.first().copied().unwrap_or(0.0);
153                let b = self.coefficients.get(1).copied().unwrap_or(0.0);
154                let c = self.coefficients.get(2).copied().unwrap_or(0.0);
155                a * t_s * t_s + b * t_s + c
156            }
157            DriftModel::PiecewiseLinear => piecewise_linear_eval(&self.coefficients, t_s),
158        };
159        drift.round() as i64
160    }
161}
162
163/// Quality metrics for a fitted drift model.
164#[derive(Debug, Clone, Copy)]
165pub struct DriftQuality {
166    /// Root-mean-square of the residuals (in ms).
167    pub rms_error_ms: f64,
168    /// Maximum absolute residual (in ms).
169    pub max_error_ms: f64,
170    /// Coefficient of determination R².
171    pub r_squared: f64,
172}
173
174impl DriftQuality {
175    /// Evaluate the quality of `model` against the provided measurements.
176    #[must_use]
177    pub fn evaluate(model: &DriftCorrector, measurements: &[DriftMeasurement]) -> Self {
178        if measurements.is_empty() {
179            return Self {
180                rms_error_ms: 0.0,
181                max_error_ms: 0.0,
182                r_squared: 1.0,
183            };
184        }
185
186        let n = measurements.len() as f64;
187        let mean_drift = measurements.iter().map(|m| m.drift_ms).sum::<f64>() / n;
188
189        let mut ss_res = 0.0f64;
190        let mut ss_tot = 0.0f64;
191        let mut max_err = 0.0f64;
192
193        for m in measurements {
194            let predicted = model.correct(m.time_ms) as f64;
195            let residual = m.drift_ms - predicted;
196            ss_res += residual * residual;
197            ss_tot += (m.drift_ms - mean_drift) * (m.drift_ms - mean_drift);
198            max_err = max_err.max(residual.abs());
199        }
200
201        let rms = (ss_res / n).sqrt();
202        let r2 = if ss_tot < 1e-12 {
203            1.0
204        } else {
205            1.0 - ss_res / ss_tot
206        };
207
208        Self {
209            rms_error_ms: rms,
210            max_error_ms: max_err,
211            r_squared: r2,
212        }
213    }
214}
215
216// ─────────────────────────────────────────────────────────────────────────────
217// Internal helpers
218// ─────────────────────────────────────────────────────────────────────────────
219
220/// Fit a quadratic `y = a*t² + b*t + c` to the measurements using normal
221/// equations.  Falls back to linear when fewer than 3 points.
222fn fit_quadratic(measurements: &[DriftMeasurement]) -> Vec<f64> {
223    let n = measurements.len();
224    if n < 3 {
225        let (slope, intercept) = LinearDriftEstimator::fit(measurements);
226        return vec![0.0, slope, intercept];
227    }
228
229    // Build sums for the 3×3 normal equations: X'X β = X'y
230    //   X = [t², t, 1],  y = drift
231    let xs: Vec<f64> = measurements
232        .iter()
233        .map(|m| m.time_ms as f64 / 1000.0)
234        .collect();
235    let ys: Vec<f64> = measurements.iter().map(|m| m.drift_ms).collect();
236
237    let n_f = n as f64;
238    let s1: f64 = xs.iter().sum();
239    let s2: f64 = xs.iter().map(|x| x * x).sum();
240    let s3: f64 = xs.iter().map(|x| x * x * x).sum();
241    let s4: f64 = xs.iter().map(|x| x * x * x * x).sum();
242    let t0: f64 = ys.iter().sum();
243    let t1: f64 = xs.iter().zip(ys.iter()).map(|(x, y)| x * y).sum();
244    let t2: f64 = xs.iter().zip(ys.iter()).map(|(x, y)| x * x * y).sum();
245
246    // 3×3 system: [[s4, s3, s2], [s3, s2, s1], [s2, s1, n]] [a,b,c]' = [t2,t1,t0]
247    let mat = [[s4, s3, s2], [s3, s2, s1], [s2, s1, n_f]];
248    let rhs = [t2, t1, t0];
249
250    if let Some([a, b, c]) = solve_3x3(&mat, &rhs) {
251        vec![a, b, c]
252    } else {
253        // Singular matrix – fall back to linear.
254        let (slope, intercept) = LinearDriftEstimator::fit(measurements);
255        vec![0.0, slope, intercept]
256    }
257}
258
259/// Solve a 3×3 linear system via Cramer's rule.  Returns `None` if singular.
260fn solve_3x3(m: &[[f64; 3]; 3], rhs: &[f64; 3]) -> Option<[f64; 3]> {
261    let det = m[0][0] * (m[1][1] * m[2][2] - m[1][2] * m[2][1])
262        - m[0][1] * (m[1][0] * m[2][2] - m[1][2] * m[2][0])
263        + m[0][2] * (m[1][0] * m[2][1] - m[1][1] * m[2][0]);
264
265    if det.abs() < 1e-12 {
266        return None;
267    }
268
269    let mut result = [0.0f64; 3];
270    for k in 0..3 {
271        let mut mat_k = *m;
272        for i in 0..3 {
273            mat_k[i][k] = rhs[i];
274        }
275        let det_k = mat_k[0][0] * (mat_k[1][1] * mat_k[2][2] - mat_k[1][2] * mat_k[2][1])
276            - mat_k[0][1] * (mat_k[1][0] * mat_k[2][2] - mat_k[1][2] * mat_k[2][0])
277            + mat_k[0][2] * (mat_k[1][0] * mat_k[2][1] - mat_k[1][1] * mat_k[2][0]);
278        result[k] = det_k / det;
279    }
280    Some(result)
281}
282
283/// Evaluate a piecewise-linear function stored as interleaved `[t, v, …]` pairs
284/// at time `t_s`.
285fn piecewise_linear_eval(coeffs: &[f64], t_s: f64) -> f64 {
286    if coeffs.len() < 2 {
287        return 0.0;
288    }
289
290    // Pairs: (t0, d0), (t1, d1), …
291    let pairs: Vec<(f64, f64)> = coeffs.chunks(2).map(|c| (c[0], c[1])).collect();
292
293    if t_s <= pairs[0].0 {
294        return pairs[0].1;
295    }
296    let last = pairs[pairs.len() - 1];
297    if t_s >= last.0 {
298        return last.1;
299    }
300
301    for i in 0..pairs.len() - 1 {
302        let (t0, d0) = pairs[i];
303        let (t1, d1) = pairs[i + 1];
304        if t_s >= t0 && t_s <= t1 {
305            let alpha = (t_s - t0) / (t1 - t0);
306            return d0 + alpha * (d1 - d0);
307        }
308    }
309    0.0
310}
311
312// ─────────────────────────────────────────────────────────────────────────────
313// Unit tests
314// ─────────────────────────────────────────────────────────────────────────────
315
316#[cfg(test)]
317mod tests {
318    use super::*;
319
320    // ── DriftMeasurement ─────────────────────────────────────────────────────
321
322    #[test]
323    fn test_measurement_creation() {
324        let m = DriftMeasurement::new(5000, 1.5);
325        assert_eq!(m.time_ms, 5000);
326        assert!((m.drift_ms - 1.5).abs() < f64::EPSILON);
327    }
328
329    // ── LinearDriftEstimator ─────────────────────────────────────────────────
330
331    #[test]
332    fn test_linear_fit_insufficient_data() {
333        let (slope, intercept) = LinearDriftEstimator::fit(&[]);
334        assert_eq!(slope, 0.0);
335        assert_eq!(intercept, 0.0);
336
337        let (s, i) = LinearDriftEstimator::fit(&[DriftMeasurement::new(0, 1.0)]);
338        assert_eq!(s, 0.0);
339        assert_eq!(i, 0.0);
340    }
341
342    #[test]
343    fn test_linear_fit_zero_drift() {
344        let measurements: Vec<DriftMeasurement> = (0..5)
345            .map(|i| DriftMeasurement::new(i * 1000, 0.0))
346            .collect();
347        let (slope, intercept) = LinearDriftEstimator::fit(&measurements);
348        assert!(slope.abs() < 1e-9);
349        assert!(intercept.abs() < 1e-9);
350    }
351
352    #[test]
353    fn test_linear_fit_perfect_linear() {
354        // drift = 2 * t_s + 0.5  (t in seconds)
355        let measurements: Vec<DriftMeasurement> = (0..5)
356            .map(|i| {
357                let t_s = i as f64;
358                DriftMeasurement::new((t_s * 1000.0) as u64, 2.0 * t_s + 0.5)
359            })
360            .collect();
361        let (slope, intercept) = LinearDriftEstimator::fit(&measurements);
362        assert!((slope - 2.0).abs() < 1e-6, "slope: {slope}");
363        assert!((intercept - 0.5).abs() < 1e-6, "intercept: {intercept}");
364    }
365
366    // ── DriftCorrector (linear) ───────────────────────────────────────────────
367
368    #[test]
369    fn test_corrector_linear_zero() {
370        let corrector = DriftCorrector::new(DriftModel::Linear, vec![0.0, 0.0]);
371        assert_eq!(corrector.correct(0), 0);
372        assert_eq!(corrector.correct(60_000), 0);
373    }
374
375    #[test]
376    fn test_corrector_linear_constant_drift() {
377        // slope = 0, intercept = 10 ms
378        let corrector = DriftCorrector::new(DriftModel::Linear, vec![0.0, 10.0]);
379        assert_eq!(corrector.correct(0), 10);
380        assert_eq!(corrector.correct(30_000), 10);
381    }
382
383    #[test]
384    fn test_corrector_from_measurements_linear() {
385        let measurements = vec![
386            DriftMeasurement::new(0, 0.0),
387            DriftMeasurement::new(1_000, 1.0),
388            DriftMeasurement::new(2_000, 2.0),
389        ];
390        let corrector = DriftCorrector::from_measurements(&measurements, DriftModel::Linear);
391        // slope ≈ 1 ms/s, intercept ≈ 0
392        let correction_at_3s = corrector.correct(3_000);
393        assert!((correction_at_3s - 3).abs() <= 1, "got {correction_at_3s}");
394    }
395
396    // ── DriftCorrector (piecewise linear) ────────────────────────────────────
397
398    #[test]
399    fn test_corrector_piecewise_clamping() {
400        let measurements = vec![
401            DriftMeasurement::new(1_000, 5.0),
402            DriftMeasurement::new(3_000, 15.0),
403        ];
404        let corrector =
405            DriftCorrector::from_measurements(&measurements, DriftModel::PiecewiseLinear);
406        // Before first point → clamp to first value.
407        assert_eq!(corrector.correct(0), 5);
408        // After last point → clamp to last value.
409        assert_eq!(corrector.correct(10_000), 15);
410        // At mid-point (2 s) → ~10.
411        let mid = corrector.correct(2_000);
412        assert!(
413            (mid - 10).abs() <= 1,
414            "mid correction should be ~10, got {mid}"
415        );
416    }
417
418    // ── DriftQuality ─────────────────────────────────────────────────────────
419
420    #[test]
421    fn test_quality_perfect_fit() {
422        let measurements = vec![
423            DriftMeasurement::new(0, 0.0),
424            DriftMeasurement::new(1_000, 1.0),
425            DriftMeasurement::new(2_000, 2.0),
426        ];
427        let corrector = DriftCorrector::from_measurements(&measurements, DriftModel::Linear);
428        let quality = DriftQuality::evaluate(&corrector, &measurements);
429        assert!(quality.rms_error_ms < 0.5, "rms: {}", quality.rms_error_ms);
430        assert!(quality.r_squared > 0.99, "r²: {}", quality.r_squared);
431    }
432
433    #[test]
434    fn test_quality_empty_measurements() {
435        let corrector = DriftCorrector::new(DriftModel::Linear, vec![0.0, 0.0]);
436        let quality = DriftQuality::evaluate(&corrector, &[]);
437        assert_eq!(quality.rms_error_ms, 0.0);
438        assert!((quality.r_squared - 1.0).abs() < f64::EPSILON);
439    }
440
441    #[test]
442    fn test_quality_fields_exist() {
443        let corrector = DriftCorrector::new(DriftModel::Linear, vec![0.0, 5.0]);
444        let measurements = vec![
445            DriftMeasurement::new(0, 5.0),
446            DriftMeasurement::new(1000, 5.0),
447        ];
448        let q = DriftQuality::evaluate(&corrector, &measurements);
449        // Both measurements predict 5 ms perfectly.
450        assert!(q.rms_error_ms < 0.1);
451        assert!(q.max_error_ms < 0.1);
452    }
453
454    // ── DriftCorrector (polynomial) ───────────────────────────────────────────
455
456    #[test]
457    fn test_corrector_polynomial_from_measurements() {
458        let measurements: Vec<DriftMeasurement> = (0..5)
459            .map(|i| DriftMeasurement::new(i * 1000, (i * i) as f64))
460            .collect();
461        let corrector = DriftCorrector::from_measurements(&measurements, DriftModel::Polynomial);
462        // At t=3 s, drift should be ≈ 9 ms.
463        let c = corrector.correct(3_000);
464        assert!((c - 9).abs() <= 2, "polynomial correction at 3s: {c}");
465    }
466}