flow_clustering/clustering/
kmeans.rs1use crate::clustering::{ClusteringError, ClusteringResult};
4use linfa::prelude::*;
5use linfa_clustering::KMeans as LinfaKMeans;
6use ndarray::Array2;
7
8#[derive(Debug, Clone)]
10pub struct KMeansConfig {
11 pub n_clusters: usize,
13 pub max_iterations: usize,
15 pub tolerance: f64,
17 pub seed: Option<u64>,
19}
20
21impl Default for KMeansConfig {
22 fn default() -> Self {
23 Self {
24 n_clusters: 2,
25 max_iterations: 300,
26 tolerance: 1e-4,
27 seed: None,
28 }
29 }
30}
31
32#[derive(Debug, Clone)]
34pub struct KMeansResult {
35 pub assignments: Vec<usize>,
37 pub centroids: Array2<f64>,
39 pub iterations: usize,
41 pub inertia: f64,
43}
44
45pub struct KMeans;
47
48impl KMeans {
49 pub fn fit_from_rows(
60 data_rows: Vec<Vec<f64>>,
61 config: &KMeansConfig,
62 ) -> ClusteringResult<KMeansResult> {
63 if data_rows.is_empty() {
64 return Err(ClusteringError::EmptyData);
65 }
66 let n_features = data_rows[0].len();
67 let n_samples = data_rows.len();
68
69 let flat: Vec<f64> = data_rows.into_iter().flatten().collect();
71 let data = Array2::from_shape_vec((n_samples, n_features), flat).map_err(|e| {
72 ClusteringError::ClusteringFailed(format!("Failed to create array: {:?}", e))
73 })?;
74
75 Self::fit(&data, config)
76 }
77
78 pub fn fit(data: &Array2<f64>, config: &KMeansConfig) -> ClusteringResult<KMeansResult> {
87 if data.nrows() == 0 {
88 return Err(ClusteringError::EmptyData);
89 }
90
91 if data.nrows() < config.n_clusters {
92 return Err(ClusteringError::InsufficientData {
93 min: config.n_clusters,
94 actual: data.nrows(),
95 });
96 }
97
98 let dataset = DatasetBase::new(data.clone(), ());
103 let model = LinfaKMeans::params(config.n_clusters)
104 .max_n_iterations(config.max_iterations as u64)
105 .tolerance(config.tolerance)
106 .fit(&dataset)
107 .map_err(|e| ClusteringError::ClusteringFailed(format!("{}", e)))?;
108
109 let assignments: Vec<usize> = (0..data.nrows())
111 .map(|i| {
112 let point = data.row(i);
113 let mut min_dist = f64::INFINITY;
114 let mut best_cluster = 0;
115 for (j, centroid) in model.centroids().rows().into_iter().enumerate() {
116 let dist: f64 = point
117 .iter()
118 .zip(centroid.iter())
119 .map(|(a, b)| (a - b).powi(2))
120 .sum();
121 if dist < min_dist {
122 min_dist = dist;
123 best_cluster = j;
124 }
125 }
126 best_cluster
127 })
128 .collect();
129
130 let centroids = model.centroids().to_owned();
132
133 let inertia = Self::calculate_inertia(data, ¢roids, &assignments);
135
136 Ok(KMeansResult {
137 assignments,
138 centroids,
139 iterations: config.max_iterations, inertia,
141 })
142 }
143
144 fn calculate_inertia(
146 data: &Array2<f64>,
147 centroids: &Array2<f64>,
148 assignments: &[usize],
149 ) -> f64 {
150 let mut inertia = 0.0;
151 for (i, assignment) in assignments.iter().enumerate() {
152 let point = data.row(i);
153 let centroid = centroids.row(*assignment);
154 let dist_sq: f64 = point
155 .iter()
156 .zip(centroid.iter())
157 .map(|(a, b)| (a - b).powi(2))
158 .sum();
159 inertia += dist_sq;
160 }
161 inertia
162 }
163}