Skip to main content

annflat_core/
lib.rs

1//! Pure-Rust core for `annflat`. Flat brute-force ANN over `(n, d)` f32
2//! vectors with cosine / L2 / inner-product metrics.
3//!
4//! Cosine is implemented as inner product over L2-normalized vectors, so
5//! adding under the cosine metric runs an in-place normalization once at
6//! insert time. Subsequent search is a single matrix-vector dot.
7//!
8//! `score` is "higher is better": for cosine it's the cosine similarity
9//! in `[-1, 1]`, for inner product it's the dot, for L2 it's `-distance`
10//! (so the same `top_k` heap logic works across metrics).
11
12#![deny(unsafe_code)]
13#![warn(missing_docs)]
14#![warn(rust_2018_idioms)]
15
16use std::cmp::Reverse;
17use std::collections::BinaryHeap;
18
19use ndarray::{Array2, ArrayView1, ArrayView2, Axis};
20use rayon::prelude::*;
21use serde::{Deserialize, Serialize};
22use thiserror::Error;
23
24/// Crate-wide result alias.
25pub type Result<T> = std::result::Result<T, AnnFlatError>;
26
27/// All errors surfaced by `annflat-core`.
28#[derive(Error, Debug)]
29pub enum AnnFlatError {
30    /// I/O failure during save/load.
31    #[error("io error: {0}")]
32    Io(#[from] std::io::Error),
33    /// JSON serialization or deserialization failure.
34    #[error("serde error: {0}")]
35    Serde(#[from] serde_json::Error),
36    /// Caller supplied a vector of the wrong dimensionality.
37    #[error("dim mismatch: expected {expected}, got {got}")]
38    DimMismatch {
39        /// Expected dimensionality (set by the first insert).
40        expected: usize,
41        /// Provided dimensionality.
42        got: usize,
43    },
44    /// Caller asked for more results than exist in the index.
45    #[error("k ({k}) > index size ({n})")]
46    KTooLarge {
47        /// Requested k.
48        k: usize,
49        /// Available entries.
50        n: usize,
51    },
52    /// Caller passed `k = 0`.
53    #[error("k must be > 0")]
54    KZero,
55    /// `add_batch` length mismatch.
56    #[error("add_batch ids and matrix row counts disagree: {ids} vs {rows}")]
57    BatchLengthMismatch {
58        /// Number of ids supplied.
59        ids: usize,
60        /// Number of rows in the matrix.
61        rows: usize,
62    },
63}
64
65/// Distance metric.
66#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
67#[serde(rename_all = "snake_case")]
68pub enum Metric {
69    /// Cosine similarity. Vectors are L2-normalized at insert time.
70    Cosine,
71    /// L2 (Euclidean) distance. Reported `score` is `-distance`.
72    L2,
73    /// Inner product. No normalization.
74    Dot,
75}
76
77/// One match returned by [`Index::search`].
78#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
79pub struct Hit {
80    /// Document id.
81    pub id: String,
82    /// Score; higher is better. See [`Metric`] for semantics.
83    pub score: f32,
84}
85
86/// In-memory flat ANN index.
87#[derive(Serialize, Deserialize)]
88pub struct Index {
89    metric: Metric,
90    dim: Option<usize>,
91    ids: Vec<String>,
92    /// `(n, d)` f32 storage. For `Metric::Cosine`, rows are L2-normalized.
93    vectors: Vec<Vec<f32>>,
94}
95
96impl Index {
97    /// Build an empty index.
98    pub fn new(metric: Metric) -> Self {
99        Self {
100            metric,
101            dim: None,
102            ids: Vec::new(),
103            vectors: Vec::new(),
104        }
105    }
106
107    /// Active metric.
108    pub fn metric(&self) -> Metric {
109        self.metric
110    }
111
112    /// Number of indexed vectors.
113    pub fn len(&self) -> usize {
114        self.ids.len()
115    }
116
117    /// True iff no vectors are indexed.
118    pub fn is_empty(&self) -> bool {
119        self.ids.is_empty()
120    }
121
122    /// Dimension of stored vectors, set on first insert.
123    pub fn dim(&self) -> Option<usize> {
124        self.dim
125    }
126
127    /// Insert a single vector.
128    pub fn add(&mut self, id: impl Into<String>, vector: &[f32]) -> Result<()> {
129        match self.dim {
130            None => self.dim = Some(vector.len()),
131            Some(d) if d != vector.len() => {
132                return Err(AnnFlatError::DimMismatch {
133                    expected: d,
134                    got: vector.len(),
135                });
136            }
137            _ => {}
138        }
139        let mut v = vector.to_vec();
140        if self.metric == Metric::Cosine {
141            normalize_in_place(&mut v);
142        }
143        self.ids.push(id.into());
144        self.vectors.push(v);
145        Ok(())
146    }
147
148    /// Remove the first entry whose id matches. Returns `true` if found.
149    /// O(n) — uses swap-remove so the rest of the index isn't shifted.
150    /// `dim` is preserved even after the index becomes empty.
151    pub fn remove(&mut self, id: &str) -> bool {
152        let Some(pos) = self.ids.iter().position(|s| s == id) else {
153            return false;
154        };
155        self.ids.swap_remove(pos);
156        self.vectors.swap_remove(pos);
157        true
158    }
159
160    /// Persist the entire index (metric, dim, ids, vectors) to a JSON
161    /// file. Re-load with [`Index::load`]. JSON is verbose for f32 arrays
162    /// but cross-platform and debuggable; binary persistence is on the
163    /// v0.2 list.
164    pub fn save<P: AsRef<std::path::Path>>(&self, path: P) -> Result<()> {
165        let file = std::fs::File::create(path)?;
166        let buf = std::io::BufWriter::new(file);
167        serde_json::to_writer(buf, self)?;
168        Ok(())
169    }
170
171    /// Reverse of [`Index::save`]. Loads an index previously saved with
172    /// the same crate version.
173    pub fn load<P: AsRef<std::path::Path>>(path: P) -> Result<Self> {
174        let file = std::fs::File::open(path)?;
175        let buf = std::io::BufReader::new(file);
176        let idx: Self = serde_json::from_reader(buf)?;
177        Ok(idx)
178    }
179
180    /// Insert many vectors. `matrix` is `(n, d)`.
181    pub fn add_batch(&mut self, ids: Vec<String>, matrix: &ArrayView2<'_, f32>) -> Result<()> {
182        if ids.len() != matrix.nrows() {
183            return Err(AnnFlatError::BatchLengthMismatch {
184                ids: ids.len(),
185                rows: matrix.nrows(),
186            });
187        }
188        for (id, row) in ids.into_iter().zip(matrix.axis_iter(Axis(0))) {
189            self.add(id, row.as_slice().unwrap_or(&row.to_vec()))?;
190        }
191        Ok(())
192    }
193
194    /// Top-k search. `query` is a 1-D vector of the same dimension as
195    /// stored vectors.
196    pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<Hit>> {
197        if k == 0 {
198            return Err(AnnFlatError::KZero);
199        }
200        if k > self.len() {
201            return Err(AnnFlatError::KTooLarge { k, n: self.len() });
202        }
203        match self.dim {
204            Some(d) if d != query.len() => {
205                return Err(AnnFlatError::DimMismatch {
206                    expected: d,
207                    got: query.len(),
208                });
209            }
210            None => {
211                return Err(AnnFlatError::KTooLarge { k, n: 0 });
212            }
213            _ => {}
214        }
215        let q: Vec<f32> = if self.metric == Metric::Cosine {
216            let mut q2 = query.to_vec();
217            normalize_in_place(&mut q2);
218            q2
219        } else {
220            query.to_vec()
221        };
222
223        // Maintain a min-heap of (score, idx) of size k. Use OrdScore so we
224        // can BinaryHeap on f32.
225        let mut heap: BinaryHeap<(Reverse<OrdScore>, usize)> = BinaryHeap::with_capacity(k);
226        for (i, v) in self.vectors.iter().enumerate() {
227            let s = self.score(&q, v);
228            let entry = (Reverse(OrdScore(s)), i);
229            if heap.len() < k {
230                heap.push(entry);
231            } else if let Some(top) = heap.peek() {
232                if entry.0 < top.0 {
233                    heap.pop();
234                    heap.push(entry);
235                }
236            }
237        }
238        let mut out: Vec<Hit> = heap
239            .into_iter()
240            .map(|(rs, i)| Hit {
241                id: self.ids[i].clone(),
242                score: rs.0 .0,
243            })
244            .collect();
245        out.sort_by(|a, b| {
246            b.score
247                .partial_cmp(&a.score)
248                .unwrap_or(std::cmp::Ordering::Equal)
249                .then(a.id.cmp(&b.id))
250        });
251        Ok(out)
252    }
253
254    /// Batch search. `queries` is `(n_q, d)`. With `parallel = true`, each
255    /// query runs on a rayon thread.
256    pub fn search_batch(
257        &self,
258        queries: &ArrayView2<'_, f32>,
259        k: usize,
260        parallel: bool,
261    ) -> Result<Vec<Vec<Hit>>> {
262        if parallel {
263            queries
264                .axis_iter(Axis(0))
265                .into_par_iter()
266                .map(|row| self.search_view(&row, k))
267                .collect()
268        } else {
269            queries
270                .axis_iter(Axis(0))
271                .map(|row| self.search_view(&row, k))
272                .collect()
273        }
274    }
275
276    fn search_view(&self, row: &ArrayView1<'_, f32>, k: usize) -> Result<Vec<Hit>> {
277        match row.as_slice() {
278            Some(s) => self.search(s, k),
279            None => self.search(&row.to_vec(), k),
280        }
281    }
282
283    /// Snapshot vectors as a `(n, d)` matrix (allocates a copy).
284    pub fn vectors(&self) -> Result<Array2<f32>> {
285        let n = self.len();
286        let d = self.dim.unwrap_or(0);
287        if n == 0 {
288            return Ok(Array2::<f32>::zeros((0, 0)));
289        }
290        let mut out = Array2::<f32>::zeros((n, d));
291        for (i, v) in self.vectors.iter().enumerate() {
292            for (j, &x) in v.iter().enumerate() {
293                out[[i, j]] = x;
294            }
295        }
296        Ok(out)
297    }
298
299    fn score(&self, q: &[f32], v: &[f32]) -> f32 {
300        match self.metric {
301            Metric::Cosine | Metric::Dot => {
302                let mut s = 0.0_f32;
303                for (a, b) in q.iter().zip(v.iter()) {
304                    s += a * b;
305                }
306                s
307            }
308            Metric::L2 => {
309                let mut s = 0.0_f32;
310                for (a, b) in q.iter().zip(v.iter()) {
311                    let d = a - b;
312                    s += d * d;
313                }
314                -s.sqrt()
315            }
316        }
317    }
318}
319
320fn normalize_in_place(v: &mut [f32]) {
321    let mut sq = 0.0_f32;
322    for &x in v.iter() {
323        sq += x * x;
324    }
325    let n = sq.sqrt();
326    if n > 1e-12 {
327        for x in v.iter_mut() {
328            *x /= n;
329        }
330    } else {
331        for x in v.iter_mut() {
332            *x = 0.0;
333        }
334    }
335}
336
337#[derive(Debug, Clone, Copy, PartialEq)]
338struct OrdScore(f32);
339impl Eq for OrdScore {}
340impl Ord for OrdScore {
341    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
342        self.0
343            .partial_cmp(&other.0)
344            .unwrap_or(std::cmp::Ordering::Equal)
345    }
346}
347impl PartialOrd for OrdScore {
348    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
349        Some(self.cmp(other))
350    }
351}
352
353#[cfg(test)]
354mod tests {
355    use super::*;
356    use ndarray::arr2;
357
358    #[test]
359    fn empty_search_rejected() {
360        let idx = Index::new(Metric::Cosine);
361        assert!(idx.search(&[1.0, 2.0], 1).is_err());
362    }
363
364    #[test]
365    fn cosine_search_finds_self() {
366        let mut idx = Index::new(Metric::Cosine);
367        idx.add("a", &[1.0, 0.0]).unwrap();
368        idx.add("b", &[0.0, 1.0]).unwrap();
369        idx.add("c", &[0.6, 0.8]).unwrap();
370        let hits = idx.search(&[1.0, 0.0], 3).unwrap();
371        assert_eq!(hits[0].id, "a");
372        assert!((hits[0].score - 1.0).abs() < 1e-4);
373    }
374
375    #[test]
376    fn l2_search_smaller_distance_first() {
377        let mut idx = Index::new(Metric::L2);
378        idx.add("near", &[1.0, 1.0]).unwrap();
379        idx.add("far", &[10.0, 10.0]).unwrap();
380        let hits = idx.search(&[1.0, 1.1], 2).unwrap();
381        assert_eq!(hits[0].id, "near");
382        // Score is -distance, so near is closer to 0.
383        assert!(hits[0].score > hits[1].score);
384    }
385
386    #[test]
387    fn dot_search() {
388        let mut idx = Index::new(Metric::Dot);
389        idx.add("a", &[1.0, 1.0]).unwrap();
390        idx.add("b", &[2.0, 2.0]).unwrap();
391        let hits = idx.search(&[1.0, 1.0], 2).unwrap();
392        // b · q = 4, a · q = 2.
393        assert_eq!(hits[0].id, "b");
394        assert!((hits[0].score - 4.0).abs() < 1e-6);
395    }
396
397    #[test]
398    fn dim_mismatch_on_add() {
399        let mut idx = Index::new(Metric::Cosine);
400        idx.add("a", &[1.0, 0.0]).unwrap();
401        assert!(idx.add("b", &[1.0]).is_err());
402    }
403
404    #[test]
405    fn dim_mismatch_on_search() {
406        let mut idx = Index::new(Metric::Cosine);
407        idx.add("a", &[1.0, 0.0]).unwrap();
408        assert!(idx.search(&[1.0], 1).is_err());
409    }
410
411    #[test]
412    fn k_zero_rejected() {
413        let mut idx = Index::new(Metric::Cosine);
414        idx.add("a", &[1.0, 0.0]).unwrap();
415        assert!(matches!(
416            idx.search(&[1.0, 0.0], 0),
417            Err(AnnFlatError::KZero)
418        ));
419    }
420
421    #[test]
422    fn k_too_large_rejected() {
423        let mut idx = Index::new(Metric::Cosine);
424        idx.add("a", &[1.0, 0.0]).unwrap();
425        assert!(matches!(
426            idx.search(&[1.0, 0.0], 5),
427            Err(AnnFlatError::KTooLarge { .. })
428        ));
429    }
430
431    #[test]
432    fn add_batch_works() {
433        let mut idx = Index::new(Metric::Cosine);
434        let m = arr2(&[[1.0_f32, 0.0], [0.0, 1.0], [0.5, 0.5]]);
435        idx.add_batch(
436            vec!["a".to_string(), "b".to_string(), "c".to_string()],
437            &m.view(),
438        )
439        .unwrap();
440        assert_eq!(idx.len(), 3);
441    }
442
443    #[test]
444    fn add_batch_length_mismatch() {
445        let mut idx = Index::new(Metric::Cosine);
446        let m = arr2(&[[1.0_f32, 0.0], [0.0, 1.0]]);
447        let r = idx.add_batch(vec!["a".to_string()], &m.view());
448        assert!(matches!(r, Err(AnnFlatError::BatchLengthMismatch { .. })));
449    }
450
451    #[test]
452    fn search_batch_serial_and_parallel_match() {
453        let mut idx = Index::new(Metric::Cosine);
454        for i in 0..50 {
455            idx.add(format!("d{i}"), &[i as f32, 1.0, 2.0]).unwrap();
456        }
457        let q = arr2(&[[1.0_f32, 1.0, 2.0], [25.0, 1.0, 2.0]]);
458        let s = idx.search_batch(&q.view(), 5, false).unwrap();
459        let p = idx.search_batch(&q.view(), 5, true).unwrap();
460        assert_eq!(s, p);
461        assert_eq!(s.len(), 2);
462        assert_eq!(s[0].len(), 5);
463    }
464
465    #[test]
466    fn metric_get() {
467        let idx = Index::new(Metric::L2);
468        assert_eq!(idx.metric(), Metric::L2);
469    }
470
471    #[test]
472    fn empty_index_dim_is_none() {
473        let idx = Index::new(Metric::Cosine);
474        assert!(idx.dim().is_none());
475        assert!(idx.is_empty());
476    }
477
478    #[test]
479    fn cosine_normalizes_at_insert() {
480        let mut idx = Index::new(Metric::Cosine);
481        // Insert a non-unit vector. Internally it gets normalized.
482        idx.add("a", &[3.0, 4.0]).unwrap();
483        // Search with [1, 0]; cosine is 3/5 = 0.6.
484        let hits = idx.search(&[1.0, 0.0], 1).unwrap();
485        assert!((hits[0].score - 0.6).abs() < 1e-4);
486    }
487
488    #[test]
489    fn remove_present_returns_true() {
490        let mut idx = Index::new(Metric::Cosine);
491        idx.add("a", &[1.0, 0.0]).unwrap();
492        idx.add("b", &[0.0, 1.0]).unwrap();
493        assert!(idx.remove("a"));
494        assert_eq!(idx.len(), 1);
495        // Search excludes the removed id.
496        let hits = idx.search(&[1.0, 0.0], 1).unwrap();
497        assert_eq!(hits[0].id, "b");
498    }
499
500    #[test]
501    fn remove_missing_returns_false() {
502        let mut idx = Index::new(Metric::Cosine);
503        idx.add("a", &[1.0, 0.0]).unwrap();
504        assert!(!idx.remove("nonexistent"));
505        assert_eq!(idx.len(), 1);
506    }
507
508    #[test]
509    fn save_load_round_trip() {
510        let dir = std::env::temp_dir().join(format!(
511            "annflat-test-{}-{}",
512            std::process::id(),
513            std::time::SystemTime::now()
514                .duration_since(std::time::UNIX_EPOCH)
515                .unwrap()
516                .as_nanos()
517        ));
518        std::fs::create_dir_all(&dir).unwrap();
519        let path = dir.join("index.json");
520
521        let mut idx = Index::new(Metric::Cosine);
522        idx.add("a", &[1.0, 0.0]).unwrap();
523        idx.add("b", &[0.0, 1.0]).unwrap();
524        idx.add("c", &[0.6, 0.8]).unwrap();
525        idx.save(&path).unwrap();
526
527        let loaded = Index::load(&path).unwrap();
528        assert_eq!(loaded.len(), 3);
529        assert_eq!(loaded.metric(), Metric::Cosine);
530        let hits = loaded.search(&[1.0, 0.0], 3).unwrap();
531        assert_eq!(hits[0].id, "a");
532    }
533
534    #[test]
535    fn load_nonexistent_path_errors() {
536        let r = Index::load("/no/such/path/should/exist.json");
537        assert!(r.is_err());
538    }
539}