use crate::core::config::DtwMethod;
use crate::core::traits::DistanceMetric;
use crate::metrics::constraint_bands::{itakura_parallelogram, sakoe_chiba_band};
#[derive(Debug, Clone)]
pub struct Dtw {
pub method: DtwMethod,
}
impl Dtw {
pub fn new() -> Self {
Self {
method: DtwMethod::Classic,
}
}
pub fn with_method(method: DtwMethod) -> Self {
Self { method }
}
}
impl Default for Dtw {
fn default() -> Self {
Self::new()
}
}
impl DistanceMetric for Dtw {
fn distance(&self, a: &[f64], b: &[f64]) -> f64 {
match self.method {
DtwMethod::Classic => dtw_classic(a, b),
DtwMethod::SakoeChibaBand { window_size } => dtw_sakoe_chiba(a, b, window_size),
DtwMethod::ItakuraParallelogram { max_slope } => dtw_itakura(a, b, max_slope),
DtwMethod::Multiscale { resolution, radius } => {
dtw_multiscale(a, b, resolution, radius)
}
DtwMethod::Fast { resolution, radius } => dtw_fast(a, b, resolution, radius),
}
}
}
pub fn dtw_classic(a: &[f64], b: &[f64]) -> f64 {
let n = a.len();
let m = b.len();
assert!(n > 0 && m > 0, "Time series must be non-empty");
let mut prev = vec![f64::INFINITY; m];
let mut curr = vec![f64::INFINITY; m];
let mut sq_diff = vec![0.0_f64; m];
let a0 = a[0];
for j in 0..m {
let d = a0 - b[j];
sq_diff[j] = d * d;
}
prev[0] = sq_diff[0];
for j in 1..m {
prev[j] = prev[j - 1] + sq_diff[j];
}
for i in 1..n {
let ai = a[i];
for j in 0..m {
let d = ai - b[j];
sq_diff[j] = d * d;
}
curr[0] = prev[0] + sq_diff[0];
for j in 1..m {
curr[j] = sq_diff[j] + prev[j].min(curr[j - 1]).min(prev[j - 1]);
}
std::mem::swap(&mut prev, &mut curr);
curr[0] = f64::INFINITY;
}
prev[m - 1].sqrt()
}
pub fn dtw_sakoe_chiba(a: &[f64], b: &[f64], window_size: usize) -> f64 {
let n = a.len();
let m = b.len();
assert!(n > 0 && m > 0, "Time series must be non-empty");
let band = sakoe_chiba_band(n, m, window_size);
dtw_with_band(a, b, &band)
}
pub fn dtw_itakura(a: &[f64], b: &[f64], max_slope: f64) -> f64 {
let n = a.len();
let m = b.len();
assert!(n > 0 && m > 0, "Time series must be non-empty");
let band = itakura_parallelogram(n, m, max_slope);
dtw_with_band(a, b, &band)
}
fn dtw_with_band(a: &[f64], b: &[f64], band: &[(usize, usize)]) -> f64 {
let n = a.len();
let m = b.len();
let cols = m + 1;
let mut cost = vec![f64::INFINITY; (n + 1) * cols];
cost[0] = 0.0;
for (i, &(lo, hi)) in band.iter().enumerate() {
let row = (i + 1) * cols;
let prev_row = i * cols;
for j in lo..=hi {
let d = a[i] - b[j];
cost[row + j + 1] = d * d
+ cost[prev_row + j + 1]
.min(cost[row + j])
.min(cost[prev_row + j]);
}
}
cost[n * cols + m].sqrt()
}
pub fn dtw_multiscale(a: &[f64], b: &[f64], resolution: usize, radius: usize) -> f64 {
let n = a.len();
let m = b.len();
assert!(n > 0 && m > 0, "Time series must be non-empty");
if n <= resolution * 2 || m <= resolution * 2 {
return dtw_classic(a, b);
}
let a_coarse = coarsen(a, resolution);
let b_coarse = coarsen(b, resolution);
let coarse_path = dtw_path(&a_coarse, &b_coarse);
let band = project_path_to_band(&coarse_path, resolution, radius, n, m);
dtw_with_band(a, b, &band)
}
pub fn dtw_fast(a: &[f64], b: &[f64], resolution: usize, radius: usize) -> f64 {
dtw_multiscale(a, b, resolution, radius)
}
fn coarsen(ts: &[f64], factor: usize) -> Vec<f64> {
ts.chunks(factor)
.map(|chunk| chunk.iter().sum::<f64>() / chunk.len() as f64)
.collect()
}
fn fill_cost_matrix(a: &[f64], b: &[f64], cost: &mut [f64], n: usize, m: usize) {
let d = a[0] - b[0];
cost[0] = d * d;
for j in 1..m {
let d = a[0] - b[j];
cost[j] = cost[j - 1] + d * d;
}
for i in 1..n {
let d = a[i] - b[0];
cost[i * m] = cost[(i - 1) * m] + d * d;
}
for i in 1..n {
let row = i * m;
let prev_row = (i - 1) * m;
for j in 1..m {
let d = a[i] - b[j];
cost[row + j] = d * d
+ cost[prev_row + j]
.min(cost[row + j - 1])
.min(cost[prev_row + j - 1]);
}
}
}
fn backtrack_path(cost: &[f64], n: usize, m: usize) -> Vec<(usize, usize)> {
let mut path = Vec::new();
let mut i = n - 1;
let mut j = m - 1;
path.push((i, j));
while i > 0 || j > 0 {
if i == 0 {
j -= 1;
} else if j == 0 {
i -= 1;
} else {
let diag = cost[(i - 1) * m + j - 1];
let left = cost[i * m + j - 1];
let up = cost[(i - 1) * m + j];
if diag <= left && diag <= up {
i -= 1;
j -= 1;
} else if left <= up {
j -= 1;
} else {
i -= 1;
}
}
path.push((i, j));
}
path.reverse();
path
}
pub fn dtw_path(a: &[f64], b: &[f64]) -> Vec<(usize, usize)> {
let n = a.len();
let m = b.len();
let mut cost = vec![f64::INFINITY; n * m];
fill_cost_matrix(a, b, &mut cost, n, m);
backtrack_path(&cost, n, m)
}
fn project_path_to_band(
coarse_path: &[(usize, usize)],
resolution: usize,
radius: usize,
n: usize,
m: usize,
) -> Vec<(usize, usize)> {
let mut band = vec![(m, 0usize); n];
for &(ci, cj) in coarse_path {
let i_start = ci * resolution;
let i_end = ((ci + 1) * resolution).min(n);
let j_center = cj * resolution + resolution / 2;
for i in i_start..i_end {
let lo = j_center.saturating_sub(radius + resolution);
let hi = (j_center + radius + resolution).min(m - 1);
band[i].0 = band[i].0.min(lo);
band[i].1 = band[i].1.max(hi);
}
}
band[0].0 = 0;
band[n - 1].1 = m - 1;
for i in 0..n {
if band[i].0 > band[i].1 {
let center = (i as f64 * (m - 1) as f64 / (n - 1).max(1) as f64).round() as usize;
band[i] = (center.saturating_sub(radius), (center + radius).min(m - 1));
}
}
band
}
pub fn dtw_cost_matrix(a: &[f64], b: &[f64]) -> Vec<Vec<f64>> {
let n = a.len();
let m = b.len();
let mut cost = vec![f64::INFINITY; n * m];
fill_cost_matrix(a, b, &mut cost, n, m);
cost.chunks_exact(m).map(|row| row.to_vec()).collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dtw_identical() {
let a = vec![1.0, 2.0, 3.0, 4.0];
let d = dtw_classic(&a, &a);
assert!(
d.abs() < 1e-10,
"DTW of identical series should be 0, got {d}"
);
}
#[test]
fn test_dtw_simple() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![1.0, 2.0, 3.0, 4.0];
let d = dtw_classic(&a, &b);
assert!(d > 0.0);
}
#[test]
fn test_dtw_symmetric() {
let a = vec![1.0, 3.0, 5.0, 2.0];
let b = vec![2.0, 4.0, 3.0, 1.0];
let d1 = dtw_classic(&a, &b);
let d2 = dtw_classic(&b, &a);
assert!((d1 - d2).abs() < 1e-10, "DTW should be symmetric");
}
#[test]
fn test_dtw_euclidean_special_case() {
let a = vec![0.0, 1.0, 0.0];
let b = vec![1.0, 0.0, 1.0];
let d = dtw_classic(&a, &b);
assert!(d > 0.0);
}
#[test]
fn test_dtw_sakoe_chiba() {
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let b = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let d = dtw_sakoe_chiba(&a, &b, 1);
assert!(d.abs() < 1e-10);
}
#[test]
fn test_dtw_sakoe_chiba_constrained() {
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let b = vec![5.0, 4.0, 3.0, 2.0, 1.0];
let d_classic = dtw_classic(&a, &b);
let d_constrained = dtw_sakoe_chiba(&a, &b, 1);
assert!(d_constrained >= d_classic - 1e-10);
}
#[test]
fn test_dtw_itakura() {
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let b = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let d = dtw_itakura(&a, &b, 2.0);
assert!(d.abs() < 1e-10);
}
#[test]
fn test_dtw_path_basic() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![1.0, 2.0, 3.0];
let path = dtw_path(&a, &b);
assert_eq!(path, vec![(0, 0), (1, 1), (2, 2)]);
}
#[test]
fn test_dtw_path_endpoints() {
let a = vec![1.0, 2.0, 3.0, 4.0];
let b = vec![1.0, 3.0, 4.0];
let path = dtw_path(&a, &b);
assert_eq!(path[0], (0, 0));
assert_eq!(*path.last().unwrap(), (3, 2));
}
#[test]
fn test_dtw_multiscale() {
let a: Vec<f64> = (0..100).map(|i| (i as f64 * 0.1).sin()).collect();
let b: Vec<f64> = (0..100).map(|i| (i as f64 * 0.1 + 0.5).sin()).collect();
let d_classic = dtw_classic(&a, &b);
let d_multi = dtw_multiscale(&a, &b, 4, 2);
assert!(d_multi > 0.0);
assert!((d_multi - d_classic).abs() / d_classic < 0.5);
}
#[test]
fn test_dtw_metric_trait() {
let metric = Dtw::new();
let a = vec![1.0, 2.0, 3.0];
let b = vec![1.0, 2.0, 3.0];
let d = metric.distance(&a, &b);
assert!(d.abs() < 1e-10);
}
#[test]
fn test_cost_matrix_shape() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![1.0, 2.0, 3.0, 4.0];
let cost = dtw_cost_matrix(&a, &b);
assert_eq!(cost.len(), 3);
assert_eq!(cost[0].len(), 4);
}
}