1use arrow::array::{AsArray, ListBuilder, UInt32Builder};
7use arrow::compute::concat_batches;
8use arrow::datatypes::{DataType, Float32Type, UInt32Type};
9use arrow_array::{ArrayRef, Float32Array, RecordBatch, UInt64Array};
10use crossbeam_queue::ArrayQueue;
11use deepsize::DeepSizeOf;
12use itertools::Itertools;
13
14use lance_core::utils::tokio::get_num_compute_intensive_cpus;
15use lance_linalg::distance::DistanceType;
16use rayon::prelude::*;
17use std::cmp::min;
18use std::collections::{BinaryHeap, HashMap, VecDeque};
19use std::fmt::Debug;
20use std::iter;
21use std::sync::Arc;
22use std::sync::RwLock;
23use std::sync::atomic::{AtomicUsize, Ordering};
24use tracing::instrument;
25
26use lance_core::{Error, Result};
27use rand::{Rng, SeedableRng, rngs::SmallRng};
28use serde::{Deserialize, Serialize};
29
30use super::super::graph::beam_search;
31use super::{HNSW_TYPE, HnswMetadata, VECTOR_ID_COL, VECTOR_ID_FIELD, select_neighbors_heuristic};
32use crate::metrics::MetricsCollector;
33use crate::prefilter::PreFilter;
34use crate::vector::flat::storage::{FlatBinStorage, FlatFloatStorage};
35use crate::vector::graph::builder::GraphBuilderNode;
36use crate::vector::graph::{
37 BorrowingGraph, DISTS_FIELD, Graph, NEIGHBORS_COL, NEIGHBORS_FIELD, OrderedFloat, OrderedNode,
38 VisitedGenerator,
39};
40use crate::vector::graph::{Visited, beam_search_borrowed, greedy_search, greedy_search_borrowed};
41use crate::vector::storage::{DistCalculator, VectorStore};
42use crate::vector::v3::subindex::IvfSubIndex;
43use crate::vector::{DIST_COL, Query, VECTOR_RESULT_SCHEMA};
44
45pub const HNSW_METADATA_KEY: &str = "lance:hnsw";
46
47pub(crate) const HNSW_LEVEL_RNG_SEED: u64 = 42;
55
56#[derive(Debug, Clone, Serialize, Deserialize, DeepSizeOf)]
58pub struct HnswBuildParams {
59 pub max_level: u16,
61
62 pub m: usize,
64
65 pub ef_construction: usize,
67
68 pub prefetch_distance: Option<usize>,
70}
71
72impl From<&HnswBuildParams> for crate::pb::HnswParameters {
73 fn from(params: &HnswBuildParams) -> Self {
74 Self {
75 max_connections: params.m as u32,
76 construction_ef: params.ef_construction as u32,
77 max_level: params.max_level as u32,
78 }
79 }
80}
81
82impl Default for HnswBuildParams {
83 fn default() -> Self {
84 Self {
85 max_level: 7,
86 m: 20,
87 ef_construction: 150,
88 prefetch_distance: Some(2),
89 }
90 }
91}
92
93impl HnswBuildParams {
94 pub fn max_level(mut self, max_level: u16) -> Self {
97 self.max_level = max_level;
98 self
99 }
100
101 pub fn num_edges(mut self, m: usize) -> Self {
104 self.m = m;
105 self
106 }
107
108 pub fn ef_construction(mut self, ef_construction: usize) -> Self {
113 self.ef_construction = ef_construction;
114 self
115 }
116
117 pub async fn build(self, data: ArrayRef, distance_type: DistanceType) -> Result<HNSW> {
123 let vectors = data.as_fixed_size_list().clone();
124 match (vectors.value_type(), distance_type) {
125 (DataType::UInt8, DistanceType::Hamming) => {
126 let vec_store = Arc::new(FlatBinStorage::new(vectors, distance_type));
127 HNSW::index_vectors(vec_store.as_ref(), self)
128 }
129 (DataType::UInt8, _) => Err(Error::invalid_input(format!(
130 "HNSW only supports hamming distance for UInt8 vectors, got {}",
131 distance_type
132 ))),
133 (_, DistanceType::Hamming) => Err(Error::invalid_input(format!(
134 "HNSW hamming distance only supports UInt8 vectors, got {}",
135 vectors.value_type()
136 ))),
137 _ => {
138 let vec_store = Arc::new(FlatFloatStorage::new(vectors, distance_type));
139 HNSW::index_vectors(vec_store.as_ref(), self)
140 }
141 }
142 }
143}
144
145#[derive(Clone, DeepSizeOf)]
153pub struct HNSW {
154 inner: Arc<HnswCore>,
155}
156
157struct HnswCore {
158 params: HnswBuildParams,
159 nodes: Arc<Vec<GraphBuilderNode>>,
160 level_count: Vec<usize>,
161 entry_point: u32,
162 visited_generator_queue: Arc<ArrayQueue<VisitedGenerator>>,
163}
164
165impl DeepSizeOf for HnswCore {
166 fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
167 self.params.deep_size_of_children(context)
168 + self.nodes.deep_size_of_children(context)
169 + self.level_count.deep_size_of_children(context)
170 }
172}
173
174impl HnswCore {
175 fn max_level(&self) -> u16 {
176 self.params.max_level
177 }
178
179 fn num_nodes(&self, level: usize) -> usize {
180 self.level_count[level]
181 }
182
183 fn nodes(&self) -> Arc<Vec<GraphBuilderNode>> {
184 self.nodes.clone()
185 }
186}
187
188impl Debug for HNSW {
189 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190 write!(f, "HNSW(max_layers: {})", self.inner.max_level() as usize,)
191 }
192}
193
194impl HNSW {
195 pub(crate) fn from_parts(
198 params: HnswBuildParams,
199 nodes: Vec<GraphBuilderNode>,
200 level_count: Vec<usize>,
201 entry_point: u32,
202 ) -> Self {
203 let queue_size = get_num_compute_intensive_cpus().max(1) * 2;
204 let visited_generator_queue = Arc::new(ArrayQueue::new(queue_size));
205 for _ in 0..queue_size {
206 let _ = visited_generator_queue.push(VisitedGenerator::new(0));
207 }
208 Self {
209 inner: Arc::new(HnswCore {
210 params,
211 nodes: Arc::new(nodes),
212 level_count,
213 entry_point,
214 visited_generator_queue,
215 }),
216 }
217 }
218
219 pub fn empty() -> Self {
220 Self {
221 inner: Arc::new(HnswCore {
222 params: HnswBuildParams::default(),
223 nodes: Arc::new(Vec::new()),
224 level_count: Vec::new(),
225 entry_point: 0,
226 visited_generator_queue: Arc::new(ArrayQueue::new(1)),
227 }),
228 }
229 }
230
231 pub fn len(&self) -> usize {
232 self.inner.nodes.len()
233 }
234
235 pub fn is_empty(&self) -> bool {
236 self.len() == 0
237 }
238
239 pub fn max_level(&self) -> u16 {
240 self.inner.max_level()
241 }
242
243 pub fn num_nodes(&self, level: usize) -> usize {
244 self.inner.num_nodes(level)
245 }
246
247 pub fn nodes(&self) -> Arc<Vec<GraphBuilderNode>> {
248 self.inner.nodes()
249 }
250
251 #[allow(clippy::too_many_arguments)]
252 pub fn search_inner(
253 &self,
254 query: ArrayRef,
255 k: usize,
256 params: &HnswQueryParams,
257 bitset: Option<Visited>,
258 visited_generator: &mut VisitedGenerator,
259 storage: &impl VectorStore,
260 prefetch_distance: Option<usize>,
261 ) -> Result<Vec<OrderedNode>> {
262 let dist_calc = storage.dist_calculator(query, params.dist_q_c);
263 let entry = self.inner.entry_point;
264 let mut ep = OrderedNode::new(entry, dist_calc.distance(entry).into());
265 let nodes = self.inner.nodes.as_ref();
266 for level in (0..self.max_level()).rev() {
267 let cur_level = ImmutableHnswLevelView::new(level, nodes);
268 ep = greedy_search_borrowed(
269 &cur_level,
270 ep,
271 &dist_calc,
272 self.inner.params.prefetch_distance,
273 );
274 }
275
276 let bottom_level = ImmutableHnswBottomView::new(nodes);
277 let mut visited = visited_generator.generate(storage.len());
278 Ok(beam_search_borrowed(
279 &bottom_level,
280 &ep,
281 params,
282 &dist_calc,
283 bitset.as_ref(),
284 prefetch_distance,
285 &mut visited,
286 )
287 .into_iter()
288 .take(k)
289 .collect())
290 }
291
292 #[instrument(level = "debug", skip(self, query, bitset, storage))]
293 pub fn search_basic(
294 &self,
295 query: ArrayRef,
296 k: usize,
297 params: &HnswQueryParams,
298 bitset: Option<Visited>,
299 storage: &impl VectorStore,
300 ) -> Result<Vec<OrderedNode>> {
301 let mut visited_generator = self
302 .inner
303 .visited_generator_queue
304 .pop()
305 .unwrap_or_else(|| VisitedGenerator::new(storage.len()));
306 let result = self.search_inner(
307 query,
308 k,
309 params,
310 bitset,
311 &mut visited_generator,
312 storage,
313 Some(2),
314 );
315
316 match self.inner.visited_generator_queue.push(visited_generator) {
317 Ok(_) => {}
318 Err(_) => {
319 log::warn!("visited_generator_queue is full");
320 }
321 }
322
323 result
324 }
325
326 #[instrument(level = "debug", skip(self, storage, query, prefilter_bitset))]
327 fn flat_search(
328 &self,
329 storage: &impl VectorStore,
330 query: ArrayRef,
331 k: usize,
332 prefilter_bitset: Visited,
333 params: &HnswQueryParams,
334 ) -> Vec<OrderedNode> {
335 let lower_bound: OrderedFloat = params.lower_bound.unwrap_or(f32::MIN).into();
336 let upper_bound: OrderedFloat = params.upper_bound.unwrap_or(f32::MAX).into();
337
338 let dist_calc = storage.dist_calculator(query, params.dist_q_c);
339 let mut heap = BinaryHeap::<OrderedNode>::with_capacity(k);
340
341 match self.inner.params.prefetch_distance {
342 Some(ahead) if ahead > 0 => {
343 let mut ids_iter = prefilter_bitset.iter_ones().map(|i| i as u32);
344 let mut buffer = VecDeque::with_capacity(ahead + 1);
345 for _ in 0..=ahead {
346 if let Some(id) = ids_iter.next() {
347 buffer.push_back(id);
348 } else {
349 break;
350 }
351 }
352
353 while let Some(node_id) = buffer.pop_front() {
354 if let Some(&prefetch_id) = buffer.get(ahead - 1) {
355 dist_calc.prefetch(prefetch_id);
356 }
357 if let Some(next) = ids_iter.next() {
358 buffer.push_back(next);
359 }
360
361 let dist: OrderedFloat = dist_calc.distance(node_id).into();
362 if dist <= lower_bound || dist > upper_bound {
363 continue;
364 }
365 if heap.len() < k {
366 heap.push((dist, node_id).into());
367 } else if dist < heap.peek().unwrap().dist {
368 heap.pop();
369 heap.push((dist, node_id).into());
370 }
371 }
372 }
373 _ => {
374 for node_id in prefilter_bitset.iter_ones().map(|i| i as u32) {
375 let dist: OrderedFloat = dist_calc.distance(node_id).into();
376 if dist <= lower_bound || dist > upper_bound {
377 continue;
378 }
379 if heap.len() < k {
380 heap.push((dist, node_id).into());
381 } else if dist < heap.peek().unwrap().dist {
382 heap.pop();
383 heap.push((dist, node_id).into());
384 }
385 }
386 }
387 };
388 heap.into_sorted_vec()
389 }
390
391 pub fn metadata(&self) -> HnswMetadata {
393 let level_offsets = self
396 .inner
397 .level_count
398 .iter()
399 .chain(iter::once(&0))
400 .scan(0, |state, x| {
401 let start = *state;
402 *state += *x;
403 Some(start)
404 })
405 .collect();
406
407 HnswMetadata {
408 entry_point: self.inner.entry_point,
409 params: self.inner.params.clone(),
410 level_offsets,
411 }
412 }
413}
414
415struct HnswBuilder {
416 params: HnswBuildParams,
417
418 nodes: Arc<Vec<RwLock<GraphBuilderNode>>>,
419 level_count: Vec<AtomicUsize>,
420
421 entry_point: u32,
422
423 visited_generator_queue: Arc<ArrayQueue<VisitedGenerator>>,
424}
425
426impl DeepSizeOf for HnswBuilder {
427 fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
428 self.params.deep_size_of_children(context)
429 + self.nodes.deep_size_of_children(context)
430 + self.level_count.deep_size_of_children(context)
431 }
433}
434
435impl HnswBuilder {
436 fn finish(self) -> HNSW {
437 let nodes = match Arc::try_unwrap(self.nodes) {
438 Ok(nodes) => nodes
439 .into_iter()
440 .map(|node| node.into_inner().expect("builder lock poisoned"))
441 .collect(),
442 Err(nodes) => nodes
443 .iter()
444 .map(|node| node.read().expect("builder lock poisoned").clone())
445 .collect(),
446 };
447
448 let level_count = self
449 .level_count
450 .into_iter()
451 .map(|count| count.load(Ordering::Relaxed))
452 .collect();
453
454 HNSW {
455 inner: Arc::new(HnswCore {
456 params: self.params,
457 nodes: Arc::new(nodes),
458 level_count,
459 entry_point: self.entry_point,
460 visited_generator_queue: self.visited_generator_queue,
461 }),
462 }
463 }
464
465 pub fn with_params(params: HnswBuildParams, storage: &impl VectorStore) -> Self {
467 let len = storage.len();
468 let max_level = params.max_level;
469
470 let level_count = (0..max_level)
471 .map(|_| AtomicUsize::new(0))
472 .collect::<Vec<_>>();
473
474 let visited_generator_queue = Arc::new(ArrayQueue::new(get_num_compute_intensive_cpus()));
475 for _ in 0..get_num_compute_intensive_cpus() {
476 visited_generator_queue
477 .push(VisitedGenerator::new(0))
478 .unwrap();
479 }
480 let mut builder = Self {
481 params,
482 nodes: Arc::new(Vec::new()),
483 level_count,
484 entry_point: 0,
485 visited_generator_queue,
486 };
487
488 if storage.is_empty() {
489 return builder;
490 }
491
492 let mut nodes = Vec::with_capacity(len);
493 {
494 if len > 0 {
495 nodes.push(RwLock::new(GraphBuilderNode::new(0, max_level as usize)));
496 }
497 let mut level_rng = SmallRng::seed_from_u64(HNSW_LEVEL_RNG_SEED);
498 for i in 1..len {
499 nodes.push(RwLock::new(GraphBuilderNode::new(
500 i as u32,
501 builder.random_level(&mut level_rng) as usize + 1,
502 )));
503 }
504 }
505 builder.nodes = Arc::new(nodes);
506
507 builder
508 }
509
510 fn random_level<R: Rng + ?Sized>(&self, rng: &mut R) -> u16 {
514 let ml = 1.0 / (self.params.m as f32).ln();
515 min(
516 (-rng.random::<f32>().ln() * ml) as u16,
517 self.params.max_level - 1,
518 )
519 }
520
521 fn insert(
523 &self,
524 node: u32,
525 visited_generator: &mut VisitedGenerator,
526 storage: &impl VectorStore,
527 ) {
528 let nodes = &self.nodes;
529 let target_level = nodes[node as usize].read().unwrap().level_neighbors.len() as u16 - 1;
530 let dist_calc = storage.dist_calculator_from_id(node);
531 let mut ep = OrderedNode::new(
532 self.entry_point,
533 dist_calc.distance(self.entry_point).into(),
534 );
535
536 for level in (target_level + 1..self.params.max_level).rev() {
545 let cur_level = HnswLevelView::new(level, nodes);
546 ep = greedy_search(&cur_level, ep, &dist_calc, self.params.prefetch_distance);
547 }
548
549 let mut pruned_neighbors_per_level: Vec<Vec<_>> =
550 vec![Vec::new(); (target_level + 1) as usize];
551 {
552 let mut current_node = nodes[node as usize].write().unwrap();
553 for level in (0..=target_level).rev() {
554 self.level_count[level as usize].fetch_add(1, Ordering::Relaxed);
555
556 let neighbors = self.search_level(&ep, level, &dist_calc, nodes, visited_generator);
557 for neighbor in &neighbors {
558 current_node.add_neighbor(neighbor.id, neighbor.dist, level);
559 }
560 self.prune(storage, &mut current_node, level);
561 pruned_neighbors_per_level[level as usize]
562 .clone_from(¤t_node.level_neighbors_ranked[level as usize]);
563
564 ep = neighbors[0].clone();
565 }
566 }
567 for (level, pruned_neighbors) in pruned_neighbors_per_level.iter().enumerate() {
568 let _: Vec<_> = pruned_neighbors
569 .iter()
570 .map(|unpruned_edge| {
571 let level = level as u16;
572 let m_max = match level {
573 0 => self.params.m * 2,
574 _ => self.params.m,
575 };
576 if unpruned_edge.dist
577 < nodes[unpruned_edge.id as usize]
578 .read()
579 .unwrap()
580 .cutoff(level, m_max)
581 {
582 let mut chosen_node = nodes[unpruned_edge.id as usize].write().unwrap();
583 chosen_node.add_neighbor(node, unpruned_edge.dist, level);
584 self.prune(storage, &mut chosen_node, level);
585 }
586 })
587 .collect();
588 }
589 }
590
591 fn search_level(
592 &self,
593 ep: &OrderedNode,
594 level: u16,
595 dist_calc: &impl DistCalculator,
596 nodes: &[RwLock<GraphBuilderNode>],
597 visited_generator: &mut VisitedGenerator,
598 ) -> Vec<OrderedNode> {
599 let cur_level = HnswLevelView::new(level, nodes);
600 let mut visited = visited_generator.generate(nodes.len());
601 beam_search(
602 &cur_level,
603 ep,
604 &HnswQueryParams {
605 ef: self.params.ef_construction,
606 lower_bound: None,
607 upper_bound: None,
608 dist_q_c: 0.0,
609 },
610 dist_calc,
611 None,
612 self.params.prefetch_distance,
613 &mut visited,
614 )
615 }
616
617 fn prune(&self, storage: &impl VectorStore, builder_node: &mut GraphBuilderNode, level: u16) {
618 let m_max = match level {
619 0 => self.params.m * 2,
620 _ => self.params.m,
621 };
622
623 let neighbors_ranked = &mut builder_node.level_neighbors_ranked[level as usize];
624 let level_neighbors = neighbors_ranked.clone();
625 if level_neighbors.len() <= m_max {
626 builder_node.update_from_ranked_neighbors(level);
627 return;
628 }
629
630 *neighbors_ranked = select_neighbors_heuristic(storage, &level_neighbors, m_max);
631 builder_node.update_from_ranked_neighbors(level);
632 }
633}
634
635pub(crate) struct HnswLevelView<'a> {
638 level: u16,
639 nodes: &'a [RwLock<GraphBuilderNode>],
640}
641
642impl<'a> HnswLevelView<'a> {
643 pub fn new(level: u16, nodes: &'a [RwLock<GraphBuilderNode>]) -> Self {
644 Self { level, nodes }
645 }
646}
647
648impl Graph for HnswLevelView<'_> {
649 fn len(&self) -> usize {
650 self.nodes.len()
651 }
652
653 fn neighbors(&self, key: u32) -> Arc<Vec<u32>> {
654 let node = &self.nodes[key as usize];
655 node.read().unwrap().level_neighbors[self.level as usize].clone()
656 }
657}
658
659pub(crate) struct ImmutableHnswLevelView<'a> {
660 level: u16,
661 nodes: &'a [GraphBuilderNode],
662}
663
664impl<'a> ImmutableHnswLevelView<'a> {
665 pub fn new(level: u16, nodes: &'a [GraphBuilderNode]) -> Self {
666 Self { level, nodes }
667 }
668}
669
670impl Graph for ImmutableHnswLevelView<'_> {
671 fn len(&self) -> usize {
672 self.nodes.len()
673 }
674
675 fn neighbors(&self, key: u32) -> Arc<Vec<u32>> {
676 self.nodes[key as usize].level_neighbors[self.level as usize].clone()
677 }
678}
679
680impl BorrowingGraph for ImmutableHnswLevelView<'_> {
681 fn len(&self) -> usize {
682 self.nodes.len()
683 }
684
685 fn neighbors(&self, key: u32) -> &[u32] {
686 self.nodes[key as usize].level_neighbors[self.level as usize].as_slice()
687 }
688}
689
690pub(crate) struct ImmutableHnswBottomView<'a> {
691 nodes: &'a [GraphBuilderNode],
692}
693
694impl<'a> ImmutableHnswBottomView<'a> {
695 pub fn new(nodes: &'a [GraphBuilderNode]) -> Self {
696 Self { nodes }
697 }
698}
699
700impl Graph for ImmutableHnswBottomView<'_> {
701 fn len(&self) -> usize {
702 self.nodes.len()
703 }
704
705 fn neighbors(&self, key: u32) -> Arc<Vec<u32>> {
706 self.nodes[key as usize].bottom_neighbors.clone()
707 }
708}
709
710impl BorrowingGraph for ImmutableHnswBottomView<'_> {
711 fn len(&self) -> usize {
712 self.nodes.len()
713 }
714
715 fn neighbors(&self, key: u32) -> &[u32] {
716 self.nodes[key as usize].bottom_neighbors.as_slice()
717 }
718}
719
720#[derive(Debug, Clone, Copy)]
721pub struct HnswQueryParams {
722 pub ef: usize,
723 pub lower_bound: Option<f32>,
724 pub upper_bound: Option<f32>,
725 pub dist_q_c: f32,
726}
727
728impl From<&Query> for HnswQueryParams {
729 fn from(query: &Query) -> Self {
730 let k = query.k * query.refine_factor.unwrap_or(1) as usize;
731 Self {
732 ef: query.ef.unwrap_or(k + k / 2),
733 lower_bound: query.lower_bound,
734 upper_bound: query.upper_bound,
735 dist_q_c: query.dist_q_c,
736 }
737 }
738}
739
740impl IvfSubIndex for HNSW {
741 type BuildParams = HnswBuildParams;
742 type QueryParams = HnswQueryParams;
743
744 fn load(data: RecordBatch) -> Result<Self>
745 where
746 Self: Sized,
747 {
748 if data.num_rows() == 0 {
749 return Ok(Self::empty());
750 }
751
752 let hnsw_metadata = data
753 .schema_ref()
754 .metadata()
755 .get(HNSW_METADATA_KEY)
756 .ok_or(Error::index(format!("{} not found", HNSW_METADATA_KEY)))?;
757 let hnsw_metadata: HnswMetadata = serde_json::from_str(hnsw_metadata).map_err(|e| {
758 Error::index(format!(
759 "Failed to decode HNSW metadata: {}, json: {}",
760 e, hnsw_metadata
761 ))
762 })?;
763
764 let levels: Vec<_> = hnsw_metadata
765 .level_offsets
766 .iter()
767 .tuple_windows()
768 .map(|(start, end)| data.slice(*start, end - start))
769 .collect();
770
771 let level_count = levels.iter().map(|b| b.num_rows()).collect::<Vec<_>>();
772
773 let bottom_level_len = levels[0].num_rows();
774 let mut nodes = Vec::with_capacity(bottom_level_len);
775 for i in 0..bottom_level_len {
776 nodes.push(GraphBuilderNode::new(i as u32, levels.len()));
777 }
778 for (level, batch) in levels.into_iter().enumerate() {
779 let ids = batch[VECTOR_ID_COL].as_primitive::<UInt32Type>();
780 let neighbors = batch[NEIGHBORS_COL].as_list::<i32>();
781 let distances = batch[DIST_COL].as_list::<i32>();
782
783 for ((node, neighbors), distances) in
784 ids.iter().zip(neighbors.iter()).zip(distances.iter())
785 {
786 let node = node.unwrap();
787 let neighbors = neighbors.as_ref().unwrap().as_primitive::<UInt32Type>();
788 let distances = distances.as_ref().unwrap().as_primitive::<Float32Type>();
789
790 nodes[node as usize].level_neighbors_ranked[level] = neighbors
791 .iter()
792 .zip(distances.iter())
793 .map(|(n, dist)| OrderedNode::new(n.unwrap(), OrderedFloat(dist.unwrap())))
794 .collect();
795 nodes[node as usize].update_from_ranked_neighbors(level as u16);
796 }
797 }
798
799 let visited_generator_queue =
800 Arc::new(ArrayQueue::new(get_num_compute_intensive_cpus() * 2));
801 for _ in 0..get_num_compute_intensive_cpus() * 2 {
802 visited_generator_queue
803 .push(VisitedGenerator::new(0))
804 .unwrap();
805 }
806 let inner = HnswCore {
807 params: hnsw_metadata.params,
808 nodes: Arc::new(nodes),
809 level_count,
810 entry_point: hnsw_metadata.entry_point,
811 visited_generator_queue,
812 };
813
814 Ok(Self {
815 inner: Arc::new(inner),
816 })
817 }
818
819 fn name() -> &'static str {
820 HNSW_TYPE
821 }
822
823 fn metadata_key() -> &'static str {
824 "lance:hnsw"
825 }
826
827 fn schema() -> arrow_schema::SchemaRef {
829 arrow_schema::Schema::new(vec![
830 VECTOR_ID_FIELD.clone(),
831 NEIGHBORS_FIELD.clone(),
832 DISTS_FIELD.clone(),
833 ])
834 .into()
835 }
836
837 #[instrument(level = "debug", skip(self, query, storage, prefilter, _metrics))]
838 fn search(
839 &self,
840 query: ArrayRef,
841 k: usize,
842 params: Self::QueryParams,
843 storage: &impl VectorStore,
844 prefilter: Arc<dyn PreFilter>,
845 _metrics: &dyn MetricsCollector,
846 ) -> Result<RecordBatch> {
847 if params.ef < k {
848 return Err(Error::index(
849 "ef must be greater than or equal to k".to_string(),
850 ));
851 }
852
853 let schema = VECTOR_RESULT_SCHEMA.clone();
854 if self.is_empty() {
855 return Ok(RecordBatch::new_empty(schema));
856 }
857
858 let mut prefilter_generator = self
859 .inner
860 .visited_generator_queue
861 .pop()
862 .unwrap_or_else(|| VisitedGenerator::new(storage.len()));
863 let prefilter_bitset = if prefilter.is_empty() {
864 None
865 } else {
866 let indices = prefilter.filter_row_ids(Box::new(storage.row_ids()));
867 let mut bitset = prefilter_generator.generate(storage.len());
868 for indices in indices {
869 bitset.insert(indices as u32);
870 }
871 Some(bitset)
872 };
873
874 let remained = prefilter_bitset
875 .as_ref()
876 .map(|b| b.count_ones())
877 .unwrap_or(storage.len());
878 let results = if remained < self.len() * 10 / 100 {
879 let prefilter_bitset =
880 prefilter_bitset.expect("the prefilter bitset must be set for flat search");
881 self.flat_search(storage, query, k, prefilter_bitset, ¶ms)
882 } else {
883 self.search_basic(query, k, ¶ms, prefilter_bitset, storage)?
884 };
885 let _ = self.inner.visited_generator_queue.push(prefilter_generator);
887
888 let (row_ids, dists): (Vec<_>, Vec<_>) = results
890 .into_iter()
891 .map(|r| (storage.row_id(r.id), r.dist.0))
892 .unique_by(|r| r.0)
893 .unzip();
894 let row_ids = Arc::new(UInt64Array::from(row_ids));
895 let distances = Arc::new(Float32Array::from(dists));
896
897 Ok(RecordBatch::try_new(schema, vec![distances, row_ids])?)
898 }
899
900 fn index_vectors(storage: &impl VectorStore, params: Self::BuildParams) -> Result<Self>
902 where
903 Self: Sized,
904 {
905 let builder = HnswBuilder::with_params(params, storage);
906
907 log::debug!(
908 "Building HNSW graph: num={}, max_levels={}, m={}, ef_construction={}, distance_type:{}",
909 storage.len(),
910 builder.params.max_level,
911 builder.params.m,
912 builder.params.ef_construction,
913 storage.distance_type(),
914 );
915
916 if storage.is_empty() {
917 return Ok(builder.finish());
918 }
919
920 let len = storage.len();
921 builder.level_count[0].fetch_add(1, Ordering::Relaxed);
922 (1..len).into_par_iter().for_each_init(
923 || VisitedGenerator::new(len),
924 |visited_generator, node| {
925 builder.insert(node as u32, visited_generator, storage);
926 },
927 );
928
929 assert_eq!(builder.level_count[0].load(Ordering::Relaxed), len);
930 Ok(builder.finish())
931 }
932
933 fn remap(
934 &self,
935 _mapping: &HashMap<u64, Option<u64>>, store: &impl VectorStore,
937 ) -> Result<Self> {
938 Self::index_vectors(store, self.inner.params.clone())
941 }
942
943 fn to_batch(&self) -> Result<RecordBatch> {
945 let mut vector_id_builder = UInt32Builder::with_capacity(self.len());
946 let mut neighbors_builder = ListBuilder::with_capacity(UInt32Builder::new(), self.len());
947 let mut distances_builder =
948 ListBuilder::with_capacity(arrow_array::builder::Float32Builder::new(), self.len());
949 let mut batches = Vec::with_capacity(self.max_level() as usize);
950 for level in 0..self.max_level() {
951 let level = level as usize;
952 for (id, node) in self.inner.nodes.iter().enumerate() {
953 if level >= node.level_neighbors.len() {
954 continue;
955 }
956 let neighbors = node.level_neighbors[level].iter().map(|n| Some(*n));
957 let distances = node.level_neighbors_ranked[level]
958 .iter()
959 .map(|n| Some(n.dist.0));
960 vector_id_builder.append_value(id as u32);
961 neighbors_builder.append_value(neighbors);
962 distances_builder.append_value(distances);
963 }
964
965 let batch = RecordBatch::try_new(
966 Self::schema(),
967 vec![
968 Arc::new(vector_id_builder.finish()),
969 Arc::new(neighbors_builder.finish()),
970 Arc::new(distances_builder.finish()),
971 ],
972 )?;
973 batches.push(batch);
974 }
975
976 let metadata = self.metadata();
977 let metadata = serde_json::to_string(&metadata)?;
978 let schema = Self::schema()
979 .as_ref()
980 .clone()
981 .with_metadata(HashMap::from_iter(vec![(
982 HNSW_METADATA_KEY.to_string(),
983 metadata,
984 )]));
985 let batch = concat_batches(&Self::schema(), batches.iter())?;
986 let batch = batch.with_schema(Arc::new(schema))?;
987 Ok(batch)
988 }
989}
990
991#[cfg(test)]
992mod tests {
993 use std::sync::Arc;
994
995 use arrow_array::{FixedSizeListArray, UInt8Array};
996 use arrow_schema::Schema;
997 use lance_arrow::FixedSizeListArrayExt;
998 use lance_file::previous::{
999 reader::FileReader as PreviousFileReader,
1000 writer::{
1001 FileWriter as PreviousFileWriter, FileWriterOptions as PreviousFileWriterOptions,
1002 },
1003 };
1004 use lance_io::object_store::ObjectStore;
1005 use lance_linalg::distance::DistanceType;
1006 use lance_table::format::SelfDescribingFileReader;
1007 use lance_table::io::manifest::ManifestDescribing;
1008 use lance_testing::datagen::generate_random_array;
1009 use object_store::path::Path;
1010
1011 use crate::scalar::IndexWriter;
1012 use crate::vector::v3::subindex::IvfSubIndex;
1013 use crate::vector::{
1014 flat::storage::{FlatBinStorage, FlatFloatStorage},
1015 graph::{DISTS_FIELD, NEIGHBORS_FIELD},
1016 hnsw::{
1017 HNSW, VECTOR_ID_FIELD,
1018 builder::{HnswBuildParams, HnswQueryParams},
1019 },
1020 };
1021
1022 #[tokio::test]
1023 async fn test_builder_write_load() {
1024 const DIM: usize = 32;
1025 const TOTAL: usize = 2048;
1026 const NUM_EDGES: usize = 20;
1027 let data = generate_random_array(TOTAL * DIM);
1028 let fsl = FixedSizeListArray::try_new_from_values(data, DIM as i32).unwrap();
1029 let store = Arc::new(FlatFloatStorage::new(fsl.clone(), DistanceType::L2));
1030 let builder = HNSW::index_vectors(
1031 store.as_ref(),
1032 HnswBuildParams::default()
1033 .num_edges(NUM_EDGES)
1034 .ef_construction(50),
1035 )
1036 .unwrap();
1037
1038 let object_store = ObjectStore::memory();
1039 let path = Path::from("test_builder_write_load");
1040 let writer = object_store.create(&path).await.unwrap();
1041 let schema = Schema::new(vec![
1042 VECTOR_ID_FIELD.clone(),
1043 NEIGHBORS_FIELD.clone(),
1044 DISTS_FIELD.clone(),
1045 ]);
1046 let schema = lance_core::datatypes::Schema::try_from(&schema).unwrap();
1047 let mut writer = PreviousFileWriter::<ManifestDescribing>::with_object_writer(
1048 writer,
1049 schema,
1050 &PreviousFileWriterOptions::default(),
1051 )
1052 .unwrap();
1053 let batch = builder.to_batch().unwrap();
1054 let metadata = batch.schema_ref().metadata().clone();
1055 writer.write_record_batch(batch).await.unwrap();
1056 writer.finish_with_metadata(&metadata).await.unwrap();
1057
1058 let reader = PreviousFileReader::try_new_self_described(&object_store, &path, None)
1059 .await
1060 .unwrap();
1061 let batch = reader
1062 .read_range(0..reader.len(), reader.schema())
1063 .await
1064 .unwrap();
1065 let loaded_hnsw = HNSW::load(batch).unwrap();
1066
1067 let query = fsl.value(0);
1068 let k = 10;
1069 let params = HnswQueryParams {
1070 ef: 50,
1071 lower_bound: None,
1072 upper_bound: None,
1073 dist_q_c: 0.0,
1074 };
1075 let builder_results = builder
1076 .search_basic(query.clone(), k, ¶ms, None, store.as_ref())
1077 .unwrap();
1078 let loaded_results = loaded_hnsw
1079 .search_basic(query, k, ¶ms, None, store.as_ref())
1080 .unwrap();
1081 assert_eq!(builder_results, loaded_results);
1082 }
1083
1084 #[tokio::test]
1085 async fn test_builder_write_load_binary_hamming() {
1086 const DIM: usize = 8;
1087 const TOTAL: usize = 256;
1088 const NUM_EDGES: usize = 20;
1089 let data = UInt8Array::from_iter_values((0..TOTAL * DIM).map(|v| (v % 16) as u8));
1090 let fsl = FixedSizeListArray::try_new_from_values(data, DIM as i32).unwrap();
1091 let store = Arc::new(FlatBinStorage::new(fsl.clone(), DistanceType::Hamming));
1092 let builder = HnswBuildParams::default()
1093 .num_edges(NUM_EDGES)
1094 .ef_construction(50)
1095 .build(Arc::new(fsl.clone()), DistanceType::Hamming)
1096 .await
1097 .unwrap();
1098
1099 let object_store = ObjectStore::memory();
1100 let path = Path::from("test_builder_write_load_binary_hamming");
1101 let writer = object_store.create(&path).await.unwrap();
1102 let schema = Schema::new(vec![
1103 VECTOR_ID_FIELD.clone(),
1104 NEIGHBORS_FIELD.clone(),
1105 DISTS_FIELD.clone(),
1106 ]);
1107 let schema = lance_core::datatypes::Schema::try_from(&schema).unwrap();
1108 let mut writer = PreviousFileWriter::<ManifestDescribing>::with_object_writer(
1109 writer,
1110 schema,
1111 &PreviousFileWriterOptions::default(),
1112 )
1113 .unwrap();
1114 let batch = builder.to_batch().unwrap();
1115 let metadata = batch.schema_ref().metadata().clone();
1116 writer.write_record_batch(batch).await.unwrap();
1117 writer.finish_with_metadata(&metadata).await.unwrap();
1118
1119 let reader = PreviousFileReader::try_new_self_described(&object_store, &path, None)
1120 .await
1121 .unwrap();
1122 let batch = reader
1123 .read_range(0..reader.len(), reader.schema())
1124 .await
1125 .unwrap();
1126 let loaded_hnsw = HNSW::load(batch).unwrap();
1127
1128 let query = fsl.value(0);
1129 let k = 10;
1130 let params = HnswQueryParams {
1131 ef: 50,
1132 lower_bound: None,
1133 upper_bound: None,
1134 dist_q_c: 0.0,
1135 };
1136 let builder_results = builder
1137 .search_basic(query.clone(), k, ¶ms, None, store.as_ref())
1138 .unwrap();
1139 let loaded_results = loaded_hnsw
1140 .search_basic(query, k, ¶ms, None, store.as_ref())
1141 .unwrap();
1142 assert_eq!(builder_results, loaded_results);
1143 }
1144}