Skip to main content

oxigdal_algorithms/vector/clustering/
kmeans.rs

1//! K-means clustering algorithm
2//!
3//! Partition points into K clusters by minimizing within-cluster variance.
4
5use crate::error::{AlgorithmError, Result};
6use crate::vector::clustering::dbscan::{DistanceMetric, calculate_distance};
7use oxigdal_core::vector::Point;
8use std::collections::HashMap;
9
10/// Options for K-means clustering
11#[derive(Debug, Clone)]
12pub struct KmeansOptions {
13    /// Number of clusters
14    pub k: usize,
15    /// Maximum number of iterations
16    pub max_iterations: usize,
17    /// Convergence tolerance
18    pub tolerance: f64,
19    /// Distance metric
20    pub metric: DistanceMetric,
21    /// Initialization method
22    pub init_method: InitMethod,
23    /// Random seed for reproducibility
24    pub seed: Option<u64>,
25}
26
27impl Default for KmeansOptions {
28    fn default() -> Self {
29        Self {
30            k: 3,
31            max_iterations: 100,
32            tolerance: 1e-6,
33            metric: DistanceMetric::Euclidean,
34            init_method: InitMethod::KMeansPlusPlus,
35            seed: None,
36        }
37    }
38}
39
40/// Initialization method for cluster centroids
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub enum InitMethod {
43    /// Random selection
44    Random,
45    /// K-means++ (better initialization)
46    KMeansPlusPlus,
47    /// Uniform grid initialization
48    Grid,
49}
50
51/// Result of K-means clustering
52#[derive(Debug, Clone)]
53pub struct KmeansResult {
54    /// Cluster assignments for each point
55    pub labels: Vec<usize>,
56    /// Final centroid positions
57    pub centroids: Vec<Point>,
58    /// Sum of squared distances (inertia)
59    pub inertia: f64,
60    /// Number of iterations performed
61    pub iterations: usize,
62    /// Whether algorithm converged
63    pub converged: bool,
64    /// Cluster sizes
65    pub cluster_sizes: HashMap<usize, usize>,
66}
67
68/// Perform K-means clustering
69///
70/// # Arguments
71///
72/// * `points` - Points to cluster
73/// * `options` - K-means options
74///
75/// # Returns
76///
77/// Clustering result with labels and centroids
78///
79/// # Examples
80///
81/// ```
82/// use oxigdal_algorithms::vector::clustering::{kmeans_cluster, KmeansOptions};
83/// use oxigdal_algorithms::Point;
84/// # use oxigdal_algorithms::error::Result;
85///
86/// # fn main() -> Result<()> {
87/// let points = vec![
88///     Point::new(0.0, 0.0),
89///     Point::new(0.1, 0.1),
90///     Point::new(5.0, 5.0),
91///     Point::new(5.1, 5.1),
92/// ];
93///
94/// let options = KmeansOptions {
95///     k: 2,
96///     max_iterations: 100,
97///     ..Default::default()
98/// };
99///
100/// let result = kmeans_cluster(&points, &options)?;
101/// assert_eq!(result.centroids.len(), 2);
102/// # Ok(())
103/// # }
104/// ```
105pub fn kmeans_cluster(points: &[Point], options: &KmeansOptions) -> Result<KmeansResult> {
106    if points.is_empty() {
107        return Err(AlgorithmError::InvalidInput(
108            "Cannot cluster empty point set".to_string(),
109        ));
110    }
111
112    if options.k == 0 {
113        return Err(AlgorithmError::InvalidInput(
114            "Number of clusters must be positive".to_string(),
115        ));
116    }
117
118    if options.k > points.len() {
119        return Err(AlgorithmError::InvalidInput(format!(
120            "Number of clusters ({}) exceeds number of points ({})",
121            options.k,
122            points.len()
123        )));
124    }
125
126    // Initialize centroids
127    let mut centroids = match options.init_method {
128        InitMethod::KMeansPlusPlus => kmeans_plus_plus_init(points, options.k, options.metric)?,
129        InitMethod::Random => random_init(points, options.k),
130        InitMethod::Grid => grid_init(points, options.k),
131    };
132
133    let mut labels = vec![0; points.len()];
134    let mut converged = false;
135    let mut iteration = 0;
136
137    for iter in 0..options.max_iterations {
138        iteration = iter + 1;
139
140        // Assignment step: assign each point to nearest centroid
141        let mut changed = false;
142        for (i, point) in points.iter().enumerate() {
143            let nearest = find_nearest_centroid(point, &centroids, options.metric);
144            if labels[i] != nearest {
145                labels[i] = nearest;
146                changed = true;
147            }
148        }
149
150        if !changed {
151            converged = true;
152            break;
153        }
154
155        // Update step: recalculate centroids
156        let old_centroids = centroids.clone();
157        centroids = update_centroids(points, &labels, options.k);
158
159        // Check convergence
160        let max_movement = old_centroids
161            .iter()
162            .zip(&centroids)
163            .map(|(old, new)| calculate_distance(old, new, options.metric))
164            .fold(0.0, f64::max);
165
166        if max_movement < options.tolerance {
167            converged = true;
168            break;
169        }
170    }
171
172    // Calculate inertia (sum of squared distances)
173    let mut inertia = 0.0;
174    for (point, &label) in points.iter().zip(&labels) {
175        let centroid = &centroids[label];
176        let dist = calculate_distance(point, centroid, options.metric);
177        inertia += dist * dist;
178    }
179
180    // Calculate cluster sizes
181    let mut cluster_sizes: HashMap<usize, usize> = HashMap::new();
182    for &label in &labels {
183        *cluster_sizes.entry(label).or_insert(0) += 1;
184    }
185
186    Ok(KmeansResult {
187        labels,
188        centroids,
189        inertia,
190        iterations: iteration,
191        converged,
192        cluster_sizes,
193    })
194}
195
196/// K-means++ initialization for better starting centroids
197pub fn kmeans_plus_plus_init(
198    points: &[Point],
199    k: usize,
200    metric: DistanceMetric,
201) -> Result<Vec<Point>> {
202    if k > points.len() {
203        return Err(AlgorithmError::InvalidInput(
204            "k exceeds number of points".to_string(),
205        ));
206    }
207
208    let mut centroids = Vec::with_capacity(k);
209
210    // Choose first centroid randomly
211    // Using first point as deterministic "random" choice
212    centroids.push(points[0].clone());
213
214    // Choose remaining centroids
215    for _ in 1..k {
216        // Calculate D^2 for each point (squared distance to nearest centroid)
217        let mut weights: Vec<f64> = points
218            .iter()
219            .map(|point| {
220                let min_dist = centroids
221                    .iter()
222                    .map(|centroid| calculate_distance(point, centroid, metric))
223                    .fold(f64::INFINITY, f64::min);
224                min_dist * min_dist
225            })
226            .collect();
227
228        // Normalize weights
229        let total_weight: f64 = weights.iter().sum();
230        if total_weight > 0.0 {
231            for w in &mut weights {
232                *w /= total_weight;
233            }
234        }
235
236        // Choose next centroid with probability proportional to D^2
237        // For deterministic behavior, choose the point with maximum weight
238        let next_idx = weights
239            .iter()
240            .enumerate()
241            .max_by(|(_, a): &(usize, &f64), (_, b): &(usize, &f64)| {
242                a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
243            })
244            .map(|(idx, _)| idx)
245            .unwrap_or(centroids.len());
246
247        centroids.push(points[next_idx].clone());
248    }
249
250    Ok(centroids)
251}
252
253/// Random initialization (first k points)
254fn random_init(points: &[Point], k: usize) -> Vec<Point> {
255    points.iter().take(k).cloned().collect()
256}
257
258/// Grid-based initialization
259fn grid_init(points: &[Point], k: usize) -> Vec<Point> {
260    if points.is_empty() {
261        return Vec::new();
262    }
263
264    // Find bounding box
265    let mut min_x = f64::INFINITY;
266    let mut max_x = f64::NEG_INFINITY;
267    let mut min_y = f64::INFINITY;
268    let mut max_y = f64::NEG_INFINITY;
269
270    for point in points {
271        min_x = min_x.min(point.coord.x);
272        max_x = max_x.max(point.coord.x);
273        min_y = min_y.min(point.coord.y);
274        max_y = max_y.max(point.coord.y);
275    }
276
277    // Create grid of k centroids
278    let grid_size = (k as f64).sqrt().ceil() as usize;
279    let mut centroids = Vec::new();
280
281    for i in 0..grid_size {
282        for j in 0..grid_size {
283            if centroids.len() >= k {
284                break;
285            }
286
287            let x = min_x + (max_x - min_x) * (i as f64 + 0.5) / grid_size as f64;
288            let y = min_y + (max_y - min_y) * (j as f64 + 0.5) / grid_size as f64;
289
290            centroids.push(Point::new(x, y));
291        }
292
293        if centroids.len() >= k {
294            break;
295        }
296    }
297
298    centroids
299}
300
301/// Find nearest centroid for a point
302fn find_nearest_centroid(point: &Point, centroids: &[Point], metric: DistanceMetric) -> usize {
303    centroids
304        .iter()
305        .enumerate()
306        .map(|(idx, centroid)| (idx, calculate_distance(point, centroid, metric)))
307        .min_by(|(_, d1): &(usize, f64), (_, d2): &(usize, f64)| {
308            d1.partial_cmp(d2).unwrap_or(std::cmp::Ordering::Equal)
309        })
310        .map(|(idx, _)| idx)
311        .unwrap_or(0)
312}
313
314/// Update centroids based on current assignments
315fn update_centroids(points: &[Point], labels: &[usize], k: usize) -> Vec<Point> {
316    let mut sums_x = vec![0.0; k];
317    let mut sums_y = vec![0.0; k];
318    let mut counts = vec![0; k];
319
320    for (point, &label) in points.iter().zip(labels) {
321        sums_x[label] += point.coord.x;
322        sums_y[label] += point.coord.y;
323        counts[label] += 1;
324    }
325
326    (0..k)
327        .map(|i| {
328            if counts[i] > 0 {
329                Point::new(sums_x[i] / counts[i] as f64, sums_y[i] / counts[i] as f64)
330            } else {
331                // Empty cluster, keep old centroid or use first point
332                Point::new(
333                    sums_x[0] / counts[0].max(1) as f64,
334                    sums_y[0] / counts[0].max(1) as f64,
335                )
336            }
337        })
338        .collect()
339}
340
341#[cfg(test)]
342mod tests {
343    use super::*;
344
345    #[test]
346    fn test_kmeans_simple() {
347        let points = vec![
348            Point::new(0.0, 0.0),
349            Point::new(0.1, 0.1),
350            Point::new(5.0, 5.0),
351            Point::new(5.1, 5.1),
352        ];
353
354        let options = KmeansOptions {
355            k: 2,
356            max_iterations: 100,
357            ..Default::default()
358        };
359
360        let result = kmeans_cluster(&points, &options);
361        assert!(result.is_ok());
362
363        let clustering = result.expect("Clustering failed");
364        assert_eq!(clustering.centroids.len(), 2);
365        assert_eq!(clustering.labels.len(), 4);
366    }
367
368    #[test]
369    fn test_kmeans_plus_plus() {
370        let points = vec![
371            Point::new(0.0, 0.0),
372            Point::new(0.1, 0.1),
373            Point::new(5.0, 5.0),
374            Point::new(5.1, 5.1),
375        ];
376
377        let centroids = kmeans_plus_plus_init(&points, 2, DistanceMetric::Euclidean);
378        assert!(centroids.is_ok());
379
380        let init = centroids.expect("Init failed");
381        assert_eq!(init.len(), 2);
382    }
383
384    #[test]
385    fn test_grid_init() {
386        let points = vec![Point::new(0.0, 0.0), Point::new(10.0, 10.0)];
387
388        let centroids = grid_init(&points, 4);
389        assert_eq!(centroids.len(), 4);
390    }
391
392    #[test]
393    fn test_kmeans_convergence() {
394        let points = vec![
395            Point::new(0.0, 0.0),
396            Point::new(0.0, 0.0),
397            Point::new(10.0, 10.0),
398            Point::new(10.0, 10.0),
399        ];
400
401        let options = KmeansOptions {
402            k: 2,
403            tolerance: 0.01,
404            ..Default::default()
405        };
406
407        let result = kmeans_cluster(&points, &options);
408        assert!(result.is_ok());
409
410        let clustering = result.expect("Clustering failed");
411        assert!(clustering.converged);
412    }
413}