Skip to main content

oximedia_align/
warp.rs

1//! Time warp / time-stretching alignment using Dynamic Time Warping (DTW).
2//!
3//! Provides:
4//! - [`DtwAligner`] – standard DTW with full cost matrix and Euclidean distance.
5//! - [`WarpPath`] – the DTW alignment path and timestamp remapping utilities.
6//! - [`WarpCurve`] – a continuous (`time_ms`, `offset_ms`) curve derived from a path.
7//! - [`WarpSmoothing`] – moving-average smoother for warp curves.
8
9#![allow(dead_code)]
10
11/// A DTW alignment path represented as matched index pairs `(i_a, i_b)`.
12#[derive(Debug, Clone, PartialEq, Eq)]
13pub struct WarpPath {
14    /// Ordered list of aligned index pairs from sequence A and sequence B.
15    pub pairs: Vec<(usize, usize)>,
16}
17
18impl WarpPath {
19    /// Create a warp path from a list of index pairs.
20    #[must_use]
21    pub fn new(pairs: Vec<(usize, usize)>) -> Self {
22        Self { pairs }
23    }
24
25    /// Remap a set of timestamps (in milliseconds) from sequence A's time axis
26    /// to sequence B's time axis.
27    ///
28    /// For each timestamp in `original_ms`, the method finds the closest index
29    /// in the path and returns the corresponding B-index scaled by the per-frame
30    /// duration.
31    ///
32    /// # Arguments
33    /// * `original_ms` – timestamps in A's coordinate system (monotonically
34    ///   increasing, same unit as frame indices × frame duration).
35    ///
36    /// The caller is responsible for choosing a consistent unit (e.g. 1 index =
37    /// 1 ms, or use [`WarpCurve`] for fractional frame rates).
38    #[must_use]
39    pub fn apply_to_timestamps(&self, original_ms: &[u64]) -> Vec<u64> {
40        if self.pairs.is_empty() || original_ms.is_empty() {
41            return original_ms.to_vec();
42        }
43
44        original_ms
45            .iter()
46            .map(|&t| {
47                // Find the pair whose A-index is closest to t.
48                let closest = self
49                    .pairs
50                    .iter()
51                    .min_by_key(|(ia, _)| (*ia as i64 - t as i64).unsigned_abs())
52                    .copied()
53                    .unwrap_or((0, 0));
54                closest.1 as u64
55            })
56            .collect()
57    }
58
59    /// Return the length of the path.
60    #[must_use]
61    pub fn len(&self) -> usize {
62        self.pairs.len()
63    }
64
65    /// Return `true` if the path is empty.
66    #[must_use]
67    pub fn is_empty(&self) -> bool {
68        self.pairs.is_empty()
69    }
70}
71
72/// Dynamic Time Warping aligner.
73///
74/// Uses the standard full-matrix DTW algorithm with Euclidean (absolute value)
75/// distance between scalar samples.
76pub struct DtwAligner;
77
78impl DtwAligner {
79    /// Create a new DTW aligner.
80    #[must_use]
81    pub fn new() -> Self {
82        Self
83    }
84
85    /// Compute the DTW distance and alignment path between two sequences.
86    ///
87    /// Returns `(distance, path)` where `distance` is the normalised DTW cost
88    /// (divided by path length) and `path` contains the matched index pairs.
89    ///
90    /// # Panics
91    /// Does not panic; returns empty path and `0.0` distance for empty inputs.
92    #[must_use]
93    pub fn compute(seq_a: &[f32], seq_b: &[f32]) -> (f32, WarpPath) {
94        let na = seq_a.len();
95        let nb = seq_b.len();
96
97        if na == 0 || nb == 0 {
98            return (0.0, WarpPath::new(vec![]));
99        }
100
101        // Build the DTW cost matrix (na × nb).
102        let inf = f32::INFINITY;
103        let mut dtw = vec![vec![inf; nb]; na];
104
105        dtw[0][0] = (seq_a[0] - seq_b[0]).abs();
106
107        for j in 1..nb {
108            dtw[0][j] = dtw[0][j - 1] + (seq_a[0] - seq_b[j]).abs();
109        }
110        for i in 1..na {
111            dtw[i][0] = dtw[i - 1][0] + (seq_a[i] - seq_b[0]).abs();
112        }
113        for i in 1..na {
114            for j in 1..nb {
115                let cost = (seq_a[i] - seq_b[j]).abs();
116                let min_prev = dtw[i - 1][j].min(dtw[i][j - 1]).min(dtw[i - 1][j - 1]);
117                dtw[i][j] = cost + min_prev;
118            }
119        }
120
121        // Back-track to recover the path.
122        let mut path = Vec::new();
123        let mut i = na - 1;
124        let mut j = nb - 1;
125        path.push((i, j));
126
127        while i > 0 || j > 0 {
128            if i == 0 {
129                j -= 1;
130            } else if j == 0 {
131                i -= 1;
132            } else {
133                let diag = dtw[i - 1][j - 1];
134                let left = dtw[i][j - 1];
135                let up = dtw[i - 1][j];
136                if diag <= left && diag <= up {
137                    i -= 1;
138                    j -= 1;
139                } else if left < up {
140                    j -= 1;
141                } else {
142                    i -= 1;
143                }
144            }
145            path.push((i, j));
146        }
147
148        path.reverse();
149
150        let total_cost = dtw[na - 1][nb - 1];
151        let norm_cost = if path.is_empty() {
152            0.0
153        } else {
154            total_cost / path.len() as f32
155        };
156
157        (norm_cost, WarpPath::new(path))
158    }
159}
160
161impl Default for DtwAligner {
162    fn default() -> Self {
163        Self::new()
164    }
165}
166
167/// A continuous warp curve mapping original timestamps (ms) to signed offsets
168/// (ms).  Each point is `(original_ms, offset_ms)`.
169#[derive(Debug, Clone)]
170pub struct WarpCurve {
171    /// Ordered control points: `(original_time_ms, offset_ms)`.
172    pub points: Vec<(u64, i64)>,
173}
174
175impl WarpCurve {
176    /// Create a warp curve from a [`WarpPath`] and a frames-per-second value.
177    ///
178    /// Each path pair `(ia, ib)` is converted: the A-time is `ia * frame_ms`
179    /// and the offset is `(ib as i64 - ia as i64) * frame_ms`.
180    #[must_use]
181    pub fn from_path(path: &WarpPath, fps: f32) -> Self {
182        if fps <= 0.0 || path.is_empty() {
183            return Self { points: vec![] };
184        }
185
186        let frame_ms = (1000.0 / fps) as i64;
187        let mut points: Vec<(u64, i64)> = path
188            .pairs
189            .iter()
190            .map(|&(ia, ib)| {
191                let t = ia as u64 * frame_ms as u64;
192                let offset = (ib as i64 - ia as i64) * frame_ms;
193                (t, offset)
194            })
195            .collect();
196
197        // Deduplicate by time, keeping the last (should already be monotone).
198        points.dedup_by_key(|p| p.0);
199        Self { points }
200    }
201
202    /// Linearly interpolate the offset at `time_ms`.
203    ///
204    /// Clamps to the first/last point outside the curve's range.
205    #[must_use]
206    pub fn interpolate(&self, time_ms: u64) -> i64 {
207        if self.points.is_empty() {
208            return 0;
209        }
210        if time_ms <= self.points[0].0 {
211            return self.points[0].1;
212        }
213        let last = self.points[self.points.len() - 1];
214        if time_ms >= last.0 {
215            return last.1;
216        }
217
218        // Binary-search for the surrounding segment.
219        let idx = self.points.partition_point(|&(t, _)| t <= time_ms);
220        let (t0, o0) = self.points[idx - 1];
221        let (t1, o1) = self.points[idx];
222
223        let alpha = (time_ms - t0) as f64 / (t1 - t0) as f64;
224        let interpolated = o0 as f64 + alpha * (o1 as f64 - o0 as f64);
225        interpolated.round() as i64
226    }
227}
228
229/// Moving-average smoother for [`WarpCurve`]s.
230pub struct WarpSmoothing;
231
232impl WarpSmoothing {
233    /// Smooth a warp curve using a symmetric moving average of `window` samples.
234    ///
235    /// Points at the boundaries use a reduced window (causal/anticausal
236    /// clamping).
237    #[must_use]
238    pub fn smooth(curve: &WarpCurve, window: usize) -> WarpCurve {
239        let n = curve.points.len();
240        if n == 0 || window <= 1 {
241            return curve.clone();
242        }
243
244        let half = window / 2;
245        let smoothed_points: Vec<(u64, i64)> = (0..n)
246            .map(|i| {
247                let start = i.saturating_sub(half);
248                let end = (i + half + 1).min(n);
249                let count = end - start;
250                let sum: i64 = curve.points[start..end].iter().map(|p| p.1).sum();
251                let avg = (sum as f64 / count as f64).round() as i64;
252                (curve.points[i].0, avg)
253            })
254            .collect();
255
256        WarpCurve {
257            points: smoothed_points,
258        }
259    }
260}
261
262// ─────────────────────────────────────────────────────────────────────────────
263// Unit tests
264// ─────────────────────────────────────────────────────────────────────────────
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269
270    // ── DtwAligner ────────────────────────────────────────────────────────────
271
272    #[test]
273    fn test_dtw_empty_inputs() {
274        let (dist, path) = DtwAligner::compute(&[], &[1.0]);
275        assert_eq!(dist, 0.0);
276        assert!(path.is_empty());
277
278        let (dist2, path2) = DtwAligner::compute(&[1.0], &[]);
279        assert_eq!(dist2, 0.0);
280        assert!(path2.is_empty());
281    }
282
283    #[test]
284    fn test_dtw_identical_sequences() {
285        let seq = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
286        let (dist, path) = DtwAligner::compute(&seq, &seq);
287        assert_eq!(
288            dist, 0.0,
289            "identical sequences should have zero DTW distance"
290        );
291        // Path should be diagonal.
292        for (i, &(ia, ib)) in path.pairs.iter().enumerate() {
293            let _ = i;
294            assert_eq!(ia, ib, "diagonal path expected for identical sequences");
295        }
296    }
297
298    #[test]
299    fn test_dtw_shifted_sequence() {
300        // seq_b is seq_a shifted by one; DTW should find a near-zero alignment.
301        let seq_a = vec![0.0f32, 1.0, 2.0, 3.0, 4.0];
302        let seq_b = vec![0.0f32, 0.0, 1.0, 2.0, 3.0, 4.0];
303        let (dist, path) = DtwAligner::compute(&seq_a, &seq_b);
304        assert!(
305            dist < 1.0,
306            "shifted sequence should have low DTW distance: {dist}"
307        );
308        assert!(!path.is_empty());
309    }
310
311    #[test]
312    fn test_dtw_path_starts_at_origin_ends_at_corner() {
313        let a = vec![1.0f32, 2.0, 3.0];
314        let b = vec![1.0f32, 2.5, 3.0, 3.5];
315        let (_, path) = DtwAligner::compute(&a, &b);
316        assert_eq!(path.pairs[0], (0, 0), "path must start at (0,0)");
317        let last = *path.pairs.last().expect("last should be valid");
318        assert_eq!(
319            last,
320            (a.len() - 1, b.len() - 1),
321            "path must end at (na-1, nb-1)"
322        );
323    }
324
325    #[test]
326    fn test_dtw_single_elements() {
327        let (dist, path) = DtwAligner::compute(&[3.0], &[5.0]);
328        assert!((dist - 2.0).abs() < 1e-6);
329        assert_eq!(path.pairs, vec![(0, 0)]);
330    }
331
332    // ── WarpPath ─────────────────────────────────────────────────────────────
333
334    #[test]
335    fn test_warp_path_apply_timestamps_empty_path() {
336        let path = WarpPath::new(vec![]);
337        let ts = vec![100u64, 200, 300];
338        let result = path.apply_to_timestamps(&ts);
339        assert_eq!(result, ts);
340    }
341
342    #[test]
343    fn test_warp_path_apply_timestamps() {
344        // Path: A[0]→B[0], A[1]→B[2], A[2]→B[3]
345        let path = WarpPath::new(vec![(0, 0), (1, 2), (2, 3)]);
346        // timestamp 1 → closest A-index 1 → B-index 2
347        let result = path.apply_to_timestamps(&[1]);
348        assert_eq!(result, vec![2]);
349    }
350
351    #[test]
352    fn test_warp_path_len() {
353        let path = WarpPath::new(vec![(0, 0), (1, 1), (2, 2)]);
354        assert_eq!(path.len(), 3);
355        assert!(!path.is_empty());
356    }
357
358    // ── WarpCurve ─────────────────────────────────────────────────────────────
359
360    #[test]
361    fn test_warp_curve_from_path_empty() {
362        let path = WarpPath::new(vec![]);
363        let curve = WarpCurve::from_path(&path, 25.0);
364        assert!(curve.points.is_empty());
365    }
366
367    #[test]
368    fn test_warp_curve_from_path_diagonal() {
369        // Diagonal path → all offsets zero.
370        let pairs: Vec<(usize, usize)> = (0..5).map(|i| (i, i)).collect();
371        let path = WarpPath::new(pairs);
372        let curve = WarpCurve::from_path(&path, 25.0);
373        for &(_, offset) in &curve.points {
374            assert_eq!(offset, 0, "diagonal path should produce zero offsets");
375        }
376    }
377
378    #[test]
379    fn test_warp_curve_interpolate_clamp() {
380        let curve = WarpCurve {
381            points: vec![(0, 10), (1000, 20)],
382        };
383        assert_eq!(curve.interpolate(0), 10);
384        assert_eq!(curve.interpolate(2000), 20); // clamp to last
385    }
386
387    #[test]
388    fn test_warp_curve_interpolate_midpoint() {
389        let curve = WarpCurve {
390            points: vec![(0, 0), (1000, 100)],
391        };
392        let mid = curve.interpolate(500);
393        assert!(
394            (mid - 50).abs() <= 1,
395            "midpoint offset should be ~50, got {mid}"
396        );
397    }
398
399    // ── WarpSmoothing ─────────────────────────────────────────────────────────
400
401    #[test]
402    fn test_warp_smoothing_constant_curve() {
403        let curve = WarpCurve {
404            points: vec![(0, 5), (100, 5), (200, 5), (300, 5)],
405        };
406        let smoothed = WarpSmoothing::smooth(&curve, 3);
407        for &(_, v) in &smoothed.points {
408            assert_eq!(
409                v, 5,
410                "constant curve should remain unchanged after smoothing"
411            );
412        }
413    }
414
415    #[test]
416    fn test_warp_smoothing_reduces_spike() {
417        let curve = WarpCurve {
418            points: vec![(0, 0), (100, 0), (200, 100), (300, 0), (400, 0)],
419        };
420        let smoothed = WarpSmoothing::smooth(&curve, 3);
421        // The spike at index 2 should be reduced.
422        let spike_val = smoothed.points[2].1;
423        assert!(spike_val < 100, "spike should be attenuated: {spike_val}");
424    }
425
426    #[test]
427    fn test_warp_smoothing_empty() {
428        let curve = WarpCurve { points: vec![] };
429        let smoothed = WarpSmoothing::smooth(&curve, 5);
430        assert!(smoothed.points.is_empty());
431    }
432}