1use std::cmp::Reverse;
8use std::collections::BinaryHeap;
9use std::sync::Arc;
10
11use arrow_schema::{DataType, Field};
12use deepsize::DeepSizeOf;
13
14use crate::vector::hnsw::builder::HnswQueryParams;
15
16pub mod builder;
17
18use crate::vector::DIST_COL;
19
20use crate::vector::storage::DistCalculator;
21
22pub(crate) const NEIGHBORS_COL: &str = "__neighbors";
23
24use std::sync::LazyLock;
25
26pub static NEIGHBORS_FIELD: LazyLock<Field> = LazyLock::new(|| {
28 Field::new(
29 NEIGHBORS_COL,
30 DataType::List(Field::new_list_field(DataType::UInt32, true).into()),
31 true,
32 )
33});
34pub static DISTS_FIELD: LazyLock<Field> = LazyLock::new(|| {
35 Field::new(
36 DIST_COL,
37 DataType::List(Field::new_list_field(DataType::Float32, true).into()),
38 true,
39 )
40});
41
42pub struct GraphNode<I = u32> {
43 pub id: I,
44 pub neighbors: Vec<I>,
45}
46
47impl<I> GraphNode<I> {
48 pub fn new(id: I, neighbors: Vec<I>) -> Self {
49 Self { id, neighbors }
50 }
51}
52
53impl<I> From<I> for GraphNode<I> {
54 fn from(id: I) -> Self {
55 Self {
56 id,
57 neighbors: vec![],
58 }
59 }
60}
61
62#[derive(Debug, PartialEq, Clone, Copy, DeepSizeOf)]
65pub struct OrderedFloat(pub f32);
66
67impl PartialOrd for OrderedFloat {
68 #[inline(always)]
69 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
70 Some(self.cmp(other))
71 }
72}
73
74impl Eq for OrderedFloat {}
75
76impl Ord for OrderedFloat {
77 #[inline(always)]
78 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
79 self.0.total_cmp(&other.0)
80 }
81}
82
83impl From<f32> for OrderedFloat {
84 fn from(f: f32) -> Self {
85 Self(f)
86 }
87}
88
89impl From<OrderedFloat> for f32 {
90 fn from(f: OrderedFloat) -> Self {
91 f.0
92 }
93}
94
95#[derive(Debug, Eq, PartialEq, Clone, DeepSizeOf)]
96pub struct OrderedNode<T = u32>
97where
98 T: PartialEq + Eq,
99{
100 pub id: T,
101 pub dist: OrderedFloat,
102}
103
104impl<T: PartialEq + Eq> OrderedNode<T> {
105 pub fn new(id: T, dist: OrderedFloat) -> Self {
106 Self { id, dist }
107 }
108}
109
110impl<T: PartialEq + Eq> PartialOrd for OrderedNode<T> {
111 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
112 Some(self.cmp(other))
113 }
114}
115
116impl<T: PartialEq + Eq> Ord for OrderedNode<T> {
117 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
118 self.dist.cmp(&other.dist)
119 }
120}
121
122impl<T: PartialEq + Eq> From<(OrderedFloat, T)> for OrderedNode<T> {
123 fn from((dist, id): (OrderedFloat, T)) -> Self {
124 Self { id, dist }
125 }
126}
127
128impl<T: PartialEq + Eq> From<OrderedNode<T>> for (OrderedFloat, T) {
129 fn from(node: OrderedNode<T>) -> Self {
130 (node.dist, node.id)
131 }
132}
133
134pub trait DistanceCalculator {
139 fn compute_distances(&self, ids: &[u32]) -> Box<dyn Iterator<Item = f32>>;
142}
143
144pub trait Graph {
151 fn len(&self) -> usize;
153
154 fn is_empty(&self) -> bool {
156 self.len() == 0
157 }
158
159 fn neighbors(&self, key: u32) -> Arc<Vec<u32>>;
161}
162
163pub trait BorrowingGraph {
164 fn len(&self) -> usize;
166
167 fn is_empty(&self) -> bool {
169 self.len() == 0
170 }
171
172 fn neighbors(&self, key: u32) -> &[u32];
174}
175
176const WORD_BITS: usize = usize::BITS as usize;
177
178pub struct Visited<'a> {
180 visited: &'a mut Vec<usize>,
181 recently_visited: &'a mut Vec<u32>,
182}
183
184impl Visited<'_> {
185 pub fn insert(&mut self, node_id: u32) {
186 let node_id_usize = node_id as usize;
187 let word_index = node_id_usize / WORD_BITS;
188 let mask = 1usize << (node_id_usize % WORD_BITS);
189 if self.visited[word_index] & mask == 0 {
190 self.visited[word_index] |= mask;
191 self.recently_visited.push(node_id);
192 }
193 }
194
195 pub fn contains(&self, node_id: u32) -> bool {
196 let node_id_usize = node_id as usize;
197 let word_index = node_id_usize / WORD_BITS;
198 let mask = 1usize << (node_id_usize % WORD_BITS);
199 self.visited[word_index] & mask != 0
200 }
201
202 #[inline(always)]
203 pub fn iter_ones(&self) -> impl Iterator<Item = usize> + '_ {
204 self.recently_visited
205 .iter()
206 .map(|node_id| *node_id as usize)
207 }
208
209 pub fn count_ones(&self) -> usize {
210 self.recently_visited.len()
211 }
212}
213
214impl Drop for Visited<'_> {
215 fn drop(&mut self) {
216 for node_id in self.recently_visited.iter().copied() {
217 let node_id_usize = node_id as usize;
218 let word_index = node_id_usize / WORD_BITS;
219 let mask = 1usize << (node_id_usize % WORD_BITS);
220 self.visited[word_index] &= !mask;
221 }
222 self.recently_visited.clear();
223 }
224}
225
226#[derive(Debug, Clone)]
227pub struct VisitedGenerator {
228 visited: Vec<usize>,
229 recently_visited: Vec<u32>,
230 capacity: usize,
231}
232
233impl VisitedGenerator {
234 pub fn new(capacity: usize) -> Self {
235 Self {
236 visited: vec![0; capacity.div_ceil(WORD_BITS)],
237 recently_visited: Vec::new(),
238 capacity,
239 }
240 }
241
242 pub fn generate(&mut self, node_count: usize) -> Visited<'_> {
243 if node_count > self.capacity {
244 let new_capacity = self.capacity.max(node_count).next_power_of_two();
245 self.visited.resize(new_capacity.div_ceil(WORD_BITS), 0);
246 self.capacity = new_capacity;
247 }
248 Visited {
249 visited: &mut self.visited,
250 recently_visited: &mut self.recently_visited,
251 }
252 }
253}
254
255fn process_neighbors_with_look_ahead<F>(
256 neighbors: &[u32],
257 mut process_neighbor: F,
258 look_ahead: Option<usize>,
259 dist_calc: &impl DistCalculator,
260) where
261 F: FnMut(u32),
262{
263 match look_ahead {
264 Some(look_ahead) => {
265 for i in 0..neighbors.len().saturating_sub(look_ahead) {
266 dist_calc.prefetch(neighbors[i + look_ahead]);
267 process_neighbor(neighbors[i]);
268 }
269 for neighbor in &neighbors[neighbors.len().saturating_sub(look_ahead)..] {
270 process_neighbor(*neighbor);
271 }
272 }
273 None => {
274 for neighbor in neighbors.iter() {
275 process_neighbor(*neighbor);
276 }
277 }
278 }
279}
280
281#[inline]
282fn furthest_distance(results: &BinaryHeap<OrderedNode>) -> OrderedFloat {
283 results
284 .peek()
285 .map(|node| node.dist)
286 .unwrap_or(OrderedFloat(f32::INFINITY))
287}
288
289#[inline]
290fn push_result(results: &mut BinaryHeap<OrderedNode>, candidate: OrderedNode, k: usize) {
291 if results.len() < k {
292 results.push(candidate);
293 } else if candidate.dist < results.peek().unwrap().dist {
294 results.pop();
295 results.push(candidate);
296 }
297}
298
299macro_rules! beam_search_loop {
300 (
301 $candidates:ident,
302 $results:ident,
303 $visited:ident,
304 $k:expr,
305 $dist_calc:expr,
306 $prefetch_distance:expr,
307 $accepts_result:expr,
308 |$current:ident, $process_neighbor:ident| $visit_neighbors:block
309 ) => {{
310 while !$candidates.is_empty() {
311 let $current = $candidates.pop().expect("candidates is empty").0;
312 let furthest = furthest_distance(&$results);
313
314 if $current.dist > furthest && $results.len() == $k {
315 break;
316 }
317
318 let $process_neighbor = |neighbor: u32| {
319 if $visited.contains(neighbor) {
320 return;
321 }
322 $visited.insert(neighbor);
323 let dist: OrderedFloat = $dist_calc.distance(neighbor).into();
324 if dist <= furthest || $results.len() < $k {
325 if $accepts_result(neighbor, dist) {
326 push_result(&mut $results, (dist, neighbor).into(), $k);
327 }
328 $candidates.push(Reverse((dist, neighbor).into()));
329 }
330 };
331 $visit_neighbors
332 }
333 }};
334}
335
336macro_rules! greedy_search_loop {
337 (
338 $current:ident,
339 $closest_dist:ident,
340 $dist_calc:expr,
341 $prefetch_distance:expr,
342 |$process_neighbor:ident| $visit_neighbors:block
343 ) => {{
344 loop {
345 let mut next = None;
346 let $process_neighbor = |neighbor: u32| {
347 let dist = $dist_calc.distance(neighbor);
348 if dist < $closest_dist {
349 $closest_dist = dist;
350 next = Some(neighbor);
351 }
352 };
353 $visit_neighbors
354
355 if let Some(next) = next {
356 $current = next;
357 } else {
358 break;
359 }
360 }
361 }};
362}
363
364pub fn beam_search(
389 graph: &dyn Graph,
390 ep: &OrderedNode,
391 params: &HnswQueryParams,
392 dist_calc: &impl DistCalculator,
393 bitset: Option<&Visited>,
394 prefetch_distance: Option<usize>,
395 visited: &mut Visited,
396) -> Vec<OrderedNode> {
397 let k = params.ef;
398 let mut candidates = BinaryHeap::with_capacity(k);
399 visited.insert(ep.id);
400 candidates.push(Reverse(ep.clone()));
401
402 let mut results = BinaryHeap::with_capacity(k);
403 let no_filter =
404 bitset.is_none() && params.lower_bound.is_none() && params.upper_bound.is_none();
405
406 if no_filter {
407 results.push(ep.clone());
408 let accepts_result = |_: u32, _: OrderedFloat| true;
409 beam_search_loop!(
410 candidates,
411 results,
412 visited,
413 k,
414 dist_calc,
415 prefetch_distance,
416 accepts_result,
417 |current, process_neighbor| {
418 let neighbors = graph.neighbors(current.id);
419 process_neighbors_with_look_ahead(
420 &neighbors,
421 process_neighbor,
422 prefetch_distance,
423 dist_calc,
424 );
425 }
426 );
427 return results.into_sorted_vec();
428 }
429
430 let lower_bound: OrderedFloat = params.lower_bound.unwrap_or(f32::MIN).into();
432 let upper_bound: OrderedFloat = params.upper_bound.unwrap_or(f32::MAX).into();
433
434 if bitset.map(|bitset| bitset.contains(ep.id)).unwrap_or(true)
435 && ep.dist >= lower_bound
436 && ep.dist < upper_bound
437 {
438 results.push(ep.clone());
439 }
440
441 let accepts_result = |node_id: u32, dist: OrderedFloat| {
442 bitset
443 .map(|bitset| bitset.contains(node_id))
444 .unwrap_or(true)
445 && dist >= lower_bound
446 && dist < upper_bound
447 };
448 beam_search_loop!(
449 candidates,
450 results,
451 visited,
452 k,
453 dist_calc,
454 prefetch_distance,
455 accepts_result,
456 |current, process_neighbor| {
457 let neighbors = graph.neighbors(current.id);
458 process_neighbors_with_look_ahead(
459 &neighbors,
460 process_neighbor,
461 prefetch_distance,
462 dist_calc,
463 );
464 }
465 );
466 results.into_sorted_vec()
467}
468
469pub fn beam_search_borrowed(
470 graph: &impl BorrowingGraph,
471 ep: &OrderedNode,
472 params: &HnswQueryParams,
473 dist_calc: &impl DistCalculator,
474 bitset: Option<&Visited>,
475 prefetch_distance: Option<usize>,
476 visited: &mut Visited,
477) -> Vec<OrderedNode> {
478 let k = params.ef;
479 let mut candidates = BinaryHeap::with_capacity(k);
480 visited.insert(ep.id);
481 candidates.push(Reverse(ep.clone()));
482
483 let mut results = BinaryHeap::with_capacity(k);
484 let no_filter =
485 bitset.is_none() && params.lower_bound.is_none() && params.upper_bound.is_none();
486
487 if no_filter {
488 results.push(ep.clone());
489 let accepts_result = |_: u32, _: OrderedFloat| true;
490 beam_search_loop!(
491 candidates,
492 results,
493 visited,
494 k,
495 dist_calc,
496 prefetch_distance,
497 accepts_result,
498 |current, process_neighbor| {
499 let neighbors = graph.neighbors(current.id);
500 process_neighbors_with_look_ahead(
501 neighbors,
502 process_neighbor,
503 prefetch_distance,
504 dist_calc,
505 );
506 }
507 );
508 return results.into_sorted_vec();
509 }
510
511 let lower_bound: OrderedFloat = params.lower_bound.unwrap_or(f32::MIN).into();
512 let upper_bound: OrderedFloat = params.upper_bound.unwrap_or(f32::MAX).into();
513
514 if bitset.map(|bitset| bitset.contains(ep.id)).unwrap_or(true)
515 && ep.dist >= lower_bound
516 && ep.dist < upper_bound
517 {
518 results.push(ep.clone());
519 }
520
521 let accepts_result = |node_id: u32, dist: OrderedFloat| {
522 bitset
523 .map(|bitset| bitset.contains(node_id))
524 .unwrap_or(true)
525 && dist >= lower_bound
526 && dist < upper_bound
527 };
528 beam_search_loop!(
529 candidates,
530 results,
531 visited,
532 k,
533 dist_calc,
534 prefetch_distance,
535 accepts_result,
536 |current, process_neighbor| {
537 let neighbors = graph.neighbors(current.id);
538 process_neighbors_with_look_ahead(
539 neighbors,
540 process_neighbor,
541 prefetch_distance,
542 dist_calc,
543 );
544 }
545 );
546 results.into_sorted_vec()
547}
548
549pub fn greedy_search(
568 graph: &dyn Graph,
569 start: OrderedNode,
570 dist_calc: &impl DistCalculator,
571 prefetch_distance: Option<usize>,
572) -> OrderedNode {
573 let mut current = start.id;
574 let mut closest_dist = start.dist.0;
575 greedy_search_loop!(
576 current,
577 closest_dist,
578 dist_calc,
579 prefetch_distance,
580 |process_neighbor| {
581 let neighbors = graph.neighbors(current);
582 process_neighbors_with_look_ahead(
583 &neighbors,
584 process_neighbor,
585 prefetch_distance,
586 dist_calc,
587 );
588 }
589 );
590 OrderedNode::new(current, closest_dist.into())
591}
592
593pub fn greedy_search_borrowed(
594 graph: &impl BorrowingGraph,
595 start: OrderedNode,
596 dist_calc: &impl DistCalculator,
597 prefetch_distance: Option<usize>,
598) -> OrderedNode {
599 let mut current = start.id;
600 let mut closest_dist = start.dist.0;
601 greedy_search_loop!(
602 current,
603 closest_dist,
604 dist_calc,
605 prefetch_distance,
606 |process_neighbor| {
607 let neighbors = graph.neighbors(current);
608 process_neighbors_with_look_ahead(
609 neighbors,
610 process_neighbor,
611 prefetch_distance,
612 dist_calc,
613 );
614 }
615 );
616 OrderedNode::new(current, closest_dist.into())
617}
618
619#[cfg(test)]
620mod tests {}