Skip to main content

annflat_core/
lib.rs

1//! Pure-Rust core for `annflat`. Flat brute-force ANN over `(n, d)` f32
2//! vectors with cosine / L2 / inner-product metrics.
3//!
4//! Cosine is implemented as inner product over L2-normalized vectors, so
5//! adding under the cosine metric runs an in-place normalization once at
6//! insert time. Subsequent search is a single matrix-vector dot.
7//!
8//! `score` is "higher is better": for cosine it's the cosine similarity
9//! in `[-1, 1]`, for inner product it's the dot, for L2 it's `-distance`
10//! (so the same `top_k` heap logic works across metrics).
11
12#![deny(unsafe_code)]
13#![warn(missing_docs)]
14#![warn(rust_2018_idioms)]
15
16use std::cmp::Reverse;
17use std::collections::BinaryHeap;
18
19use ndarray::{Array2, ArrayView1, ArrayView2, Axis};
20use rayon::prelude::*;
21use serde::{Deserialize, Serialize};
22use thiserror::Error;
23
24/// Crate-wide result alias.
25pub type Result<T> = std::result::Result<T, AnnFlatError>;
26
27/// All errors surfaced by `annflat-core`.
28#[derive(Error, Debug)]
29pub enum AnnFlatError {
30    /// Caller supplied a vector of the wrong dimensionality.
31    #[error("dim mismatch: expected {expected}, got {got}")]
32    DimMismatch {
33        /// Expected dimensionality (set by the first insert).
34        expected: usize,
35        /// Provided dimensionality.
36        got: usize,
37    },
38    /// Caller asked for more results than exist in the index.
39    #[error("k ({k}) > index size ({n})")]
40    KTooLarge {
41        /// Requested k.
42        k: usize,
43        /// Available entries.
44        n: usize,
45    },
46    /// Caller passed `k = 0`.
47    #[error("k must be > 0")]
48    KZero,
49    /// `add_batch` length mismatch.
50    #[error("add_batch ids and matrix row counts disagree: {ids} vs {rows}")]
51    BatchLengthMismatch {
52        /// Number of ids supplied.
53        ids: usize,
54        /// Number of rows in the matrix.
55        rows: usize,
56    },
57}
58
59/// Distance metric.
60#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
61#[serde(rename_all = "snake_case")]
62pub enum Metric {
63    /// Cosine similarity. Vectors are L2-normalized at insert time.
64    Cosine,
65    /// L2 (Euclidean) distance. Reported `score` is `-distance`.
66    L2,
67    /// Inner product. No normalization.
68    Dot,
69}
70
71/// One match returned by [`Index::search`].
72#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
73pub struct Hit {
74    /// Document id.
75    pub id: String,
76    /// Score; higher is better. See [`Metric`] for semantics.
77    pub score: f32,
78}
79
80/// In-memory flat ANN index.
81pub struct Index {
82    metric: Metric,
83    dim: Option<usize>,
84    ids: Vec<String>,
85    /// `(n, d)` f32 storage. For `Metric::Cosine`, rows are L2-normalized.
86    vectors: Vec<Vec<f32>>,
87}
88
89impl Index {
90    /// Build an empty index.
91    pub fn new(metric: Metric) -> Self {
92        Self {
93            metric,
94            dim: None,
95            ids: Vec::new(),
96            vectors: Vec::new(),
97        }
98    }
99
100    /// Active metric.
101    pub fn metric(&self) -> Metric {
102        self.metric
103    }
104
105    /// Number of indexed vectors.
106    pub fn len(&self) -> usize {
107        self.ids.len()
108    }
109
110    /// True iff no vectors are indexed.
111    pub fn is_empty(&self) -> bool {
112        self.ids.is_empty()
113    }
114
115    /// Dimension of stored vectors, set on first insert.
116    pub fn dim(&self) -> Option<usize> {
117        self.dim
118    }
119
120    /// Insert a single vector.
121    pub fn add(&mut self, id: impl Into<String>, vector: &[f32]) -> Result<()> {
122        match self.dim {
123            None => self.dim = Some(vector.len()),
124            Some(d) if d != vector.len() => {
125                return Err(AnnFlatError::DimMismatch {
126                    expected: d,
127                    got: vector.len(),
128                });
129            }
130            _ => {}
131        }
132        let mut v = vector.to_vec();
133        if self.metric == Metric::Cosine {
134            normalize_in_place(&mut v);
135        }
136        self.ids.push(id.into());
137        self.vectors.push(v);
138        Ok(())
139    }
140
141    /// Insert many vectors. `matrix` is `(n, d)`.
142    pub fn add_batch(&mut self, ids: Vec<String>, matrix: &ArrayView2<'_, f32>) -> Result<()> {
143        if ids.len() != matrix.nrows() {
144            return Err(AnnFlatError::BatchLengthMismatch {
145                ids: ids.len(),
146                rows: matrix.nrows(),
147            });
148        }
149        for (id, row) in ids.into_iter().zip(matrix.axis_iter(Axis(0))) {
150            self.add(id, row.as_slice().unwrap_or(&row.to_vec()))?;
151        }
152        Ok(())
153    }
154
155    /// Top-k search. `query` is a 1-D vector of the same dimension as
156    /// stored vectors.
157    pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<Hit>> {
158        if k == 0 {
159            return Err(AnnFlatError::KZero);
160        }
161        if k > self.len() {
162            return Err(AnnFlatError::KTooLarge { k, n: self.len() });
163        }
164        match self.dim {
165            Some(d) if d != query.len() => {
166                return Err(AnnFlatError::DimMismatch {
167                    expected: d,
168                    got: query.len(),
169                });
170            }
171            None => {
172                return Err(AnnFlatError::KTooLarge { k, n: 0 });
173            }
174            _ => {}
175        }
176        let q: Vec<f32> = if self.metric == Metric::Cosine {
177            let mut q2 = query.to_vec();
178            normalize_in_place(&mut q2);
179            q2
180        } else {
181            query.to_vec()
182        };
183
184        // Maintain a min-heap of (score, idx) of size k. Use OrdScore so we
185        // can BinaryHeap on f32.
186        let mut heap: BinaryHeap<(Reverse<OrdScore>, usize)> = BinaryHeap::with_capacity(k);
187        for (i, v) in self.vectors.iter().enumerate() {
188            let s = self.score(&q, v);
189            let entry = (Reverse(OrdScore(s)), i);
190            if heap.len() < k {
191                heap.push(entry);
192            } else if let Some(top) = heap.peek() {
193                if entry.0 < top.0 {
194                    heap.pop();
195                    heap.push(entry);
196                }
197            }
198        }
199        let mut out: Vec<Hit> = heap
200            .into_iter()
201            .map(|(rs, i)| Hit {
202                id: self.ids[i].clone(),
203                score: rs.0 .0,
204            })
205            .collect();
206        out.sort_by(|a, b| {
207            b.score
208                .partial_cmp(&a.score)
209                .unwrap_or(std::cmp::Ordering::Equal)
210                .then(a.id.cmp(&b.id))
211        });
212        Ok(out)
213    }
214
215    /// Batch search. `queries` is `(n_q, d)`. With `parallel = true`, each
216    /// query runs on a rayon thread.
217    pub fn search_batch(
218        &self,
219        queries: &ArrayView2<'_, f32>,
220        k: usize,
221        parallel: bool,
222    ) -> Result<Vec<Vec<Hit>>> {
223        if parallel {
224            queries
225                .axis_iter(Axis(0))
226                .into_par_iter()
227                .map(|row| self.search_view(&row, k))
228                .collect()
229        } else {
230            queries
231                .axis_iter(Axis(0))
232                .map(|row| self.search_view(&row, k))
233                .collect()
234        }
235    }
236
237    fn search_view(&self, row: &ArrayView1<'_, f32>, k: usize) -> Result<Vec<Hit>> {
238        match row.as_slice() {
239            Some(s) => self.search(s, k),
240            None => self.search(&row.to_vec(), k),
241        }
242    }
243
244    /// Snapshot vectors as a `(n, d)` matrix (allocates a copy).
245    pub fn vectors(&self) -> Result<Array2<f32>> {
246        let n = self.len();
247        let d = self.dim.unwrap_or(0);
248        if n == 0 {
249            return Ok(Array2::<f32>::zeros((0, 0)));
250        }
251        let mut out = Array2::<f32>::zeros((n, d));
252        for (i, v) in self.vectors.iter().enumerate() {
253            for (j, &x) in v.iter().enumerate() {
254                out[[i, j]] = x;
255            }
256        }
257        Ok(out)
258    }
259
260    fn score(&self, q: &[f32], v: &[f32]) -> f32 {
261        match self.metric {
262            Metric::Cosine | Metric::Dot => {
263                let mut s = 0.0_f32;
264                for (a, b) in q.iter().zip(v.iter()) {
265                    s += a * b;
266                }
267                s
268            }
269            Metric::L2 => {
270                let mut s = 0.0_f32;
271                for (a, b) in q.iter().zip(v.iter()) {
272                    let d = a - b;
273                    s += d * d;
274                }
275                -s.sqrt()
276            }
277        }
278    }
279}
280
281fn normalize_in_place(v: &mut [f32]) {
282    let mut sq = 0.0_f32;
283    for &x in v.iter() {
284        sq += x * x;
285    }
286    let n = sq.sqrt();
287    if n > 1e-12 {
288        for x in v.iter_mut() {
289            *x /= n;
290        }
291    } else {
292        for x in v.iter_mut() {
293            *x = 0.0;
294        }
295    }
296}
297
298#[derive(Debug, Clone, Copy, PartialEq)]
299struct OrdScore(f32);
300impl Eq for OrdScore {}
301impl Ord for OrdScore {
302    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
303        self.0
304            .partial_cmp(&other.0)
305            .unwrap_or(std::cmp::Ordering::Equal)
306    }
307}
308impl PartialOrd for OrdScore {
309    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
310        Some(self.cmp(other))
311    }
312}
313
314#[cfg(test)]
315mod tests {
316    use super::*;
317    use ndarray::arr2;
318
319    #[test]
320    fn empty_search_rejected() {
321        let idx = Index::new(Metric::Cosine);
322        assert!(idx.search(&[1.0, 2.0], 1).is_err());
323    }
324
325    #[test]
326    fn cosine_search_finds_self() {
327        let mut idx = Index::new(Metric::Cosine);
328        idx.add("a", &[1.0, 0.0]).unwrap();
329        idx.add("b", &[0.0, 1.0]).unwrap();
330        idx.add("c", &[0.6, 0.8]).unwrap();
331        let hits = idx.search(&[1.0, 0.0], 3).unwrap();
332        assert_eq!(hits[0].id, "a");
333        assert!((hits[0].score - 1.0).abs() < 1e-4);
334    }
335
336    #[test]
337    fn l2_search_smaller_distance_first() {
338        let mut idx = Index::new(Metric::L2);
339        idx.add("near", &[1.0, 1.0]).unwrap();
340        idx.add("far", &[10.0, 10.0]).unwrap();
341        let hits = idx.search(&[1.0, 1.1], 2).unwrap();
342        assert_eq!(hits[0].id, "near");
343        // Score is -distance, so near is closer to 0.
344        assert!(hits[0].score > hits[1].score);
345    }
346
347    #[test]
348    fn dot_search() {
349        let mut idx = Index::new(Metric::Dot);
350        idx.add("a", &[1.0, 1.0]).unwrap();
351        idx.add("b", &[2.0, 2.0]).unwrap();
352        let hits = idx.search(&[1.0, 1.0], 2).unwrap();
353        // b · q = 4, a · q = 2.
354        assert_eq!(hits[0].id, "b");
355        assert!((hits[0].score - 4.0).abs() < 1e-6);
356    }
357
358    #[test]
359    fn dim_mismatch_on_add() {
360        let mut idx = Index::new(Metric::Cosine);
361        idx.add("a", &[1.0, 0.0]).unwrap();
362        assert!(idx.add("b", &[1.0]).is_err());
363    }
364
365    #[test]
366    fn dim_mismatch_on_search() {
367        let mut idx = Index::new(Metric::Cosine);
368        idx.add("a", &[1.0, 0.0]).unwrap();
369        assert!(idx.search(&[1.0], 1).is_err());
370    }
371
372    #[test]
373    fn k_zero_rejected() {
374        let mut idx = Index::new(Metric::Cosine);
375        idx.add("a", &[1.0, 0.0]).unwrap();
376        assert!(matches!(
377            idx.search(&[1.0, 0.0], 0),
378            Err(AnnFlatError::KZero)
379        ));
380    }
381
382    #[test]
383    fn k_too_large_rejected() {
384        let mut idx = Index::new(Metric::Cosine);
385        idx.add("a", &[1.0, 0.0]).unwrap();
386        assert!(matches!(
387            idx.search(&[1.0, 0.0], 5),
388            Err(AnnFlatError::KTooLarge { .. })
389        ));
390    }
391
392    #[test]
393    fn add_batch_works() {
394        let mut idx = Index::new(Metric::Cosine);
395        let m = arr2(&[[1.0_f32, 0.0], [0.0, 1.0], [0.5, 0.5]]);
396        idx.add_batch(
397            vec!["a".to_string(), "b".to_string(), "c".to_string()],
398            &m.view(),
399        )
400        .unwrap();
401        assert_eq!(idx.len(), 3);
402    }
403
404    #[test]
405    fn add_batch_length_mismatch() {
406        let mut idx = Index::new(Metric::Cosine);
407        let m = arr2(&[[1.0_f32, 0.0], [0.0, 1.0]]);
408        let r = idx.add_batch(vec!["a".to_string()], &m.view());
409        assert!(matches!(r, Err(AnnFlatError::BatchLengthMismatch { .. })));
410    }
411
412    #[test]
413    fn search_batch_serial_and_parallel_match() {
414        let mut idx = Index::new(Metric::Cosine);
415        for i in 0..50 {
416            idx.add(format!("d{i}"), &[i as f32, 1.0, 2.0]).unwrap();
417        }
418        let q = arr2(&[[1.0_f32, 1.0, 2.0], [25.0, 1.0, 2.0]]);
419        let s = idx.search_batch(&q.view(), 5, false).unwrap();
420        let p = idx.search_batch(&q.view(), 5, true).unwrap();
421        assert_eq!(s, p);
422        assert_eq!(s.len(), 2);
423        assert_eq!(s[0].len(), 5);
424    }
425
426    #[test]
427    fn metric_get() {
428        let idx = Index::new(Metric::L2);
429        assert_eq!(idx.metric(), Metric::L2);
430    }
431
432    #[test]
433    fn empty_index_dim_is_none() {
434        let idx = Index::new(Metric::Cosine);
435        assert!(idx.dim().is_none());
436        assert!(idx.is_empty());
437    }
438
439    #[test]
440    fn cosine_normalizes_at_insert() {
441        let mut idx = Index::new(Metric::Cosine);
442        // Insert a non-unit vector. Internally it gets normalized.
443        idx.add("a", &[3.0, 4.0]).unwrap();
444        // Search with [1, 0]; cosine is 3/5 = 0.6.
445        let hits = idx.search(&[1.0, 0.0], 1).unwrap();
446        assert!((hits[0].score - 0.6).abs() < 1e-4);
447    }
448}