1use std::collections::HashMap;
13
14#[derive(Debug, Clone, PartialEq, Eq)]
20pub enum ClusterError {
21 InvalidK { k: usize, n: usize },
23 EmptyIndex,
25 UnknownCluster(usize),
27 SameCluster,
29 DuplicateId(String),
31 DimMismatch { expected: usize, got: usize },
33}
34
35impl std::fmt::Display for ClusterError {
36 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37 match self {
38 ClusterError::InvalidK { k, n } => {
39 write!(f, "k={k} is invalid for {n} vectors")
40 }
41 ClusterError::EmptyIndex => write!(f, "the index is empty"),
42 ClusterError::UnknownCluster(id) => write!(f, "unknown cluster id {id}"),
43 ClusterError::SameCluster => write!(f, "cannot merge a cluster with itself"),
44 ClusterError::DuplicateId(id) => write!(f, "duplicate vector id '{id}'"),
45 ClusterError::DimMismatch { expected, got } => {
46 write!(f, "dimension mismatch: expected {expected}, got {got}")
47 }
48 }
49 }
50}
51
52fn sq_dist(a: &[f32], b: &[f32]) -> f32 {
58 a.iter().zip(b.iter()).map(|(x, y)| (x - y) * (x - y)).sum()
59}
60
61fn dist(a: &[f32], b: &[f32]) -> f32 {
63 sq_dist(a, b).sqrt()
64}
65
66#[allow(dead_code)]
69fn mean_vector(vectors: &[Vec<f32>]) -> Option<Vec<f32>> {
70 if vectors.is_empty() {
71 return None;
72 }
73 let dim = vectors[0].len();
74 let mut sum = vec![0.0f32; dim];
75 for v in vectors {
76 for (s, x) in sum.iter_mut().zip(v.iter()) {
77 *s += x;
78 }
79 }
80 let n = vectors.len() as f32;
81 Some(sum.into_iter().map(|s| s / n).collect())
82}
83
84#[derive(Debug, Clone)]
90pub struct ClusterStats {
91 pub cluster_id: usize,
93 pub size: usize,
95 pub variance: f32,
97 pub centroid_drift: f32,
99 pub centroid: Vec<f32>,
101}
102
103#[derive(Debug, Clone)]
109struct Entry {
110 vector: Vec<f32>,
111 cluster_id: usize,
112}
113
114#[derive(Debug, Clone)]
116pub struct ClusterIndex {
117 k: usize,
119 max_iter: usize,
121 dim: Option<usize>,
123 entries: HashMap<String, Entry>,
125 id_order: Vec<String>,
127 centroids: Vec<Vec<f32>>,
129 prev_centroids: Vec<Vec<f32>>,
131 #[allow(dead_code)]
133 next_cluster_id: usize,
134}
135
136impl ClusterIndex {
137 pub fn new(k: usize, max_iter: usize) -> Self {
142 Self {
143 k,
144 max_iter,
145 dim: None,
146 entries: HashMap::new(),
147 id_order: Vec::new(),
148 centroids: Vec::new(),
149 prev_centroids: Vec::new(),
150 next_cluster_id: k,
151 }
152 }
153
154 pub fn num_clusters(&self) -> usize {
156 self.centroids.len()
157 }
158
159 pub fn len(&self) -> usize {
161 self.entries.len()
162 }
163
164 pub fn is_empty(&self) -> bool {
166 self.entries.is_empty()
167 }
168
169 pub fn insert(&mut self, id: String, vector: Vec<f32>) -> Result<(), ClusterError> {
183 match self.dim {
185 None => self.dim = Some(vector.len()),
186 Some(d) if d != vector.len() => {
187 return Err(ClusterError::DimMismatch {
188 expected: d,
189 got: vector.len(),
190 })
191 }
192 _ => {}
193 }
194 if self.entries.contains_key(&id) {
195 return Err(ClusterError::DuplicateId(id));
196 }
197 let cluster_id = if self.centroids.is_empty() {
198 0 } else {
200 self.nearest_centroid_idx(&vector)
201 };
202 self.entries
203 .insert(id.clone(), Entry { vector, cluster_id });
204 self.id_order.push(id);
205 Ok(())
206 }
207
208 pub fn build(&mut self) -> Result<(), ClusterError> {
217 let n = self.entries.len();
218 if n == 0 {
219 return Err(ClusterError::EmptyIndex);
220 }
221 let effective_k = self.k.min(n);
222 if effective_k == 0 {
223 return Err(ClusterError::InvalidK { k: self.k, n });
224 }
225
226 let all_vecs: Vec<Vec<f32>> = self
227 .id_order
228 .iter()
229 .map(|id| self.entries[id].vector.clone())
230 .collect();
231 let dim = all_vecs[0].len();
232
233 let mut centroids = vec![all_vecs[0].clone()];
235 while centroids.len() < effective_k {
236 let mut max_dist = f32::NEG_INFINITY;
237 let mut best_idx = 0usize;
238 for (i, v) in all_vecs.iter().enumerate() {
239 let min_d = centroids
240 .iter()
241 .map(|c| sq_dist(v, c))
242 .fold(f32::INFINITY, f32::min);
243 if min_d > max_dist {
244 max_dist = min_d;
245 best_idx = i;
246 }
247 }
248 centroids.push(all_vecs[best_idx].clone());
249 }
250
251 for _ in 0..self.max_iter {
253 let assignments: Vec<usize> = all_vecs
255 .iter()
256 .map(|v| {
257 centroids
258 .iter()
259 .enumerate()
260 .min_by(|(_, a), (_, b)| {
261 sq_dist(v, a)
262 .partial_cmp(&sq_dist(v, b))
263 .unwrap_or(std::cmp::Ordering::Equal)
264 })
265 .map(|(i, _)| i)
266 .unwrap_or(0)
267 })
268 .collect();
269
270 let mut new_centroids = vec![vec![0.0f32; dim]; effective_k];
272 let mut counts = vec![0usize; effective_k];
273 for (v, &c) in all_vecs.iter().zip(assignments.iter()) {
274 for (nc, x) in new_centroids[c].iter_mut().zip(v.iter()) {
275 *nc += x;
276 }
277 counts[c] += 1;
278 }
279 let mut converged = true;
280 for (i, nc) in new_centroids.iter_mut().enumerate() {
281 let cnt = counts[i].max(1);
282 for x in nc.iter_mut() {
283 *x /= cnt as f32;
284 }
285 if dist(nc, ¢roids[i]) > 1e-6 {
286 converged = false;
287 }
288 }
289 centroids = new_centroids;
290 if converged {
291 break;
292 }
293 }
294
295 let assignments: Vec<usize> = all_vecs
297 .iter()
298 .map(|v| {
299 centroids
300 .iter()
301 .enumerate()
302 .min_by(|(_, a), (_, b)| {
303 sq_dist(v, a)
304 .partial_cmp(&sq_dist(v, b))
305 .unwrap_or(std::cmp::Ordering::Equal)
306 })
307 .map(|(i, _)| i)
308 .unwrap_or(0)
309 })
310 .collect();
311
312 self.prev_centroids = self.centroids.clone();
313 self.centroids = centroids;
314 for (i, id) in self.id_order.iter().enumerate() {
315 if let Some(entry) = self.entries.get_mut(id) {
316 entry.cluster_id = assignments[i];
317 }
318 }
319 Ok(())
320 }
321
322 pub fn assign(&self, query: &[f32]) -> Option<usize> {
328 if self.centroids.is_empty() {
329 return None;
330 }
331 Some(self.nearest_centroid_idx(query))
332 }
333
334 fn nearest_centroid_idx(&self, query: &[f32]) -> usize {
335 self.centroids
336 .iter()
337 .enumerate()
338 .min_by(|(_, a), (_, b)| {
339 sq_dist(query, a)
340 .partial_cmp(&sq_dist(query, b))
341 .unwrap_or(std::cmp::Ordering::Equal)
342 })
343 .map(|(i, _)| i)
344 .unwrap_or(0)
345 }
346
347 pub fn cluster_stats(&self, cluster_id: usize) -> Result<ClusterStats, ClusterError> {
353 if cluster_id >= self.centroids.len() {
354 return Err(ClusterError::UnknownCluster(cluster_id));
355 }
356 let centroid = &self.centroids[cluster_id];
357 let members: Vec<&Vec<f32>> = self
358 .entries
359 .values()
360 .filter(|e| e.cluster_id == cluster_id)
361 .map(|e| &e.vector)
362 .collect();
363 let size = members.len();
364 let variance = if size == 0 {
365 0.0
366 } else {
367 members.iter().map(|v| sq_dist(v, centroid)).sum::<f32>() / size as f32
368 };
369 let drift = if self.prev_centroids.len() > cluster_id {
370 dist(centroid, &self.prev_centroids[cluster_id])
371 } else {
372 0.0
373 };
374 Ok(ClusterStats {
375 cluster_id,
376 size,
377 variance,
378 centroid_drift: drift,
379 centroid: centroid.clone(),
380 })
381 }
382
383 pub fn all_cluster_stats(&self) -> Vec<ClusterStats> {
385 (0..self.centroids.len())
386 .filter_map(|id| self.cluster_stats(id).ok())
387 .collect()
388 }
389
390 pub fn merge_clusters(&mut self, a: usize, b: usize) -> Result<(), ClusterError> {
400 if a == b {
401 return Err(ClusterError::SameCluster);
402 }
403 let n = self.centroids.len();
404 if a >= n {
405 return Err(ClusterError::UnknownCluster(a));
406 }
407 if b >= n {
408 return Err(ClusterError::UnknownCluster(b));
409 }
410 let (keep, remove) = if a < b { (a, b) } else { (b, a) };
411
412 let count_keep = self
414 .entries
415 .values()
416 .filter(|e| e.cluster_id == keep)
417 .count();
418 let count_remove = self
419 .entries
420 .values()
421 .filter(|e| e.cluster_id == remove)
422 .count();
423 let total = count_keep + count_remove;
424
425 let dim = self.centroids[0].len();
426 let mut merged = vec![0.0f32; dim];
427 let w_keep = count_keep as f32 / total.max(1) as f32;
428 let w_remove = count_remove as f32 / total.max(1) as f32;
429 for (m, (ck, cr)) in merged.iter_mut().zip(
430 self.centroids[keep]
431 .iter()
432 .zip(self.centroids[remove].iter()),
433 ) {
434 *m = ck * w_keep + cr * w_remove;
435 }
436 self.centroids[keep] = merged;
437
438 let last = n - 1;
440 for entry in self.entries.values_mut() {
441 if entry.cluster_id == remove {
442 entry.cluster_id = keep;
443 }
444 if remove != last && entry.cluster_id == last {
446 entry.cluster_id = remove;
447 }
448 }
449
450 self.centroids.swap(remove, last);
452 self.centroids.pop();
453 if !self.prev_centroids.is_empty() {
454 if self.prev_centroids.len() > last {
455 self.prev_centroids.swap(remove, last);
456 self.prev_centroids.pop();
457 } else {
458 self.prev_centroids.clear();
459 }
460 }
461 Ok(())
462 }
463
464 pub fn merge_closest_clusters(&mut self) -> Result<(), ClusterError> {
466 let n = self.centroids.len();
467 if n < 2 {
468 return Err(ClusterError::EmptyIndex);
469 }
470 let (mut best_i, mut best_j) = (0, 1);
471 let mut best_dist = f32::INFINITY;
472 for i in 0..n {
473 for j in (i + 1)..n {
474 let d = dist(&self.centroids[i], &self.centroids[j]);
475 if d < best_dist {
476 best_dist = d;
477 best_i = i;
478 best_j = j;
479 }
480 }
481 }
482 self.merge_clusters(best_i, best_j)
483 }
484
485 pub fn split_largest_cluster(&mut self) -> Result<(), ClusterError> {
496 if self.centroids.is_empty() {
497 return Err(ClusterError::EmptyIndex);
498 }
499 let mut counts = vec![0usize; self.centroids.len()];
501 for entry in self.entries.values() {
502 if entry.cluster_id < counts.len() {
503 counts[entry.cluster_id] += 1;
504 }
505 }
506 let largest = counts
507 .iter()
508 .enumerate()
509 .max_by_key(|(_, &c)| c)
510 .map(|(i, _)| i);
511 let cluster_id = match largest {
512 Some(id) if counts[id] >= 2 => id,
513 _ => return Err(ClusterError::EmptyIndex),
514 };
515
516 let members: Vec<String> = self
517 .id_order
518 .iter()
519 .filter(|id| {
520 self.entries
521 .get(*id)
522 .map(|e| e.cluster_id == cluster_id)
523 .unwrap_or(false)
524 })
525 .cloned()
526 .collect();
527
528 let member_vecs: Vec<&Vec<f32>> = members
529 .iter()
530 .filter_map(|id| self.entries.get(id).map(|e| &e.vector))
531 .collect();
532
533 let dim = self.centroids[0].len();
534 let centroid = self.centroids[cluster_id].clone();
536 let mut axis_var = vec![0.0f32; dim];
537 for v in &member_vecs {
538 for (d, x) in v.iter().enumerate() {
539 let diff = x - centroid[d];
540 axis_var[d] += diff * diff;
541 }
542 }
543 let n_members = member_vecs.len() as f32;
544 for ax in &mut axis_var {
545 *ax /= n_members;
546 }
547 let split_axis = axis_var
548 .iter()
549 .enumerate()
550 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
551 .map(|(i, _)| i)
552 .unwrap_or(0);
553 let spread = axis_var[split_axis].sqrt() * 0.5;
554
555 let mut c1 = centroid.clone();
556 let mut c2 = centroid.clone();
557 c1[split_axis] -= spread;
558 c2[split_axis] += spread;
559
560 let new_id = self.centroids.len();
561 self.centroids[cluster_id] = c1.clone();
562 self.centroids.push(c2.clone());
563
564 let half = members.len() / 2;
566 for (i, member_id) in members.iter().enumerate() {
567 if let Some(entry) = self.entries.get_mut(member_id) {
568 entry.cluster_id = if i < half { cluster_id } else { new_id };
569 }
570 }
571 Ok(())
572 }
573
574 pub fn search(
583 &self,
584 query: &[f32],
585 top_k: usize,
586 n_probes: usize,
587 ) -> Result<Vec<(String, f32)>, ClusterError> {
588 if self.entries.is_empty() {
589 return Err(ClusterError::EmptyIndex);
590 }
591 let n_probes = n_probes.min(self.centroids.len());
592
593 let mut cluster_dists: Vec<(usize, f32)> = self
595 .centroids
596 .iter()
597 .enumerate()
598 .map(|(i, c)| (i, dist(query, c)))
599 .collect();
600 cluster_dists
601 .sort_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
602
603 let probe_set: std::collections::HashSet<usize> = cluster_dists
605 .iter()
606 .take(n_probes)
607 .map(|(i, _)| *i)
608 .collect();
609
610 let mut candidates: Vec<(String, f32)> = self
611 .entries
612 .iter()
613 .filter(|(_, e)| probe_set.contains(&e.cluster_id))
614 .map(|(id, e)| (id.clone(), dist(query, &e.vector)))
615 .collect();
616
617 candidates.sort_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
618 candidates.truncate(top_k);
619 Ok(candidates)
620 }
621}
622
623#[cfg(test)]
628mod tests {
629 use super::*;
630
631 fn make_index(k: usize) -> ClusterIndex {
632 ClusterIndex::new(k, 20)
633 }
634
635 fn insert_and_build(k: usize, vectors: Vec<(&str, Vec<f32>)>) -> ClusterIndex {
636 let mut idx = make_index(k);
637 for (id, v) in vectors {
638 idx.insert(id.to_string(), v).expect("insert");
639 }
640 idx.build().expect("build");
641 idx
642 }
643
644 #[test]
649 fn test_insert_single_vector() {
650 let mut idx = make_index(1);
651 idx.insert("v0".into(), vec![1.0, 0.0]).expect("insert");
652 assert_eq!(idx.len(), 1);
653 }
654
655 #[test]
656 fn test_insert_duplicate_id_error() {
657 let mut idx = make_index(1);
658 idx.insert("v0".into(), vec![1.0, 0.0]).expect("insert");
659 let result = idx.insert("v0".into(), vec![0.0, 1.0]);
660 assert!(matches!(result, Err(ClusterError::DuplicateId(_))));
661 }
662
663 #[test]
664 fn test_insert_dim_mismatch_error() {
665 let mut idx = make_index(1);
666 idx.insert("v0".into(), vec![1.0, 0.0]).expect("insert");
667 let result = idx.insert("v1".into(), vec![1.0, 0.0, 0.0]);
668 assert!(matches!(result, Err(ClusterError::DimMismatch { .. })));
669 }
670
671 #[test]
676 fn test_build_single_cluster() {
677 let idx = insert_and_build(
678 1,
679 vec![
680 ("a", vec![1.0, 0.0]),
681 ("b", vec![2.0, 0.0]),
682 ("c", vec![3.0, 0.0]),
683 ],
684 );
685 assert_eq!(idx.num_clusters(), 1);
686 assert_eq!(idx.len(), 3);
687 }
688
689 #[test]
690 fn test_build_two_clusters() {
691 let idx = insert_and_build(
693 2,
694 vec![
695 ("a", vec![0.0, 0.0]),
696 ("b", vec![0.1, 0.0]),
697 ("c", vec![10.0, 0.0]),
698 ("d", vec![10.1, 0.0]),
699 ],
700 );
701 assert_eq!(idx.num_clusters(), 2);
702 }
703
704 #[test]
705 fn test_build_empty_index_error() {
706 let mut idx = make_index(3);
707 let result = idx.build();
708 assert!(matches!(result, Err(ClusterError::EmptyIndex)));
709 }
710
711 #[test]
716 fn test_assign_returns_cluster() {
717 let idx = insert_and_build(
718 2,
719 vec![
720 ("a", vec![0.0f32, 0.0]),
721 ("b", vec![0.0, 0.1]),
722 ("c", vec![10.0, 0.0]),
723 ("d", vec![10.0, 0.1]),
724 ],
725 );
726 let cluster_near_origin = idx.assign(&[0.05, 0.0]).expect("assign");
727 let cluster_far = idx.assign(&[10.05, 0.0]).expect("assign");
728 assert_ne!(cluster_near_origin, cluster_far);
729 }
730
731 #[test]
732 fn test_assign_empty_returns_none() {
733 let idx = make_index(2);
734 assert!(idx.assign(&[1.0, 0.0]).is_none());
735 }
736
737 #[test]
742 fn test_cluster_stats_size() {
743 let idx = insert_and_build(1, vec![("a", vec![1.0f32, 0.0]), ("b", vec![2.0, 0.0])]);
744 let stats = idx.cluster_stats(0).expect("stats");
745 assert_eq!(stats.cluster_id, 0);
746 assert_eq!(stats.size, 2);
747 }
748
749 #[test]
750 fn test_cluster_stats_unknown_error() {
751 let idx = insert_and_build(1, vec![("a", vec![1.0f32, 0.0])]);
752 let result = idx.cluster_stats(99);
753 assert!(matches!(result, Err(ClusterError::UnknownCluster(99))));
754 }
755
756 #[test]
757 fn test_cluster_stats_variance_non_negative() {
758 let idx = insert_and_build(
759 1,
760 vec![
761 ("a", vec![1.0f32, 0.0]),
762 ("b", vec![2.0, 0.0]),
763 ("c", vec![3.0, 0.0]),
764 ],
765 );
766 let stats = idx.cluster_stats(0).expect("stats");
767 assert!(stats.variance >= 0.0);
768 }
769
770 #[test]
771 fn test_all_cluster_stats_count() {
772 let idx = insert_and_build(
773 2,
774 vec![
775 ("a", vec![0.0f32, 0.0]),
776 ("b", vec![0.0, 0.1]),
777 ("c", vec![10.0, 0.0]),
778 ("d", vec![10.0, 0.1]),
779 ],
780 );
781 assert_eq!(idx.all_cluster_stats().len(), 2);
782 }
783
784 #[test]
789 fn test_merge_clusters_reduces_count() {
790 let mut idx = insert_and_build(
791 2,
792 vec![
793 ("a", vec![0.0f32, 0.0]),
794 ("b", vec![0.0, 0.1]),
795 ("c", vec![10.0, 0.0]),
796 ("d", vec![10.0, 0.1]),
797 ],
798 );
799 idx.merge_clusters(0, 1).expect("merge");
800 assert_eq!(idx.num_clusters(), 1);
801 }
802
803 #[test]
804 fn test_merge_same_cluster_error() {
805 let mut idx = insert_and_build(1, vec![("a", vec![1.0f32, 0.0])]);
806 let result = idx.merge_clusters(0, 0);
807 assert!(matches!(result, Err(ClusterError::SameCluster)));
808 }
809
810 #[test]
811 fn test_merge_closest_clusters() {
812 let mut idx = insert_and_build(
813 3,
814 vec![
815 ("a", vec![0.0f32, 0.0]),
816 ("b", vec![0.0, 0.1]),
817 ("c", vec![5.0, 0.0]),
818 ("d", vec![100.0, 0.0]),
819 ("e", vec![100.0, 0.1]),
820 ],
821 );
822 let before = idx.num_clusters();
823 idx.merge_closest_clusters().expect("merge");
824 assert_eq!(idx.num_clusters(), before - 1);
825 }
826
827 #[test]
828 fn test_merge_unknown_cluster_error() {
829 let mut idx = insert_and_build(1, vec![("a", vec![1.0f32, 0.0])]);
830 let result = idx.merge_clusters(0, 99);
831 assert!(matches!(result, Err(ClusterError::UnknownCluster(99))));
832 }
833
834 #[test]
839 fn test_split_largest_cluster_increases_count() {
840 let mut idx = insert_and_build(
841 1,
842 vec![
843 ("a", vec![1.0f32, 0.0]),
844 ("b", vec![2.0, 0.0]),
845 ("c", vec![3.0, 0.0]),
846 ("d", vec![4.0, 0.0]),
847 ],
848 );
849 let before = idx.num_clusters();
850 idx.split_largest_cluster().expect("split");
851 assert_eq!(idx.num_clusters(), before + 1);
852 }
853
854 #[test]
855 fn test_split_empty_error() {
856 let mut idx = make_index(1);
857 assert!(matches!(
858 idx.split_largest_cluster(),
859 Err(ClusterError::EmptyIndex)
860 ));
861 }
862
863 #[test]
868 fn test_search_returns_nearest() {
869 let idx = insert_and_build(
870 2,
871 vec![
872 ("origin", vec![0.0f32, 0.0]),
873 ("near", vec![0.1, 0.0]),
874 ("far", vec![10.0, 0.0]),
875 ],
876 );
877 let results = idx.search(&[0.0, 0.0], 1, 2).expect("search");
878 assert_eq!(results.len(), 1);
879 assert_eq!(results[0].0, "origin");
880 }
881
882 #[test]
883 fn test_search_top_k_limit() {
884 let idx = insert_and_build(
885 1,
886 vec![
887 ("a", vec![0.0f32, 0.0]),
888 ("b", vec![1.0, 0.0]),
889 ("c", vec![2.0, 0.0]),
890 ],
891 );
892 let results = idx.search(&[0.0, 0.0], 2, 1).expect("search");
893 assert!(results.len() <= 2);
894 }
895
896 #[test]
897 fn test_search_empty_error() {
898 let idx = make_index(2);
899 let result = idx.search(&[0.0, 0.0], 3, 1);
900 assert!(matches!(result, Err(ClusterError::EmptyIndex)));
901 }
902
903 #[test]
904 fn test_search_results_sorted_by_distance() {
905 let idx = insert_and_build(
906 1,
907 vec![
908 ("near", vec![0.1f32, 0.0]),
909 ("mid", vec![1.0, 0.0]),
910 ("far", vec![5.0, 0.0]),
911 ],
912 );
913 let results = idx.search(&[0.0, 0.0], 3, 1).expect("search");
914 for i in 1..results.len() {
915 assert!(results[i - 1].1 <= results[i].1);
916 }
917 }
918
919 #[test]
924 fn test_is_empty_initial() {
925 let idx = make_index(2);
926 assert!(idx.is_empty());
927 }
928
929 #[test]
930 fn test_len_after_inserts() {
931 let mut idx = make_index(2);
932 idx.insert("v0".into(), vec![0.0f32, 0.0]).expect("insert");
933 idx.insert("v1".into(), vec![1.0, 0.0]).expect("insert");
934 assert_eq!(idx.len(), 2);
935 }
936
937 #[test]
942 fn test_build_k_equals_n() {
943 let idx = insert_and_build(
945 3,
946 vec![
947 ("a", vec![0.0f32, 0.0]),
948 ("b", vec![5.0, 0.0]),
949 ("c", vec![10.0, 0.0]),
950 ],
951 );
952 assert_eq!(idx.num_clusters(), 3);
953 }
954
955 #[test]
956 fn test_build_more_k_than_vectors_clamped() {
957 let idx = insert_and_build(10, vec![("a", vec![0.0f32, 0.0]), ("b", vec![1.0, 0.0])]);
959 assert_eq!(idx.num_clusters(), 2);
960 }
961
962 #[test]
963 fn test_assign_after_build_consistent() {
964 let idx = insert_and_build(1, vec![("a", vec![3.0f32, 3.0]), ("b", vec![3.1, 3.1])]);
965 assert_eq!(idx.assign(&[3.05, 3.05]), Some(0));
967 }
968
969 #[test]
970 fn test_cluster_stats_single_member_zero_variance() {
971 let mut idx = make_index(1);
972 idx.insert("only".into(), vec![7.0f32, 7.0])
973 .expect("insert");
974 idx.build().expect("build");
975 let stats = idx.cluster_stats(0).expect("stats");
976 assert_eq!(stats.size, 1);
977 assert!(stats.variance < 1e-6);
978 }
979
980 #[test]
981 fn test_all_stats_total_size_equals_len() {
982 let idx = insert_and_build(
983 2,
984 vec![
985 ("a", vec![0.0f32, 0.0]),
986 ("b", vec![0.0, 0.1]),
987 ("c", vec![10.0, 0.0]),
988 ("d", vec![10.0, 0.1]),
989 ],
990 );
991 let total: usize = idx.all_cluster_stats().iter().map(|s| s.size).sum();
992 assert_eq!(total, idx.len());
993 }
994
995 #[test]
996 fn test_merge_all_members_accounted_for() {
997 let mut idx = insert_and_build(
998 2,
999 vec![
1000 ("a", vec![0.0f32, 0.0]),
1001 ("b", vec![0.0, 0.1]),
1002 ("c", vec![10.0, 0.0]),
1003 ("d", vec![10.0, 0.1]),
1004 ],
1005 );
1006 idx.merge_clusters(0, 1).expect("merge");
1007 let stats = idx.cluster_stats(0).expect("stats");
1008 assert_eq!(stats.size, 4); }
1010
1011 #[test]
1012 fn test_split_two_disjoint_halves() {
1013 let idx = insert_and_build(
1014 1,
1015 vec![
1016 ("lo1", vec![-5.0f32, 0.0]),
1017 ("lo2", vec![-4.0, 0.0]),
1018 ("hi1", vec![4.0, 0.0]),
1019 ("hi2", vec![5.0, 0.0]),
1020 ],
1021 );
1022 let stats = idx.all_cluster_stats();
1023 assert!(!stats.is_empty());
1024 }
1025
1026 #[test]
1027 fn test_search_n_probes_one_finds_all_in_cluster() {
1028 let idx = insert_and_build(
1029 1,
1030 vec![
1031 ("a", vec![1.0f32, 0.0]),
1032 ("b", vec![2.0, 0.0]),
1033 ("c", vec![3.0, 0.0]),
1034 ],
1035 );
1036 let results = idx.search(&[2.0, 0.0], 3, 1).expect("search");
1037 assert_eq!(results.len(), 3);
1038 }
1039
1040 #[test]
1041 fn test_cluster_index_3d_vectors() {
1042 let idx = insert_and_build(
1043 2,
1044 vec![
1045 ("a", vec![0.0f32, 0.0, 0.0]),
1046 ("b", vec![0.1, 0.0, 0.0]),
1047 ("c", vec![10.0, 10.0, 10.0]),
1048 ("d", vec![10.1, 10.0, 10.0]),
1049 ],
1050 );
1051 let r = idx.search(&[0.05, 0.0, 0.0], 1, 1).expect("search");
1052 assert!(!r.is_empty());
1053 }
1054
1055 #[test]
1056 fn test_assign_unknown_without_build_returns_none_or_some() {
1057 let mut idx = make_index(2);
1058 idx.insert("v0".into(), vec![1.0f32, 0.0]).expect("insert");
1059 assert!(idx.assign(&[1.0, 0.0]).is_none());
1061 }
1062
1063 #[test]
1064 fn test_cluster_error_display() {
1065 let e = ClusterError::InvalidK { k: 0, n: 5 };
1066 assert!(e.to_string().contains("invalid"));
1067 let e2 = ClusterError::EmptyIndex;
1068 assert!(e2.to_string().contains("empty"));
1069 let e3 = ClusterError::DuplicateId("x".into());
1070 assert!(e3.to_string().contains("x"));
1071 let e4 = ClusterError::DimMismatch {
1072 expected: 3,
1073 got: 2,
1074 };
1075 assert!(e4.to_string().contains("mismatch"));
1076 }
1077
1078 #[test]
1079 fn test_search_multi_probe() {
1080 let idx = insert_and_build(
1081 3,
1082 vec![
1083 ("c1a", vec![0.0f32, 0.0]),
1084 ("c1b", vec![0.1, 0.0]),
1085 ("c2a", vec![5.0, 0.0]),
1086 ("c2b", vec![5.1, 0.0]),
1087 ("c3a", vec![10.0, 0.0]),
1088 ("c3b", vec![10.1, 0.0]),
1089 ],
1090 );
1091 let results = idx.search(&[0.0, 0.0], 5, 2).expect("search");
1093 assert!(!results.is_empty());
1094 }
1095
1096 #[test]
1097 fn test_split_then_search() {
1098 let mut idx = insert_and_build(
1099 1,
1100 vec![
1101 ("a", vec![-3.0f32, 0.0]),
1102 ("b", vec![-2.0, 0.0]),
1103 ("c", vec![2.0, 0.0]),
1104 ("d", vec![3.0, 0.0]),
1105 ],
1106 );
1107 idx.split_largest_cluster().expect("split");
1108 let results = idx.search(&[-3.0, 0.0], 2, 2).expect("search");
1109 assert!(!results.is_empty());
1110 }
1111
1112 #[test]
1113 fn test_merge_then_build_consistent() {
1114 let mut idx = insert_and_build(
1115 2,
1116 vec![
1117 ("a", vec![0.0f32, 0.0]),
1118 ("b", vec![0.1, 0.0]),
1119 ("c", vec![10.0, 0.0]),
1120 ("d", vec![10.1, 0.0]),
1121 ],
1122 );
1123 idx.merge_clusters(0, 1).expect("merge");
1124 assert_eq!(idx.num_clusters(), 1);
1125 let r = idx.search(&[5.0, 0.0], 2, 1).expect("search");
1127 assert_eq!(r.len(), 2);
1128 }
1129
1130 #[test]
1131 fn test_num_clusters_zero_before_build() {
1132 let idx = make_index(3);
1133 assert_eq!(idx.num_clusters(), 0);
1134 }
1135
1136 #[test]
1137 fn test_build_single_vector() {
1138 let idx = insert_and_build(1, vec![("solo", vec![1.0f32, 2.0])]);
1139 assert_eq!(idx.num_clusters(), 1);
1140 assert_eq!(idx.len(), 1);
1141 }
1142
1143 #[test]
1144 fn test_search_returns_correct_id() {
1145 let idx = insert_and_build(
1146 1,
1147 vec![("x_near", vec![0.01f32, 0.0]), ("x_far", vec![100.0, 0.0])],
1148 );
1149 let r = idx.search(&[0.0, 0.0], 1, 1).expect("search");
1150 assert_eq!(r[0].0, "x_near");
1151 }
1152
1153 #[test]
1154 fn test_search_distance_is_non_negative() {
1155 let idx = insert_and_build(1, vec![("a", vec![1.0f32, 0.0]), ("b", vec![2.0, 0.0])]);
1156 let r = idx.search(&[0.0, 0.0], 2, 1).expect("search");
1157 for (_, d) in &r {
1158 assert!(*d >= 0.0);
1159 }
1160 }
1161
1162 #[test]
1163 fn test_cluster_stats_centroid_len() {
1164 let idx = insert_and_build(1, vec![("a", vec![1.0f32, 2.0, 3.0])]);
1165 let s = idx.cluster_stats(0).expect("stats");
1166 assert_eq!(s.centroid.len(), 3);
1167 }
1168
1169 #[test]
1170 fn test_merge_closest_with_two_clusters() {
1171 let mut idx = insert_and_build(
1172 2,
1173 vec![
1174 ("near_a", vec![0.0f32, 0.0]),
1175 ("near_b", vec![0.1, 0.0]),
1176 ("far_a", vec![100.0, 0.0]),
1177 ("far_b", vec![100.1, 0.0]),
1178 ],
1179 );
1180 idx.merge_closest_clusters().expect("merge");
1182 assert_eq!(idx.num_clusters(), 1);
1183 }
1184
1185 #[test]
1186 fn test_cluster_error_unknown_cluster_display() {
1187 let e = ClusterError::UnknownCluster(42);
1188 assert!(e.to_string().contains("42"));
1189 }
1190}