1use anyhow::Result;
2use std::path::Path;
3use serde::{Deserialize, Serialize};
4
5pub trait AnnIndex: Send + Sync {
6 fn build(vectors: &[Vec<f32>]) -> Result<Self> where Self: Sized;
7 fn search(&self, query: &[f32], topk: usize) -> Vec<(u32, f32)>;
8 fn add(&mut self, id: u32, vector: &[f32]);
9 fn save(&self, path: &Path) -> Result<()>;
10 fn load(path: &Path) -> Result<Self> where Self: Sized;
11}
12
13pub fn create_index(_backend: Option<&str>) -> Result<Box<dyn AnnIndex>> {
14 Ok(Box::new(SimpleIndex::new()?))
15}
16
17#[derive(Serialize, Deserialize)]
18pub struct SimpleIndex {
19 vectors: Vec<Vec<f32>>,
20 ids: Vec<u32>,
21 dim: usize,
22}
23
24impl SimpleIndex {
25 pub fn new() -> Result<Self> {
26 Ok(Self {
27 vectors: Vec::new(),
28 ids: Vec::new(),
29 dim: 0,
30 })
31 }
32
33 fn cosine_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
34 let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
35 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
36 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
37
38 if norm_a == 0.0 || norm_b == 0.0 {
39 0.0
40 } else {
41 dot_product / (norm_a * norm_b)
42 }
43 }
44}
45
46impl AnnIndex for SimpleIndex {
47 fn build(vectors: &[Vec<f32>]) -> Result<Self> where Self: Sized {
48 if vectors.is_empty() {
49 return Ok(Self::new()?);
50 }
51
52 let dim = vectors[0].len();
53 let ids: Vec<u32> = (0..vectors.len() as u32).collect();
54
55 Ok(Self {
56 vectors: vectors.to_vec(),
57 ids,
58 dim,
59 })
60 }
61
62 fn search(&self, query: &[f32], topk: usize) -> Vec<(u32, f32)> {
63 let mut similarities: Vec<_> = self.vectors
64 .iter()
65 .zip(&self.ids)
66 .map(|(vector, &id)| {
67 let similarity = self.cosine_similarity(query, vector);
68 (id, similarity)
69 })
70 .collect();
71
72 similarities.sort_by(|a, b| {
73 b.1
74 .partial_cmp(&a.1)
75 .unwrap_or(std::cmp::Ordering::Equal)
76 });
77 similarities.truncate(topk);
78 similarities
79 }
80
81 fn add(&mut self, id: u32, vector: &[f32]) {
82 if self.dim == 0 {
83 self.dim = vector.len();
84 }
85
86 self.vectors.push(vector.to_vec());
87 self.ids.push(id);
88 }
89
90 fn save(&self, path: &Path) -> Result<()> {
91 let data = bincode::serialize(self)?;
92 std::fs::write(path, data)?;
93 Ok(())
94 }
95
96 fn load(path: &Path) -> Result<Self> where Self: Sized {
97 let data = std::fs::read(path)?;
98 let index: Self = bincode::deserialize(&data)?;
99 Ok(index)
100 }
101}
102
103#[cfg(test)]
104mod tests {
105 use super::*;
106 use tempfile::TempDir;
107
108 #[test]
109 fn test_simple_index_new() {
110 let index = SimpleIndex::new().unwrap();
111 assert_eq!(index.vectors.len(), 0);
112 assert_eq!(index.ids.len(), 0);
113 assert_eq!(index.dim, 0);
114 }
115
116 #[test]
117 fn test_simple_index_build_empty() {
118 let vectors: Vec<Vec<f32>> = vec![];
119 let index = SimpleIndex::build(&vectors).unwrap();
120 assert_eq!(index.vectors.len(), 0);
121 assert_eq!(index.dim, 0);
122 }
123
124 #[test]
125 fn test_simple_index_build() {
126 let vectors = vec![
127 vec![1.0, 0.0, 0.0],
128 vec![0.0, 1.0, 0.0],
129 vec![0.0, 0.0, 1.0],
130 ];
131
132 let index = SimpleIndex::build(&vectors).unwrap();
133 assert_eq!(index.vectors.len(), 3);
134 assert_eq!(index.ids.len(), 3);
135 assert_eq!(index.dim, 3);
136 assert_eq!(index.ids, vec![0, 1, 2]);
137 }
138
139 #[test]
140 fn test_cosine_similarity() {
141 let index = SimpleIndex::new().unwrap();
142
143 let a = vec![1.0, 2.0, 3.0];
145 let b = vec![1.0, 2.0, 3.0];
146 let sim = index.cosine_similarity(&a, &b);
147 assert!((sim - 1.0).abs() < 1e-6);
148
149 let a = vec![1.0, 0.0];
151 let b = vec![0.0, 1.0];
152 let sim = index.cosine_similarity(&a, &b);
153 assert!((sim - 0.0).abs() < 1e-6);
154
155 let a = vec![1.0, 0.0];
157 let b = vec![-1.0, 0.0];
158 let sim = index.cosine_similarity(&a, &b);
159 assert!((sim - (-1.0)).abs() < 1e-6);
160 }
161
162 #[test]
163 fn test_cosine_similarity_zero_vectors() {
164 let index = SimpleIndex::new().unwrap();
165
166 let a = vec![0.0, 0.0, 0.0];
167 let b = vec![1.0, 2.0, 3.0];
168 let sim = index.cosine_similarity(&a, &b);
169 assert_eq!(sim, 0.0);
170
171 let a = vec![1.0, 2.0, 3.0];
172 let b = vec![0.0, 0.0, 0.0];
173 let sim = index.cosine_similarity(&a, &b);
174 assert_eq!(sim, 0.0);
175 }
176
177 #[test]
178 fn test_search() {
179 let vectors = vec![
180 vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0], vec![0.5, 0.5, 0.0], ];
184
185 let index = SimpleIndex::build(&vectors).unwrap();
186
187 let query = vec![0.9, 0.1, 0.0];
189 let results = index.search(&query, 2);
190
191 assert_eq!(results.len(), 2);
192 assert_eq!(results[0].0, 0); assert!(results[0].1 > results[1].1); }
195
196 #[test]
197 fn test_search_empty_index() {
198 let vectors: Vec<Vec<f32>> = vec![];
199 let index = SimpleIndex::build(&vectors).unwrap();
200
201 let query = vec![1.0, 0.0];
202 let results = index.search(&query, 5);
203 assert_eq!(results.len(), 0);
204 }
205
206 #[test]
207 fn test_search_topk_limit() {
208 let vectors = vec![
209 vec![1.0, 0.0],
210 vec![0.9, 0.1],
211 vec![0.8, 0.2],
212 vec![0.7, 0.3],
213 vec![0.6, 0.4],
214 ];
215
216 let index = SimpleIndex::build(&vectors).unwrap();
217
218 let query = vec![1.0, 0.0];
219 let results = index.search(&query, 3);
220
221 assert_eq!(results.len(), 3);
222 for i in 1..results.len() {
224 assert!(results[i-1].1 >= results[i].1);
225 }
226 }
227
228 #[test]
229 fn test_add() {
230 let mut index = SimpleIndex::new().unwrap();
231
232 index.add(100, &vec![1.0, 2.0, 3.0]);
233 assert_eq!(index.vectors.len(), 1);
234 assert_eq!(index.ids.len(), 1);
235 assert_eq!(index.ids[0], 100);
236 assert_eq!(index.dim, 3);
237
238 index.add(200, &vec![4.0, 5.0, 6.0]);
239 assert_eq!(index.vectors.len(), 2);
240 assert_eq!(index.ids.len(), 2);
241 assert_eq!(index.ids[1], 200);
242 }
243
244 #[test]
245 fn test_save_and_load() {
246 let temp_dir = TempDir::new().unwrap();
247 let index_path = temp_dir.path().join("test_index.bin");
248
249 let vectors = vec![
251 vec![1.0, 2.0, 3.0],
252 vec![4.0, 5.0, 6.0],
253 ];
254 let index = SimpleIndex::build(&vectors).unwrap();
255 index.save(&index_path).unwrap();
256
257 let loaded_index = SimpleIndex::load(&index_path).unwrap();
259
260 assert_eq!(loaded_index.vectors.len(), index.vectors.len());
261 assert_eq!(loaded_index.ids.len(), index.ids.len());
262 assert_eq!(loaded_index.dim, index.dim);
263
264 let query = vec![1.0, 2.0, 3.0];
266 let original_results = index.search(&query, 2);
267 let loaded_results = loaded_index.search(&query, 2);
268
269 assert_eq!(original_results.len(), loaded_results.len());
270 for (orig, loaded) in original_results.iter().zip(&loaded_results) {
271 assert_eq!(orig.0, loaded.0);
272 assert!((orig.1 - loaded.1).abs() < 1e-6);
273 }
274 }
275
276 #[test]
277 fn test_create_index() {
278 let _index = create_index(None).unwrap();
279
280 let vectors = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
283 let index = SimpleIndex::build(&vectors).unwrap();
284
285 let query = vec![1.0, 0.0];
286 let results = index.search(&query, 1);
287 assert!(!results.is_empty());
288 }
289
290 #[test]
291 fn test_load_nonexistent_file() {
292 let result = SimpleIndex::load(&std::path::PathBuf::from("nonexistent.bin"));
293 assert!(result.is_err());
294 }
295
296 #[test]
297 fn test_ann_index_trait() {
298 let vectors = vec![
299 vec![1.0, 0.0, 0.0],
300 vec![0.0, 1.0, 0.0],
301 ];
302
303 let mut index: Box<dyn AnnIndex> = Box::new(SimpleIndex::build(&vectors).unwrap());
304
305 let query = vec![1.0, 0.0, 0.0];
307 let results = index.search(&query, 1);
308 assert_eq!(results.len(), 1);
309 assert_eq!(results[0].0, 0);
310
311 index.add(99, &vec![0.0, 0.0, 1.0]);
313 let results = index.search(&vec![0.0, 0.0, 1.0], 1);
314 assert_eq!(results.len(), 1);
315 assert_eq!(results[0].0, 99);
316 }
317}