1use std::cmp::min;
33use std::sync::Arc;
34use std::sync::Mutex;
35use std::sync::atomic::{AtomicU32, AtomicUsize, Ordering};
36
37use arc_swap::ArcSwap;
38use crossbeam_queue::ArrayQueue;
39use rand::{Rng, SeedableRng, rngs::SmallRng};
40
41use super::builder::{HNSW, HNSW_LEVEL_RNG_SEED, HnswBuildParams, HnswQueryParams};
42use super::select_neighbors_heuristic;
43use crate::vector::graph::builder::GraphBuilderNode;
44use crate::vector::graph::{
45 Graph, OrderedFloat, OrderedNode, VisitedGenerator, beam_search, greedy_search,
46};
47use crate::vector::storage::{DistCalculator, VectorStore};
48use lance_core::utils::tokio::get_num_compute_intensive_cpus;
49
50pub struct OnlineGraphBuilderNode {
57 pub(crate) level_neighbors: Vec<ArcSwap<Vec<u32>>>,
60 pub(crate) level_neighbors_ranked: Mutex<Vec<Vec<OrderedNode>>>,
62 pub(crate) bottom_neighbors: ArcSwap<Vec<u32>>,
64}
65
66impl OnlineGraphBuilderNode {
67 pub fn new(target_level: u16) -> Self {
68 let levels = (target_level as usize) + 1;
69 let level_neighbors = (0..levels)
70 .map(|_| ArcSwap::from_pointee(Vec::new()))
71 .collect();
72 let level_neighbors_ranked = (0..levels).map(|_| Vec::new()).collect();
73 Self {
74 level_neighbors,
75 level_neighbors_ranked: Mutex::new(level_neighbors_ranked),
76 bottom_neighbors: ArcSwap::from_pointee(Vec::new()),
77 }
78 }
79
80 fn target_level(&self) -> u16 {
81 self.level_neighbors.len() as u16 - 1
82 }
83
84 fn has_level(&self, level: u16) -> bool {
86 (level as usize) < self.level_neighbors.len()
87 }
88
89 fn add_neighbor(&self, v: u32, dist: OrderedFloat, level: u16) {
90 if !self.has_level(level) {
91 return;
92 }
93 let mut ranked = self
94 .level_neighbors_ranked
95 .lock()
96 .expect("level_neighbors_ranked mutex poisoned");
97 ranked[level as usize].push(OrderedNode { dist, id: v });
98 }
99
100 fn cutoff(&self, level: u16, max_size: usize) -> OrderedFloat {
101 if !self.has_level(level) {
102 return OrderedFloat(f32::NEG_INFINITY);
103 }
104 let ranked = self
105 .level_neighbors_ranked
106 .lock()
107 .expect("level_neighbors_ranked mutex poisoned");
108 let neighbors = &ranked[level as usize];
109 if neighbors.len() < max_size {
110 OrderedFloat(f32::INFINITY)
111 } else {
112 neighbors.last().unwrap().dist
113 }
114 }
115
116 fn publish_from_ranked(&self, level: u16) {
119 if !self.has_level(level) {
120 return;
121 }
122 let ranked = self
123 .level_neighbors_ranked
124 .lock()
125 .expect("level_neighbors_ranked mutex poisoned");
126 let new_list: Vec<u32> = ranked[level as usize].iter().map(|n| n.id).collect();
127 drop(ranked);
128 let new_arc = Arc::new(new_list);
129 self.level_neighbors[level as usize].store(new_arc.clone());
130 if level == 0 {
131 self.bottom_neighbors.store(new_arc);
132 }
133 }
134}
135
136pub struct OnlineHnswBuilder {
144 params: HnswBuildParams,
145 nodes: Vec<OnlineGraphBuilderNode>,
146 level_count: Vec<AtomicUsize>,
149 entry_point: AtomicU32,
152 inserted_len: AtomicUsize,
154 visited_generator_queue: Arc<ArrayQueue<VisitedGenerator>>,
156}
157
158impl OnlineHnswBuilder {
159 pub fn with_capacity(capacity: usize, params: HnswBuildParams) -> Self {
162 assert!(
163 params.max_level > 0,
164 "HnswBuildParams::max_level must be > 0"
165 );
166 let max_level = params.max_level;
167 let level_count = (0..max_level).map(|_| AtomicUsize::new(0)).collect();
168
169 let mut level_rng = SmallRng::seed_from_u64(HNSW_LEVEL_RNG_SEED);
170 let nodes: Vec<_> = (0..capacity)
171 .map(|i| {
172 let target_level = if i == 0 {
173 0
176 } else {
177 Self::random_level_with(¶ms, &mut level_rng)
178 };
179 OnlineGraphBuilderNode::new(target_level)
180 })
181 .collect();
182
183 let queue_size = get_num_compute_intensive_cpus().max(1);
184 let visited_generator_queue = Arc::new(ArrayQueue::new(queue_size));
185 for _ in 0..queue_size {
186 let _ = visited_generator_queue.push(VisitedGenerator::new(0));
187 }
188
189 Self {
190 params,
191 nodes,
192 level_count,
193 entry_point: AtomicU32::new(0),
194 inserted_len: AtomicUsize::new(0),
195 visited_generator_queue,
196 }
197 }
198
199 fn random_level_with<R: Rng + ?Sized>(params: &HnswBuildParams, rng: &mut R) -> u16 {
200 let ml = 1.0 / (params.m as f32).ln();
201 min(
202 (-rng.random::<f32>().ln() * ml) as u16,
203 params.max_level - 1,
204 )
205 }
206
207 pub fn capacity(&self) -> usize {
208 self.nodes.len()
209 }
210
211 pub fn len(&self) -> usize {
212 self.inserted_len.load(Ordering::Acquire)
213 }
214
215 pub fn is_empty(&self) -> bool {
216 self.len() == 0
217 }
218
219 pub fn params(&self) -> &HnswBuildParams {
220 &self.params
221 }
222
223 pub fn insert(&self, id: u32, storage: &impl VectorStore) {
229 let mut visited_generator = self
230 .visited_generator_queue
231 .pop()
232 .unwrap_or_else(|| VisitedGenerator::new(self.nodes.len()));
233
234 self.insert_with_generator(id, storage, &mut visited_generator);
235
236 let _ = self.visited_generator_queue.push(visited_generator);
238 }
239
240 fn insert_with_generator(
241 &self,
242 id: u32,
243 storage: &impl VectorStore,
244 visited_generator: &mut VisitedGenerator,
245 ) {
246 let nodes = self.nodes.as_slice();
247 let target_level = nodes[id as usize].target_level();
248 let dist_calc = storage.dist_calculator_from_id(id);
249
250 if self.inserted_len.load(Ordering::Acquire) == 0 {
252 for level in 0..=target_level {
253 self.level_count[level as usize].fetch_add(1, Ordering::Relaxed);
254 }
255 self.entry_point.store(id, Ordering::Release);
256 self.inserted_len.store(1, Ordering::Release);
257 return;
258 }
259
260 let entry = self.entry_point.load(Ordering::Acquire);
261 let mut ep = OrderedNode::new(entry, dist_calc.distance(entry).into());
262
263 for level in (target_level + 1..self.params.max_level).rev() {
265 let cur_level = OnlineHnswLevelView::new(level, nodes);
266 ep = greedy_search(&cur_level, ep, &dist_calc, self.params.prefetch_distance);
267 }
268
269 let mut pruned_neighbors_per_level: Vec<Vec<OrderedNode>> =
271 vec![Vec::new(); (target_level + 1) as usize];
272
273 let current_node = &nodes[id as usize];
274 for level in (0..=target_level).rev() {
275 self.level_count[level as usize].fetch_add(1, Ordering::Relaxed);
276
277 let neighbors = self.search_level(&ep, level, &dist_calc, nodes, visited_generator);
278 for neighbor in &neighbors {
282 if !nodes[neighbor.id as usize].has_level(level) {
283 continue;
284 }
285 current_node.add_neighbor(neighbor.id, neighbor.dist, level);
286 }
287 self.prune(storage, current_node, level);
288 let snapshot = {
290 let ranked = current_node
291 .level_neighbors_ranked
292 .lock()
293 .expect("level_neighbors_ranked mutex poisoned");
294 ranked[level as usize].clone()
295 };
296 current_node.publish_from_ranked(level);
297 pruned_neighbors_per_level[level as usize] = snapshot;
298
299 if let Some(next) = neighbors
303 .iter()
304 .find(|n| nodes[n.id as usize].has_level(level))
305 {
306 ep = next.clone();
307 }
308 }
309
310 for (level, pruned_neighbors) in pruned_neighbors_per_level.iter().enumerate() {
312 let level = level as u16;
313 let m_max = if level == 0 {
314 self.params.m * 2
315 } else {
316 self.params.m
317 };
318 for unpruned_edge in pruned_neighbors {
319 let chosen = &nodes[unpruned_edge.id as usize];
320 if unpruned_edge.dist < chosen.cutoff(level, m_max) {
321 chosen.add_neighbor(id, unpruned_edge.dist, level);
322 self.prune(storage, chosen, level);
323 chosen.publish_from_ranked(level);
324 }
325 }
326 }
327
328 let entry_target_level = nodes[entry as usize].target_level();
335 if target_level > entry_target_level {
336 let _ =
337 self.entry_point
338 .compare_exchange(entry, id, Ordering::AcqRel, Ordering::Acquire);
339 }
340
341 self.inserted_len.fetch_add(1, Ordering::AcqRel);
342 }
343
344 fn search_level(
345 &self,
346 ep: &OrderedNode,
347 level: u16,
348 dist_calc: &impl DistCalculator,
349 nodes: &[OnlineGraphBuilderNode],
350 visited_generator: &mut VisitedGenerator,
351 ) -> Vec<OrderedNode> {
352 let cur_level = OnlineHnswLevelView::new(level, nodes);
353 let mut visited = visited_generator.generate(nodes.len());
354 beam_search(
355 &cur_level,
356 ep,
357 &HnswQueryParams {
358 ef: self.params.ef_construction,
359 lower_bound: None,
360 upper_bound: None,
361 dist_q_c: 0.0,
362 },
363 dist_calc,
364 None,
365 self.params.prefetch_distance,
366 &mut visited,
367 )
368 }
369
370 fn prune(&self, storage: &impl VectorStore, node: &OnlineGraphBuilderNode, level: u16) {
371 let m_max = if level == 0 {
372 self.params.m * 2
373 } else {
374 self.params.m
375 };
376
377 let mut ranked = node
378 .level_neighbors_ranked
379 .lock()
380 .expect("level_neighbors_ranked mutex poisoned");
381 let level_neighbors = ranked[level as usize].clone();
382 if level_neighbors.len() <= m_max {
383 return;
384 }
385 ranked[level as usize] = select_neighbors_heuristic(storage, &level_neighbors, m_max);
386 }
387
388 pub fn search(
400 &self,
401 query: arrow_array::ArrayRef,
402 k: usize,
403 ef: usize,
404 storage: &impl VectorStore,
405 ) -> Vec<OrderedNode> {
406 let visible = self.inserted_len.load(Ordering::Acquire);
407 if visible == 0 {
408 return Vec::new();
409 }
410
411 let mut visited_generator = self
412 .visited_generator_queue
413 .pop()
414 .unwrap_or_else(|| VisitedGenerator::new(self.nodes.len()));
415
416 let dist_calc = storage.dist_calculator(query, 0.0);
417 let entry = self.entry_point.load(Ordering::Acquire);
418 let mut ep = OrderedNode::new(entry, dist_calc.distance(entry).into());
419
420 let nodes = self.nodes.as_slice();
421 for level in (1..self.params.max_level).rev() {
422 let cur_level = OnlineHnswLevelView::new(level, nodes);
423 ep = greedy_search(&cur_level, ep, &dist_calc, self.params.prefetch_distance);
424 }
425
426 let bottom = OnlineHnswBottomView::new(nodes);
427 let mut visited = visited_generator.generate(nodes.len());
428 let params = HnswQueryParams {
429 ef: ef.max(k),
430 lower_bound: None,
431 upper_bound: None,
432 dist_q_c: 0.0,
433 };
434 let result = beam_search(
435 &bottom,
436 &ep,
437 ¶ms,
438 &dist_calc,
439 None,
440 self.params.prefetch_distance,
441 &mut visited,
442 );
443 drop(visited);
444
445 let _ = self.visited_generator_queue.push(visited_generator);
446
447 let limit = ef.max(k);
450 result.into_iter().take(limit).collect()
451 }
452
453 pub fn to_hnsw(&self) -> HNSW {
463 let inserted = self.inserted_len.load(Ordering::Acquire);
464 let entry_point = self.entry_point.load(Ordering::Acquire);
465 let max_level = self.params.max_level as usize;
466
467 let mut frozen_nodes: Vec<GraphBuilderNode> = Vec::with_capacity(inserted);
468 for (idx, node) in self.nodes.iter().enumerate().take(inserted) {
469 let mut level_neighbors: Vec<Arc<Vec<u32>>> = node
470 .level_neighbors
471 .iter()
472 .map(|sl| sl.load_full())
473 .collect();
474 let mut level_neighbors_ranked = node
475 .level_neighbors_ranked
476 .lock()
477 .expect("level_neighbors_ranked mutex poisoned")
478 .clone();
479
480 if idx as u32 == entry_point {
481 while level_neighbors.len() < max_level {
482 level_neighbors.push(Arc::new(Vec::new()));
483 level_neighbors_ranked.push(Vec::new());
484 }
485 }
486
487 let bottom_neighbors = level_neighbors
488 .first()
489 .cloned()
490 .unwrap_or_else(|| Arc::new(Vec::new()));
491 frozen_nodes.push(GraphBuilderNode::from_parts(
492 level_neighbors,
493 level_neighbors_ranked,
494 bottom_neighbors,
495 ));
496 }
497
498 let mut level_count: Vec<usize> = vec![0; max_level];
499 for node in &frozen_nodes {
500 let levels = node.level_neighbors.len().min(max_level);
501 for count in level_count.iter_mut().take(levels) {
502 *count += 1;
503 }
504 }
505
506 HNSW::from_parts(self.params.clone(), frozen_nodes, level_count, entry_point)
507 }
508
509 pub fn finalize(self) -> HNSW {
511 self.to_hnsw()
512 }
513}
514
515pub struct OnlineHnswLevelView<'a> {
518 level: u16,
519 nodes: &'a [OnlineGraphBuilderNode],
520}
521
522impl<'a> OnlineHnswLevelView<'a> {
523 pub fn new(level: u16, nodes: &'a [OnlineGraphBuilderNode]) -> Self {
524 Self { level, nodes }
525 }
526}
527
528impl Graph for OnlineHnswLevelView<'_> {
529 fn len(&self) -> usize {
530 self.nodes.len()
531 }
532
533 fn neighbors(&self, key: u32) -> Arc<Vec<u32>> {
534 let node = &self.nodes[key as usize];
535 let level_idx = self.level as usize;
536 if level_idx >= node.level_neighbors.len() {
537 return Arc::new(Vec::new());
539 }
540 node.level_neighbors[level_idx].load_full()
541 }
542}
543
544pub struct OnlineHnswBottomView<'a> {
546 nodes: &'a [OnlineGraphBuilderNode],
547}
548
549impl<'a> OnlineHnswBottomView<'a> {
550 pub fn new(nodes: &'a [OnlineGraphBuilderNode]) -> Self {
551 Self { nodes }
552 }
553}
554
555impl Graph for OnlineHnswBottomView<'_> {
556 fn len(&self) -> usize {
557 self.nodes.len()
558 }
559
560 fn neighbors(&self, key: u32) -> Arc<Vec<u32>> {
561 self.nodes[key as usize].bottom_neighbors.load_full()
562 }
563}
564
565#[cfg(test)]
566mod tests {
567 use super::*;
568 use crate::vector::flat::storage::FlatFloatStorage;
569 use arrow_array::FixedSizeListArray;
570 use lance_arrow::FixedSizeListArrayExt;
571 use lance_linalg::distance::DistanceType;
572 use lance_testing::datagen::generate_random_array;
573 use std::sync::Arc;
574
575 fn build_storage(n: usize, dim: usize) -> (Arc<FlatFloatStorage>, FixedSizeListArray) {
576 let data = generate_random_array(n * dim);
577 let fsl = FixedSizeListArray::try_new_from_values(data, dim as i32).unwrap();
578 let storage = Arc::new(FlatFloatStorage::new(fsl.clone(), DistanceType::L2));
579 (storage, fsl)
580 }
581
582 #[test]
583 fn test_online_hnsw_recall() {
584 const N: usize = 1000;
585 const DIM: usize = 32;
586
587 let (storage, fsl) = build_storage(N, DIM);
588 let params = HnswBuildParams::default()
589 .num_edges(16)
590 .ef_construction(100);
591 let builder = OnlineHnswBuilder::with_capacity(N, params);
592
593 for i in 0..N {
594 builder.insert(i as u32, storage.as_ref());
595 }
596 assert_eq!(builder.len(), N);
597
598 let k = 10;
600 let mut total_correct = 0usize;
601 for q_idx in 0..50 {
602 let query = fsl.value(q_idx);
603
604 let mut all_dists: Vec<(usize, f32)> = (0..N)
606 .map(|i| {
607 let v = fsl.value(i);
608 let q = query
609 .as_any()
610 .downcast_ref::<arrow_array::Float32Array>()
611 .unwrap();
612 let vv = v
613 .as_any()
614 .downcast_ref::<arrow_array::Float32Array>()
615 .unwrap();
616 let mut s = 0.0f32;
617 for j in 0..DIM {
618 let d = q.value(j) - vv.value(j);
619 s += d * d;
620 }
621 (i, s)
622 })
623 .collect();
624 all_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
625 let truth: std::collections::HashSet<usize> =
626 all_dists.iter().take(k).map(|(i, _)| *i).collect();
627
628 let results = builder.search(query, k, 64, storage.as_ref());
629 let found: std::collections::HashSet<usize> =
630 results.iter().map(|r| r.id as usize).collect();
631 total_correct += truth.intersection(&found).count();
632 }
633
634 let recall = total_correct as f32 / (50 * k) as f32;
635 assert!(recall >= 0.85, "recall too low: {}", recall);
636 }
637
638 #[test]
639 fn test_online_hnsw_finalize_matches_search() {
640 const N: usize = 256;
641 const DIM: usize = 16;
642
643 let (storage, fsl) = build_storage(N, DIM);
644 let params = HnswBuildParams::default()
645 .num_edges(16)
646 .ef_construction(100);
647 let builder = OnlineHnswBuilder::with_capacity(N, params);
648 for i in 0..N {
649 builder.insert(i as u32, storage.as_ref());
650 }
651
652 let online_results = builder.search(fsl.value(0), 10, 64, storage.as_ref());
653
654 let hnsw = builder.finalize();
655 let mut visited = VisitedGenerator::new(N);
656 let bottom_results = hnsw
657 .search_inner(
658 fsl.value(0),
659 10,
660 &HnswQueryParams {
661 ef: 64,
662 lower_bound: None,
663 upper_bound: None,
664 dist_q_c: 0.0,
665 },
666 None,
667 &mut visited,
668 storage.as_ref(),
669 Some(2),
670 )
671 .unwrap();
672
673 let online_ids: std::collections::HashSet<u32> =
674 online_results.iter().map(|r| r.id).collect();
675 let frozen_ids: std::collections::HashSet<u32> =
676 bottom_results.iter().map(|r| r.id).collect();
677 let overlap = online_ids.intersection(&frozen_ids).count();
681 assert!(
682 overlap >= 7,
683 "frozen vs online overlap too low: {}",
684 overlap
685 );
686 }
687
688 #[test]
689 fn test_online_hnsw_empty_search() {
690 let params = HnswBuildParams::default();
691 let builder = OnlineHnswBuilder::with_capacity(16, params);
692 let (storage, fsl) = build_storage(1, 8);
693 let results = builder.search(fsl.value(0), 10, 32, storage.as_ref());
694 assert!(results.is_empty());
695 }
696}