manifoldb_vector/index/
graph.rs

1//! HNSW graph data structure.
2//!
3//! This module contains the in-memory graph representation for HNSW.
4//! The graph is a multi-layer structure where each node can have connections
5//! to other nodes in the same layer.
6
7use std::cmp::Ordering;
8use std::collections::{BinaryHeap, HashMap, HashSet};
9
10use manifoldb_core::EntityId;
11
12use crate::distance::DistanceMetric;
13use crate::types::Embedding;
14
15/// A node in the HNSW graph.
16#[derive(Debug, Clone)]
17pub struct HnswNode {
18    /// The entity ID this node represents.
19    pub entity_id: EntityId,
20    /// The embedding vector.
21    pub embedding: Embedding,
22    /// The maximum layer this node appears in.
23    pub max_layer: usize,
24    /// Connections to other nodes, indexed by layer.
25    /// `connections[layer]` = list of neighbor entity IDs
26    pub connections: Vec<Vec<EntityId>>,
27}
28
29impl HnswNode {
30    /// Create a new HNSW node.
31    #[inline]
32    pub fn new(entity_id: EntityId, embedding: Embedding, max_layer: usize) -> Self {
33        let connections = vec![Vec::new(); max_layer + 1];
34        Self { entity_id, embedding, max_layer, connections }
35    }
36
37    /// Get the connections at a specific layer.
38    #[inline]
39    #[must_use]
40    pub fn connections_at(&self, layer: usize) -> &[EntityId] {
41        self.connections.get(layer).map_or(&[], |c| c.as_slice())
42    }
43
44    /// Add a connection at a specific layer.
45    #[inline]
46    pub fn add_connection(&mut self, layer: usize, neighbor: EntityId) {
47        if layer < self.connections.len() && !self.connections[layer].contains(&neighbor) {
48            self.connections[layer].push(neighbor);
49        }
50    }
51
52    /// Remove a connection at a specific layer.
53    #[inline]
54    pub fn remove_connection(&mut self, layer: usize, neighbor: EntityId) {
55        if layer < self.connections.len() {
56            self.connections[layer].retain(|&id| id != neighbor);
57        }
58    }
59
60    /// Set the connections at a specific layer, replacing existing ones.
61    #[inline]
62    pub fn set_connections(&mut self, layer: usize, neighbors: Vec<EntityId>) {
63        if layer < self.connections.len() {
64            self.connections[layer] = neighbors;
65        }
66    }
67}
68
69/// The HNSW graph structure.
70#[derive(Debug)]
71pub struct HnswGraph {
72    /// All nodes in the graph, keyed by entity ID.
73    pub nodes: HashMap<EntityId, HnswNode>,
74    /// The entry point node (highest level node).
75    pub entry_point: Option<EntityId>,
76    /// The current maximum layer in the graph.
77    pub max_layer: usize,
78    /// The distance metric used for similarity.
79    pub distance_metric: DistanceMetric,
80    /// The dimension of embeddings in this graph.
81    pub dimension: usize,
82}
83
84impl HnswGraph {
85    /// Create a new empty HNSW graph.
86    #[must_use]
87    pub fn new(dimension: usize, distance_metric: DistanceMetric) -> Self {
88        Self { nodes: HashMap::new(), entry_point: None, max_layer: 0, distance_metric, dimension }
89    }
90
91    /// Get a node by entity ID.
92    #[inline]
93    #[must_use]
94    pub fn get_node(&self, entity_id: EntityId) -> Option<&HnswNode> {
95        self.nodes.get(&entity_id)
96    }
97
98    /// Get a mutable node by entity ID.
99    #[inline]
100    pub fn get_node_mut(&mut self, entity_id: EntityId) -> Option<&mut HnswNode> {
101        self.nodes.get_mut(&entity_id)
102    }
103
104    /// Check if a node exists in the graph.
105    #[inline]
106    #[must_use]
107    pub fn contains(&self, entity_id: EntityId) -> bool {
108        self.nodes.contains_key(&entity_id)
109    }
110
111    /// Get the number of nodes in the graph.
112    #[inline]
113    #[must_use]
114    pub fn len(&self) -> usize {
115        self.nodes.len()
116    }
117
118    /// Check if the graph is empty.
119    #[inline]
120    #[must_use]
121    pub fn is_empty(&self) -> bool {
122        self.nodes.is_empty()
123    }
124
125    /// Calculate the distance between two embeddings.
126    #[inline]
127    #[must_use]
128    pub fn distance(&self, a: &Embedding, b: &Embedding) -> f32 {
129        match self.distance_metric {
130            DistanceMetric::Euclidean => crate::distance::euclidean_distance(a, b),
131            DistanceMetric::Cosine => crate::distance::cosine_distance(a, b),
132            DistanceMetric::DotProduct => -crate::distance::dot_product(a, b), // Negate for min-distance
133            DistanceMetric::Manhattan => crate::distance::manhattan_distance(a, b),
134            DistanceMetric::Chebyshev => crate::distance::chebyshev_distance(a, b),
135        }
136    }
137
138    /// Calculate the distance from a query to a node.
139    #[inline]
140    #[must_use]
141    pub fn distance_to_node(&self, query: &Embedding, entity_id: EntityId) -> Option<f32> {
142        self.nodes.get(&entity_id).map(|node| self.distance(query, &node.embedding))
143    }
144
145    /// Insert a node into the graph.
146    pub fn insert_node(&mut self, node: HnswNode) {
147        let entity_id = node.entity_id;
148        let max_layer = node.max_layer;
149
150        // Update entry point if this is the first node or has a higher layer
151        if self.entry_point.is_none() || max_layer > self.max_layer {
152            self.entry_point = Some(entity_id);
153            self.max_layer = max_layer;
154        }
155
156        self.nodes.insert(entity_id, node);
157    }
158
159    /// Remove a node from the graph.
160    pub fn remove_node(&mut self, entity_id: EntityId) -> Option<HnswNode> {
161        let node = self.nodes.remove(&entity_id)?;
162
163        // Remove connections to this node from all neighbors
164        for layer in 0..=node.max_layer {
165            for &neighbor_id in &node.connections[layer] {
166                if let Some(neighbor) = self.nodes.get_mut(&neighbor_id) {
167                    neighbor.remove_connection(layer, entity_id);
168                }
169            }
170        }
171
172        // Update entry point if we removed it
173        if self.entry_point == Some(entity_id) {
174            self.update_entry_point();
175        }
176
177        Some(node)
178    }
179
180    /// Find a new entry point after removal.
181    fn update_entry_point(&mut self) {
182        // Find the node with the highest max_layer
183        let new_entry = self
184            .nodes
185            .iter()
186            .max_by_key(|(_, node)| node.max_layer)
187            .map(|(&id, node)| (id, node.max_layer));
188
189        if let Some((id, max_layer)) = new_entry {
190            self.entry_point = Some(id);
191            self.max_layer = max_layer;
192        } else {
193            self.entry_point = None;
194            self.max_layer = 0;
195        }
196    }
197}
198
199/// A candidate during HNSW search.
200///
201/// Used in the priority queue for greedy search.
202#[derive(Debug, Clone, Copy)]
203pub struct Candidate {
204    /// The entity ID of this candidate.
205    pub entity_id: EntityId,
206    /// The distance to the query.
207    pub distance: f32,
208}
209
210impl Candidate {
211    /// Create a new candidate.
212    #[inline]
213    #[must_use]
214    pub const fn new(entity_id: EntityId, distance: f32) -> Self {
215        Self { entity_id, distance }
216    }
217}
218
219impl PartialEq for Candidate {
220    #[inline]
221    fn eq(&self, other: &Self) -> bool {
222        self.distance == other.distance && self.entity_id == other.entity_id
223    }
224}
225
226impl Eq for Candidate {}
227
228impl PartialOrd for Candidate {
229    #[inline]
230    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
231        Some(self.cmp(other))
232    }
233}
234
235impl Ord for Candidate {
236    #[inline]
237    fn cmp(&self, other: &Self) -> Ordering {
238        // Reverse ordering for min-heap (smallest distance first)
239        // NaN values are treated as equal to maintain a total ordering for the heap.
240        // In practice, NaN distances should not occur from valid distance calculations.
241        other.distance.partial_cmp(&self.distance).unwrap_or(Ordering::Equal)
242    }
243}
244
245/// A max-heap candidate for tracking the worst element in the result set.
246#[derive(Debug, Clone, Copy)]
247pub struct MaxCandidate(pub Candidate);
248
249impl PartialEq for MaxCandidate {
250    #[inline]
251    fn eq(&self, other: &Self) -> bool {
252        self.0 == other.0
253    }
254}
255
256impl Eq for MaxCandidate {}
257
258impl PartialOrd for MaxCandidate {
259    #[inline]
260    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
261        Some(self.cmp(other))
262    }
263}
264
265impl Ord for MaxCandidate {
266    #[inline]
267    fn cmp(&self, other: &Self) -> Ordering {
268        // Normal ordering for max-heap (largest distance first)
269        // NaN values are treated as equal to maintain a total ordering for the heap.
270        // In practice, NaN distances should not occur from valid distance calculations.
271        self.0.distance.partial_cmp(&other.0.distance).unwrap_or(Ordering::Equal)
272    }
273}
274
275/// Search layer for a single entry point.
276///
277/// Performs a greedy search starting from the entry point, returning
278/// the ef closest candidates to the query.
279pub fn search_layer(
280    graph: &HnswGraph,
281    query: &Embedding,
282    entry_points: &[EntityId],
283    ef: usize,
284    layer: usize,
285) -> Vec<Candidate> {
286    search_layer_filtered(graph, query, entry_points, ef, layer, |_| true)
287}
288
289/// Search layer with a filter predicate applied during traversal.
290///
291/// Only nodes that pass the predicate are included in the results.
292/// Nodes that fail the predicate are still traversed (their neighbors are explored)
293/// but they are not added to the result set.
294///
295/// # Arguments
296///
297/// * `graph` - The HNSW graph to search
298/// * `query` - The query embedding
299/// * `entry_points` - Initial entry points to start search from
300/// * `ef` - The beam width (number of candidates to track)
301/// * `layer` - The layer to search
302/// * `predicate` - A predicate function that returns true for nodes to include in results
303///
304/// # Returns
305///
306/// A vector of candidates that pass the predicate, sorted by distance.
307pub fn search_layer_filtered<F>(
308    graph: &HnswGraph,
309    query: &Embedding,
310    entry_points: &[EntityId],
311    ef: usize,
312    layer: usize,
313    predicate: F,
314) -> Vec<Candidate>
315where
316    F: Fn(EntityId) -> bool,
317{
318    if entry_points.is_empty() {
319        return Vec::new();
320    }
321
322    // Initialize candidates with entry points
323    let mut candidates: BinaryHeap<Candidate> = BinaryHeap::new();
324    let mut results: BinaryHeap<MaxCandidate> = BinaryHeap::new();
325    let mut visited: HashSet<EntityId> = HashSet::new();
326
327    for &ep in entry_points {
328        if let Some(dist) = graph.distance_to_node(query, ep) {
329            visited.insert(ep);
330            let candidate = Candidate::new(ep, dist);
331            candidates.push(candidate);
332            // Only add to results if it passes the predicate
333            if predicate(ep) {
334                results.push(MaxCandidate(candidate));
335            }
336        }
337    }
338
339    // Greedy search
340    while let Some(current) = candidates.pop() {
341        // Get the furthest result
342        let furthest_result = results.peek().map_or(f32::INFINITY, |c| c.0.distance);
343
344        // If the closest candidate is further than the furthest result, we're done
345        // But only if we have enough results
346        if current.distance > furthest_result && results.len() >= ef {
347            break;
348        }
349
350        // Explore neighbors
351        if let Some(node) = graph.get_node(current.entity_id) {
352            for &neighbor_id in node.connections_at(layer) {
353                if visited.contains(&neighbor_id) {
354                    continue;
355                }
356                visited.insert(neighbor_id);
357
358                if let Some(neighbor_dist) = graph.distance_to_node(query, neighbor_id) {
359                    let furthest_result = results.peek().map_or(f32::INFINITY, |c| c.0.distance);
360
361                    // Always explore for graph traversal (add to candidates)
362                    // This ensures we find good paths even through non-matching nodes
363                    let neighbor_candidate = Candidate::new(neighbor_id, neighbor_dist);
364                    candidates.push(neighbor_candidate);
365
366                    // Only add to results if it passes the predicate
367                    if predicate(neighbor_id)
368                        && (results.len() < ef || neighbor_dist < furthest_result)
369                    {
370                        results.push(MaxCandidate(neighbor_candidate));
371
372                        // Trim results to ef size
373                        if results.len() > ef {
374                            results.pop();
375                        }
376                    }
377                }
378            }
379        }
380    }
381
382    // Convert results to vector, sorted by distance
383    let mut result_vec: Vec<Candidate> = results.into_iter().map(|mc| mc.0).collect();
384    result_vec.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap_or(Ordering::Equal));
385    result_vec
386}
387
388/// Select the best neighbors for a node using a simple heuristic.
389///
390/// This uses a simple approach: keep the M closest neighbors.
391/// This is an alternative to [`select_neighbors_heuristic`] that
392/// may be faster for small candidate sets.
393pub fn select_neighbors_simple(candidates: &[Candidate], m: usize) -> Vec<EntityId> {
394    candidates.iter().take(m).map(|c| c.entity_id).collect()
395}
396
397/// Select neighbors using the heuristic algorithm (Algorithm 4 from the paper).
398///
399/// This algorithm tries to ensure diversity in the neighborhood by preferring
400/// neighbors that are not too close to each other.
401pub fn select_neighbors_heuristic(
402    graph: &HnswGraph,
403    _query: &Embedding,
404    candidates: &[Candidate],
405    m: usize,
406    _extend_candidates: bool,
407) -> Vec<EntityId> {
408    if candidates.len() <= m {
409        return candidates.iter().map(|c| c.entity_id).collect();
410    }
411
412    let mut selected: Vec<EntityId> = Vec::with_capacity(m);
413    let mut remaining: Vec<Candidate> = candidates.to_vec();
414
415    // Sort by distance
416    remaining.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap_or(Ordering::Equal));
417
418    for candidate in remaining {
419        if selected.len() >= m {
420            break;
421        }
422
423        // Check if this candidate is good (diverse enough from already selected)
424        let mut is_good = true;
425        let candidate_embedding = match graph.get_node(candidate.entity_id) {
426            Some(node) => &node.embedding,
427            None => continue,
428        };
429
430        for &selected_id in &selected {
431            if let Some(selected_node) = graph.get_node(selected_id) {
432                let dist_to_selected =
433                    graph.distance(candidate_embedding, &selected_node.embedding);
434                // If candidate is closer to an already selected node than to the query,
435                // it might not provide diverse coverage
436                if dist_to_selected < candidate.distance {
437                    is_good = false;
438                    break;
439                }
440            }
441        }
442
443        if is_good || selected.is_empty() {
444            selected.push(candidate.entity_id);
445        }
446    }
447
448    // If we didn't get enough diverse neighbors, fill with closest remaining
449    if selected.len() < m {
450        let remaining: Vec<Candidate> =
451            candidates.iter().filter(|c| !selected.contains(&c.entity_id)).copied().collect();
452
453        for candidate in remaining {
454            if selected.len() >= m {
455                break;
456            }
457            selected.push(candidate.entity_id);
458        }
459    }
460
461    selected
462}
463
464#[cfg(test)]
465mod tests {
466    use super::*;
467
468    fn create_test_embedding(dim: usize, value: f32) -> Embedding {
469        Embedding::new(vec![value; dim]).unwrap()
470    }
471
472    #[test]
473    fn test_hnsw_node_creation() {
474        let embedding = create_test_embedding(4, 1.0);
475        let node = HnswNode::new(EntityId::new(1), embedding.clone(), 2);
476
477        assert_eq!(node.entity_id, EntityId::new(1));
478        assert_eq!(node.max_layer, 2);
479        assert_eq!(node.connections.len(), 3); // layers 0, 1, 2
480    }
481
482    #[test]
483    fn test_node_connections() {
484        let embedding = create_test_embedding(4, 1.0);
485        let mut node = HnswNode::new(EntityId::new(1), embedding, 1);
486
487        node.add_connection(0, EntityId::new(2));
488        node.add_connection(0, EntityId::new(3));
489        node.add_connection(1, EntityId::new(4));
490
491        assert_eq!(node.connections_at(0), &[EntityId::new(2), EntityId::new(3)]);
492        assert_eq!(node.connections_at(1), &[EntityId::new(4)]);
493
494        node.remove_connection(0, EntityId::new(2));
495        assert_eq!(node.connections_at(0), &[EntityId::new(3)]);
496    }
497
498    #[test]
499    fn test_graph_insert_and_remove() {
500        let mut graph = HnswGraph::new(4, DistanceMetric::Euclidean);
501
502        let node1 = HnswNode::new(EntityId::new(1), create_test_embedding(4, 1.0), 2);
503        let node2 = HnswNode::new(EntityId::new(2), create_test_embedding(4, 2.0), 1);
504
505        graph.insert_node(node1);
506        assert_eq!(graph.entry_point, Some(EntityId::new(1)));
507        assert_eq!(graph.max_layer, 2);
508
509        graph.insert_node(node2);
510        assert_eq!(graph.entry_point, Some(EntityId::new(1))); // Still node 1 (higher layer)
511        assert_eq!(graph.len(), 2);
512
513        graph.remove_node(EntityId::new(1));
514        assert_eq!(graph.entry_point, Some(EntityId::new(2)));
515        assert_eq!(graph.max_layer, 1);
516    }
517
518    #[test]
519    fn test_candidate_ordering() {
520        let c1 = Candidate::new(EntityId::new(1), 1.0);
521        let c2 = Candidate::new(EntityId::new(2), 2.0);
522        let c3 = Candidate::new(EntityId::new(3), 0.5);
523
524        let mut heap: BinaryHeap<Candidate> = BinaryHeap::new();
525        heap.push(c1);
526        heap.push(c2);
527        heap.push(c3);
528
529        // Min-heap: should pop smallest first
530        assert_eq!(heap.pop().unwrap().entity_id, EntityId::new(3));
531        assert_eq!(heap.pop().unwrap().entity_id, EntityId::new(1));
532        assert_eq!(heap.pop().unwrap().entity_id, EntityId::new(2));
533    }
534
535    #[test]
536    fn test_search_layer_empty() {
537        let graph = HnswGraph::new(4, DistanceMetric::Euclidean);
538        let query = create_test_embedding(4, 1.0);
539
540        let results = search_layer(&graph, &query, &[], 10, 0);
541        assert!(results.is_empty());
542    }
543
544    #[test]
545    fn test_search_layer_single_node() {
546        let mut graph = HnswGraph::new(4, DistanceMetric::Euclidean);
547        let node = HnswNode::new(EntityId::new(1), create_test_embedding(4, 1.0), 0);
548        graph.insert_node(node);
549
550        let query = create_test_embedding(4, 2.0);
551        let results = search_layer(&graph, &query, &[EntityId::new(1)], 10, 0);
552
553        assert_eq!(results.len(), 1);
554        assert_eq!(results[0].entity_id, EntityId::new(1));
555    }
556}