lance_index/vector/
graph.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Generic Graph implementation.
5//!
6
7use std::cmp::Reverse;
8use std::collections::BinaryHeap;
9use std::sync::Arc;
10
11use arrow_schema::{DataType, Field};
12use bitvec::vec::BitVec;
13use deepsize::DeepSizeOf;
14
15use crate::vector::hnsw::builder::HnswQueryParams;
16
17pub mod builder;
18
19use crate::vector::DIST_COL;
20
21use crate::vector::storage::DistCalculator;
22
23pub(crate) const NEIGHBORS_COL: &str = "__neighbors";
24
25use std::sync::LazyLock;
26
27/// NEIGHBORS field.
28pub static NEIGHBORS_FIELD: LazyLock<Field> = LazyLock::new(|| {
29    Field::new(
30        NEIGHBORS_COL,
31        DataType::List(Field::new_list_field(DataType::UInt32, true).into()),
32        true,
33    )
34});
35pub static DISTS_FIELD: LazyLock<Field> = LazyLock::new(|| {
36    Field::new(
37        DIST_COL,
38        DataType::List(Field::new_list_field(DataType::Float32, true).into()),
39        true,
40    )
41});
42
43pub struct GraphNode<I = u32> {
44    pub id: I,
45    pub neighbors: Vec<I>,
46}
47
48impl<I> GraphNode<I> {
49    pub fn new(id: I, neighbors: Vec<I>) -> Self {
50        Self { id, neighbors }
51    }
52}
53
54impl<I> From<I> for GraphNode<I> {
55    fn from(id: I) -> Self {
56        Self {
57            id,
58            neighbors: vec![],
59        }
60    }
61}
62
63/// A wrapper for f32 to make it ordered, so that we can put it into
64/// a BTree or Heap
65#[derive(Debug, PartialEq, Clone, Copy, DeepSizeOf)]
66pub struct OrderedFloat(pub f32);
67
68impl PartialOrd for OrderedFloat {
69    #[inline(always)]
70    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
71        Some(self.cmp(other))
72    }
73}
74
75impl Eq for OrderedFloat {}
76
77impl Ord for OrderedFloat {
78    #[inline(always)]
79    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
80        self.0.total_cmp(&other.0)
81    }
82}
83
84impl From<f32> for OrderedFloat {
85    fn from(f: f32) -> Self {
86        Self(f)
87    }
88}
89
90impl From<OrderedFloat> for f32 {
91    fn from(f: OrderedFloat) -> Self {
92        f.0
93    }
94}
95
96#[derive(Debug, Eq, PartialEq, Clone, DeepSizeOf)]
97pub struct OrderedNode<T = u32>
98where
99    T: PartialEq + Eq,
100{
101    pub id: T,
102    pub dist: OrderedFloat,
103}
104
105impl<T: PartialEq + Eq> OrderedNode<T> {
106    pub fn new(id: T, dist: OrderedFloat) -> Self {
107        Self { id, dist }
108    }
109}
110
111impl<T: PartialEq + Eq> PartialOrd for OrderedNode<T> {
112    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
113        Some(self.cmp(other))
114    }
115}
116
117impl<T: PartialEq + Eq> Ord for OrderedNode<T> {
118    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
119        self.dist.cmp(&other.dist)
120    }
121}
122
123impl<T: PartialEq + Eq> From<(OrderedFloat, T)> for OrderedNode<T> {
124    fn from((dist, id): (OrderedFloat, T)) -> Self {
125        Self { id, dist }
126    }
127}
128
129impl<T: PartialEq + Eq> From<OrderedNode<T>> for (OrderedFloat, T) {
130    fn from(node: OrderedNode<T>) -> Self {
131        (node.dist, node.id)
132    }
133}
134
135/// Distance calculator.
136///
137/// This trait is used to calculate a query vector to a stream of vector IDs.
138///
139pub trait DistanceCalculator {
140    /// Compute distances between one query vector to all the vectors in the
141    /// list of IDs.
142    fn compute_distances(&self, ids: &[u32]) -> Box<dyn Iterator<Item = f32>>;
143}
144
145/// Graph trait.
146///
147/// Type parameters
148/// ---------------
149/// K: Vertex Index type
150/// T: the data type of vector, i.e., ``f32`` or ``f16``.
151pub trait Graph {
152    /// Get the number of nodes in the graph.
153    fn len(&self) -> usize;
154
155    /// Returns true if the graph is empty.
156    fn is_empty(&self) -> bool {
157        self.len() == 0
158    }
159
160    /// Get the neighbors of a graph node, identifyied by the index.
161    fn neighbors(&self, key: u32) -> Arc<Vec<u32>>;
162}
163
164/// Array-based visited list (faster than HashSet)
165pub struct Visited<'a> {
166    visited: &'a mut BitVec,
167    recently_visited: Vec<u32>,
168}
169
170impl Visited<'_> {
171    pub fn insert(&mut self, node_id: u32) {
172        let node_id_usize = node_id as usize;
173        if !self.visited[node_id_usize] {
174            self.visited.set(node_id_usize, true);
175            self.recently_visited.push(node_id);
176        }
177    }
178
179    pub fn contains(&self, node_id: u32) -> bool {
180        let node_id_usize = node_id as usize;
181        self.visited[node_id_usize]
182    }
183
184    #[inline(always)]
185    pub fn iter_ones(&self) -> impl Iterator<Item = usize> + '_ {
186        self.visited.iter_ones()
187    }
188
189    pub fn count_ones(&self) -> usize {
190        self.visited.count_ones()
191    }
192}
193
194impl Drop for Visited<'_> {
195    fn drop(&mut self) {
196        for node_id in self.recently_visited.iter() {
197            self.visited.set(*node_id as usize, false);
198        }
199        self.recently_visited.clear();
200    }
201}
202
203#[derive(Debug, Clone)]
204pub struct VisitedGenerator {
205    visited: BitVec,
206    capacity: usize,
207}
208
209impl VisitedGenerator {
210    pub fn new(capacity: usize) -> Self {
211        Self {
212            visited: BitVec::repeat(false, capacity),
213            capacity,
214        }
215    }
216
217    pub fn generate(&mut self, node_count: usize) -> Visited<'_> {
218        if node_count > self.capacity {
219            let new_capacity = self.capacity.max(node_count).next_power_of_two();
220            self.visited.resize(new_capacity, false);
221            self.capacity = new_capacity;
222        }
223        Visited {
224            visited: &mut self.visited,
225            recently_visited: Vec::new(),
226        }
227    }
228}
229
230fn process_neighbors_with_look_ahead<F>(
231    neighbors: &[u32],
232    mut process_neighbor: F,
233    look_ahead: Option<usize>,
234    dist_calc: &impl DistCalculator,
235) where
236    F: FnMut(u32),
237{
238    match look_ahead {
239        Some(look_ahead) => {
240            for i in 0..neighbors.len().saturating_sub(look_ahead) {
241                dist_calc.prefetch(neighbors[i + look_ahead]);
242                process_neighbor(neighbors[i]);
243            }
244            for neighbor in &neighbors[neighbors.len().saturating_sub(look_ahead)..] {
245                process_neighbor(*neighbor);
246            }
247        }
248        None => {
249            for neighbor in neighbors.iter() {
250                process_neighbor(*neighbor);
251            }
252        }
253    }
254}
255
256/// Beam search over a graph
257///
258/// This is the same as ``search-layer`` in HNSW.
259///
260/// Parameters
261/// ----------
262/// graph : Graph
263///  The graph to search.
264/// start : &[OrderedNode]
265///  The starting point.
266/// query : &[f32]
267///  The query vector.
268/// k : usize
269///  The number of results to return.
270/// bitset : Option<&RoaringBitmap>
271///  The bitset of node IDs to filter the results, bit 1 for the node to keep, and bit 0 for the node to discard.
272///
273/// Returns
274/// -------
275/// A descending sorted list of ``(dist, node_id)`` pairs.
276///
277/// WARNING: Internal API,  API stability is not guaranteed
278///
279/// TODO: This isn't actually beam search, function should probably be renamed
280pub fn beam_search(
281    graph: &dyn Graph,
282    ep: &OrderedNode,
283    params: &HnswQueryParams,
284    dist_calc: &impl DistCalculator,
285    bitset: Option<&Visited>,
286    prefetch_distance: Option<usize>,
287    visited: &mut Visited,
288) -> Vec<OrderedNode> {
289    let k = params.ef;
290    let mut candidates = BinaryHeap::with_capacity(k);
291    visited.insert(ep.id);
292    candidates.push(Reverse(ep.clone()));
293
294    // add range search support
295    let lower_bound: OrderedFloat = params.lower_bound.unwrap_or(f32::MIN).into();
296    let upper_bound: OrderedFloat = params.upper_bound.unwrap_or(f32::MAX).into();
297
298    let mut results = BinaryHeap::with_capacity(k);
299
300    if bitset.map(|bitset| bitset.contains(ep.id)).unwrap_or(true)
301        && ep.dist >= lower_bound
302        && ep.dist < upper_bound
303    {
304        results.push(ep.clone());
305    }
306
307    while !candidates.is_empty() {
308        let current = candidates.pop().expect("candidates is empty").0;
309        let furthest = results
310            .peek()
311            .map(|node| node.dist)
312            .unwrap_or(OrderedFloat(f32::INFINITY));
313
314        // TODO: add an option to ignore the second condition for better performance.
315        if current.dist > furthest && results.len() == k {
316            break;
317        }
318        let furthest = results
319            .peek()
320            .map(|node| node.dist)
321            .unwrap_or(OrderedFloat(f32::INFINITY));
322
323        let process_neighbor = |neighbor: u32| {
324            if visited.contains(neighbor) {
325                return;
326            }
327            visited.insert(neighbor);
328            let dist: OrderedFloat = dist_calc.distance(neighbor).into();
329            if dist <= furthest || results.len() < k {
330                if bitset
331                    .map(|bitset| bitset.contains(neighbor))
332                    .unwrap_or(true)
333                    && dist >= lower_bound
334                    && dist < upper_bound
335                {
336                    if results.len() < k {
337                        results.push((dist, neighbor).into());
338                    } else if results.len() == k && dist < results.peek().unwrap().dist {
339                        results.pop();
340                        results.push((dist, neighbor).into());
341                    }
342                }
343                candidates.push(Reverse((dist, neighbor).into()));
344            }
345        };
346        let neighbors = graph.neighbors(current.id);
347        process_neighbors_with_look_ahead(
348            &neighbors,
349            process_neighbor,
350            prefetch_distance,
351            dist_calc,
352        );
353    }
354
355    results.into_sorted_vec()
356}
357
358/// Greedy search over a graph
359///
360/// This searches for only one result, only used for finding the entry point
361///
362/// Parameters
363/// ----------
364/// graph : Graph
365///    The graph to search.
366/// start : u32
367///   The index starting point.
368/// query : &[f32]
369///   The query vector.
370///
371/// Returns
372/// -------
373/// A ``(dist, node_id)`` pair.
374///
375/// WARNING: Internal API,  API stability is not guaranteed
376pub fn greedy_search(
377    graph: &dyn Graph,
378    start: OrderedNode,
379    dist_calc: &impl DistCalculator,
380    prefetch_distance: Option<usize>,
381) -> OrderedNode {
382    let mut current = start.id;
383    let mut closest_dist = start.dist.0;
384    loop {
385        let neighbors = graph.neighbors(current);
386        let mut next = None;
387
388        let process_neighbor = |neighbor: u32| {
389            let dist = dist_calc.distance(neighbor);
390            if dist < closest_dist {
391                closest_dist = dist;
392                next = Some(neighbor);
393            }
394        };
395        process_neighbors_with_look_ahead(
396            &neighbors,
397            process_neighbor,
398            prefetch_distance,
399            dist_calc,
400        );
401
402        if let Some(next) = next {
403            current = next;
404        } else {
405            break;
406        }
407    }
408
409    OrderedNode::new(current, closest_dist.into())
410}
411
412#[cfg(test)]
413mod tests {}