1use crate::distance::{DistanceMetric, distance};
11use crate::hnsw::SearchResult;
12
13pub const DEFAULT_FLAT_INDEX_THRESHOLD: usize = 10_000;
15
16pub struct FlatIndex {
18 dim: usize,
19 metric: DistanceMetric,
20 data: Vec<f32>,
22 deleted: Vec<bool>,
24 live_count: usize,
26}
27
28impl FlatIndex {
29 pub fn new(dim: usize, metric: DistanceMetric) -> Self {
31 Self {
32 dim,
33 metric,
34 data: Vec::new(),
35 deleted: Vec::new(),
36 live_count: 0,
37 }
38 }
39
40 pub fn insert(&mut self, vector: Vec<f32>) -> u32 {
42 assert_eq!(
43 vector.len(),
44 self.dim,
45 "dimension mismatch: expected {}, got {}",
46 self.dim,
47 vector.len()
48 );
49 let id = self.len() as u32;
50 self.data.extend_from_slice(&vector);
51 self.deleted.push(false);
52 self.live_count += 1;
53 id
54 }
55
56 pub fn delete(&mut self, id: u32) -> bool {
58 let idx = id as usize;
59 if idx < self.deleted.len() && !self.deleted[idx] {
60 self.deleted[idx] = true;
61 self.live_count -= 1;
62 true
63 } else {
64 false
65 }
66 }
67
68 pub fn search(&self, query: &[f32], top_k: usize) -> Vec<SearchResult> {
70 assert_eq!(query.len(), self.dim);
71 let n = self.len();
72 if n == 0 || top_k == 0 {
73 return Vec::new();
74 }
75
76 let mut candidates: Vec<SearchResult> = Vec::with_capacity(n.min(top_k * 2));
77 for i in 0..n {
78 if self.deleted[i] {
79 continue;
80 }
81 let start = i * self.dim;
82 let vec_slice = &self.data[start..start + self.dim];
83 let dist = distance(query, vec_slice, self.metric);
84 candidates.push(SearchResult {
85 id: i as u32,
86 distance: dist,
87 });
88 }
89
90 if candidates.len() > top_k {
91 candidates.select_nth_unstable_by(top_k, |a, b| {
92 a.distance
93 .partial_cmp(&b.distance)
94 .unwrap_or(std::cmp::Ordering::Equal)
95 });
96 candidates.truncate(top_k);
97 }
98 candidates.sort_by(|a, b| {
99 a.distance
100 .partial_cmp(&b.distance)
101 .unwrap_or(std::cmp::Ordering::Equal)
102 });
103 candidates
104 }
105
106 pub fn search_filtered(&self, query: &[f32], top_k: usize, bitmap: &[u8]) -> Vec<SearchResult> {
108 assert_eq!(query.len(), self.dim);
109 let n = self.len();
110 if n == 0 || top_k == 0 {
111 return Vec::new();
112 }
113
114 let mut candidates: Vec<SearchResult> = Vec::with_capacity(top_k * 2);
115 for i in 0..n {
116 if self.deleted[i] {
117 continue;
118 }
119 let byte_idx = i / 8;
120 let bit_idx = i % 8;
121 if byte_idx >= bitmap.len() || (bitmap[byte_idx] & (1 << bit_idx)) == 0 {
122 continue;
123 }
124 let start = i * self.dim;
125 let vec_slice = &self.data[start..start + self.dim];
126 let dist = distance(query, vec_slice, self.metric);
127 candidates.push(SearchResult {
128 id: i as u32,
129 distance: dist,
130 });
131 }
132
133 if candidates.len() > top_k {
134 candidates.select_nth_unstable_by(top_k, |a, b| {
135 a.distance
136 .partial_cmp(&b.distance)
137 .unwrap_or(std::cmp::Ordering::Equal)
138 });
139 candidates.truncate(top_k);
140 }
141 candidates.sort_by(|a, b| {
142 a.distance
143 .partial_cmp(&b.distance)
144 .unwrap_or(std::cmp::Ordering::Equal)
145 });
146 candidates
147 }
148
149 pub fn len(&self) -> usize {
150 self.deleted.len()
151 }
152
153 pub fn live_count(&self) -> usize {
154 self.live_count
155 }
156
157 pub fn is_empty(&self) -> bool {
158 self.live_count == 0
159 }
160
161 pub fn get_vector(&self, id: u32) -> Option<&[f32]> {
162 let idx = id as usize;
163 if idx < self.deleted.len() {
164 let start = idx * self.dim;
165 Some(&self.data[start..start + self.dim])
166 } else {
167 None
168 }
169 }
170
171 pub fn dim(&self) -> usize {
172 self.dim
173 }
174
175 pub fn metric(&self) -> DistanceMetric {
176 self.metric
177 }
178
179 pub fn tombstone_count(&self) -> usize {
180 self.len().saturating_sub(self.live_count)
181 }
182}
183
184#[cfg(test)]
185mod tests {
186 use super::*;
187
188 #[test]
189 fn insert_and_search() {
190 let mut idx = FlatIndex::new(3, DistanceMetric::L2);
191 for i in 0..100u32 {
192 idx.insert(vec![i as f32, 0.0, 0.0]);
193 }
194 assert_eq!(idx.len(), 100);
195 assert_eq!(idx.live_count(), 100);
196
197 let results = idx.search(&[50.0, 0.0, 0.0], 3);
198 assert_eq!(results.len(), 3);
199 assert_eq!(results[0].id, 50);
200 assert!(results[0].distance < 0.01);
201 }
202
203 #[test]
204 fn delete_excludes_from_search() {
205 let mut idx = FlatIndex::new(2, DistanceMetric::L2);
206 idx.insert(vec![0.0, 0.0]);
207 idx.insert(vec![1.0, 0.0]);
208 idx.insert(vec![2.0, 0.0]);
209
210 assert!(idx.delete(1));
211 assert_eq!(idx.live_count(), 2);
212
213 let results = idx.search(&[1.0, 0.0], 3);
214 assert_eq!(results.len(), 2);
215 assert!(results.iter().all(|r| r.id != 1));
216 }
217
218 #[test]
219 fn exact_results() {
220 let mut idx = FlatIndex::new(2, DistanceMetric::Cosine);
221 idx.insert(vec![1.0, 0.0]);
222 idx.insert(vec![0.0, 1.0]);
223 idx.insert(vec![1.0, 1.0]);
224
225 let results = idx.search(&[1.0, 0.0], 1);
226 assert_eq!(results.len(), 1);
227 assert_eq!(results[0].id, 0);
228 }
229
230 #[test]
231 fn empty_search() {
232 let idx = FlatIndex::new(3, DistanceMetric::L2);
233 let results = idx.search(&[1.0, 0.0, 0.0], 5);
234 assert!(results.is_empty());
235 }
236
237 #[test]
238 fn filtered_search() {
239 let mut idx = FlatIndex::new(2, DistanceMetric::L2);
240 for i in 0..8u32 {
241 idx.insert(vec![i as f32, 0.0]);
242 }
243 let bitmap = vec![0b11001100u8];
244 let results = idx.search_filtered(&[3.0, 0.0], 2, &bitmap);
245 assert_eq!(results.len(), 2);
246 assert_eq!(results[0].id, 3);
247 assert_eq!(results[1].id, 2);
248 }
249}