use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
use scirs2_core::numeric::{Float, FromPrimitive};
use std::fmt::Debug;
use serde::{Deserialize, Serialize};
use crate::error::{ClusteringError, Result};
#[allow(dead_code)]
pub fn dtw_distance<F>(
series1: ArrayView1<F>,
series2: ArrayView1<F>,
window: Option<usize>,
) -> Result<F>
where
F: Float + FromPrimitive + Debug + 'static,
{
let n = series1.len();
let m = series2.len();
if n == 0 || m == 0 {
return Err(ClusteringError::InvalidInput(
"Time series cannot be empty".to_string(),
));
}
let mut dtw = Array2::from_elem((n + 1, m + 1), F::infinity());
dtw[[0, 0]] = F::zero();
let effective_window = window.unwrap_or(m.max(n));
for i in 1..=n {
let start_j = if effective_window < i {
i - effective_window
} else {
1
};
let end_j = (i + effective_window).min(m + 1);
for j in start_j..end_j {
if j <= m {
let cost = (series1[i - 1] - series2[j - 1]).abs();
let candidates = [
dtw[[i - 1, j]], dtw[[i, j - 1]], dtw[[i - 1, j - 1]], ];
let min_prev = candidates.iter().fold(F::infinity(), |acc, &x| acc.min(x));
dtw[[i, j]] = cost + min_prev;
}
}
}
Ok(dtw[[n, m]])
}
#[allow(dead_code)]
pub fn dtw_distance_custom<F, D>(
series1: ArrayView1<F>,
series2: ArrayView1<F>,
local_distance: D,
window: Option<usize>,
) -> Result<F>
where
F: Float + FromPrimitive + Debug + 'static,
D: Fn(F, F) -> F,
{
let n = series1.len();
let m = series2.len();
if n == 0 || m == 0 {
return Err(ClusteringError::InvalidInput(
"Time series cannot be empty".to_string(),
));
}
let mut dtw = Array2::from_elem((n + 1, m + 1), F::infinity());
dtw[[0, 0]] = F::zero();
let effective_window = window.unwrap_or(m.max(n));
for i in 1..=n {
let start_j = if effective_window < i {
i - effective_window
} else {
1
};
let end_j = (i + effective_window).min(m + 1);
for j in start_j..end_j {
if j <= m {
let cost = local_distance(series1[i - 1], series2[j - 1]);
let candidates = [dtw[[i - 1, j]], dtw[[i, j - 1]], dtw[[i - 1, j - 1]]];
let min_prev = candidates.iter().fold(F::infinity(), |acc, &x| acc.min(x));
dtw[[i, j]] = cost + min_prev;
}
}
}
Ok(dtw[[n, m]])
}
#[allow(dead_code)]
pub fn soft_dtw_distance<F>(series1: ArrayView1<F>, series2: ArrayView1<F>, gamma: F) -> Result<F>
where
F: Float + FromPrimitive + Debug + 'static,
{
let n = series1.len();
let m = series2.len();
if n == 0 || m == 0 {
return Err(ClusteringError::InvalidInput(
"Time series cannot be empty".to_string(),
));
}
if gamma <= F::zero() {
return Err(ClusteringError::InvalidInput(
"Gamma must be positive".to_string(),
));
}
let mut dtw = Array2::from_elem((n + 1, m + 1), F::infinity());
dtw[[0, 0]] = F::zero();
for i in 1..=n {
for j in 1..=m {
let cost = (series1[i - 1] - series2[j - 1]).powi(2);
let candidates = [dtw[[i - 1, j]], dtw[[i, j - 1]], dtw[[i - 1, j - 1]]];
let min_val = candidates.iter().fold(F::infinity(), |acc, &x| acc.min(x));
let sum_exp = candidates
.iter()
.map(|&x| (-(x - min_val) / gamma).exp())
.fold(F::zero(), |acc, x| acc + x);
let soft_min = min_val - gamma * sum_exp.ln();
dtw[[i, j]] = cost + soft_min;
}
}
Ok(dtw[[n, m]])
}
#[allow(dead_code)]
pub fn dtw_k_medoids<F>(
time_series: ArrayView2<F>,
k: usize,
max_iterations: usize,
window: Option<usize>,
) -> Result<(Array1<usize>, Array1<usize>)>
where
F: Float + FromPrimitive + Debug + 'static,
{
let n_series = time_series.nrows();
if k > n_series {
return Err(ClusteringError::InvalidInput(
"Number of clusters cannot exceed number of time _series".to_string(),
));
}
if n_series == 0 {
return Err(ClusteringError::InvalidInput(
"No time _series provided".to_string(),
));
}
let mut medoids: Array1<usize> = Array1::from_iter(0..k);
let mut assignments = Array1::zeros(n_series);
for _iteration in 0..max_iterations {
let mut changed = false;
for i in 0..n_series {
let mut min_distance = F::infinity();
let mut best_cluster = 0;
for (cluster_id, &medoid_idx) in medoids.iter().enumerate() {
let distance =
dtw_distance(time_series.row(i), time_series.row(medoid_idx), window)?;
if distance < min_distance {
min_distance = distance;
best_cluster = cluster_id;
}
}
if assignments[i] != best_cluster {
assignments[i] = best_cluster;
changed = true;
}
}
for cluster_id in 0..k {
let cluster_members: Vec<usize> = assignments
.iter()
.enumerate()
.filter(|(_, &assignment)| assignment == cluster_id)
.map(|(idx, _)| idx)
.collect();
if !cluster_members.is_empty() {
let mut best_medoid = medoids[cluster_id];
let mut min_total_distance = F::infinity();
for &candidate in &cluster_members {
let mut total_distance = F::zero();
for &member in &cluster_members {
if candidate != member {
let distance = dtw_distance(
time_series.row(candidate),
time_series.row(member),
window,
)?;
total_distance = total_distance + distance;
}
}
if total_distance < min_total_distance {
min_total_distance = total_distance;
best_medoid = candidate;
}
}
if medoids[cluster_id] != best_medoid {
medoids[cluster_id] = best_medoid;
changed = true;
}
}
}
if !changed {
break;
}
}
Ok((medoids, assignments))
}
#[allow(dead_code)]
pub fn dtw_hierarchical_clustering<F>(
time_series: ArrayView2<F>,
window: Option<usize>,
) -> Result<Array2<F>>
where
F: Float + FromPrimitive + Debug + 'static,
{
let n_series = time_series.nrows();
if n_series < 2 {
return Err(ClusteringError::InvalidInput(
"Need at least 2 time _series for clustering".to_string(),
));
}
let mut distances = Array2::zeros((n_series, n_series));
for i in 0..n_series {
for j in (i + 1)..n_series {
let distance = dtw_distance(time_series.row(i), time_series.row(j), window)?;
distances[[i, j]] = distance;
distances[[j, i]] = distance;
}
}
let mut clusters: Vec<Vec<usize>> = (0..n_series).map(|i| vec![i]).collect();
let mut linkage = Vec::new();
let mut cluster_id = n_series;
while clusters.len() > 1 {
let mut min_distance = F::infinity();
let mut merge_i = 0;
let mut merge_j = 1;
for i in 0..clusters.len() {
for j in (i + 1)..clusters.len() {
let mut max_dist = F::zero();
for &point_i in &clusters[i] {
for &point_j in &clusters[j] {
max_dist = max_dist.max(distances[[point_i, point_j]]);
}
}
if max_dist < min_distance {
min_distance = max_dist;
merge_i = i;
merge_j = j;
}
}
}
let cluster_i_size = clusters[merge_i].len();
let cluster_j_size = clusters[merge_j].len();
linkage.push([
F::from(if merge_i < n_series {
merge_i
} else {
n_series + merge_i
})
.expect("Operation failed"),
F::from(if merge_j < n_series {
merge_j
} else {
n_series + merge_j
})
.expect("Operation failed"),
min_distance,
F::from(cluster_i_size + cluster_j_size).expect("Failed to convert to float"),
]);
let mut new_cluster = clusters[merge_i].clone();
new_cluster.extend(&clusters[merge_j]);
let (first, second) = if merge_i > merge_j {
(merge_i, merge_j)
} else {
(merge_j, merge_i)
};
clusters.remove(first);
clusters.remove(second);
clusters.push(new_cluster);
#[allow(unused_assignments)]
{
cluster_id += 1;
}
}
let linkage_array =
Array2::from_shape_vec((linkage.len(), 4), linkage.into_iter().flatten().collect())
.map_err(|_| {
ClusteringError::ComputationError("Failed to create linkage matrix".to_string())
})?;
Ok(linkage_array)
}
#[allow(dead_code)]
pub fn dtw_k_means<F>(
time_series: ArrayView2<F>,
k: usize,
max_iterations: usize,
tolerance: F,
) -> Result<(Array2<F>, Array1<usize>)>
where
F: Float + FromPrimitive + Debug + 'static,
{
let n_series = time_series.nrows();
let series_length = time_series.ncols();
if k > n_series {
return Err(ClusteringError::InvalidInput(
"Number of clusters cannot exceed number of time _series".to_string(),
));
}
let mut centers = Array2::zeros((k, series_length));
for i in 0..k {
centers.row_mut(i).assign(&time_series.row(i));
}
let mut assignments = Array1::zeros(n_series);
for _iteration in 0..max_iterations {
let mut changed = false;
for i in 0..n_series {
let mut min_distance = F::infinity();
let mut best_cluster = 0;
for j in 0..k {
let distance = dtw_distance(time_series.row(i), centers.row(j), None)?;
if distance < min_distance {
min_distance = distance;
best_cluster = j;
}
}
if assignments[i] != best_cluster {
assignments[i] = best_cluster;
changed = true;
}
}
if !changed {
break;
}
let mut center_changed = false;
for cluster_id in 0..k {
let cluster_members: Vec<usize> = assignments
.iter()
.enumerate()
.filter(|(_, &assignment)| assignment == cluster_id)
.map(|(idx, _)| idx)
.collect();
if !cluster_members.is_empty() {
let new_center = dtw_barycenter_averaging(
&time_series.select(Axis(0), &cluster_members),
10,
tolerance,
)?;
let center_distance =
dtw_distance(centers.row(cluster_id), new_center.view(), None)?;
if center_distance > tolerance {
center_changed = true;
}
centers.row_mut(cluster_id).assign(&new_center);
}
}
if !center_changed {
break;
}
}
Ok((centers, assignments))
}
#[allow(dead_code)]
pub fn dtw_barycenter_averaging<F>(
time_series: &Array2<F>,
max_iterations: usize,
tolerance: F,
) -> Result<Array1<F>>
where
F: Float + FromPrimitive + Debug + 'static,
{
let n_series = time_series.nrows();
let series_length = time_series.ncols();
if n_series == 0 {
return Err(ClusteringError::InvalidInput(
"No time _series provided".to_string(),
));
}
if n_series == 1 {
return Ok(time_series.row(0).to_owned());
}
let mut barycenter = time_series.mean_axis(Axis(0)).expect("Operation failed");
for _iteration in 0..max_iterations {
let mut new_barycenter = Array1::zeros(series_length);
let mut weights = Array1::zeros(series_length);
for i in 0..n_series {
let (aligned_series, alignment_weights) =
dtw_align_series(time_series.row(i), barycenter.view())?;
new_barycenter = new_barycenter + aligned_series;
weights = weights + alignment_weights;
}
for i in 0..series_length {
if weights[i] > F::zero() {
new_barycenter[i] = new_barycenter[i] / weights[i];
}
}
let change = dtw_distance(barycenter.view(), new_barycenter.view(), None)?;
if change < tolerance {
break;
}
barycenter = new_barycenter;
}
Ok(barycenter)
}
#[allow(dead_code)]
fn dtw_align_series<F>(
series: ArrayView1<F>,
reference: ArrayView1<F>,
) -> Result<(Array1<F>, Array1<F>)>
where
F: Float + FromPrimitive + Debug + 'static,
{
let n = series.len();
let m = reference.len();
let mut dtw = Array2::from_elem((n + 1, m + 1), F::infinity());
dtw[[0, 0]] = F::zero();
for i in 1..=n {
for j in 1..=m {
let cost = (series[i - 1] - reference[j - 1]).abs();
let min_prev = [dtw[[i - 1, j]], dtw[[i, j - 1]], dtw[[i - 1, j - 1]]]
.iter()
.fold(F::infinity(), |acc, &x| acc.min(x));
dtw[[i, j]] = cost + min_prev;
}
}
let mut i = n;
let mut j = m;
let mut aligned_series = Array1::zeros(m);
let mut weights = Array1::zeros(m);
while i > 0 && j > 0 {
aligned_series[j - 1] = aligned_series[j - 1] + series[i - 1];
weights[j - 1] = weights[j - 1] + F::one();
let candidates = [
(dtw[[i - 1, j - 1]], (i - 1, j - 1)), (dtw[[i - 1, j]], (i - 1, j)), (dtw[[i, j - 1]], (i, j - 1)), ];
let (_, (next_i, next_j)) = candidates
.iter()
.min_by(|a, b| a.0.partial_cmp(&b.0).expect("Operation failed"))
.expect("Operation failed");
i = *next_i;
j = *next_j;
}
Ok((aligned_series, weights))
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TimeSeriesClusteringConfig {
pub algorithm: TimeSeriesAlgorithm,
pub n_clusters: usize,
pub max_iterations: usize,
pub tolerance: f64,
pub dtw_window: Option<usize>,
pub soft_dtw_gamma: Option<f64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum TimeSeriesAlgorithm {
DTWKMedoids,
DTWKMeans,
DTWHierarchical,
}
impl Default for TimeSeriesClusteringConfig {
fn default() -> Self {
Self {
algorithm: TimeSeriesAlgorithm::DTWKMedoids,
n_clusters: 3,
max_iterations: 100,
tolerance: 1e-4,
dtw_window: None,
soft_dtw_gamma: None,
}
}
}
#[allow(dead_code)]
pub fn time_series_clustering<F>(
time_series: ArrayView2<F>,
config: &TimeSeriesClusteringConfig,
) -> Result<Array1<usize>>
where
F: Float + FromPrimitive + Debug + 'static,
{
match config.algorithm {
TimeSeriesAlgorithm::DTWKMedoids => {
let (_, assignments) = dtw_k_medoids(
time_series,
config.n_clusters,
config.max_iterations,
config.dtw_window,
)?;
Ok(assignments)
}
TimeSeriesAlgorithm::DTWKMeans => {
let tolerance = F::from(config.tolerance).expect("Failed to convert to float");
let (_, assignments) = dtw_k_means(
time_series,
config.n_clusters,
config.max_iterations,
tolerance,
)?;
Ok(assignments)
}
TimeSeriesAlgorithm::DTWHierarchical => {
let _linkage = dtw_hierarchical_clustering(time_series, config.dtw_window)?;
let n_series = time_series.nrows();
let mut assignments = Array1::from_iter(0..n_series);
for i in 0..n_series {
assignments[i] = i % config.n_clusters;
}
Ok(assignments)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array2;
#[test]
fn test_dtw_distance() {
let series1 = Array1::from_vec(vec![1.0, 2.0, 3.0, 2.0, 1.0]);
let series2 = Array1::from_vec(vec![1.0, 2.0, 2.0, 3.0, 2.0, 1.0]);
let distance =
dtw_distance(series1.view(), series2.view(), None).expect("Operation failed");
assert!(distance >= 0.0);
}
#[test]
fn test_dtw_identical_series() {
let series = Array1::from_vec(vec![1.0, 2.0, 3.0, 2.0, 1.0]);
let distance = dtw_distance(series.view(), series.view(), None).expect("Operation failed");
assert_eq!(distance, 0.0);
}
#[test]
fn test_dtw_k_medoids() {
let time_series = Array2::from_shape_vec(
(4, 5),
vec![
1.0, 2.0, 3.0, 2.0, 1.0, 1.1, 2.1, 3.1, 2.1, 1.1, 5.0, 6.0, 7.0, 6.0, 5.0, 5.1,
6.1, 7.1, 6.1, 5.1,
],
)
.expect("Operation failed");
let (medoids, assignments) =
dtw_k_medoids(time_series.view(), 2, 10, None).expect("Operation failed");
assert_eq!(medoids.len(), 2);
assert_eq!(assignments.len(), 4);
assert_eq!(assignments[0], assignments[1]);
assert_eq!(assignments[2], assignments[3]);
assert_ne!(assignments[0], assignments[2]);
}
#[test]
fn test_soft_dtw_distance() {
let series1 = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let series2 = Array1::from_vec(vec![1.0, 2.5, 3.0]);
let distance =
soft_dtw_distance(series1.view(), series2.view(), 0.1).expect("Operation failed");
assert!(distance >= 0.0);
}
#[test]
fn test_dtw_barycenter_averaging() {
let time_series = Array2::from_shape_vec(
(3, 4),
vec![1.0, 2.0, 3.0, 2.0, 1.1, 2.1, 3.1, 2.1, 0.9, 1.9, 2.9, 1.9],
)
.expect("Operation failed");
let barycenter =
dtw_barycenter_averaging(&time_series, 10, 1e-3).expect("Operation failed");
assert_eq!(barycenter.len(), 4);
let mean_series = time_series.mean_axis(Axis(0)).expect("Operation failed");
for i in 0..4 {
assert!((barycenter[i] - mean_series[i]).abs() < 0.5);
}
}
#[test]
fn test_time_series_clustering_config() {
let config = TimeSeriesClusteringConfig::default();
assert_eq!(config.n_clusters, 3);
assert_eq!(config.max_iterations, 100);
let time_series = Array2::from_shape_vec(
(4, 5),
vec![
1.0, 2.0, 3.0, 2.0, 1.0, 1.1, 2.1, 3.1, 2.1, 1.1, 5.0, 6.0, 7.0, 6.0, 5.0, 5.1,
6.1, 7.1, 6.1, 5.1,
],
)
.expect("Operation failed");
let assignments =
time_series_clustering(time_series.view(), &config).expect("Operation failed");
assert_eq!(assignments.len(), 4);
}
}