jin/diskann/
graph.rs

1//! DiskANN graph structure and Vamana construction.
2
3use crate::RetrieveError;
4use rand::seq::SliceRandom;
5use rand::Rng;
6use smallvec::SmallVec;
7use std::collections::HashSet;
8use std::path::Path;
9
10/// DiskANN index for disk-based approximate nearest neighbor search.
11///
12/// Implements the Vamana graph construction algorithm:
13/// 1. Random graph initialization
14/// 2. Two-pass construction (alpha=1.0, then alpha>1.0)
15/// 3. Robust pruning (alpha-pruning) to maintain long-range edges
16pub struct DiskANNIndex {
17    dimension: usize,
18    params: DiskANNParams,
19    built: bool,
20
21    // Vectors stored in memory for build (would be on disk in prod)
22    vectors: Vec<f32>,
23    num_vectors: usize,
24
25    // Graph structure (adjacency list)
26    // Using SmallVec to optimize for typical degree M=16-32
27    // Stored in memory for construction, serialized to disk later
28    adj: Vec<SmallVec<[u32; 32]>>,
29
30    // Entry point for search (medoid)
31    start_node: u32,
32}
33
34impl DiskANNIndex {
35    /// Vector dimensionality.
36    #[inline]
37    pub fn dimension(&self) -> usize {
38        self.dimension
39    }
40
41    /// Number of vectors currently stored in the index.
42    #[inline]
43    pub fn num_vectors(&self) -> usize {
44        self.num_vectors
45    }
46
47    /// Approximate memory usage in bytes (vectors + adjacency lists).
48    #[inline]
49    pub fn size_bytes(&self) -> usize {
50        self.vectors.len() * std::mem::size_of::<f32>()
51            + self
52                .adj
53                .iter()
54                .map(|n| n.len() * std::mem::size_of::<u32>())
55                .sum::<usize>()
56    }
57
58    /// Save the built index to disk.
59    ///
60    /// Saves:
61    /// - Graph structure (adjacency list) using DiskGraphWriter
62    /// - Vectors (flat binary format)
63    /// - Metadata (JSON)
64    pub fn save(&self, output_dir: &Path) -> Result<(), RetrieveError> {
65        if !self.built {
66            return Err(RetrieveError::Other(
67                "Cannot save unbuilt index".to_string(),
68            ));
69        }
70
71        if !output_dir.exists() {
72            std::fs::create_dir_all(output_dir).map_err(|e| RetrieveError::Io(e.to_string()))?;
73        }
74
75        // 1. Save Vectors (vectors.bin)
76        let vectors_path = output_dir.join("vectors.bin");
77        let mut vectors_file =
78            std::fs::File::create(&vectors_path).map_err(|e| RetrieveError::Io(e.to_string()))?;
79        let vectors_bytes = unsafe {
80            std::slice::from_raw_parts(
81                self.vectors.as_ptr() as *const u8,
82                self.vectors.len() * std::mem::size_of::<f32>(),
83            )
84        };
85        use std::io::Write;
86        vectors_file
87            .write_all(vectors_bytes)
88            .map_err(|e| RetrieveError::Io(e.to_string()))?;
89
90        // 2. Save Graph (graph.index)
91        let graph_path = output_dir.join("graph.index");
92        // Convert persistence error to RetrieveError if needed, or handle unwraps
93        // We'll define a simple wrapper
94        let mut graph_writer = super::disk_io::DiskGraphWriter::new(
95            &graph_path,
96            self.num_vectors,
97            self.params.m,
98            self.start_node,
99        )
100        .map_err(|e| RetrieveError::Other(format!("Failed to create graph writer: {}", e)))?;
101
102        for neighbors in &self.adj {
103            graph_writer
104                .write_adjacency(neighbors)
105                .map_err(|e| RetrieveError::Other(format!("Failed to write adjacency: {}", e)))?;
106        }
107        graph_writer
108            .flush()
109            .map_err(|e| RetrieveError::Other(format!("Failed to flush graph: {}", e)))?;
110
111        // 3. Save Metadata (metadata.json)
112        let metadata_path = output_dir.join("metadata.json");
113        let metadata = serde_json::json!({
114            "dimension": self.dimension,
115            "num_vectors": self.num_vectors,
116            "start_node": self.start_node,
117            "params": {
118                "m": self.params.m,
119                "ef_construction": self.params.ef_construction,
120                "alpha": self.params.alpha,
121                "ef_search": self.params.ef_search
122            }
123        });
124        let metadata_file =
125            std::fs::File::create(&metadata_path).map_err(|e| RetrieveError::Io(e.to_string()))?;
126        serde_json::to_writer_pretty(metadata_file, &metadata)
127            .map_err(|e| RetrieveError::Serialization(e.to_string()))?; // Need to add Serialization error to RetrieveError
128
129        Ok(())
130    }
131}
132
133/// Disk-based searcher for DiskANN.
134///
135/// Operates on persisted index without loading the full graph into RAM.
136pub struct DiskANNSearcher {
137    dimension: usize,
138    num_vectors: usize,
139    start_node: u32,
140    params: DiskANNParams,
141
142    // Components
143    graph_reader: super::disk_io::DiskGraphReader,
144    vectors_file: std::fs::File, // Or mmap
145                                 // Using simple file I/O for vectors for now, upgradable to mmap
146}
147
148impl DiskANNSearcher {
149    /// Load searcher from index directory.
150    pub fn load(index_dir: &Path) -> Result<Self, RetrieveError> {
151        // 1. Load Metadata
152        let metadata_path = index_dir.join("metadata.json");
153        let metadata_file =
154            std::fs::File::open(&metadata_path).map_err(|e| RetrieveError::Io(e.to_string()))?;
155        let metadata: serde_json::Value = serde_json::from_reader(metadata_file)
156            .map_err(|e| RetrieveError::Serialization(e.to_string()))?;
157
158        let dimension = metadata["dimension"]
159            .as_u64()
160            .ok_or(RetrieveError::FormatError("Missing dimension".to_string()))?
161            as usize;
162        let num_vectors = metadata["num_vectors"]
163            .as_u64()
164            .ok_or(RetrieveError::FormatError(
165                "Missing num_vectors".to_string(),
166            ))? as usize;
167        let start_node = metadata["start_node"]
168            .as_u64()
169            .ok_or(RetrieveError::FormatError("Missing start_node".to_string()))?
170            as u32;
171
172        let params_val = &metadata["params"];
173        let params = DiskANNParams {
174            m: params_val["m"].as_u64().unwrap_or(32) as usize,
175            ef_construction: params_val["ef_construction"].as_u64().unwrap_or(100) as usize,
176            alpha: params_val["alpha"].as_f64().unwrap_or(1.2) as f32,
177            ef_search: params_val["ef_search"].as_u64().unwrap_or(100) as usize,
178        };
179
180        // 2. Open Graph
181        let graph_path = index_dir.join("graph.index");
182        let graph_reader = super::disk_io::DiskGraphReader::open(&graph_path)
183            .map_err(|e| RetrieveError::Other(format!("Failed to open graph: {}", e)))?;
184
185        // 3. Open Vectors
186        let vectors_path = index_dir.join("vectors.bin");
187        let vectors_file =
188            std::fs::File::open(&vectors_path).map_err(|e| RetrieveError::Io(e.to_string()))?;
189
190        Ok(Self {
191            dimension,
192            num_vectors,
193            start_node,
194            params,
195            graph_reader,
196            vectors_file,
197        })
198    }
199
200    /// Search for k nearest neighbors using disk-based graph.
201    pub fn search(
202        &mut self,
203        query: &[f32],
204        k: usize,
205        ef_search: usize,
206    ) -> Result<Vec<(u32, f32)>, RetrieveError> {
207        let ef = ef_search.max(k).max(self.params.ef_search);
208
209        // Use greedy search similar to in-memory, but fetching neighbors from disk
210        // Note: Performance will be limited by random I/O here without caching/prefetching
211        // This is a functional baseline.
212
213        let mut visited = HashSet::new();
214        let mut retset: Vec<Candidate> = Vec::with_capacity(ef + 1);
215
216        // Fetch start node vector
217        let start_vec = self.get_vector(self.start_node)?;
218        let start_dist = self.dist(query, &start_vec);
219
220        retset.push(Candidate {
221            id: self.start_node,
222            dist: start_dist,
223        });
224        visited.insert(self.start_node);
225
226        let mut current_idx = 0;
227
228        while current_idx < retset.len() {
229            retset.sort_by(|a, b| a.dist.total_cmp(&b.dist));
230
231            if current_idx >= retset.len() {
232                break;
233            }
234
235            let current = retset[current_idx];
236            current_idx += 1;
237
238            // Fetch neighbors from disk
239            // TODO: Cache hot nodes (top levels of Vamana) in RAM
240            let neighbors = self.graph_reader.get_neighbors(current.id)?;
241
242            for neighbor in neighbors {
243                if visited.contains(&neighbor) {
244                    continue;
245                }
246                visited.insert(neighbor);
247
248                // Fetch neighbor vector from disk
249                let neighbor_vec = self.get_vector(neighbor)?;
250                let dist = self.dist(query, &neighbor_vec);
251
252                retset.push(Candidate { id: neighbor, dist });
253            }
254
255            // Keep top L
256            retset.sort_by(|a, b| a.dist.total_cmp(&b.dist));
257            if retset.len() > ef {
258                retset.truncate(ef);
259            }
260        }
261
262        Ok(retset.into_iter().take(k).map(|c| (c.id, c.dist)).collect())
263    }
264
265    fn get_vector(&mut self, idx: u32) -> Result<Vec<f32>, RetrieveError> {
266        use std::io::{Read, Seek, SeekFrom};
267        let offset = idx as u64 * self.dimension as u64 * 4;
268        self.vectors_file
269            .seek(SeekFrom::Start(offset))
270            .map_err(|e| RetrieveError::Io(e.to_string()))?;
271
272        let mut buffer = vec![0u8; self.dimension * 4];
273        self.vectors_file
274            .read_exact(&mut buffer)
275            .map_err(|e| RetrieveError::Io(e.to_string()))?;
276
277        let mut vec = Vec::with_capacity(self.dimension);
278        for i in 0..self.dimension {
279            let start = i * 4;
280            let val = f32::from_le_bytes([
281                buffer[start],
282                buffer[start + 1],
283                buffer[start + 2],
284                buffer[start + 3],
285            ]);
286            vec.push(val);
287        }
288        Ok(vec)
289    }
290
291    fn dist(&self, a: &[f32], b: &[f32]) -> f32 {
292        a.iter().zip(b.iter()).map(|(x, y)| (x - y) * (x - y)).sum()
293    }
294}
295
296/// DiskANN parameters.
297#[derive(Clone, Debug)]
298pub struct DiskANNParams {
299    /// Maximum connections per node (R in paper)
300    pub m: usize,
301
302    /// Beam width for construction search (L in paper)
303    pub ef_construction: usize,
304
305    /// Alpha parameter for pruning (typically 1.2 - 1.4)
306    pub alpha: f32,
307
308    /// Search width
309    pub ef_search: usize,
310}
311
312impl Default for DiskANNParams {
313    fn default() -> Self {
314        Self {
315            m: 32,
316            ef_construction: 100,
317            alpha: 1.2,
318            ef_search: 100,
319        }
320    }
321}
322
323/// Candidate for priority queues
324#[derive(Clone, Copy, PartialEq)]
325struct Candidate {
326    id: u32,
327    dist: f32,
328}
329
330impl Eq for Candidate {}
331
332impl Ord for Candidate {
333    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
334        // Max-heap: larger distance = higher priority (for results pruning)
335        // Use total_cmp for IEEE 754 total ordering (NaN-safe, NaN > all)
336        self.dist.total_cmp(&other.dist)
337    }
338}
339
340impl PartialOrd for Candidate {
341    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
342        Some(self.cmp(other))
343    }
344}
345
346impl DiskANNIndex {
347    /// Create a new DiskANN index.
348    pub fn new(dimension: usize, params: DiskANNParams) -> Result<Self, RetrieveError> {
349        if dimension == 0 {
350            return Err(RetrieveError::EmptyQuery);
351        }
352
353        Ok(Self {
354            dimension,
355            params,
356            built: false,
357            vectors: Vec::new(),
358            num_vectors: 0,
359            adj: Vec::new(),
360            start_node: 0,
361        })
362    }
363
364    /// Add a vector to the index.
365    pub fn add(&mut self, _doc_id: u32, vector: Vec<f32>) -> Result<(), RetrieveError> {
366        self.add_slice(_doc_id, &vector)
367    }
368
369    /// Add a vector to the index from a borrowed slice.
370    ///
371    /// Notes:
372    /// - The index stores vectors internally, so it must copy the slice into its own storage.
373    /// - DiskANN currently ignores `doc_id` and uses insertion order as the internal ID.
374    pub fn add_slice(&mut self, _doc_id: u32, vector: &[f32]) -> Result<(), RetrieveError> {
375        if self.built {
376            return Err(RetrieveError::Other(
377                "Cannot add vectors after index is built".to_string(),
378            ));
379        }
380
381        if vector.len() != self.dimension {
382            return Err(RetrieveError::DimensionMismatch {
383                query_dim: self.dimension,
384                doc_dim: vector.len(),
385            });
386        }
387
388        self.vectors.extend_from_slice(vector);
389        self.num_vectors += 1;
390        self.adj.push(SmallVec::new());
391        Ok(())
392    }
393
394    /// Build the index using Vamana construction.
395    pub fn build(&mut self) -> Result<(), RetrieveError> {
396        if self.built {
397            return Ok(());
398        }
399
400        if self.num_vectors == 0 {
401            return Err(RetrieveError::EmptyIndex);
402        }
403
404        // 1. Initialize random graph (R-regular)
405        self.initialize_random_graph();
406
407        // 2. Compute medoid as start node
408        self.start_node = self.compute_medoid();
409
410        // 3. First pass: alpha = 1.0 (approximates RNG)
411        // Helps build initial connectivity
412        self.vamana_pass(1.0)?;
413
414        // 4. Second pass: alpha = params.alpha (e.g. 1.2)
415        // Adds long-range edges for small-world navigation
416        self.vamana_pass(self.params.alpha)?;
417
418        self.built = true;
419        Ok(())
420    }
421
422    /// Initialize random R-regular graph.
423    fn initialize_random_graph(&mut self) {
424        let mut rng = rand::rng();
425        let r = self.params.m;
426
427        for i in 0..self.num_vectors {
428            // Pick R random neighbors
429            let mut neighbors: HashSet<u32> = HashSet::with_capacity(r);
430            while neighbors.len() < r && neighbors.len() < self.num_vectors - 1 {
431                let n = rng.random_range(0..self.num_vectors) as u32;
432                if n != i as u32 {
433                    neighbors.insert(n);
434                }
435            }
436            self.adj[i] = neighbors.into_iter().collect();
437        }
438    }
439
440    /// Compute geometric medoid of the dataset.
441    fn compute_medoid(&self) -> u32 {
442        // Approximate medoid by centroid of a sample
443        // For simplicity in this implementation, just pick a random node if N is large,
444        // or 0. A robust implementation would compute the true centroid.
445        // Using 0 is a common valid simplification for prototype.
446        0
447    }
448
449    /// Single pass of Vamana construction.
450    fn vamana_pass(&mut self, alpha: f32) -> Result<(), RetrieveError> {
451        // Random permutation of nodes
452        let mut nodes: Vec<u32> = (0..self.num_vectors as u32).collect();
453        nodes.shuffle(&mut rand::rng());
454
455        for &i in &nodes {
456            let query_vec = self.get_vector(i);
457
458            // Greedy search to find candidates
459            // We use the graph as it exists so far
460            let (visited, _) =
461                self.greedy_search(query_vec, self.params.ef_construction, self.start_node);
462
463            // Candidate set V = visited nodes
464            // Run RobustPrune on V to find new neighbors for i
465            let new_neighbors = self.robust_prune(i, &visited, alpha, self.params.m);
466
467            // Update graph: add directed edges
468            self.adj[i as usize] = new_neighbors.into_iter().collect();
469
470            // Note: In full DiskANN, we'd also add reverse edges to keep graph undirected/balanced,
471            // but vanilla Vamana works well with directed edges refined this way.
472            // For production, we'd enforce max degree on reverse updates.
473        }
474
475        Ok(())
476    }
477
478    /// RobustPrune (Alpha-Pruning) algorithm.
479    ///
480    /// Selects neighbors that are close to `node`, but also "orthogonal" to each other
481    /// to ensure good coverage of the space.
482    fn robust_prune(
483        &self,
484        node: u32,
485        candidates: &[u32],
486        alpha: f32,
487        max_degree: usize,
488    ) -> Vec<u32> {
489        let node_vec = self.get_vector(node);
490
491        // 1. Calculate distances to all candidates
492        let mut candidates_with_dist: Vec<Candidate> = candidates
493            .iter()
494            .filter(|&&c| c != node) // distinct
495            .map(|&c| Candidate {
496                id: c,
497                dist: self.dist(node_vec, self.get_vector(c)),
498            })
499            .collect();
500
501        // Add current neighbors to candidate set (to refine them)
502        for &neighbor in &self.adj[node as usize] {
503            if !candidates.contains(&neighbor) {
504                candidates_with_dist.push(Candidate {
505                    id: neighbor,
506                    dist: self.dist(node_vec, self.get_vector(neighbor)),
507                });
508            }
509        }
510
511        // 2. Sort by distance (ascending)
512        candidates_with_dist.sort_by(|a, b| a.dist.total_cmp(&b.dist));
513
514        // 3. Prune
515        let mut new_neighbors: Vec<u32> = Vec::with_capacity(max_degree);
516
517        // Remove duplicates if any
518        candidates_with_dist.dedup_by(|a, b| a.id == b.id);
519
520        for cand in candidates_with_dist {
521            if new_neighbors.len() >= max_degree {
522                break;
523            }
524
525            // Check if cand is reachable from any existing neighbor with shorter path
526            // alpha parameter controls "shorter": distance(p*, p') <= alpha * distance(p, p')
527            let mut prune = false;
528            let cand_vec = self.get_vector(cand.id);
529
530            for &existing_neighbor in &new_neighbors {
531                let dist_existing_cand = self.dist(self.get_vector(existing_neighbor), cand_vec);
532
533                // If existing neighbor is closer to candidate than node is (scaled by alpha),
534                // then candidate is redundant (we can reach it via existing neighbor).
535                if alpha * dist_existing_cand <= cand.dist {
536                    prune = true;
537                    break;
538                }
539            }
540
541            if !prune {
542                new_neighbors.push(cand.id);
543            }
544        }
545
546        new_neighbors
547    }
548
549    /// Greedy search for construction and querying.
550    ///
551    /// Returns (visited_nodes, nearest_candidates).
552    fn greedy_search(
553        &self,
554        query: &[f32],
555        l_size: usize,
556        start_node: u32,
557    ) -> (Vec<u32>, Vec<Candidate>) {
558        let mut visited = HashSet::new();
559        // Note: We use retset Vec instead of BinaryHeap for simpler control over L closest
560
561        // Use a max-heap for the working queue to easily pop the worst candidate
562        // Wait, standard beam search keeps L closest.
563        // Let's implement standard "iterate until convergence" greedy search.
564
565        // Results set (L closest found so far) - sorted vector or binary heap
566        // We'll use a vector and sort it, for simplicity in this proto.
567        let mut retset: Vec<Candidate> = Vec::with_capacity(l_size + 1);
568
569        let start_dist = self.dist(query, self.get_vector(start_node));
570        retset.push(Candidate {
571            id: start_node,
572            dist: start_dist,
573        });
574        visited.insert(start_node);
575
576        let mut current_idx = 0;
577
578        while current_idx < retset.len() {
579            // Find the closest unvisited node in retset
580            // (In optimized impl, we iterate sorted retset)
581            retset.sort_by(|a, b| a.dist.total_cmp(&b.dist));
582
583            if current_idx >= retset.len() {
584                break;
585            }
586
587            let current = retset[current_idx];
588            current_idx += 1;
589
590            // If closest unvisited is farther than our worst candidate (and list is full), stop?
591            // Vamana doesn't strictly stop, it explores all neighbors.
592
593            for &neighbor in &self.adj[current.id as usize] {
594                if visited.contains(&neighbor) {
595                    continue;
596                }
597                visited.insert(neighbor);
598
599                let dist = self.dist(query, self.get_vector(neighbor));
600
601                // Add to retset
602                retset.push(Candidate { id: neighbor, dist });
603            }
604
605            // Keep only top L
606            retset.sort_by(|a, b| a.dist.total_cmp(&b.dist));
607            if retset.len() > l_size {
608                retset.truncate(l_size);
609            }
610        }
611
612        let ids: Vec<u32> = retset.iter().map(|c| c.id).collect();
613        (ids, retset)
614    }
615
616    /// Search for k nearest neighbors.
617    pub fn search(
618        &self,
619        query: &[f32],
620        k: usize,
621        ef_search: usize,
622    ) -> Result<Vec<(u32, f32)>, RetrieveError> {
623        if !self.built {
624            return Err(RetrieveError::Other(
625                "Index must be built before search".to_string(),
626            ));
627        }
628
629        if query.len() != self.dimension {
630            return Err(RetrieveError::DimensionMismatch {
631                query_dim: self.dimension,
632                doc_dim: query.len(),
633            });
634        }
635
636        let ef = ef_search.max(k);
637        let (_, candidates) = self.greedy_search(query, ef, self.start_node);
638
639        // Return top k
640        let result = candidates
641            .into_iter()
642            .take(k)
643            .map(|c| (c.id, c.dist))
644            .collect();
645
646        Ok(result)
647    }
648
649    fn get_vector(&self, idx: u32) -> &[f32] {
650        let start = idx as usize * self.dimension;
651        &self.vectors[start..start + self.dimension]
652    }
653
654    // Euclidean distance (squared)
655    fn dist(&self, a: &[f32], b: &[f32]) -> f32 {
656        // In full impl, use SIMD from crate::simd
657        a.iter().zip(b.iter()).map(|(x, y)| (x - y) * (x - y)).sum()
658    }
659}