Skip to main content

reddb_server/storage/engine/
ivf.rs

1//! IVF (Inverted File Index) for Vector Search
2//!
3//! Clustering-based approximate nearest neighbor search.
4//! Partitions vectors into clusters (Voronoi cells) and only searches
5//! the most relevant clusters at query time.
6//!
7//! # Design
8//!
9//! - k-means clustering to build centroids
10//! - Each vector assigned to its nearest centroid
11//! - At query time, probe only `nprobe` nearest clusters
12//! - Trade-off: more probes = better recall, slower search
13//!
14//! # Example
15//!
16//! ```ignore
17//! let mut ivf = IvfIndex::new(IvfConfig {
18//!     n_lists: 100,      // Number of clusters
19//!     n_probes: 10,      // Clusters to search
20//!     dimension: 384,
21//! });
22//!
23//! // Train on sample vectors
24//! ivf.train(&training_vectors);
25//!
26//! // Add vectors
27//! ivf.add_batch(&vectors);
28//!
29//! // Search
30//! let results = ivf.search(&query, 10);
31//! ```
32
33use std::collections::HashMap;
34
35use super::distance::{cmp_f32, l2_squared_simd, DistanceResult};
36use super::hnsw::NodeId;
37
38/// IVF configuration
39#[derive(Clone, Debug)]
40pub struct IvfConfig {
41    /// Number of clusters (Voronoi cells)
42    pub n_lists: usize,
43    /// Number of clusters to probe at query time
44    pub n_probes: usize,
45    /// Vector dimension
46    pub dimension: usize,
47    /// Maximum k-means iterations during training
48    pub max_iterations: usize,
49    /// Convergence threshold for k-means
50    pub convergence_threshold: f32,
51}
52
53impl Default for IvfConfig {
54    fn default() -> Self {
55        Self {
56            n_lists: 100,
57            n_probes: 10,
58            dimension: 128,
59            max_iterations: 50,
60            convergence_threshold: 1e-4,
61        }
62    }
63}
64
65impl IvfConfig {
66    pub fn new(dimension: usize, n_lists: usize) -> Self {
67        Self {
68            n_lists,
69            n_probes: (n_lists / 10).max(1),
70            dimension,
71            ..Default::default()
72        }
73    }
74
75    pub fn with_probes(mut self, n_probes: usize) -> Self {
76        self.n_probes = n_probes;
77        self
78    }
79}
80
81/// A cluster containing vectors
82#[derive(Clone)]
83struct IvfList {
84    /// Centroid of this cluster
85    centroid: Vec<f32>,
86    /// Vector IDs in this cluster
87    ids: Vec<NodeId>,
88    /// Vectors in this cluster (stored for search)
89    vectors: Vec<Vec<f32>>,
90}
91
92impl IvfList {
93    fn new(centroid: Vec<f32>) -> Self {
94        Self {
95            centroid,
96            ids: Vec::new(),
97            vectors: Vec::new(),
98        }
99    }
100
101    fn add(&mut self, id: NodeId, vector: Vec<f32>) {
102        self.ids.push(id);
103        self.vectors.push(vector);
104    }
105
106    fn len(&self) -> usize {
107        self.ids.len()
108    }
109
110    fn is_empty(&self) -> bool {
111        self.ids.is_empty()
112    }
113}
114
115/// IVF Index for approximate nearest neighbor search
116pub struct IvfIndex {
117    config: IvfConfig,
118    /// Cluster lists
119    lists: Vec<IvfList>,
120    /// Mapping from vector ID to list index
121    id_to_list: HashMap<NodeId, usize>,
122    /// Whether the index has been trained
123    trained: bool,
124    /// Total vector count
125    count: usize,
126    /// Next auto-generated ID
127    next_id: NodeId,
128}
129
130impl IvfIndex {
131    /// Create a new IVF index (untrained)
132    pub fn new(config: IvfConfig) -> Self {
133        Self {
134            config,
135            lists: Vec::new(),
136            id_to_list: HashMap::new(),
137            trained: false,
138            count: 0,
139            next_id: 0,
140        }
141    }
142
143    /// Create with default config for given dimension
144    pub fn with_dimension(dimension: usize) -> Self {
145        Self::new(IvfConfig::new(dimension, 100))
146    }
147
148    /// Train the index using k-means clustering
149    pub fn train(&mut self, vectors: &[Vec<f32>]) {
150        if vectors.is_empty() {
151            return;
152        }
153
154        let n_lists = self.config.n_lists.min(vectors.len());
155
156        // Initialize centroids using k-means++
157        let centroids = self.kmeans_plusplus_init(vectors, n_lists);
158
159        // Run k-means
160        let final_centroids = self.kmeans(vectors, centroids);
161
162        // Create lists
163        self.lists = final_centroids.into_iter().map(IvfList::new).collect();
164
165        self.trained = true;
166    }
167
168    /// K-means++ initialization for better centroid starting points
169    fn kmeans_plusplus_init(&self, vectors: &[Vec<f32>], k: usize) -> Vec<Vec<f32>> {
170        let mut centroids = Vec::with_capacity(k);
171
172        if vectors.is_empty() || k == 0 {
173            return centroids;
174        }
175
176        // First centroid: random (use middle for determinism)
177        centroids.push(vectors[vectors.len() / 2].clone());
178
179        // Subsequent centroids: weighted by distance to nearest existing centroid
180        for _ in 1..k {
181            let mut distances: Vec<f32> = vectors
182                .iter()
183                .map(|v| {
184                    centroids
185                        .iter()
186                        .map(|c| l2_squared_simd(v, c))
187                        .fold(f32::MAX, f32::min)
188                })
189                .collect();
190
191            // Normalize to probabilities
192            let total: f32 = distances.iter().sum();
193            if total > 0.0 {
194                for d in &mut distances {
195                    *d /= total;
196                }
197            }
198
199            // Select based on cumulative probability (deterministic: use max distance)
200            let max_idx = distances
201                .iter()
202                .enumerate()
203                .max_by(|(la, a), (lb, b)| cmp_f32(**a, **b).then_with(|| la.cmp(lb)))
204                .map(|(i, _)| i)
205                .unwrap_or(0);
206
207            centroids.push(vectors[max_idx].clone());
208        }
209
210        centroids
211    }
212
213    /// Run k-means clustering
214    fn kmeans(&self, vectors: &[Vec<f32>], mut centroids: Vec<Vec<f32>>) -> Vec<Vec<f32>> {
215        let dim = self.config.dimension;
216        let k = centroids.len();
217
218        for _ in 0..self.config.max_iterations {
219            // Assign vectors to nearest centroid
220            let mut assignments: Vec<Vec<usize>> = vec![Vec::new(); k];
221            for (i, vector) in vectors.iter().enumerate() {
222                let nearest = self.find_nearest_centroid(vector, &centroids);
223                assignments[nearest].push(i);
224            }
225
226            // Compute new centroids
227            let mut new_centroids = Vec::with_capacity(k);
228            let mut max_shift: f32 = 0.0;
229
230            for (cluster_idx, indices) in assignments.iter().enumerate() {
231                if indices.is_empty() {
232                    // Keep old centroid if cluster is empty
233                    new_centroids.push(centroids[cluster_idx].clone());
234                    continue;
235                }
236
237                // Average of all vectors in cluster
238                let mut new_centroid = vec![0.0f32; dim];
239                for &idx in indices {
240                    for (j, val) in vectors[idx].iter().enumerate() {
241                        if j < dim {
242                            new_centroid[j] += val;
243                        }
244                    }
245                }
246                for val in &mut new_centroid {
247                    *val /= indices.len() as f32;
248                }
249
250                // Track centroid shift
251                let shift = l2_squared_simd(&new_centroid, &centroids[cluster_idx]).sqrt();
252                max_shift = max_shift.max(shift);
253
254                new_centroids.push(new_centroid);
255            }
256
257            centroids = new_centroids;
258
259            // Check convergence
260            if max_shift < self.config.convergence_threshold {
261                break;
262            }
263        }
264
265        centroids
266    }
267
268    /// Find nearest centroid index
269    fn find_nearest_centroid(&self, vector: &[f32], centroids: &[Vec<f32>]) -> usize {
270        centroids
271            .iter()
272            .enumerate()
273            .map(|(i, c)| (i, l2_squared_simd(vector, c)))
274            .min_by(|(li, la), (ri, rb)| cmp_f32(*la, *rb).then_with(|| li.cmp(ri)))
275            .map(|(i, _)| i)
276            .unwrap_or(0)
277    }
278
279    /// Find k nearest centroids
280    fn find_nearest_centroids(&self, vector: &[f32], k: usize) -> Vec<usize> {
281        let mut distances: Vec<(usize, f32)> = self
282            .lists
283            .iter()
284            .enumerate()
285            .map(|(i, list)| (i, l2_squared_simd(vector, &list.centroid)))
286            .collect();
287
288        distances.sort_by(|(li, la), (ri, lb)| cmp_f32(*la, *lb).then_with(|| li.cmp(ri)));
289        distances.into_iter().take(k).map(|(i, _)| i).collect()
290    }
291
292    /// Add a single vector
293    pub fn add(&mut self, vector: Vec<f32>) -> NodeId {
294        let id = self.next_id;
295        self.next_id += 1;
296        self.add_with_id(id, vector);
297        id
298    }
299
300    /// Add a vector with specific ID
301    pub fn add_with_id(&mut self, id: NodeId, vector: Vec<f32>) {
302        if !self.trained || self.lists.is_empty() {
303            // Auto-train with a single cluster if not trained
304            if self.lists.is_empty() {
305                self.lists.push(IvfList::new(vector.clone()));
306                self.trained = true;
307            }
308        }
309
310        let list_idx = self.find_nearest_centroid(
311            &vector,
312            &self
313                .lists
314                .iter()
315                .map(|l| l.centroid.clone())
316                .collect::<Vec<_>>(),
317        );
318
319        self.lists[list_idx].add(id, vector);
320        self.id_to_list.insert(id, list_idx);
321        self.count += 1;
322    }
323
324    /// Add multiple vectors
325    pub fn add_batch(&mut self, vectors: Vec<Vec<f32>>) -> Vec<NodeId> {
326        vectors.into_iter().map(|v| self.add(v)).collect()
327    }
328
329    /// Add multiple vectors with IDs
330    pub fn add_batch_with_ids(&mut self, items: Vec<(NodeId, Vec<f32>)>) {
331        for (id, vector) in items {
332            self.add_with_id(id, vector);
333        }
334    }
335
336    /// Remove a vector by ID
337    pub fn remove(&mut self, id: NodeId) -> bool {
338        if let Some(list_idx) = self.id_to_list.remove(&id) {
339            let list = &mut self.lists[list_idx];
340            if let Some(pos) = list.ids.iter().position(|&x| x == id) {
341                list.ids.remove(pos);
342                list.vectors.remove(pos);
343                self.count = self.count.saturating_sub(1);
344                return true;
345            }
346        }
347        false
348    }
349
350    /// Search for k nearest neighbors
351    pub fn search(&self, query: &[f32], k: usize) -> Vec<DistanceResult> {
352        self.search_with_probes(query, k, self.config.n_probes)
353    }
354
355    /// Search with custom number of probes
356    pub fn search_with_probes(
357        &self,
358        query: &[f32],
359        k: usize,
360        n_probes: usize,
361    ) -> Vec<DistanceResult> {
362        if self.lists.is_empty() {
363            return Vec::new();
364        }
365
366        let probes = self.find_nearest_centroids(query, n_probes);
367
368        // Collect candidates from probed clusters
369        let mut candidates: Vec<DistanceResult> = Vec::new();
370        for list_idx in probes {
371            let list = &self.lists[list_idx];
372            for (i, vector) in list.vectors.iter().enumerate() {
373                let distance = l2_squared_simd(query, vector).sqrt();
374                candidates.push(DistanceResult::new(list.ids[i], distance));
375            }
376        }
377
378        // Sort and return top k
379        candidates.sort_by(|a, b| cmp_f32(a.distance, b.distance).then_with(|| a.id.cmp(&b.id)));
380        candidates.truncate(k);
381        candidates
382    }
383
384    /// Get a vector by ID
385    pub fn get(&self, id: NodeId) -> Option<&[f32]> {
386        if let Some(&list_idx) = self.id_to_list.get(&id) {
387            let list = &self.lists[list_idx];
388            if let Some(pos) = list.ids.iter().position(|&x| x == id) {
389                return Some(&list.vectors[pos]);
390            }
391        }
392        None
393    }
394
395    /// Check if index contains an ID
396    pub fn contains(&self, id: NodeId) -> bool {
397        self.id_to_list.contains_key(&id)
398    }
399
400    /// Get total vector count
401    pub fn len(&self) -> usize {
402        self.count
403    }
404
405    /// Check if empty
406    pub fn is_empty(&self) -> bool {
407        self.count == 0
408    }
409
410    /// Get number of clusters
411    pub fn n_lists(&self) -> usize {
412        self.lists.len()
413    }
414
415    /// Get cluster statistics
416    pub fn stats(&self) -> IvfStats {
417        let sizes: Vec<usize> = self.lists.iter().map(|l| l.len()).collect();
418        let non_empty = sizes.iter().filter(|&&s| s > 0).count();
419
420        let avg = if non_empty > 0 {
421            sizes.iter().sum::<usize>() as f64 / non_empty as f64
422        } else {
423            0.0
424        };
425
426        let max = sizes.iter().copied().max().unwrap_or(0);
427        let min = sizes.iter().filter(|&&s| s > 0).copied().min().unwrap_or(0);
428
429        IvfStats {
430            total_vectors: self.count,
431            n_lists: self.lists.len(),
432            non_empty_lists: non_empty,
433            avg_list_size: avg,
434            max_list_size: max,
435            min_list_size: min,
436            dimension: self.config.dimension,
437            trained: self.trained,
438        }
439    }
440
441    /// Serialize the index to bytes for storage.
442    ///
443    /// The `IVF1` payload byte layout is owned by `reddb-file` (ADR 0046); this
444    /// only projects the engine state into [`reddb_file::IvfIndexLayout`].
445    pub fn to_bytes(&self) -> Vec<u8> {
446        let lists = self
447            .lists
448            .iter()
449            .map(|list| reddb_file::IvfListLayout {
450                centroid: list.centroid.clone(),
451                ids: list.ids.clone(),
452                vectors: list.vectors.clone(),
453            })
454            .collect();
455        let layout = reddb_file::IvfIndexLayout {
456            n_lists: self.config.n_lists,
457            n_probes: self.config.n_probes,
458            dimension: self.config.dimension,
459            max_iterations: self.config.max_iterations,
460            convergence_threshold: self.config.convergence_threshold,
461            trained: self.trained,
462            count: self.count,
463            next_id: self.next_id,
464            lists,
465        };
466        reddb_file::encode_ivf_index(&layout)
467    }
468
469    /// Deserialize an index from bytes via the `reddb-file` codec.
470    pub fn from_bytes(bytes: &[u8]) -> Result<Self, String> {
471        let layout = reddb_file::decode_ivf_index(bytes).map_err(|e| e.to_string())?;
472
473        let config = IvfConfig {
474            n_lists: layout.n_lists,
475            n_probes: layout.n_probes,
476            dimension: layout.dimension,
477            max_iterations: layout.max_iterations,
478            convergence_threshold: layout.convergence_threshold,
479        };
480
481        let mut lists = Vec::with_capacity(layout.lists.len());
482        let mut id_to_list = HashMap::new();
483        for (list_idx, list) in layout.lists.into_iter().enumerate() {
484            for &id in &list.ids {
485                id_to_list.insert(id, list_idx);
486            }
487            lists.push(IvfList {
488                centroid: list.centroid,
489                ids: list.ids,
490                vectors: list.vectors,
491            });
492        }
493
494        Ok(Self {
495            config,
496            lists,
497            id_to_list,
498            trained: layout.trained,
499            count: layout.count,
500            next_id: layout.next_id,
501        })
502    }
503}
504
505/// IVF index statistics
506#[derive(Debug, Clone)]
507pub struct IvfStats {
508    pub total_vectors: usize,
509    pub n_lists: usize,
510    pub non_empty_lists: usize,
511    pub avg_list_size: f64,
512    pub max_list_size: usize,
513    pub min_list_size: usize,
514    pub dimension: usize,
515    pub trained: bool,
516}
517
518// ============================================================================
519// Tests
520// ============================================================================
521
522#[cfg(test)]
523mod tests {
524    use super::*;
525
526    fn random_vector(dim: usize, seed: u64) -> Vec<f32> {
527        // Simple deterministic "random" for testing
528        (0..dim)
529            .map(|i| ((seed * 1103515245 + i as u64 * 12345) % 1000) as f32 / 1000.0)
530            .collect()
531    }
532
533    #[test]
534    fn test_ivf_basic() {
535        let mut ivf = IvfIndex::new(IvfConfig::new(8, 4));
536
537        // Generate training vectors
538        let training: Vec<Vec<f32>> = (0..100).map(|i| random_vector(8, i)).collect();
539
540        ivf.train(&training);
541        assert!(ivf.trained);
542        assert_eq!(ivf.n_lists(), 4);
543
544        // Add vectors
545        for (i, v) in training.iter().enumerate() {
546            ivf.add_with_id(i as u64, v.clone());
547        }
548
549        assert_eq!(ivf.len(), 100);
550    }
551
552    #[test]
553    fn test_ivf_search() {
554        let dim = 8;
555        let mut ivf = IvfIndex::new(IvfConfig {
556            n_lists: 4,
557            n_probes: 2,
558            dimension: dim,
559            ..Default::default()
560        });
561
562        // Create clustered data
563        let mut vectors = Vec::new();
564        for cluster in 0..4 {
565            let base = cluster as f32 * 10.0;
566            for i in 0..25 {
567                let mut v = vec![base; dim];
568                v[0] += i as f32 * 0.01;
569                vectors.push(v);
570            }
571        }
572
573        ivf.train(&vectors);
574
575        for (i, v) in vectors.iter().enumerate() {
576            ivf.add_with_id(i as u64, v.clone());
577        }
578
579        // Search for vector near cluster 0
580        let query = vec![0.05; dim];
581        let results = ivf.search(&query, 5);
582
583        assert!(!results.is_empty());
584        // Results should be from cluster 0 (IDs 0-24)
585        for r in &results {
586            assert!(r.id < 25);
587        }
588    }
589
590    #[test]
591    fn test_ivf_remove() {
592        let mut ivf = IvfIndex::new(IvfConfig::new(4, 2));
593
594        ivf.add_with_id(1, vec![1.0, 0.0, 0.0, 0.0]);
595        ivf.add_with_id(2, vec![0.0, 1.0, 0.0, 0.0]);
596        ivf.add_with_id(3, vec![0.0, 0.0, 1.0, 0.0]);
597
598        assert_eq!(ivf.len(), 3);
599        assert!(ivf.contains(2));
600
601        assert!(ivf.remove(2));
602        assert_eq!(ivf.len(), 2);
603        assert!(!ivf.contains(2));
604    }
605
606    #[test]
607    fn test_ivf_stats() {
608        let mut ivf = IvfIndex::new(IvfConfig::new(4, 3));
609
610        let training: Vec<Vec<f32>> = vec![
611            vec![0.0, 0.0, 0.0, 0.0],
612            vec![1.0, 0.0, 0.0, 0.0],
613            vec![2.0, 0.0, 0.0, 0.0],
614        ];
615
616        ivf.train(&training);
617
618        for (i, v) in training.iter().enumerate() {
619            ivf.add_with_id(i as u64, v.clone());
620        }
621
622        let stats = ivf.stats();
623        assert_eq!(stats.total_vectors, 3);
624        assert_eq!(stats.n_lists, 3);
625        assert!(stats.trained);
626    }
627}