Skip to main content

oxirs_vec/
cluster_index.rs

1//! K-means clustering index for approximate nearest-neighbour search.
2//!
3//! Features:
4//! * Lloyd's algorithm with configurable `k` and maximum iterations
5//! * Cluster assignment (nearest centroid)
6//! * Centroid tracking (incremental updates as vectors are inserted)
7//! * Cluster statistics (size, intra-cluster variance, centroid drift)
8//! * Cluster merge (merge two closest clusters)
9//! * Cluster split (split the largest cluster)
10//! * ANN index search (probe nearest clusters for a query vector)
11
12use std::collections::HashMap;
13
14// ---------------------------------------------------------------------------
15// Error type
16// ---------------------------------------------------------------------------
17
18/// Errors produced by the cluster index.
19#[derive(Debug, Clone, PartialEq, Eq)]
20pub enum ClusterError {
21    /// Requested `k` is 0 or larger than the number of vectors.
22    InvalidK { k: usize, n: usize },
23    /// Operation requested on an empty index.
24    EmptyIndex,
25    /// The cluster id does not exist.
26    UnknownCluster(usize),
27    /// Cannot merge a cluster with itself.
28    SameCluster,
29    /// A vector with this id is already in the index.
30    DuplicateId(String),
31    /// Vectors have incompatible dimensionalities.
32    DimMismatch { expected: usize, got: usize },
33}
34
35impl std::fmt::Display for ClusterError {
36    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37        match self {
38            ClusterError::InvalidK { k, n } => {
39                write!(f, "k={k} is invalid for {n} vectors")
40            }
41            ClusterError::EmptyIndex => write!(f, "the index is empty"),
42            ClusterError::UnknownCluster(id) => write!(f, "unknown cluster id {id}"),
43            ClusterError::SameCluster => write!(f, "cannot merge a cluster with itself"),
44            ClusterError::DuplicateId(id) => write!(f, "duplicate vector id '{id}'"),
45            ClusterError::DimMismatch { expected, got } => {
46                write!(f, "dimension mismatch: expected {expected}, got {got}")
47            }
48        }
49    }
50}
51
52// ---------------------------------------------------------------------------
53// Internal helpers
54// ---------------------------------------------------------------------------
55
56/// Squared Euclidean distance between two equal-length slices.
57fn sq_dist(a: &[f32], b: &[f32]) -> f32 {
58    a.iter().zip(b.iter()).map(|(x, y)| (x - y) * (x - y)).sum()
59}
60
61/// Euclidean distance.
62fn dist(a: &[f32], b: &[f32]) -> f32 {
63    sq_dist(a, b).sqrt()
64}
65
66/// Component-wise mean of a collection of vectors.
67/// Returns `None` when `vectors` is empty.
68#[allow(dead_code)]
69fn mean_vector(vectors: &[Vec<f32>]) -> Option<Vec<f32>> {
70    if vectors.is_empty() {
71        return None;
72    }
73    let dim = vectors[0].len();
74    let mut sum = vec![0.0f32; dim];
75    for v in vectors {
76        for (s, x) in sum.iter_mut().zip(v.iter()) {
77            *s += x;
78        }
79    }
80    let n = vectors.len() as f32;
81    Some(sum.into_iter().map(|s| s / n).collect())
82}
83
84// ---------------------------------------------------------------------------
85// Cluster statistics
86// ---------------------------------------------------------------------------
87
88/// Runtime statistics for a single cluster.
89#[derive(Debug, Clone)]
90pub struct ClusterStats {
91    /// Cluster identifier.
92    pub cluster_id: usize,
93    /// Number of vectors assigned to this cluster.
94    pub size: usize,
95    /// Mean squared distance from members to centroid (intra-cluster variance).
96    pub variance: f32,
97    /// Euclidean distance the centroid moved on the last update.
98    pub centroid_drift: f32,
99    /// The centroid coordinates.
100    pub centroid: Vec<f32>,
101}
102
103// ---------------------------------------------------------------------------
104// Cluster index
105// ---------------------------------------------------------------------------
106
107/// A stored vector entry.
108#[derive(Debug, Clone)]
109struct Entry {
110    vector: Vec<f32>,
111    cluster_id: usize,
112}
113
114/// K-means clustering index supporting ANN search by cluster probing.
115#[derive(Debug, Clone)]
116pub struct ClusterIndex {
117    /// Number of clusters.
118    k: usize,
119    /// Maximum Lloyd iterations during `build`.
120    max_iter: usize,
121    /// Dimensionality (set on first insertion).
122    dim: Option<usize>,
123    /// All stored vectors keyed by their string id.
124    entries: HashMap<String, Entry>,
125    /// Ordered list of vector ids (for centroid computation by cluster).
126    id_order: Vec<String>,
127    /// Centroids: one row per cluster.
128    centroids: Vec<Vec<f32>>,
129    /// Previous centroids for drift computation.
130    prev_centroids: Vec<Vec<f32>>,
131    /// Next cluster id counter (reserved for future merge/split expansion).
132    #[allow(dead_code)]
133    next_cluster_id: usize,
134}
135
136impl ClusterIndex {
137    /// Create a new index.
138    ///
139    /// * `k`        — number of clusters (must be ≥ 1)
140    /// * `max_iter` — maximum Lloyd iterations when `build` is called
141    pub fn new(k: usize, max_iter: usize) -> Self {
142        Self {
143            k,
144            max_iter,
145            dim: None,
146            entries: HashMap::new(),
147            id_order: Vec::new(),
148            centroids: Vec::new(),
149            prev_centroids: Vec::new(),
150            next_cluster_id: k,
151        }
152    }
153
154    /// Current number of clusters.
155    pub fn num_clusters(&self) -> usize {
156        self.centroids.len()
157    }
158
159    /// Current number of stored vectors.
160    pub fn len(&self) -> usize {
161        self.entries.len()
162    }
163
164    /// Return `true` when no vectors have been inserted.
165    pub fn is_empty(&self) -> bool {
166        self.entries.is_empty()
167    }
168
169    // -----------------------------------------------------------------------
170    // Insertion
171    // -----------------------------------------------------------------------
172
173    /// Insert a vector into the index, assigning it to the nearest centroid.
174    ///
175    /// If no clustering has been built yet (`centroids` is empty), the vector
176    /// is stored without a valid cluster assignment and an initial cluster will
177    /// be assigned during `build`.
178    ///
179    /// Returns an error when:
180    /// * The id is already present.
181    /// * The vector dimensionality is inconsistent with existing vectors.
182    pub fn insert(&mut self, id: String, vector: Vec<f32>) -> Result<(), ClusterError> {
183        // Dimension check
184        match self.dim {
185            None => self.dim = Some(vector.len()),
186            Some(d) if d != vector.len() => {
187                return Err(ClusterError::DimMismatch {
188                    expected: d,
189                    got: vector.len(),
190                })
191            }
192            _ => {}
193        }
194        if self.entries.contains_key(&id) {
195            return Err(ClusterError::DuplicateId(id));
196        }
197        let cluster_id = if self.centroids.is_empty() {
198            0 // placeholder; cluster assigned after build()
199        } else {
200            self.nearest_centroid_idx(&vector)
201        };
202        self.entries
203            .insert(id.clone(), Entry { vector, cluster_id });
204        self.id_order.push(id);
205        Ok(())
206    }
207
208    // -----------------------------------------------------------------------
209    // Build (Lloyd's algorithm)
210    // -----------------------------------------------------------------------
211
212    /// Run Lloyd's K-means algorithm to assign all vectors to clusters.
213    ///
214    /// Initialisation: k-means++ style — first centroid chosen as the first
215    /// vector, subsequent centroids chosen as the furthest from the current set.
216    pub fn build(&mut self) -> Result<(), ClusterError> {
217        let n = self.entries.len();
218        if n == 0 {
219            return Err(ClusterError::EmptyIndex);
220        }
221        let effective_k = self.k.min(n);
222        if effective_k == 0 {
223            return Err(ClusterError::InvalidK { k: self.k, n });
224        }
225
226        let all_vecs: Vec<Vec<f32>> = self
227            .id_order
228            .iter()
229            .map(|id| self.entries[id].vector.clone())
230            .collect();
231        let dim = all_vecs[0].len();
232
233        // K-means++ initialisation
234        let mut centroids = vec![all_vecs[0].clone()];
235        while centroids.len() < effective_k {
236            let mut max_dist = f32::NEG_INFINITY;
237            let mut best_idx = 0usize;
238            for (i, v) in all_vecs.iter().enumerate() {
239                let min_d = centroids
240                    .iter()
241                    .map(|c| sq_dist(v, c))
242                    .fold(f32::INFINITY, f32::min);
243                if min_d > max_dist {
244                    max_dist = min_d;
245                    best_idx = i;
246                }
247            }
248            centroids.push(all_vecs[best_idx].clone());
249        }
250
251        // Lloyd iterations
252        for _ in 0..self.max_iter {
253            // Assignment step
254            let assignments: Vec<usize> = all_vecs
255                .iter()
256                .map(|v| {
257                    centroids
258                        .iter()
259                        .enumerate()
260                        .min_by(|(_, a), (_, b)| {
261                            sq_dist(v, a)
262                                .partial_cmp(&sq_dist(v, b))
263                                .unwrap_or(std::cmp::Ordering::Equal)
264                        })
265                        .map(|(i, _)| i)
266                        .unwrap_or(0)
267                })
268                .collect();
269
270            // Update step
271            let mut new_centroids = vec![vec![0.0f32; dim]; effective_k];
272            let mut counts = vec![0usize; effective_k];
273            for (v, &c) in all_vecs.iter().zip(assignments.iter()) {
274                for (nc, x) in new_centroids[c].iter_mut().zip(v.iter()) {
275                    *nc += x;
276                }
277                counts[c] += 1;
278            }
279            let mut converged = true;
280            for (i, nc) in new_centroids.iter_mut().enumerate() {
281                let cnt = counts[i].max(1);
282                for x in nc.iter_mut() {
283                    *x /= cnt as f32;
284                }
285                if dist(nc, &centroids[i]) > 1e-6 {
286                    converged = false;
287                }
288            }
289            centroids = new_centroids;
290            if converged {
291                break;
292            }
293        }
294
295        // Final assignment of all entries
296        let assignments: Vec<usize> = all_vecs
297            .iter()
298            .map(|v| {
299                centroids
300                    .iter()
301                    .enumerate()
302                    .min_by(|(_, a), (_, b)| {
303                        sq_dist(v, a)
304                            .partial_cmp(&sq_dist(v, b))
305                            .unwrap_or(std::cmp::Ordering::Equal)
306                    })
307                    .map(|(i, _)| i)
308                    .unwrap_or(0)
309            })
310            .collect();
311
312        self.prev_centroids = self.centroids.clone();
313        self.centroids = centroids;
314        for (i, id) in self.id_order.iter().enumerate() {
315            if let Some(entry) = self.entries.get_mut(id) {
316                entry.cluster_id = assignments[i];
317            }
318        }
319        Ok(())
320    }
321
322    // -----------------------------------------------------------------------
323    // Cluster assignment
324    // -----------------------------------------------------------------------
325
326    /// Return the index of the centroid nearest to `query`.
327    pub fn assign(&self, query: &[f32]) -> Option<usize> {
328        if self.centroids.is_empty() {
329            return None;
330        }
331        Some(self.nearest_centroid_idx(query))
332    }
333
334    fn nearest_centroid_idx(&self, query: &[f32]) -> usize {
335        self.centroids
336            .iter()
337            .enumerate()
338            .min_by(|(_, a), (_, b)| {
339                sq_dist(query, a)
340                    .partial_cmp(&sq_dist(query, b))
341                    .unwrap_or(std::cmp::Ordering::Equal)
342            })
343            .map(|(i, _)| i)
344            .unwrap_or(0)
345    }
346
347    // -----------------------------------------------------------------------
348    // Cluster statistics
349    // -----------------------------------------------------------------------
350
351    /// Return statistics for the cluster with the given id.
352    pub fn cluster_stats(&self, cluster_id: usize) -> Result<ClusterStats, ClusterError> {
353        if cluster_id >= self.centroids.len() {
354            return Err(ClusterError::UnknownCluster(cluster_id));
355        }
356        let centroid = &self.centroids[cluster_id];
357        let members: Vec<&Vec<f32>> = self
358            .entries
359            .values()
360            .filter(|e| e.cluster_id == cluster_id)
361            .map(|e| &e.vector)
362            .collect();
363        let size = members.len();
364        let variance = if size == 0 {
365            0.0
366        } else {
367            members.iter().map(|v| sq_dist(v, centroid)).sum::<f32>() / size as f32
368        };
369        let drift = if self.prev_centroids.len() > cluster_id {
370            dist(centroid, &self.prev_centroids[cluster_id])
371        } else {
372            0.0
373        };
374        Ok(ClusterStats {
375            cluster_id,
376            size,
377            variance,
378            centroid_drift: drift,
379            centroid: centroid.clone(),
380        })
381    }
382
383    /// Return statistics for all clusters.
384    pub fn all_cluster_stats(&self) -> Vec<ClusterStats> {
385        (0..self.centroids.len())
386            .filter_map(|id| self.cluster_stats(id).ok())
387            .collect()
388    }
389
390    // -----------------------------------------------------------------------
391    // Cluster merge
392    // -----------------------------------------------------------------------
393
394    /// Merge clusters `a` and `b` into a single cluster.
395    ///
396    /// The merged centroid is the weighted mean of the two cluster centroids.
397    /// All members of both clusters are reassigned to the lower id; the higher
398    /// id slot is removed by swapping with the last centroid and truncating.
399    pub fn merge_clusters(&mut self, a: usize, b: usize) -> Result<(), ClusterError> {
400        if a == b {
401            return Err(ClusterError::SameCluster);
402        }
403        let n = self.centroids.len();
404        if a >= n {
405            return Err(ClusterError::UnknownCluster(a));
406        }
407        if b >= n {
408            return Err(ClusterError::UnknownCluster(b));
409        }
410        let (keep, remove) = if a < b { (a, b) } else { (b, a) };
411
412        // Count members in each cluster for weighted centroid
413        let count_keep = self
414            .entries
415            .values()
416            .filter(|e| e.cluster_id == keep)
417            .count();
418        let count_remove = self
419            .entries
420            .values()
421            .filter(|e| e.cluster_id == remove)
422            .count();
423        let total = count_keep + count_remove;
424
425        let dim = self.centroids[0].len();
426        let mut merged = vec![0.0f32; dim];
427        let w_keep = count_keep as f32 / total.max(1) as f32;
428        let w_remove = count_remove as f32 / total.max(1) as f32;
429        for (m, (ck, cr)) in merged.iter_mut().zip(
430            self.centroids[keep]
431                .iter()
432                .zip(self.centroids[remove].iter()),
433        ) {
434            *m = ck * w_keep + cr * w_remove;
435        }
436        self.centroids[keep] = merged;
437
438        // Reassign members of `remove` to `keep`
439        let last = n - 1;
440        for entry in self.entries.values_mut() {
441            if entry.cluster_id == remove {
442                entry.cluster_id = keep;
443            }
444            // Fix up entries pointing to the swapped-in last centroid
445            if remove != last && entry.cluster_id == last {
446                entry.cluster_id = remove;
447            }
448        }
449
450        // Swap `remove` with last and truncate
451        self.centroids.swap(remove, last);
452        self.centroids.pop();
453        if !self.prev_centroids.is_empty() {
454            if self.prev_centroids.len() > last {
455                self.prev_centroids.swap(remove, last);
456                self.prev_centroids.pop();
457            } else {
458                self.prev_centroids.clear();
459            }
460        }
461        Ok(())
462    }
463
464    /// Merge the two closest clusters (by centroid distance).
465    pub fn merge_closest_clusters(&mut self) -> Result<(), ClusterError> {
466        let n = self.centroids.len();
467        if n < 2 {
468            return Err(ClusterError::EmptyIndex);
469        }
470        let (mut best_i, mut best_j) = (0, 1);
471        let mut best_dist = f32::INFINITY;
472        for i in 0..n {
473            for j in (i + 1)..n {
474                let d = dist(&self.centroids[i], &self.centroids[j]);
475                if d < best_dist {
476                    best_dist = d;
477                    best_i = i;
478                    best_j = j;
479                }
480            }
481        }
482        self.merge_clusters(best_i, best_j)
483    }
484
485    // -----------------------------------------------------------------------
486    // Cluster split
487    // -----------------------------------------------------------------------
488
489    /// Split the largest cluster into two by creating a new centroid displaced
490    /// by the principal direction (first PCA component approximated as the
491    /// direction of greatest variance along each axis).
492    ///
493    /// The new centroid is appended to the end; half the members are
494    /// re-assigned to it.
495    pub fn split_largest_cluster(&mut self) -> Result<(), ClusterError> {
496        if self.centroids.is_empty() {
497            return Err(ClusterError::EmptyIndex);
498        }
499        // Find the largest cluster
500        let mut counts = vec![0usize; self.centroids.len()];
501        for entry in self.entries.values() {
502            if entry.cluster_id < counts.len() {
503                counts[entry.cluster_id] += 1;
504            }
505        }
506        let largest = counts
507            .iter()
508            .enumerate()
509            .max_by_key(|(_, &c)| c)
510            .map(|(i, _)| i);
511        let cluster_id = match largest {
512            Some(id) if counts[id] >= 2 => id,
513            _ => return Err(ClusterError::EmptyIndex),
514        };
515
516        let members: Vec<String> = self
517            .id_order
518            .iter()
519            .filter(|id| {
520                self.entries
521                    .get(*id)
522                    .map(|e| e.cluster_id == cluster_id)
523                    .unwrap_or(false)
524            })
525            .cloned()
526            .collect();
527
528        let member_vecs: Vec<&Vec<f32>> = members
529            .iter()
530            .filter_map(|id| self.entries.get(id).map(|e| &e.vector))
531            .collect();
532
533        let dim = self.centroids[0].len();
534        // Compute per-axis variance; split along highest-variance axis
535        let centroid = self.centroids[cluster_id].clone();
536        let mut axis_var = vec![0.0f32; dim];
537        for v in &member_vecs {
538            for (d, x) in v.iter().enumerate() {
539                let diff = x - centroid[d];
540                axis_var[d] += diff * diff;
541            }
542        }
543        let n_members = member_vecs.len() as f32;
544        for ax in &mut axis_var {
545            *ax /= n_members;
546        }
547        let split_axis = axis_var
548            .iter()
549            .enumerate()
550            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
551            .map(|(i, _)| i)
552            .unwrap_or(0);
553        let spread = axis_var[split_axis].sqrt() * 0.5;
554
555        let mut c1 = centroid.clone();
556        let mut c2 = centroid.clone();
557        c1[split_axis] -= spread;
558        c2[split_axis] += spread;
559
560        let new_id = self.centroids.len();
561        self.centroids[cluster_id] = c1.clone();
562        self.centroids.push(c2.clone());
563
564        // Reassign each member to the nearer of the two new centroids
565        let half = members.len() / 2;
566        for (i, member_id) in members.iter().enumerate() {
567            if let Some(entry) = self.entries.get_mut(member_id) {
568                entry.cluster_id = if i < half { cluster_id } else { new_id };
569            }
570        }
571        Ok(())
572    }
573
574    // -----------------------------------------------------------------------
575    // ANN search by cluster probing
576    // -----------------------------------------------------------------------
577
578    /// Search for the `top_k` nearest vectors to `query` by probing the
579    /// `n_probes` closest cluster centroids.
580    ///
581    /// Returns a list of `(id, distance)` pairs sorted by ascending distance.
582    pub fn search(
583        &self,
584        query: &[f32],
585        top_k: usize,
586        n_probes: usize,
587    ) -> Result<Vec<(String, f32)>, ClusterError> {
588        if self.entries.is_empty() {
589            return Err(ClusterError::EmptyIndex);
590        }
591        let n_probes = n_probes.min(self.centroids.len());
592
593        // Rank clusters by distance from query
594        let mut cluster_dists: Vec<(usize, f32)> = self
595            .centroids
596            .iter()
597            .enumerate()
598            .map(|(i, c)| (i, dist(query, c)))
599            .collect();
600        cluster_dists
601            .sort_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
602
603        // Collect candidates from the nearest n_probes clusters
604        let probe_set: std::collections::HashSet<usize> = cluster_dists
605            .iter()
606            .take(n_probes)
607            .map(|(i, _)| *i)
608            .collect();
609
610        let mut candidates: Vec<(String, f32)> = self
611            .entries
612            .iter()
613            .filter(|(_, e)| probe_set.contains(&e.cluster_id))
614            .map(|(id, e)| (id.clone(), dist(query, &e.vector)))
615            .collect();
616
617        candidates.sort_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
618        candidates.truncate(top_k);
619        Ok(candidates)
620    }
621}
622
623// ---------------------------------------------------------------------------
624// Tests
625// ---------------------------------------------------------------------------
626
627#[cfg(test)]
628mod tests {
629    use super::*;
630
631    fn make_index(k: usize) -> ClusterIndex {
632        ClusterIndex::new(k, 20)
633    }
634
635    fn insert_and_build(k: usize, vectors: Vec<(&str, Vec<f32>)>) -> ClusterIndex {
636        let mut idx = make_index(k);
637        for (id, v) in vectors {
638            idx.insert(id.to_string(), v).expect("insert");
639        }
640        idx.build().expect("build");
641        idx
642    }
643
644    // -----------------------------------------------------------------------
645    // Basic insertion
646    // -----------------------------------------------------------------------
647
648    #[test]
649    fn test_insert_single_vector() {
650        let mut idx = make_index(1);
651        idx.insert("v0".into(), vec![1.0, 0.0]).expect("insert");
652        assert_eq!(idx.len(), 1);
653    }
654
655    #[test]
656    fn test_insert_duplicate_id_error() {
657        let mut idx = make_index(1);
658        idx.insert("v0".into(), vec![1.0, 0.0]).expect("insert");
659        let result = idx.insert("v0".into(), vec![0.0, 1.0]);
660        assert!(matches!(result, Err(ClusterError::DuplicateId(_))));
661    }
662
663    #[test]
664    fn test_insert_dim_mismatch_error() {
665        let mut idx = make_index(1);
666        idx.insert("v0".into(), vec![1.0, 0.0]).expect("insert");
667        let result = idx.insert("v1".into(), vec![1.0, 0.0, 0.0]);
668        assert!(matches!(result, Err(ClusterError::DimMismatch { .. })));
669    }
670
671    // -----------------------------------------------------------------------
672    // Build
673    // -----------------------------------------------------------------------
674
675    #[test]
676    fn test_build_single_cluster() {
677        let idx = insert_and_build(
678            1,
679            vec![
680                ("a", vec![1.0, 0.0]),
681                ("b", vec![2.0, 0.0]),
682                ("c", vec![3.0, 0.0]),
683            ],
684        );
685        assert_eq!(idx.num_clusters(), 1);
686        assert_eq!(idx.len(), 3);
687    }
688
689    #[test]
690    fn test_build_two_clusters() {
691        // Two well-separated groups
692        let idx = insert_and_build(
693            2,
694            vec![
695                ("a", vec![0.0, 0.0]),
696                ("b", vec![0.1, 0.0]),
697                ("c", vec![10.0, 0.0]),
698                ("d", vec![10.1, 0.0]),
699            ],
700        );
701        assert_eq!(idx.num_clusters(), 2);
702    }
703
704    #[test]
705    fn test_build_empty_index_error() {
706        let mut idx = make_index(3);
707        let result = idx.build();
708        assert!(matches!(result, Err(ClusterError::EmptyIndex)));
709    }
710
711    // -----------------------------------------------------------------------
712    // Cluster assignment
713    // -----------------------------------------------------------------------
714
715    #[test]
716    fn test_assign_returns_cluster() {
717        let idx = insert_and_build(
718            2,
719            vec![
720                ("a", vec![0.0f32, 0.0]),
721                ("b", vec![0.0, 0.1]),
722                ("c", vec![10.0, 0.0]),
723                ("d", vec![10.0, 0.1]),
724            ],
725        );
726        let cluster_near_origin = idx.assign(&[0.05, 0.0]).expect("assign");
727        let cluster_far = idx.assign(&[10.05, 0.0]).expect("assign");
728        assert_ne!(cluster_near_origin, cluster_far);
729    }
730
731    #[test]
732    fn test_assign_empty_returns_none() {
733        let idx = make_index(2);
734        assert!(idx.assign(&[1.0, 0.0]).is_none());
735    }
736
737    // -----------------------------------------------------------------------
738    // Cluster statistics
739    // -----------------------------------------------------------------------
740
741    #[test]
742    fn test_cluster_stats_size() {
743        let idx = insert_and_build(1, vec![("a", vec![1.0f32, 0.0]), ("b", vec![2.0, 0.0])]);
744        let stats = idx.cluster_stats(0).expect("stats");
745        assert_eq!(stats.cluster_id, 0);
746        assert_eq!(stats.size, 2);
747    }
748
749    #[test]
750    fn test_cluster_stats_unknown_error() {
751        let idx = insert_and_build(1, vec![("a", vec![1.0f32, 0.0])]);
752        let result = idx.cluster_stats(99);
753        assert!(matches!(result, Err(ClusterError::UnknownCluster(99))));
754    }
755
756    #[test]
757    fn test_cluster_stats_variance_non_negative() {
758        let idx = insert_and_build(
759            1,
760            vec![
761                ("a", vec![1.0f32, 0.0]),
762                ("b", vec![2.0, 0.0]),
763                ("c", vec![3.0, 0.0]),
764            ],
765        );
766        let stats = idx.cluster_stats(0).expect("stats");
767        assert!(stats.variance >= 0.0);
768    }
769
770    #[test]
771    fn test_all_cluster_stats_count() {
772        let idx = insert_and_build(
773            2,
774            vec![
775                ("a", vec![0.0f32, 0.0]),
776                ("b", vec![0.0, 0.1]),
777                ("c", vec![10.0, 0.0]),
778                ("d", vec![10.0, 0.1]),
779            ],
780        );
781        assert_eq!(idx.all_cluster_stats().len(), 2);
782    }
783
784    // -----------------------------------------------------------------------
785    // Cluster merge
786    // -----------------------------------------------------------------------
787
788    #[test]
789    fn test_merge_clusters_reduces_count() {
790        let mut idx = insert_and_build(
791            2,
792            vec![
793                ("a", vec![0.0f32, 0.0]),
794                ("b", vec![0.0, 0.1]),
795                ("c", vec![10.0, 0.0]),
796                ("d", vec![10.0, 0.1]),
797            ],
798        );
799        idx.merge_clusters(0, 1).expect("merge");
800        assert_eq!(idx.num_clusters(), 1);
801    }
802
803    #[test]
804    fn test_merge_same_cluster_error() {
805        let mut idx = insert_and_build(1, vec![("a", vec![1.0f32, 0.0])]);
806        let result = idx.merge_clusters(0, 0);
807        assert!(matches!(result, Err(ClusterError::SameCluster)));
808    }
809
810    #[test]
811    fn test_merge_closest_clusters() {
812        let mut idx = insert_and_build(
813            3,
814            vec![
815                ("a", vec![0.0f32, 0.0]),
816                ("b", vec![0.0, 0.1]),
817                ("c", vec![5.0, 0.0]),
818                ("d", vec![100.0, 0.0]),
819                ("e", vec![100.0, 0.1]),
820            ],
821        );
822        let before = idx.num_clusters();
823        idx.merge_closest_clusters().expect("merge");
824        assert_eq!(idx.num_clusters(), before - 1);
825    }
826
827    #[test]
828    fn test_merge_unknown_cluster_error() {
829        let mut idx = insert_and_build(1, vec![("a", vec![1.0f32, 0.0])]);
830        let result = idx.merge_clusters(0, 99);
831        assert!(matches!(result, Err(ClusterError::UnknownCluster(99))));
832    }
833
834    // -----------------------------------------------------------------------
835    // Cluster split
836    // -----------------------------------------------------------------------
837
838    #[test]
839    fn test_split_largest_cluster_increases_count() {
840        let mut idx = insert_and_build(
841            1,
842            vec![
843                ("a", vec![1.0f32, 0.0]),
844                ("b", vec![2.0, 0.0]),
845                ("c", vec![3.0, 0.0]),
846                ("d", vec![4.0, 0.0]),
847            ],
848        );
849        let before = idx.num_clusters();
850        idx.split_largest_cluster().expect("split");
851        assert_eq!(idx.num_clusters(), before + 1);
852    }
853
854    #[test]
855    fn test_split_empty_error() {
856        let mut idx = make_index(1);
857        assert!(matches!(
858            idx.split_largest_cluster(),
859            Err(ClusterError::EmptyIndex)
860        ));
861    }
862
863    // -----------------------------------------------------------------------
864    // ANN search
865    // -----------------------------------------------------------------------
866
867    #[test]
868    fn test_search_returns_nearest() {
869        let idx = insert_and_build(
870            2,
871            vec![
872                ("origin", vec![0.0f32, 0.0]),
873                ("near", vec![0.1, 0.0]),
874                ("far", vec![10.0, 0.0]),
875            ],
876        );
877        let results = idx.search(&[0.0, 0.0], 1, 2).expect("search");
878        assert_eq!(results.len(), 1);
879        assert_eq!(results[0].0, "origin");
880    }
881
882    #[test]
883    fn test_search_top_k_limit() {
884        let idx = insert_and_build(
885            1,
886            vec![
887                ("a", vec![0.0f32, 0.0]),
888                ("b", vec![1.0, 0.0]),
889                ("c", vec![2.0, 0.0]),
890            ],
891        );
892        let results = idx.search(&[0.0, 0.0], 2, 1).expect("search");
893        assert!(results.len() <= 2);
894    }
895
896    #[test]
897    fn test_search_empty_error() {
898        let idx = make_index(2);
899        let result = idx.search(&[0.0, 0.0], 3, 1);
900        assert!(matches!(result, Err(ClusterError::EmptyIndex)));
901    }
902
903    #[test]
904    fn test_search_results_sorted_by_distance() {
905        let idx = insert_and_build(
906            1,
907            vec![
908                ("near", vec![0.1f32, 0.0]),
909                ("mid", vec![1.0, 0.0]),
910                ("far", vec![5.0, 0.0]),
911            ],
912        );
913        let results = idx.search(&[0.0, 0.0], 3, 1).expect("search");
914        for i in 1..results.len() {
915            assert!(results[i - 1].1 <= results[i].1);
916        }
917    }
918
919    // -----------------------------------------------------------------------
920    // is_empty / len
921    // -----------------------------------------------------------------------
922
923    #[test]
924    fn test_is_empty_initial() {
925        let idx = make_index(2);
926        assert!(idx.is_empty());
927    }
928
929    #[test]
930    fn test_len_after_inserts() {
931        let mut idx = make_index(2);
932        idx.insert("v0".into(), vec![0.0f32, 0.0]).expect("insert");
933        idx.insert("v1".into(), vec![1.0, 0.0]).expect("insert");
934        assert_eq!(idx.len(), 2);
935    }
936
937    // -----------------------------------------------------------------------
938    // Additional edge-case tests
939    // -----------------------------------------------------------------------
940
941    #[test]
942    fn test_build_k_equals_n() {
943        // k == number of vectors → each vector is its own centroid
944        let idx = insert_and_build(
945            3,
946            vec![
947                ("a", vec![0.0f32, 0.0]),
948                ("b", vec![5.0, 0.0]),
949                ("c", vec![10.0, 0.0]),
950            ],
951        );
952        assert_eq!(idx.num_clusters(), 3);
953    }
954
955    #[test]
956    fn test_build_more_k_than_vectors_clamped() {
957        // k > n → clamped to n
958        let idx = insert_and_build(10, vec![("a", vec![0.0f32, 0.0]), ("b", vec![1.0, 0.0])]);
959        assert_eq!(idx.num_clusters(), 2);
960    }
961
962    #[test]
963    fn test_assign_after_build_consistent() {
964        let idx = insert_and_build(1, vec![("a", vec![3.0f32, 3.0]), ("b", vec![3.1, 3.1])]);
965        // Both should map to cluster 0 (only cluster)
966        assert_eq!(idx.assign(&[3.05, 3.05]), Some(0));
967    }
968
969    #[test]
970    fn test_cluster_stats_single_member_zero_variance() {
971        let mut idx = make_index(1);
972        idx.insert("only".into(), vec![7.0f32, 7.0])
973            .expect("insert");
974        idx.build().expect("build");
975        let stats = idx.cluster_stats(0).expect("stats");
976        assert_eq!(stats.size, 1);
977        assert!(stats.variance < 1e-6);
978    }
979
980    #[test]
981    fn test_all_stats_total_size_equals_len() {
982        let idx = insert_and_build(
983            2,
984            vec![
985                ("a", vec![0.0f32, 0.0]),
986                ("b", vec![0.0, 0.1]),
987                ("c", vec![10.0, 0.0]),
988                ("d", vec![10.0, 0.1]),
989            ],
990        );
991        let total: usize = idx.all_cluster_stats().iter().map(|s| s.size).sum();
992        assert_eq!(total, idx.len());
993    }
994
995    #[test]
996    fn test_merge_all_members_accounted_for() {
997        let mut idx = insert_and_build(
998            2,
999            vec![
1000                ("a", vec![0.0f32, 0.0]),
1001                ("b", vec![0.0, 0.1]),
1002                ("c", vec![10.0, 0.0]),
1003                ("d", vec![10.0, 0.1]),
1004            ],
1005        );
1006        idx.merge_clusters(0, 1).expect("merge");
1007        let stats = idx.cluster_stats(0).expect("stats");
1008        assert_eq!(stats.size, 4); // all members merged
1009    }
1010
1011    #[test]
1012    fn test_split_two_disjoint_halves() {
1013        let idx = insert_and_build(
1014            1,
1015            vec![
1016                ("lo1", vec![-5.0f32, 0.0]),
1017                ("lo2", vec![-4.0, 0.0]),
1018                ("hi1", vec![4.0, 0.0]),
1019                ("hi2", vec![5.0, 0.0]),
1020            ],
1021        );
1022        let stats = idx.all_cluster_stats();
1023        assert!(!stats.is_empty());
1024    }
1025
1026    #[test]
1027    fn test_search_n_probes_one_finds_all_in_cluster() {
1028        let idx = insert_and_build(
1029            1,
1030            vec![
1031                ("a", vec![1.0f32, 0.0]),
1032                ("b", vec![2.0, 0.0]),
1033                ("c", vec![3.0, 0.0]),
1034            ],
1035        );
1036        let results = idx.search(&[2.0, 0.0], 3, 1).expect("search");
1037        assert_eq!(results.len(), 3);
1038    }
1039
1040    #[test]
1041    fn test_cluster_index_3d_vectors() {
1042        let idx = insert_and_build(
1043            2,
1044            vec![
1045                ("a", vec![0.0f32, 0.0, 0.0]),
1046                ("b", vec![0.1, 0.0, 0.0]),
1047                ("c", vec![10.0, 10.0, 10.0]),
1048                ("d", vec![10.1, 10.0, 10.0]),
1049            ],
1050        );
1051        let r = idx.search(&[0.05, 0.0, 0.0], 1, 1).expect("search");
1052        assert!(!r.is_empty());
1053    }
1054
1055    #[test]
1056    fn test_assign_unknown_without_build_returns_none_or_some() {
1057        let mut idx = make_index(2);
1058        idx.insert("v0".into(), vec![1.0f32, 0.0]).expect("insert");
1059        // Before build: centroids empty → None
1060        assert!(idx.assign(&[1.0, 0.0]).is_none());
1061    }
1062
1063    #[test]
1064    fn test_cluster_error_display() {
1065        let e = ClusterError::InvalidK { k: 0, n: 5 };
1066        assert!(e.to_string().contains("invalid"));
1067        let e2 = ClusterError::EmptyIndex;
1068        assert!(e2.to_string().contains("empty"));
1069        let e3 = ClusterError::DuplicateId("x".into());
1070        assert!(e3.to_string().contains("x"));
1071        let e4 = ClusterError::DimMismatch {
1072            expected: 3,
1073            got: 2,
1074        };
1075        assert!(e4.to_string().contains("mismatch"));
1076    }
1077
1078    #[test]
1079    fn test_search_multi_probe() {
1080        let idx = insert_and_build(
1081            3,
1082            vec![
1083                ("c1a", vec![0.0f32, 0.0]),
1084                ("c1b", vec![0.1, 0.0]),
1085                ("c2a", vec![5.0, 0.0]),
1086                ("c2b", vec![5.1, 0.0]),
1087                ("c3a", vec![10.0, 0.0]),
1088                ("c3b", vec![10.1, 0.0]),
1089            ],
1090        );
1091        // Probe 2 clusters should find at least 2 results
1092        let results = idx.search(&[0.0, 0.0], 5, 2).expect("search");
1093        assert!(!results.is_empty());
1094    }
1095
1096    #[test]
1097    fn test_split_then_search() {
1098        let mut idx = insert_and_build(
1099            1,
1100            vec![
1101                ("a", vec![-3.0f32, 0.0]),
1102                ("b", vec![-2.0, 0.0]),
1103                ("c", vec![2.0, 0.0]),
1104                ("d", vec![3.0, 0.0]),
1105            ],
1106        );
1107        idx.split_largest_cluster().expect("split");
1108        let results = idx.search(&[-3.0, 0.0], 2, 2).expect("search");
1109        assert!(!results.is_empty());
1110    }
1111
1112    #[test]
1113    fn test_merge_then_build_consistent() {
1114        let mut idx = insert_and_build(
1115            2,
1116            vec![
1117                ("a", vec![0.0f32, 0.0]),
1118                ("b", vec![0.1, 0.0]),
1119                ("c", vec![10.0, 0.0]),
1120                ("d", vec![10.1, 0.0]),
1121            ],
1122        );
1123        idx.merge_clusters(0, 1).expect("merge");
1124        assert_eq!(idx.num_clusters(), 1);
1125        // After merge, should still be searchable
1126        let r = idx.search(&[5.0, 0.0], 2, 1).expect("search");
1127        assert_eq!(r.len(), 2);
1128    }
1129
1130    #[test]
1131    fn test_num_clusters_zero_before_build() {
1132        let idx = make_index(3);
1133        assert_eq!(idx.num_clusters(), 0);
1134    }
1135
1136    #[test]
1137    fn test_build_single_vector() {
1138        let idx = insert_and_build(1, vec![("solo", vec![1.0f32, 2.0])]);
1139        assert_eq!(idx.num_clusters(), 1);
1140        assert_eq!(idx.len(), 1);
1141    }
1142
1143    #[test]
1144    fn test_search_returns_correct_id() {
1145        let idx = insert_and_build(
1146            1,
1147            vec![("x_near", vec![0.01f32, 0.0]), ("x_far", vec![100.0, 0.0])],
1148        );
1149        let r = idx.search(&[0.0, 0.0], 1, 1).expect("search");
1150        assert_eq!(r[0].0, "x_near");
1151    }
1152
1153    #[test]
1154    fn test_search_distance_is_non_negative() {
1155        let idx = insert_and_build(1, vec![("a", vec![1.0f32, 0.0]), ("b", vec![2.0, 0.0])]);
1156        let r = idx.search(&[0.0, 0.0], 2, 1).expect("search");
1157        for (_, d) in &r {
1158            assert!(*d >= 0.0);
1159        }
1160    }
1161
1162    #[test]
1163    fn test_cluster_stats_centroid_len() {
1164        let idx = insert_and_build(1, vec![("a", vec![1.0f32, 2.0, 3.0])]);
1165        let s = idx.cluster_stats(0).expect("stats");
1166        assert_eq!(s.centroid.len(), 3);
1167    }
1168
1169    #[test]
1170    fn test_merge_closest_with_two_clusters() {
1171        let mut idx = insert_and_build(
1172            2,
1173            vec![
1174                ("near_a", vec![0.0f32, 0.0]),
1175                ("near_b", vec![0.1, 0.0]),
1176                ("far_a", vec![100.0, 0.0]),
1177                ("far_b", vec![100.1, 0.0]),
1178            ],
1179        );
1180        // Merging the two clusters (they are the only two) should work
1181        idx.merge_closest_clusters().expect("merge");
1182        assert_eq!(idx.num_clusters(), 1);
1183    }
1184
1185    #[test]
1186    fn test_cluster_error_unknown_cluster_display() {
1187        let e = ClusterError::UnknownCluster(42);
1188        assert!(e.to_string().contains("42"));
1189    }
1190}