Skip to main content

ck_ann/
lib.rs

1use anyhow::{Result, bail};
2use serde::{Deserialize, Serialize};
3use std::path::Path;
4
5pub trait AnnIndex: Send + Sync {
6    fn build(vectors: &[Vec<f32>]) -> Result<Self>
7    where
8        Self: Sized;
9    fn search(&self, query: &[f32], topk: usize) -> Result<Vec<(u32, f32)>>;
10    fn add(&mut self, id: u32, vector: &[f32]) -> Result<()>;
11    fn save(&self, path: &Path) -> Result<()>;
12    fn load(path: &Path) -> Result<Self>
13    where
14        Self: Sized;
15}
16
17pub fn create_index(_backend: Option<&str>) -> Result<Box<dyn AnnIndex>> {
18    Ok(Box::new(SimpleIndex::new()?))
19}
20
21#[derive(Serialize, Deserialize)]
22pub struct SimpleIndex {
23    vectors: Vec<Vec<f32>>,
24    ids: Vec<u32>,
25    dim: usize,
26}
27
28impl SimpleIndex {
29    pub fn new() -> Result<Self> {
30        Ok(Self {
31            vectors: Vec::new(),
32            ids: Vec::new(),
33            dim: 0,
34        })
35    }
36
37    fn cosine_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
38        let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
39        let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
40        let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
41
42        if norm_a == 0.0 || norm_b == 0.0 {
43            0.0
44        } else {
45            dot_product / (norm_a * norm_b)
46        }
47    }
48}
49
50impl AnnIndex for SimpleIndex {
51    fn build(vectors: &[Vec<f32>]) -> Result<Self>
52    where
53        Self: Sized,
54    {
55        if vectors.is_empty() {
56            return Self::new();
57        }
58
59        let dim = vectors[0].len();
60        if dim == 0 {
61            bail!(
62                "Embedding vectors are empty. The embedding model returned 0 values per vector. Re-run the command with a supported embedding model or rebuild the index."
63            );
64        }
65
66        for (i, vector) in vectors.iter().enumerate() {
67            if vector.len() != dim {
68                bail!(
69                    "Embedding size mismatch while building index: expected {dim} values but vector #{i} has {}. This usually means different embedding models were mixed. Clean the index (`ck --clean .`) and rebuild with a single model, or rerun your command using the same `--model` you originally indexed with.",
70                    vector.len()
71                );
72            }
73        }
74
75        let ids: Vec<u32> = (0..vectors.len() as u32).collect();
76
77        Ok(Self {
78            vectors: vectors.to_vec(),
79            ids,
80            dim,
81        })
82    }
83
84    fn search(&self, query: &[f32], topk: usize) -> Result<Vec<(u32, f32)>> {
85        if self.dim == 0 {
86            bail!(
87                "The ANN index is empty. Reindex the repository before running semantic search (`ck --index`)."
88            );
89        }
90
91        if query.len() != self.dim {
92            bail!(
93                "Embedding size mismatch during search: this index stores vectors with {expected} values, but the query provided {actual}. This happens when different embedding models are mixed. Re-run the command with the original model or clean the index (`ck --clean .`) and rebuild with a single model.",
94                expected = self.dim,
95                actual = query.len()
96            );
97        }
98
99        let mut similarities: Vec<_> = self
100            .vectors
101            .iter()
102            .zip(&self.ids)
103            .map(|(vector, &id)| {
104                let similarity = self.cosine_similarity(query, vector);
105                (id, similarity)
106            })
107            .collect();
108
109        similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
110        similarities.truncate(topk);
111        Ok(similarities)
112    }
113
114    fn add(&mut self, id: u32, vector: &[f32]) -> Result<()> {
115        if self.dim == 0 {
116            self.dim = vector.len();
117        }
118
119        if vector.len() != self.dim {
120            bail!(
121                "Embedding size mismatch while updating index: expected {} values but received {}. To switch models, clean the index (`ck --clean .`) and rebuild with the new model. Otherwise rerun your command using the original `--model`.",
122                self.dim,
123                vector.len()
124            );
125        }
126
127        self.vectors.push(vector.to_vec());
128        self.ids.push(id);
129        Ok(())
130    }
131
132    fn save(&self, path: &Path) -> Result<()> {
133        let data = bincode::serialize(self)?;
134        std::fs::write(path, data)?;
135        Ok(())
136    }
137
138    fn load(path: &Path) -> Result<Self>
139    where
140        Self: Sized,
141    {
142        let data = std::fs::read(path)?;
143        let index: Self = bincode::deserialize(&data)?;
144        Ok(index)
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151    use tempfile::TempDir;
152
153    #[test]
154    fn test_simple_index_new() {
155        let index = SimpleIndex::new().unwrap();
156        assert_eq!(index.vectors.len(), 0);
157        assert_eq!(index.ids.len(), 0);
158        assert_eq!(index.dim, 0);
159    }
160
161    #[test]
162    fn test_simple_index_build_empty() {
163        let vectors: Vec<Vec<f32>> = vec![];
164        let index = SimpleIndex::build(&vectors).unwrap();
165        assert_eq!(index.vectors.len(), 0);
166        assert_eq!(index.dim, 0);
167    }
168
169    #[test]
170    fn test_simple_index_build() {
171        let vectors = vec![
172            vec![1.0, 0.0, 0.0],
173            vec![0.0, 1.0, 0.0],
174            vec![0.0, 0.0, 1.0],
175        ];
176
177        let index = SimpleIndex::build(&vectors).unwrap();
178        assert_eq!(index.vectors.len(), 3);
179        assert_eq!(index.ids.len(), 3);
180        assert_eq!(index.dim, 3);
181        assert_eq!(index.ids, vec![0, 1, 2]);
182    }
183
184    #[test]
185    fn test_cosine_similarity() {
186        let index = SimpleIndex::new().unwrap();
187
188        // Identical vectors should have similarity 1.0
189        let a = vec![1.0, 2.0, 3.0];
190        let b = vec![1.0, 2.0, 3.0];
191        let sim = index.cosine_similarity(&a, &b);
192        assert!((sim - 1.0).abs() < 1e-6);
193
194        // Orthogonal vectors should have similarity 0.0
195        let a = vec![1.0, 0.0];
196        let b = vec![0.0, 1.0];
197        let sim = index.cosine_similarity(&a, &b);
198        assert!((sim - 0.0).abs() < 1e-6);
199
200        // Opposite vectors should have similarity -1.0
201        let a = vec![1.0, 0.0];
202        let b = vec![-1.0, 0.0];
203        let sim = index.cosine_similarity(&a, &b);
204        assert!((sim - (-1.0)).abs() < 1e-6);
205    }
206
207    #[test]
208    fn test_cosine_similarity_zero_vectors() {
209        let index = SimpleIndex::new().unwrap();
210
211        let a = vec![0.0, 0.0, 0.0];
212        let b = vec![1.0, 2.0, 3.0];
213        let sim = index.cosine_similarity(&a, &b);
214        assert_eq!(sim, 0.0);
215
216        let a = vec![1.0, 2.0, 3.0];
217        let b = vec![0.0, 0.0, 0.0];
218        let sim = index.cosine_similarity(&a, &b);
219        assert_eq!(sim, 0.0);
220    }
221
222    #[test]
223    fn test_search() {
224        let vectors = vec![
225            vec![1.0, 0.0, 0.0], // id=0
226            vec![0.0, 1.0, 0.0], // id=1
227            vec![0.5, 0.5, 0.0], // id=2
228        ];
229
230        let index = SimpleIndex::build(&vectors).unwrap();
231
232        // Query closest to first vector
233        let query = vec![0.9, 0.1, 0.0];
234        let results = index.search(&query, 2).unwrap();
235
236        assert_eq!(results.len(), 2);
237        assert_eq!(results[0].0, 0); // First result should be vector 0
238        assert!(results[0].1 > results[1].1); // First should have higher similarity
239    }
240
241    #[test]
242    fn test_search_empty_index() {
243        let vectors: Vec<Vec<f32>> = vec![];
244        let index = SimpleIndex::build(&vectors).unwrap();
245
246        let query = vec![1.0, 0.0];
247        let err = index.search(&query, 5).unwrap_err();
248        assert!(err.to_string().contains("The ANN index is empty"));
249    }
250
251    #[test]
252    fn test_search_topk_limit() {
253        let vectors = vec![
254            vec![1.0, 0.0],
255            vec![0.9, 0.1],
256            vec![0.8, 0.2],
257            vec![0.7, 0.3],
258            vec![0.6, 0.4],
259        ];
260
261        let index = SimpleIndex::build(&vectors).unwrap();
262
263        let query = vec![1.0, 0.0];
264        let results = index.search(&query, 3).unwrap();
265
266        assert_eq!(results.len(), 3);
267        // Results should be sorted by similarity (descending)
268        for i in 1..results.len() {
269            assert!(results[i - 1].1 >= results[i].1);
270        }
271    }
272
273    #[test]
274    fn test_add() {
275        let mut index = SimpleIndex::new().unwrap();
276
277        index.add(100, &[1.0, 2.0, 3.0]).unwrap();
278        assert_eq!(index.vectors.len(), 1);
279        assert_eq!(index.ids.len(), 1);
280        assert_eq!(index.ids[0], 100);
281        assert_eq!(index.dim, 3);
282
283        index.add(200, &[4.0, 5.0, 6.0]).unwrap();
284        assert_eq!(index.vectors.len(), 2);
285        assert_eq!(index.ids.len(), 2);
286        assert_eq!(index.ids[1], 200);
287    }
288
289    #[test]
290    fn test_save_and_load() {
291        let temp_dir = TempDir::new().unwrap();
292        let index_path = temp_dir.path().join("test_index.bin");
293
294        // Create and save index
295        let vectors = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]];
296        let index = SimpleIndex::build(&vectors).unwrap();
297        index.save(&index_path).unwrap();
298
299        // Load index
300        let loaded_index = SimpleIndex::load(&index_path).unwrap();
301
302        assert_eq!(loaded_index.vectors.len(), index.vectors.len());
303        assert_eq!(loaded_index.ids.len(), index.ids.len());
304        assert_eq!(loaded_index.dim, index.dim);
305
306        // Test that loaded index works the same
307        let query = vec![1.0, 2.0, 3.0];
308        let original_results = index.search(&query, 2).unwrap();
309        let loaded_results = loaded_index.search(&query, 2).unwrap();
310
311        assert_eq!(original_results.len(), loaded_results.len());
312        for (orig, loaded) in original_results.iter().zip(&loaded_results) {
313            assert_eq!(orig.0, loaded.0);
314            assert!((orig.1 - loaded.1).abs() < 1e-6);
315        }
316    }
317
318    #[test]
319    fn test_create_index() {
320        let _index = create_index(None).unwrap();
321
322        // Should create a SimpleIndex
323        // We can't directly test the type, but we can test the interface
324        let vectors = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
325        let index = SimpleIndex::build(&vectors).unwrap();
326
327        let query = vec![1.0, 0.0];
328        let results = index.search(&query, 1).unwrap();
329        assert!(!results.is_empty());
330    }
331
332    #[test]
333    fn test_load_nonexistent_file() {
334        let result = SimpleIndex::load(&std::path::PathBuf::from("nonexistent.bin"));
335        assert!(result.is_err());
336    }
337
338    #[test]
339    fn test_ann_index_trait() {
340        let vectors = vec![vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0]];
341
342        let mut index: Box<dyn AnnIndex> = Box::new(SimpleIndex::build(&vectors).unwrap());
343
344        // Test search through trait
345        let query = vec![1.0, 0.0, 0.0];
346        let results = index.search(&query, 1).unwrap();
347        assert_eq!(results.len(), 1);
348        assert_eq!(results[0].0, 0);
349
350        // Test add through trait
351        index.add(99, &[0.0, 0.0, 1.0]).unwrap();
352        let results = index.search(&[0.0, 0.0, 1.0], 1).unwrap();
353        assert_eq!(results.len(), 1);
354        assert_eq!(results[0].0, 99);
355    }
356
357    #[test]
358    fn test_build_rejects_mismatched_dimensions() {
359        let vectors = vec![vec![1.0, 0.0], vec![0.0, 1.0, 0.0]];
360        let err = match SimpleIndex::build(&vectors) {
361            Ok(_) => panic!("Expected build to fail for mismatched dimensions"),
362            Err(err) => err,
363        };
364        assert!(
365            err.to_string()
366                .contains("Embedding size mismatch while building index")
367        );
368    }
369
370    #[test]
371    fn test_add_rejects_mismatched_dimensions() {
372        let mut index = SimpleIndex::new().unwrap();
373        index.add(1, &[0.1, 0.2]).unwrap();
374        let err = index.add(2, &[0.1, 0.2, 0.3]).unwrap_err();
375        assert!(
376            err.to_string()
377                .contains("Embedding size mismatch while updating index")
378        );
379    }
380
381    #[test]
382    fn test_search_rejects_mismatched_query() {
383        let vectors = vec![vec![1.0, 0.0, 0.0]];
384        let index = SimpleIndex::build(&vectors).unwrap();
385        let err = index.search(&[1.0, 0.0], 1).unwrap_err();
386        assert!(
387            err.to_string()
388                .contains("Embedding size mismatch during search")
389        );
390    }
391}