1use roaring::RoaringBitmap;
13
14use crate::distance::{DistanceMetric, distance};
15use crate::hnsw::SearchResult;
16
17pub const DEFAULT_FLAT_INDEX_THRESHOLD: usize = 10_000;
19
20pub struct FlatIndex {
22 dim: usize,
23 metric: DistanceMetric,
24 data: Vec<f32>,
26 deleted: Vec<bool>,
28 live_count: usize,
30}
31
32impl FlatIndex {
33 pub fn new(dim: usize, metric: DistanceMetric) -> Self {
35 Self {
36 dim,
37 metric,
38 data: Vec::new(),
39 deleted: Vec::new(),
40 live_count: 0,
41 }
42 }
43
44 pub fn insert(&mut self, vector: Vec<f32>) -> u32 {
46 assert_eq!(
47 vector.len(),
48 self.dim,
49 "dimension mismatch: expected {}, got {}",
50 self.dim,
51 vector.len()
52 );
53 let id = self.len() as u32;
54 self.data.extend_from_slice(&vector);
55 self.deleted.push(false);
56 self.live_count += 1;
57 id
58 }
59
60 pub fn delete(&mut self, id: u32) -> bool {
62 let idx = id as usize;
63 if idx < self.deleted.len() && !self.deleted[idx] {
64 self.deleted[idx] = true;
65 self.live_count -= 1;
66 true
67 } else {
68 false
69 }
70 }
71
72 pub fn search_with_metric(
75 &self,
76 query: &[f32],
77 top_k: usize,
78 metric: DistanceMetric,
79 ) -> Vec<SearchResult> {
80 assert_eq!(query.len(), self.dim);
81 let n = self.len();
82 if n == 0 || top_k == 0 {
83 return Vec::new();
84 }
85
86 let mut candidates: Vec<SearchResult> = Vec::with_capacity(n.min(top_k * 2));
88 for i in 0..n {
89 if self.deleted[i] {
90 continue;
91 }
92 let start = i * self.dim;
93 let vec_slice = &self.data[start..start + self.dim];
94 let dist = distance(query, vec_slice, metric);
95 candidates.push(SearchResult {
96 id: i as u32,
97 distance: dist,
98 });
99 }
100
101 if candidates.len() > top_k {
102 candidates.select_nth_unstable_by(top_k, |a, b| {
103 a.distance
104 .partial_cmp(&b.distance)
105 .unwrap_or(std::cmp::Ordering::Equal)
106 });
107 candidates.truncate(top_k);
108 }
109 candidates.sort_by(|a, b| {
110 a.distance
111 .partial_cmp(&b.distance)
112 .unwrap_or(std::cmp::Ordering::Equal)
113 });
114 candidates
115 }
116
117 pub fn search(&self, query: &[f32], top_k: usize) -> Vec<SearchResult> {
119 assert_eq!(query.len(), self.dim);
120 let n = self.len();
121 if n == 0 || top_k == 0 {
122 return Vec::new();
123 }
124
125 let mut candidates: Vec<SearchResult> = Vec::with_capacity(n.min(top_k * 2));
127 for i in 0..n {
128 if self.deleted[i] {
129 continue;
130 }
131 let start = i * self.dim;
132 let vec_slice = &self.data[start..start + self.dim];
133 let dist = distance(query, vec_slice, self.metric);
134 candidates.push(SearchResult {
135 id: i as u32,
136 distance: dist,
137 });
138 }
139
140 if candidates.len() > top_k {
141 candidates.select_nth_unstable_by(top_k, |a, b| {
142 a.distance
143 .partial_cmp(&b.distance)
144 .unwrap_or(std::cmp::Ordering::Equal)
145 });
146 candidates.truncate(top_k);
147 }
148 candidates.sort_by(|a, b| {
149 a.distance
150 .partial_cmp(&b.distance)
151 .unwrap_or(std::cmp::Ordering::Equal)
152 });
153 candidates
154 }
155
156 pub fn search_filtered(&self, query: &[f32], top_k: usize, bitmap: &[u8]) -> Vec<SearchResult> {
158 self.search_filtered_offset(query, top_k, bitmap, 0)
159 }
160
161 pub fn search_filtered_offset_with_metric(
163 &self,
164 query: &[f32],
165 top_k: usize,
166 bitmap: &[u8],
167 id_offset: u32,
168 metric: DistanceMetric,
169 ) -> Vec<SearchResult> {
170 assert_eq!(query.len(), self.dim);
171 let n = self.len();
172 if n == 0 || top_k == 0 {
173 return Vec::new();
174 }
175
176 let parsed = RoaringBitmap::deserialize_from(bitmap).ok();
177
178 let mut candidates: Vec<SearchResult> = Vec::with_capacity(top_k * 2);
180 for i in 0..n {
181 if self.deleted[i] {
182 continue;
183 }
184 if let Some(ref bm) = parsed {
185 let global = (i as u32).saturating_add(id_offset);
186 if !bm.contains(global) {
187 continue;
188 }
189 }
190 let start = i * self.dim;
191 let vec_slice = &self.data[start..start + self.dim];
192 let dist = distance(query, vec_slice, metric);
193 candidates.push(SearchResult {
194 id: i as u32,
195 distance: dist,
196 });
197 }
198
199 if candidates.len() > top_k {
200 candidates.select_nth_unstable_by(top_k, |a, b| {
201 a.distance
202 .partial_cmp(&b.distance)
203 .unwrap_or(std::cmp::Ordering::Equal)
204 });
205 candidates.truncate(top_k);
206 }
207 candidates.sort_by(|a, b| {
208 a.distance
209 .partial_cmp(&b.distance)
210 .unwrap_or(std::cmp::Ordering::Equal)
211 });
212 candidates
213 }
214
215 pub fn search_filtered_offset(
222 &self,
223 query: &[f32],
224 top_k: usize,
225 bitmap: &[u8],
226 id_offset: u32,
227 ) -> Vec<SearchResult> {
228 assert_eq!(query.len(), self.dim);
229 let n = self.len();
230 if n == 0 || top_k == 0 {
231 return Vec::new();
232 }
233
234 let parsed = RoaringBitmap::deserialize_from(bitmap).ok();
235
236 let mut candidates: Vec<SearchResult> = Vec::with_capacity(top_k * 2);
238 for i in 0..n {
239 if self.deleted[i] {
240 continue;
241 }
242 if let Some(ref bm) = parsed {
243 let global = (i as u32).saturating_add(id_offset);
244 if !bm.contains(global) {
245 continue;
246 }
247 }
248 let start = i * self.dim;
249 let vec_slice = &self.data[start..start + self.dim];
250 let dist = distance(query, vec_slice, self.metric);
251 candidates.push(SearchResult {
252 id: i as u32,
253 distance: dist,
254 });
255 }
256
257 if candidates.len() > top_k {
258 candidates.select_nth_unstable_by(top_k, |a, b| {
259 a.distance
260 .partial_cmp(&b.distance)
261 .unwrap_or(std::cmp::Ordering::Equal)
262 });
263 candidates.truncate(top_k);
264 }
265 candidates.sort_by(|a, b| {
266 a.distance
267 .partial_cmp(&b.distance)
268 .unwrap_or(std::cmp::Ordering::Equal)
269 });
270 candidates
271 }
272
273 pub fn len(&self) -> usize {
274 self.deleted.len()
275 }
276
277 pub fn live_count(&self) -> usize {
278 self.live_count
279 }
280
281 pub fn is_empty(&self) -> bool {
282 self.live_count == 0
283 }
284
285 pub fn get_vector(&self, id: u32) -> Option<&[f32]> {
286 let idx = id as usize;
287 if idx < self.deleted.len() && !self.deleted[idx] {
288 let start = idx * self.dim;
289 Some(&self.data[start..start + self.dim])
290 } else {
291 None
292 }
293 }
294
295 pub fn get_vector_raw(&self, id: u32) -> Option<&[f32]> {
297 let idx = id as usize;
298 if idx < self.deleted.len() {
299 let start = idx * self.dim;
300 Some(&self.data[start..start + self.dim])
301 } else {
302 None
303 }
304 }
305
306 pub fn is_deleted(&self, id: u32) -> bool {
308 let idx = id as usize;
309 idx < self.deleted.len() && self.deleted[idx]
310 }
311
312 pub fn insert_tombstoned(&mut self, vector: Vec<f32>) -> u32 {
314 assert_eq!(
315 vector.len(),
316 self.dim,
317 "dimension mismatch: expected {}, got {}",
318 self.dim,
319 vector.len()
320 );
321 let id = self.len() as u32;
322 self.data.extend_from_slice(&vector);
323 self.deleted.push(true);
324 id
326 }
327
328 pub fn dim(&self) -> usize {
329 self.dim
330 }
331
332 pub fn metric(&self) -> DistanceMetric {
333 self.metric
334 }
335
336 pub fn tombstone_count(&self) -> usize {
337 self.len().saturating_sub(self.live_count)
338 }
339}
340
341#[cfg(test)]
342mod tests {
343 use super::*;
344
345 #[test]
346 fn insert_and_search() {
347 let mut idx = FlatIndex::new(3, DistanceMetric::L2);
348 for i in 0..100u32 {
349 idx.insert(vec![i as f32, 0.0, 0.0]);
350 }
351 assert_eq!(idx.len(), 100);
352 assert_eq!(idx.live_count(), 100);
353
354 let results = idx.search(&[50.0, 0.0, 0.0], 3);
355 assert_eq!(results.len(), 3);
356 assert_eq!(results[0].id, 50);
357 assert!(results[0].distance < 0.01);
358 }
359
360 #[test]
361 fn delete_excludes_from_search() {
362 let mut idx = FlatIndex::new(2, DistanceMetric::L2);
363 idx.insert(vec![0.0, 0.0]);
364 idx.insert(vec![1.0, 0.0]);
365 idx.insert(vec![2.0, 0.0]);
366
367 assert!(idx.delete(1));
368 assert_eq!(idx.live_count(), 2);
369
370 let results = idx.search(&[1.0, 0.0], 3);
371 assert_eq!(results.len(), 2);
372 assert!(results.iter().all(|r| r.id != 1));
373 }
374
375 #[test]
376 fn exact_results() {
377 let mut idx = FlatIndex::new(2, DistanceMetric::Cosine);
378 idx.insert(vec![1.0, 0.0]);
379 idx.insert(vec![0.0, 1.0]);
380 idx.insert(vec![1.0, 1.0]);
381
382 let results = idx.search(&[1.0, 0.0], 1);
383 assert_eq!(results.len(), 1);
384 assert_eq!(results[0].id, 0);
385 }
386
387 #[test]
388 fn empty_search() {
389 let idx = FlatIndex::new(3, DistanceMetric::L2);
390 let results = idx.search(&[1.0, 0.0, 0.0], 5);
391 assert!(results.is_empty());
392 }
393
394 #[test]
395 fn filtered_search() {
396 let mut idx = FlatIndex::new(2, DistanceMetric::L2);
397 for i in 0..8u32 {
398 idx.insert(vec![i as f32, 0.0]);
399 }
400 let bitmap = vec![0b11001100u8];
401 let results = idx.search_filtered(&[3.0, 0.0], 2, &bitmap);
402 assert_eq!(results.len(), 2);
403 assert_eq!(results[0].id, 3);
404 assert_eq!(results[1].id, 2);
405 }
406}