lance_index/vector/
graph.rs1use std::cmp::Reverse;
8use std::collections::BinaryHeap;
9use std::sync::Arc;
10
11use arrow_schema::{DataType, Field};
12use bitvec::vec::BitVec;
13use deepsize::DeepSizeOf;
14
15use crate::vector::hnsw::builder::HnswQueryParams;
16
17pub mod builder;
18
19use crate::vector::DIST_COL;
20
21use crate::vector::storage::DistCalculator;
22
23pub(crate) const NEIGHBORS_COL: &str = "__neighbors";
24
25use std::sync::LazyLock;
26
27pub static NEIGHBORS_FIELD: LazyLock<Field> = LazyLock::new(|| {
29 Field::new(
30 NEIGHBORS_COL,
31 DataType::List(Field::new_list_field(DataType::UInt32, true).into()),
32 true,
33 )
34});
35pub static DISTS_FIELD: LazyLock<Field> = LazyLock::new(|| {
36 Field::new(
37 DIST_COL,
38 DataType::List(Field::new_list_field(DataType::Float32, true).into()),
39 true,
40 )
41});
42
43pub struct GraphNode<I = u32> {
44 pub id: I,
45 pub neighbors: Vec<I>,
46}
47
48impl<I> GraphNode<I> {
49 pub fn new(id: I, neighbors: Vec<I>) -> Self {
50 Self { id, neighbors }
51 }
52}
53
54impl<I> From<I> for GraphNode<I> {
55 fn from(id: I) -> Self {
56 Self {
57 id,
58 neighbors: vec![],
59 }
60 }
61}
62
63#[derive(Debug, PartialEq, Clone, Copy, DeepSizeOf)]
66pub struct OrderedFloat(pub f32);
67
68impl PartialOrd for OrderedFloat {
69 #[inline(always)]
70 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
71 Some(self.cmp(other))
72 }
73}
74
75impl Eq for OrderedFloat {}
76
77impl Ord for OrderedFloat {
78 #[inline(always)]
79 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
80 self.0.total_cmp(&other.0)
81 }
82}
83
84impl From<f32> for OrderedFloat {
85 fn from(f: f32) -> Self {
86 Self(f)
87 }
88}
89
90impl From<OrderedFloat> for f32 {
91 fn from(f: OrderedFloat) -> Self {
92 f.0
93 }
94}
95
96#[derive(Debug, Eq, PartialEq, Clone, DeepSizeOf)]
97pub struct OrderedNode<T = u32>
98where
99 T: PartialEq + Eq,
100{
101 pub id: T,
102 pub dist: OrderedFloat,
103}
104
105impl<T: PartialEq + Eq> OrderedNode<T> {
106 pub fn new(id: T, dist: OrderedFloat) -> Self {
107 Self { id, dist }
108 }
109}
110
111impl<T: PartialEq + Eq> PartialOrd for OrderedNode<T> {
112 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
113 Some(self.cmp(other))
114 }
115}
116
117impl<T: PartialEq + Eq> Ord for OrderedNode<T> {
118 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
119 self.dist.cmp(&other.dist)
120 }
121}
122
123impl<T: PartialEq + Eq> From<(OrderedFloat, T)> for OrderedNode<T> {
124 fn from((dist, id): (OrderedFloat, T)) -> Self {
125 Self { id, dist }
126 }
127}
128
129impl<T: PartialEq + Eq> From<OrderedNode<T>> for (OrderedFloat, T) {
130 fn from(node: OrderedNode<T>) -> Self {
131 (node.dist, node.id)
132 }
133}
134
135pub trait DistanceCalculator {
140 fn compute_distances(&self, ids: &[u32]) -> Box<dyn Iterator<Item = f32>>;
143}
144
145pub trait Graph {
152 fn len(&self) -> usize;
154
155 fn is_empty(&self) -> bool {
157 self.len() == 0
158 }
159
160 fn neighbors(&self, key: u32) -> Arc<Vec<u32>>;
162}
163
164pub struct Visited<'a> {
166 visited: &'a mut BitVec,
167 recently_visited: Vec<u32>,
168}
169
170impl Visited<'_> {
171 pub fn insert(&mut self, node_id: u32) {
172 let node_id_usize = node_id as usize;
173 if !self.visited[node_id_usize] {
174 self.visited.set(node_id_usize, true);
175 self.recently_visited.push(node_id);
176 }
177 }
178
179 pub fn contains(&self, node_id: u32) -> bool {
180 let node_id_usize = node_id as usize;
181 self.visited[node_id_usize]
182 }
183
184 #[inline(always)]
185 pub fn iter_ones(&self) -> impl Iterator<Item = usize> + '_ {
186 self.visited.iter_ones()
187 }
188
189 pub fn count_ones(&self) -> usize {
190 self.visited.count_ones()
191 }
192}
193
194impl Drop for Visited<'_> {
195 fn drop(&mut self) {
196 for node_id in self.recently_visited.iter() {
197 self.visited.set(*node_id as usize, false);
198 }
199 self.recently_visited.clear();
200 }
201}
202
203#[derive(Debug, Clone)]
204pub struct VisitedGenerator {
205 visited: BitVec,
206 capacity: usize,
207}
208
209impl VisitedGenerator {
210 pub fn new(capacity: usize) -> Self {
211 Self {
212 visited: BitVec::repeat(false, capacity),
213 capacity,
214 }
215 }
216
217 pub fn generate(&mut self, node_count: usize) -> Visited<'_> {
218 if node_count > self.capacity {
219 let new_capacity = self.capacity.max(node_count).next_power_of_two();
220 self.visited.resize(new_capacity, false);
221 self.capacity = new_capacity;
222 }
223 Visited {
224 visited: &mut self.visited,
225 recently_visited: Vec::new(),
226 }
227 }
228}
229
230fn process_neighbors_with_look_ahead<F>(
231 neighbors: &[u32],
232 mut process_neighbor: F,
233 look_ahead: Option<usize>,
234 dist_calc: &impl DistCalculator,
235) where
236 F: FnMut(u32),
237{
238 match look_ahead {
239 Some(look_ahead) => {
240 for i in 0..neighbors.len().saturating_sub(look_ahead) {
241 dist_calc.prefetch(neighbors[i + look_ahead]);
242 process_neighbor(neighbors[i]);
243 }
244 for neighbor in &neighbors[neighbors.len().saturating_sub(look_ahead)..] {
245 process_neighbor(*neighbor);
246 }
247 }
248 None => {
249 for neighbor in neighbors.iter() {
250 process_neighbor(*neighbor);
251 }
252 }
253 }
254}
255
256pub fn beam_search(
281 graph: &dyn Graph,
282 ep: &OrderedNode,
283 params: &HnswQueryParams,
284 dist_calc: &impl DistCalculator,
285 bitset: Option<&Visited>,
286 prefetch_distance: Option<usize>,
287 visited: &mut Visited,
288) -> Vec<OrderedNode> {
289 let k = params.ef;
290 let mut candidates = BinaryHeap::with_capacity(k);
291 visited.insert(ep.id);
292 candidates.push(Reverse(ep.clone()));
293
294 let lower_bound: OrderedFloat = params.lower_bound.unwrap_or(f32::MIN).into();
296 let upper_bound: OrderedFloat = params.upper_bound.unwrap_or(f32::MAX).into();
297
298 let mut results = BinaryHeap::with_capacity(k);
299
300 if bitset.map(|bitset| bitset.contains(ep.id)).unwrap_or(true)
301 && ep.dist >= lower_bound
302 && ep.dist < upper_bound
303 {
304 results.push(ep.clone());
305 }
306
307 while !candidates.is_empty() {
308 let current = candidates.pop().expect("candidates is empty").0;
309 let furthest = results
310 .peek()
311 .map(|node| node.dist)
312 .unwrap_or(OrderedFloat(f32::INFINITY));
313
314 if current.dist > furthest && results.len() == k {
316 break;
317 }
318 let furthest = results
319 .peek()
320 .map(|node| node.dist)
321 .unwrap_or(OrderedFloat(f32::INFINITY));
322
323 let process_neighbor = |neighbor: u32| {
324 if visited.contains(neighbor) {
325 return;
326 }
327 visited.insert(neighbor);
328 let dist: OrderedFloat = dist_calc.distance(neighbor).into();
329 if dist <= furthest || results.len() < k {
330 if bitset
331 .map(|bitset| bitset.contains(neighbor))
332 .unwrap_or(true)
333 && dist >= lower_bound
334 && dist < upper_bound
335 {
336 if results.len() < k {
337 results.push((dist, neighbor).into());
338 } else if results.len() == k && dist < results.peek().unwrap().dist {
339 results.pop();
340 results.push((dist, neighbor).into());
341 }
342 }
343 candidates.push(Reverse((dist, neighbor).into()));
344 }
345 };
346 let neighbors = graph.neighbors(current.id);
347 process_neighbors_with_look_ahead(
348 &neighbors,
349 process_neighbor,
350 prefetch_distance,
351 dist_calc,
352 );
353 }
354
355 results.into_sorted_vec()
356}
357
358pub fn greedy_search(
377 graph: &dyn Graph,
378 start: OrderedNode,
379 dist_calc: &impl DistCalculator,
380 prefetch_distance: Option<usize>,
381) -> OrderedNode {
382 let mut current = start.id;
383 let mut closest_dist = start.dist.0;
384 loop {
385 let neighbors = graph.neighbors(current);
386 let mut next = None;
387
388 let process_neighbor = |neighbor: u32| {
389 let dist = dist_calc.distance(neighbor);
390 if dist < closest_dist {
391 closest_dist = dist;
392 next = Some(neighbor);
393 }
394 };
395 process_neighbors_with_look_ahead(
396 &neighbors,
397 process_neighbor,
398 prefetch_distance,
399 dist_calc,
400 );
401
402 if let Some(next) = next {
403 current = next;
404 } else {
405 break;
406 }
407 }
408
409 OrderedNode::new(current, closest_dist.into())
410}
411
412#[cfg(test)]
413mod tests {}