Skip to main content

reddb_server/storage/engine/
ivf.rs

1//! IVF (Inverted File Index) for Vector Search
2//!
3//! Clustering-based approximate nearest neighbor search.
4//! Partitions vectors into clusters (Voronoi cells) and only searches
5//! the most relevant clusters at query time.
6//!
7//! # Design
8//!
9//! - k-means clustering to build centroids
10//! - Each vector assigned to its nearest centroid
11//! - At query time, probe only `nprobe` nearest clusters
12//! - Trade-off: more probes = better recall, slower search
13//!
14//! # Example
15//!
16//! ```ignore
17//! let mut ivf = IvfIndex::new(IvfConfig {
18//!     n_lists: 100,      // Number of clusters
19//!     n_probes: 10,      // Clusters to search
20//!     dimension: 384,
21//! });
22//!
23//! // Train on sample vectors
24//! ivf.train(&training_vectors);
25//!
26//! // Add vectors
27//! ivf.add_batch(&vectors);
28//!
29//! // Search
30//! let results = ivf.search(&query, 10);
31//! ```
32
33use std::collections::HashMap;
34
35use super::distance::{cmp_f32, l2_squared_simd, DistanceResult};
36use super::hnsw::NodeId;
37
38/// IVF configuration
39#[derive(Clone, Debug)]
40pub struct IvfConfig {
41    /// Number of clusters (Voronoi cells)
42    pub n_lists: usize,
43    /// Number of clusters to probe at query time
44    pub n_probes: usize,
45    /// Vector dimension
46    pub dimension: usize,
47    /// Maximum k-means iterations during training
48    pub max_iterations: usize,
49    /// Convergence threshold for k-means
50    pub convergence_threshold: f32,
51}
52
53impl Default for IvfConfig {
54    fn default() -> Self {
55        Self {
56            n_lists: 100,
57            n_probes: 10,
58            dimension: 128,
59            max_iterations: 50,
60            convergence_threshold: 1e-4,
61        }
62    }
63}
64
65impl IvfConfig {
66    pub fn new(dimension: usize, n_lists: usize) -> Self {
67        Self {
68            n_lists,
69            n_probes: (n_lists / 10).max(1),
70            dimension,
71            ..Default::default()
72        }
73    }
74
75    pub fn with_probes(mut self, n_probes: usize) -> Self {
76        self.n_probes = n_probes;
77        self
78    }
79}
80
81/// A cluster containing vectors
82#[derive(Clone)]
83struct IvfList {
84    /// Centroid of this cluster
85    centroid: Vec<f32>,
86    /// Vector IDs in this cluster
87    ids: Vec<NodeId>,
88    /// Vectors in this cluster (stored for search)
89    vectors: Vec<Vec<f32>>,
90}
91
92impl IvfList {
93    fn new(centroid: Vec<f32>) -> Self {
94        Self {
95            centroid,
96            ids: Vec::new(),
97            vectors: Vec::new(),
98        }
99    }
100
101    fn add(&mut self, id: NodeId, vector: Vec<f32>) {
102        self.ids.push(id);
103        self.vectors.push(vector);
104    }
105
106    fn len(&self) -> usize {
107        self.ids.len()
108    }
109
110    fn is_empty(&self) -> bool {
111        self.ids.is_empty()
112    }
113}
114
115/// IVF Index for approximate nearest neighbor search
116pub struct IvfIndex {
117    config: IvfConfig,
118    /// Cluster lists
119    lists: Vec<IvfList>,
120    /// Mapping from vector ID to list index
121    id_to_list: HashMap<NodeId, usize>,
122    /// Whether the index has been trained
123    trained: bool,
124    /// Total vector count
125    count: usize,
126    /// Next auto-generated ID
127    next_id: NodeId,
128}
129
130impl IvfIndex {
131    /// Create a new IVF index (untrained)
132    pub fn new(config: IvfConfig) -> Self {
133        Self {
134            config,
135            lists: Vec::new(),
136            id_to_list: HashMap::new(),
137            trained: false,
138            count: 0,
139            next_id: 0,
140        }
141    }
142
143    /// Create with default config for given dimension
144    pub fn with_dimension(dimension: usize) -> Self {
145        Self::new(IvfConfig::new(dimension, 100))
146    }
147
148    /// Train the index using k-means clustering
149    pub fn train(&mut self, vectors: &[Vec<f32>]) {
150        if vectors.is_empty() {
151            return;
152        }
153
154        let n_lists = self.config.n_lists.min(vectors.len());
155
156        // Initialize centroids using k-means++
157        let centroids = self.kmeans_plusplus_init(vectors, n_lists);
158
159        // Run k-means
160        let final_centroids = self.kmeans(vectors, centroids);
161
162        // Create lists
163        self.lists = final_centroids.into_iter().map(IvfList::new).collect();
164
165        self.trained = true;
166    }
167
168    /// K-means++ initialization for better centroid starting points
169    fn kmeans_plusplus_init(&self, vectors: &[Vec<f32>], k: usize) -> Vec<Vec<f32>> {
170        let mut centroids = Vec::with_capacity(k);
171
172        if vectors.is_empty() || k == 0 {
173            return centroids;
174        }
175
176        // First centroid: random (use middle for determinism)
177        centroids.push(vectors[vectors.len() / 2].clone());
178
179        // Subsequent centroids: weighted by distance to nearest existing centroid
180        for _ in 1..k {
181            let mut distances: Vec<f32> = vectors
182                .iter()
183                .map(|v| {
184                    centroids
185                        .iter()
186                        .map(|c| l2_squared_simd(v, c))
187                        .fold(f32::MAX, f32::min)
188                })
189                .collect();
190
191            // Normalize to probabilities
192            let total: f32 = distances.iter().sum();
193            if total > 0.0 {
194                for d in &mut distances {
195                    *d /= total;
196                }
197            }
198
199            // Select based on cumulative probability (deterministic: use max distance)
200            let max_idx = distances
201                .iter()
202                .enumerate()
203                .max_by(|(la, a), (lb, b)| cmp_f32(**a, **b).then_with(|| la.cmp(lb)))
204                .map(|(i, _)| i)
205                .unwrap_or(0);
206
207            centroids.push(vectors[max_idx].clone());
208        }
209
210        centroids
211    }
212
213    /// Run k-means clustering
214    fn kmeans(&self, vectors: &[Vec<f32>], mut centroids: Vec<Vec<f32>>) -> Vec<Vec<f32>> {
215        let dim = self.config.dimension;
216        let k = centroids.len();
217
218        for _ in 0..self.config.max_iterations {
219            // Assign vectors to nearest centroid
220            let mut assignments: Vec<Vec<usize>> = vec![Vec::new(); k];
221            for (i, vector) in vectors.iter().enumerate() {
222                let nearest = self.find_nearest_centroid(vector, &centroids);
223                assignments[nearest].push(i);
224            }
225
226            // Compute new centroids
227            let mut new_centroids = Vec::with_capacity(k);
228            let mut max_shift: f32 = 0.0;
229
230            for (cluster_idx, indices) in assignments.iter().enumerate() {
231                if indices.is_empty() {
232                    // Keep old centroid if cluster is empty
233                    new_centroids.push(centroids[cluster_idx].clone());
234                    continue;
235                }
236
237                // Average of all vectors in cluster
238                let mut new_centroid = vec![0.0f32; dim];
239                for &idx in indices {
240                    for (j, val) in vectors[idx].iter().enumerate() {
241                        if j < dim {
242                            new_centroid[j] += val;
243                        }
244                    }
245                }
246                for val in &mut new_centroid {
247                    *val /= indices.len() as f32;
248                }
249
250                // Track centroid shift
251                let shift = l2_squared_simd(&new_centroid, &centroids[cluster_idx]).sqrt();
252                max_shift = max_shift.max(shift);
253
254                new_centroids.push(new_centroid);
255            }
256
257            centroids = new_centroids;
258
259            // Check convergence
260            if max_shift < self.config.convergence_threshold {
261                break;
262            }
263        }
264
265        centroids
266    }
267
268    /// Find nearest centroid index
269    fn find_nearest_centroid(&self, vector: &[f32], centroids: &[Vec<f32>]) -> usize {
270        centroids
271            .iter()
272            .enumerate()
273            .map(|(i, c)| (i, l2_squared_simd(vector, c)))
274            .min_by(|(li, la), (ri, rb)| cmp_f32(*la, *rb).then_with(|| li.cmp(ri)))
275            .map(|(i, _)| i)
276            .unwrap_or(0)
277    }
278
279    /// Find k nearest centroids
280    fn find_nearest_centroids(&self, vector: &[f32], k: usize) -> Vec<usize> {
281        let mut distances: Vec<(usize, f32)> = self
282            .lists
283            .iter()
284            .enumerate()
285            .map(|(i, list)| (i, l2_squared_simd(vector, &list.centroid)))
286            .collect();
287
288        distances.sort_by(|(li, la), (ri, lb)| cmp_f32(*la, *lb).then_with(|| li.cmp(ri)));
289        distances.into_iter().take(k).map(|(i, _)| i).collect()
290    }
291
292    /// Add a single vector
293    pub fn add(&mut self, vector: Vec<f32>) -> NodeId {
294        let id = self.next_id;
295        self.next_id += 1;
296        self.add_with_id(id, vector);
297        id
298    }
299
300    /// Add a vector with specific ID
301    pub fn add_with_id(&mut self, id: NodeId, vector: Vec<f32>) {
302        if !self.trained || self.lists.is_empty() {
303            // Auto-train with a single cluster if not trained
304            if self.lists.is_empty() {
305                self.lists.push(IvfList::new(vector.clone()));
306                self.trained = true;
307            }
308        }
309
310        let list_idx = self.find_nearest_centroid(
311            &vector,
312            &self
313                .lists
314                .iter()
315                .map(|l| l.centroid.clone())
316                .collect::<Vec<_>>(),
317        );
318
319        self.lists[list_idx].add(id, vector);
320        self.id_to_list.insert(id, list_idx);
321        self.count += 1;
322    }
323
324    /// Add multiple vectors
325    pub fn add_batch(&mut self, vectors: Vec<Vec<f32>>) -> Vec<NodeId> {
326        vectors.into_iter().map(|v| self.add(v)).collect()
327    }
328
329    /// Add multiple vectors with IDs
330    pub fn add_batch_with_ids(&mut self, items: Vec<(NodeId, Vec<f32>)>) {
331        for (id, vector) in items {
332            self.add_with_id(id, vector);
333        }
334    }
335
336    /// Remove a vector by ID
337    pub fn remove(&mut self, id: NodeId) -> bool {
338        if let Some(list_idx) = self.id_to_list.remove(&id) {
339            let list = &mut self.lists[list_idx];
340            if let Some(pos) = list.ids.iter().position(|&x| x == id) {
341                list.ids.remove(pos);
342                list.vectors.remove(pos);
343                self.count = self.count.saturating_sub(1);
344                return true;
345            }
346        }
347        false
348    }
349
350    /// Search for k nearest neighbors
351    pub fn search(&self, query: &[f32], k: usize) -> Vec<DistanceResult> {
352        self.search_with_probes(query, k, self.config.n_probes)
353    }
354
355    /// Search with custom number of probes
356    pub fn search_with_probes(
357        &self,
358        query: &[f32],
359        k: usize,
360        n_probes: usize,
361    ) -> Vec<DistanceResult> {
362        if self.lists.is_empty() {
363            return Vec::new();
364        }
365
366        let probes = self.find_nearest_centroids(query, n_probes);
367
368        // Collect candidates from probed clusters
369        let mut candidates: Vec<DistanceResult> = Vec::new();
370        for list_idx in probes {
371            let list = &self.lists[list_idx];
372            for (i, vector) in list.vectors.iter().enumerate() {
373                let distance = l2_squared_simd(query, vector).sqrt();
374                candidates.push(DistanceResult::new(list.ids[i], distance));
375            }
376        }
377
378        // Sort and return top k
379        candidates.sort_by(|a, b| cmp_f32(a.distance, b.distance).then_with(|| a.id.cmp(&b.id)));
380        candidates.truncate(k);
381        candidates
382    }
383
384    /// Get a vector by ID
385    pub fn get(&self, id: NodeId) -> Option<&[f32]> {
386        if let Some(&list_idx) = self.id_to_list.get(&id) {
387            let list = &self.lists[list_idx];
388            if let Some(pos) = list.ids.iter().position(|&x| x == id) {
389                return Some(&list.vectors[pos]);
390            }
391        }
392        None
393    }
394
395    /// Check if index contains an ID
396    pub fn contains(&self, id: NodeId) -> bool {
397        self.id_to_list.contains_key(&id)
398    }
399
400    /// Get total vector count
401    pub fn len(&self) -> usize {
402        self.count
403    }
404
405    /// Check if empty
406    pub fn is_empty(&self) -> bool {
407        self.count == 0
408    }
409
410    /// Get number of clusters
411    pub fn n_lists(&self) -> usize {
412        self.lists.len()
413    }
414
415    /// Get cluster statistics
416    pub fn stats(&self) -> IvfStats {
417        let sizes: Vec<usize> = self.lists.iter().map(|l| l.len()).collect();
418        let non_empty = sizes.iter().filter(|&&s| s > 0).count();
419
420        let avg = if non_empty > 0 {
421            sizes.iter().sum::<usize>() as f64 / non_empty as f64
422        } else {
423            0.0
424        };
425
426        let max = sizes.iter().copied().max().unwrap_or(0);
427        let min = sizes.iter().filter(|&&s| s > 0).copied().min().unwrap_or(0);
428
429        IvfStats {
430            total_vectors: self.count,
431            n_lists: self.lists.len(),
432            non_empty_lists: non_empty,
433            avg_list_size: avg,
434            max_list_size: max,
435            min_list_size: min,
436            dimension: self.config.dimension,
437            trained: self.trained,
438        }
439    }
440
441    /// Serialize the index to bytes for storage
442    pub fn to_bytes(&self) -> Vec<u8> {
443        let mut bytes = Vec::new();
444        bytes.extend_from_slice(b"IVF1");
445        bytes.extend_from_slice(&(self.config.n_lists as u32).to_le_bytes());
446        bytes.extend_from_slice(&(self.config.n_probes as u32).to_le_bytes());
447        bytes.extend_from_slice(&(self.config.dimension as u32).to_le_bytes());
448        bytes.extend_from_slice(&(self.config.max_iterations as u32).to_le_bytes());
449        bytes.extend_from_slice(&self.config.convergence_threshold.to_le_bytes());
450        bytes.push(if self.trained { 1 } else { 0 });
451        bytes.extend_from_slice(&(self.count as u64).to_le_bytes());
452        bytes.extend_from_slice(&self.next_id.to_le_bytes());
453        bytes.extend_from_slice(&(self.lists.len() as u32).to_le_bytes());
454
455        for list in &self.lists {
456            bytes.extend_from_slice(&(list.centroid.len() as u32).to_le_bytes());
457            for value in &list.centroid {
458                bytes.extend_from_slice(&value.to_le_bytes());
459            }
460
461            bytes.extend_from_slice(&(list.ids.len() as u32).to_le_bytes());
462            for id in &list.ids {
463                bytes.extend_from_slice(&id.to_le_bytes());
464            }
465
466            bytes.extend_from_slice(&(list.vectors.len() as u32).to_le_bytes());
467            for vector in &list.vectors {
468                bytes.extend_from_slice(&(vector.len() as u32).to_le_bytes());
469                for value in vector {
470                    bytes.extend_from_slice(&value.to_le_bytes());
471                }
472            }
473        }
474
475        bytes
476    }
477
478    /// Deserialize an index from bytes
479    pub fn from_bytes(bytes: &[u8]) -> Result<Self, String> {
480        if bytes.len() < 41 {
481            return Err("data too short".to_string());
482        }
483        if &bytes[0..4] != b"IVF1" {
484            return Err("invalid IVF magic".to_string());
485        }
486
487        let mut pos = 4usize;
488        let read_u32 = |buf: &[u8], pos: &mut usize| -> Result<u32, String> {
489            if *pos + 4 > buf.len() {
490                return Err("truncated IVF payload".to_string());
491            }
492            let value =
493                u32::from_le_bytes([buf[*pos], buf[*pos + 1], buf[*pos + 2], buf[*pos + 3]]);
494            *pos += 4;
495            Ok(value)
496        };
497        let read_u64 = |buf: &[u8], pos: &mut usize| -> Result<u64, String> {
498            if *pos + 8 > buf.len() {
499                return Err("truncated IVF payload".to_string());
500            }
501            let value = u64::from_le_bytes([
502                buf[*pos],
503                buf[*pos + 1],
504                buf[*pos + 2],
505                buf[*pos + 3],
506                buf[*pos + 4],
507                buf[*pos + 5],
508                buf[*pos + 6],
509                buf[*pos + 7],
510            ]);
511            *pos += 8;
512            Ok(value)
513        };
514        let read_f32 = |buf: &[u8], pos: &mut usize| -> Result<f32, String> {
515            if *pos + 4 > buf.len() {
516                return Err("truncated IVF payload".to_string());
517            }
518            let value =
519                f32::from_le_bytes([buf[*pos], buf[*pos + 1], buf[*pos + 2], buf[*pos + 3]]);
520            *pos += 4;
521            Ok(value)
522        };
523
524        let config = IvfConfig {
525            n_lists: read_u32(bytes, &mut pos)? as usize,
526            n_probes: read_u32(bytes, &mut pos)? as usize,
527            dimension: read_u32(bytes, &mut pos)? as usize,
528            max_iterations: read_u32(bytes, &mut pos)? as usize,
529            convergence_threshold: read_f32(bytes, &mut pos)?,
530        };
531        if pos >= bytes.len() {
532            return Err("truncated IVF payload".to_string());
533        }
534        let trained = bytes[pos] == 1;
535        pos += 1;
536        let count = read_u64(bytes, &mut pos)? as usize;
537        let next_id = read_u64(bytes, &mut pos)?;
538        let list_count = read_u32(bytes, &mut pos)? as usize;
539
540        let mut lists = Vec::with_capacity(list_count);
541        let mut id_to_list = HashMap::new();
542        for list_idx in 0..list_count {
543            let centroid_len = read_u32(bytes, &mut pos)? as usize;
544            let mut centroid = Vec::with_capacity(centroid_len);
545            for _ in 0..centroid_len {
546                centroid.push(read_f32(bytes, &mut pos)?);
547            }
548
549            let id_count = read_u32(bytes, &mut pos)? as usize;
550            let mut ids = Vec::with_capacity(id_count);
551            for _ in 0..id_count {
552                let id = read_u64(bytes, &mut pos)?;
553                id_to_list.insert(id, list_idx);
554                ids.push(id);
555            }
556
557            let vector_count = read_u32(bytes, &mut pos)? as usize;
558            let mut vectors = Vec::with_capacity(vector_count);
559            for _ in 0..vector_count {
560                let vector_len = read_u32(bytes, &mut pos)? as usize;
561                let mut vector = Vec::with_capacity(vector_len);
562                for _ in 0..vector_len {
563                    vector.push(read_f32(bytes, &mut pos)?);
564                }
565                vectors.push(vector);
566            }
567
568            lists.push(IvfList {
569                centroid,
570                ids,
571                vectors,
572            });
573        }
574
575        Ok(Self {
576            config,
577            lists,
578            id_to_list,
579            trained,
580            count,
581            next_id,
582        })
583    }
584}
585
586/// IVF index statistics
587#[derive(Debug, Clone)]
588pub struct IvfStats {
589    pub total_vectors: usize,
590    pub n_lists: usize,
591    pub non_empty_lists: usize,
592    pub avg_list_size: f64,
593    pub max_list_size: usize,
594    pub min_list_size: usize,
595    pub dimension: usize,
596    pub trained: bool,
597}
598
599// ============================================================================
600// Tests
601// ============================================================================
602
603#[cfg(test)]
604mod tests {
605    use super::*;
606
607    fn random_vector(dim: usize, seed: u64) -> Vec<f32> {
608        // Simple deterministic "random" for testing
609        (0..dim)
610            .map(|i| ((seed * 1103515245 + i as u64 * 12345) % 1000) as f32 / 1000.0)
611            .collect()
612    }
613
614    #[test]
615    fn test_ivf_basic() {
616        let mut ivf = IvfIndex::new(IvfConfig::new(8, 4));
617
618        // Generate training vectors
619        let training: Vec<Vec<f32>> = (0..100).map(|i| random_vector(8, i)).collect();
620
621        ivf.train(&training);
622        assert!(ivf.trained);
623        assert_eq!(ivf.n_lists(), 4);
624
625        // Add vectors
626        for (i, v) in training.iter().enumerate() {
627            ivf.add_with_id(i as u64, v.clone());
628        }
629
630        assert_eq!(ivf.len(), 100);
631    }
632
633    #[test]
634    fn test_ivf_search() {
635        let dim = 8;
636        let mut ivf = IvfIndex::new(IvfConfig {
637            n_lists: 4,
638            n_probes: 2,
639            dimension: dim,
640            ..Default::default()
641        });
642
643        // Create clustered data
644        let mut vectors = Vec::new();
645        for cluster in 0..4 {
646            let base = cluster as f32 * 10.0;
647            for i in 0..25 {
648                let mut v = vec![base; dim];
649                v[0] += i as f32 * 0.01;
650                vectors.push(v);
651            }
652        }
653
654        ivf.train(&vectors);
655
656        for (i, v) in vectors.iter().enumerate() {
657            ivf.add_with_id(i as u64, v.clone());
658        }
659
660        // Search for vector near cluster 0
661        let query = vec![0.05; dim];
662        let results = ivf.search(&query, 5);
663
664        assert!(!results.is_empty());
665        // Results should be from cluster 0 (IDs 0-24)
666        for r in &results {
667            assert!(r.id < 25);
668        }
669    }
670
671    #[test]
672    fn test_ivf_remove() {
673        let mut ivf = IvfIndex::new(IvfConfig::new(4, 2));
674
675        ivf.add_with_id(1, vec![1.0, 0.0, 0.0, 0.0]);
676        ivf.add_with_id(2, vec![0.0, 1.0, 0.0, 0.0]);
677        ivf.add_with_id(3, vec![0.0, 0.0, 1.0, 0.0]);
678
679        assert_eq!(ivf.len(), 3);
680        assert!(ivf.contains(2));
681
682        assert!(ivf.remove(2));
683        assert_eq!(ivf.len(), 2);
684        assert!(!ivf.contains(2));
685    }
686
687    #[test]
688    fn test_ivf_stats() {
689        let mut ivf = IvfIndex::new(IvfConfig::new(4, 3));
690
691        let training: Vec<Vec<f32>> = vec![
692            vec![0.0, 0.0, 0.0, 0.0],
693            vec![1.0, 0.0, 0.0, 0.0],
694            vec![2.0, 0.0, 0.0, 0.0],
695        ];
696
697        ivf.train(&training);
698
699        for (i, v) in training.iter().enumerate() {
700            ivf.add_with_id(i as u64, v.clone());
701        }
702
703        let stats = ivf.stats();
704        assert_eq!(stats.total_vectors, 3);
705        assert_eq!(stats.n_lists, 3);
706        assert!(stats.trained);
707    }
708}