1use crate::distance::{DistanceMetric, distance};
11use crate::hnsw::SearchResult;
12
13pub const DEFAULT_FLAT_INDEX_THRESHOLD: usize = 10_000;
15
16pub struct FlatIndex {
18 dim: usize,
19 metric: DistanceMetric,
20 data: Vec<f32>,
22 deleted: Vec<bool>,
24 live_count: usize,
26}
27
28impl FlatIndex {
29 pub fn new(dim: usize, metric: DistanceMetric) -> Self {
31 Self {
32 dim,
33 metric,
34 data: Vec::new(),
35 deleted: Vec::new(),
36 live_count: 0,
37 }
38 }
39
40 pub fn insert(&mut self, vector: Vec<f32>) -> u32 {
42 assert_eq!(
43 vector.len(),
44 self.dim,
45 "dimension mismatch: expected {}, got {}",
46 self.dim,
47 vector.len()
48 );
49 let id = self.len() as u32;
50 self.data.extend_from_slice(&vector);
51 self.deleted.push(false);
52 self.live_count += 1;
53 id
54 }
55
56 pub fn delete(&mut self, id: u32) -> bool {
58 let idx = id as usize;
59 if idx < self.deleted.len() && !self.deleted[idx] {
60 self.deleted[idx] = true;
61 self.live_count -= 1;
62 true
63 } else {
64 false
65 }
66 }
67
68 pub fn search(&self, query: &[f32], top_k: usize) -> Vec<SearchResult> {
70 assert_eq!(query.len(), self.dim);
71 let n = self.len();
72 if n == 0 || top_k == 0 {
73 return Vec::new();
74 }
75
76 let mut candidates: Vec<SearchResult> = Vec::with_capacity(n.min(top_k * 2));
77 for i in 0..n {
78 if self.deleted[i] {
79 continue;
80 }
81 let start = i * self.dim;
82 let vec_slice = &self.data[start..start + self.dim];
83 let dist = distance(query, vec_slice, self.metric);
84 candidates.push(SearchResult {
85 id: i as u32,
86 distance: dist,
87 });
88 }
89
90 if candidates.len() > top_k {
91 candidates.select_nth_unstable_by(top_k, |a, b| {
92 a.distance
93 .partial_cmp(&b.distance)
94 .unwrap_or(std::cmp::Ordering::Equal)
95 });
96 candidates.truncate(top_k);
97 }
98 candidates.sort_by(|a, b| {
99 a.distance
100 .partial_cmp(&b.distance)
101 .unwrap_or(std::cmp::Ordering::Equal)
102 });
103 candidates
104 }
105
106 pub fn search_filtered(&self, query: &[f32], top_k: usize, bitmap: &[u8]) -> Vec<SearchResult> {
108 self.search_filtered_offset(query, top_k, bitmap, 0)
109 }
110
111 pub fn search_filtered_offset(
117 &self,
118 query: &[f32],
119 top_k: usize,
120 bitmap: &[u8],
121 id_offset: u32,
122 ) -> Vec<SearchResult> {
123 assert_eq!(query.len(), self.dim);
124 let n = self.len();
125 if n == 0 || top_k == 0 {
126 return Vec::new();
127 }
128
129 let mut candidates: Vec<SearchResult> = Vec::with_capacity(top_k * 2);
130 for i in 0..n {
131 if self.deleted[i] {
132 continue;
133 }
134 let global = i + id_offset as usize;
135 let byte_idx = global / 8;
136 let bit_idx = global % 8;
137 if byte_idx >= bitmap.len() || (bitmap[byte_idx] & (1 << bit_idx)) == 0 {
138 continue;
139 }
140 let start = i * self.dim;
141 let vec_slice = &self.data[start..start + self.dim];
142 let dist = distance(query, vec_slice, self.metric);
143 candidates.push(SearchResult {
144 id: i as u32,
145 distance: dist,
146 });
147 }
148
149 if candidates.len() > top_k {
150 candidates.select_nth_unstable_by(top_k, |a, b| {
151 a.distance
152 .partial_cmp(&b.distance)
153 .unwrap_or(std::cmp::Ordering::Equal)
154 });
155 candidates.truncate(top_k);
156 }
157 candidates.sort_by(|a, b| {
158 a.distance
159 .partial_cmp(&b.distance)
160 .unwrap_or(std::cmp::Ordering::Equal)
161 });
162 candidates
163 }
164
165 pub fn len(&self) -> usize {
166 self.deleted.len()
167 }
168
169 pub fn live_count(&self) -> usize {
170 self.live_count
171 }
172
173 pub fn is_empty(&self) -> bool {
174 self.live_count == 0
175 }
176
177 pub fn get_vector(&self, id: u32) -> Option<&[f32]> {
178 let idx = id as usize;
179 if idx < self.deleted.len() && !self.deleted[idx] {
180 let start = idx * self.dim;
181 Some(&self.data[start..start + self.dim])
182 } else {
183 None
184 }
185 }
186
187 pub fn get_vector_raw(&self, id: u32) -> Option<&[f32]> {
189 let idx = id as usize;
190 if idx < self.deleted.len() {
191 let start = idx * self.dim;
192 Some(&self.data[start..start + self.dim])
193 } else {
194 None
195 }
196 }
197
198 pub fn is_deleted(&self, id: u32) -> bool {
200 let idx = id as usize;
201 idx < self.deleted.len() && self.deleted[idx]
202 }
203
204 pub fn insert_tombstoned(&mut self, vector: Vec<f32>) -> u32 {
206 assert_eq!(
207 vector.len(),
208 self.dim,
209 "dimension mismatch: expected {}, got {}",
210 self.dim,
211 vector.len()
212 );
213 let id = self.len() as u32;
214 self.data.extend_from_slice(&vector);
215 self.deleted.push(true);
216 id
218 }
219
220 pub fn dim(&self) -> usize {
221 self.dim
222 }
223
224 pub fn metric(&self) -> DistanceMetric {
225 self.metric
226 }
227
228 pub fn tombstone_count(&self) -> usize {
229 self.len().saturating_sub(self.live_count)
230 }
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236
237 #[test]
238 fn insert_and_search() {
239 let mut idx = FlatIndex::new(3, DistanceMetric::L2);
240 for i in 0..100u32 {
241 idx.insert(vec![i as f32, 0.0, 0.0]);
242 }
243 assert_eq!(idx.len(), 100);
244 assert_eq!(idx.live_count(), 100);
245
246 let results = idx.search(&[50.0, 0.0, 0.0], 3);
247 assert_eq!(results.len(), 3);
248 assert_eq!(results[0].id, 50);
249 assert!(results[0].distance < 0.01);
250 }
251
252 #[test]
253 fn delete_excludes_from_search() {
254 let mut idx = FlatIndex::new(2, DistanceMetric::L2);
255 idx.insert(vec![0.0, 0.0]);
256 idx.insert(vec![1.0, 0.0]);
257 idx.insert(vec![2.0, 0.0]);
258
259 assert!(idx.delete(1));
260 assert_eq!(idx.live_count(), 2);
261
262 let results = idx.search(&[1.0, 0.0], 3);
263 assert_eq!(results.len(), 2);
264 assert!(results.iter().all(|r| r.id != 1));
265 }
266
267 #[test]
268 fn exact_results() {
269 let mut idx = FlatIndex::new(2, DistanceMetric::Cosine);
270 idx.insert(vec![1.0, 0.0]);
271 idx.insert(vec![0.0, 1.0]);
272 idx.insert(vec![1.0, 1.0]);
273
274 let results = idx.search(&[1.0, 0.0], 1);
275 assert_eq!(results.len(), 1);
276 assert_eq!(results[0].id, 0);
277 }
278
279 #[test]
280 fn empty_search() {
281 let idx = FlatIndex::new(3, DistanceMetric::L2);
282 let results = idx.search(&[1.0, 0.0, 0.0], 5);
283 assert!(results.is_empty());
284 }
285
286 #[test]
287 fn filtered_search() {
288 let mut idx = FlatIndex::new(2, DistanceMetric::L2);
289 for i in 0..8u32 {
290 idx.insert(vec![i as f32, 0.0]);
291 }
292 let bitmap = vec![0b11001100u8];
293 let results = idx.search_filtered(&[3.0, 0.0], 2, &bitmap);
294 assert_eq!(results.len(), 2);
295 assert_eq!(results[0].id, 3);
296 assert_eq!(results[1].id, 2);
297 }
298}