1#![deny(unsafe_code)]
13#![warn(missing_docs)]
14#![warn(rust_2018_idioms)]
15
16use std::cmp::Reverse;
17use std::collections::BinaryHeap;
18
19use ndarray::{Array2, ArrayView1, ArrayView2, Axis};
20use rayon::prelude::*;
21use serde::{Deserialize, Serialize};
22use thiserror::Error;
23
24pub type Result<T> = std::result::Result<T, AnnFlatError>;
26
27#[derive(Error, Debug)]
29pub enum AnnFlatError {
30 #[error("dim mismatch: expected {expected}, got {got}")]
32 DimMismatch {
33 expected: usize,
35 got: usize,
37 },
38 #[error("k ({k}) > index size ({n})")]
40 KTooLarge {
41 k: usize,
43 n: usize,
45 },
46 #[error("k must be > 0")]
48 KZero,
49 #[error("add_batch ids and matrix row counts disagree: {ids} vs {rows}")]
51 BatchLengthMismatch {
52 ids: usize,
54 rows: usize,
56 },
57}
58
59#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
61#[serde(rename_all = "snake_case")]
62pub enum Metric {
63 Cosine,
65 L2,
67 Dot,
69}
70
71#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
73pub struct Hit {
74 pub id: String,
76 pub score: f32,
78}
79
80pub struct Index {
82 metric: Metric,
83 dim: Option<usize>,
84 ids: Vec<String>,
85 vectors: Vec<Vec<f32>>,
87}
88
89impl Index {
90 pub fn new(metric: Metric) -> Self {
92 Self {
93 metric,
94 dim: None,
95 ids: Vec::new(),
96 vectors: Vec::new(),
97 }
98 }
99
100 pub fn metric(&self) -> Metric {
102 self.metric
103 }
104
105 pub fn len(&self) -> usize {
107 self.ids.len()
108 }
109
110 pub fn is_empty(&self) -> bool {
112 self.ids.is_empty()
113 }
114
115 pub fn dim(&self) -> Option<usize> {
117 self.dim
118 }
119
120 pub fn add(&mut self, id: impl Into<String>, vector: &[f32]) -> Result<()> {
122 match self.dim {
123 None => self.dim = Some(vector.len()),
124 Some(d) if d != vector.len() => {
125 return Err(AnnFlatError::DimMismatch {
126 expected: d,
127 got: vector.len(),
128 });
129 }
130 _ => {}
131 }
132 let mut v = vector.to_vec();
133 if self.metric == Metric::Cosine {
134 normalize_in_place(&mut v);
135 }
136 self.ids.push(id.into());
137 self.vectors.push(v);
138 Ok(())
139 }
140
141 pub fn add_batch(&mut self, ids: Vec<String>, matrix: &ArrayView2<'_, f32>) -> Result<()> {
143 if ids.len() != matrix.nrows() {
144 return Err(AnnFlatError::BatchLengthMismatch {
145 ids: ids.len(),
146 rows: matrix.nrows(),
147 });
148 }
149 for (id, row) in ids.into_iter().zip(matrix.axis_iter(Axis(0))) {
150 self.add(id, row.as_slice().unwrap_or(&row.to_vec()))?;
151 }
152 Ok(())
153 }
154
155 pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<Hit>> {
158 if k == 0 {
159 return Err(AnnFlatError::KZero);
160 }
161 if k > self.len() {
162 return Err(AnnFlatError::KTooLarge { k, n: self.len() });
163 }
164 match self.dim {
165 Some(d) if d != query.len() => {
166 return Err(AnnFlatError::DimMismatch {
167 expected: d,
168 got: query.len(),
169 });
170 }
171 None => {
172 return Err(AnnFlatError::KTooLarge { k, n: 0 });
173 }
174 _ => {}
175 }
176 let q: Vec<f32> = if self.metric == Metric::Cosine {
177 let mut q2 = query.to_vec();
178 normalize_in_place(&mut q2);
179 q2
180 } else {
181 query.to_vec()
182 };
183
184 let mut heap: BinaryHeap<(Reverse<OrdScore>, usize)> = BinaryHeap::with_capacity(k);
187 for (i, v) in self.vectors.iter().enumerate() {
188 let s = self.score(&q, v);
189 let entry = (Reverse(OrdScore(s)), i);
190 if heap.len() < k {
191 heap.push(entry);
192 } else if let Some(top) = heap.peek() {
193 if entry.0 < top.0 {
194 heap.pop();
195 heap.push(entry);
196 }
197 }
198 }
199 let mut out: Vec<Hit> = heap
200 .into_iter()
201 .map(|(rs, i)| Hit {
202 id: self.ids[i].clone(),
203 score: rs.0 .0,
204 })
205 .collect();
206 out.sort_by(|a, b| {
207 b.score
208 .partial_cmp(&a.score)
209 .unwrap_or(std::cmp::Ordering::Equal)
210 .then(a.id.cmp(&b.id))
211 });
212 Ok(out)
213 }
214
215 pub fn search_batch(
218 &self,
219 queries: &ArrayView2<'_, f32>,
220 k: usize,
221 parallel: bool,
222 ) -> Result<Vec<Vec<Hit>>> {
223 if parallel {
224 queries
225 .axis_iter(Axis(0))
226 .into_par_iter()
227 .map(|row| self.search_view(&row, k))
228 .collect()
229 } else {
230 queries
231 .axis_iter(Axis(0))
232 .map(|row| self.search_view(&row, k))
233 .collect()
234 }
235 }
236
237 fn search_view(&self, row: &ArrayView1<'_, f32>, k: usize) -> Result<Vec<Hit>> {
238 match row.as_slice() {
239 Some(s) => self.search(s, k),
240 None => self.search(&row.to_vec(), k),
241 }
242 }
243
244 pub fn vectors(&self) -> Result<Array2<f32>> {
246 let n = self.len();
247 let d = self.dim.unwrap_or(0);
248 if n == 0 {
249 return Ok(Array2::<f32>::zeros((0, 0)));
250 }
251 let mut out = Array2::<f32>::zeros((n, d));
252 for (i, v) in self.vectors.iter().enumerate() {
253 for (j, &x) in v.iter().enumerate() {
254 out[[i, j]] = x;
255 }
256 }
257 Ok(out)
258 }
259
260 fn score(&self, q: &[f32], v: &[f32]) -> f32 {
261 match self.metric {
262 Metric::Cosine | Metric::Dot => {
263 let mut s = 0.0_f32;
264 for (a, b) in q.iter().zip(v.iter()) {
265 s += a * b;
266 }
267 s
268 }
269 Metric::L2 => {
270 let mut s = 0.0_f32;
271 for (a, b) in q.iter().zip(v.iter()) {
272 let d = a - b;
273 s += d * d;
274 }
275 -s.sqrt()
276 }
277 }
278 }
279}
280
281fn normalize_in_place(v: &mut [f32]) {
282 let mut sq = 0.0_f32;
283 for &x in v.iter() {
284 sq += x * x;
285 }
286 let n = sq.sqrt();
287 if n > 1e-12 {
288 for x in v.iter_mut() {
289 *x /= n;
290 }
291 } else {
292 for x in v.iter_mut() {
293 *x = 0.0;
294 }
295 }
296}
297
298#[derive(Debug, Clone, Copy, PartialEq)]
299struct OrdScore(f32);
300impl Eq for OrdScore {}
301impl Ord for OrdScore {
302 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
303 self.0
304 .partial_cmp(&other.0)
305 .unwrap_or(std::cmp::Ordering::Equal)
306 }
307}
308impl PartialOrd for OrdScore {
309 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
310 Some(self.cmp(other))
311 }
312}
313
314#[cfg(test)]
315mod tests {
316 use super::*;
317 use ndarray::arr2;
318
319 #[test]
320 fn empty_search_rejected() {
321 let idx = Index::new(Metric::Cosine);
322 assert!(idx.search(&[1.0, 2.0], 1).is_err());
323 }
324
325 #[test]
326 fn cosine_search_finds_self() {
327 let mut idx = Index::new(Metric::Cosine);
328 idx.add("a", &[1.0, 0.0]).unwrap();
329 idx.add("b", &[0.0, 1.0]).unwrap();
330 idx.add("c", &[0.6, 0.8]).unwrap();
331 let hits = idx.search(&[1.0, 0.0], 3).unwrap();
332 assert_eq!(hits[0].id, "a");
333 assert!((hits[0].score - 1.0).abs() < 1e-4);
334 }
335
336 #[test]
337 fn l2_search_smaller_distance_first() {
338 let mut idx = Index::new(Metric::L2);
339 idx.add("near", &[1.0, 1.0]).unwrap();
340 idx.add("far", &[10.0, 10.0]).unwrap();
341 let hits = idx.search(&[1.0, 1.1], 2).unwrap();
342 assert_eq!(hits[0].id, "near");
343 assert!(hits[0].score > hits[1].score);
345 }
346
347 #[test]
348 fn dot_search() {
349 let mut idx = Index::new(Metric::Dot);
350 idx.add("a", &[1.0, 1.0]).unwrap();
351 idx.add("b", &[2.0, 2.0]).unwrap();
352 let hits = idx.search(&[1.0, 1.0], 2).unwrap();
353 assert_eq!(hits[0].id, "b");
355 assert!((hits[0].score - 4.0).abs() < 1e-6);
356 }
357
358 #[test]
359 fn dim_mismatch_on_add() {
360 let mut idx = Index::new(Metric::Cosine);
361 idx.add("a", &[1.0, 0.0]).unwrap();
362 assert!(idx.add("b", &[1.0]).is_err());
363 }
364
365 #[test]
366 fn dim_mismatch_on_search() {
367 let mut idx = Index::new(Metric::Cosine);
368 idx.add("a", &[1.0, 0.0]).unwrap();
369 assert!(idx.search(&[1.0], 1).is_err());
370 }
371
372 #[test]
373 fn k_zero_rejected() {
374 let mut idx = Index::new(Metric::Cosine);
375 idx.add("a", &[1.0, 0.0]).unwrap();
376 assert!(matches!(
377 idx.search(&[1.0, 0.0], 0),
378 Err(AnnFlatError::KZero)
379 ));
380 }
381
382 #[test]
383 fn k_too_large_rejected() {
384 let mut idx = Index::new(Metric::Cosine);
385 idx.add("a", &[1.0, 0.0]).unwrap();
386 assert!(matches!(
387 idx.search(&[1.0, 0.0], 5),
388 Err(AnnFlatError::KTooLarge { .. })
389 ));
390 }
391
392 #[test]
393 fn add_batch_works() {
394 let mut idx = Index::new(Metric::Cosine);
395 let m = arr2(&[[1.0_f32, 0.0], [0.0, 1.0], [0.5, 0.5]]);
396 idx.add_batch(
397 vec!["a".to_string(), "b".to_string(), "c".to_string()],
398 &m.view(),
399 )
400 .unwrap();
401 assert_eq!(idx.len(), 3);
402 }
403
404 #[test]
405 fn add_batch_length_mismatch() {
406 let mut idx = Index::new(Metric::Cosine);
407 let m = arr2(&[[1.0_f32, 0.0], [0.0, 1.0]]);
408 let r = idx.add_batch(vec!["a".to_string()], &m.view());
409 assert!(matches!(r, Err(AnnFlatError::BatchLengthMismatch { .. })));
410 }
411
412 #[test]
413 fn search_batch_serial_and_parallel_match() {
414 let mut idx = Index::new(Metric::Cosine);
415 for i in 0..50 {
416 idx.add(format!("d{i}"), &[i as f32, 1.0, 2.0]).unwrap();
417 }
418 let q = arr2(&[[1.0_f32, 1.0, 2.0], [25.0, 1.0, 2.0]]);
419 let s = idx.search_batch(&q.view(), 5, false).unwrap();
420 let p = idx.search_batch(&q.view(), 5, true).unwrap();
421 assert_eq!(s, p);
422 assert_eq!(s.len(), 2);
423 assert_eq!(s[0].len(), 5);
424 }
425
426 #[test]
427 fn metric_get() {
428 let idx = Index::new(Metric::L2);
429 assert_eq!(idx.metric(), Metric::L2);
430 }
431
432 #[test]
433 fn empty_index_dim_is_none() {
434 let idx = Index::new(Metric::Cosine);
435 assert!(idx.dim().is_none());
436 assert!(idx.is_empty());
437 }
438
439 #[test]
440 fn cosine_normalizes_at_insert() {
441 let mut idx = Index::new(Metric::Cosine);
442 idx.add("a", &[3.0, 4.0]).unwrap();
444 let hits = idx.search(&[1.0, 0.0], 1).unwrap();
446 assert!((hits[0].score - 0.6).abs() < 1e-4);
447 }
448}