1use anyhow::Result;
2use serde::{Deserialize, Serialize};
3use std::path::Path;
4
5pub trait AnnIndex: Send + Sync {
6 fn build(vectors: &[Vec<f32>]) -> Result<Self>
7 where
8 Self: Sized;
9 fn search(&self, query: &[f32], topk: usize) -> Vec<(u32, f32)>;
10 fn add(&mut self, id: u32, vector: &[f32]);
11 fn save(&self, path: &Path) -> Result<()>;
12 fn load(path: &Path) -> Result<Self>
13 where
14 Self: Sized;
15}
16
17pub fn create_index(_backend: Option<&str>) -> Result<Box<dyn AnnIndex>> {
18 Ok(Box::new(SimpleIndex::new()?))
19}
20
21#[derive(Serialize, Deserialize)]
22pub struct SimpleIndex {
23 vectors: Vec<Vec<f32>>,
24 ids: Vec<u32>,
25 dim: usize,
26}
27
28impl SimpleIndex {
29 pub fn new() -> Result<Self> {
30 Ok(Self {
31 vectors: Vec::new(),
32 ids: Vec::new(),
33 dim: 0,
34 })
35 }
36
37 fn cosine_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
38 let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
39 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
40 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
41
42 if norm_a == 0.0 || norm_b == 0.0 {
43 0.0
44 } else {
45 dot_product / (norm_a * norm_b)
46 }
47 }
48}
49
50impl AnnIndex for SimpleIndex {
51 fn build(vectors: &[Vec<f32>]) -> Result<Self>
52 where
53 Self: Sized,
54 {
55 if vectors.is_empty() {
56 return Self::new();
57 }
58
59 let dim = vectors[0].len();
60 let ids: Vec<u32> = (0..vectors.len() as u32).collect();
61
62 Ok(Self {
63 vectors: vectors.to_vec(),
64 ids,
65 dim,
66 })
67 }
68
69 fn search(&self, query: &[f32], topk: usize) -> Vec<(u32, f32)> {
70 let mut similarities: Vec<_> = self
71 .vectors
72 .iter()
73 .zip(&self.ids)
74 .map(|(vector, &id)| {
75 let similarity = self.cosine_similarity(query, vector);
76 (id, similarity)
77 })
78 .collect();
79
80 similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
81 similarities.truncate(topk);
82 similarities
83 }
84
85 fn add(&mut self, id: u32, vector: &[f32]) {
86 if self.dim == 0 {
87 self.dim = vector.len();
88 }
89
90 self.vectors.push(vector.to_vec());
91 self.ids.push(id);
92 }
93
94 fn save(&self, path: &Path) -> Result<()> {
95 let data = bincode::serialize(self)?;
96 std::fs::write(path, data)?;
97 Ok(())
98 }
99
100 fn load(path: &Path) -> Result<Self>
101 where
102 Self: Sized,
103 {
104 let data = std::fs::read(path)?;
105 let index: Self = bincode::deserialize(&data)?;
106 Ok(index)
107 }
108}
109
110#[cfg(test)]
111mod tests {
112 use super::*;
113 use tempfile::TempDir;
114
115 #[test]
116 fn test_simple_index_new() {
117 let index = SimpleIndex::new().unwrap();
118 assert_eq!(index.vectors.len(), 0);
119 assert_eq!(index.ids.len(), 0);
120 assert_eq!(index.dim, 0);
121 }
122
123 #[test]
124 fn test_simple_index_build_empty() {
125 let vectors: Vec<Vec<f32>> = vec![];
126 let index = SimpleIndex::build(&vectors).unwrap();
127 assert_eq!(index.vectors.len(), 0);
128 assert_eq!(index.dim, 0);
129 }
130
131 #[test]
132 fn test_simple_index_build() {
133 let vectors = vec![
134 vec![1.0, 0.0, 0.0],
135 vec![0.0, 1.0, 0.0],
136 vec![0.0, 0.0, 1.0],
137 ];
138
139 let index = SimpleIndex::build(&vectors).unwrap();
140 assert_eq!(index.vectors.len(), 3);
141 assert_eq!(index.ids.len(), 3);
142 assert_eq!(index.dim, 3);
143 assert_eq!(index.ids, vec![0, 1, 2]);
144 }
145
146 #[test]
147 fn test_cosine_similarity() {
148 let index = SimpleIndex::new().unwrap();
149
150 let a = vec![1.0, 2.0, 3.0];
152 let b = vec![1.0, 2.0, 3.0];
153 let sim = index.cosine_similarity(&a, &b);
154 assert!((sim - 1.0).abs() < 1e-6);
155
156 let a = vec![1.0, 0.0];
158 let b = vec![0.0, 1.0];
159 let sim = index.cosine_similarity(&a, &b);
160 assert!((sim - 0.0).abs() < 1e-6);
161
162 let a = vec![1.0, 0.0];
164 let b = vec![-1.0, 0.0];
165 let sim = index.cosine_similarity(&a, &b);
166 assert!((sim - (-1.0)).abs() < 1e-6);
167 }
168
169 #[test]
170 fn test_cosine_similarity_zero_vectors() {
171 let index = SimpleIndex::new().unwrap();
172
173 let a = vec![0.0, 0.0, 0.0];
174 let b = vec![1.0, 2.0, 3.0];
175 let sim = index.cosine_similarity(&a, &b);
176 assert_eq!(sim, 0.0);
177
178 let a = vec![1.0, 2.0, 3.0];
179 let b = vec![0.0, 0.0, 0.0];
180 let sim = index.cosine_similarity(&a, &b);
181 assert_eq!(sim, 0.0);
182 }
183
184 #[test]
185 fn test_search() {
186 let vectors = vec![
187 vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0], vec![0.5, 0.5, 0.0], ];
191
192 let index = SimpleIndex::build(&vectors).unwrap();
193
194 let query = vec![0.9, 0.1, 0.0];
196 let results = index.search(&query, 2);
197
198 assert_eq!(results.len(), 2);
199 assert_eq!(results[0].0, 0); assert!(results[0].1 > results[1].1); }
202
203 #[test]
204 fn test_search_empty_index() {
205 let vectors: Vec<Vec<f32>> = vec![];
206 let index = SimpleIndex::build(&vectors).unwrap();
207
208 let query = vec![1.0, 0.0];
209 let results = index.search(&query, 5);
210 assert_eq!(results.len(), 0);
211 }
212
213 #[test]
214 fn test_search_topk_limit() {
215 let vectors = vec![
216 vec![1.0, 0.0],
217 vec![0.9, 0.1],
218 vec![0.8, 0.2],
219 vec![0.7, 0.3],
220 vec![0.6, 0.4],
221 ];
222
223 let index = SimpleIndex::build(&vectors).unwrap();
224
225 let query = vec![1.0, 0.0];
226 let results = index.search(&query, 3);
227
228 assert_eq!(results.len(), 3);
229 for i in 1..results.len() {
231 assert!(results[i - 1].1 >= results[i].1);
232 }
233 }
234
235 #[test]
236 fn test_add() {
237 let mut index = SimpleIndex::new().unwrap();
238
239 index.add(100, &[1.0, 2.0, 3.0]);
240 assert_eq!(index.vectors.len(), 1);
241 assert_eq!(index.ids.len(), 1);
242 assert_eq!(index.ids[0], 100);
243 assert_eq!(index.dim, 3);
244
245 index.add(200, &[4.0, 5.0, 6.0]);
246 assert_eq!(index.vectors.len(), 2);
247 assert_eq!(index.ids.len(), 2);
248 assert_eq!(index.ids[1], 200);
249 }
250
251 #[test]
252 fn test_save_and_load() {
253 let temp_dir = TempDir::new().unwrap();
254 let index_path = temp_dir.path().join("test_index.bin");
255
256 let vectors = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]];
258 let index = SimpleIndex::build(&vectors).unwrap();
259 index.save(&index_path).unwrap();
260
261 let loaded_index = SimpleIndex::load(&index_path).unwrap();
263
264 assert_eq!(loaded_index.vectors.len(), index.vectors.len());
265 assert_eq!(loaded_index.ids.len(), index.ids.len());
266 assert_eq!(loaded_index.dim, index.dim);
267
268 let query = vec![1.0, 2.0, 3.0];
270 let original_results = index.search(&query, 2);
271 let loaded_results = loaded_index.search(&query, 2);
272
273 assert_eq!(original_results.len(), loaded_results.len());
274 for (orig, loaded) in original_results.iter().zip(&loaded_results) {
275 assert_eq!(orig.0, loaded.0);
276 assert!((orig.1 - loaded.1).abs() < 1e-6);
277 }
278 }
279
280 #[test]
281 fn test_create_index() {
282 let _index = create_index(None).unwrap();
283
284 let vectors = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
287 let index = SimpleIndex::build(&vectors).unwrap();
288
289 let query = vec![1.0, 0.0];
290 let results = index.search(&query, 1);
291 assert!(!results.is_empty());
292 }
293
294 #[test]
295 fn test_load_nonexistent_file() {
296 let result = SimpleIndex::load(&std::path::PathBuf::from("nonexistent.bin"));
297 assert!(result.is_err());
298 }
299
300 #[test]
301 fn test_ann_index_trait() {
302 let vectors = vec![vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0]];
303
304 let mut index: Box<dyn AnnIndex> = Box::new(SimpleIndex::build(&vectors).unwrap());
305
306 let query = vec![1.0, 0.0, 0.0];
308 let results = index.search(&query, 1);
309 assert_eq!(results.len(), 1);
310 assert_eq!(results[0].0, 0);
311
312 index.add(99, &[0.0, 0.0, 1.0]);
314 let results = index.search(&[0.0, 0.0, 1.0], 1);
315 assert_eq!(results.len(), 1);
316 assert_eq!(results[0].0, 99);
317 }
318}