Skip to main content

lance_index/vector/
graph.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Generic Graph implementation.
5//!
6
7use std::cmp::Reverse;
8use std::collections::BinaryHeap;
9use std::sync::Arc;
10
11use arrow_schema::{DataType, Field};
12use deepsize::DeepSizeOf;
13
14use crate::vector::hnsw::builder::HnswQueryParams;
15
16pub mod builder;
17
18use crate::vector::DIST_COL;
19
20use crate::vector::storage::DistCalculator;
21
22pub(crate) const NEIGHBORS_COL: &str = "__neighbors";
23
24use std::sync::LazyLock;
25
26/// NEIGHBORS field.
27pub static NEIGHBORS_FIELD: LazyLock<Field> = LazyLock::new(|| {
28    Field::new(
29        NEIGHBORS_COL,
30        DataType::List(Field::new_list_field(DataType::UInt32, true).into()),
31        true,
32    )
33});
34pub static DISTS_FIELD: LazyLock<Field> = LazyLock::new(|| {
35    Field::new(
36        DIST_COL,
37        DataType::List(Field::new_list_field(DataType::Float32, true).into()),
38        true,
39    )
40});
41
42pub struct GraphNode<I = u32> {
43    pub id: I,
44    pub neighbors: Vec<I>,
45}
46
47impl<I> GraphNode<I> {
48    pub fn new(id: I, neighbors: Vec<I>) -> Self {
49        Self { id, neighbors }
50    }
51}
52
53impl<I> From<I> for GraphNode<I> {
54    fn from(id: I) -> Self {
55        Self {
56            id,
57            neighbors: vec![],
58        }
59    }
60}
61
62/// A wrapper for f32 to make it ordered, so that we can put it into
63/// a BTree or Heap
64#[derive(Debug, PartialEq, Clone, Copy, DeepSizeOf)]
65pub struct OrderedFloat(pub f32);
66
67impl PartialOrd for OrderedFloat {
68    #[inline(always)]
69    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
70        Some(self.cmp(other))
71    }
72}
73
74impl Eq for OrderedFloat {}
75
76impl Ord for OrderedFloat {
77    #[inline(always)]
78    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
79        self.0.total_cmp(&other.0)
80    }
81}
82
83impl From<f32> for OrderedFloat {
84    fn from(f: f32) -> Self {
85        Self(f)
86    }
87}
88
89impl From<OrderedFloat> for f32 {
90    fn from(f: OrderedFloat) -> Self {
91        f.0
92    }
93}
94
95#[derive(Debug, Eq, PartialEq, Clone, DeepSizeOf)]
96pub struct OrderedNode<T = u32>
97where
98    T: PartialEq + Eq,
99{
100    pub id: T,
101    pub dist: OrderedFloat,
102}
103
104impl<T: PartialEq + Eq> OrderedNode<T> {
105    pub fn new(id: T, dist: OrderedFloat) -> Self {
106        Self { id, dist }
107    }
108}
109
110impl<T: PartialEq + Eq> PartialOrd for OrderedNode<T> {
111    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
112        Some(self.cmp(other))
113    }
114}
115
116impl<T: PartialEq + Eq> Ord for OrderedNode<T> {
117    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
118        self.dist.cmp(&other.dist)
119    }
120}
121
122impl<T: PartialEq + Eq> From<(OrderedFloat, T)> for OrderedNode<T> {
123    fn from((dist, id): (OrderedFloat, T)) -> Self {
124        Self { id, dist }
125    }
126}
127
128impl<T: PartialEq + Eq> From<OrderedNode<T>> for (OrderedFloat, T) {
129    fn from(node: OrderedNode<T>) -> Self {
130        (node.dist, node.id)
131    }
132}
133
134/// Distance calculator.
135///
136/// This trait is used to calculate a query vector to a stream of vector IDs.
137///
138pub trait DistanceCalculator {
139    /// Compute distances between one query vector to all the vectors in the
140    /// list of IDs.
141    fn compute_distances(&self, ids: &[u32]) -> Box<dyn Iterator<Item = f32>>;
142}
143
144/// Graph trait.
145///
146/// Type parameters
147/// ---------------
148/// K: Vertex Index type
149/// T: the data type of vector, i.e., ``f32`` or ``f16``.
150pub trait Graph {
151    /// Get the number of nodes in the graph.
152    fn len(&self) -> usize;
153
154    /// Returns true if the graph is empty.
155    fn is_empty(&self) -> bool {
156        self.len() == 0
157    }
158
159    /// Get the neighbors of a graph node, identifyied by the index.
160    fn neighbors(&self, key: u32) -> Arc<Vec<u32>>;
161}
162
163pub trait BorrowingGraph {
164    /// Get the number of nodes in the graph.
165    fn len(&self) -> usize;
166
167    /// Returns true if the graph is empty.
168    fn is_empty(&self) -> bool {
169        self.len() == 0
170    }
171
172    /// Borrow the neighbors of a graph node, identified by the index.
173    fn neighbors(&self, key: u32) -> &[u32];
174}
175
176const WORD_BITS: usize = usize::BITS as usize;
177
178/// Compact visited list for graph traversals.
179pub struct Visited<'a> {
180    visited: &'a mut Vec<usize>,
181    recently_visited: &'a mut Vec<u32>,
182}
183
184impl Visited<'_> {
185    pub fn insert(&mut self, node_id: u32) {
186        let node_id_usize = node_id as usize;
187        let word_index = node_id_usize / WORD_BITS;
188        let mask = 1usize << (node_id_usize % WORD_BITS);
189        if self.visited[word_index] & mask == 0 {
190            self.visited[word_index] |= mask;
191            self.recently_visited.push(node_id);
192        }
193    }
194
195    pub fn contains(&self, node_id: u32) -> bool {
196        let node_id_usize = node_id as usize;
197        let word_index = node_id_usize / WORD_BITS;
198        let mask = 1usize << (node_id_usize % WORD_BITS);
199        self.visited[word_index] & mask != 0
200    }
201
202    #[inline(always)]
203    pub fn iter_ones(&self) -> impl Iterator<Item = usize> + '_ {
204        self.recently_visited
205            .iter()
206            .map(|node_id| *node_id as usize)
207    }
208
209    pub fn count_ones(&self) -> usize {
210        self.recently_visited.len()
211    }
212}
213
214impl Drop for Visited<'_> {
215    fn drop(&mut self) {
216        for node_id in self.recently_visited.iter().copied() {
217            let node_id_usize = node_id as usize;
218            let word_index = node_id_usize / WORD_BITS;
219            let mask = 1usize << (node_id_usize % WORD_BITS);
220            self.visited[word_index] &= !mask;
221        }
222        self.recently_visited.clear();
223    }
224}
225
226#[derive(Debug, Clone)]
227pub struct VisitedGenerator {
228    visited: Vec<usize>,
229    recently_visited: Vec<u32>,
230    capacity: usize,
231}
232
233impl VisitedGenerator {
234    pub fn new(capacity: usize) -> Self {
235        Self {
236            visited: vec![0; capacity.div_ceil(WORD_BITS)],
237            recently_visited: Vec::new(),
238            capacity,
239        }
240    }
241
242    pub fn generate(&mut self, node_count: usize) -> Visited<'_> {
243        if node_count > self.capacity {
244            let new_capacity = self.capacity.max(node_count).next_power_of_two();
245            self.visited.resize(new_capacity.div_ceil(WORD_BITS), 0);
246            self.capacity = new_capacity;
247        }
248        Visited {
249            visited: &mut self.visited,
250            recently_visited: &mut self.recently_visited,
251        }
252    }
253}
254
255fn process_neighbors_with_look_ahead<F>(
256    neighbors: &[u32],
257    mut process_neighbor: F,
258    look_ahead: Option<usize>,
259    dist_calc: &impl DistCalculator,
260) where
261    F: FnMut(u32),
262{
263    match look_ahead {
264        Some(look_ahead) => {
265            for i in 0..neighbors.len().saturating_sub(look_ahead) {
266                dist_calc.prefetch(neighbors[i + look_ahead]);
267                process_neighbor(neighbors[i]);
268            }
269            for neighbor in &neighbors[neighbors.len().saturating_sub(look_ahead)..] {
270                process_neighbor(*neighbor);
271            }
272        }
273        None => {
274            for neighbor in neighbors.iter() {
275                process_neighbor(*neighbor);
276            }
277        }
278    }
279}
280
281#[inline]
282fn furthest_distance(results: &BinaryHeap<OrderedNode>) -> OrderedFloat {
283    results
284        .peek()
285        .map(|node| node.dist)
286        .unwrap_or(OrderedFloat(f32::INFINITY))
287}
288
289#[inline]
290fn push_result(results: &mut BinaryHeap<OrderedNode>, candidate: OrderedNode, k: usize) {
291    if results.len() < k {
292        results.push(candidate);
293    } else if candidate.dist < results.peek().unwrap().dist {
294        results.pop();
295        results.push(candidate);
296    }
297}
298
299macro_rules! beam_search_loop {
300    (
301        $candidates:ident,
302        $results:ident,
303        $visited:ident,
304        $k:expr,
305        $dist_calc:expr,
306        $prefetch_distance:expr,
307        $accepts_result:expr,
308        |$current:ident, $process_neighbor:ident| $visit_neighbors:block
309    ) => {{
310        while !$candidates.is_empty() {
311            let $current = $candidates.pop().expect("candidates is empty").0;
312            let furthest = furthest_distance(&$results);
313
314            if $current.dist > furthest && $results.len() == $k {
315                break;
316            }
317
318            let $process_neighbor = |neighbor: u32| {
319                if $visited.contains(neighbor) {
320                    return;
321                }
322                $visited.insert(neighbor);
323                let dist: OrderedFloat = $dist_calc.distance(neighbor).into();
324                if dist <= furthest || $results.len() < $k {
325                    if $accepts_result(neighbor, dist) {
326                        push_result(&mut $results, (dist, neighbor).into(), $k);
327                    }
328                    $candidates.push(Reverse((dist, neighbor).into()));
329                }
330            };
331            $visit_neighbors
332        }
333    }};
334}
335
336macro_rules! greedy_search_loop {
337    (
338        $current:ident,
339        $closest_dist:ident,
340        $dist_calc:expr,
341        $prefetch_distance:expr,
342        |$process_neighbor:ident| $visit_neighbors:block
343    ) => {{
344        loop {
345            let mut next = None;
346            let $process_neighbor = |neighbor: u32| {
347                let dist = $dist_calc.distance(neighbor);
348                if dist < $closest_dist {
349                    $closest_dist = dist;
350                    next = Some(neighbor);
351                }
352            };
353            $visit_neighbors
354
355            if let Some(next) = next {
356                $current = next;
357            } else {
358                break;
359            }
360        }
361    }};
362}
363
364/// Beam search over a graph
365///
366/// This is the same as ``search-layer`` in HNSW.
367///
368/// Parameters
369/// ----------
370/// graph : Graph
371///  The graph to search.
372/// start : &[OrderedNode]
373///  The starting point.
374/// query : &[f32]
375///  The query vector.
376/// k : usize
377///  The number of results to return.
378/// bitset : Option<&RoaringBitmap>
379///  The bitset of node IDs to filter the results, bit 1 for the node to keep, and bit 0 for the node to discard.
380///
381/// Returns
382/// -------
383/// A descending sorted list of ``(dist, node_id)`` pairs.
384///
385/// WARNING: Internal API,  API stability is not guaranteed
386///
387/// TODO: This isn't actually beam search, function should probably be renamed
388pub fn beam_search(
389    graph: &dyn Graph,
390    ep: &OrderedNode,
391    params: &HnswQueryParams,
392    dist_calc: &impl DistCalculator,
393    bitset: Option<&Visited>,
394    prefetch_distance: Option<usize>,
395    visited: &mut Visited,
396) -> Vec<OrderedNode> {
397    let k = params.ef;
398    let mut candidates = BinaryHeap::with_capacity(k);
399    visited.insert(ep.id);
400    candidates.push(Reverse(ep.clone()));
401
402    let mut results = BinaryHeap::with_capacity(k);
403    let no_filter =
404        bitset.is_none() && params.lower_bound.is_none() && params.upper_bound.is_none();
405
406    if no_filter {
407        results.push(ep.clone());
408        let accepts_result = |_: u32, _: OrderedFloat| true;
409        beam_search_loop!(
410            candidates,
411            results,
412            visited,
413            k,
414            dist_calc,
415            prefetch_distance,
416            accepts_result,
417            |current, process_neighbor| {
418                let neighbors = graph.neighbors(current.id);
419                process_neighbors_with_look_ahead(
420                    &neighbors,
421                    process_neighbor,
422                    prefetch_distance,
423                    dist_calc,
424                );
425            }
426        );
427        return results.into_sorted_vec();
428    }
429
430    // add range search support
431    let lower_bound: OrderedFloat = params.lower_bound.unwrap_or(f32::MIN).into();
432    let upper_bound: OrderedFloat = params.upper_bound.unwrap_or(f32::MAX).into();
433
434    if bitset.map(|bitset| bitset.contains(ep.id)).unwrap_or(true)
435        && ep.dist >= lower_bound
436        && ep.dist < upper_bound
437    {
438        results.push(ep.clone());
439    }
440
441    let accepts_result = |node_id: u32, dist: OrderedFloat| {
442        bitset
443            .map(|bitset| bitset.contains(node_id))
444            .unwrap_or(true)
445            && dist >= lower_bound
446            && dist < upper_bound
447    };
448    beam_search_loop!(
449        candidates,
450        results,
451        visited,
452        k,
453        dist_calc,
454        prefetch_distance,
455        accepts_result,
456        |current, process_neighbor| {
457            let neighbors = graph.neighbors(current.id);
458            process_neighbors_with_look_ahead(
459                &neighbors,
460                process_neighbor,
461                prefetch_distance,
462                dist_calc,
463            );
464        }
465    );
466    results.into_sorted_vec()
467}
468
469pub fn beam_search_borrowed(
470    graph: &impl BorrowingGraph,
471    ep: &OrderedNode,
472    params: &HnswQueryParams,
473    dist_calc: &impl DistCalculator,
474    bitset: Option<&Visited>,
475    prefetch_distance: Option<usize>,
476    visited: &mut Visited,
477) -> Vec<OrderedNode> {
478    let k = params.ef;
479    let mut candidates = BinaryHeap::with_capacity(k);
480    visited.insert(ep.id);
481    candidates.push(Reverse(ep.clone()));
482
483    let mut results = BinaryHeap::with_capacity(k);
484    let no_filter =
485        bitset.is_none() && params.lower_bound.is_none() && params.upper_bound.is_none();
486
487    if no_filter {
488        results.push(ep.clone());
489        let accepts_result = |_: u32, _: OrderedFloat| true;
490        beam_search_loop!(
491            candidates,
492            results,
493            visited,
494            k,
495            dist_calc,
496            prefetch_distance,
497            accepts_result,
498            |current, process_neighbor| {
499                let neighbors = graph.neighbors(current.id);
500                process_neighbors_with_look_ahead(
501                    neighbors,
502                    process_neighbor,
503                    prefetch_distance,
504                    dist_calc,
505                );
506            }
507        );
508        return results.into_sorted_vec();
509    }
510
511    let lower_bound: OrderedFloat = params.lower_bound.unwrap_or(f32::MIN).into();
512    let upper_bound: OrderedFloat = params.upper_bound.unwrap_or(f32::MAX).into();
513
514    if bitset.map(|bitset| bitset.contains(ep.id)).unwrap_or(true)
515        && ep.dist >= lower_bound
516        && ep.dist < upper_bound
517    {
518        results.push(ep.clone());
519    }
520
521    let accepts_result = |node_id: u32, dist: OrderedFloat| {
522        bitset
523            .map(|bitset| bitset.contains(node_id))
524            .unwrap_or(true)
525            && dist >= lower_bound
526            && dist < upper_bound
527    };
528    beam_search_loop!(
529        candidates,
530        results,
531        visited,
532        k,
533        dist_calc,
534        prefetch_distance,
535        accepts_result,
536        |current, process_neighbor| {
537            let neighbors = graph.neighbors(current.id);
538            process_neighbors_with_look_ahead(
539                neighbors,
540                process_neighbor,
541                prefetch_distance,
542                dist_calc,
543            );
544        }
545    );
546    results.into_sorted_vec()
547}
548
549/// Greedy search over a graph
550///
551/// This searches for only one result, only used for finding the entry point
552///
553/// Parameters
554/// ----------
555/// graph : Graph
556///    The graph to search.
557/// start : u32
558///   The index starting point.
559/// query : &[f32]
560///   The query vector.
561///
562/// Returns
563/// -------
564/// A ``(dist, node_id)`` pair.
565///
566/// WARNING: Internal API,  API stability is not guaranteed
567pub fn greedy_search(
568    graph: &dyn Graph,
569    start: OrderedNode,
570    dist_calc: &impl DistCalculator,
571    prefetch_distance: Option<usize>,
572) -> OrderedNode {
573    let mut current = start.id;
574    let mut closest_dist = start.dist.0;
575    greedy_search_loop!(
576        current,
577        closest_dist,
578        dist_calc,
579        prefetch_distance,
580        |process_neighbor| {
581            let neighbors = graph.neighbors(current);
582            process_neighbors_with_look_ahead(
583                &neighbors,
584                process_neighbor,
585                prefetch_distance,
586                dist_calc,
587            );
588        }
589    );
590    OrderedNode::new(current, closest_dist.into())
591}
592
593pub fn greedy_search_borrowed(
594    graph: &impl BorrowingGraph,
595    start: OrderedNode,
596    dist_calc: &impl DistCalculator,
597    prefetch_distance: Option<usize>,
598) -> OrderedNode {
599    let mut current = start.id;
600    let mut closest_dist = start.dist.0;
601    greedy_search_loop!(
602        current,
603        closest_dist,
604        dist_calc,
605        prefetch_distance,
606        |process_neighbor| {
607            let neighbors = graph.neighbors(current);
608            process_neighbors_with_look_ahead(
609                neighbors,
610                process_neighbor,
611                prefetch_distance,
612                dist_calc,
613            );
614        }
615    );
616    OrderedNode::new(current, closest_dist.into())
617}
618
619#[cfg(test)]
620mod tests {}