ck_ann/
lib.rs

1use anyhow::Result;
2use serde::{Deserialize, Serialize};
3use std::path::Path;
4
5pub trait AnnIndex: Send + Sync {
6    fn build(vectors: &[Vec<f32>]) -> Result<Self>
7    where
8        Self: Sized;
9    fn search(&self, query: &[f32], topk: usize) -> Vec<(u32, f32)>;
10    fn add(&mut self, id: u32, vector: &[f32]);
11    fn save(&self, path: &Path) -> Result<()>;
12    fn load(path: &Path) -> Result<Self>
13    where
14        Self: Sized;
15}
16
17pub fn create_index(_backend: Option<&str>) -> Result<Box<dyn AnnIndex>> {
18    Ok(Box::new(SimpleIndex::new()?))
19}
20
21#[derive(Serialize, Deserialize)]
22pub struct SimpleIndex {
23    vectors: Vec<Vec<f32>>,
24    ids: Vec<u32>,
25    dim: usize,
26}
27
28impl SimpleIndex {
29    pub fn new() -> Result<Self> {
30        Ok(Self {
31            vectors: Vec::new(),
32            ids: Vec::new(),
33            dim: 0,
34        })
35    }
36
37    fn cosine_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
38        let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
39        let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
40        let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
41
42        if norm_a == 0.0 || norm_b == 0.0 {
43            0.0
44        } else {
45            dot_product / (norm_a * norm_b)
46        }
47    }
48}
49
50impl AnnIndex for SimpleIndex {
51    fn build(vectors: &[Vec<f32>]) -> Result<Self>
52    where
53        Self: Sized,
54    {
55        if vectors.is_empty() {
56            return Self::new();
57        }
58
59        let dim = vectors[0].len();
60        let ids: Vec<u32> = (0..vectors.len() as u32).collect();
61
62        Ok(Self {
63            vectors: vectors.to_vec(),
64            ids,
65            dim,
66        })
67    }
68
69    fn search(&self, query: &[f32], topk: usize) -> Vec<(u32, f32)> {
70        let mut similarities: Vec<_> = self
71            .vectors
72            .iter()
73            .zip(&self.ids)
74            .map(|(vector, &id)| {
75                let similarity = self.cosine_similarity(query, vector);
76                (id, similarity)
77            })
78            .collect();
79
80        similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
81        similarities.truncate(topk);
82        similarities
83    }
84
85    fn add(&mut self, id: u32, vector: &[f32]) {
86        if self.dim == 0 {
87            self.dim = vector.len();
88        }
89
90        self.vectors.push(vector.to_vec());
91        self.ids.push(id);
92    }
93
94    fn save(&self, path: &Path) -> Result<()> {
95        let data = bincode::serialize(self)?;
96        std::fs::write(path, data)?;
97        Ok(())
98    }
99
100    fn load(path: &Path) -> Result<Self>
101    where
102        Self: Sized,
103    {
104        let data = std::fs::read(path)?;
105        let index: Self = bincode::deserialize(&data)?;
106        Ok(index)
107    }
108}
109
110#[cfg(test)]
111mod tests {
112    use super::*;
113    use tempfile::TempDir;
114
115    #[test]
116    fn test_simple_index_new() {
117        let index = SimpleIndex::new().unwrap();
118        assert_eq!(index.vectors.len(), 0);
119        assert_eq!(index.ids.len(), 0);
120        assert_eq!(index.dim, 0);
121    }
122
123    #[test]
124    fn test_simple_index_build_empty() {
125        let vectors: Vec<Vec<f32>> = vec![];
126        let index = SimpleIndex::build(&vectors).unwrap();
127        assert_eq!(index.vectors.len(), 0);
128        assert_eq!(index.dim, 0);
129    }
130
131    #[test]
132    fn test_simple_index_build() {
133        let vectors = vec![
134            vec![1.0, 0.0, 0.0],
135            vec![0.0, 1.0, 0.0],
136            vec![0.0, 0.0, 1.0],
137        ];
138
139        let index = SimpleIndex::build(&vectors).unwrap();
140        assert_eq!(index.vectors.len(), 3);
141        assert_eq!(index.ids.len(), 3);
142        assert_eq!(index.dim, 3);
143        assert_eq!(index.ids, vec![0, 1, 2]);
144    }
145
146    #[test]
147    fn test_cosine_similarity() {
148        let index = SimpleIndex::new().unwrap();
149
150        // Identical vectors should have similarity 1.0
151        let a = vec![1.0, 2.0, 3.0];
152        let b = vec![1.0, 2.0, 3.0];
153        let sim = index.cosine_similarity(&a, &b);
154        assert!((sim - 1.0).abs() < 1e-6);
155
156        // Orthogonal vectors should have similarity 0.0
157        let a = vec![1.0, 0.0];
158        let b = vec![0.0, 1.0];
159        let sim = index.cosine_similarity(&a, &b);
160        assert!((sim - 0.0).abs() < 1e-6);
161
162        // Opposite vectors should have similarity -1.0
163        let a = vec![1.0, 0.0];
164        let b = vec![-1.0, 0.0];
165        let sim = index.cosine_similarity(&a, &b);
166        assert!((sim - (-1.0)).abs() < 1e-6);
167    }
168
169    #[test]
170    fn test_cosine_similarity_zero_vectors() {
171        let index = SimpleIndex::new().unwrap();
172
173        let a = vec![0.0, 0.0, 0.0];
174        let b = vec![1.0, 2.0, 3.0];
175        let sim = index.cosine_similarity(&a, &b);
176        assert_eq!(sim, 0.0);
177
178        let a = vec![1.0, 2.0, 3.0];
179        let b = vec![0.0, 0.0, 0.0];
180        let sim = index.cosine_similarity(&a, &b);
181        assert_eq!(sim, 0.0);
182    }
183
184    #[test]
185    fn test_search() {
186        let vectors = vec![
187            vec![1.0, 0.0, 0.0], // id=0
188            vec![0.0, 1.0, 0.0], // id=1
189            vec![0.5, 0.5, 0.0], // id=2
190        ];
191
192        let index = SimpleIndex::build(&vectors).unwrap();
193
194        // Query closest to first vector
195        let query = vec![0.9, 0.1, 0.0];
196        let results = index.search(&query, 2);
197
198        assert_eq!(results.len(), 2);
199        assert_eq!(results[0].0, 0); // First result should be vector 0
200        assert!(results[0].1 > results[1].1); // First should have higher similarity
201    }
202
203    #[test]
204    fn test_search_empty_index() {
205        let vectors: Vec<Vec<f32>> = vec![];
206        let index = SimpleIndex::build(&vectors).unwrap();
207
208        let query = vec![1.0, 0.0];
209        let results = index.search(&query, 5);
210        assert_eq!(results.len(), 0);
211    }
212
213    #[test]
214    fn test_search_topk_limit() {
215        let vectors = vec![
216            vec![1.0, 0.0],
217            vec![0.9, 0.1],
218            vec![0.8, 0.2],
219            vec![0.7, 0.3],
220            vec![0.6, 0.4],
221        ];
222
223        let index = SimpleIndex::build(&vectors).unwrap();
224
225        let query = vec![1.0, 0.0];
226        let results = index.search(&query, 3);
227
228        assert_eq!(results.len(), 3);
229        // Results should be sorted by similarity (descending)
230        for i in 1..results.len() {
231            assert!(results[i - 1].1 >= results[i].1);
232        }
233    }
234
235    #[test]
236    fn test_add() {
237        let mut index = SimpleIndex::new().unwrap();
238
239        index.add(100, &[1.0, 2.0, 3.0]);
240        assert_eq!(index.vectors.len(), 1);
241        assert_eq!(index.ids.len(), 1);
242        assert_eq!(index.ids[0], 100);
243        assert_eq!(index.dim, 3);
244
245        index.add(200, &[4.0, 5.0, 6.0]);
246        assert_eq!(index.vectors.len(), 2);
247        assert_eq!(index.ids.len(), 2);
248        assert_eq!(index.ids[1], 200);
249    }
250
251    #[test]
252    fn test_save_and_load() {
253        let temp_dir = TempDir::new().unwrap();
254        let index_path = temp_dir.path().join("test_index.bin");
255
256        // Create and save index
257        let vectors = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]];
258        let index = SimpleIndex::build(&vectors).unwrap();
259        index.save(&index_path).unwrap();
260
261        // Load index
262        let loaded_index = SimpleIndex::load(&index_path).unwrap();
263
264        assert_eq!(loaded_index.vectors.len(), index.vectors.len());
265        assert_eq!(loaded_index.ids.len(), index.ids.len());
266        assert_eq!(loaded_index.dim, index.dim);
267
268        // Test that loaded index works the same
269        let query = vec![1.0, 2.0, 3.0];
270        let original_results = index.search(&query, 2);
271        let loaded_results = loaded_index.search(&query, 2);
272
273        assert_eq!(original_results.len(), loaded_results.len());
274        for (orig, loaded) in original_results.iter().zip(&loaded_results) {
275            assert_eq!(orig.0, loaded.0);
276            assert!((orig.1 - loaded.1).abs() < 1e-6);
277        }
278    }
279
280    #[test]
281    fn test_create_index() {
282        let _index = create_index(None).unwrap();
283
284        // Should create a SimpleIndex
285        // We can't directly test the type, but we can test the interface
286        let vectors = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
287        let index = SimpleIndex::build(&vectors).unwrap();
288
289        let query = vec![1.0, 0.0];
290        let results = index.search(&query, 1);
291        assert!(!results.is_empty());
292    }
293
294    #[test]
295    fn test_load_nonexistent_file() {
296        let result = SimpleIndex::load(&std::path::PathBuf::from("nonexistent.bin"));
297        assert!(result.is_err());
298    }
299
300    #[test]
301    fn test_ann_index_trait() {
302        let vectors = vec![vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0]];
303
304        let mut index: Box<dyn AnnIndex> = Box::new(SimpleIndex::build(&vectors).unwrap());
305
306        // Test search through trait
307        let query = vec![1.0, 0.0, 0.0];
308        let results = index.search(&query, 1);
309        assert_eq!(results.len(), 1);
310        assert_eq!(results[0].0, 0);
311
312        // Test add through trait
313        index.add(99, &[0.0, 0.0, 1.0]);
314        let results = index.search(&[0.0, 0.0, 1.0], 1);
315        assert_eq!(results.len(), 1);
316        assert_eq!(results[0].0, 99);
317    }
318}