Skip to main content

engine/
spfresh.rs

1//! SPFresh Index with LIRE (Lazy Index Reorganization and Expansion)
2//!
3//! Optimized for serverless vector databases on cold storage (S3/MinIO).
4//! Key features:
5//! - Lazy cluster splitting and merging
6//! - Tombstone-based deletion with background compaction
7//! - Optimized for batch operations and cold storage patterns
8
9use std::collections::{HashMap, HashSet};
10
11use parking_lot::RwLock;
12use rand::seq::SliceRandom;
13use serde::{Deserialize, Serialize};
14
15use common::{DistanceMetric, Vector, VectorId};
16
17use crate::distance::calculate_distance;
18
19/// Configuration for SPFresh index
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct SpFreshConfig {
22    /// Target number of clusters
23    pub num_clusters: usize,
24    /// Maximum vectors per cluster before split
25    pub max_cluster_size: usize,
26    /// Minimum vectors per cluster before merge consideration
27    pub min_cluster_size: usize,
28    /// Number of clusters to probe during search
29    pub n_probe: usize,
30    /// Tombstone ratio threshold for compaction (0.0 - 1.0)
31    pub compaction_threshold: f32,
32    /// Distance metric to use
33    pub distance_metric: DistanceMetric,
34}
35
36impl Default for SpFreshConfig {
37    fn default() -> Self {
38        Self {
39            num_clusters: 16,
40            max_cluster_size: 1000,
41            min_cluster_size: 50,
42            n_probe: 4,
43            compaction_threshold: 0.3,
44            distance_metric: DistanceMetric::Cosine,
45        }
46    }
47}
48
49/// A cluster in the SPFresh index
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct Cluster {
52    /// Cluster ID
53    pub id: usize,
54    /// Centroid vector
55    pub centroid: Vec<f32>,
56    /// Vectors in this cluster
57    pub vectors: Vec<Vector>,
58    /// Tombstones (deleted vector IDs)
59    pub tombstones: HashSet<VectorId>,
60    /// Number of live vectors (vectors.len() - tombstones that are in vectors)
61    pub live_count: usize,
62}
63
64impl Cluster {
65    fn new(id: usize, centroid: Vec<f32>) -> Self {
66        Self {
67            id,
68            centroid,
69            vectors: Vec::new(),
70            tombstones: HashSet::new(),
71            live_count: 0,
72        }
73    }
74
75    /// Get live vectors (excluding tombstones)
76    fn live_vectors(&self) -> impl Iterator<Item = &Vector> {
77        self.vectors
78            .iter()
79            .filter(|v| !self.tombstones.contains(&v.id))
80    }
81
82    /// Tombstone ratio
83    fn tombstone_ratio(&self) -> f32 {
84        if self.vectors.is_empty() {
85            0.0
86        } else {
87            self.tombstones.len() as f32 / self.vectors.len() as f32
88        }
89    }
90
91    /// Recompute centroid from live vectors
92    fn recompute_centroid(&mut self) {
93        let live: Vec<&Vector> = self.live_vectors().collect();
94        if live.is_empty() {
95            return;
96        }
97
98        let dim = live[0].values.len();
99        let mut new_centroid = vec![0.0f32; dim];
100
101        for vector in &live {
102            for (i, &val) in vector.values.iter().enumerate() {
103                new_centroid[i] += val;
104            }
105        }
106
107        let count = live.len() as f32;
108        for val in &mut new_centroid {
109            *val /= count;
110        }
111
112        self.centroid = new_centroid;
113    }
114
115    /// Compact the cluster by removing tombstoned vectors
116    fn compact(&mut self) {
117        self.vectors.retain(|v| !self.tombstones.contains(&v.id));
118        self.tombstones.clear();
119        self.live_count = self.vectors.len();
120    }
121}
122
123/// Search result from SPFresh index
124#[derive(Debug, Clone)]
125pub struct SpFreshSearchResult {
126    pub id: VectorId,
127    pub score: f32,
128    pub vector: Option<Vector>,
129}
130
131/// SPFresh Index implementation
132pub struct SpFreshIndex {
133    config: SpFreshConfig,
134    clusters: RwLock<Vec<Cluster>>,
135    /// Vector ID to cluster ID mapping for fast lookup
136    vector_cluster_map: RwLock<HashMap<VectorId, usize>>,
137    /// Global tombstones (for vectors not yet assigned to clusters)
138    global_tombstones: RwLock<HashSet<VectorId>>,
139    /// Pending vectors not yet assigned to clusters
140    pending_vectors: RwLock<Vec<Vector>>,
141    /// Whether the index has been trained
142    trained: RwLock<bool>,
143    /// Vector dimension
144    dimension: RwLock<Option<usize>>,
145}
146
147impl SpFreshIndex {
148    /// Create a new SPFresh index
149    pub fn new(config: SpFreshConfig) -> Self {
150        Self {
151            config,
152            clusters: RwLock::new(Vec::new()),
153            vector_cluster_map: RwLock::new(HashMap::new()),
154            global_tombstones: RwLock::new(HashSet::new()),
155            pending_vectors: RwLock::new(Vec::new()),
156            trained: RwLock::new(false),
157            dimension: RwLock::new(None),
158        }
159    }
160
161    /// Check if index is trained
162    pub fn is_trained(&self) -> bool {
163        *self.trained.read()
164    }
165
166    /// Get vector dimension
167    pub fn dimension(&self) -> Option<usize> {
168        *self.dimension.read()
169    }
170
171    /// Train the index with initial vectors using k-means
172    pub fn train(&self, vectors: &[Vector]) -> Result<(), String> {
173        if vectors.is_empty() {
174            return Err("Cannot train with empty vectors".to_string());
175        }
176
177        let dim = vectors[0].values.len();
178        *self.dimension.write() = Some(dim);
179
180        // Initialize centroids using k-means++
181        let centroids = self.kmeans_plus_plus_init(vectors);
182
183        // Run k-means iterations
184        let final_centroids = self.kmeans_iterate(vectors, centroids, 20);
185
186        // Create clusters
187        let mut clusters = Vec::with_capacity(self.config.num_clusters);
188        for (i, centroid) in final_centroids.into_iter().enumerate() {
189            clusters.push(Cluster::new(i, centroid));
190        }
191
192        // Assign vectors to clusters
193        let mut vector_cluster_map = HashMap::new();
194        for vector in vectors {
195            let cluster_id = self.find_nearest_cluster_idx(&vector.values, &clusters);
196            clusters[cluster_id].vectors.push(vector.clone());
197            clusters[cluster_id].live_count += 1;
198            vector_cluster_map.insert(vector.id.clone(), cluster_id);
199        }
200
201        // Update centroids based on assigned vectors
202        for cluster in &mut clusters {
203            cluster.recompute_centroid();
204        }
205
206        *self.clusters.write() = clusters;
207        *self.vector_cluster_map.write() = vector_cluster_map;
208        *self.trained.write() = true;
209
210        Ok(())
211    }
212
213    /// K-means++ initialization
214    fn kmeans_plus_plus_init(&self, vectors: &[Vector]) -> Vec<Vec<f32>> {
215        let mut rng = rand::thread_rng();
216        let k = self.config.num_clusters.min(vectors.len());
217        let mut centroids = Vec::with_capacity(k);
218
219        // First centroid: random
220        let first = vectors.choose(&mut rng).unwrap();
221        centroids.push(first.values.clone());
222
223        // Remaining centroids: weighted by distance squared
224        for _ in 1..k {
225            let mut distances: Vec<f32> = vectors
226                .iter()
227                .map(|v| {
228                    centroids
229                        .iter()
230                        .map(|c| calculate_distance(&v.values, c, self.config.distance_metric))
231                        .fold(f32::MAX, f32::min)
232                })
233                .collect();
234
235            // Convert to cumulative distribution
236            let total: f32 = distances.iter().sum();
237            if total == 0.0 {
238                break;
239            }
240
241            for d in &mut distances {
242                *d /= total;
243            }
244
245            // Sample from distribution
246            let threshold: f32 = rand::random();
247            let mut cumsum = 0.0;
248            for (i, d) in distances.iter().enumerate() {
249                cumsum += d;
250                if cumsum >= threshold {
251                    centroids.push(vectors[i].values.clone());
252                    break;
253                }
254            }
255        }
256
257        centroids
258    }
259
260    /// Run k-means iterations
261    fn kmeans_iterate(
262        &self,
263        vectors: &[Vector],
264        mut centroids: Vec<Vec<f32>>,
265        max_iters: usize,
266    ) -> Vec<Vec<f32>> {
267        let dim = vectors[0].values.len();
268
269        for _ in 0..max_iters {
270            // Assign vectors to nearest centroid
271            let mut assignments: Vec<Vec<&Vector>> = vec![Vec::new(); centroids.len()];
272            for vector in vectors {
273                let mut best_idx = 0;
274                let mut best_dist = f32::MAX;
275                for (i, centroid) in centroids.iter().enumerate() {
276                    let dist =
277                        calculate_distance(&vector.values, centroid, self.config.distance_metric);
278                    if dist < best_dist {
279                        best_dist = dist;
280                        best_idx = i;
281                    }
282                }
283                assignments[best_idx].push(vector);
284            }
285
286            // Recompute centroids
287            let mut new_centroids = Vec::with_capacity(centroids.len());
288            for (i, assigned) in assignments.iter().enumerate() {
289                if assigned.is_empty() {
290                    new_centroids.push(centroids[i].clone());
291                } else {
292                    let mut new_centroid = vec![0.0f32; dim];
293                    for vector in assigned {
294                        for (j, &val) in vector.values.iter().enumerate() {
295                            new_centroid[j] += val;
296                        }
297                    }
298                    let count = assigned.len() as f32;
299                    for val in &mut new_centroid {
300                        *val /= count;
301                    }
302                    new_centroids.push(new_centroid);
303                }
304            }
305
306            centroids = new_centroids;
307        }
308
309        centroids
310    }
311
312    /// Find nearest cluster index
313    fn find_nearest_cluster_idx(&self, vector: &[f32], clusters: &[Cluster]) -> usize {
314        let mut best_idx = 0;
315        let mut best_dist = f32::MAX;
316
317        for (i, cluster) in clusters.iter().enumerate() {
318            let dist = calculate_distance(vector, &cluster.centroid, self.config.distance_metric);
319            if dist < best_dist {
320                best_dist = dist;
321                best_idx = i;
322            }
323        }
324
325        best_idx
326    }
327
328    /// Add vectors to the index (LIRE: lazy insertion)
329    pub fn add(&self, vectors: Vec<Vector>) -> Result<usize, String> {
330        if vectors.is_empty() {
331            return Ok(0);
332        }
333
334        // Validate dimension
335        let dim = vectors[0].values.len();
336        {
337            let current_dim = *self.dimension.read();
338            if let Some(expected) = current_dim {
339                if dim != expected {
340                    return Err(format!(
341                        "Dimension mismatch: expected {}, got {}",
342                        expected, dim
343                    ));
344                }
345            } else {
346                *self.dimension.write() = Some(dim);
347            }
348        }
349
350        let count = vectors.len();
351
352        // If not trained yet, add to pending
353        if !self.is_trained() {
354            let mut pending = self.pending_vectors.write();
355            for vector in vectors {
356                if !self.global_tombstones.read().contains(&vector.id) {
357                    pending.push(vector);
358                }
359            }
360            return Ok(count);
361        }
362
363        // Add to appropriate clusters
364        let mut clusters = self.clusters.write();
365        let mut vector_map = self.vector_cluster_map.write();
366        let global_tombstones = self.global_tombstones.read();
367
368        for vector in vectors {
369            if global_tombstones.contains(&vector.id) {
370                continue;
371            }
372
373            let cluster_id = self.find_nearest_cluster_idx(&vector.values, &clusters);
374
375            // Remove from old cluster if exists
376            if let Some(&old_cluster_id) = vector_map.get(&vector.id) {
377                if old_cluster_id != cluster_id {
378                    clusters[old_cluster_id]
379                        .tombstones
380                        .insert(vector.id.clone());
381                    clusters[old_cluster_id].live_count =
382                        clusters[old_cluster_id].live_count.saturating_sub(1);
383                }
384            }
385
386            clusters[cluster_id].vectors.push(vector.clone());
387            clusters[cluster_id].live_count += 1;
388            vector_map.insert(vector.id.clone(), cluster_id);
389        }
390
391        // Check for splits needed (LIRE: lazy)
392        drop(vector_map);
393        self.check_splits(&mut clusters);
394
395        Ok(count)
396    }
397
398    /// Check and perform cluster splits if needed
399    fn check_splits(&self, clusters: &mut Vec<Cluster>) {
400        let mut new_clusters = Vec::new();
401        let max_size = self.config.max_cluster_size;
402        let base_len = clusters.len();
403
404        for cluster in clusters.iter_mut().take(base_len) {
405            if cluster.live_count > max_size {
406                // Split cluster
407                let new_id = base_len + new_clusters.len();
408                if let Some(new_cluster) = self.split_cluster(cluster, new_id) {
409                    new_clusters.push(new_cluster);
410                }
411            }
412        }
413
414        clusters.extend(new_clusters);
415    }
416
417    /// Split a cluster into two
418    fn split_cluster(&self, cluster: &mut Cluster, new_id: usize) -> Option<Cluster> {
419        let live_vectors: Vec<Vector> = cluster.live_vectors().cloned().collect();
420        if live_vectors.len() < 2 {
421            return None;
422        }
423
424        // Simple split: use two furthest vectors as new centroids
425        let mut max_dist = 0.0f32;
426        let mut idx1 = 0;
427        let mut idx2 = 1;
428
429        for (i, v1) in live_vectors.iter().enumerate() {
430            for (j, v2) in live_vectors.iter().enumerate().skip(i + 1) {
431                let dist = calculate_distance(&v1.values, &v2.values, self.config.distance_metric);
432                if dist > max_dist {
433                    max_dist = dist;
434                    idx1 = i;
435                    idx2 = j;
436                }
437            }
438        }
439
440        let centroid1 = live_vectors[idx1].values.clone();
441        let centroid2 = live_vectors[idx2].values.clone();
442
443        // Assign vectors to new clusters
444        let mut vectors1 = Vec::new();
445        let mut vectors2 = Vec::new();
446
447        for vector in live_vectors {
448            let dist1 = calculate_distance(&vector.values, &centroid1, self.config.distance_metric);
449            let dist2 = calculate_distance(&vector.values, &centroid2, self.config.distance_metric);
450
451            if dist1 <= dist2 {
452                vectors1.push(vector);
453            } else {
454                vectors2.push(vector);
455            }
456        }
457
458        // Update original cluster
459        cluster.vectors = vectors1;
460        cluster.tombstones.clear();
461        cluster.live_count = cluster.vectors.len();
462        cluster.recompute_centroid();
463
464        // Create new cluster
465        let mut new_cluster = Cluster::new(new_id, centroid2);
466        new_cluster.vectors = vectors2;
467        new_cluster.live_count = new_cluster.vectors.len();
468        new_cluster.recompute_centroid();
469
470        // Update vector-cluster map
471        let mut vector_map = self.vector_cluster_map.write();
472        for v in &cluster.vectors {
473            vector_map.insert(v.id.clone(), cluster.id);
474        }
475        for v in &new_cluster.vectors {
476            vector_map.insert(v.id.clone(), new_cluster.id);
477        }
478
479        Some(new_cluster)
480    }
481
482    /// Remove vectors by ID (LIRE: tombstone-based)
483    pub fn remove(&self, ids: &[VectorId]) -> usize {
484        if !self.is_trained() {
485            // Remove from pending
486            let mut pending = self.pending_vectors.write();
487            let mut global_tombstones = self.global_tombstones.write();
488            let before = pending.len();
489            pending.retain(|v| !ids.contains(&v.id));
490            for id in ids {
491                global_tombstones.insert(id.clone());
492            }
493            return before - pending.len();
494        }
495
496        let mut clusters = self.clusters.write();
497        let vector_map = self.vector_cluster_map.read();
498        let mut count = 0;
499
500        for id in ids {
501            if let Some(&cluster_id) = vector_map.get(id) {
502                if cluster_id < clusters.len() {
503                    clusters[cluster_id].tombstones.insert(id.clone());
504                    clusters[cluster_id].live_count =
505                        clusters[cluster_id].live_count.saturating_sub(1);
506                    count += 1;
507                }
508            }
509        }
510
511        count
512    }
513
514    /// Search for nearest neighbors
515    pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<SpFreshSearchResult>, String> {
516        if !self.is_trained() {
517            // Search pending vectors
518            return self.search_pending(query, k);
519        }
520
521        let clusters = self.clusters.read();
522        if clusters.is_empty() {
523            return Ok(Vec::new());
524        }
525
526        // Find n_probe nearest clusters
527        let mut cluster_distances: Vec<(usize, f32)> = clusters
528            .iter()
529            .enumerate()
530            .map(|(i, c)| {
531                (
532                    i,
533                    calculate_distance(query, &c.centroid, self.config.distance_metric),
534                )
535            })
536            .collect();
537
538        cluster_distances
539            .sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
540
541        let n_probe = self.config.n_probe.min(clusters.len());
542
543        // Search in top clusters
544        let mut results: Vec<SpFreshSearchResult> = Vec::new();
545
546        for (cluster_idx, _) in cluster_distances.iter().take(n_probe) {
547            let cluster = &clusters[*cluster_idx];
548            for vector in cluster.live_vectors() {
549                let score = calculate_distance(query, &vector.values, self.config.distance_metric);
550                results.push(SpFreshSearchResult {
551                    id: vector.id.clone(),
552                    score,
553                    vector: Some(vector.clone()),
554                });
555            }
556        }
557
558        // Sort by score descending (higher = more similar)
559        results.sort_by(|a, b| {
560            b.score
561                .partial_cmp(&a.score)
562                .unwrap_or(std::cmp::Ordering::Equal)
563        });
564        results.truncate(k);
565
566        Ok(results)
567    }
568
569    /// Search pending vectors (before training)
570    fn search_pending(&self, query: &[f32], k: usize) -> Result<Vec<SpFreshSearchResult>, String> {
571        let pending = self.pending_vectors.read();
572        let tombstones = self.global_tombstones.read();
573
574        let mut results: Vec<SpFreshSearchResult> = pending
575            .iter()
576            .filter(|v| !tombstones.contains(&v.id))
577            .map(|v| SpFreshSearchResult {
578                id: v.id.clone(),
579                score: calculate_distance(query, &v.values, self.config.distance_metric),
580                vector: Some(v.clone()),
581            })
582            .collect();
583
584        results.sort_by(|a, b| {
585            b.score
586                .partial_cmp(&a.score)
587                .unwrap_or(std::cmp::Ordering::Equal)
588        });
589        results.truncate(k);
590
591        Ok(results)
592    }
593
594    /// Trigger compaction on clusters exceeding tombstone threshold
595    pub fn compact(&self) -> usize {
596        if !self.is_trained() {
597            return 0;
598        }
599
600        let mut clusters = self.clusters.write();
601        let mut compacted = 0;
602
603        for cluster in clusters.iter_mut() {
604            if cluster.tombstone_ratio() >= self.config.compaction_threshold {
605                cluster.compact();
606                compacted += 1;
607            }
608        }
609
610        // Rebuild vector-cluster map after compaction
611        if compacted > 0 {
612            let mut vector_map = self.vector_cluster_map.write();
613            vector_map.clear();
614            for cluster in clusters.iter() {
615                for vector in &cluster.vectors {
616                    vector_map.insert(vector.id.clone(), cluster.id);
617                }
618            }
619        }
620
621        compacted
622    }
623
624    /// Merge small clusters
625    pub fn merge_small_clusters(&self) -> usize {
626        if !self.is_trained() {
627            return 0;
628        }
629
630        let mut clusters = self.clusters.write();
631        let min_size = self.config.min_cluster_size;
632
633        // Find small clusters
634        let small_clusters: Vec<usize> = clusters
635            .iter()
636            .enumerate()
637            .filter(|(_, c)| c.live_count < min_size && c.live_count > 0)
638            .map(|(i, _)| i)
639            .collect();
640
641        if small_clusters.len() < 2 {
642            return 0;
643        }
644
645        let mut merged = 0;
646
647        // Simple merge: combine pairs of small clusters
648        for chunk in small_clusters.chunks(2) {
649            if chunk.len() == 2 {
650                let (idx1, idx2) = (chunk[0], chunk[1]);
651
652                // Move vectors from idx2 to idx1
653                let vectors_to_move: Vec<Vector> = clusters[idx2].live_vectors().cloned().collect();
654
655                for vector in vectors_to_move {
656                    clusters[idx1].vectors.push(vector);
657                    clusters[idx1].live_count += 1;
658                }
659
660                // Clear idx2
661                clusters[idx2].vectors.clear();
662                clusters[idx2].tombstones.clear();
663                clusters[idx2].live_count = 0;
664
665                // Recompute centroid for merged cluster
666                clusters[idx1].recompute_centroid();
667
668                merged += 1;
669            }
670        }
671
672        // Update vector-cluster map
673        if merged > 0 {
674            let mut vector_map = self.vector_cluster_map.write();
675            for cluster in clusters.iter() {
676                for vector in &cluster.vectors {
677                    if !cluster.tombstones.contains(&vector.id) {
678                        vector_map.insert(vector.id.clone(), cluster.id);
679                    }
680                }
681            }
682        }
683
684        merged
685    }
686
687    /// Get index statistics
688    pub fn stats(&self) -> SpFreshStats {
689        let clusters = self.clusters.read();
690        let pending = self.pending_vectors.read();
691
692        let total_vectors: usize = clusters.iter().map(|c| c.live_count).sum();
693        let total_tombstones: usize = clusters.iter().map(|c| c.tombstones.len()).sum();
694
695        SpFreshStats {
696            num_clusters: clusters.len(),
697            total_vectors,
698            total_tombstones,
699            pending_vectors: pending.len(),
700            trained: *self.trained.read(),
701            dimension: *self.dimension.read(),
702        }
703    }
704
705    /// Get configuration
706    pub fn config(&self) -> &SpFreshConfig {
707        &self.config
708    }
709
710    /// Get read access to clusters for persistence
711    pub(crate) fn clusters_read(&self) -> Vec<Cluster> {
712        self.clusters.read().clone()
713    }
714
715    /// Get read access to vector-cluster map for persistence
716    pub(crate) fn vector_cluster_map_read(&self) -> HashMap<VectorId, usize> {
717        self.vector_cluster_map.read().clone()
718    }
719
720    /// Get read access to global tombstones for persistence
721    pub(crate) fn global_tombstones_read(&self) -> HashSet<VectorId> {
722        self.global_tombstones.read().clone()
723    }
724
725    /// Get read access to pending vectors for persistence
726    pub(crate) fn pending_vectors_read(&self) -> Vec<Vector> {
727        self.pending_vectors.read().clone()
728    }
729
730    /// Restore SPFresh index from a full snapshot
731    pub fn from_snapshot(
732        snapshot: crate::persistence::SpFreshFullSnapshot,
733    ) -> Result<Self, String> {
734        Ok(Self {
735            config: snapshot.config,
736            clusters: RwLock::new(snapshot.clusters),
737            vector_cluster_map: RwLock::new(snapshot.vector_cluster_map),
738            global_tombstones: RwLock::new(snapshot.global_tombstones),
739            pending_vectors: RwLock::new(snapshot.pending_vectors),
740            trained: RwLock::new(snapshot.trained),
741            dimension: RwLock::new(snapshot.dimension),
742        })
743    }
744}
745
746/// Statistics for SPFresh index
747#[derive(Debug, Clone)]
748pub struct SpFreshStats {
749    pub num_clusters: usize,
750    pub total_vectors: usize,
751    pub total_tombstones: usize,
752    pub pending_vectors: usize,
753    pub trained: bool,
754    pub dimension: Option<usize>,
755}
756
757#[cfg(test)]
758mod tests {
759    use super::*;
760
761    fn test_vectors(n: usize, dim: usize) -> Vec<Vector> {
762        // Generate unique vectors - each has a distinct "peak" dimension
763        (0..n)
764            .map(|i| Vector {
765                id: format!("v{}", i),
766                values: (0..dim)
767                    .map(|j| {
768                        // Base value plus unique offset for this vector
769                        (i as f32) + (j as f32 * 0.01)
770                    })
771                    .collect(),
772                metadata: None,
773                ttl_seconds: None,
774                expires_at: None,
775            })
776            .collect()
777    }
778
779    #[test]
780    fn test_train_and_search() {
781        // Use single cluster to guarantee all vectors are searchable
782        let config = SpFreshConfig {
783            num_clusters: 1,
784            n_probe: 1,
785            distance_metric: DistanceMetric::Euclidean,
786            ..Default::default()
787        };
788        let index = SpFreshIndex::new(config);
789
790        let vectors = test_vectors(50, 4);
791        index.train(&vectors).unwrap();
792
793        assert!(index.is_trained());
794        assert_eq!(index.dimension(), Some(4));
795
796        // Search for exact vector - should find it
797        let results = index.search(&vectors[25].values, 5).unwrap();
798        assert!(!results.is_empty());
799
800        // With single cluster, exact match must be first
801        assert_eq!(results[0].id, "v25");
802        assert!(results[0].score < 0.001, "Exact match should have score ~0");
803
804        // Verify results are sorted by score descending (higher = more similar)
805        for i in 1..results.len() {
806            assert!(
807                results[i - 1].score >= results[i].score,
808                "Results should be sorted by score descending"
809            );
810        }
811    }
812
813    #[test]
814    fn test_multi_cluster_search() {
815        let config = SpFreshConfig {
816            num_clusters: 4,
817            n_probe: 4, // Search all clusters
818            distance_metric: DistanceMetric::Euclidean,
819            ..Default::default()
820        };
821        let index = SpFreshIndex::new(config);
822
823        let vectors = test_vectors(100, 8);
824        index.train(&vectors).unwrap();
825
826        // Search should return results
827        let results = index.search(&vectors[50].values, 10).unwrap();
828        assert!(!results.is_empty());
829        assert!(results.len() <= 10);
830
831        // Results should be sorted by score descending (higher = more similar)
832        for i in 1..results.len() {
833            assert!(results[i - 1].score >= results[i].score);
834        }
835
836        // Stats should show 4 clusters with 100 total vectors
837        let stats = index.stats();
838        assert_eq!(stats.num_clusters, 4);
839        assert_eq!(stats.total_vectors, 100);
840    }
841
842    #[test]
843    fn test_add_after_train() {
844        let config = SpFreshConfig {
845            num_clusters: 4,
846            ..Default::default()
847        };
848        let index = SpFreshIndex::new(config);
849
850        let vectors = test_vectors(50, 8);
851        index.train(&vectors).unwrap();
852
853        let new_vectors = vec![Vector {
854            id: "new1".to_string(),
855            values: vec![0.5; 8],
856            metadata: None,
857            ttl_seconds: None,
858            expires_at: None,
859        }];
860
861        let added = index.add(new_vectors).unwrap();
862        assert_eq!(added, 1);
863
864        let stats = index.stats();
865        assert_eq!(stats.total_vectors, 51);
866    }
867
868    #[test]
869    fn test_remove_tombstone() {
870        let config = SpFreshConfig {
871            num_clusters: 4,
872            ..Default::default()
873        };
874        let index = SpFreshIndex::new(config);
875
876        let vectors = test_vectors(50, 8);
877        index.train(&vectors).unwrap();
878
879        let removed = index.remove(&["v0".to_string(), "v1".to_string()]);
880        assert_eq!(removed, 2);
881
882        let stats = index.stats();
883        assert_eq!(stats.total_vectors, 48);
884        assert_eq!(stats.total_tombstones, 2);
885    }
886
887    #[test]
888    fn test_compaction() {
889        let config = SpFreshConfig {
890            num_clusters: 2,
891            compaction_threshold: 0.1,
892            ..Default::default()
893        };
894        let index = SpFreshIndex::new(config);
895
896        let vectors = test_vectors(20, 4);
897        index.train(&vectors).unwrap();
898
899        // Remove many vectors
900        let ids: Vec<String> = (0..10).map(|i| format!("v{}", i)).collect();
901        index.remove(&ids);
902
903        let compacted = index.compact();
904        assert!(compacted > 0);
905
906        let stats = index.stats();
907        assert_eq!(stats.total_tombstones, 0);
908    }
909
910    #[test]
911    fn test_pending_before_train() {
912        let config = SpFreshConfig::default();
913        let index = SpFreshIndex::new(config);
914
915        let vectors = test_vectors(10, 4);
916        index.add(vectors.clone()).unwrap();
917
918        assert!(!index.is_trained());
919        let stats = index.stats();
920        assert_eq!(stats.pending_vectors, 10);
921
922        // Search pending
923        let results = index.search(&vectors[0].values, 3).unwrap();
924        assert!(!results.is_empty());
925    }
926
927    #[test]
928    fn test_dimension_mismatch() {
929        let config = SpFreshConfig {
930            num_clusters: 2,
931            ..Default::default()
932        };
933        let index = SpFreshIndex::new(config);
934
935        let vectors = test_vectors(10, 4);
936        index.train(&vectors).unwrap();
937
938        let bad_vectors = vec![Vector {
939            id: "bad".to_string(),
940            values: vec![1.0, 2.0], // Wrong dimension
941            metadata: None,
942            ttl_seconds: None,
943            expires_at: None,
944        }];
945
946        let result = index.add(bad_vectors);
947        assert!(result.is_err());
948    }
949
950    #[test]
951    fn test_cluster_split() {
952        let config = SpFreshConfig {
953            num_clusters: 1,
954            max_cluster_size: 10,
955            ..Default::default()
956        };
957        let index = SpFreshIndex::new(config);
958
959        let vectors = test_vectors(15, 4);
960        index.train(&vectors).unwrap();
961
962        // Add more to trigger split
963        let more_vectors = test_vectors(20, 4)
964            .into_iter()
965            .enumerate()
966            .map(|(i, mut v)| {
967                v.id = format!("new{}", i);
968                v
969            })
970            .collect();
971
972        index.add(more_vectors).unwrap();
973
974        let stats = index.stats();
975        assert!(stats.num_clusters > 1);
976    }
977
978    #[test]
979    fn test_stats() {
980        let config = SpFreshConfig {
981            num_clusters: 4,
982            ..Default::default()
983        };
984        let index = SpFreshIndex::new(config);
985
986        let vectors = test_vectors(100, 8);
987        index.train(&vectors).unwrap();
988
989        let stats = index.stats();
990        assert_eq!(stats.total_vectors, 100);
991        assert_eq!(stats.num_clusters, 4);
992        assert!(stats.trained);
993        assert_eq!(stats.dimension, Some(8));
994    }
995}