1use anyhow::{Result, bail};
2use serde::{Deserialize, Serialize};
3use std::path::Path;
4
5pub trait AnnIndex: Send + Sync {
6 fn build(vectors: &[Vec<f32>]) -> Result<Self>
7 where
8 Self: Sized;
9 fn search(&self, query: &[f32], topk: usize) -> Result<Vec<(u32, f32)>>;
10 fn add(&mut self, id: u32, vector: &[f32]) -> Result<()>;
11 fn save(&self, path: &Path) -> Result<()>;
12 fn load(path: &Path) -> Result<Self>
13 where
14 Self: Sized;
15}
16
17pub fn create_index(_backend: Option<&str>) -> Result<Box<dyn AnnIndex>> {
18 Ok(Box::new(SimpleIndex::new()?))
19}
20
21#[derive(Serialize, Deserialize)]
22pub struct SimpleIndex {
23 vectors: Vec<Vec<f32>>,
24 ids: Vec<u32>,
25 dim: usize,
26}
27
28impl SimpleIndex {
29 pub fn new() -> Result<Self> {
30 Ok(Self {
31 vectors: Vec::new(),
32 ids: Vec::new(),
33 dim: 0,
34 })
35 }
36
37 fn cosine_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
38 let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
39 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
40 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
41
42 if norm_a == 0.0 || norm_b == 0.0 {
43 0.0
44 } else {
45 dot_product / (norm_a * norm_b)
46 }
47 }
48}
49
50impl AnnIndex for SimpleIndex {
51 fn build(vectors: &[Vec<f32>]) -> Result<Self>
52 where
53 Self: Sized,
54 {
55 if vectors.is_empty() {
56 return Self::new();
57 }
58
59 let dim = vectors[0].len();
60 if dim == 0 {
61 bail!(
62 "Embedding vectors are empty. The embedding model returned 0 values per vector. Re-run the command with a supported embedding model or rebuild the index."
63 );
64 }
65
66 for (i, vector) in vectors.iter().enumerate() {
67 if vector.len() != dim {
68 bail!(
69 "Embedding size mismatch while building index: expected {dim} values but vector #{i} has {}. This usually means different embedding models were mixed. Clean the index (`ck --clean .`) and rebuild with a single model, or rerun your command using the same `--model` you originally indexed with.",
70 vector.len()
71 );
72 }
73 }
74
75 let ids: Vec<u32> = (0..vectors.len() as u32).collect();
76
77 Ok(Self {
78 vectors: vectors.to_vec(),
79 ids,
80 dim,
81 })
82 }
83
84 fn search(&self, query: &[f32], topk: usize) -> Result<Vec<(u32, f32)>> {
85 if self.dim == 0 {
86 bail!(
87 "The ANN index is empty. Reindex the repository before running semantic search (`ck --index`)."
88 );
89 }
90
91 if query.len() != self.dim {
92 bail!(
93 "Embedding size mismatch during search: this index stores vectors with {expected} values, but the query provided {actual}. This happens when different embedding models are mixed. Re-run the command with the original model or clean the index (`ck --clean .`) and rebuild with a single model.",
94 expected = self.dim,
95 actual = query.len()
96 );
97 }
98
99 let mut similarities: Vec<_> = self
100 .vectors
101 .iter()
102 .zip(&self.ids)
103 .map(|(vector, &id)| {
104 let similarity = self.cosine_similarity(query, vector);
105 (id, similarity)
106 })
107 .collect();
108
109 similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
110 similarities.truncate(topk);
111 Ok(similarities)
112 }
113
114 fn add(&mut self, id: u32, vector: &[f32]) -> Result<()> {
115 if self.dim == 0 {
116 self.dim = vector.len();
117 }
118
119 if vector.len() != self.dim {
120 bail!(
121 "Embedding size mismatch while updating index: expected {} values but received {}. To switch models, clean the index (`ck --clean .`) and rebuild with the new model. Otherwise rerun your command using the original `--model`.",
122 self.dim,
123 vector.len()
124 );
125 }
126
127 self.vectors.push(vector.to_vec());
128 self.ids.push(id);
129 Ok(())
130 }
131
132 fn save(&self, path: &Path) -> Result<()> {
133 let data = bincode::serialize(self)?;
134 std::fs::write(path, data)?;
135 Ok(())
136 }
137
138 fn load(path: &Path) -> Result<Self>
139 where
140 Self: Sized,
141 {
142 let data = std::fs::read(path)?;
143 let index: Self = bincode::deserialize(&data)?;
144 Ok(index)
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151 use tempfile::TempDir;
152
153 #[test]
154 fn test_simple_index_new() {
155 let index = SimpleIndex::new().unwrap();
156 assert_eq!(index.vectors.len(), 0);
157 assert_eq!(index.ids.len(), 0);
158 assert_eq!(index.dim, 0);
159 }
160
161 #[test]
162 fn test_simple_index_build_empty() {
163 let vectors: Vec<Vec<f32>> = vec![];
164 let index = SimpleIndex::build(&vectors).unwrap();
165 assert_eq!(index.vectors.len(), 0);
166 assert_eq!(index.dim, 0);
167 }
168
169 #[test]
170 fn test_simple_index_build() {
171 let vectors = vec![
172 vec![1.0, 0.0, 0.0],
173 vec![0.0, 1.0, 0.0],
174 vec![0.0, 0.0, 1.0],
175 ];
176
177 let index = SimpleIndex::build(&vectors).unwrap();
178 assert_eq!(index.vectors.len(), 3);
179 assert_eq!(index.ids.len(), 3);
180 assert_eq!(index.dim, 3);
181 assert_eq!(index.ids, vec![0, 1, 2]);
182 }
183
184 #[test]
185 fn test_cosine_similarity() {
186 let index = SimpleIndex::new().unwrap();
187
188 let a = vec![1.0, 2.0, 3.0];
190 let b = vec![1.0, 2.0, 3.0];
191 let sim = index.cosine_similarity(&a, &b);
192 assert!((sim - 1.0).abs() < 1e-6);
193
194 let a = vec![1.0, 0.0];
196 let b = vec![0.0, 1.0];
197 let sim = index.cosine_similarity(&a, &b);
198 assert!((sim - 0.0).abs() < 1e-6);
199
200 let a = vec![1.0, 0.0];
202 let b = vec![-1.0, 0.0];
203 let sim = index.cosine_similarity(&a, &b);
204 assert!((sim - (-1.0)).abs() < 1e-6);
205 }
206
207 #[test]
208 fn test_cosine_similarity_zero_vectors() {
209 let index = SimpleIndex::new().unwrap();
210
211 let a = vec![0.0, 0.0, 0.0];
212 let b = vec![1.0, 2.0, 3.0];
213 let sim = index.cosine_similarity(&a, &b);
214 assert_eq!(sim, 0.0);
215
216 let a = vec![1.0, 2.0, 3.0];
217 let b = vec![0.0, 0.0, 0.0];
218 let sim = index.cosine_similarity(&a, &b);
219 assert_eq!(sim, 0.0);
220 }
221
222 #[test]
223 fn test_search() {
224 let vectors = vec![
225 vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0], vec![0.5, 0.5, 0.0], ];
229
230 let index = SimpleIndex::build(&vectors).unwrap();
231
232 let query = vec![0.9, 0.1, 0.0];
234 let results = index.search(&query, 2).unwrap();
235
236 assert_eq!(results.len(), 2);
237 assert_eq!(results[0].0, 0); assert!(results[0].1 > results[1].1); }
240
241 #[test]
242 fn test_search_empty_index() {
243 let vectors: Vec<Vec<f32>> = vec![];
244 let index = SimpleIndex::build(&vectors).unwrap();
245
246 let query = vec![1.0, 0.0];
247 let err = index.search(&query, 5).unwrap_err();
248 assert!(err.to_string().contains("The ANN index is empty"));
249 }
250
251 #[test]
252 fn test_search_topk_limit() {
253 let vectors = vec![
254 vec![1.0, 0.0],
255 vec![0.9, 0.1],
256 vec![0.8, 0.2],
257 vec![0.7, 0.3],
258 vec![0.6, 0.4],
259 ];
260
261 let index = SimpleIndex::build(&vectors).unwrap();
262
263 let query = vec![1.0, 0.0];
264 let results = index.search(&query, 3).unwrap();
265
266 assert_eq!(results.len(), 3);
267 for i in 1..results.len() {
269 assert!(results[i - 1].1 >= results[i].1);
270 }
271 }
272
273 #[test]
274 fn test_add() {
275 let mut index = SimpleIndex::new().unwrap();
276
277 index.add(100, &[1.0, 2.0, 3.0]).unwrap();
278 assert_eq!(index.vectors.len(), 1);
279 assert_eq!(index.ids.len(), 1);
280 assert_eq!(index.ids[0], 100);
281 assert_eq!(index.dim, 3);
282
283 index.add(200, &[4.0, 5.0, 6.0]).unwrap();
284 assert_eq!(index.vectors.len(), 2);
285 assert_eq!(index.ids.len(), 2);
286 assert_eq!(index.ids[1], 200);
287 }
288
289 #[test]
290 fn test_save_and_load() {
291 let temp_dir = TempDir::new().unwrap();
292 let index_path = temp_dir.path().join("test_index.bin");
293
294 let vectors = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]];
296 let index = SimpleIndex::build(&vectors).unwrap();
297 index.save(&index_path).unwrap();
298
299 let loaded_index = SimpleIndex::load(&index_path).unwrap();
301
302 assert_eq!(loaded_index.vectors.len(), index.vectors.len());
303 assert_eq!(loaded_index.ids.len(), index.ids.len());
304 assert_eq!(loaded_index.dim, index.dim);
305
306 let query = vec![1.0, 2.0, 3.0];
308 let original_results = index.search(&query, 2).unwrap();
309 let loaded_results = loaded_index.search(&query, 2).unwrap();
310
311 assert_eq!(original_results.len(), loaded_results.len());
312 for (orig, loaded) in original_results.iter().zip(&loaded_results) {
313 assert_eq!(orig.0, loaded.0);
314 assert!((orig.1 - loaded.1).abs() < 1e-6);
315 }
316 }
317
318 #[test]
319 fn test_create_index() {
320 let _index = create_index(None).unwrap();
321
322 let vectors = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
325 let index = SimpleIndex::build(&vectors).unwrap();
326
327 let query = vec![1.0, 0.0];
328 let results = index.search(&query, 1).unwrap();
329 assert!(!results.is_empty());
330 }
331
332 #[test]
333 fn test_load_nonexistent_file() {
334 let result = SimpleIndex::load(&std::path::PathBuf::from("nonexistent.bin"));
335 assert!(result.is_err());
336 }
337
338 #[test]
339 fn test_ann_index_trait() {
340 let vectors = vec![vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0]];
341
342 let mut index: Box<dyn AnnIndex> = Box::new(SimpleIndex::build(&vectors).unwrap());
343
344 let query = vec![1.0, 0.0, 0.0];
346 let results = index.search(&query, 1).unwrap();
347 assert_eq!(results.len(), 1);
348 assert_eq!(results[0].0, 0);
349
350 index.add(99, &[0.0, 0.0, 1.0]).unwrap();
352 let results = index.search(&[0.0, 0.0, 1.0], 1).unwrap();
353 assert_eq!(results.len(), 1);
354 assert_eq!(results[0].0, 99);
355 }
356
357 #[test]
358 fn test_build_rejects_mismatched_dimensions() {
359 let vectors = vec![vec![1.0, 0.0], vec![0.0, 1.0, 0.0]];
360 let err = match SimpleIndex::build(&vectors) {
361 Ok(_) => panic!("Expected build to fail for mismatched dimensions"),
362 Err(err) => err,
363 };
364 assert!(
365 err.to_string()
366 .contains("Embedding size mismatch while building index")
367 );
368 }
369
370 #[test]
371 fn test_add_rejects_mismatched_dimensions() {
372 let mut index = SimpleIndex::new().unwrap();
373 index.add(1, &[0.1, 0.2]).unwrap();
374 let err = index.add(2, &[0.1, 0.2, 0.3]).unwrap_err();
375 assert!(
376 err.to_string()
377 .contains("Embedding size mismatch while updating index")
378 );
379 }
380
381 #[test]
382 fn test_search_rejects_mismatched_query() {
383 let vectors = vec![vec![1.0, 0.0, 0.0]];
384 let index = SimpleIndex::build(&vectors).unwrap();
385 let err = index.search(&[1.0, 0.0], 1).unwrap_err();
386 assert!(
387 err.to_string()
388 .contains("Embedding size mismatch during search")
389 );
390 }
391}