Skip to main content

reddb_server/storage/engine/
clustering.rs

1//! Standalone vector clustering: K-Means and DBSCAN.
2//!
3//! Extracted from IVF internals and extended with DBSCAN for
4//! density-based clustering without a pre-specified K.
5
6use super::simd_distance::l2_squared_simd;
7
8/// Result of a clustering operation.
9#[derive(Debug, Clone)]
10pub struct ClusterResult {
11    /// Cluster assignment for each input vector (index → cluster_id).
12    /// -1 means noise (DBSCAN only).
13    pub assignments: Vec<i32>,
14    /// Centroid vectors (one per cluster). Empty for DBSCAN noise points.
15    pub centroids: Vec<Vec<f32>>,
16    /// Number of clusters found.
17    pub k: usize,
18    /// Per-cluster sizes.
19    pub cluster_sizes: Vec<usize>,
20    /// Iterations used (K-Means) or 0 (DBSCAN).
21    pub iterations: usize,
22    /// Whether the algorithm converged (K-Means only).
23    pub converged: bool,
24}
25
26// ── K-Means ─────────────────────────────────────────────────────────────────
27
28/// K-Means++ clustering.
29///
30/// Partitions `vectors` into `k` clusters by iteratively assigning each vector
31/// to its nearest centroid and recomputing centroids as cluster means.
32/// Uses K-Means++ initialization for better starting centroids.
33pub fn kmeans(
34    vectors: &[Vec<f32>],
35    k: usize,
36    max_iterations: usize,
37    convergence_threshold: f32,
38) -> ClusterResult {
39    if vectors.is_empty() || k == 0 {
40        return ClusterResult {
41            assignments: Vec::new(),
42            centroids: Vec::new(),
43            k: 0,
44            cluster_sizes: Vec::new(),
45            iterations: 0,
46            converged: true,
47        };
48    }
49
50    let k = k.min(vectors.len());
51    let dim = vectors[0].len();
52
53    // K-Means++ initialization
54    let mut centroids = kmeans_plusplus_init(vectors, k);
55
56    let mut assignments = vec![0i32; vectors.len()];
57    let mut iterations = 0;
58    let mut converged = false;
59    let use_parallel = vectors.len() > 1000
60        && std::thread::available_parallelism()
61            .map(|p| p.get())
62            .unwrap_or(1)
63            > 1;
64
65    for iter in 0..max_iterations {
66        iterations = iter + 1;
67
68        // Assign each vector to nearest centroid (parallel for large datasets)
69        if use_parallel {
70            std::thread::scope(|s| {
71                let chunk_size = (vectors.len() / 4).max(256);
72                let chunks: Vec<_> = assignments.chunks_mut(chunk_size).enumerate().collect();
73                let handles: Vec<_> = chunks
74                    .into_iter()
75                    .map(|(chunk_idx, chunk)| {
76                        let centroids = &centroids;
77                        let vectors = &vectors;
78                        let offset = chunk_idx * chunk_size;
79                        s.spawn(move || {
80                            for (j, assignment) in chunk.iter_mut().enumerate() {
81                                let i = offset + j;
82                                if i < vectors.len() {
83                                    *assignment =
84                                        find_nearest_centroid(&vectors[i], centroids) as i32;
85                                }
86                            }
87                        })
88                    })
89                    .collect();
90                for h in handles {
91                    let _ = h.join();
92                }
93            });
94        } else {
95            for (i, vector) in vectors.iter().enumerate() {
96                assignments[i] = find_nearest_centroid(vector, &centroids) as i32;
97            }
98        }
99
100        let mut cluster_groups: Vec<Vec<usize>> = vec![Vec::new(); k];
101        for (i, &a) in assignments.iter().enumerate() {
102            cluster_groups[a as usize].push(i);
103        }
104
105        // Recompute centroids
106        let mut max_shift: f32 = 0.0;
107        let mut new_centroids = Vec::with_capacity(k);
108
109        for (cluster_idx, indices) in cluster_groups.iter().enumerate() {
110            if indices.is_empty() {
111                new_centroids.push(centroids[cluster_idx].clone());
112                continue;
113            }
114
115            let mut new_centroid = vec![0.0f32; dim];
116            for &idx in indices {
117                for (j, val) in vectors[idx].iter().enumerate() {
118                    if j < dim {
119                        new_centroid[j] += val;
120                    }
121                }
122            }
123            for val in &mut new_centroid {
124                *val /= indices.len() as f32;
125            }
126
127            let shift = l2_squared_simd(&new_centroid, &centroids[cluster_idx]).sqrt();
128            max_shift = max_shift.max(shift);
129            new_centroids.push(new_centroid);
130        }
131
132        centroids = new_centroids;
133
134        if max_shift < convergence_threshold {
135            converged = true;
136            break;
137        }
138    }
139
140    let cluster_sizes: Vec<usize> = (0..k)
141        .map(|c| assignments.iter().filter(|&&a| a == c as i32).count())
142        .collect();
143
144    ClusterResult {
145        assignments,
146        centroids,
147        k,
148        cluster_sizes,
149        iterations,
150        converged,
151    }
152}
153
154fn kmeans_plusplus_init(vectors: &[Vec<f32>], k: usize) -> Vec<Vec<f32>> {
155    let mut centroids = Vec::with_capacity(k);
156    if vectors.is_empty() || k == 0 {
157        return centroids;
158    }
159
160    centroids.push(vectors[vectors.len() / 2].clone());
161
162    for _ in 1..k {
163        let distances: Vec<f32> = vectors
164            .iter()
165            .map(|v| {
166                centroids
167                    .iter()
168                    .map(|c| l2_squared_simd(v, c))
169                    .fold(f32::MAX, f32::min)
170            })
171            .collect();
172
173        let max_idx = distances
174            .iter()
175            .enumerate()
176            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
177            .map(|(i, _)| i)
178            .unwrap_or(0);
179
180        centroids.push(vectors[max_idx].clone());
181    }
182
183    centroids
184}
185
186fn find_nearest_centroid(vector: &[f32], centroids: &[Vec<f32>]) -> usize {
187    centroids
188        .iter()
189        .enumerate()
190        .map(|(i, c)| (i, l2_squared_simd(vector, c)))
191        .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
192        .map(|(i, _)| i)
193        .unwrap_or(0)
194}
195
196// ── DBSCAN ──────────────────────────────────────────────────────────────────
197
198/// DBSCAN (Density-Based Spatial Clustering of Applications with Noise).
199///
200/// Finds clusters of arbitrary shape based on density. Points that are not
201/// reachable from any dense region are labeled as noise (-1).
202///
203/// - `eps`: maximum distance between two points to be neighbors (L2)
204/// - `min_points`: minimum neighbors to form a dense region
205pub fn dbscan(vectors: &[Vec<f32>], eps: f32, min_points: usize) -> ClusterResult {
206    const UNVISITED: i32 = -2;
207    const NOISE: i32 = -1;
208
209    let n = vectors.len();
210    if n == 0 {
211        return ClusterResult {
212            assignments: Vec::new(),
213            centroids: Vec::new(),
214            k: 0,
215            cluster_sizes: Vec::new(),
216            iterations: 0,
217            converged: true,
218        };
219    }
220
221    let eps_sq = eps * eps;
222    let mut assignments = vec![UNVISITED; n];
223    let mut visited = vec![false; n];
224    let mut cluster_id: i32 = 0;
225
226    for i in 0..n {
227        if visited[i] {
228            continue;
229        }
230
231        visited[i] = true;
232        let neighbors = range_query(vectors, i, eps_sq);
233
234        if neighbors.len() < min_points {
235            assignments[i] = NOISE;
236            continue;
237        }
238
239        // Start new cluster
240        assignments[i] = cluster_id;
241        let mut seed_set: Vec<usize> = neighbors;
242        let mut j = 0;
243
244        while j < seed_set.len() {
245            let q = seed_set[j];
246            j += 1;
247
248            if !visited[q] {
249                visited[q] = true;
250
251                let q_neighbors = range_query(vectors, q, eps_sq);
252                if q_neighbors.len() >= min_points {
253                    for &neighbor in &q_neighbors {
254                        if matches!(assignments[neighbor], UNVISITED | NOISE)
255                            && !seed_set.contains(&neighbor)
256                        {
257                            seed_set.push(neighbor);
258                        }
259                    }
260                }
261            }
262
263            if matches!(assignments[q], UNVISITED | NOISE) {
264                assignments[q] = cluster_id;
265            }
266        }
267
268        cluster_id += 1;
269    }
270
271    for assignment in &mut assignments {
272        if *assignment == UNVISITED {
273            *assignment = NOISE;
274        }
275    }
276
277    let k = cluster_id as usize;
278
279    // Compute centroids for each cluster
280    let dim = vectors[0].len();
281    let mut centroids = Vec::with_capacity(k);
282    let mut cluster_sizes = Vec::with_capacity(k);
283
284    for c in 0..k {
285        let members: Vec<usize> = assignments
286            .iter()
287            .enumerate()
288            .filter(|(_, &a)| a == c as i32)
289            .map(|(i, _)| i)
290            .collect();
291
292        cluster_sizes.push(members.len());
293
294        if members.is_empty() {
295            centroids.push(vec![0.0; dim]);
296            continue;
297        }
298
299        let mut centroid = vec![0.0f32; dim];
300        for &idx in &members {
301            for (j, val) in vectors[idx].iter().enumerate() {
302                if j < dim {
303                    centroid[j] += val;
304                }
305            }
306        }
307        for val in &mut centroid {
308            *val /= members.len() as f32;
309        }
310        centroids.push(centroid);
311    }
312
313    ClusterResult {
314        assignments,
315        centroids,
316        k,
317        cluster_sizes,
318        iterations: 0,
319        converged: true,
320    }
321}
322
323/// Find all points within eps_sq (squared L2) distance of vectors[idx].
324fn range_query(vectors: &[Vec<f32>], idx: usize, eps_sq: f32) -> Vec<usize> {
325    let point = &vectors[idx];
326    vectors
327        .iter()
328        .enumerate()
329        .filter(|(_, v)| l2_squared_simd(point, v) <= eps_sq)
330        .map(|(i, _)| i)
331        .collect()
332}
333
334// ── Tests ───────────────────────────────────────────────────────────────────
335
336#[cfg(test)]
337mod tests {
338    use super::*;
339
340    #[test]
341    fn test_kmeans_basic() {
342        let vectors = vec![
343            vec![0.0, 0.0],
344            vec![0.1, 0.1],
345            vec![0.2, 0.0],
346            vec![10.0, 10.0],
347            vec![10.1, 10.1],
348            vec![10.2, 10.0],
349        ];
350        let result = kmeans(&vectors, 2, 100, 0.001);
351        assert_eq!(result.k, 2);
352        assert_eq!(result.assignments.len(), 6);
353        // First 3 should be in one cluster, last 3 in another
354        assert_eq!(result.assignments[0], result.assignments[1]);
355        assert_eq!(result.assignments[1], result.assignments[2]);
356        assert_eq!(result.assignments[3], result.assignments[4]);
357        assert_eq!(result.assignments[4], result.assignments[5]);
358        assert_ne!(result.assignments[0], result.assignments[3]);
359    }
360
361    #[test]
362    fn test_kmeans_single_cluster() {
363        let vectors = vec![vec![1.0, 1.0], vec![1.1, 1.1], vec![0.9, 0.9]];
364        let result = kmeans(&vectors, 1, 10, 0.001);
365        assert_eq!(result.k, 1);
366        assert!(result.assignments.iter().all(|&a| a == 0));
367    }
368
369    #[test]
370    fn test_kmeans_empty() {
371        let result = kmeans(&[], 5, 10, 0.001);
372        assert_eq!(result.k, 0);
373    }
374
375    #[test]
376    fn test_dbscan_basic() {
377        let vectors = vec![
378            vec![0.0, 0.0],
379            vec![0.1, 0.0],
380            vec![0.0, 0.1],
381            vec![10.0, 10.0],
382            vec![10.1, 10.0],
383            vec![10.0, 10.1],
384            vec![100.0, 100.0], // noise
385        ];
386        let result = dbscan(&vectors, 0.5, 2);
387        assert_eq!(result.k, 2);
388        // First 3 in one cluster
389        assert_eq!(result.assignments[0], result.assignments[1]);
390        assert_eq!(result.assignments[1], result.assignments[2]);
391        // Last 3 (minus noise) in another cluster
392        assert_eq!(result.assignments[3], result.assignments[4]);
393        assert_eq!(result.assignments[4], result.assignments[5]);
394        // Two distinct clusters
395        assert_ne!(result.assignments[0], result.assignments[3]);
396        // Noise point
397        assert_eq!(result.assignments[6], -1);
398    }
399
400    #[test]
401    fn test_dbscan_all_noise() {
402        let vectors = vec![vec![0.0, 0.0], vec![100.0, 100.0], vec![200.0, 200.0]];
403        let result = dbscan(&vectors, 0.1, 2);
404        assert_eq!(result.k, 0);
405        assert!(result.assignments.iter().all(|&a| a == -1));
406    }
407
408    #[test]
409    fn test_dbscan_single_cluster() {
410        let vectors = vec![
411            vec![0.0, 0.0],
412            vec![0.1, 0.0],
413            vec![0.2, 0.0],
414            vec![0.3, 0.0],
415        ];
416        let result = dbscan(&vectors, 0.15, 2);
417        assert_eq!(result.k, 1);
418        assert!(result.assignments.iter().all(|&a| a == 0));
419    }
420
421    #[test]
422    fn test_dbscan_relabels_noise_point_when_later_core_expands_cluster() {
423        let vectors = vec![
424            vec![0.0, 0.0],
425            vec![0.08, 0.0],
426            vec![0.16, 0.0],
427            vec![10.0, 10.0],
428        ];
429
430        let result = dbscan(&vectors, 0.09, 3);
431
432        assert_eq!(result.k, 1);
433        assert_eq!(result.assignments[0], 0);
434        assert_eq!(result.assignments[1], 0);
435        assert_eq!(result.assignments[2], 0);
436        assert_eq!(result.assignments[3], -1);
437        assert_eq!(result.cluster_sizes, vec![3]);
438    }
439
440    #[test]
441    fn test_dbscan_expands_density_connected_chain() {
442        let vectors = vec![
443            vec![0.0, 0.0],
444            vec![0.08, 0.0],
445            vec![0.16, 0.0],
446            vec![0.24, 0.0],
447        ];
448
449        let result = dbscan(&vectors, 0.09, 3);
450
451        assert_eq!(result.k, 1);
452        assert!(result.assignments.iter().all(|&assignment| assignment == 0));
453        assert_eq!(result.cluster_sizes, vec![4]);
454    }
455}