1use roaring::RoaringBitmap;
13
14use crate::distance::{DistanceMetric, distance};
15use crate::hnsw::SearchResult;
16
17pub const DEFAULT_FLAT_INDEX_THRESHOLD: usize = 10_000;
19
20pub struct FlatIndex {
22 dim: usize,
23 metric: DistanceMetric,
24 data: Vec<f32>,
26 deleted: Vec<bool>,
28 live_count: usize,
30}
31
32impl FlatIndex {
33 pub fn new(dim: usize, metric: DistanceMetric) -> Self {
35 Self {
36 dim,
37 metric,
38 data: Vec::new(),
39 deleted: Vec::new(),
40 live_count: 0,
41 }
42 }
43
44 pub fn insert(&mut self, vector: Vec<f32>) -> u32 {
46 assert_eq!(
47 vector.len(),
48 self.dim,
49 "dimension mismatch: expected {}, got {}",
50 self.dim,
51 vector.len()
52 );
53 let id = self.len() as u32;
54 self.data.extend_from_slice(&vector);
55 self.deleted.push(false);
56 self.live_count += 1;
57 id
58 }
59
60 pub fn delete(&mut self, id: u32) -> bool {
62 let idx = id as usize;
63 if idx < self.deleted.len() && !self.deleted[idx] {
64 self.deleted[idx] = true;
65 self.live_count -= 1;
66 true
67 } else {
68 false
69 }
70 }
71
72 pub fn search_with_metric(
75 &self,
76 query: &[f32],
77 top_k: usize,
78 metric: DistanceMetric,
79 ) -> Vec<SearchResult> {
80 assert_eq!(query.len(), self.dim);
81 let n = self.len();
82 if n == 0 || top_k == 0 {
83 return Vec::new();
84 }
85
86 let mut candidates: Vec<SearchResult> = Vec::with_capacity(n.min(top_k * 2));
87 for i in 0..n {
88 if self.deleted[i] {
89 continue;
90 }
91 let start = i * self.dim;
92 let vec_slice = &self.data[start..start + self.dim];
93 let dist = distance(query, vec_slice, metric);
94 candidates.push(SearchResult {
95 id: i as u32,
96 distance: dist,
97 });
98 }
99
100 if candidates.len() > top_k {
101 candidates.select_nth_unstable_by(top_k, |a, b| {
102 a.distance
103 .partial_cmp(&b.distance)
104 .unwrap_or(std::cmp::Ordering::Equal)
105 });
106 candidates.truncate(top_k);
107 }
108 candidates.sort_by(|a, b| {
109 a.distance
110 .partial_cmp(&b.distance)
111 .unwrap_or(std::cmp::Ordering::Equal)
112 });
113 candidates
114 }
115
116 pub fn search(&self, query: &[f32], top_k: usize) -> Vec<SearchResult> {
118 assert_eq!(query.len(), self.dim);
119 let n = self.len();
120 if n == 0 || top_k == 0 {
121 return Vec::new();
122 }
123
124 let mut candidates: Vec<SearchResult> = Vec::with_capacity(n.min(top_k * 2));
125 for i in 0..n {
126 if self.deleted[i] {
127 continue;
128 }
129 let start = i * self.dim;
130 let vec_slice = &self.data[start..start + self.dim];
131 let dist = distance(query, vec_slice, self.metric);
132 candidates.push(SearchResult {
133 id: i as u32,
134 distance: dist,
135 });
136 }
137
138 if candidates.len() > top_k {
139 candidates.select_nth_unstable_by(top_k, |a, b| {
140 a.distance
141 .partial_cmp(&b.distance)
142 .unwrap_or(std::cmp::Ordering::Equal)
143 });
144 candidates.truncate(top_k);
145 }
146 candidates.sort_by(|a, b| {
147 a.distance
148 .partial_cmp(&b.distance)
149 .unwrap_or(std::cmp::Ordering::Equal)
150 });
151 candidates
152 }
153
154 pub fn search_filtered(&self, query: &[f32], top_k: usize, bitmap: &[u8]) -> Vec<SearchResult> {
156 self.search_filtered_offset(query, top_k, bitmap, 0)
157 }
158
159 pub fn search_filtered_offset_with_metric(
161 &self,
162 query: &[f32],
163 top_k: usize,
164 bitmap: &[u8],
165 id_offset: u32,
166 metric: DistanceMetric,
167 ) -> Vec<SearchResult> {
168 assert_eq!(query.len(), self.dim);
169 let n = self.len();
170 if n == 0 || top_k == 0 {
171 return Vec::new();
172 }
173
174 let parsed = RoaringBitmap::deserialize_from(bitmap).ok();
175
176 let mut candidates: Vec<SearchResult> = Vec::with_capacity(top_k * 2);
177 for i in 0..n {
178 if self.deleted[i] {
179 continue;
180 }
181 if let Some(ref bm) = parsed {
182 let global = (i as u32).saturating_add(id_offset);
183 if !bm.contains(global) {
184 continue;
185 }
186 }
187 let start = i * self.dim;
188 let vec_slice = &self.data[start..start + self.dim];
189 let dist = distance(query, vec_slice, metric);
190 candidates.push(SearchResult {
191 id: i as u32,
192 distance: dist,
193 });
194 }
195
196 if candidates.len() > top_k {
197 candidates.select_nth_unstable_by(top_k, |a, b| {
198 a.distance
199 .partial_cmp(&b.distance)
200 .unwrap_or(std::cmp::Ordering::Equal)
201 });
202 candidates.truncate(top_k);
203 }
204 candidates.sort_by(|a, b| {
205 a.distance
206 .partial_cmp(&b.distance)
207 .unwrap_or(std::cmp::Ordering::Equal)
208 });
209 candidates
210 }
211
212 pub fn search_filtered_offset(
219 &self,
220 query: &[f32],
221 top_k: usize,
222 bitmap: &[u8],
223 id_offset: u32,
224 ) -> Vec<SearchResult> {
225 assert_eq!(query.len(), self.dim);
226 let n = self.len();
227 if n == 0 || top_k == 0 {
228 return Vec::new();
229 }
230
231 let parsed = RoaringBitmap::deserialize_from(bitmap).ok();
232
233 let mut candidates: Vec<SearchResult> = Vec::with_capacity(top_k * 2);
234 for i in 0..n {
235 if self.deleted[i] {
236 continue;
237 }
238 if let Some(ref bm) = parsed {
239 let global = (i as u32).saturating_add(id_offset);
240 if !bm.contains(global) {
241 continue;
242 }
243 }
244 let start = i * self.dim;
245 let vec_slice = &self.data[start..start + self.dim];
246 let dist = distance(query, vec_slice, self.metric);
247 candidates.push(SearchResult {
248 id: i as u32,
249 distance: dist,
250 });
251 }
252
253 if candidates.len() > top_k {
254 candidates.select_nth_unstable_by(top_k, |a, b| {
255 a.distance
256 .partial_cmp(&b.distance)
257 .unwrap_or(std::cmp::Ordering::Equal)
258 });
259 candidates.truncate(top_k);
260 }
261 candidates.sort_by(|a, b| {
262 a.distance
263 .partial_cmp(&b.distance)
264 .unwrap_or(std::cmp::Ordering::Equal)
265 });
266 candidates
267 }
268
269 pub fn len(&self) -> usize {
270 self.deleted.len()
271 }
272
273 pub fn live_count(&self) -> usize {
274 self.live_count
275 }
276
277 pub fn is_empty(&self) -> bool {
278 self.live_count == 0
279 }
280
281 pub fn get_vector(&self, id: u32) -> Option<&[f32]> {
282 let idx = id as usize;
283 if idx < self.deleted.len() && !self.deleted[idx] {
284 let start = idx * self.dim;
285 Some(&self.data[start..start + self.dim])
286 } else {
287 None
288 }
289 }
290
291 pub fn get_vector_raw(&self, id: u32) -> Option<&[f32]> {
293 let idx = id as usize;
294 if idx < self.deleted.len() {
295 let start = idx * self.dim;
296 Some(&self.data[start..start + self.dim])
297 } else {
298 None
299 }
300 }
301
302 pub fn is_deleted(&self, id: u32) -> bool {
304 let idx = id as usize;
305 idx < self.deleted.len() && self.deleted[idx]
306 }
307
308 pub fn insert_tombstoned(&mut self, vector: Vec<f32>) -> u32 {
310 assert_eq!(
311 vector.len(),
312 self.dim,
313 "dimension mismatch: expected {}, got {}",
314 self.dim,
315 vector.len()
316 );
317 let id = self.len() as u32;
318 self.data.extend_from_slice(&vector);
319 self.deleted.push(true);
320 id
322 }
323
324 pub fn dim(&self) -> usize {
325 self.dim
326 }
327
328 pub fn metric(&self) -> DistanceMetric {
329 self.metric
330 }
331
332 pub fn tombstone_count(&self) -> usize {
333 self.len().saturating_sub(self.live_count)
334 }
335}
336
337#[cfg(test)]
338mod tests {
339 use super::*;
340
341 #[test]
342 fn insert_and_search() {
343 let mut idx = FlatIndex::new(3, DistanceMetric::L2);
344 for i in 0..100u32 {
345 idx.insert(vec![i as f32, 0.0, 0.0]);
346 }
347 assert_eq!(idx.len(), 100);
348 assert_eq!(idx.live_count(), 100);
349
350 let results = idx.search(&[50.0, 0.0, 0.0], 3);
351 assert_eq!(results.len(), 3);
352 assert_eq!(results[0].id, 50);
353 assert!(results[0].distance < 0.01);
354 }
355
356 #[test]
357 fn delete_excludes_from_search() {
358 let mut idx = FlatIndex::new(2, DistanceMetric::L2);
359 idx.insert(vec![0.0, 0.0]);
360 idx.insert(vec![1.0, 0.0]);
361 idx.insert(vec![2.0, 0.0]);
362
363 assert!(idx.delete(1));
364 assert_eq!(idx.live_count(), 2);
365
366 let results = idx.search(&[1.0, 0.0], 3);
367 assert_eq!(results.len(), 2);
368 assert!(results.iter().all(|r| r.id != 1));
369 }
370
371 #[test]
372 fn exact_results() {
373 let mut idx = FlatIndex::new(2, DistanceMetric::Cosine);
374 idx.insert(vec![1.0, 0.0]);
375 idx.insert(vec![0.0, 1.0]);
376 idx.insert(vec![1.0, 1.0]);
377
378 let results = idx.search(&[1.0, 0.0], 1);
379 assert_eq!(results.len(), 1);
380 assert_eq!(results[0].id, 0);
381 }
382
383 #[test]
384 fn empty_search() {
385 let idx = FlatIndex::new(3, DistanceMetric::L2);
386 let results = idx.search(&[1.0, 0.0, 0.0], 5);
387 assert!(results.is_empty());
388 }
389
390 #[test]
391 fn filtered_search() {
392 let mut idx = FlatIndex::new(2, DistanceMetric::L2);
393 for i in 0..8u32 {
394 idx.insert(vec![i as f32, 0.0]);
395 }
396 let bitmap = vec![0b11001100u8];
397 let results = idx.search_filtered(&[3.0, 0.0], 2, &bitmap);
398 assert_eq!(results.len(), 2);
399 assert_eq!(results[0].id, 3);
400 assert_eq!(results[1].id, 2);
401 }
402}