Skip to main content

nodedb_vector/
flat.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Flat (brute-force) vector index for small collections.
4//!
5//! Simple linear scan over all stored vectors. No graph overhead, exact
6//! results. Automatically used when a collection has fewer than
7//! `DEFAULT_FLAT_INDEX_THRESHOLD` vectors (default 10K). Also serves as the
8//! search method for growing segments before HNSW construction.
9//!
10//! Complexity: O(N × D) per query where N = vectors, D = dimensions.
11
12use roaring::RoaringBitmap;
13
14use crate::distance::{DistanceMetric, distance};
15use crate::hnsw::SearchResult;
16
17/// Default threshold below which collections use flat index instead of HNSW.
18pub const DEFAULT_FLAT_INDEX_THRESHOLD: usize = 10_000;
19
20/// Flat vector index: append-only buffer with brute-force search.
21pub struct FlatIndex {
22    dim: usize,
23    metric: DistanceMetric,
24    /// Vectors stored contiguously for cache-friendly sequential scan.
25    data: Vec<f32>,
26    /// Tombstone bitmap: `deleted[i]` = true means vector i is soft-deleted.
27    deleted: Vec<bool>,
28    /// Number of live (non-deleted) vectors.
29    live_count: usize,
30}
31
32impl FlatIndex {
33    /// Create a new empty flat index.
34    pub fn new(dim: usize, metric: DistanceMetric) -> Self {
35        Self {
36            dim,
37            metric,
38            data: Vec::new(),
39            deleted: Vec::new(),
40            live_count: 0,
41        }
42    }
43
44    /// Insert a vector. Returns the assigned vector ID.
45    pub fn insert(&mut self, vector: Vec<f32>) -> u32 {
46        assert_eq!(
47            vector.len(),
48            self.dim,
49            "dimension mismatch: expected {}, got {}",
50            self.dim,
51            vector.len()
52        );
53        let id = self.len() as u32;
54        self.data.extend_from_slice(&vector);
55        self.deleted.push(false);
56        self.live_count += 1;
57        id
58    }
59
60    /// Soft-delete a vector by ID.
61    pub fn delete(&mut self, id: u32) -> bool {
62        let idx = id as usize;
63        if idx < self.deleted.len() && !self.deleted[idx] {
64            self.deleted[idx] = true;
65            self.live_count -= 1;
66            true
67        } else {
68            false
69        }
70    }
71
72    /// Brute-force k-NN search with an explicit distance metric override.
73    /// Overrides the `self.metric` configured at collection creation time.
74    pub fn search_with_metric(
75        &self,
76        query: &[f32],
77        top_k: usize,
78        metric: DistanceMetric,
79    ) -> Vec<SearchResult> {
80        assert_eq!(query.len(), self.dim);
81        let n = self.len();
82        if n == 0 || top_k == 0 {
83            return Vec::new();
84        }
85
86        let mut candidates: Vec<SearchResult> = Vec::with_capacity(n.min(top_k * 2));
87        for i in 0..n {
88            if self.deleted[i] {
89                continue;
90            }
91            let start = i * self.dim;
92            let vec_slice = &self.data[start..start + self.dim];
93            let dist = distance(query, vec_slice, metric);
94            candidates.push(SearchResult {
95                id: i as u32,
96                distance: dist,
97            });
98        }
99
100        if candidates.len() > top_k {
101            candidates.select_nth_unstable_by(top_k, |a, b| {
102                a.distance
103                    .partial_cmp(&b.distance)
104                    .unwrap_or(std::cmp::Ordering::Equal)
105            });
106            candidates.truncate(top_k);
107        }
108        candidates.sort_by(|a, b| {
109            a.distance
110                .partial_cmp(&b.distance)
111                .unwrap_or(std::cmp::Ordering::Equal)
112        });
113        candidates
114    }
115
116    /// Brute-force k-NN search. Exact results — no approximation.
117    pub fn search(&self, query: &[f32], top_k: usize) -> Vec<SearchResult> {
118        assert_eq!(query.len(), self.dim);
119        let n = self.len();
120        if n == 0 || top_k == 0 {
121            return Vec::new();
122        }
123
124        let mut candidates: Vec<SearchResult> = Vec::with_capacity(n.min(top_k * 2));
125        for i in 0..n {
126            if self.deleted[i] {
127                continue;
128            }
129            let start = i * self.dim;
130            let vec_slice = &self.data[start..start + self.dim];
131            let dist = distance(query, vec_slice, self.metric);
132            candidates.push(SearchResult {
133                id: i as u32,
134                distance: dist,
135            });
136        }
137
138        if candidates.len() > top_k {
139            candidates.select_nth_unstable_by(top_k, |a, b| {
140                a.distance
141                    .partial_cmp(&b.distance)
142                    .unwrap_or(std::cmp::Ordering::Equal)
143            });
144            candidates.truncate(top_k);
145        }
146        candidates.sort_by(|a, b| {
147            a.distance
148                .partial_cmp(&b.distance)
149                .unwrap_or(std::cmp::Ordering::Equal)
150        });
151        candidates
152    }
153
154    /// Search with a pre-filter bitmap (byte-array format).
155    pub fn search_filtered(&self, query: &[f32], top_k: usize, bitmap: &[u8]) -> Vec<SearchResult> {
156        self.search_filtered_offset(query, top_k, bitmap, 0)
157    }
158
159    /// Filtered search with an explicit metric override.
160    pub fn search_filtered_offset_with_metric(
161        &self,
162        query: &[f32],
163        top_k: usize,
164        bitmap: &[u8],
165        id_offset: u32,
166        metric: DistanceMetric,
167    ) -> Vec<SearchResult> {
168        assert_eq!(query.len(), self.dim);
169        let n = self.len();
170        if n == 0 || top_k == 0 {
171            return Vec::new();
172        }
173
174        let parsed = RoaringBitmap::deserialize_from(bitmap).ok();
175
176        let mut candidates: Vec<SearchResult> = Vec::with_capacity(top_k * 2);
177        for i in 0..n {
178            if self.deleted[i] {
179                continue;
180            }
181            if let Some(ref bm) = parsed {
182                let global = (i as u32).saturating_add(id_offset);
183                if !bm.contains(global) {
184                    continue;
185                }
186            }
187            let start = i * self.dim;
188            let vec_slice = &self.data[start..start + self.dim];
189            let dist = distance(query, vec_slice, metric);
190            candidates.push(SearchResult {
191                id: i as u32,
192                distance: dist,
193            });
194        }
195
196        if candidates.len() > top_k {
197            candidates.select_nth_unstable_by(top_k, |a, b| {
198                a.distance
199                    .partial_cmp(&b.distance)
200                    .unwrap_or(std::cmp::Ordering::Equal)
201            });
202            candidates.truncate(top_k);
203        }
204        candidates.sort_by(|a, b| {
205            a.distance
206                .partial_cmp(&b.distance)
207                .unwrap_or(std::cmp::Ordering::Equal)
208        });
209        candidates
210    }
211
212    /// Search with a pre-filter bitmap applying a global id offset.
213    ///
214    /// `bitmap` is a serialized `RoaringBitmap` (matching the HNSW filter
215    /// format). Bit `i + id_offset` tests local id `i`. Used by multi-segment
216    /// collections where the bitmap holds GLOBAL vector ids. If the bytes
217    /// fail to deserialize, the search degrades to unfiltered.
218    pub fn search_filtered_offset(
219        &self,
220        query: &[f32],
221        top_k: usize,
222        bitmap: &[u8],
223        id_offset: u32,
224    ) -> Vec<SearchResult> {
225        assert_eq!(query.len(), self.dim);
226        let n = self.len();
227        if n == 0 || top_k == 0 {
228            return Vec::new();
229        }
230
231        let parsed = RoaringBitmap::deserialize_from(bitmap).ok();
232
233        let mut candidates: Vec<SearchResult> = Vec::with_capacity(top_k * 2);
234        for i in 0..n {
235            if self.deleted[i] {
236                continue;
237            }
238            if let Some(ref bm) = parsed {
239                let global = (i as u32).saturating_add(id_offset);
240                if !bm.contains(global) {
241                    continue;
242                }
243            }
244            let start = i * self.dim;
245            let vec_slice = &self.data[start..start + self.dim];
246            let dist = distance(query, vec_slice, self.metric);
247            candidates.push(SearchResult {
248                id: i as u32,
249                distance: dist,
250            });
251        }
252
253        if candidates.len() > top_k {
254            candidates.select_nth_unstable_by(top_k, |a, b| {
255                a.distance
256                    .partial_cmp(&b.distance)
257                    .unwrap_or(std::cmp::Ordering::Equal)
258            });
259            candidates.truncate(top_k);
260        }
261        candidates.sort_by(|a, b| {
262            a.distance
263                .partial_cmp(&b.distance)
264                .unwrap_or(std::cmp::Ordering::Equal)
265        });
266        candidates
267    }
268
269    pub fn len(&self) -> usize {
270        self.deleted.len()
271    }
272
273    pub fn live_count(&self) -> usize {
274        self.live_count
275    }
276
277    pub fn is_empty(&self) -> bool {
278        self.live_count == 0
279    }
280
281    pub fn get_vector(&self, id: u32) -> Option<&[f32]> {
282        let idx = id as usize;
283        if idx < self.deleted.len() && !self.deleted[idx] {
284            let start = idx * self.dim;
285            Some(&self.data[start..start + self.dim])
286        } else {
287            None
288        }
289    }
290
291    /// Raw access bypassing tombstone filter — used by snapshot/restore.
292    pub fn get_vector_raw(&self, id: u32) -> Option<&[f32]> {
293        let idx = id as usize;
294        if idx < self.deleted.len() {
295            let start = idx * self.dim;
296            Some(&self.data[start..start + self.dim])
297        } else {
298            None
299        }
300    }
301
302    /// Whether the given local id has been tombstoned.
303    pub fn is_deleted(&self, id: u32) -> bool {
304        let idx = id as usize;
305        idx < self.deleted.len() && self.deleted[idx]
306    }
307
308    /// Insert a vector that is already tombstoned (for checkpoint restore).
309    pub fn insert_tombstoned(&mut self, vector: Vec<f32>) -> u32 {
310        assert_eq!(
311            vector.len(),
312            self.dim,
313            "dimension mismatch: expected {}, got {}",
314            self.dim,
315            vector.len()
316        );
317        let id = self.len() as u32;
318        self.data.extend_from_slice(&vector);
319        self.deleted.push(true);
320        // No live_count increment — it's dead on arrival.
321        id
322    }
323
324    pub fn dim(&self) -> usize {
325        self.dim
326    }
327
328    pub fn metric(&self) -> DistanceMetric {
329        self.metric
330    }
331
332    pub fn tombstone_count(&self) -> usize {
333        self.len().saturating_sub(self.live_count)
334    }
335}
336
337#[cfg(test)]
338mod tests {
339    use super::*;
340
341    #[test]
342    fn insert_and_search() {
343        let mut idx = FlatIndex::new(3, DistanceMetric::L2);
344        for i in 0..100u32 {
345            idx.insert(vec![i as f32, 0.0, 0.0]);
346        }
347        assert_eq!(idx.len(), 100);
348        assert_eq!(idx.live_count(), 100);
349
350        let results = idx.search(&[50.0, 0.0, 0.0], 3);
351        assert_eq!(results.len(), 3);
352        assert_eq!(results[0].id, 50);
353        assert!(results[0].distance < 0.01);
354    }
355
356    #[test]
357    fn delete_excludes_from_search() {
358        let mut idx = FlatIndex::new(2, DistanceMetric::L2);
359        idx.insert(vec![0.0, 0.0]);
360        idx.insert(vec![1.0, 0.0]);
361        idx.insert(vec![2.0, 0.0]);
362
363        assert!(idx.delete(1));
364        assert_eq!(idx.live_count(), 2);
365
366        let results = idx.search(&[1.0, 0.0], 3);
367        assert_eq!(results.len(), 2);
368        assert!(results.iter().all(|r| r.id != 1));
369    }
370
371    #[test]
372    fn exact_results() {
373        let mut idx = FlatIndex::new(2, DistanceMetric::Cosine);
374        idx.insert(vec![1.0, 0.0]);
375        idx.insert(vec![0.0, 1.0]);
376        idx.insert(vec![1.0, 1.0]);
377
378        let results = idx.search(&[1.0, 0.0], 1);
379        assert_eq!(results.len(), 1);
380        assert_eq!(results[0].id, 0);
381    }
382
383    #[test]
384    fn empty_search() {
385        let idx = FlatIndex::new(3, DistanceMetric::L2);
386        let results = idx.search(&[1.0, 0.0, 0.0], 5);
387        assert!(results.is_empty());
388    }
389
390    #[test]
391    fn filtered_search() {
392        let mut idx = FlatIndex::new(2, DistanceMetric::L2);
393        for i in 0..8u32 {
394            idx.insert(vec![i as f32, 0.0]);
395        }
396        let bitmap = vec![0b11001100u8];
397        let results = idx.search_filtered(&[3.0, 0.0], 2, &bitmap);
398        assert_eq!(results.len(), 2);
399        assert_eq!(results[0].id, 3);
400        assert_eq!(results[1].id, 2);
401    }
402}