1use ahash::AHashMap;
7use ordered_float::OrderedFloat;
8use rayon::prelude::*;
9use std::cmp::Reverse;
10use std::collections::BinaryHeap;
11
12use crate::error::{Error, Result};
13use super::traits::{DistanceType, IndexConfig, SearchResult, VectorIndex};
14
15#[derive(Debug)]
19pub struct FlatIndex {
20 config: IndexConfig,
22 vectors: AHashMap<String, Vec<f32>>,
24 parallel: bool,
26 parallel_threshold: usize,
28}
29
30impl FlatIndex {
31 #[must_use]
33 pub fn new(config: IndexConfig) -> Self {
34 Self {
35 config,
36 vectors: AHashMap::new(),
37 parallel: true,
38 parallel_threshold: 1000,
39 }
40 }
41
42 #[must_use]
44 pub fn with_capacity(config: IndexConfig, capacity: usize) -> Self {
45 Self {
46 config,
47 vectors: AHashMap::with_capacity(capacity),
48 parallel: true,
49 parallel_threshold: 1000,
50 }
51 }
52
53 #[must_use]
55 pub const fn with_parallel(mut self, parallel: bool) -> Self {
56 self.parallel = parallel;
57 self
58 }
59
60 #[must_use]
62 pub const fn with_parallel_threshold(mut self, threshold: usize) -> Self {
63 self.parallel_threshold = threshold;
64 self
65 }
66
67 fn compute_distance(&self, a: &[f32], b: &[f32]) -> f32 {
69 match self.config.distance_type {
70 DistanceType::L2 => Self::l2_distance(a, b),
71 DistanceType::InnerProduct => Self::inner_product(a, b),
72 DistanceType::Cosine => Self::cosine_similarity(a, b),
73 }
74 }
75
76 #[inline]
78 fn l2_distance(a: &[f32], b: &[f32]) -> f32 {
79 a.iter()
80 .zip(b.iter())
81 .map(|(x, y)| (x - y).powi(2))
82 .sum::<f32>()
83 .sqrt()
84 }
85
86 #[inline]
88 fn inner_product(a: &[f32], b: &[f32]) -> f32 {
89 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
90 }
91
92 #[inline]
94 fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
95 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
96 let norm_a: f32 = a.iter().map(|x| x.powi(2)).sum::<f32>().sqrt();
97 let norm_b: f32 = b.iter().map(|x| x.powi(2)).sum::<f32>().sqrt();
98
99 if norm_a == 0.0 || norm_b == 0.0 {
100 0.0
101 } else {
102 dot / (norm_a * norm_b)
103 }
104 }
105
106 fn normalize(vector: &mut [f32]) {
108 let norm: f32 = vector.iter().map(|x| x.powi(2)).sum::<f32>().sqrt();
109 if norm > 0.0 {
110 for x in vector.iter_mut() {
111 *x /= norm;
112 }
113 }
114 }
115
116 fn search_sequential(&self, query: &[f32], k: usize) -> Vec<SearchResult> {
118 let mut heap: BinaryHeap<(Reverse<OrderedFloat<f32>>, String)> = BinaryHeap::new();
120
121 let is_similarity = matches!(
122 self.config.distance_type,
123 DistanceType::InnerProduct | DistanceType::Cosine
124 );
125
126 for (id, vector) in &self.vectors {
127 let dist = self.compute_distance(query, vector);
128 let key = if is_similarity {
129 Reverse(OrderedFloat(-dist))
131 } else {
132 Reverse(OrderedFloat(dist))
134 };
135
136 if heap.len() < k {
137 heap.push((key, id.clone()));
138 } else if let Some((top_key, _)) = heap.peek() {
139 if key > *top_key {
140 heap.pop();
141 heap.push((key, id.clone()));
142 }
143 }
144 }
145
146 let mut results: Vec<_> = heap
148 .into_iter()
149 .map(|(Reverse(OrderedFloat(dist)), id)| {
150 let actual_dist = if is_similarity { -dist } else { dist };
151 SearchResult::new(id, actual_dist, self.config.distance_type)
152 })
153 .collect();
154
155 if is_similarity {
157 results.sort_by(|a, b| b.distance.partial_cmp(&a.distance).unwrap());
158 } else {
159 results.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap());
160 }
161
162 results
163 }
164
165 fn search_parallel(&self, query: &[f32], k: usize) -> Vec<SearchResult> {
167 let is_similarity = matches!(
168 self.config.distance_type,
169 DistanceType::InnerProduct | DistanceType::Cosine
170 );
171
172 let mut distances: Vec<_> = self.vectors
174 .par_iter()
175 .map(|(id, vector)| {
176 let dist = self.compute_distance(query, vector);
177 (id.clone(), dist)
178 })
179 .collect();
180
181 if is_similarity {
183 distances.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
184 } else {
185 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
186 }
187
188 distances
190 .into_iter()
191 .take(k)
192 .map(|(id, dist)| SearchResult::new(id, dist, self.config.distance_type))
193 .collect()
194 }
195}
196
197impl VectorIndex for FlatIndex {
198 fn add(&mut self, id: String, vector: &[f32]) -> Result<()> {
199 if vector.len() != self.config.dimension {
200 return Err(Error::InvalidQuery {
201 reason: format!(
202 "Dimension mismatch: expected {}, got {}",
203 self.config.dimension,
204 vector.len()
205 ),
206 });
207 }
208
209 let mut vec = vector.to_vec();
210 if self.config.normalize {
211 Self::normalize(&mut vec);
212 }
213
214 self.vectors.insert(id, vec);
215 Ok(())
216 }
217
218 fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
219 if query.len() != self.config.dimension {
220 return Err(Error::InvalidQuery {
221 reason: format!(
222 "Query dimension mismatch: expected {}, got {}",
223 self.config.dimension,
224 query.len()
225 ),
226 });
227 }
228
229 if self.vectors.is_empty() {
230 return Ok(vec![]);
231 }
232
233 let k = k.min(self.vectors.len());
234
235 let query = if self.config.normalize {
237 let mut q = query.to_vec();
238 Self::normalize(&mut q);
239 q
240 } else {
241 query.to_vec()
242 };
243
244 let results = if self.parallel && self.vectors.len() > self.parallel_threshold {
246 self.search_parallel(&query, k)
247 } else {
248 self.search_sequential(&query, k)
249 };
250
251 Ok(results)
252 }
253
254 fn remove(&mut self, id: &str) -> Result<bool> {
255 Ok(self.vectors.remove(id).is_some())
256 }
257
258 fn contains(&self, id: &str) -> bool {
259 self.vectors.contains_key(id)
260 }
261
262 fn len(&self) -> usize {
263 self.vectors.len()
264 }
265
266 fn dimension(&self) -> usize {
267 self.config.dimension
268 }
269
270 fn distance_type(&self) -> DistanceType {
271 self.config.distance_type
272 }
273
274 fn clear(&mut self) {
275 self.vectors.clear();
276 }
277
278 fn memory_usage(&self) -> usize {
279 let per_entry = self.config.dimension * 4 + 32 + 48;
284 self.vectors.len() * per_entry
285 }
286}
287
288#[cfg(test)]
289mod tests {
290 use super::*;
291
292 fn create_test_index() -> FlatIndex {
293 let config = IndexConfig::new(4);
294 let mut index = FlatIndex::new(config).with_parallel(false);
295
296 index.add("a".to_string(), &[1.0, 0.0, 0.0, 0.0]).unwrap();
298 index.add("b".to_string(), &[0.0, 1.0, 0.0, 0.0]).unwrap();
299 index.add("c".to_string(), &[0.0, 0.0, 1.0, 0.0]).unwrap();
300 index.add("d".to_string(), &[0.5, 0.5, 0.0, 0.0]).unwrap();
301
302 index
303 }
304
305 #[test]
306 fn test_add_and_len() {
307 let index = create_test_index();
308 assert_eq!(index.len(), 4);
309 assert!(index.contains("a"));
310 assert!(!index.contains("z"));
311 }
312
313 #[test]
314 fn test_search_l2() {
315 let index = create_test_index();
316
317 let results = index.search(&[0.9, 0.1, 0.0, 0.0], 2).unwrap();
319
320 assert_eq!(results.len(), 2);
321 assert_eq!(results[0].id, "a"); }
323
324 #[test]
325 fn test_search_cosine() {
326 let config = IndexConfig::new(4).with_distance(DistanceType::Cosine);
327 let mut index = FlatIndex::new(config).with_parallel(false);
328
329 index.add("a".to_string(), &[1.0, 0.0, 0.0, 0.0]).unwrap();
330 index.add("b".to_string(), &[1.0, 1.0, 0.0, 0.0]).unwrap();
331 index.add("c".to_string(), &[0.0, 1.0, 0.0, 0.0]).unwrap();
332
333 let results = index.search(&[1.0, 0.0, 0.0, 0.0], 2).unwrap();
334
335 assert_eq!(results[0].id, "a"); assert!((results[0].distance - 1.0).abs() < 1e-6); }
338
339 #[test]
340 fn test_remove() {
341 let mut index = create_test_index();
342
343 assert!(index.remove("a").unwrap());
344 assert!(!index.contains("a"));
345 assert_eq!(index.len(), 3);
346
347 assert!(!index.remove("z").unwrap()); }
349
350 #[test]
351 fn test_dimension_mismatch() {
352 let mut index = create_test_index();
353
354 let result = index.add("e".to_string(), &[1.0, 2.0]); assert!(result.is_err());
356 }
357
358 #[test]
359 fn test_empty_search() {
360 let config = IndexConfig::new(4);
361 let index = FlatIndex::new(config);
362
363 let results = index.search(&[1.0, 0.0, 0.0, 0.0], 10).unwrap();
364 assert!(results.is_empty());
365 }
366
367 #[test]
368 fn test_clear() {
369 let mut index = create_test_index();
370 assert_eq!(index.len(), 4);
371
372 index.clear();
373 assert!(index.is_empty());
374 }
375}