use crate::distance::base::DistanceCalculator;
pub fn dtw<D: DistanceCalculator>(
calculator: &D,
use_full_matrix: bool,
) -> crate::distance::DpResult {
let n0 = calculator.len_seq1();
let n1 = calculator.len_seq2();
if n0 == 0 || n1 == 0 {
return crate::distance::DpResult::new(f64::MAX);
}
if use_full_matrix {
let mut c = vec![f64::INFINITY; (n0 + 1) * (n1 + 1)];
c[0] = 0.0;
for i in 1..=n0 {
for j in 1..=n1 {
let dist = calculator.dis_between(i - 1, j - 1);
let min_prev = c[(i - 1) * (n1 + 1) + (j - 1)]
.min(c[(i - 1) * (n1 + 1) + j])
.min(c[i * (n1 + 1) + (j - 1)]);
c[i * (n1 + 1) + j] = dist + min_prev;
}
}
crate::distance::DpResult::with_matrix(c[n0 * (n1 + 1) + n1], c)
} else {
let mut prev_row = vec![f64::INFINITY; n1 + 1];
let mut curr_row = vec![f64::INFINITY; n1 + 1];
prev_row[0] = 0.0;
for i in 1..=n0 {
curr_row[0] = f64::INFINITY;
for j in 1..=n1 {
let dist = calculator.dis_between(i - 1, j - 1);
let min_prev = prev_row[j - 1].min(prev_row[j]).min(curr_row[j - 1]);
curr_row[j] = dist + min_prev;
}
std::mem::swap(&mut prev_row, &mut curr_row);
}
crate::distance::DpResult::new(prev_row[n1])
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::distance::base::TrajectoryCalculator;
use crate::distance::distance_type::DistanceType;
#[test]
fn test_dtw_euclidean_simple() {
let t0: Vec<[f64; 2]> = vec![[0.0, 0.0], [1.0, 1.0]];
let t1: Vec<[f64; 2]> = vec![[0.0, 1.0], [1.0, 0.0]];
let calculator = TrajectoryCalculator::new(&t0, &t1, DistanceType::Euclidean);
let result = dtw(&calculator, false);
println!("DTW Euclidean distance: {}", result.distance);
assert!(result.distance > 0.0);
assert!(result.matrix.is_none());
}
#[test]
fn test_dtw_euclidean_identical() {
let t0: Vec<[f64; 2]> = vec![[0.0, 0.0], [1.0, 1.0]];
let t1: Vec<[f64; 2]> = vec![[0.0, 0.0], [1.0, 1.0]];
let calculator = TrajectoryCalculator::new(&t0, &t1, DistanceType::Euclidean);
let result = dtw(&calculator, false);
println!(
"DTW Euclidean distance for identical trajectories: {}",
result.distance
);
assert!(result.distance < 1e-6);
}
#[test]
fn test_dtw_spherical_simple() {
let t0: Vec<[f64; 2]> = vec![[0.0, 0.0], [1.0, 1.0]];
let t1: Vec<[f64; 2]> = vec![[0.0, 1.0], [1.0, 0.0]];
let calculator = TrajectoryCalculator::new(&t0, &t1, DistanceType::Spherical);
let result = dtw(&calculator, false);
println!("DTW Spherical distance: {}", result.distance);
assert!(result.distance > 0.0);
}
#[test]
fn test_dtw_with_both_distance_types() {
let t0: Vec<[f64; 2]> = vec![[0.0, 0.0], [1.0, 1.0]];
let t1: Vec<[f64; 2]> = vec![[0.0, 1.0], [1.0, 0.0]];
let euclidean_calc = TrajectoryCalculator::new(&t0, &t1, DistanceType::Euclidean);
let spherical_calc = TrajectoryCalculator::new(&t0, &t1, DistanceType::Spherical);
let euclidean_result = dtw(&euclidean_calc, false);
let spherical_result = dtw(&spherical_calc, false);
assert!(euclidean_result.distance > 0.0);
assert!(spherical_result.distance > 0.0);
}
#[test]
fn test_dtw_consistency_between_modes() {
let t0: Vec<[f64; 2]> = vec![[0.0, 0.0], [1.0, 1.0], [2.0, 2.0]];
let t1: Vec<[f64; 2]> = vec![[0.0, 1.0], [1.0, 0.0], [2.0, 3.0]];
let calculator = TrajectoryCalculator::new(&t0, &t1, DistanceType::Euclidean);
let result_optimized = dtw(&calculator, false);
let result_full = dtw(&calculator, true);
assert!((result_optimized.distance - result_full.distance).abs() < 1e-10);
assert!(result_optimized.matrix.is_none());
assert!(result_full.matrix.is_some());
if let Some(matrix) = result_full.matrix {
assert_eq!(matrix.len(), 16);
}
}
#[test]
fn test_dtw_empty_trajectories() {
let t0: Vec<[f64; 2]> = vec![];
let t1: Vec<[f64; 2]> = vec![[0.0, 0.0]];
let calculator = TrajectoryCalculator::new(&t0, &t1, DistanceType::Euclidean);
let result = dtw(&calculator, false);
assert_eq!(result.distance, f64::MAX);
assert!(result.matrix.is_none());
}
#[test]
fn test_dtw_matrix_content() {
let t0: Vec<[f64; 2]> = vec![[0.0, 0.0], [1.0, 1.0]];
let t1: Vec<[f64; 2]> = vec![[0.0, 1.0], [1.0, 0.0]];
let calculator = TrajectoryCalculator::new(&t0, &t1, DistanceType::Euclidean);
let result = dtw(&calculator, true);
assert!(result.matrix.is_some());
if let Some(matrix) = result.matrix {
assert_eq!(matrix.len(), 9);
assert_eq!(matrix[0], 0.0);
assert!((matrix[8] - result.distance).abs() < 1e-10);
}
}
#[test]
fn test_dtw_with_precomputed_distances() {
let t0: Vec<[f64; 2]> = vec![[0.0, 0.0], [1.0, 1.0]];
let t1: Vec<[f64; 2]> = vec![[0.0, 1.0], [1.0, 0.0]];
let distance_matrix =
crate::distance::utils::precompute_distance_matrix(&t0, &t1, DistanceType::Euclidean);
let calculator = crate::distance::base::PrecomputedDistanceCalculator::new(
&distance_matrix,
t0.len(),
t1.len(),
);
let result = dtw(&calculator, false);
println!(
"DTW distance with precomputed distances: {}",
result.distance
);
assert!(result.distance > 0.0);
}
#[test]
fn test_dtw_with_precomputed_distances_spherical() {
let t0: Vec<[f64; 2]> = vec![[0.0, 0.0], [1.0, 1.0]];
let t1: Vec<[f64; 2]> = vec![[0.0, 1.0], [1.0, 0.0]];
let distance_matrix =
crate::distance::utils::precompute_distance_matrix(&t0, &t1, DistanceType::Spherical);
let calculator = crate::distance::base::PrecomputedDistanceCalculator::new(
&distance_matrix,
t0.len(),
t1.len(),
);
let result = dtw(&calculator, false);
println!(
"DTW distance with precomputed spherical distances: {}",
result.distance
);
assert!(result.distance > 0.0);
}
#[test]
fn test_dtw_precomputed_vs_trajectory_calculator() {
let t0: Vec<[f64; 2]> = vec![[0.0, 0.0], [1.0, 1.0], [2.0, 2.0]];
let t1: Vec<[f64; 2]> = vec![[0.0, 1.0], [1.0, 0.0], [2.0, 3.0]];
let traj_calc = TrajectoryCalculator::new(&t0, &t1, DistanceType::Euclidean);
let result_traj = dtw(&traj_calc, false);
let distance_matrix =
crate::distance::utils::precompute_distance_matrix(&t0, &t1, DistanceType::Euclidean);
let precomp_calc = crate::distance::base::PrecomputedDistanceCalculator::new(
&distance_matrix,
t0.len(),
t1.len(),
);
let result_precomp = dtw(&precomp_calc, false);
assert!((result_traj.distance - result_precomp.distance).abs() < 1e-10);
}
}