1use std::cmp::Ordering;
8use std::collections::{BinaryHeap, HashMap, HashSet};
9
10use manifoldb_core::EntityId;
11
12use crate::distance::DistanceMetric;
13use crate::types::Embedding;
14
15#[derive(Debug, Clone)]
17pub struct HnswNode {
18 pub entity_id: EntityId,
20 pub embedding: Embedding,
22 pub max_layer: usize,
24 pub connections: Vec<Vec<EntityId>>,
27}
28
29impl HnswNode {
30 #[inline]
32 pub fn new(entity_id: EntityId, embedding: Embedding, max_layer: usize) -> Self {
33 let connections = vec![Vec::new(); max_layer + 1];
34 Self { entity_id, embedding, max_layer, connections }
35 }
36
37 #[inline]
39 #[must_use]
40 pub fn connections_at(&self, layer: usize) -> &[EntityId] {
41 self.connections.get(layer).map_or(&[], |c| c.as_slice())
42 }
43
44 #[inline]
46 pub fn add_connection(&mut self, layer: usize, neighbor: EntityId) {
47 if layer < self.connections.len() && !self.connections[layer].contains(&neighbor) {
48 self.connections[layer].push(neighbor);
49 }
50 }
51
52 #[inline]
54 pub fn remove_connection(&mut self, layer: usize, neighbor: EntityId) {
55 if layer < self.connections.len() {
56 self.connections[layer].retain(|&id| id != neighbor);
57 }
58 }
59
60 #[inline]
62 pub fn set_connections(&mut self, layer: usize, neighbors: Vec<EntityId>) {
63 if layer < self.connections.len() {
64 self.connections[layer] = neighbors;
65 }
66 }
67}
68
69#[derive(Debug)]
71pub struct HnswGraph {
72 pub nodes: HashMap<EntityId, HnswNode>,
74 pub entry_point: Option<EntityId>,
76 pub max_layer: usize,
78 pub distance_metric: DistanceMetric,
80 pub dimension: usize,
82}
83
84impl HnswGraph {
85 #[must_use]
87 pub fn new(dimension: usize, distance_metric: DistanceMetric) -> Self {
88 Self { nodes: HashMap::new(), entry_point: None, max_layer: 0, distance_metric, dimension }
89 }
90
91 #[inline]
93 #[must_use]
94 pub fn get_node(&self, entity_id: EntityId) -> Option<&HnswNode> {
95 self.nodes.get(&entity_id)
96 }
97
98 #[inline]
100 pub fn get_node_mut(&mut self, entity_id: EntityId) -> Option<&mut HnswNode> {
101 self.nodes.get_mut(&entity_id)
102 }
103
104 #[inline]
106 #[must_use]
107 pub fn contains(&self, entity_id: EntityId) -> bool {
108 self.nodes.contains_key(&entity_id)
109 }
110
111 #[inline]
113 #[must_use]
114 pub fn len(&self) -> usize {
115 self.nodes.len()
116 }
117
118 #[inline]
120 #[must_use]
121 pub fn is_empty(&self) -> bool {
122 self.nodes.is_empty()
123 }
124
125 #[inline]
127 #[must_use]
128 pub fn distance(&self, a: &Embedding, b: &Embedding) -> f32 {
129 match self.distance_metric {
130 DistanceMetric::Euclidean => crate::distance::euclidean_distance(a, b),
131 DistanceMetric::Cosine => crate::distance::cosine_distance(a, b),
132 DistanceMetric::DotProduct => -crate::distance::dot_product(a, b), DistanceMetric::Manhattan => crate::distance::manhattan_distance(a, b),
134 DistanceMetric::Chebyshev => crate::distance::chebyshev_distance(a, b),
135 }
136 }
137
138 #[inline]
140 #[must_use]
141 pub fn distance_to_node(&self, query: &Embedding, entity_id: EntityId) -> Option<f32> {
142 self.nodes.get(&entity_id).map(|node| self.distance(query, &node.embedding))
143 }
144
145 pub fn insert_node(&mut self, node: HnswNode) {
147 let entity_id = node.entity_id;
148 let max_layer = node.max_layer;
149
150 if self.entry_point.is_none() || max_layer > self.max_layer {
152 self.entry_point = Some(entity_id);
153 self.max_layer = max_layer;
154 }
155
156 self.nodes.insert(entity_id, node);
157 }
158
159 pub fn remove_node(&mut self, entity_id: EntityId) -> Option<HnswNode> {
161 let node = self.nodes.remove(&entity_id)?;
162
163 for layer in 0..=node.max_layer {
165 for &neighbor_id in &node.connections[layer] {
166 if let Some(neighbor) = self.nodes.get_mut(&neighbor_id) {
167 neighbor.remove_connection(layer, entity_id);
168 }
169 }
170 }
171
172 if self.entry_point == Some(entity_id) {
174 self.update_entry_point();
175 }
176
177 Some(node)
178 }
179
180 fn update_entry_point(&mut self) {
182 let new_entry = self
184 .nodes
185 .iter()
186 .max_by_key(|(_, node)| node.max_layer)
187 .map(|(&id, node)| (id, node.max_layer));
188
189 if let Some((id, max_layer)) = new_entry {
190 self.entry_point = Some(id);
191 self.max_layer = max_layer;
192 } else {
193 self.entry_point = None;
194 self.max_layer = 0;
195 }
196 }
197}
198
199#[derive(Debug, Clone, Copy)]
203pub struct Candidate {
204 pub entity_id: EntityId,
206 pub distance: f32,
208}
209
210impl Candidate {
211 #[inline]
213 #[must_use]
214 pub const fn new(entity_id: EntityId, distance: f32) -> Self {
215 Self { entity_id, distance }
216 }
217}
218
219impl PartialEq for Candidate {
220 #[inline]
221 fn eq(&self, other: &Self) -> bool {
222 self.distance == other.distance && self.entity_id == other.entity_id
223 }
224}
225
226impl Eq for Candidate {}
227
228impl PartialOrd for Candidate {
229 #[inline]
230 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
231 Some(self.cmp(other))
232 }
233}
234
235impl Ord for Candidate {
236 #[inline]
237 fn cmp(&self, other: &Self) -> Ordering {
238 other.distance.partial_cmp(&self.distance).unwrap_or(Ordering::Equal)
242 }
243}
244
245#[derive(Debug, Clone, Copy)]
247pub struct MaxCandidate(pub Candidate);
248
249impl PartialEq for MaxCandidate {
250 #[inline]
251 fn eq(&self, other: &Self) -> bool {
252 self.0 == other.0
253 }
254}
255
256impl Eq for MaxCandidate {}
257
258impl PartialOrd for MaxCandidate {
259 #[inline]
260 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
261 Some(self.cmp(other))
262 }
263}
264
265impl Ord for MaxCandidate {
266 #[inline]
267 fn cmp(&self, other: &Self) -> Ordering {
268 self.0.distance.partial_cmp(&other.0.distance).unwrap_or(Ordering::Equal)
272 }
273}
274
275pub fn search_layer(
280 graph: &HnswGraph,
281 query: &Embedding,
282 entry_points: &[EntityId],
283 ef: usize,
284 layer: usize,
285) -> Vec<Candidate> {
286 search_layer_filtered(graph, query, entry_points, ef, layer, |_| true)
287}
288
289pub fn search_layer_filtered<F>(
308 graph: &HnswGraph,
309 query: &Embedding,
310 entry_points: &[EntityId],
311 ef: usize,
312 layer: usize,
313 predicate: F,
314) -> Vec<Candidate>
315where
316 F: Fn(EntityId) -> bool,
317{
318 if entry_points.is_empty() {
319 return Vec::new();
320 }
321
322 let mut candidates: BinaryHeap<Candidate> = BinaryHeap::new();
324 let mut results: BinaryHeap<MaxCandidate> = BinaryHeap::new();
325 let mut visited: HashSet<EntityId> = HashSet::new();
326
327 for &ep in entry_points {
328 if let Some(dist) = graph.distance_to_node(query, ep) {
329 visited.insert(ep);
330 let candidate = Candidate::new(ep, dist);
331 candidates.push(candidate);
332 if predicate(ep) {
334 results.push(MaxCandidate(candidate));
335 }
336 }
337 }
338
339 while let Some(current) = candidates.pop() {
341 let furthest_result = results.peek().map_or(f32::INFINITY, |c| c.0.distance);
343
344 if current.distance > furthest_result && results.len() >= ef {
347 break;
348 }
349
350 if let Some(node) = graph.get_node(current.entity_id) {
352 for &neighbor_id in node.connections_at(layer) {
353 if visited.contains(&neighbor_id) {
354 continue;
355 }
356 visited.insert(neighbor_id);
357
358 if let Some(neighbor_dist) = graph.distance_to_node(query, neighbor_id) {
359 let furthest_result = results.peek().map_or(f32::INFINITY, |c| c.0.distance);
360
361 let neighbor_candidate = Candidate::new(neighbor_id, neighbor_dist);
364 candidates.push(neighbor_candidate);
365
366 if predicate(neighbor_id)
368 && (results.len() < ef || neighbor_dist < furthest_result)
369 {
370 results.push(MaxCandidate(neighbor_candidate));
371
372 if results.len() > ef {
374 results.pop();
375 }
376 }
377 }
378 }
379 }
380 }
381
382 let mut result_vec: Vec<Candidate> = results.into_iter().map(|mc| mc.0).collect();
384 result_vec.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap_or(Ordering::Equal));
385 result_vec
386}
387
388pub fn select_neighbors_simple(candidates: &[Candidate], m: usize) -> Vec<EntityId> {
394 candidates.iter().take(m).map(|c| c.entity_id).collect()
395}
396
397pub fn select_neighbors_heuristic(
402 graph: &HnswGraph,
403 _query: &Embedding,
404 candidates: &[Candidate],
405 m: usize,
406 _extend_candidates: bool,
407) -> Vec<EntityId> {
408 if candidates.len() <= m {
409 return candidates.iter().map(|c| c.entity_id).collect();
410 }
411
412 let mut selected: Vec<EntityId> = Vec::with_capacity(m);
413 let mut remaining: Vec<Candidate> = candidates.to_vec();
414
415 remaining.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap_or(Ordering::Equal));
417
418 for candidate in remaining {
419 if selected.len() >= m {
420 break;
421 }
422
423 let mut is_good = true;
425 let candidate_embedding = match graph.get_node(candidate.entity_id) {
426 Some(node) => &node.embedding,
427 None => continue,
428 };
429
430 for &selected_id in &selected {
431 if let Some(selected_node) = graph.get_node(selected_id) {
432 let dist_to_selected =
433 graph.distance(candidate_embedding, &selected_node.embedding);
434 if dist_to_selected < candidate.distance {
437 is_good = false;
438 break;
439 }
440 }
441 }
442
443 if is_good || selected.is_empty() {
444 selected.push(candidate.entity_id);
445 }
446 }
447
448 if selected.len() < m {
450 let remaining: Vec<Candidate> =
451 candidates.iter().filter(|c| !selected.contains(&c.entity_id)).copied().collect();
452
453 for candidate in remaining {
454 if selected.len() >= m {
455 break;
456 }
457 selected.push(candidate.entity_id);
458 }
459 }
460
461 selected
462}
463
464#[cfg(test)]
465mod tests {
466 use super::*;
467
468 fn create_test_embedding(dim: usize, value: f32) -> Embedding {
469 Embedding::new(vec![value; dim]).unwrap()
470 }
471
472 #[test]
473 fn test_hnsw_node_creation() {
474 let embedding = create_test_embedding(4, 1.0);
475 let node = HnswNode::new(EntityId::new(1), embedding.clone(), 2);
476
477 assert_eq!(node.entity_id, EntityId::new(1));
478 assert_eq!(node.max_layer, 2);
479 assert_eq!(node.connections.len(), 3); }
481
482 #[test]
483 fn test_node_connections() {
484 let embedding = create_test_embedding(4, 1.0);
485 let mut node = HnswNode::new(EntityId::new(1), embedding, 1);
486
487 node.add_connection(0, EntityId::new(2));
488 node.add_connection(0, EntityId::new(3));
489 node.add_connection(1, EntityId::new(4));
490
491 assert_eq!(node.connections_at(0), &[EntityId::new(2), EntityId::new(3)]);
492 assert_eq!(node.connections_at(1), &[EntityId::new(4)]);
493
494 node.remove_connection(0, EntityId::new(2));
495 assert_eq!(node.connections_at(0), &[EntityId::new(3)]);
496 }
497
498 #[test]
499 fn test_graph_insert_and_remove() {
500 let mut graph = HnswGraph::new(4, DistanceMetric::Euclidean);
501
502 let node1 = HnswNode::new(EntityId::new(1), create_test_embedding(4, 1.0), 2);
503 let node2 = HnswNode::new(EntityId::new(2), create_test_embedding(4, 2.0), 1);
504
505 graph.insert_node(node1);
506 assert_eq!(graph.entry_point, Some(EntityId::new(1)));
507 assert_eq!(graph.max_layer, 2);
508
509 graph.insert_node(node2);
510 assert_eq!(graph.entry_point, Some(EntityId::new(1))); assert_eq!(graph.len(), 2);
512
513 graph.remove_node(EntityId::new(1));
514 assert_eq!(graph.entry_point, Some(EntityId::new(2)));
515 assert_eq!(graph.max_layer, 1);
516 }
517
518 #[test]
519 fn test_candidate_ordering() {
520 let c1 = Candidate::new(EntityId::new(1), 1.0);
521 let c2 = Candidate::new(EntityId::new(2), 2.0);
522 let c3 = Candidate::new(EntityId::new(3), 0.5);
523
524 let mut heap: BinaryHeap<Candidate> = BinaryHeap::new();
525 heap.push(c1);
526 heap.push(c2);
527 heap.push(c3);
528
529 assert_eq!(heap.pop().unwrap().entity_id, EntityId::new(3));
531 assert_eq!(heap.pop().unwrap().entity_id, EntityId::new(1));
532 assert_eq!(heap.pop().unwrap().entity_id, EntityId::new(2));
533 }
534
535 #[test]
536 fn test_search_layer_empty() {
537 let graph = HnswGraph::new(4, DistanceMetric::Euclidean);
538 let query = create_test_embedding(4, 1.0);
539
540 let results = search_layer(&graph, &query, &[], 10, 0);
541 assert!(results.is_empty());
542 }
543
544 #[test]
545 fn test_search_layer_single_node() {
546 let mut graph = HnswGraph::new(4, DistanceMetric::Euclidean);
547 let node = HnswNode::new(EntityId::new(1), create_test_embedding(4, 1.0), 0);
548 graph.insert_node(node);
549
550 let query = create_test_embedding(4, 2.0);
551 let results = search_layer(&graph, &query, &[EntityId::new(1)], 10, 0);
552
553 assert_eq!(results.len(), 1);
554 assert_eq!(results[0].entity_id, EntityId::new(1));
555 }
556}