1use super::simd_distance::l2_squared_simd;
7
8#[derive(Debug, Clone)]
10pub struct ClusterResult {
11 pub assignments: Vec<i32>,
14 pub centroids: Vec<Vec<f32>>,
16 pub k: usize,
18 pub cluster_sizes: Vec<usize>,
20 pub iterations: usize,
22 pub converged: bool,
24}
25
26pub fn kmeans(
34 vectors: &[Vec<f32>],
35 k: usize,
36 max_iterations: usize,
37 convergence_threshold: f32,
38) -> ClusterResult {
39 if vectors.is_empty() || k == 0 {
40 return ClusterResult {
41 assignments: Vec::new(),
42 centroids: Vec::new(),
43 k: 0,
44 cluster_sizes: Vec::new(),
45 iterations: 0,
46 converged: true,
47 };
48 }
49
50 let k = k.min(vectors.len());
51 let dim = vectors[0].len();
52
53 let mut centroids = kmeans_plusplus_init(vectors, k);
55
56 let mut assignments = vec![0i32; vectors.len()];
57 let mut iterations = 0;
58 let mut converged = false;
59 let use_parallel = vectors.len() > 1000
60 && std::thread::available_parallelism()
61 .map(|p| p.get())
62 .unwrap_or(1)
63 > 1;
64
65 for iter in 0..max_iterations {
66 iterations = iter + 1;
67
68 if use_parallel {
70 std::thread::scope(|s| {
71 let chunk_size = (vectors.len() / 4).max(256);
72 let chunks: Vec<_> = assignments.chunks_mut(chunk_size).enumerate().collect();
73 let handles: Vec<_> = chunks
74 .into_iter()
75 .map(|(chunk_idx, chunk)| {
76 let centroids = ¢roids;
77 let vectors = &vectors;
78 let offset = chunk_idx * chunk_size;
79 s.spawn(move || {
80 for (j, assignment) in chunk.iter_mut().enumerate() {
81 let i = offset + j;
82 if i < vectors.len() {
83 *assignment =
84 find_nearest_centroid(&vectors[i], centroids) as i32;
85 }
86 }
87 })
88 })
89 .collect();
90 for h in handles {
91 let _ = h.join();
92 }
93 });
94 } else {
95 for (i, vector) in vectors.iter().enumerate() {
96 assignments[i] = find_nearest_centroid(vector, ¢roids) as i32;
97 }
98 }
99
100 let mut cluster_groups: Vec<Vec<usize>> = vec![Vec::new(); k];
101 for (i, &a) in assignments.iter().enumerate() {
102 cluster_groups[a as usize].push(i);
103 }
104
105 let mut max_shift: f32 = 0.0;
107 let mut new_centroids = Vec::with_capacity(k);
108
109 for (cluster_idx, indices) in cluster_groups.iter().enumerate() {
110 if indices.is_empty() {
111 new_centroids.push(centroids[cluster_idx].clone());
112 continue;
113 }
114
115 let mut new_centroid = vec![0.0f32; dim];
116 for &idx in indices {
117 for (j, val) in vectors[idx].iter().enumerate() {
118 if j < dim {
119 new_centroid[j] += val;
120 }
121 }
122 }
123 for val in &mut new_centroid {
124 *val /= indices.len() as f32;
125 }
126
127 let shift = l2_squared_simd(&new_centroid, ¢roids[cluster_idx]).sqrt();
128 max_shift = max_shift.max(shift);
129 new_centroids.push(new_centroid);
130 }
131
132 centroids = new_centroids;
133
134 if max_shift < convergence_threshold {
135 converged = true;
136 break;
137 }
138 }
139
140 let cluster_sizes: Vec<usize> = (0..k)
141 .map(|c| assignments.iter().filter(|&&a| a == c as i32).count())
142 .collect();
143
144 ClusterResult {
145 assignments,
146 centroids,
147 k,
148 cluster_sizes,
149 iterations,
150 converged,
151 }
152}
153
154fn kmeans_plusplus_init(vectors: &[Vec<f32>], k: usize) -> Vec<Vec<f32>> {
155 let mut centroids = Vec::with_capacity(k);
156 if vectors.is_empty() || k == 0 {
157 return centroids;
158 }
159
160 centroids.push(vectors[vectors.len() / 2].clone());
161
162 for _ in 1..k {
163 let distances: Vec<f32> = vectors
164 .iter()
165 .map(|v| {
166 centroids
167 .iter()
168 .map(|c| l2_squared_simd(v, c))
169 .fold(f32::MAX, f32::min)
170 })
171 .collect();
172
173 let max_idx = distances
174 .iter()
175 .enumerate()
176 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
177 .map(|(i, _)| i)
178 .unwrap_or(0);
179
180 centroids.push(vectors[max_idx].clone());
181 }
182
183 centroids
184}
185
186fn find_nearest_centroid(vector: &[f32], centroids: &[Vec<f32>]) -> usize {
187 centroids
188 .iter()
189 .enumerate()
190 .map(|(i, c)| (i, l2_squared_simd(vector, c)))
191 .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
192 .map(|(i, _)| i)
193 .unwrap_or(0)
194}
195
196pub fn dbscan(vectors: &[Vec<f32>], eps: f32, min_points: usize) -> ClusterResult {
206 const UNVISITED: i32 = -2;
207 const NOISE: i32 = -1;
208
209 let n = vectors.len();
210 if n == 0 {
211 return ClusterResult {
212 assignments: Vec::new(),
213 centroids: Vec::new(),
214 k: 0,
215 cluster_sizes: Vec::new(),
216 iterations: 0,
217 converged: true,
218 };
219 }
220
221 let eps_sq = eps * eps;
222 let mut assignments = vec![UNVISITED; n];
223 let mut visited = vec![false; n];
224 let mut cluster_id: i32 = 0;
225
226 for i in 0..n {
227 if visited[i] {
228 continue;
229 }
230
231 visited[i] = true;
232 let neighbors = range_query(vectors, i, eps_sq);
233
234 if neighbors.len() < min_points {
235 assignments[i] = NOISE;
236 continue;
237 }
238
239 assignments[i] = cluster_id;
241 let mut seed_set: Vec<usize> = neighbors;
242 let mut j = 0;
243
244 while j < seed_set.len() {
245 let q = seed_set[j];
246 j += 1;
247
248 if !visited[q] {
249 visited[q] = true;
250
251 let q_neighbors = range_query(vectors, q, eps_sq);
252 if q_neighbors.len() >= min_points {
253 for &neighbor in &q_neighbors {
254 if matches!(assignments[neighbor], UNVISITED | NOISE)
255 && !seed_set.contains(&neighbor)
256 {
257 seed_set.push(neighbor);
258 }
259 }
260 }
261 }
262
263 if matches!(assignments[q], UNVISITED | NOISE) {
264 assignments[q] = cluster_id;
265 }
266 }
267
268 cluster_id += 1;
269 }
270
271 for assignment in &mut assignments {
272 if *assignment == UNVISITED {
273 *assignment = NOISE;
274 }
275 }
276
277 let k = cluster_id as usize;
278
279 let dim = vectors[0].len();
281 let mut centroids = Vec::with_capacity(k);
282 let mut cluster_sizes = Vec::with_capacity(k);
283
284 for c in 0..k {
285 let members: Vec<usize> = assignments
286 .iter()
287 .enumerate()
288 .filter(|(_, &a)| a == c as i32)
289 .map(|(i, _)| i)
290 .collect();
291
292 cluster_sizes.push(members.len());
293
294 if members.is_empty() {
295 centroids.push(vec![0.0; dim]);
296 continue;
297 }
298
299 let mut centroid = vec![0.0f32; dim];
300 for &idx in &members {
301 for (j, val) in vectors[idx].iter().enumerate() {
302 if j < dim {
303 centroid[j] += val;
304 }
305 }
306 }
307 for val in &mut centroid {
308 *val /= members.len() as f32;
309 }
310 centroids.push(centroid);
311 }
312
313 ClusterResult {
314 assignments,
315 centroids,
316 k,
317 cluster_sizes,
318 iterations: 0,
319 converged: true,
320 }
321}
322
323fn range_query(vectors: &[Vec<f32>], idx: usize, eps_sq: f32) -> Vec<usize> {
325 let point = &vectors[idx];
326 vectors
327 .iter()
328 .enumerate()
329 .filter(|(_, v)| l2_squared_simd(point, v) <= eps_sq)
330 .map(|(i, _)| i)
331 .collect()
332}
333
334#[cfg(test)]
337mod tests {
338 use super::*;
339
340 #[test]
341 fn test_kmeans_basic() {
342 let vectors = vec![
343 vec![0.0, 0.0],
344 vec![0.1, 0.1],
345 vec![0.2, 0.0],
346 vec![10.0, 10.0],
347 vec![10.1, 10.1],
348 vec![10.2, 10.0],
349 ];
350 let result = kmeans(&vectors, 2, 100, 0.001);
351 assert_eq!(result.k, 2);
352 assert_eq!(result.assignments.len(), 6);
353 assert_eq!(result.assignments[0], result.assignments[1]);
355 assert_eq!(result.assignments[1], result.assignments[2]);
356 assert_eq!(result.assignments[3], result.assignments[4]);
357 assert_eq!(result.assignments[4], result.assignments[5]);
358 assert_ne!(result.assignments[0], result.assignments[3]);
359 }
360
361 #[test]
362 fn test_kmeans_single_cluster() {
363 let vectors = vec![vec![1.0, 1.0], vec![1.1, 1.1], vec![0.9, 0.9]];
364 let result = kmeans(&vectors, 1, 10, 0.001);
365 assert_eq!(result.k, 1);
366 assert!(result.assignments.iter().all(|&a| a == 0));
367 }
368
369 #[test]
370 fn test_kmeans_empty() {
371 let result = kmeans(&[], 5, 10, 0.001);
372 assert_eq!(result.k, 0);
373 }
374
375 #[test]
376 fn test_dbscan_basic() {
377 let vectors = vec![
378 vec![0.0, 0.0],
379 vec![0.1, 0.0],
380 vec![0.0, 0.1],
381 vec![10.0, 10.0],
382 vec![10.1, 10.0],
383 vec![10.0, 10.1],
384 vec![100.0, 100.0], ];
386 let result = dbscan(&vectors, 0.5, 2);
387 assert_eq!(result.k, 2);
388 assert_eq!(result.assignments[0], result.assignments[1]);
390 assert_eq!(result.assignments[1], result.assignments[2]);
391 assert_eq!(result.assignments[3], result.assignments[4]);
393 assert_eq!(result.assignments[4], result.assignments[5]);
394 assert_ne!(result.assignments[0], result.assignments[3]);
396 assert_eq!(result.assignments[6], -1);
398 }
399
400 #[test]
401 fn test_dbscan_all_noise() {
402 let vectors = vec![vec![0.0, 0.0], vec![100.0, 100.0], vec![200.0, 200.0]];
403 let result = dbscan(&vectors, 0.1, 2);
404 assert_eq!(result.k, 0);
405 assert!(result.assignments.iter().all(|&a| a == -1));
406 }
407
408 #[test]
409 fn test_dbscan_single_cluster() {
410 let vectors = vec![
411 vec![0.0, 0.0],
412 vec![0.1, 0.0],
413 vec![0.2, 0.0],
414 vec![0.3, 0.0],
415 ];
416 let result = dbscan(&vectors, 0.15, 2);
417 assert_eq!(result.k, 1);
418 assert!(result.assignments.iter().all(|&a| a == 0));
419 }
420
421 #[test]
422 fn test_dbscan_relabels_noise_point_when_later_core_expands_cluster() {
423 let vectors = vec![
424 vec![0.0, 0.0],
425 vec![0.08, 0.0],
426 vec![0.16, 0.0],
427 vec![10.0, 10.0],
428 ];
429
430 let result = dbscan(&vectors, 0.09, 3);
431
432 assert_eq!(result.k, 1);
433 assert_eq!(result.assignments[0], 0);
434 assert_eq!(result.assignments[1], 0);
435 assert_eq!(result.assignments[2], 0);
436 assert_eq!(result.assignments[3], -1);
437 assert_eq!(result.cluster_sizes, vec![3]);
438 }
439
440 #[test]
441 fn test_dbscan_expands_density_connected_chain() {
442 let vectors = vec![
443 vec![0.0, 0.0],
444 vec![0.08, 0.0],
445 vec![0.16, 0.0],
446 vec![0.24, 0.0],
447 ];
448
449 let result = dbscan(&vectors, 0.09, 3);
450
451 assert_eq!(result.k, 1);
452 assert!(result.assignments.iter().all(|&assignment| assignment == 0));
453 assert_eq!(result.cluster_sizes, vec![4]);
454 }
455}