1use crate::RetrieveError;
4use rand::seq::SliceRandom;
5use rand::Rng;
6use smallvec::SmallVec;
7use std::collections::HashSet;
8use std::path::Path;
9
10pub struct DiskANNIndex {
17 dimension: usize,
18 params: DiskANNParams,
19 built: bool,
20
21 vectors: Vec<f32>,
23 num_vectors: usize,
24
25 adj: Vec<SmallVec<[u32; 32]>>,
29
30 start_node: u32,
32}
33
34impl DiskANNIndex {
35 #[inline]
37 pub fn dimension(&self) -> usize {
38 self.dimension
39 }
40
41 #[inline]
43 pub fn num_vectors(&self) -> usize {
44 self.num_vectors
45 }
46
47 #[inline]
49 pub fn size_bytes(&self) -> usize {
50 self.vectors.len() * std::mem::size_of::<f32>()
51 + self
52 .adj
53 .iter()
54 .map(|n| n.len() * std::mem::size_of::<u32>())
55 .sum::<usize>()
56 }
57
58 pub fn save(&self, output_dir: &Path) -> Result<(), RetrieveError> {
65 if !self.built {
66 return Err(RetrieveError::Other(
67 "Cannot save unbuilt index".to_string(),
68 ));
69 }
70
71 if !output_dir.exists() {
72 std::fs::create_dir_all(output_dir).map_err(|e| RetrieveError::Io(e.to_string()))?;
73 }
74
75 let vectors_path = output_dir.join("vectors.bin");
77 let mut vectors_file =
78 std::fs::File::create(&vectors_path).map_err(|e| RetrieveError::Io(e.to_string()))?;
79 let vectors_bytes = unsafe {
80 std::slice::from_raw_parts(
81 self.vectors.as_ptr() as *const u8,
82 self.vectors.len() * std::mem::size_of::<f32>(),
83 )
84 };
85 use std::io::Write;
86 vectors_file
87 .write_all(vectors_bytes)
88 .map_err(|e| RetrieveError::Io(e.to_string()))?;
89
90 let graph_path = output_dir.join("graph.index");
92 let mut graph_writer = super::disk_io::DiskGraphWriter::new(
95 &graph_path,
96 self.num_vectors,
97 self.params.m,
98 self.start_node,
99 )
100 .map_err(|e| RetrieveError::Other(format!("Failed to create graph writer: {}", e)))?;
101
102 for neighbors in &self.adj {
103 graph_writer
104 .write_adjacency(neighbors)
105 .map_err(|e| RetrieveError::Other(format!("Failed to write adjacency: {}", e)))?;
106 }
107 graph_writer
108 .flush()
109 .map_err(|e| RetrieveError::Other(format!("Failed to flush graph: {}", e)))?;
110
111 let metadata_path = output_dir.join("metadata.json");
113 let metadata = serde_json::json!({
114 "dimension": self.dimension,
115 "num_vectors": self.num_vectors,
116 "start_node": self.start_node,
117 "params": {
118 "m": self.params.m,
119 "ef_construction": self.params.ef_construction,
120 "alpha": self.params.alpha,
121 "ef_search": self.params.ef_search
122 }
123 });
124 let metadata_file =
125 std::fs::File::create(&metadata_path).map_err(|e| RetrieveError::Io(e.to_string()))?;
126 serde_json::to_writer_pretty(metadata_file, &metadata)
127 .map_err(|e| RetrieveError::Serialization(e.to_string()))?; Ok(())
130 }
131}
132
133pub struct DiskANNSearcher {
137 dimension: usize,
138 num_vectors: usize,
139 start_node: u32,
140 params: DiskANNParams,
141
142 graph_reader: super::disk_io::DiskGraphReader,
144 vectors_file: std::fs::File, }
147
148impl DiskANNSearcher {
149 pub fn load(index_dir: &Path) -> Result<Self, RetrieveError> {
151 let metadata_path = index_dir.join("metadata.json");
153 let metadata_file =
154 std::fs::File::open(&metadata_path).map_err(|e| RetrieveError::Io(e.to_string()))?;
155 let metadata: serde_json::Value = serde_json::from_reader(metadata_file)
156 .map_err(|e| RetrieveError::Serialization(e.to_string()))?;
157
158 let dimension = metadata["dimension"]
159 .as_u64()
160 .ok_or(RetrieveError::FormatError("Missing dimension".to_string()))?
161 as usize;
162 let num_vectors = metadata["num_vectors"]
163 .as_u64()
164 .ok_or(RetrieveError::FormatError(
165 "Missing num_vectors".to_string(),
166 ))? as usize;
167 let start_node = metadata["start_node"]
168 .as_u64()
169 .ok_or(RetrieveError::FormatError("Missing start_node".to_string()))?
170 as u32;
171
172 let params_val = &metadata["params"];
173 let params = DiskANNParams {
174 m: params_val["m"].as_u64().unwrap_or(32) as usize,
175 ef_construction: params_val["ef_construction"].as_u64().unwrap_or(100) as usize,
176 alpha: params_val["alpha"].as_f64().unwrap_or(1.2) as f32,
177 ef_search: params_val["ef_search"].as_u64().unwrap_or(100) as usize,
178 };
179
180 let graph_path = index_dir.join("graph.index");
182 let graph_reader = super::disk_io::DiskGraphReader::open(&graph_path)
183 .map_err(|e| RetrieveError::Other(format!("Failed to open graph: {}", e)))?;
184
185 let vectors_path = index_dir.join("vectors.bin");
187 let vectors_file =
188 std::fs::File::open(&vectors_path).map_err(|e| RetrieveError::Io(e.to_string()))?;
189
190 Ok(Self {
191 dimension,
192 num_vectors,
193 start_node,
194 params,
195 graph_reader,
196 vectors_file,
197 })
198 }
199
200 pub fn search(
202 &mut self,
203 query: &[f32],
204 k: usize,
205 ef_search: usize,
206 ) -> Result<Vec<(u32, f32)>, RetrieveError> {
207 let ef = ef_search.max(k).max(self.params.ef_search);
208
209 let mut visited = HashSet::new();
214 let mut retset: Vec<Candidate> = Vec::with_capacity(ef + 1);
215
216 let start_vec = self.get_vector(self.start_node)?;
218 let start_dist = self.dist(query, &start_vec);
219
220 retset.push(Candidate {
221 id: self.start_node,
222 dist: start_dist,
223 });
224 visited.insert(self.start_node);
225
226 let mut current_idx = 0;
227
228 while current_idx < retset.len() {
229 retset.sort_by(|a, b| a.dist.total_cmp(&b.dist));
230
231 if current_idx >= retset.len() {
232 break;
233 }
234
235 let current = retset[current_idx];
236 current_idx += 1;
237
238 let neighbors = self.graph_reader.get_neighbors(current.id)?;
241
242 for neighbor in neighbors {
243 if visited.contains(&neighbor) {
244 continue;
245 }
246 visited.insert(neighbor);
247
248 let neighbor_vec = self.get_vector(neighbor)?;
250 let dist = self.dist(query, &neighbor_vec);
251
252 retset.push(Candidate { id: neighbor, dist });
253 }
254
255 retset.sort_by(|a, b| a.dist.total_cmp(&b.dist));
257 if retset.len() > ef {
258 retset.truncate(ef);
259 }
260 }
261
262 Ok(retset.into_iter().take(k).map(|c| (c.id, c.dist)).collect())
263 }
264
265 fn get_vector(&mut self, idx: u32) -> Result<Vec<f32>, RetrieveError> {
266 use std::io::{Read, Seek, SeekFrom};
267 let offset = idx as u64 * self.dimension as u64 * 4;
268 self.vectors_file
269 .seek(SeekFrom::Start(offset))
270 .map_err(|e| RetrieveError::Io(e.to_string()))?;
271
272 let mut buffer = vec![0u8; self.dimension * 4];
273 self.vectors_file
274 .read_exact(&mut buffer)
275 .map_err(|e| RetrieveError::Io(e.to_string()))?;
276
277 let mut vec = Vec::with_capacity(self.dimension);
278 for i in 0..self.dimension {
279 let start = i * 4;
280 let val = f32::from_le_bytes([
281 buffer[start],
282 buffer[start + 1],
283 buffer[start + 2],
284 buffer[start + 3],
285 ]);
286 vec.push(val);
287 }
288 Ok(vec)
289 }
290
291 fn dist(&self, a: &[f32], b: &[f32]) -> f32 {
292 a.iter().zip(b.iter()).map(|(x, y)| (x - y) * (x - y)).sum()
293 }
294}
295
296#[derive(Clone, Debug)]
298pub struct DiskANNParams {
299 pub m: usize,
301
302 pub ef_construction: usize,
304
305 pub alpha: f32,
307
308 pub ef_search: usize,
310}
311
312impl Default for DiskANNParams {
313 fn default() -> Self {
314 Self {
315 m: 32,
316 ef_construction: 100,
317 alpha: 1.2,
318 ef_search: 100,
319 }
320 }
321}
322
323#[derive(Clone, Copy, PartialEq)]
325struct Candidate {
326 id: u32,
327 dist: f32,
328}
329
330impl Eq for Candidate {}
331
332impl Ord for Candidate {
333 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
334 self.dist.total_cmp(&other.dist)
337 }
338}
339
340impl PartialOrd for Candidate {
341 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
342 Some(self.cmp(other))
343 }
344}
345
346impl DiskANNIndex {
347 pub fn new(dimension: usize, params: DiskANNParams) -> Result<Self, RetrieveError> {
349 if dimension == 0 {
350 return Err(RetrieveError::EmptyQuery);
351 }
352
353 Ok(Self {
354 dimension,
355 params,
356 built: false,
357 vectors: Vec::new(),
358 num_vectors: 0,
359 adj: Vec::new(),
360 start_node: 0,
361 })
362 }
363
364 pub fn add(&mut self, _doc_id: u32, vector: Vec<f32>) -> Result<(), RetrieveError> {
366 self.add_slice(_doc_id, &vector)
367 }
368
369 pub fn add_slice(&mut self, _doc_id: u32, vector: &[f32]) -> Result<(), RetrieveError> {
375 if self.built {
376 return Err(RetrieveError::Other(
377 "Cannot add vectors after index is built".to_string(),
378 ));
379 }
380
381 if vector.len() != self.dimension {
382 return Err(RetrieveError::DimensionMismatch {
383 query_dim: self.dimension,
384 doc_dim: vector.len(),
385 });
386 }
387
388 self.vectors.extend_from_slice(vector);
389 self.num_vectors += 1;
390 self.adj.push(SmallVec::new());
391 Ok(())
392 }
393
394 pub fn build(&mut self) -> Result<(), RetrieveError> {
396 if self.built {
397 return Ok(());
398 }
399
400 if self.num_vectors == 0 {
401 return Err(RetrieveError::EmptyIndex);
402 }
403
404 self.initialize_random_graph();
406
407 self.start_node = self.compute_medoid();
409
410 self.vamana_pass(1.0)?;
413
414 self.vamana_pass(self.params.alpha)?;
417
418 self.built = true;
419 Ok(())
420 }
421
422 fn initialize_random_graph(&mut self) {
424 let mut rng = rand::rng();
425 let r = self.params.m;
426
427 for i in 0..self.num_vectors {
428 let mut neighbors: HashSet<u32> = HashSet::with_capacity(r);
430 while neighbors.len() < r && neighbors.len() < self.num_vectors - 1 {
431 let n = rng.random_range(0..self.num_vectors) as u32;
432 if n != i as u32 {
433 neighbors.insert(n);
434 }
435 }
436 self.adj[i] = neighbors.into_iter().collect();
437 }
438 }
439
440 fn compute_medoid(&self) -> u32 {
442 0
447 }
448
449 fn vamana_pass(&mut self, alpha: f32) -> Result<(), RetrieveError> {
451 let mut nodes: Vec<u32> = (0..self.num_vectors as u32).collect();
453 nodes.shuffle(&mut rand::rng());
454
455 for &i in &nodes {
456 let query_vec = self.get_vector(i);
457
458 let (visited, _) =
461 self.greedy_search(query_vec, self.params.ef_construction, self.start_node);
462
463 let new_neighbors = self.robust_prune(i, &visited, alpha, self.params.m);
466
467 self.adj[i as usize] = new_neighbors.into_iter().collect();
469
470 }
474
475 Ok(())
476 }
477
478 fn robust_prune(
483 &self,
484 node: u32,
485 candidates: &[u32],
486 alpha: f32,
487 max_degree: usize,
488 ) -> Vec<u32> {
489 let node_vec = self.get_vector(node);
490
491 let mut candidates_with_dist: Vec<Candidate> = candidates
493 .iter()
494 .filter(|&&c| c != node) .map(|&c| Candidate {
496 id: c,
497 dist: self.dist(node_vec, self.get_vector(c)),
498 })
499 .collect();
500
501 for &neighbor in &self.adj[node as usize] {
503 if !candidates.contains(&neighbor) {
504 candidates_with_dist.push(Candidate {
505 id: neighbor,
506 dist: self.dist(node_vec, self.get_vector(neighbor)),
507 });
508 }
509 }
510
511 candidates_with_dist.sort_by(|a, b| a.dist.total_cmp(&b.dist));
513
514 let mut new_neighbors: Vec<u32> = Vec::with_capacity(max_degree);
516
517 candidates_with_dist.dedup_by(|a, b| a.id == b.id);
519
520 for cand in candidates_with_dist {
521 if new_neighbors.len() >= max_degree {
522 break;
523 }
524
525 let mut prune = false;
528 let cand_vec = self.get_vector(cand.id);
529
530 for &existing_neighbor in &new_neighbors {
531 let dist_existing_cand = self.dist(self.get_vector(existing_neighbor), cand_vec);
532
533 if alpha * dist_existing_cand <= cand.dist {
536 prune = true;
537 break;
538 }
539 }
540
541 if !prune {
542 new_neighbors.push(cand.id);
543 }
544 }
545
546 new_neighbors
547 }
548
549 fn greedy_search(
553 &self,
554 query: &[f32],
555 l_size: usize,
556 start_node: u32,
557 ) -> (Vec<u32>, Vec<Candidate>) {
558 let mut visited = HashSet::new();
559 let mut retset: Vec<Candidate> = Vec::with_capacity(l_size + 1);
568
569 let start_dist = self.dist(query, self.get_vector(start_node));
570 retset.push(Candidate {
571 id: start_node,
572 dist: start_dist,
573 });
574 visited.insert(start_node);
575
576 let mut current_idx = 0;
577
578 while current_idx < retset.len() {
579 retset.sort_by(|a, b| a.dist.total_cmp(&b.dist));
582
583 if current_idx >= retset.len() {
584 break;
585 }
586
587 let current = retset[current_idx];
588 current_idx += 1;
589
590 for &neighbor in &self.adj[current.id as usize] {
594 if visited.contains(&neighbor) {
595 continue;
596 }
597 visited.insert(neighbor);
598
599 let dist = self.dist(query, self.get_vector(neighbor));
600
601 retset.push(Candidate { id: neighbor, dist });
603 }
604
605 retset.sort_by(|a, b| a.dist.total_cmp(&b.dist));
607 if retset.len() > l_size {
608 retset.truncate(l_size);
609 }
610 }
611
612 let ids: Vec<u32> = retset.iter().map(|c| c.id).collect();
613 (ids, retset)
614 }
615
616 pub fn search(
618 &self,
619 query: &[f32],
620 k: usize,
621 ef_search: usize,
622 ) -> Result<Vec<(u32, f32)>, RetrieveError> {
623 if !self.built {
624 return Err(RetrieveError::Other(
625 "Index must be built before search".to_string(),
626 ));
627 }
628
629 if query.len() != self.dimension {
630 return Err(RetrieveError::DimensionMismatch {
631 query_dim: self.dimension,
632 doc_dim: query.len(),
633 });
634 }
635
636 let ef = ef_search.max(k);
637 let (_, candidates) = self.greedy_search(query, ef, self.start_node);
638
639 let result = candidates
641 .into_iter()
642 .take(k)
643 .map(|c| (c.id, c.dist))
644 .collect();
645
646 Ok(result)
647 }
648
649 fn get_vector(&self, idx: u32) -> &[f32] {
650 let start = idx as usize * self.dimension;
651 &self.vectors[start..start + self.dimension]
652 }
653
654 fn dist(&self, a: &[f32], b: &[f32]) -> f32 {
656 a.iter().zip(b.iter()).map(|(x, y)| (x - y) * (x - y)).sum()
658 }
659}