1#![deny(unsafe_code)]
13#![warn(missing_docs)]
14#![warn(rust_2018_idioms)]
15
16use std::cmp::Reverse;
17use std::collections::BinaryHeap;
18
19use ndarray::{Array2, ArrayView1, ArrayView2, Axis};
20use rayon::prelude::*;
21use serde::{Deserialize, Serialize};
22use thiserror::Error;
23
24pub type Result<T> = std::result::Result<T, AnnFlatError>;
26
27#[derive(Error, Debug)]
29pub enum AnnFlatError {
30 #[error("io error: {0}")]
32 Io(#[from] std::io::Error),
33 #[error("serde error: {0}")]
35 Serde(#[from] serde_json::Error),
36 #[error("dim mismatch: expected {expected}, got {got}")]
38 DimMismatch {
39 expected: usize,
41 got: usize,
43 },
44 #[error("k ({k}) > index size ({n})")]
46 KTooLarge {
47 k: usize,
49 n: usize,
51 },
52 #[error("k must be > 0")]
54 KZero,
55 #[error("add_batch ids and matrix row counts disagree: {ids} vs {rows}")]
57 BatchLengthMismatch {
58 ids: usize,
60 rows: usize,
62 },
63}
64
65#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
67#[serde(rename_all = "snake_case")]
68pub enum Metric {
69 Cosine,
71 L2,
73 Dot,
75}
76
77#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
79pub struct Hit {
80 pub id: String,
82 pub score: f32,
84}
85
86#[derive(Serialize, Deserialize)]
88pub struct Index {
89 metric: Metric,
90 dim: Option<usize>,
91 ids: Vec<String>,
92 vectors: Vec<Vec<f32>>,
94}
95
96impl Index {
97 pub fn new(metric: Metric) -> Self {
99 Self {
100 metric,
101 dim: None,
102 ids: Vec::new(),
103 vectors: Vec::new(),
104 }
105 }
106
107 pub fn metric(&self) -> Metric {
109 self.metric
110 }
111
112 pub fn len(&self) -> usize {
114 self.ids.len()
115 }
116
117 pub fn is_empty(&self) -> bool {
119 self.ids.is_empty()
120 }
121
122 pub fn dim(&self) -> Option<usize> {
124 self.dim
125 }
126
127 pub fn add(&mut self, id: impl Into<String>, vector: &[f32]) -> Result<()> {
129 match self.dim {
130 None => self.dim = Some(vector.len()),
131 Some(d) if d != vector.len() => {
132 return Err(AnnFlatError::DimMismatch {
133 expected: d,
134 got: vector.len(),
135 });
136 }
137 _ => {}
138 }
139 let mut v = vector.to_vec();
140 if self.metric == Metric::Cosine {
141 normalize_in_place(&mut v);
142 }
143 self.ids.push(id.into());
144 self.vectors.push(v);
145 Ok(())
146 }
147
148 pub fn remove(&mut self, id: &str) -> bool {
152 let Some(pos) = self.ids.iter().position(|s| s == id) else {
153 return false;
154 };
155 self.ids.swap_remove(pos);
156 self.vectors.swap_remove(pos);
157 true
158 }
159
160 pub fn save<P: AsRef<std::path::Path>>(&self, path: P) -> Result<()> {
165 let file = std::fs::File::create(path)?;
166 let buf = std::io::BufWriter::new(file);
167 serde_json::to_writer(buf, self)?;
168 Ok(())
169 }
170
171 pub fn load<P: AsRef<std::path::Path>>(path: P) -> Result<Self> {
174 let file = std::fs::File::open(path)?;
175 let buf = std::io::BufReader::new(file);
176 let idx: Self = serde_json::from_reader(buf)?;
177 Ok(idx)
178 }
179
180 pub fn add_batch(&mut self, ids: Vec<String>, matrix: &ArrayView2<'_, f32>) -> Result<()> {
182 if ids.len() != matrix.nrows() {
183 return Err(AnnFlatError::BatchLengthMismatch {
184 ids: ids.len(),
185 rows: matrix.nrows(),
186 });
187 }
188 for (id, row) in ids.into_iter().zip(matrix.axis_iter(Axis(0))) {
189 self.add(id, row.as_slice().unwrap_or(&row.to_vec()))?;
190 }
191 Ok(())
192 }
193
194 pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<Hit>> {
197 if k == 0 {
198 return Err(AnnFlatError::KZero);
199 }
200 if k > self.len() {
201 return Err(AnnFlatError::KTooLarge { k, n: self.len() });
202 }
203 match self.dim {
204 Some(d) if d != query.len() => {
205 return Err(AnnFlatError::DimMismatch {
206 expected: d,
207 got: query.len(),
208 });
209 }
210 None => {
211 return Err(AnnFlatError::KTooLarge { k, n: 0 });
212 }
213 _ => {}
214 }
215 let q: Vec<f32> = if self.metric == Metric::Cosine {
216 let mut q2 = query.to_vec();
217 normalize_in_place(&mut q2);
218 q2
219 } else {
220 query.to_vec()
221 };
222
223 let mut heap: BinaryHeap<(Reverse<OrdScore>, usize)> = BinaryHeap::with_capacity(k);
226 for (i, v) in self.vectors.iter().enumerate() {
227 let s = self.score(&q, v);
228 let entry = (Reverse(OrdScore(s)), i);
229 if heap.len() < k {
230 heap.push(entry);
231 } else if let Some(top) = heap.peek() {
232 if entry.0 < top.0 {
233 heap.pop();
234 heap.push(entry);
235 }
236 }
237 }
238 let mut out: Vec<Hit> = heap
239 .into_iter()
240 .map(|(rs, i)| Hit {
241 id: self.ids[i].clone(),
242 score: rs.0 .0,
243 })
244 .collect();
245 out.sort_by(|a, b| {
246 b.score
247 .partial_cmp(&a.score)
248 .unwrap_or(std::cmp::Ordering::Equal)
249 .then(a.id.cmp(&b.id))
250 });
251 Ok(out)
252 }
253
254 pub fn search_batch(
257 &self,
258 queries: &ArrayView2<'_, f32>,
259 k: usize,
260 parallel: bool,
261 ) -> Result<Vec<Vec<Hit>>> {
262 if parallel {
263 queries
264 .axis_iter(Axis(0))
265 .into_par_iter()
266 .map(|row| self.search_view(&row, k))
267 .collect()
268 } else {
269 queries
270 .axis_iter(Axis(0))
271 .map(|row| self.search_view(&row, k))
272 .collect()
273 }
274 }
275
276 fn search_view(&self, row: &ArrayView1<'_, f32>, k: usize) -> Result<Vec<Hit>> {
277 match row.as_slice() {
278 Some(s) => self.search(s, k),
279 None => self.search(&row.to_vec(), k),
280 }
281 }
282
283 pub fn vectors(&self) -> Result<Array2<f32>> {
285 let n = self.len();
286 let d = self.dim.unwrap_or(0);
287 if n == 0 {
288 return Ok(Array2::<f32>::zeros((0, 0)));
289 }
290 let mut out = Array2::<f32>::zeros((n, d));
291 for (i, v) in self.vectors.iter().enumerate() {
292 for (j, &x) in v.iter().enumerate() {
293 out[[i, j]] = x;
294 }
295 }
296 Ok(out)
297 }
298
299 fn score(&self, q: &[f32], v: &[f32]) -> f32 {
300 match self.metric {
301 Metric::Cosine | Metric::Dot => {
302 let mut s = 0.0_f32;
303 for (a, b) in q.iter().zip(v.iter()) {
304 s += a * b;
305 }
306 s
307 }
308 Metric::L2 => {
309 let mut s = 0.0_f32;
310 for (a, b) in q.iter().zip(v.iter()) {
311 let d = a - b;
312 s += d * d;
313 }
314 -s.sqrt()
315 }
316 }
317 }
318}
319
320fn normalize_in_place(v: &mut [f32]) {
321 let mut sq = 0.0_f32;
322 for &x in v.iter() {
323 sq += x * x;
324 }
325 let n = sq.sqrt();
326 if n > 1e-12 {
327 for x in v.iter_mut() {
328 *x /= n;
329 }
330 } else {
331 for x in v.iter_mut() {
332 *x = 0.0;
333 }
334 }
335}
336
337#[derive(Debug, Clone, Copy, PartialEq)]
338struct OrdScore(f32);
339impl Eq for OrdScore {}
340impl Ord for OrdScore {
341 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
342 self.0
343 .partial_cmp(&other.0)
344 .unwrap_or(std::cmp::Ordering::Equal)
345 }
346}
347impl PartialOrd for OrdScore {
348 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
349 Some(self.cmp(other))
350 }
351}
352
353#[cfg(test)]
354mod tests {
355 use super::*;
356 use ndarray::arr2;
357
358 #[test]
359 fn empty_search_rejected() {
360 let idx = Index::new(Metric::Cosine);
361 assert!(idx.search(&[1.0, 2.0], 1).is_err());
362 }
363
364 #[test]
365 fn cosine_search_finds_self() {
366 let mut idx = Index::new(Metric::Cosine);
367 idx.add("a", &[1.0, 0.0]).unwrap();
368 idx.add("b", &[0.0, 1.0]).unwrap();
369 idx.add("c", &[0.6, 0.8]).unwrap();
370 let hits = idx.search(&[1.0, 0.0], 3).unwrap();
371 assert_eq!(hits[0].id, "a");
372 assert!((hits[0].score - 1.0).abs() < 1e-4);
373 }
374
375 #[test]
376 fn l2_search_smaller_distance_first() {
377 let mut idx = Index::new(Metric::L2);
378 idx.add("near", &[1.0, 1.0]).unwrap();
379 idx.add("far", &[10.0, 10.0]).unwrap();
380 let hits = idx.search(&[1.0, 1.1], 2).unwrap();
381 assert_eq!(hits[0].id, "near");
382 assert!(hits[0].score > hits[1].score);
384 }
385
386 #[test]
387 fn dot_search() {
388 let mut idx = Index::new(Metric::Dot);
389 idx.add("a", &[1.0, 1.0]).unwrap();
390 idx.add("b", &[2.0, 2.0]).unwrap();
391 let hits = idx.search(&[1.0, 1.0], 2).unwrap();
392 assert_eq!(hits[0].id, "b");
394 assert!((hits[0].score - 4.0).abs() < 1e-6);
395 }
396
397 #[test]
398 fn dim_mismatch_on_add() {
399 let mut idx = Index::new(Metric::Cosine);
400 idx.add("a", &[1.0, 0.0]).unwrap();
401 assert!(idx.add("b", &[1.0]).is_err());
402 }
403
404 #[test]
405 fn dim_mismatch_on_search() {
406 let mut idx = Index::new(Metric::Cosine);
407 idx.add("a", &[1.0, 0.0]).unwrap();
408 assert!(idx.search(&[1.0], 1).is_err());
409 }
410
411 #[test]
412 fn k_zero_rejected() {
413 let mut idx = Index::new(Metric::Cosine);
414 idx.add("a", &[1.0, 0.0]).unwrap();
415 assert!(matches!(
416 idx.search(&[1.0, 0.0], 0),
417 Err(AnnFlatError::KZero)
418 ));
419 }
420
421 #[test]
422 fn k_too_large_rejected() {
423 let mut idx = Index::new(Metric::Cosine);
424 idx.add("a", &[1.0, 0.0]).unwrap();
425 assert!(matches!(
426 idx.search(&[1.0, 0.0], 5),
427 Err(AnnFlatError::KTooLarge { .. })
428 ));
429 }
430
431 #[test]
432 fn add_batch_works() {
433 let mut idx = Index::new(Metric::Cosine);
434 let m = arr2(&[[1.0_f32, 0.0], [0.0, 1.0], [0.5, 0.5]]);
435 idx.add_batch(
436 vec!["a".to_string(), "b".to_string(), "c".to_string()],
437 &m.view(),
438 )
439 .unwrap();
440 assert_eq!(idx.len(), 3);
441 }
442
443 #[test]
444 fn add_batch_length_mismatch() {
445 let mut idx = Index::new(Metric::Cosine);
446 let m = arr2(&[[1.0_f32, 0.0], [0.0, 1.0]]);
447 let r = idx.add_batch(vec!["a".to_string()], &m.view());
448 assert!(matches!(r, Err(AnnFlatError::BatchLengthMismatch { .. })));
449 }
450
451 #[test]
452 fn search_batch_serial_and_parallel_match() {
453 let mut idx = Index::new(Metric::Cosine);
454 for i in 0..50 {
455 idx.add(format!("d{i}"), &[i as f32, 1.0, 2.0]).unwrap();
456 }
457 let q = arr2(&[[1.0_f32, 1.0, 2.0], [25.0, 1.0, 2.0]]);
458 let s = idx.search_batch(&q.view(), 5, false).unwrap();
459 let p = idx.search_batch(&q.view(), 5, true).unwrap();
460 assert_eq!(s, p);
461 assert_eq!(s.len(), 2);
462 assert_eq!(s[0].len(), 5);
463 }
464
465 #[test]
466 fn metric_get() {
467 let idx = Index::new(Metric::L2);
468 assert_eq!(idx.metric(), Metric::L2);
469 }
470
471 #[test]
472 fn empty_index_dim_is_none() {
473 let idx = Index::new(Metric::Cosine);
474 assert!(idx.dim().is_none());
475 assert!(idx.is_empty());
476 }
477
478 #[test]
479 fn cosine_normalizes_at_insert() {
480 let mut idx = Index::new(Metric::Cosine);
481 idx.add("a", &[3.0, 4.0]).unwrap();
483 let hits = idx.search(&[1.0, 0.0], 1).unwrap();
485 assert!((hits[0].score - 0.6).abs() < 1e-4);
486 }
487
488 #[test]
489 fn remove_present_returns_true() {
490 let mut idx = Index::new(Metric::Cosine);
491 idx.add("a", &[1.0, 0.0]).unwrap();
492 idx.add("b", &[0.0, 1.0]).unwrap();
493 assert!(idx.remove("a"));
494 assert_eq!(idx.len(), 1);
495 let hits = idx.search(&[1.0, 0.0], 1).unwrap();
497 assert_eq!(hits[0].id, "b");
498 }
499
500 #[test]
501 fn remove_missing_returns_false() {
502 let mut idx = Index::new(Metric::Cosine);
503 idx.add("a", &[1.0, 0.0]).unwrap();
504 assert!(!idx.remove("nonexistent"));
505 assert_eq!(idx.len(), 1);
506 }
507
508 #[test]
509 fn save_load_round_trip() {
510 let dir = std::env::temp_dir().join(format!(
511 "annflat-test-{}-{}",
512 std::process::id(),
513 std::time::SystemTime::now()
514 .duration_since(std::time::UNIX_EPOCH)
515 .unwrap()
516 .as_nanos()
517 ));
518 std::fs::create_dir_all(&dir).unwrap();
519 let path = dir.join("index.json");
520
521 let mut idx = Index::new(Metric::Cosine);
522 idx.add("a", &[1.0, 0.0]).unwrap();
523 idx.add("b", &[0.0, 1.0]).unwrap();
524 idx.add("c", &[0.6, 0.8]).unwrap();
525 idx.save(&path).unwrap();
526
527 let loaded = Index::load(&path).unwrap();
528 assert_eq!(loaded.len(), 3);
529 assert_eq!(loaded.metric(), Metric::Cosine);
530 let hits = loaded.search(&[1.0, 0.0], 3).unwrap();
531 assert_eq!(hits[0].id, "a");
532 }
533
534 #[test]
535 fn load_nonexistent_path_errors() {
536 let r = Index::load("/no/such/path/should/exist.json");
537 assert!(r.is_err());
538 }
539}