1#![allow(dead_code)]
14
15#[derive(Debug, Clone, Copy, PartialEq)]
18pub struct DriftMeasurement {
19 pub time_ms: u64,
21 pub drift_ms: f64,
24}
25
26impl DriftMeasurement {
27 #[must_use]
29 pub fn new(time_ms: u64, drift_ms: f64) -> Self {
30 Self { time_ms, drift_ms }
31 }
32}
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub enum DriftModel {
37 Linear,
39 Polynomial,
41 PiecewiseLinear,
44}
45
46pub struct LinearDriftEstimator;
48
49impl LinearDriftEstimator {
50 #[must_use]
57 pub fn fit(measurements: &[DriftMeasurement]) -> (f64, f64) {
58 let n = measurements.len();
59 if n < 2 {
60 return (0.0, 0.0);
61 }
62
63 let xs: Vec<f64> = measurements
65 .iter()
66 .map(|m| m.time_ms as f64 / 1000.0)
67 .collect();
68 let ys: Vec<f64> = measurements.iter().map(|m| m.drift_ms).collect();
69
70 let n_f = n as f64;
71 let sum_x: f64 = xs.iter().sum();
72 let sum_y: f64 = ys.iter().sum();
73 let sum_xy: f64 = xs.iter().zip(ys.iter()).map(|(x, y)| x * y).sum();
74 let sum_xx: f64 = xs.iter().map(|x| x * x).sum();
75
76 let denom = n_f * sum_xx - sum_x * sum_x;
77 if denom.abs() < 1e-12 {
78 let intercept = sum_y / n_f;
80 return (0.0, intercept);
81 }
82
83 let slope = (n_f * sum_xy - sum_x * sum_y) / denom;
84 let intercept = (sum_y - slope * sum_x) / n_f;
85
86 (slope, intercept)
87 }
88}
89
90#[derive(Debug, Clone)]
92pub struct DriftCorrector {
93 pub model: DriftModel,
95 pub coefficients: Vec<f64>,
100}
101
102impl DriftCorrector {
103 #[must_use]
105 pub fn new(model: DriftModel, coefficients: Vec<f64>) -> Self {
106 Self {
107 model,
108 coefficients,
109 }
110 }
111
112 #[must_use]
114 pub fn from_measurements(measurements: &[DriftMeasurement], model: DriftModel) -> Self {
115 match model {
116 DriftModel::Linear => {
117 let (slope, intercept) = LinearDriftEstimator::fit(measurements);
118 Self::new(model, vec![slope, intercept])
119 }
120 DriftModel::Polynomial => {
121 let coeffs = fit_quadratic(measurements);
123 Self::new(model, coeffs)
124 }
125 DriftModel::PiecewiseLinear => {
126 let mut coeffs = Vec::with_capacity(measurements.len() * 2);
128 for m in measurements {
129 coeffs.push(m.time_ms as f64 / 1000.0);
130 coeffs.push(m.drift_ms);
131 }
132 Self::new(model, coeffs)
133 }
134 }
135 }
136
137 #[must_use]
143 pub fn correct(&self, time_ms: u64) -> i64 {
144 let t_s = time_ms as f64 / 1000.0;
145 let drift = match self.model {
146 DriftModel::Linear => {
147 let slope = self.coefficients.first().copied().unwrap_or(0.0);
148 let intercept = self.coefficients.get(1).copied().unwrap_or(0.0);
149 slope * t_s + intercept
150 }
151 DriftModel::Polynomial => {
152 let a = self.coefficients.first().copied().unwrap_or(0.0);
153 let b = self.coefficients.get(1).copied().unwrap_or(0.0);
154 let c = self.coefficients.get(2).copied().unwrap_or(0.0);
155 a * t_s * t_s + b * t_s + c
156 }
157 DriftModel::PiecewiseLinear => piecewise_linear_eval(&self.coefficients, t_s),
158 };
159 drift.round() as i64
160 }
161}
162
163#[derive(Debug, Clone, Copy)]
165pub struct DriftQuality {
166 pub rms_error_ms: f64,
168 pub max_error_ms: f64,
170 pub r_squared: f64,
172}
173
174impl DriftQuality {
175 #[must_use]
177 pub fn evaluate(model: &DriftCorrector, measurements: &[DriftMeasurement]) -> Self {
178 if measurements.is_empty() {
179 return Self {
180 rms_error_ms: 0.0,
181 max_error_ms: 0.0,
182 r_squared: 1.0,
183 };
184 }
185
186 let n = measurements.len() as f64;
187 let mean_drift = measurements.iter().map(|m| m.drift_ms).sum::<f64>() / n;
188
189 let mut ss_res = 0.0f64;
190 let mut ss_tot = 0.0f64;
191 let mut max_err = 0.0f64;
192
193 for m in measurements {
194 let predicted = model.correct(m.time_ms) as f64;
195 let residual = m.drift_ms - predicted;
196 ss_res += residual * residual;
197 ss_tot += (m.drift_ms - mean_drift) * (m.drift_ms - mean_drift);
198 max_err = max_err.max(residual.abs());
199 }
200
201 let rms = (ss_res / n).sqrt();
202 let r2 = if ss_tot < 1e-12 {
203 1.0
204 } else {
205 1.0 - ss_res / ss_tot
206 };
207
208 Self {
209 rms_error_ms: rms,
210 max_error_ms: max_err,
211 r_squared: r2,
212 }
213 }
214}
215
216fn fit_quadratic(measurements: &[DriftMeasurement]) -> Vec<f64> {
223 let n = measurements.len();
224 if n < 3 {
225 let (slope, intercept) = LinearDriftEstimator::fit(measurements);
226 return vec![0.0, slope, intercept];
227 }
228
229 let xs: Vec<f64> = measurements
232 .iter()
233 .map(|m| m.time_ms as f64 / 1000.0)
234 .collect();
235 let ys: Vec<f64> = measurements.iter().map(|m| m.drift_ms).collect();
236
237 let n_f = n as f64;
238 let s1: f64 = xs.iter().sum();
239 let s2: f64 = xs.iter().map(|x| x * x).sum();
240 let s3: f64 = xs.iter().map(|x| x * x * x).sum();
241 let s4: f64 = xs.iter().map(|x| x * x * x * x).sum();
242 let t0: f64 = ys.iter().sum();
243 let t1: f64 = xs.iter().zip(ys.iter()).map(|(x, y)| x * y).sum();
244 let t2: f64 = xs.iter().zip(ys.iter()).map(|(x, y)| x * x * y).sum();
245
246 let mat = [[s4, s3, s2], [s3, s2, s1], [s2, s1, n_f]];
248 let rhs = [t2, t1, t0];
249
250 if let Some([a, b, c]) = solve_3x3(&mat, &rhs) {
251 vec![a, b, c]
252 } else {
253 let (slope, intercept) = LinearDriftEstimator::fit(measurements);
255 vec![0.0, slope, intercept]
256 }
257}
258
259fn solve_3x3(m: &[[f64; 3]; 3], rhs: &[f64; 3]) -> Option<[f64; 3]> {
261 let det = m[0][0] * (m[1][1] * m[2][2] - m[1][2] * m[2][1])
262 - m[0][1] * (m[1][0] * m[2][2] - m[1][2] * m[2][0])
263 + m[0][2] * (m[1][0] * m[2][1] - m[1][1] * m[2][0]);
264
265 if det.abs() < 1e-12 {
266 return None;
267 }
268
269 let mut result = [0.0f64; 3];
270 for k in 0..3 {
271 let mut mat_k = *m;
272 for i in 0..3 {
273 mat_k[i][k] = rhs[i];
274 }
275 let det_k = mat_k[0][0] * (mat_k[1][1] * mat_k[2][2] - mat_k[1][2] * mat_k[2][1])
276 - mat_k[0][1] * (mat_k[1][0] * mat_k[2][2] - mat_k[1][2] * mat_k[2][0])
277 + mat_k[0][2] * (mat_k[1][0] * mat_k[2][1] - mat_k[1][1] * mat_k[2][0]);
278 result[k] = det_k / det;
279 }
280 Some(result)
281}
282
283fn piecewise_linear_eval(coeffs: &[f64], t_s: f64) -> f64 {
286 if coeffs.len() < 2 {
287 return 0.0;
288 }
289
290 let pairs: Vec<(f64, f64)> = coeffs.chunks(2).map(|c| (c[0], c[1])).collect();
292
293 if t_s <= pairs[0].0 {
294 return pairs[0].1;
295 }
296 let last = pairs[pairs.len() - 1];
297 if t_s >= last.0 {
298 return last.1;
299 }
300
301 for i in 0..pairs.len() - 1 {
302 let (t0, d0) = pairs[i];
303 let (t1, d1) = pairs[i + 1];
304 if t_s >= t0 && t_s <= t1 {
305 let alpha = (t_s - t0) / (t1 - t0);
306 return d0 + alpha * (d1 - d0);
307 }
308 }
309 0.0
310}
311
312#[cfg(test)]
317mod tests {
318 use super::*;
319
320 #[test]
323 fn test_measurement_creation() {
324 let m = DriftMeasurement::new(5000, 1.5);
325 assert_eq!(m.time_ms, 5000);
326 assert!((m.drift_ms - 1.5).abs() < f64::EPSILON);
327 }
328
329 #[test]
332 fn test_linear_fit_insufficient_data() {
333 let (slope, intercept) = LinearDriftEstimator::fit(&[]);
334 assert_eq!(slope, 0.0);
335 assert_eq!(intercept, 0.0);
336
337 let (s, i) = LinearDriftEstimator::fit(&[DriftMeasurement::new(0, 1.0)]);
338 assert_eq!(s, 0.0);
339 assert_eq!(i, 0.0);
340 }
341
342 #[test]
343 fn test_linear_fit_zero_drift() {
344 let measurements: Vec<DriftMeasurement> = (0..5)
345 .map(|i| DriftMeasurement::new(i * 1000, 0.0))
346 .collect();
347 let (slope, intercept) = LinearDriftEstimator::fit(&measurements);
348 assert!(slope.abs() < 1e-9);
349 assert!(intercept.abs() < 1e-9);
350 }
351
352 #[test]
353 fn test_linear_fit_perfect_linear() {
354 let measurements: Vec<DriftMeasurement> = (0..5)
356 .map(|i| {
357 let t_s = i as f64;
358 DriftMeasurement::new((t_s * 1000.0) as u64, 2.0 * t_s + 0.5)
359 })
360 .collect();
361 let (slope, intercept) = LinearDriftEstimator::fit(&measurements);
362 assert!((slope - 2.0).abs() < 1e-6, "slope: {slope}");
363 assert!((intercept - 0.5).abs() < 1e-6, "intercept: {intercept}");
364 }
365
366 #[test]
369 fn test_corrector_linear_zero() {
370 let corrector = DriftCorrector::new(DriftModel::Linear, vec![0.0, 0.0]);
371 assert_eq!(corrector.correct(0), 0);
372 assert_eq!(corrector.correct(60_000), 0);
373 }
374
375 #[test]
376 fn test_corrector_linear_constant_drift() {
377 let corrector = DriftCorrector::new(DriftModel::Linear, vec![0.0, 10.0]);
379 assert_eq!(corrector.correct(0), 10);
380 assert_eq!(corrector.correct(30_000), 10);
381 }
382
383 #[test]
384 fn test_corrector_from_measurements_linear() {
385 let measurements = vec![
386 DriftMeasurement::new(0, 0.0),
387 DriftMeasurement::new(1_000, 1.0),
388 DriftMeasurement::new(2_000, 2.0),
389 ];
390 let corrector = DriftCorrector::from_measurements(&measurements, DriftModel::Linear);
391 let correction_at_3s = corrector.correct(3_000);
393 assert!((correction_at_3s - 3).abs() <= 1, "got {correction_at_3s}");
394 }
395
396 #[test]
399 fn test_corrector_piecewise_clamping() {
400 let measurements = vec![
401 DriftMeasurement::new(1_000, 5.0),
402 DriftMeasurement::new(3_000, 15.0),
403 ];
404 let corrector =
405 DriftCorrector::from_measurements(&measurements, DriftModel::PiecewiseLinear);
406 assert_eq!(corrector.correct(0), 5);
408 assert_eq!(corrector.correct(10_000), 15);
410 let mid = corrector.correct(2_000);
412 assert!(
413 (mid - 10).abs() <= 1,
414 "mid correction should be ~10, got {mid}"
415 );
416 }
417
418 #[test]
421 fn test_quality_perfect_fit() {
422 let measurements = vec![
423 DriftMeasurement::new(0, 0.0),
424 DriftMeasurement::new(1_000, 1.0),
425 DriftMeasurement::new(2_000, 2.0),
426 ];
427 let corrector = DriftCorrector::from_measurements(&measurements, DriftModel::Linear);
428 let quality = DriftQuality::evaluate(&corrector, &measurements);
429 assert!(quality.rms_error_ms < 0.5, "rms: {}", quality.rms_error_ms);
430 assert!(quality.r_squared > 0.99, "r²: {}", quality.r_squared);
431 }
432
433 #[test]
434 fn test_quality_empty_measurements() {
435 let corrector = DriftCorrector::new(DriftModel::Linear, vec![0.0, 0.0]);
436 let quality = DriftQuality::evaluate(&corrector, &[]);
437 assert_eq!(quality.rms_error_ms, 0.0);
438 assert!((quality.r_squared - 1.0).abs() < f64::EPSILON);
439 }
440
441 #[test]
442 fn test_quality_fields_exist() {
443 let corrector = DriftCorrector::new(DriftModel::Linear, vec![0.0, 5.0]);
444 let measurements = vec![
445 DriftMeasurement::new(0, 5.0),
446 DriftMeasurement::new(1000, 5.0),
447 ];
448 let q = DriftQuality::evaluate(&corrector, &measurements);
449 assert!(q.rms_error_ms < 0.1);
451 assert!(q.max_error_ms < 0.1);
452 }
453
454 #[test]
457 fn test_corrector_polynomial_from_measurements() {
458 let measurements: Vec<DriftMeasurement> = (0..5)
459 .map(|i| DriftMeasurement::new(i * 1000, (i * i) as f64))
460 .collect();
461 let corrector = DriftCorrector::from_measurements(&measurements, DriftModel::Polynomial);
462 let c = corrector.correct(3_000);
464 assert!((c - 9).abs() <= 2, "polynomial correction at 3s: {c}");
465 }
466}