ck_ann/
lib.rs

1use anyhow::Result;
2use std::path::Path;
3use serde::{Deserialize, Serialize};
4
5pub trait AnnIndex: Send + Sync {
6    fn build(vectors: &[Vec<f32>]) -> Result<Self> where Self: Sized;
7    fn search(&self, query: &[f32], topk: usize) -> Vec<(u32, f32)>;
8    fn add(&mut self, id: u32, vector: &[f32]);
9    fn save(&self, path: &Path) -> Result<()>;
10    fn load(path: &Path) -> Result<Self> where Self: Sized;
11}
12
13pub fn create_index(_backend: Option<&str>) -> Result<Box<dyn AnnIndex>> {
14    Ok(Box::new(SimpleIndex::new()?))
15}
16
17#[derive(Serialize, Deserialize)]
18pub struct SimpleIndex {
19    vectors: Vec<Vec<f32>>,
20    ids: Vec<u32>,
21    dim: usize,
22}
23
24impl SimpleIndex {
25    pub fn new() -> Result<Self> {
26        Ok(Self {
27            vectors: Vec::new(),
28            ids: Vec::new(),
29            dim: 0,
30        })
31    }
32    
33    fn cosine_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
34        let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
35        let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
36        let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
37        
38        if norm_a == 0.0 || norm_b == 0.0 {
39            0.0
40        } else {
41            dot_product / (norm_a * norm_b)
42        }
43    }
44}
45
46impl AnnIndex for SimpleIndex {
47    fn build(vectors: &[Vec<f32>]) -> Result<Self> where Self: Sized {
48        if vectors.is_empty() {
49            return Ok(Self::new()?);
50        }
51        
52        let dim = vectors[0].len();
53        let ids: Vec<u32> = (0..vectors.len() as u32).collect();
54        
55        Ok(Self {
56            vectors: vectors.to_vec(),
57            ids,
58            dim,
59        })
60    }
61    
62    fn search(&self, query: &[f32], topk: usize) -> Vec<(u32, f32)> {
63        let mut similarities: Vec<_> = self.vectors
64            .iter()
65            .zip(&self.ids)
66            .map(|(vector, &id)| {
67                let similarity = self.cosine_similarity(query, vector);
68                (id, similarity)
69            })
70            .collect();
71        
72        similarities.sort_by(|a, b| {
73            b.1
74                .partial_cmp(&a.1)
75                .unwrap_or(std::cmp::Ordering::Equal)
76        });
77        similarities.truncate(topk);
78        similarities
79    }
80    
81    fn add(&mut self, id: u32, vector: &[f32]) {
82        if self.dim == 0 {
83            self.dim = vector.len();
84        }
85        
86        self.vectors.push(vector.to_vec());
87        self.ids.push(id);
88    }
89    
90    fn save(&self, path: &Path) -> Result<()> {
91        let data = bincode::serialize(self)?;
92        std::fs::write(path, data)?;
93        Ok(())
94    }
95    
96    fn load(path: &Path) -> Result<Self> where Self: Sized {
97        let data = std::fs::read(path)?;
98        let index: Self = bincode::deserialize(&data)?;
99        Ok(index)
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106    use tempfile::TempDir;
107
108    #[test]
109    fn test_simple_index_new() {
110        let index = SimpleIndex::new().unwrap();
111        assert_eq!(index.vectors.len(), 0);
112        assert_eq!(index.ids.len(), 0);
113        assert_eq!(index.dim, 0);
114    }
115
116    #[test]
117    fn test_simple_index_build_empty() {
118        let vectors: Vec<Vec<f32>> = vec![];
119        let index = SimpleIndex::build(&vectors).unwrap();
120        assert_eq!(index.vectors.len(), 0);
121        assert_eq!(index.dim, 0);
122    }
123
124    #[test]
125    fn test_simple_index_build() {
126        let vectors = vec![
127            vec![1.0, 0.0, 0.0],
128            vec![0.0, 1.0, 0.0],
129            vec![0.0, 0.0, 1.0],
130        ];
131        
132        let index = SimpleIndex::build(&vectors).unwrap();
133        assert_eq!(index.vectors.len(), 3);
134        assert_eq!(index.ids.len(), 3);
135        assert_eq!(index.dim, 3);
136        assert_eq!(index.ids, vec![0, 1, 2]);
137    }
138
139    #[test]
140    fn test_cosine_similarity() {
141        let index = SimpleIndex::new().unwrap();
142        
143        // Identical vectors should have similarity 1.0
144        let a = vec![1.0, 2.0, 3.0];
145        let b = vec![1.0, 2.0, 3.0];
146        let sim = index.cosine_similarity(&a, &b);
147        assert!((sim - 1.0).abs() < 1e-6);
148        
149        // Orthogonal vectors should have similarity 0.0
150        let a = vec![1.0, 0.0];
151        let b = vec![0.0, 1.0];
152        let sim = index.cosine_similarity(&a, &b);
153        assert!((sim - 0.0).abs() < 1e-6);
154        
155        // Opposite vectors should have similarity -1.0
156        let a = vec![1.0, 0.0];
157        let b = vec![-1.0, 0.0];
158        let sim = index.cosine_similarity(&a, &b);
159        assert!((sim - (-1.0)).abs() < 1e-6);
160    }
161
162    #[test]
163    fn test_cosine_similarity_zero_vectors() {
164        let index = SimpleIndex::new().unwrap();
165        
166        let a = vec![0.0, 0.0, 0.0];
167        let b = vec![1.0, 2.0, 3.0];
168        let sim = index.cosine_similarity(&a, &b);
169        assert_eq!(sim, 0.0);
170        
171        let a = vec![1.0, 2.0, 3.0];
172        let b = vec![0.0, 0.0, 0.0];
173        let sim = index.cosine_similarity(&a, &b);
174        assert_eq!(sim, 0.0);
175    }
176
177    #[test]
178    fn test_search() {
179        let vectors = vec![
180            vec![1.0, 0.0, 0.0], // id=0
181            vec![0.0, 1.0, 0.0], // id=1  
182            vec![0.5, 0.5, 0.0], // id=2
183        ];
184        
185        let index = SimpleIndex::build(&vectors).unwrap();
186        
187        // Query closest to first vector
188        let query = vec![0.9, 0.1, 0.0];
189        let results = index.search(&query, 2);
190        
191        assert_eq!(results.len(), 2);
192        assert_eq!(results[0].0, 0); // First result should be vector 0
193        assert!(results[0].1 > results[1].1); // First should have higher similarity
194    }
195
196    #[test]
197    fn test_search_empty_index() {
198        let vectors: Vec<Vec<f32>> = vec![];
199        let index = SimpleIndex::build(&vectors).unwrap();
200        
201        let query = vec![1.0, 0.0];
202        let results = index.search(&query, 5);
203        assert_eq!(results.len(), 0);
204    }
205
206    #[test]
207    fn test_search_topk_limit() {
208        let vectors = vec![
209            vec![1.0, 0.0],
210            vec![0.9, 0.1],
211            vec![0.8, 0.2],
212            vec![0.7, 0.3],
213            vec![0.6, 0.4],
214        ];
215        
216        let index = SimpleIndex::build(&vectors).unwrap();
217        
218        let query = vec![1.0, 0.0];
219        let results = index.search(&query, 3);
220        
221        assert_eq!(results.len(), 3);
222        // Results should be sorted by similarity (descending)
223        for i in 1..results.len() {
224            assert!(results[i-1].1 >= results[i].1);
225        }
226    }
227
228    #[test]
229    fn test_add() {
230        let mut index = SimpleIndex::new().unwrap();
231        
232        index.add(100, &vec![1.0, 2.0, 3.0]);
233        assert_eq!(index.vectors.len(), 1);
234        assert_eq!(index.ids.len(), 1);
235        assert_eq!(index.ids[0], 100);
236        assert_eq!(index.dim, 3);
237        
238        index.add(200, &vec![4.0, 5.0, 6.0]);
239        assert_eq!(index.vectors.len(), 2);
240        assert_eq!(index.ids.len(), 2);
241        assert_eq!(index.ids[1], 200);
242    }
243
244    #[test]
245    fn test_save_and_load() {
246        let temp_dir = TempDir::new().unwrap();
247        let index_path = temp_dir.path().join("test_index.bin");
248        
249        // Create and save index
250        let vectors = vec![
251            vec![1.0, 2.0, 3.0],
252            vec![4.0, 5.0, 6.0],
253        ];
254        let index = SimpleIndex::build(&vectors).unwrap();
255        index.save(&index_path).unwrap();
256        
257        // Load index
258        let loaded_index = SimpleIndex::load(&index_path).unwrap();
259        
260        assert_eq!(loaded_index.vectors.len(), index.vectors.len());
261        assert_eq!(loaded_index.ids.len(), index.ids.len());
262        assert_eq!(loaded_index.dim, index.dim);
263        
264        // Test that loaded index works the same
265        let query = vec![1.0, 2.0, 3.0];
266        let original_results = index.search(&query, 2);
267        let loaded_results = loaded_index.search(&query, 2);
268        
269        assert_eq!(original_results.len(), loaded_results.len());
270        for (orig, loaded) in original_results.iter().zip(&loaded_results) {
271            assert_eq!(orig.0, loaded.0);
272            assert!((orig.1 - loaded.1).abs() < 1e-6);
273        }
274    }
275
276    #[test]
277    fn test_create_index() {
278        let _index = create_index(None).unwrap();
279        
280        // Should create a SimpleIndex
281        // We can't directly test the type, but we can test the interface
282        let vectors = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
283        let index = SimpleIndex::build(&vectors).unwrap();
284        
285        let query = vec![1.0, 0.0];
286        let results = index.search(&query, 1);
287        assert!(!results.is_empty());
288    }
289
290    #[test]
291    fn test_load_nonexistent_file() {
292        let result = SimpleIndex::load(&std::path::PathBuf::from("nonexistent.bin"));
293        assert!(result.is_err());
294    }
295
296    #[test]
297    fn test_ann_index_trait() {
298        let vectors = vec![
299            vec![1.0, 0.0, 0.0],
300            vec![0.0, 1.0, 0.0],
301        ];
302        
303        let mut index: Box<dyn AnnIndex> = Box::new(SimpleIndex::build(&vectors).unwrap());
304        
305        // Test search through trait
306        let query = vec![1.0, 0.0, 0.0];
307        let results = index.search(&query, 1);
308        assert_eq!(results.len(), 1);
309        assert_eq!(results[0].0, 0);
310        
311        // Test add through trait
312        index.add(99, &vec![0.0, 0.0, 1.0]);
313        let results = index.search(&vec![0.0, 0.0, 1.0], 1);
314        assert_eq!(results.len(), 1);
315        assert_eq!(results[0].0, 99);
316    }
317}