Skip to main content

claw_vector/index/
flat.rs

1// index/flat.rs — brute-force flat index for small collections (< 1 000 vectors).
2use std::sync::RwLock;
3
4use rayon::prelude::*;
5use tracing::instrument;
6
7use crate::{
8    config::VectorConfig,
9    error::{VectorError, VectorResult},
10    index::hnsw::HnswIndex,
11    types::DistanceMetric,
12};
13
14// ─── Distance kernels ────────────────────────────────────────────────────────
15
16/// Cosine distance (1 − cosine similarity); SIMD-friendly iterator form.
17pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
18    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
19    let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
20    let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
21    if na == 0.0 || nb == 0.0 {
22        1.0
23    } else {
24        1.0 - dot / (na * nb)
25    }
26}
27
28/// Euclidean (L2) distance.
29pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
30    a.iter()
31        .zip(b.iter())
32        .map(|(x, y)| (x - y) * (x - y))
33        .sum::<f32>()
34        .sqrt()
35}
36
37/// Negative dot product ("distance" — lower = more similar).
38pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
39    -a.iter().zip(b.iter()).map(|(x, y)| x * y).sum::<f32>()
40}
41
42// ─── FlatIndex ───────────────────────────────────────────────────────────────
43
44/// Brute-force flat index backed by a `RwLock<Vec<(id, vector)>>`.
45pub struct FlatIndex {
46    vectors: RwLock<Vec<(usize, Vec<f32>)>>,
47    /// Expected vector dimensionality.
48    pub dimensions: usize,
49    /// Distance metric for similarity comparisons.
50    pub distance: DistanceMetric,
51}
52
53impl FlatIndex {
54    /// Create a new, empty flat index.
55    pub fn new(dimensions: usize, distance: DistanceMetric) -> Self {
56        FlatIndex {
57            vectors: RwLock::new(Vec::new()),
58            dimensions,
59            distance,
60        }
61    }
62
63    /// Insert a single vector, validating its dimensionality.
64    #[instrument(skip(self, vector))]
65    pub fn insert(&self, id: usize, vector: Vec<f32>) -> VectorResult<()> {
66        if vector.len() != self.dimensions {
67            return Err(VectorError::DimensionMismatch {
68                expected: self.dimensions,
69                got: vector.len(),
70            });
71        }
72        self.vectors
73            .write()
74            .map_err(|e| VectorError::Index(e.to_string()))?
75            .push((id, vector));
76        Ok(())
77    }
78
79    /// Insert multiple vectors.
80    #[instrument(skip(self, items))]
81    pub fn insert_batch(&self, items: Vec<(usize, Vec<f32>)>) -> VectorResult<()> {
82        for (_, v) in &items {
83            if v.len() != self.dimensions {
84                return Err(VectorError::DimensionMismatch {
85                    expected: self.dimensions,
86                    got: v.len(),
87                });
88            }
89        }
90        self.vectors
91            .write()
92            .map_err(|e| VectorError::Index(e.to_string()))?
93            .extend(items);
94        Ok(())
95    }
96
97    /// Score all stored vectors in parallel and return the `top_k` closest.
98    #[instrument(skip(self, query))]
99    pub fn search(&self, query: &[f32], top_k: usize) -> VectorResult<Vec<(usize, f32)>> {
100        if query.len() != self.dimensions {
101            return Err(VectorError::DimensionMismatch {
102                expected: self.dimensions,
103                got: query.len(),
104            });
105        }
106        let vecs = self
107            .vectors
108            .read()
109            .map_err(|e| VectorError::Index(e.to_string()))?;
110        let dist = self.distance;
111        let mut scores: Vec<(usize, f32)> = vecs
112            .par_iter()
113            .map(|(id, v)| {
114                let d = match dist {
115                    DistanceMetric::Cosine => cosine_similarity(query, v),
116                    DistanceMetric::Euclidean => euclidean_distance(query, v),
117                    DistanceMetric::DotProduct => dot_product(query, v),
118                };
119                (*id, d)
120            })
121            .collect();
122        scores.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
123        scores.truncate(top_k);
124        Ok(scores)
125    }
126
127    /// Remove a vector by id. Returns `true` if the id was present.
128    #[instrument(skip(self))]
129    pub fn delete(&self, id: usize) -> VectorResult<bool> {
130        let mut vecs = self
131            .vectors
132            .write()
133            .map_err(|e| VectorError::Index(e.to_string()))?;
134        let before = vecs.len();
135        vecs.retain(|(vid, _)| *vid != id);
136        Ok(vecs.len() < before)
137    }
138
139    /// Return the number of stored vectors.
140    pub fn len(&self) -> usize {
141        self.vectors.read().map(|v| v.len()).unwrap_or(0)
142    }
143
144    /// Return `true` if no vectors are stored.
145    pub fn is_empty(&self) -> bool {
146        self.len() == 0
147    }
148
149    /// Return all stored (id, vector) pairs (used for persistence and migration).
150    pub fn all_vectors(&self) -> VectorResult<Vec<(usize, Vec<f32>)>> {
151        Ok(self
152            .vectors
153            .read()
154            .map_err(|e| VectorError::Index(e.to_string()))?
155            .clone())
156    }
157
158    /// Migrate all vectors into a fresh [`HnswIndex`].
159    #[instrument(skip(self, config))]
160    pub fn to_hnsw(&self, config: &VectorConfig) -> VectorResult<HnswIndex> {
161        let hnsw = HnswIndex::new_with_dimensions(config, self.distance, self.dimensions)?;
162        let items = self.all_vectors()?;
163        hnsw.insert_batch(&items)?;
164        Ok(hnsw)
165    }
166}
167
168// ─── Unit tests ──────────────────────────────────────────────────────────────
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173    use approx::assert_abs_diff_eq;
174
175    #[test]
176    fn cosine_orthogonal_vectors() {
177        let a = vec![1.0f32, 0.0];
178        let b = vec![0.0f32, 1.0];
179        assert_abs_diff_eq!(cosine_similarity(&a, &b), 1.0, epsilon = 1e-6);
180    }
181
182    #[test]
183    fn cosine_identical_vectors() {
184        let a = vec![1.0f32, 1.0, 1.0];
185        assert_abs_diff_eq!(cosine_similarity(&a, &a), 0.0, epsilon = 1e-6);
186    }
187
188    #[test]
189    fn euclidean_known_distance() {
190        let a = vec![0.0f32, 0.0, 0.0];
191        let b = vec![3.0f32, 4.0, 0.0];
192        assert_abs_diff_eq!(euclidean_distance(&a, &b), 5.0, epsilon = 1e-6);
193    }
194
195    #[test]
196    fn euclidean_same_point() {
197        let a = vec![1.0f32, 2.0, 3.0];
198        assert_abs_diff_eq!(euclidean_distance(&a, &a), 0.0, epsilon = 1e-6);
199    }
200
201    #[test]
202    fn dot_product_known() {
203        let a = vec![1.0f32, 2.0, 3.0];
204        let b = vec![4.0f32, 5.0, 6.0];
205        assert_abs_diff_eq!(dot_product(&a, &b), -32.0, epsilon = 1e-6);
206    }
207
208    #[test]
209    fn flat_index_insert_search() {
210        let idx = FlatIndex::new(2, DistanceMetric::Euclidean);
211        idx.insert(0, vec![0.0, 0.0]).unwrap();
212        idx.insert(1, vec![1.0, 1.0]).unwrap();
213        idx.insert(2, vec![10.0, 10.0]).unwrap();
214        let results = idx.search(&[0.1, 0.1], 2).unwrap();
215        assert_eq!(results.len(), 2);
216        assert_eq!(results[0].0, 0);
217    }
218
219    #[test]
220    fn flat_index_delete() {
221        let idx = FlatIndex::new(2, DistanceMetric::Euclidean);
222        idx.insert(42, vec![1.0, 1.0]).unwrap();
223        assert_eq!(idx.len(), 1);
224        assert!(idx.delete(42).unwrap());
225        assert_eq!(idx.len(), 0);
226    }
227
228    #[test]
229    fn dimension_mismatch_returns_error() {
230        let idx = FlatIndex::new(3, DistanceMetric::Euclidean);
231        let err = idx.insert(0, vec![1.0, 2.0]).unwrap_err();
232        assert!(err.is_dimension_mismatch());
233    }
234}