oxigdal_analytics/clustering/
kmeans.rs1use crate::error::{AnalyticsError, Result};
7use scirs2_core::ndarray::{Array1, Array2, ArrayView2, Axis};
8use scirs2_core::random::Rng;
9
10#[derive(Debug, Clone)]
12pub struct KMeansResult {
13 pub labels: Array1<i32>,
15 pub centers: Array2<f64>,
17 pub inertia: f64,
19 pub n_iterations: usize,
21 pub converged: bool,
23}
24
25pub struct KMeansClusterer {
27 n_clusters: usize,
28 max_iterations: usize,
29 tolerance: f64,
30 init_method: InitMethod,
31}
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum InitMethod {
36 Random,
38 KMeansPlusPlus,
40}
41
42impl KMeansClusterer {
43 pub fn new(n_clusters: usize, max_iterations: usize, tolerance: f64) -> Self {
50 Self {
51 n_clusters,
52 max_iterations,
53 tolerance,
54 init_method: InitMethod::KMeansPlusPlus,
55 }
56 }
57
58 pub fn with_init_method(mut self, method: InitMethod) -> Self {
60 self.init_method = method;
61 self
62 }
63
64 pub fn fit(&self, data: &ArrayView2<f64>) -> Result<KMeansResult> {
72 let (n_samples, _n_features) = data.dim();
73
74 if n_samples < self.n_clusters {
75 return Err(AnalyticsError::insufficient_data(format!(
76 "Need at least {} samples for {} clusters",
77 self.n_clusters, self.n_clusters
78 )));
79 }
80
81 let mut centers = match self.init_method {
83 InitMethod::Random => self.initialize_random(data)?,
84 InitMethod::KMeansPlusPlus => self.initialize_kmeans_plus_plus(data)?,
85 };
86
87 let mut labels = Array1::zeros(n_samples);
88 let mut converged = false;
89
90 for iteration in 0..self.max_iterations {
92 let mut changed = false;
94 for i in 0..n_samples {
95 let point = data.row(i);
96 let nearest = self.find_nearest_center(&point, ¢ers)?;
97 if labels[i] != nearest {
98 labels[i] = nearest;
99 changed = true;
100 }
101 }
102
103 if !changed {
104 converged = true;
105 tracing::debug!("K-means converged after {} iterations", iteration);
106 break;
107 }
108
109 let old_centers = centers.clone();
111 centers = self.update_centers(data, &labels)?;
112
113 let max_movement = self.max_center_movement(&old_centers, ¢ers)?;
115 if max_movement < self.tolerance {
116 converged = true;
117 tracing::debug!(
118 "K-means converged after {} iterations (max movement: {})",
119 iteration,
120 max_movement
121 );
122 break;
123 }
124 }
125
126 let inertia = self.calculate_inertia(data, &labels, ¢ers)?;
128
129 Ok(KMeansResult {
130 labels,
131 centers,
132 inertia,
133 n_iterations: self.max_iterations,
134 converged,
135 })
136 }
137
138 fn initialize_random(&self, data: &ArrayView2<f64>) -> Result<Array2<f64>> {
140 let (n_samples, n_features) = data.dim();
141 let mut rng = scirs2_core::random::thread_rng();
142
143 let mut centers = Array2::zeros((self.n_clusters, n_features));
144 let mut used_indices = Vec::new();
145
146 for i in 0..self.n_clusters {
147 let idx = loop {
149 let candidate = rng.gen_range(0..n_samples);
150 if !used_indices.contains(&candidate) {
151 break candidate;
152 }
153 };
154 used_indices.push(idx);
155
156 centers.row_mut(i).assign(&data.row(idx));
157 }
158
159 Ok(centers)
160 }
161
162 fn initialize_kmeans_plus_plus(&self, data: &ArrayView2<f64>) -> Result<Array2<f64>> {
167 let (n_samples, n_features) = data.dim();
168 let mut rng = scirs2_core::random::thread_rng();
169
170 let mut centers = Array2::zeros((self.n_clusters, n_features));
171
172 let first_idx = rng.gen_range(0..n_samples);
174 centers.row_mut(0).assign(&data.row(first_idx));
175
176 for i in 1..self.n_clusters {
178 let mut distances = Vec::with_capacity(n_samples);
180 let mut distance_sum = 0.0;
181
182 for j in 0..n_samples {
183 let point = data.row(j);
184 let mut min_dist = f64::INFINITY;
185
186 for k in 0..i {
187 let center = centers.row(k);
188 let dist = euclidean_distance_squared(&point, ¢er)?;
189 min_dist = min_dist.min(dist);
190 }
191
192 distances.push(min_dist);
193 distance_sum += min_dist;
194 }
195
196 let threshold = rng.gen_range(0.0..distance_sum);
198 let mut cumsum = 0.0;
199 let mut next_idx = 0;
200
201 for (j, &dist) in distances.iter().enumerate() {
202 cumsum += dist;
203 if cumsum >= threshold {
204 next_idx = j;
205 break;
206 }
207 }
208
209 centers.row_mut(i).assign(&data.row(next_idx));
210 }
211
212 Ok(centers)
213 }
214
215 fn find_nearest_center(
217 &self,
218 point: &scirs2_core::ndarray::ArrayView1<f64>,
219 centers: &Array2<f64>,
220 ) -> Result<i32> {
221 let mut min_dist = f64::INFINITY;
222 let mut nearest = 0;
223
224 for (i, center) in centers.axis_iter(Axis(0)).enumerate() {
225 let dist = euclidean_distance_squared(point, ¢er)?;
226 if dist < min_dist {
227 min_dist = dist;
228 nearest = i;
229 }
230 }
231
232 Ok(nearest as i32)
233 }
234
235 fn update_centers(&self, data: &ArrayView2<f64>, labels: &Array1<i32>) -> Result<Array2<f64>> {
237 let (n_samples, n_features) = data.dim();
238 let mut new_centers = Array2::zeros((self.n_clusters, n_features));
239 let mut counts = vec![0; self.n_clusters];
240
241 for i in 0..n_samples {
243 let cluster = labels[i] as usize;
244 if cluster < self.n_clusters {
245 for j in 0..n_features {
246 new_centers[[cluster, j]] += data[[i, j]];
247 }
248 counts[cluster] += 1;
249 }
250 }
251
252 for i in 0..self.n_clusters {
254 if counts[i] > 0 {
255 for j in 0..n_features {
256 new_centers[[i, j]] /= counts[i] as f64;
257 }
258 } else {
259 tracing::warn!("Cluster {} is empty, reinitializing", i);
261 }
263 }
264
265 Ok(new_centers)
266 }
267
268 fn max_center_movement(
270 &self,
271 old_centers: &Array2<f64>,
272 new_centers: &Array2<f64>,
273 ) -> Result<f64> {
274 let mut max_dist: f64 = 0.0;
275
276 for i in 0..self.n_clusters {
277 let dist = euclidean_distance_squared(&old_centers.row(i), &new_centers.row(i))?;
278 max_dist = max_dist.max(dist);
279 }
280
281 Ok(max_dist.sqrt())
282 }
283
284 fn calculate_inertia(
286 &self,
287 data: &ArrayView2<f64>,
288 labels: &Array1<i32>,
289 centers: &Array2<f64>,
290 ) -> Result<f64> {
291 let mut inertia = 0.0;
292
293 for (i, &label) in labels.iter().enumerate() {
294 let cluster = label as usize;
295 if cluster < self.n_clusters {
296 let point = data.row(i);
297 let center = centers.row(cluster);
298 inertia += euclidean_distance_squared(&point, ¢er)?;
299 }
300 }
301
302 Ok(inertia)
303 }
304}
305
306fn euclidean_distance_squared(
308 p1: &scirs2_core::ndarray::ArrayView1<f64>,
309 p2: &scirs2_core::ndarray::ArrayView1<f64>,
310) -> Result<f64> {
311 if p1.len() != p2.len() {
312 return Err(AnalyticsError::dimension_mismatch(
313 format!("{}", p1.len()),
314 format!("{}", p2.len()),
315 ));
316 }
317
318 let dist_sq: f64 = p1.iter().zip(p2.iter()).map(|(a, b)| (a - b).powi(2)).sum();
319
320 Ok(dist_sq)
321}
322
323#[cfg(test)]
324mod tests {
325 use super::*;
326 use approx::assert_abs_diff_eq;
327 use scirs2_core::ndarray::array;
328
329 #[test]
330 fn test_kmeans_simple() {
331 let data = array![
333 [0.0, 0.0],
334 [0.1, 0.1],
335 [0.2, 0.0],
336 [10.0, 10.0],
337 [10.1, 10.1],
338 [10.0, 10.2],
339 ];
340
341 let clusterer = KMeansClusterer::new(2, 100, 1e-4);
342 let result = clusterer
343 .fit(&data.view())
344 .expect("K-means clustering should succeed for valid data");
345
346 assert_eq!(result.labels.len(), 6);
347 assert_eq!(result.centers.nrows(), 2);
348 assert!(result.converged);
349
350 assert_eq!(result.labels[0], result.labels[1]);
352 assert_eq!(result.labels[3], result.labels[4]);
353 assert_ne!(result.labels[0], result.labels[3]);
354 }
355
356 #[test]
357 fn test_kmeans_insufficient_data() {
358 let data = array![[1.0, 2.0]];
359 let clusterer = KMeansClusterer::new(2, 100, 1e-4);
360 let result = clusterer.fit(&data.view());
361
362 assert!(result.is_err());
363 }
364
365 #[test]
366 fn test_kmeans_plus_plus_init() {
367 let data = array![[0.0, 0.0], [1.0, 1.0], [10.0, 10.0], [11.0, 11.0],];
368
369 let clusterer =
370 KMeansClusterer::new(2, 100, 1e-4).with_init_method(InitMethod::KMeansPlusPlus);
371 let result = clusterer
372 .fit(&data.view())
373 .expect("K-means++ initialization should succeed");
374
375 assert!(result.converged);
376 assert_eq!(result.labels.len(), 4);
377 }
378}