Skip to main content

oxigdal_analytics/clustering/
kmeans.rs

1//! K-Means Clustering Algorithm
2//!
3//! Implementation of the K-means clustering algorithm using Lloyd's algorithm.
4//! Suitable for image classification and general clustering tasks.
5
6use crate::error::{AnalyticsError, Result};
7use scirs2_core::ndarray::{Array1, Array2, ArrayView2, Axis};
8use scirs2_core::random::Rng;
9
10/// K-means clustering result
11#[derive(Debug, Clone)]
12pub struct KMeansResult {
13    /// Cluster assignments for each point
14    pub labels: Array1<i32>,
15    /// Cluster centers (k × n_features)
16    pub centers: Array2<f64>,
17    /// Within-cluster sum of squares
18    pub inertia: f64,
19    /// Number of iterations performed
20    pub n_iterations: usize,
21    /// Whether the algorithm converged
22    pub converged: bool,
23}
24
25/// K-means clustering algorithm
26pub struct KMeansClusterer {
27    n_clusters: usize,
28    max_iterations: usize,
29    tolerance: f64,
30    init_method: InitMethod,
31}
32
33/// Initialization methods for K-means
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum InitMethod {
36    /// Random initialization
37    Random,
38    /// K-means++ initialization (better initial centers)
39    KMeansPlusPlus,
40}
41
42impl KMeansClusterer {
43    /// Create a new K-means clusterer
44    ///
45    /// # Arguments
46    /// * `n_clusters` - Number of clusters
47    /// * `max_iterations` - Maximum number of iterations
48    /// * `tolerance` - Convergence tolerance for center movement
49    pub fn new(n_clusters: usize, max_iterations: usize, tolerance: f64) -> Self {
50        Self {
51            n_clusters,
52            max_iterations,
53            tolerance,
54            init_method: InitMethod::KMeansPlusPlus,
55        }
56    }
57
58    /// Set initialization method
59    pub fn with_init_method(mut self, method: InitMethod) -> Self {
60        self.init_method = method;
61        self
62    }
63
64    /// Fit K-means clustering to data
65    ///
66    /// # Arguments
67    /// * `data` - Feature matrix (n_samples × n_features)
68    ///
69    /// # Errors
70    /// Returns error if clustering fails or data is invalid
71    pub fn fit(&self, data: &ArrayView2<f64>) -> Result<KMeansResult> {
72        let (n_samples, _n_features) = data.dim();
73
74        if n_samples < self.n_clusters {
75            return Err(AnalyticsError::insufficient_data(format!(
76                "Need at least {} samples for {} clusters",
77                self.n_clusters, self.n_clusters
78            )));
79        }
80
81        // Initialize centers
82        let mut centers = match self.init_method {
83            InitMethod::Random => self.initialize_random(data)?,
84            InitMethod::KMeansPlusPlus => self.initialize_kmeans_plus_plus(data)?,
85        };
86
87        let mut labels = Array1::zeros(n_samples);
88        let mut converged = false;
89
90        // Lloyd's algorithm
91        for iteration in 0..self.max_iterations {
92            // Assignment step: assign each point to nearest center
93            let mut changed = false;
94            for i in 0..n_samples {
95                let point = data.row(i);
96                let nearest = self.find_nearest_center(&point, &centers)?;
97                if labels[i] != nearest {
98                    labels[i] = nearest;
99                    changed = true;
100                }
101            }
102
103            if !changed {
104                converged = true;
105                tracing::debug!("K-means converged after {} iterations", iteration);
106                break;
107            }
108
109            // Update step: recalculate centers
110            let old_centers = centers.clone();
111            centers = self.update_centers(data, &labels)?;
112
113            // Check convergence based on center movement
114            let max_movement = self.max_center_movement(&old_centers, &centers)?;
115            if max_movement < self.tolerance {
116                converged = true;
117                tracing::debug!(
118                    "K-means converged after {} iterations (max movement: {})",
119                    iteration,
120                    max_movement
121                );
122                break;
123            }
124        }
125
126        // Calculate inertia (within-cluster sum of squares)
127        let inertia = self.calculate_inertia(data, &labels, &centers)?;
128
129        Ok(KMeansResult {
130            labels,
131            centers,
132            inertia,
133            n_iterations: self.max_iterations,
134            converged,
135        })
136    }
137
138    /// Initialize centers randomly
139    fn initialize_random(&self, data: &ArrayView2<f64>) -> Result<Array2<f64>> {
140        let (n_samples, n_features) = data.dim();
141        let mut rng = scirs2_core::random::thread_rng();
142
143        let mut centers = Array2::zeros((self.n_clusters, n_features));
144        let mut used_indices = Vec::new();
145
146        for i in 0..self.n_clusters {
147            // Select a random sample that hasn't been used
148            let idx = loop {
149                let candidate = rng.gen_range(0..n_samples);
150                if !used_indices.contains(&candidate) {
151                    break candidate;
152                }
153            };
154            used_indices.push(idx);
155
156            centers.row_mut(i).assign(&data.row(idx));
157        }
158
159        Ok(centers)
160    }
161
162    /// Initialize centers using K-means++ algorithm
163    ///
164    /// This gives better initial centers by choosing them probabilistically
165    /// based on distance from existing centers.
166    fn initialize_kmeans_plus_plus(&self, data: &ArrayView2<f64>) -> Result<Array2<f64>> {
167        let (n_samples, n_features) = data.dim();
168        let mut rng = scirs2_core::random::thread_rng();
169
170        let mut centers = Array2::zeros((self.n_clusters, n_features));
171
172        // Choose first center randomly
173        let first_idx = rng.gen_range(0..n_samples);
174        centers.row_mut(0).assign(&data.row(first_idx));
175
176        // Choose remaining centers
177        for i in 1..self.n_clusters {
178            // Calculate distances to nearest existing center
179            let mut distances = Vec::with_capacity(n_samples);
180            let mut distance_sum = 0.0;
181
182            for j in 0..n_samples {
183                let point = data.row(j);
184                let mut min_dist = f64::INFINITY;
185
186                for k in 0..i {
187                    let center = centers.row(k);
188                    let dist = euclidean_distance_squared(&point, &center)?;
189                    min_dist = min_dist.min(dist);
190                }
191
192                distances.push(min_dist);
193                distance_sum += min_dist;
194            }
195
196            // Choose next center with probability proportional to squared distance
197            let threshold = rng.gen_range(0.0..distance_sum);
198            let mut cumsum = 0.0;
199            let mut next_idx = 0;
200
201            for (j, &dist) in distances.iter().enumerate() {
202                cumsum += dist;
203                if cumsum >= threshold {
204                    next_idx = j;
205                    break;
206                }
207            }
208
209            centers.row_mut(i).assign(&data.row(next_idx));
210        }
211
212        Ok(centers)
213    }
214
215    /// Find nearest center for a point
216    fn find_nearest_center(
217        &self,
218        point: &scirs2_core::ndarray::ArrayView1<f64>,
219        centers: &Array2<f64>,
220    ) -> Result<i32> {
221        let mut min_dist = f64::INFINITY;
222        let mut nearest = 0;
223
224        for (i, center) in centers.axis_iter(Axis(0)).enumerate() {
225            let dist = euclidean_distance_squared(point, &center)?;
226            if dist < min_dist {
227                min_dist = dist;
228                nearest = i;
229            }
230        }
231
232        Ok(nearest as i32)
233    }
234
235    /// Update cluster centers based on current assignments
236    fn update_centers(&self, data: &ArrayView2<f64>, labels: &Array1<i32>) -> Result<Array2<f64>> {
237        let (n_samples, n_features) = data.dim();
238        let mut new_centers = Array2::zeros((self.n_clusters, n_features));
239        let mut counts = vec![0; self.n_clusters];
240
241        // Sum up points in each cluster
242        for i in 0..n_samples {
243            let cluster = labels[i] as usize;
244            if cluster < self.n_clusters {
245                for j in 0..n_features {
246                    new_centers[[cluster, j]] += data[[i, j]];
247                }
248                counts[cluster] += 1;
249            }
250        }
251
252        // Average to get new centers
253        for i in 0..self.n_clusters {
254            if counts[i] > 0 {
255                for j in 0..n_features {
256                    new_centers[[i, j]] /= counts[i] as f64;
257                }
258            } else {
259                // Handle empty cluster by reinitializing
260                tracing::warn!("Cluster {} is empty, reinitializing", i);
261                // Keep the old center or initialize randomly
262            }
263        }
264
265        Ok(new_centers)
266    }
267
268    /// Calculate maximum movement of centers
269    fn max_center_movement(
270        &self,
271        old_centers: &Array2<f64>,
272        new_centers: &Array2<f64>,
273    ) -> Result<f64> {
274        let mut max_dist: f64 = 0.0;
275
276        for i in 0..self.n_clusters {
277            let dist = euclidean_distance_squared(&old_centers.row(i), &new_centers.row(i))?;
278            max_dist = max_dist.max(dist);
279        }
280
281        Ok(max_dist.sqrt())
282    }
283
284    /// Calculate within-cluster sum of squares (inertia)
285    fn calculate_inertia(
286        &self,
287        data: &ArrayView2<f64>,
288        labels: &Array1<i32>,
289        centers: &Array2<f64>,
290    ) -> Result<f64> {
291        let mut inertia = 0.0;
292
293        for (i, &label) in labels.iter().enumerate() {
294            let cluster = label as usize;
295            if cluster < self.n_clusters {
296                let point = data.row(i);
297                let center = centers.row(cluster);
298                inertia += euclidean_distance_squared(&point, &center)?;
299            }
300        }
301
302        Ok(inertia)
303    }
304}
305
306/// Calculate squared euclidean distance between two points
307fn euclidean_distance_squared(
308    p1: &scirs2_core::ndarray::ArrayView1<f64>,
309    p2: &scirs2_core::ndarray::ArrayView1<f64>,
310) -> Result<f64> {
311    if p1.len() != p2.len() {
312        return Err(AnalyticsError::dimension_mismatch(
313            format!("{}", p1.len()),
314            format!("{}", p2.len()),
315        ));
316    }
317
318    let dist_sq: f64 = p1.iter().zip(p2.iter()).map(|(a, b)| (a - b).powi(2)).sum();
319
320    Ok(dist_sq)
321}
322
323#[cfg(test)]
324mod tests {
325    use super::*;
326    use approx::assert_abs_diff_eq;
327    use scirs2_core::ndarray::array;
328
329    #[test]
330    fn test_kmeans_simple() {
331        // Create simple 2-cluster data
332        let data = array![
333            [0.0, 0.0],
334            [0.1, 0.1],
335            [0.2, 0.0],
336            [10.0, 10.0],
337            [10.1, 10.1],
338            [10.0, 10.2],
339        ];
340
341        let clusterer = KMeansClusterer::new(2, 100, 1e-4);
342        let result = clusterer
343            .fit(&data.view())
344            .expect("K-means clustering should succeed for valid data");
345
346        assert_eq!(result.labels.len(), 6);
347        assert_eq!(result.centers.nrows(), 2);
348        assert!(result.converged);
349
350        // Check that similar points are in same cluster
351        assert_eq!(result.labels[0], result.labels[1]);
352        assert_eq!(result.labels[3], result.labels[4]);
353        assert_ne!(result.labels[0], result.labels[3]);
354    }
355
356    #[test]
357    fn test_kmeans_insufficient_data() {
358        let data = array![[1.0, 2.0]];
359        let clusterer = KMeansClusterer::new(2, 100, 1e-4);
360        let result = clusterer.fit(&data.view());
361
362        assert!(result.is_err());
363    }
364
365    #[test]
366    fn test_kmeans_plus_plus_init() {
367        let data = array![[0.0, 0.0], [1.0, 1.0], [10.0, 10.0], [11.0, 11.0],];
368
369        let clusterer =
370            KMeansClusterer::new(2, 100, 1e-4).with_init_method(InitMethod::KMeansPlusPlus);
371        let result = clusterer
372            .fit(&data.view())
373            .expect("K-means++ initialization should succeed");
374
375        assert!(result.converged);
376        assert_eq!(result.labels.len(), 4);
377    }
378}