rag_plusplus_core/index/
flat.rs

1//! Flat Index (Exact Search)
2//!
3//! Brute-force nearest neighbor search. O(n) per query.
4//! Best for: small datasets (<10k vectors), testing, ground truth.
5
6use ahash::AHashMap;
7use ordered_float::OrderedFloat;
8use rayon::prelude::*;
9use std::cmp::Reverse;
10use std::collections::BinaryHeap;
11
12use crate::error::{Error, Result};
13use super::traits::{DistanceType, IndexConfig, SearchResult, VectorIndex};
14
15/// Flat index using brute-force search.
16///
17/// Provides exact nearest neighbor results but O(n) per query.
18#[derive(Debug)]
19pub struct FlatIndex {
20    /// Configuration
21    config: IndexConfig,
22    /// ID -> vector mapping
23    vectors: AHashMap<String, Vec<f32>>,
24    /// Whether to use parallel search
25    parallel: bool,
26    /// Parallel threshold (use parallel if len > this)
27    parallel_threshold: usize,
28}
29
30impl FlatIndex {
31    /// Create a new flat index.
32    #[must_use]
33    pub fn new(config: IndexConfig) -> Self {
34        Self {
35            config,
36            vectors: AHashMap::new(),
37            parallel: true,
38            parallel_threshold: 1000,
39        }
40    }
41
42    /// Create with specified capacity.
43    #[must_use]
44    pub fn with_capacity(config: IndexConfig, capacity: usize) -> Self {
45        Self {
46            config,
47            vectors: AHashMap::with_capacity(capacity),
48            parallel: true,
49            parallel_threshold: 1000,
50        }
51    }
52
53    /// Enable/disable parallel search.
54    #[must_use]
55    pub const fn with_parallel(mut self, parallel: bool) -> Self {
56        self.parallel = parallel;
57        self
58    }
59
60    /// Set parallel threshold.
61    #[must_use]
62    pub const fn with_parallel_threshold(mut self, threshold: usize) -> Self {
63        self.parallel_threshold = threshold;
64        self
65    }
66
67    /// Compute distance between two vectors.
68    fn compute_distance(&self, a: &[f32], b: &[f32]) -> f32 {
69        match self.config.distance_type {
70            DistanceType::L2 => Self::l2_distance(a, b),
71            DistanceType::InnerProduct => Self::inner_product(a, b),
72            DistanceType::Cosine => Self::cosine_similarity(a, b),
73        }
74    }
75
76    /// L2 (Euclidean) squared distance.
77    #[inline]
78    fn l2_distance(a: &[f32], b: &[f32]) -> f32 {
79        a.iter()
80            .zip(b.iter())
81            .map(|(x, y)| (x - y).powi(2))
82            .sum::<f32>()
83            .sqrt()
84    }
85
86    /// Inner product (dot product).
87    #[inline]
88    fn inner_product(a: &[f32], b: &[f32]) -> f32 {
89        a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
90    }
91
92    /// Cosine similarity.
93    #[inline]
94    fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
95        let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
96        let norm_a: f32 = a.iter().map(|x| x.powi(2)).sum::<f32>().sqrt();
97        let norm_b: f32 = b.iter().map(|x| x.powi(2)).sum::<f32>().sqrt();
98        
99        if norm_a == 0.0 || norm_b == 0.0 {
100            0.0
101        } else {
102            dot / (norm_a * norm_b)
103        }
104    }
105
106    /// Normalize a vector in-place.
107    fn normalize(vector: &mut [f32]) {
108        let norm: f32 = vector.iter().map(|x| x.powi(2)).sum::<f32>().sqrt();
109        if norm > 0.0 {
110            for x in vector.iter_mut() {
111                *x /= norm;
112            }
113        }
114    }
115
116    /// Sequential search implementation.
117    fn search_sequential(&self, query: &[f32], k: usize) -> Vec<SearchResult> {
118        // Use a min-heap to track top-k (by distance for L2, by -score for IP/cosine)
119        let mut heap: BinaryHeap<(Reverse<OrderedFloat<f32>>, String)> = BinaryHeap::new();
120        
121        let is_similarity = matches!(
122            self.config.distance_type,
123            DistanceType::InnerProduct | DistanceType::Cosine
124        );
125
126        for (id, vector) in &self.vectors {
127            let dist = self.compute_distance(query, vector);
128            let key = if is_similarity {
129                // For similarity metrics, we want largest values
130                Reverse(OrderedFloat(-dist))
131            } else {
132                // For distance metrics, we want smallest values
133                Reverse(OrderedFloat(dist))
134            };
135
136            if heap.len() < k {
137                heap.push((key, id.clone()));
138            } else if let Some((top_key, _)) = heap.peek() {
139                if key > *top_key {
140                    heap.pop();
141                    heap.push((key, id.clone()));
142                }
143            }
144        }
145
146        // Convert heap to sorted results
147        let mut results: Vec<_> = heap
148            .into_iter()
149            .map(|(Reverse(OrderedFloat(dist)), id)| {
150                let actual_dist = if is_similarity { -dist } else { dist };
151                SearchResult::new(id, actual_dist, self.config.distance_type)
152            })
153            .collect();
154
155        // Sort by distance (ascending for L2, descending for similarity)
156        if is_similarity {
157            results.sort_by(|a, b| b.distance.partial_cmp(&a.distance).unwrap());
158        } else {
159            results.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap());
160        }
161
162        results
163    }
164
165    /// Parallel search implementation.
166    fn search_parallel(&self, query: &[f32], k: usize) -> Vec<SearchResult> {
167        let is_similarity = matches!(
168            self.config.distance_type,
169            DistanceType::InnerProduct | DistanceType::Cosine
170        );
171
172        // Compute distances in parallel
173        let mut distances: Vec<_> = self.vectors
174            .par_iter()
175            .map(|(id, vector)| {
176                let dist = self.compute_distance(query, vector);
177                (id.clone(), dist)
178            })
179            .collect();
180
181        // Sort by distance
182        if is_similarity {
183            distances.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
184        } else {
185            distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
186        }
187
188        // Take top-k
189        distances
190            .into_iter()
191            .take(k)
192            .map(|(id, dist)| SearchResult::new(id, dist, self.config.distance_type))
193            .collect()
194    }
195}
196
197impl VectorIndex for FlatIndex {
198    fn add(&mut self, id: String, vector: &[f32]) -> Result<()> {
199        if vector.len() != self.config.dimension {
200            return Err(Error::InvalidQuery {
201                reason: format!(
202                    "Dimension mismatch: expected {}, got {}",
203                    self.config.dimension,
204                    vector.len()
205                ),
206            });
207        }
208
209        let mut vec = vector.to_vec();
210        if self.config.normalize {
211            Self::normalize(&mut vec);
212        }
213
214        self.vectors.insert(id, vec);
215        Ok(())
216    }
217
218    fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
219        if query.len() != self.config.dimension {
220            return Err(Error::InvalidQuery {
221                reason: format!(
222                    "Query dimension mismatch: expected {}, got {}",
223                    self.config.dimension,
224                    query.len()
225                ),
226            });
227        }
228
229        if self.vectors.is_empty() {
230            return Ok(vec![]);
231        }
232
233        let k = k.min(self.vectors.len());
234
235        // Normalize query if needed
236        let query = if self.config.normalize {
237            let mut q = query.to_vec();
238            Self::normalize(&mut q);
239            q
240        } else {
241            query.to_vec()
242        };
243
244        // Choose sequential or parallel based on index size
245        let results = if self.parallel && self.vectors.len() > self.parallel_threshold {
246            self.search_parallel(&query, k)
247        } else {
248            self.search_sequential(&query, k)
249        };
250
251        Ok(results)
252    }
253
254    fn remove(&mut self, id: &str) -> Result<bool> {
255        Ok(self.vectors.remove(id).is_some())
256    }
257
258    fn contains(&self, id: &str) -> bool {
259        self.vectors.contains_key(id)
260    }
261
262    fn len(&self) -> usize {
263        self.vectors.len()
264    }
265
266    fn dimension(&self) -> usize {
267        self.config.dimension
268    }
269
270    fn distance_type(&self) -> DistanceType {
271        self.config.distance_type
272    }
273
274    fn clear(&mut self) {
275        self.vectors.clear();
276    }
277
278    fn memory_usage(&self) -> usize {
279        // Rough estimate: 
280        // - Each vector: dim * 4 bytes
281        // - Each ID: ~32 bytes average
282        // - HashMap overhead: ~48 bytes per entry
283        let per_entry = self.config.dimension * 4 + 32 + 48;
284        self.vectors.len() * per_entry
285    }
286}
287
288#[cfg(test)]
289mod tests {
290    use super::*;
291
292    fn create_test_index() -> FlatIndex {
293        let config = IndexConfig::new(4);
294        let mut index = FlatIndex::new(config).with_parallel(false);
295        
296        // Add test vectors
297        index.add("a".to_string(), &[1.0, 0.0, 0.0, 0.0]).unwrap();
298        index.add("b".to_string(), &[0.0, 1.0, 0.0, 0.0]).unwrap();
299        index.add("c".to_string(), &[0.0, 0.0, 1.0, 0.0]).unwrap();
300        index.add("d".to_string(), &[0.5, 0.5, 0.0, 0.0]).unwrap();
301        
302        index
303    }
304
305    #[test]
306    fn test_add_and_len() {
307        let index = create_test_index();
308        assert_eq!(index.len(), 4);
309        assert!(index.contains("a"));
310        assert!(!index.contains("z"));
311    }
312
313    #[test]
314    fn test_search_l2() {
315        let index = create_test_index();
316        
317        // Query close to "a"
318        let results = index.search(&[0.9, 0.1, 0.0, 0.0], 2).unwrap();
319        
320        assert_eq!(results.len(), 2);
321        assert_eq!(results[0].id, "a"); // Closest
322    }
323
324    #[test]
325    fn test_search_cosine() {
326        let config = IndexConfig::new(4).with_distance(DistanceType::Cosine);
327        let mut index = FlatIndex::new(config).with_parallel(false);
328        
329        index.add("a".to_string(), &[1.0, 0.0, 0.0, 0.0]).unwrap();
330        index.add("b".to_string(), &[1.0, 1.0, 0.0, 0.0]).unwrap();
331        index.add("c".to_string(), &[0.0, 1.0, 0.0, 0.0]).unwrap();
332        
333        let results = index.search(&[1.0, 0.0, 0.0, 0.0], 2).unwrap();
334        
335        assert_eq!(results[0].id, "a"); // Exact match
336        assert!((results[0].distance - 1.0).abs() < 1e-6); // Cosine = 1.0
337    }
338
339    #[test]
340    fn test_remove() {
341        let mut index = create_test_index();
342        
343        assert!(index.remove("a").unwrap());
344        assert!(!index.contains("a"));
345        assert_eq!(index.len(), 3);
346        
347        assert!(!index.remove("z").unwrap()); // Not found
348    }
349
350    #[test]
351    fn test_dimension_mismatch() {
352        let mut index = create_test_index();
353        
354        let result = index.add("e".to_string(), &[1.0, 2.0]); // Wrong dim
355        assert!(result.is_err());
356    }
357
358    #[test]
359    fn test_empty_search() {
360        let config = IndexConfig::new(4);
361        let index = FlatIndex::new(config);
362        
363        let results = index.search(&[1.0, 0.0, 0.0, 0.0], 10).unwrap();
364        assert!(results.is_empty());
365    }
366
367    #[test]
368    fn test_clear() {
369        let mut index = create_test_index();
370        assert_eq!(index.len(), 4);
371        
372        index.clear();
373        assert!(index.is_empty());
374    }
375}