oxigdal_algorithms/vector/clustering/
kmeans.rs1use crate::error::{AlgorithmError, Result};
6use crate::vector::clustering::dbscan::{DistanceMetric, calculate_distance};
7use oxigdal_core::vector::Point;
8use std::collections::HashMap;
9
10#[derive(Debug, Clone)]
12pub struct KmeansOptions {
13 pub k: usize,
15 pub max_iterations: usize,
17 pub tolerance: f64,
19 pub metric: DistanceMetric,
21 pub init_method: InitMethod,
23 pub seed: Option<u64>,
25}
26
27impl Default for KmeansOptions {
28 fn default() -> Self {
29 Self {
30 k: 3,
31 max_iterations: 100,
32 tolerance: 1e-6,
33 metric: DistanceMetric::Euclidean,
34 init_method: InitMethod::KMeansPlusPlus,
35 seed: None,
36 }
37 }
38}
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub enum InitMethod {
43 Random,
45 KMeansPlusPlus,
47 Grid,
49}
50
51#[derive(Debug, Clone)]
53pub struct KmeansResult {
54 pub labels: Vec<usize>,
56 pub centroids: Vec<Point>,
58 pub inertia: f64,
60 pub iterations: usize,
62 pub converged: bool,
64 pub cluster_sizes: HashMap<usize, usize>,
66}
67
68pub fn kmeans_cluster(points: &[Point], options: &KmeansOptions) -> Result<KmeansResult> {
106 if points.is_empty() {
107 return Err(AlgorithmError::InvalidInput(
108 "Cannot cluster empty point set".to_string(),
109 ));
110 }
111
112 if options.k == 0 {
113 return Err(AlgorithmError::InvalidInput(
114 "Number of clusters must be positive".to_string(),
115 ));
116 }
117
118 if options.k > points.len() {
119 return Err(AlgorithmError::InvalidInput(format!(
120 "Number of clusters ({}) exceeds number of points ({})",
121 options.k,
122 points.len()
123 )));
124 }
125
126 let mut centroids = match options.init_method {
128 InitMethod::KMeansPlusPlus => kmeans_plus_plus_init(points, options.k, options.metric)?,
129 InitMethod::Random => random_init(points, options.k),
130 InitMethod::Grid => grid_init(points, options.k),
131 };
132
133 let mut labels = vec![0; points.len()];
134 let mut converged = false;
135 let mut iteration = 0;
136
137 for iter in 0..options.max_iterations {
138 iteration = iter + 1;
139
140 let mut changed = false;
142 for (i, point) in points.iter().enumerate() {
143 let nearest = find_nearest_centroid(point, ¢roids, options.metric);
144 if labels[i] != nearest {
145 labels[i] = nearest;
146 changed = true;
147 }
148 }
149
150 if !changed {
151 converged = true;
152 break;
153 }
154
155 let old_centroids = centroids.clone();
157 centroids = update_centroids(points, &labels, options.k);
158
159 let max_movement = old_centroids
161 .iter()
162 .zip(¢roids)
163 .map(|(old, new)| calculate_distance(old, new, options.metric))
164 .fold(0.0, f64::max);
165
166 if max_movement < options.tolerance {
167 converged = true;
168 break;
169 }
170 }
171
172 let mut inertia = 0.0;
174 for (point, &label) in points.iter().zip(&labels) {
175 let centroid = ¢roids[label];
176 let dist = calculate_distance(point, centroid, options.metric);
177 inertia += dist * dist;
178 }
179
180 let mut cluster_sizes: HashMap<usize, usize> = HashMap::new();
182 for &label in &labels {
183 *cluster_sizes.entry(label).or_insert(0) += 1;
184 }
185
186 Ok(KmeansResult {
187 labels,
188 centroids,
189 inertia,
190 iterations: iteration,
191 converged,
192 cluster_sizes,
193 })
194}
195
196pub fn kmeans_plus_plus_init(
198 points: &[Point],
199 k: usize,
200 metric: DistanceMetric,
201) -> Result<Vec<Point>> {
202 if k > points.len() {
203 return Err(AlgorithmError::InvalidInput(
204 "k exceeds number of points".to_string(),
205 ));
206 }
207
208 let mut centroids = Vec::with_capacity(k);
209
210 centroids.push(points[0].clone());
213
214 for _ in 1..k {
216 let mut weights: Vec<f64> = points
218 .iter()
219 .map(|point| {
220 let min_dist = centroids
221 .iter()
222 .map(|centroid| calculate_distance(point, centroid, metric))
223 .fold(f64::INFINITY, f64::min);
224 min_dist * min_dist
225 })
226 .collect();
227
228 let total_weight: f64 = weights.iter().sum();
230 if total_weight > 0.0 {
231 for w in &mut weights {
232 *w /= total_weight;
233 }
234 }
235
236 let next_idx = weights
239 .iter()
240 .enumerate()
241 .max_by(|(_, a): &(usize, &f64), (_, b): &(usize, &f64)| {
242 a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
243 })
244 .map(|(idx, _)| idx)
245 .unwrap_or(centroids.len());
246
247 centroids.push(points[next_idx].clone());
248 }
249
250 Ok(centroids)
251}
252
253fn random_init(points: &[Point], k: usize) -> Vec<Point> {
255 points.iter().take(k).cloned().collect()
256}
257
258fn grid_init(points: &[Point], k: usize) -> Vec<Point> {
260 if points.is_empty() {
261 return Vec::new();
262 }
263
264 let mut min_x = f64::INFINITY;
266 let mut max_x = f64::NEG_INFINITY;
267 let mut min_y = f64::INFINITY;
268 let mut max_y = f64::NEG_INFINITY;
269
270 for point in points {
271 min_x = min_x.min(point.coord.x);
272 max_x = max_x.max(point.coord.x);
273 min_y = min_y.min(point.coord.y);
274 max_y = max_y.max(point.coord.y);
275 }
276
277 let grid_size = (k as f64).sqrt().ceil() as usize;
279 let mut centroids = Vec::new();
280
281 for i in 0..grid_size {
282 for j in 0..grid_size {
283 if centroids.len() >= k {
284 break;
285 }
286
287 let x = min_x + (max_x - min_x) * (i as f64 + 0.5) / grid_size as f64;
288 let y = min_y + (max_y - min_y) * (j as f64 + 0.5) / grid_size as f64;
289
290 centroids.push(Point::new(x, y));
291 }
292
293 if centroids.len() >= k {
294 break;
295 }
296 }
297
298 centroids
299}
300
301fn find_nearest_centroid(point: &Point, centroids: &[Point], metric: DistanceMetric) -> usize {
303 centroids
304 .iter()
305 .enumerate()
306 .map(|(idx, centroid)| (idx, calculate_distance(point, centroid, metric)))
307 .min_by(|(_, d1): &(usize, f64), (_, d2): &(usize, f64)| {
308 d1.partial_cmp(d2).unwrap_or(std::cmp::Ordering::Equal)
309 })
310 .map(|(idx, _)| idx)
311 .unwrap_or(0)
312}
313
314fn update_centroids(points: &[Point], labels: &[usize], k: usize) -> Vec<Point> {
316 let mut sums_x = vec![0.0; k];
317 let mut sums_y = vec![0.0; k];
318 let mut counts = vec![0; k];
319
320 for (point, &label) in points.iter().zip(labels) {
321 sums_x[label] += point.coord.x;
322 sums_y[label] += point.coord.y;
323 counts[label] += 1;
324 }
325
326 (0..k)
327 .map(|i| {
328 if counts[i] > 0 {
329 Point::new(sums_x[i] / counts[i] as f64, sums_y[i] / counts[i] as f64)
330 } else {
331 Point::new(
333 sums_x[0] / counts[0].max(1) as f64,
334 sums_y[0] / counts[0].max(1) as f64,
335 )
336 }
337 })
338 .collect()
339}
340
341#[cfg(test)]
342mod tests {
343 use super::*;
344
345 #[test]
346 fn test_kmeans_simple() {
347 let points = vec![
348 Point::new(0.0, 0.0),
349 Point::new(0.1, 0.1),
350 Point::new(5.0, 5.0),
351 Point::new(5.1, 5.1),
352 ];
353
354 let options = KmeansOptions {
355 k: 2,
356 max_iterations: 100,
357 ..Default::default()
358 };
359
360 let result = kmeans_cluster(&points, &options);
361 assert!(result.is_ok());
362
363 let clustering = result.expect("Clustering failed");
364 assert_eq!(clustering.centroids.len(), 2);
365 assert_eq!(clustering.labels.len(), 4);
366 }
367
368 #[test]
369 fn test_kmeans_plus_plus() {
370 let points = vec![
371 Point::new(0.0, 0.0),
372 Point::new(0.1, 0.1),
373 Point::new(5.0, 5.0),
374 Point::new(5.1, 5.1),
375 ];
376
377 let centroids = kmeans_plus_plus_init(&points, 2, DistanceMetric::Euclidean);
378 assert!(centroids.is_ok());
379
380 let init = centroids.expect("Init failed");
381 assert_eq!(init.len(), 2);
382 }
383
384 #[test]
385 fn test_grid_init() {
386 let points = vec![Point::new(0.0, 0.0), Point::new(10.0, 10.0)];
387
388 let centroids = grid_init(&points, 4);
389 assert_eq!(centroids.len(), 4);
390 }
391
392 #[test]
393 fn test_kmeans_convergence() {
394 let points = vec![
395 Point::new(0.0, 0.0),
396 Point::new(0.0, 0.0),
397 Point::new(10.0, 10.0),
398 Point::new(10.0, 10.0),
399 ];
400
401 let options = KmeansOptions {
402 k: 2,
403 tolerance: 0.01,
404 ..Default::default()
405 };
406
407 let result = kmeans_cluster(&points, &options);
408 assert!(result.is_ok());
409
410 let clustering = result.expect("Clustering failed");
411 assert!(clustering.converged);
412 }
413}