1use std::collections::{HashMap, HashSet};
10
11use parking_lot::RwLock;
12use rand::seq::SliceRandom;
13use serde::{Deserialize, Serialize};
14
15use common::{DistanceMetric, Vector, VectorId};
16
17use crate::distance::calculate_distance;
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct SpFreshConfig {
22 pub num_clusters: usize,
24 pub max_cluster_size: usize,
26 pub min_cluster_size: usize,
28 pub n_probe: usize,
30 pub compaction_threshold: f32,
32 pub distance_metric: DistanceMetric,
34}
35
36impl Default for SpFreshConfig {
37 fn default() -> Self {
38 Self {
39 num_clusters: 16,
40 max_cluster_size: 1000,
41 min_cluster_size: 50,
42 n_probe: 4,
43 compaction_threshold: 0.3,
44 distance_metric: DistanceMetric::Cosine,
45 }
46 }
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct Cluster {
52 pub id: usize,
54 pub centroid: Vec<f32>,
56 pub vectors: Vec<Vector>,
58 pub tombstones: HashSet<VectorId>,
60 pub live_count: usize,
62}
63
64impl Cluster {
65 fn new(id: usize, centroid: Vec<f32>) -> Self {
66 Self {
67 id,
68 centroid,
69 vectors: Vec::new(),
70 tombstones: HashSet::new(),
71 live_count: 0,
72 }
73 }
74
75 fn live_vectors(&self) -> impl Iterator<Item = &Vector> {
77 self.vectors
78 .iter()
79 .filter(|v| !self.tombstones.contains(&v.id))
80 }
81
82 fn tombstone_ratio(&self) -> f32 {
84 if self.vectors.is_empty() {
85 0.0
86 } else {
87 self.tombstones.len() as f32 / self.vectors.len() as f32
88 }
89 }
90
91 fn recompute_centroid(&mut self) {
93 let live: Vec<&Vector> = self.live_vectors().collect();
94 if live.is_empty() {
95 return;
96 }
97
98 let dim = live[0].values.len();
99 let mut new_centroid = vec![0.0f32; dim];
100
101 for vector in &live {
102 for (i, &val) in vector.values.iter().enumerate() {
103 new_centroid[i] += val;
104 }
105 }
106
107 let count = live.len() as f32;
108 for val in &mut new_centroid {
109 *val /= count;
110 }
111
112 self.centroid = new_centroid;
113 }
114
115 fn compact(&mut self) {
117 self.vectors.retain(|v| !self.tombstones.contains(&v.id));
118 self.tombstones.clear();
119 self.live_count = self.vectors.len();
120 }
121}
122
123#[derive(Debug, Clone)]
125pub struct SpFreshSearchResult {
126 pub id: VectorId,
127 pub score: f32,
128 pub vector: Option<Vector>,
129}
130
131pub struct SpFreshIndex {
133 config: SpFreshConfig,
134 clusters: RwLock<Vec<Cluster>>,
135 vector_cluster_map: RwLock<HashMap<VectorId, usize>>,
137 global_tombstones: RwLock<HashSet<VectorId>>,
139 pending_vectors: RwLock<Vec<Vector>>,
141 trained: RwLock<bool>,
143 dimension: RwLock<Option<usize>>,
145}
146
147impl SpFreshIndex {
148 pub fn new(config: SpFreshConfig) -> Self {
150 Self {
151 config,
152 clusters: RwLock::new(Vec::new()),
153 vector_cluster_map: RwLock::new(HashMap::new()),
154 global_tombstones: RwLock::new(HashSet::new()),
155 pending_vectors: RwLock::new(Vec::new()),
156 trained: RwLock::new(false),
157 dimension: RwLock::new(None),
158 }
159 }
160
161 pub fn is_trained(&self) -> bool {
163 *self.trained.read()
164 }
165
166 pub fn dimension(&self) -> Option<usize> {
168 *self.dimension.read()
169 }
170
171 pub fn train(&self, vectors: &[Vector]) -> Result<(), String> {
173 if vectors.is_empty() {
174 return Err("Cannot train with empty vectors".to_string());
175 }
176
177 let dim = vectors[0].values.len();
178 *self.dimension.write() = Some(dim);
179
180 let centroids = self.kmeans_plus_plus_init(vectors);
182
183 let final_centroids = self.kmeans_iterate(vectors, centroids, 20);
185
186 let mut clusters = Vec::with_capacity(self.config.num_clusters);
188 for (i, centroid) in final_centroids.into_iter().enumerate() {
189 clusters.push(Cluster::new(i, centroid));
190 }
191
192 let mut vector_cluster_map = HashMap::new();
194 for vector in vectors {
195 let cluster_id = self.find_nearest_cluster_idx(&vector.values, &clusters);
196 clusters[cluster_id].vectors.push(vector.clone());
197 clusters[cluster_id].live_count += 1;
198 vector_cluster_map.insert(vector.id.clone(), cluster_id);
199 }
200
201 for cluster in &mut clusters {
203 cluster.recompute_centroid();
204 }
205
206 *self.clusters.write() = clusters;
207 *self.vector_cluster_map.write() = vector_cluster_map;
208 *self.trained.write() = true;
209
210 Ok(())
211 }
212
213 fn kmeans_plus_plus_init(&self, vectors: &[Vector]) -> Vec<Vec<f32>> {
215 let mut rng = rand::thread_rng();
216 let k = self.config.num_clusters.min(vectors.len());
217 let mut centroids = Vec::with_capacity(k);
218
219 let first = vectors.choose(&mut rng).unwrap();
221 centroids.push(first.values.clone());
222
223 for _ in 1..k {
225 let mut distances: Vec<f32> = vectors
226 .iter()
227 .map(|v| {
228 centroids
229 .iter()
230 .map(|c| calculate_distance(&v.values, c, self.config.distance_metric))
231 .fold(f32::MAX, f32::min)
232 })
233 .collect();
234
235 let total: f32 = distances.iter().sum();
237 if total == 0.0 {
238 break;
239 }
240
241 for d in &mut distances {
242 *d /= total;
243 }
244
245 let threshold: f32 = rand::random();
247 let mut cumsum = 0.0;
248 for (i, d) in distances.iter().enumerate() {
249 cumsum += d;
250 if cumsum >= threshold {
251 centroids.push(vectors[i].values.clone());
252 break;
253 }
254 }
255 }
256
257 centroids
258 }
259
260 fn kmeans_iterate(
262 &self,
263 vectors: &[Vector],
264 mut centroids: Vec<Vec<f32>>,
265 max_iters: usize,
266 ) -> Vec<Vec<f32>> {
267 let dim = vectors[0].values.len();
268
269 for _ in 0..max_iters {
270 let mut assignments: Vec<Vec<&Vector>> = vec![Vec::new(); centroids.len()];
272 for vector in vectors {
273 let mut best_idx = 0;
274 let mut best_dist = f32::MAX;
275 for (i, centroid) in centroids.iter().enumerate() {
276 let dist =
277 calculate_distance(&vector.values, centroid, self.config.distance_metric);
278 if dist < best_dist {
279 best_dist = dist;
280 best_idx = i;
281 }
282 }
283 assignments[best_idx].push(vector);
284 }
285
286 let mut new_centroids = Vec::with_capacity(centroids.len());
288 for (i, assigned) in assignments.iter().enumerate() {
289 if assigned.is_empty() {
290 new_centroids.push(centroids[i].clone());
291 } else {
292 let mut new_centroid = vec![0.0f32; dim];
293 for vector in assigned {
294 for (j, &val) in vector.values.iter().enumerate() {
295 new_centroid[j] += val;
296 }
297 }
298 let count = assigned.len() as f32;
299 for val in &mut new_centroid {
300 *val /= count;
301 }
302 new_centroids.push(new_centroid);
303 }
304 }
305
306 centroids = new_centroids;
307 }
308
309 centroids
310 }
311
312 fn find_nearest_cluster_idx(&self, vector: &[f32], clusters: &[Cluster]) -> usize {
314 let mut best_idx = 0;
315 let mut best_dist = f32::MAX;
316
317 for (i, cluster) in clusters.iter().enumerate() {
318 let dist = calculate_distance(vector, &cluster.centroid, self.config.distance_metric);
319 if dist < best_dist {
320 best_dist = dist;
321 best_idx = i;
322 }
323 }
324
325 best_idx
326 }
327
328 pub fn add(&self, vectors: Vec<Vector>) -> Result<usize, String> {
330 if vectors.is_empty() {
331 return Ok(0);
332 }
333
334 let dim = vectors[0].values.len();
336 {
337 let current_dim = *self.dimension.read();
338 if let Some(expected) = current_dim {
339 if dim != expected {
340 return Err(format!(
341 "Dimension mismatch: expected {}, got {}",
342 expected, dim
343 ));
344 }
345 } else {
346 *self.dimension.write() = Some(dim);
347 }
348 }
349
350 let count = vectors.len();
351
352 if !self.is_trained() {
354 let mut pending = self.pending_vectors.write();
355 for vector in vectors {
356 if !self.global_tombstones.read().contains(&vector.id) {
357 pending.push(vector);
358 }
359 }
360 return Ok(count);
361 }
362
363 let mut clusters = self.clusters.write();
365 let mut vector_map = self.vector_cluster_map.write();
366 let global_tombstones = self.global_tombstones.read();
367
368 for vector in vectors {
369 if global_tombstones.contains(&vector.id) {
370 continue;
371 }
372
373 let cluster_id = self.find_nearest_cluster_idx(&vector.values, &clusters);
374
375 if let Some(&old_cluster_id) = vector_map.get(&vector.id) {
377 if old_cluster_id != cluster_id {
378 clusters[old_cluster_id]
379 .tombstones
380 .insert(vector.id.clone());
381 clusters[old_cluster_id].live_count =
382 clusters[old_cluster_id].live_count.saturating_sub(1);
383 }
384 }
385
386 clusters[cluster_id].vectors.push(vector.clone());
387 clusters[cluster_id].live_count += 1;
388 vector_map.insert(vector.id.clone(), cluster_id);
389 }
390
391 drop(vector_map);
393 self.check_splits(&mut clusters);
394
395 Ok(count)
396 }
397
398 fn check_splits(&self, clusters: &mut Vec<Cluster>) {
400 let mut new_clusters = Vec::new();
401 let max_size = self.config.max_cluster_size;
402 let base_len = clusters.len();
403
404 for cluster in clusters.iter_mut().take(base_len) {
405 if cluster.live_count > max_size {
406 let new_id = base_len + new_clusters.len();
408 if let Some(new_cluster) = self.split_cluster(cluster, new_id) {
409 new_clusters.push(new_cluster);
410 }
411 }
412 }
413
414 clusters.extend(new_clusters);
415 }
416
417 fn split_cluster(&self, cluster: &mut Cluster, new_id: usize) -> Option<Cluster> {
419 let live_vectors: Vec<Vector> = cluster.live_vectors().cloned().collect();
420 if live_vectors.len() < 2 {
421 return None;
422 }
423
424 let mut max_dist = 0.0f32;
426 let mut idx1 = 0;
427 let mut idx2 = 1;
428
429 for (i, v1) in live_vectors.iter().enumerate() {
430 for (j, v2) in live_vectors.iter().enumerate().skip(i + 1) {
431 let dist = calculate_distance(&v1.values, &v2.values, self.config.distance_metric);
432 if dist > max_dist {
433 max_dist = dist;
434 idx1 = i;
435 idx2 = j;
436 }
437 }
438 }
439
440 let centroid1 = live_vectors[idx1].values.clone();
441 let centroid2 = live_vectors[idx2].values.clone();
442
443 let mut vectors1 = Vec::new();
445 let mut vectors2 = Vec::new();
446
447 for vector in live_vectors {
448 let dist1 = calculate_distance(&vector.values, ¢roid1, self.config.distance_metric);
449 let dist2 = calculate_distance(&vector.values, ¢roid2, self.config.distance_metric);
450
451 if dist1 <= dist2 {
452 vectors1.push(vector);
453 } else {
454 vectors2.push(vector);
455 }
456 }
457
458 cluster.vectors = vectors1;
460 cluster.tombstones.clear();
461 cluster.live_count = cluster.vectors.len();
462 cluster.recompute_centroid();
463
464 let mut new_cluster = Cluster::new(new_id, centroid2);
466 new_cluster.vectors = vectors2;
467 new_cluster.live_count = new_cluster.vectors.len();
468 new_cluster.recompute_centroid();
469
470 let mut vector_map = self.vector_cluster_map.write();
472 for v in &cluster.vectors {
473 vector_map.insert(v.id.clone(), cluster.id);
474 }
475 for v in &new_cluster.vectors {
476 vector_map.insert(v.id.clone(), new_cluster.id);
477 }
478
479 Some(new_cluster)
480 }
481
482 pub fn remove(&self, ids: &[VectorId]) -> usize {
484 if !self.is_trained() {
485 let mut pending = self.pending_vectors.write();
487 let mut global_tombstones = self.global_tombstones.write();
488 let before = pending.len();
489 pending.retain(|v| !ids.contains(&v.id));
490 for id in ids {
491 global_tombstones.insert(id.clone());
492 }
493 return before - pending.len();
494 }
495
496 let mut clusters = self.clusters.write();
497 let vector_map = self.vector_cluster_map.read();
498 let mut count = 0;
499
500 for id in ids {
501 if let Some(&cluster_id) = vector_map.get(id) {
502 if cluster_id < clusters.len() {
503 clusters[cluster_id].tombstones.insert(id.clone());
504 clusters[cluster_id].live_count =
505 clusters[cluster_id].live_count.saturating_sub(1);
506 count += 1;
507 }
508 }
509 }
510
511 count
512 }
513
514 pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<SpFreshSearchResult>, String> {
516 if !self.is_trained() {
517 return self.search_pending(query, k);
519 }
520
521 let clusters = self.clusters.read();
522 if clusters.is_empty() {
523 return Ok(Vec::new());
524 }
525
526 let mut cluster_distances: Vec<(usize, f32)> = clusters
528 .iter()
529 .enumerate()
530 .map(|(i, c)| {
531 (
532 i,
533 calculate_distance(query, &c.centroid, self.config.distance_metric),
534 )
535 })
536 .collect();
537
538 cluster_distances
539 .sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
540
541 let n_probe = self.config.n_probe.min(clusters.len());
542
543 let mut results: Vec<SpFreshSearchResult> = Vec::new();
545
546 for (cluster_idx, _) in cluster_distances.iter().take(n_probe) {
547 let cluster = &clusters[*cluster_idx];
548 for vector in cluster.live_vectors() {
549 let score = calculate_distance(query, &vector.values, self.config.distance_metric);
550 results.push(SpFreshSearchResult {
551 id: vector.id.clone(),
552 score,
553 vector: Some(vector.clone()),
554 });
555 }
556 }
557
558 results.sort_by(|a, b| {
560 b.score
561 .partial_cmp(&a.score)
562 .unwrap_or(std::cmp::Ordering::Equal)
563 });
564 results.truncate(k);
565
566 Ok(results)
567 }
568
569 fn search_pending(&self, query: &[f32], k: usize) -> Result<Vec<SpFreshSearchResult>, String> {
571 let pending = self.pending_vectors.read();
572 let tombstones = self.global_tombstones.read();
573
574 let mut results: Vec<SpFreshSearchResult> = pending
575 .iter()
576 .filter(|v| !tombstones.contains(&v.id))
577 .map(|v| SpFreshSearchResult {
578 id: v.id.clone(),
579 score: calculate_distance(query, &v.values, self.config.distance_metric),
580 vector: Some(v.clone()),
581 })
582 .collect();
583
584 results.sort_by(|a, b| {
585 b.score
586 .partial_cmp(&a.score)
587 .unwrap_or(std::cmp::Ordering::Equal)
588 });
589 results.truncate(k);
590
591 Ok(results)
592 }
593
594 pub fn compact(&self) -> usize {
596 if !self.is_trained() {
597 return 0;
598 }
599
600 let mut clusters = self.clusters.write();
601 let mut compacted = 0;
602
603 for cluster in clusters.iter_mut() {
604 if cluster.tombstone_ratio() >= self.config.compaction_threshold {
605 cluster.compact();
606 compacted += 1;
607 }
608 }
609
610 if compacted > 0 {
612 let mut vector_map = self.vector_cluster_map.write();
613 vector_map.clear();
614 for cluster in clusters.iter() {
615 for vector in &cluster.vectors {
616 vector_map.insert(vector.id.clone(), cluster.id);
617 }
618 }
619 }
620
621 compacted
622 }
623
624 pub fn merge_small_clusters(&self) -> usize {
626 if !self.is_trained() {
627 return 0;
628 }
629
630 let mut clusters = self.clusters.write();
631 let min_size = self.config.min_cluster_size;
632
633 let small_clusters: Vec<usize> = clusters
635 .iter()
636 .enumerate()
637 .filter(|(_, c)| c.live_count < min_size && c.live_count > 0)
638 .map(|(i, _)| i)
639 .collect();
640
641 if small_clusters.len() < 2 {
642 return 0;
643 }
644
645 let mut merged = 0;
646
647 for chunk in small_clusters.chunks(2) {
649 if chunk.len() == 2 {
650 let (idx1, idx2) = (chunk[0], chunk[1]);
651
652 let vectors_to_move: Vec<Vector> = clusters[idx2].live_vectors().cloned().collect();
654
655 for vector in vectors_to_move {
656 clusters[idx1].vectors.push(vector);
657 clusters[idx1].live_count += 1;
658 }
659
660 clusters[idx2].vectors.clear();
662 clusters[idx2].tombstones.clear();
663 clusters[idx2].live_count = 0;
664
665 clusters[idx1].recompute_centroid();
667
668 merged += 1;
669 }
670 }
671
672 if merged > 0 {
674 let mut vector_map = self.vector_cluster_map.write();
675 for cluster in clusters.iter() {
676 for vector in &cluster.vectors {
677 if !cluster.tombstones.contains(&vector.id) {
678 vector_map.insert(vector.id.clone(), cluster.id);
679 }
680 }
681 }
682 }
683
684 merged
685 }
686
687 pub fn stats(&self) -> SpFreshStats {
689 let clusters = self.clusters.read();
690 let pending = self.pending_vectors.read();
691
692 let total_vectors: usize = clusters.iter().map(|c| c.live_count).sum();
693 let total_tombstones: usize = clusters.iter().map(|c| c.tombstones.len()).sum();
694
695 SpFreshStats {
696 num_clusters: clusters.len(),
697 total_vectors,
698 total_tombstones,
699 pending_vectors: pending.len(),
700 trained: *self.trained.read(),
701 dimension: *self.dimension.read(),
702 }
703 }
704
705 pub fn config(&self) -> &SpFreshConfig {
707 &self.config
708 }
709
710 pub(crate) fn clusters_read(&self) -> Vec<Cluster> {
712 self.clusters.read().clone()
713 }
714
715 pub(crate) fn vector_cluster_map_read(&self) -> HashMap<VectorId, usize> {
717 self.vector_cluster_map.read().clone()
718 }
719
720 pub(crate) fn global_tombstones_read(&self) -> HashSet<VectorId> {
722 self.global_tombstones.read().clone()
723 }
724
725 pub(crate) fn pending_vectors_read(&self) -> Vec<Vector> {
727 self.pending_vectors.read().clone()
728 }
729
730 pub fn from_snapshot(
732 snapshot: crate::persistence::SpFreshFullSnapshot,
733 ) -> Result<Self, String> {
734 Ok(Self {
735 config: snapshot.config,
736 clusters: RwLock::new(snapshot.clusters),
737 vector_cluster_map: RwLock::new(snapshot.vector_cluster_map),
738 global_tombstones: RwLock::new(snapshot.global_tombstones),
739 pending_vectors: RwLock::new(snapshot.pending_vectors),
740 trained: RwLock::new(snapshot.trained),
741 dimension: RwLock::new(snapshot.dimension),
742 })
743 }
744}
745
746#[derive(Debug, Clone)]
748pub struct SpFreshStats {
749 pub num_clusters: usize,
750 pub total_vectors: usize,
751 pub total_tombstones: usize,
752 pub pending_vectors: usize,
753 pub trained: bool,
754 pub dimension: Option<usize>,
755}
756
757#[cfg(test)]
758mod tests {
759 use super::*;
760
761 fn test_vectors(n: usize, dim: usize) -> Vec<Vector> {
762 (0..n)
764 .map(|i| Vector {
765 id: format!("v{}", i),
766 values: (0..dim)
767 .map(|j| {
768 (i as f32) + (j as f32 * 0.01)
770 })
771 .collect(),
772 metadata: None,
773 ttl_seconds: None,
774 expires_at: None,
775 })
776 .collect()
777 }
778
779 #[test]
780 fn test_train_and_search() {
781 let config = SpFreshConfig {
783 num_clusters: 1,
784 n_probe: 1,
785 distance_metric: DistanceMetric::Euclidean,
786 ..Default::default()
787 };
788 let index = SpFreshIndex::new(config);
789
790 let vectors = test_vectors(50, 4);
791 index.train(&vectors).unwrap();
792
793 assert!(index.is_trained());
794 assert_eq!(index.dimension(), Some(4));
795
796 let results = index.search(&vectors[25].values, 5).unwrap();
798 assert!(!results.is_empty());
799
800 assert_eq!(results[0].id, "v25");
802 assert!(results[0].score < 0.001, "Exact match should have score ~0");
803
804 for i in 1..results.len() {
806 assert!(
807 results[i - 1].score >= results[i].score,
808 "Results should be sorted by score descending"
809 );
810 }
811 }
812
813 #[test]
814 fn test_multi_cluster_search() {
815 let config = SpFreshConfig {
816 num_clusters: 4,
817 n_probe: 4, distance_metric: DistanceMetric::Euclidean,
819 ..Default::default()
820 };
821 let index = SpFreshIndex::new(config);
822
823 let vectors = test_vectors(100, 8);
824 index.train(&vectors).unwrap();
825
826 let results = index.search(&vectors[50].values, 10).unwrap();
828 assert!(!results.is_empty());
829 assert!(results.len() <= 10);
830
831 for i in 1..results.len() {
833 assert!(results[i - 1].score >= results[i].score);
834 }
835
836 let stats = index.stats();
838 assert_eq!(stats.num_clusters, 4);
839 assert_eq!(stats.total_vectors, 100);
840 }
841
842 #[test]
843 fn test_add_after_train() {
844 let config = SpFreshConfig {
845 num_clusters: 4,
846 ..Default::default()
847 };
848 let index = SpFreshIndex::new(config);
849
850 let vectors = test_vectors(50, 8);
851 index.train(&vectors).unwrap();
852
853 let new_vectors = vec![Vector {
854 id: "new1".to_string(),
855 values: vec![0.5; 8],
856 metadata: None,
857 ttl_seconds: None,
858 expires_at: None,
859 }];
860
861 let added = index.add(new_vectors).unwrap();
862 assert_eq!(added, 1);
863
864 let stats = index.stats();
865 assert_eq!(stats.total_vectors, 51);
866 }
867
868 #[test]
869 fn test_remove_tombstone() {
870 let config = SpFreshConfig {
871 num_clusters: 4,
872 ..Default::default()
873 };
874 let index = SpFreshIndex::new(config);
875
876 let vectors = test_vectors(50, 8);
877 index.train(&vectors).unwrap();
878
879 let removed = index.remove(&["v0".to_string(), "v1".to_string()]);
880 assert_eq!(removed, 2);
881
882 let stats = index.stats();
883 assert_eq!(stats.total_vectors, 48);
884 assert_eq!(stats.total_tombstones, 2);
885 }
886
887 #[test]
888 fn test_compaction() {
889 let config = SpFreshConfig {
890 num_clusters: 2,
891 compaction_threshold: 0.1,
892 ..Default::default()
893 };
894 let index = SpFreshIndex::new(config);
895
896 let vectors = test_vectors(20, 4);
897 index.train(&vectors).unwrap();
898
899 let ids: Vec<String> = (0..10).map(|i| format!("v{}", i)).collect();
901 index.remove(&ids);
902
903 let compacted = index.compact();
904 assert!(compacted > 0);
905
906 let stats = index.stats();
907 assert_eq!(stats.total_tombstones, 0);
908 }
909
910 #[test]
911 fn test_pending_before_train() {
912 let config = SpFreshConfig::default();
913 let index = SpFreshIndex::new(config);
914
915 let vectors = test_vectors(10, 4);
916 index.add(vectors.clone()).unwrap();
917
918 assert!(!index.is_trained());
919 let stats = index.stats();
920 assert_eq!(stats.pending_vectors, 10);
921
922 let results = index.search(&vectors[0].values, 3).unwrap();
924 assert!(!results.is_empty());
925 }
926
927 #[test]
928 fn test_dimension_mismatch() {
929 let config = SpFreshConfig {
930 num_clusters: 2,
931 ..Default::default()
932 };
933 let index = SpFreshIndex::new(config);
934
935 let vectors = test_vectors(10, 4);
936 index.train(&vectors).unwrap();
937
938 let bad_vectors = vec![Vector {
939 id: "bad".to_string(),
940 values: vec![1.0, 2.0], metadata: None,
942 ttl_seconds: None,
943 expires_at: None,
944 }];
945
946 let result = index.add(bad_vectors);
947 assert!(result.is_err());
948 }
949
950 #[test]
951 fn test_cluster_split() {
952 let config = SpFreshConfig {
953 num_clusters: 1,
954 max_cluster_size: 10,
955 ..Default::default()
956 };
957 let index = SpFreshIndex::new(config);
958
959 let vectors = test_vectors(15, 4);
960 index.train(&vectors).unwrap();
961
962 let more_vectors = test_vectors(20, 4)
964 .into_iter()
965 .enumerate()
966 .map(|(i, mut v)| {
967 v.id = format!("new{}", i);
968 v
969 })
970 .collect();
971
972 index.add(more_vectors).unwrap();
973
974 let stats = index.stats();
975 assert!(stats.num_clusters > 1);
976 }
977
978 #[test]
979 fn test_stats() {
980 let config = SpFreshConfig {
981 num_clusters: 4,
982 ..Default::default()
983 };
984 let index = SpFreshIndex::new(config);
985
986 let vectors = test_vectors(100, 8);
987 index.train(&vectors).unwrap();
988
989 let stats = index.stats();
990 assert_eq!(stats.total_vectors, 100);
991 assert_eq!(stats.num_clusters, 4);
992 assert!(stats.trained);
993 assert_eq!(stats.dimension, Some(8));
994 }
995}