1use arrow::array::{AsArray, ListBuilder, UInt32Builder};
7use arrow::compute::concat_batches;
8use arrow::datatypes::{DataType, Float32Type, UInt32Type};
9use arrow_array::{ArrayRef, Float32Array, RecordBatch, UInt64Array};
10use crossbeam_queue::ArrayQueue;
11use deepsize::DeepSizeOf;
12use itertools::Itertools;
13
14use lance_core::utils::tokio::get_num_compute_intensive_cpus;
15use lance_linalg::distance::DistanceType;
16use rayon::prelude::*;
17use std::cmp::min;
18use std::collections::{BinaryHeap, HashMap, VecDeque};
19use std::fmt::Debug;
20use std::iter;
21use std::sync::Arc;
22use std::sync::RwLock;
23use std::sync::atomic::{AtomicUsize, Ordering};
24use tracing::instrument;
25
26use lance_core::{Error, Result};
27use rand::{Rng, rng};
28use serde::{Deserialize, Serialize};
29
30use super::super::graph::beam_search;
31use super::{HNSW_TYPE, HnswMetadata, VECTOR_ID_COL, VECTOR_ID_FIELD, select_neighbors_heuristic};
32use crate::metrics::MetricsCollector;
33use crate::prefilter::PreFilter;
34use crate::vector::flat::storage::{FlatBinStorage, FlatFloatStorage};
35use crate::vector::graph::builder::GraphBuilderNode;
36use crate::vector::graph::{
37 BorrowingGraph, DISTS_FIELD, Graph, NEIGHBORS_COL, NEIGHBORS_FIELD, OrderedFloat, OrderedNode,
38 VisitedGenerator,
39};
40use crate::vector::graph::{Visited, beam_search_borrowed, greedy_search, greedy_search_borrowed};
41use crate::vector::storage::{DistCalculator, VectorStore};
42use crate::vector::v3::subindex::IvfSubIndex;
43use crate::vector::{DIST_COL, Query, VECTOR_RESULT_SCHEMA};
44
45pub const HNSW_METADATA_KEY: &str = "lance:hnsw";
46
47#[derive(Debug, Clone, Serialize, Deserialize, DeepSizeOf)]
49pub struct HnswBuildParams {
50 pub max_level: u16,
52
53 pub m: usize,
55
56 pub ef_construction: usize,
58
59 pub prefetch_distance: Option<usize>,
61}
62
63impl Default for HnswBuildParams {
64 fn default() -> Self {
65 Self {
66 max_level: 7,
67 m: 20,
68 ef_construction: 150,
69 prefetch_distance: Some(2),
70 }
71 }
72}
73
74impl HnswBuildParams {
75 pub fn max_level(mut self, max_level: u16) -> Self {
78 self.max_level = max_level;
79 self
80 }
81
82 pub fn num_edges(mut self, m: usize) -> Self {
85 self.m = m;
86 self
87 }
88
89 pub fn ef_construction(mut self, ef_construction: usize) -> Self {
94 self.ef_construction = ef_construction;
95 self
96 }
97
98 pub async fn build(self, data: ArrayRef, distance_type: DistanceType) -> Result<HNSW> {
104 let vectors = data.as_fixed_size_list().clone();
105 match (vectors.value_type(), distance_type) {
106 (DataType::UInt8, DistanceType::Hamming) => {
107 let vec_store = Arc::new(FlatBinStorage::new(vectors, distance_type));
108 HNSW::index_vectors(vec_store.as_ref(), self)
109 }
110 (DataType::UInt8, _) => Err(Error::invalid_input(format!(
111 "HNSW only supports hamming distance for UInt8 vectors, got {}",
112 distance_type
113 ))),
114 (_, DistanceType::Hamming) => Err(Error::invalid_input(format!(
115 "HNSW hamming distance only supports UInt8 vectors, got {}",
116 vectors.value_type()
117 ))),
118 _ => {
119 let vec_store = Arc::new(FlatFloatStorage::new(vectors, distance_type));
120 HNSW::index_vectors(vec_store.as_ref(), self)
121 }
122 }
123 }
124}
125
126#[derive(Clone, DeepSizeOf)]
134pub struct HNSW {
135 inner: Arc<HnswCore>,
136}
137
138struct HnswCore {
139 params: HnswBuildParams,
140 nodes: Arc<Vec<GraphBuilderNode>>,
141 level_count: Vec<usize>,
142 entry_point: u32,
143 visited_generator_queue: Arc<ArrayQueue<VisitedGenerator>>,
144}
145
146impl DeepSizeOf for HnswCore {
147 fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
148 self.params.deep_size_of_children(context)
149 + self.nodes.deep_size_of_children(context)
150 + self.level_count.deep_size_of_children(context)
151 }
153}
154
155impl HnswCore {
156 fn max_level(&self) -> u16 {
157 self.params.max_level
158 }
159
160 fn num_nodes(&self, level: usize) -> usize {
161 self.level_count[level]
162 }
163
164 fn nodes(&self) -> Arc<Vec<GraphBuilderNode>> {
165 self.nodes.clone()
166 }
167}
168
169impl Debug for HNSW {
170 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
171 write!(f, "HNSW(max_layers: {})", self.inner.max_level() as usize,)
172 }
173}
174
175impl HNSW {
176 pub fn empty() -> Self {
177 Self {
178 inner: Arc::new(HnswCore {
179 params: HnswBuildParams::default(),
180 nodes: Arc::new(Vec::new()),
181 level_count: Vec::new(),
182 entry_point: 0,
183 visited_generator_queue: Arc::new(ArrayQueue::new(1)),
184 }),
185 }
186 }
187
188 pub fn len(&self) -> usize {
189 self.inner.nodes.len()
190 }
191
192 pub fn is_empty(&self) -> bool {
193 self.len() == 0
194 }
195
196 pub fn max_level(&self) -> u16 {
197 self.inner.max_level()
198 }
199
200 pub fn num_nodes(&self, level: usize) -> usize {
201 self.inner.num_nodes(level)
202 }
203
204 pub fn nodes(&self) -> Arc<Vec<GraphBuilderNode>> {
205 self.inner.nodes()
206 }
207
208 #[allow(clippy::too_many_arguments)]
209 pub fn search_inner(
210 &self,
211 query: ArrayRef,
212 k: usize,
213 params: &HnswQueryParams,
214 bitset: Option<Visited>,
215 visited_generator: &mut VisitedGenerator,
216 storage: &impl VectorStore,
217 prefetch_distance: Option<usize>,
218 ) -> Result<Vec<OrderedNode>> {
219 let dist_calc = storage.dist_calculator(query, params.dist_q_c);
220 let mut ep = OrderedNode::new(0, dist_calc.distance(0).into());
221 let nodes = self.inner.nodes.as_ref();
222 for level in (0..self.max_level()).rev() {
223 let cur_level = ImmutableHnswLevelView::new(level, nodes);
224 ep = greedy_search_borrowed(
225 &cur_level,
226 ep,
227 &dist_calc,
228 self.inner.params.prefetch_distance,
229 );
230 }
231
232 let bottom_level = ImmutableHnswBottomView::new(nodes);
233 let mut visited = visited_generator.generate(storage.len());
234 Ok(beam_search_borrowed(
235 &bottom_level,
236 &ep,
237 params,
238 &dist_calc,
239 bitset.as_ref(),
240 prefetch_distance,
241 &mut visited,
242 )
243 .into_iter()
244 .take(k)
245 .collect())
246 }
247
248 #[instrument(level = "debug", skip(self, query, bitset, storage))]
249 pub fn search_basic(
250 &self,
251 query: ArrayRef,
252 k: usize,
253 params: &HnswQueryParams,
254 bitset: Option<Visited>,
255 storage: &impl VectorStore,
256 ) -> Result<Vec<OrderedNode>> {
257 let mut visited_generator = self
258 .inner
259 .visited_generator_queue
260 .pop()
261 .unwrap_or_else(|| VisitedGenerator::new(storage.len()));
262 let result = self.search_inner(
263 query,
264 k,
265 params,
266 bitset,
267 &mut visited_generator,
268 storage,
269 Some(2),
270 );
271
272 match self.inner.visited_generator_queue.push(visited_generator) {
273 Ok(_) => {}
274 Err(_) => {
275 log::warn!("visited_generator_queue is full");
276 }
277 }
278
279 result
280 }
281
282 #[instrument(level = "debug", skip(self, storage, query, prefilter_bitset))]
283 fn flat_search(
284 &self,
285 storage: &impl VectorStore,
286 query: ArrayRef,
287 k: usize,
288 prefilter_bitset: Visited,
289 params: &HnswQueryParams,
290 ) -> Vec<OrderedNode> {
291 let lower_bound: OrderedFloat = params.lower_bound.unwrap_or(f32::MIN).into();
292 let upper_bound: OrderedFloat = params.upper_bound.unwrap_or(f32::MAX).into();
293
294 let dist_calc = storage.dist_calculator(query, params.dist_q_c);
295 let mut heap = BinaryHeap::<OrderedNode>::with_capacity(k);
296
297 match self.inner.params.prefetch_distance {
298 Some(ahead) if ahead > 0 => {
299 let mut ids_iter = prefilter_bitset.iter_ones().map(|i| i as u32);
300 let mut buffer = VecDeque::with_capacity(ahead + 1);
301 for _ in 0..=ahead {
302 if let Some(id) = ids_iter.next() {
303 buffer.push_back(id);
304 } else {
305 break;
306 }
307 }
308
309 while let Some(node_id) = buffer.pop_front() {
310 if let Some(&prefetch_id) = buffer.get(ahead - 1) {
311 dist_calc.prefetch(prefetch_id);
312 }
313 if let Some(next) = ids_iter.next() {
314 buffer.push_back(next);
315 }
316
317 let dist: OrderedFloat = dist_calc.distance(node_id).into();
318 if dist <= lower_bound || dist > upper_bound {
319 continue;
320 }
321 if heap.len() < k {
322 heap.push((dist, node_id).into());
323 } else if dist < heap.peek().unwrap().dist {
324 heap.pop();
325 heap.push((dist, node_id).into());
326 }
327 }
328 }
329 _ => {
330 for node_id in prefilter_bitset.iter_ones().map(|i| i as u32) {
331 let dist: OrderedFloat = dist_calc.distance(node_id).into();
332 if dist <= lower_bound || dist > upper_bound {
333 continue;
334 }
335 if heap.len() < k {
336 heap.push((dist, node_id).into());
337 } else if dist < heap.peek().unwrap().dist {
338 heap.pop();
339 heap.push((dist, node_id).into());
340 }
341 }
342 }
343 };
344 heap.into_sorted_vec()
345 }
346
347 pub fn metadata(&self) -> HnswMetadata {
349 let level_offsets = self
352 .inner
353 .level_count
354 .iter()
355 .chain(iter::once(&0))
356 .scan(0, |state, x| {
357 let start = *state;
358 *state += *x;
359 Some(start)
360 })
361 .collect();
362
363 HnswMetadata {
364 entry_point: self.inner.entry_point,
365 params: self.inner.params.clone(),
366 level_offsets,
367 }
368 }
369}
370
371struct HnswBuilder {
372 params: HnswBuildParams,
373
374 nodes: Arc<Vec<RwLock<GraphBuilderNode>>>,
375 level_count: Vec<AtomicUsize>,
376
377 entry_point: u32,
378
379 visited_generator_queue: Arc<ArrayQueue<VisitedGenerator>>,
380}
381
382impl DeepSizeOf for HnswBuilder {
383 fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
384 self.params.deep_size_of_children(context)
385 + self.nodes.deep_size_of_children(context)
386 + self.level_count.deep_size_of_children(context)
387 }
389}
390
391impl HnswBuilder {
392 fn finish(self) -> HNSW {
393 let nodes = match Arc::try_unwrap(self.nodes) {
394 Ok(nodes) => nodes
395 .into_iter()
396 .map(|node| node.into_inner().expect("builder lock poisoned"))
397 .collect(),
398 Err(nodes) => nodes
399 .iter()
400 .map(|node| node.read().expect("builder lock poisoned").clone())
401 .collect(),
402 };
403
404 let level_count = self
405 .level_count
406 .into_iter()
407 .map(|count| count.load(Ordering::Relaxed))
408 .collect();
409
410 HNSW {
411 inner: Arc::new(HnswCore {
412 params: self.params,
413 nodes: Arc::new(nodes),
414 level_count,
415 entry_point: self.entry_point,
416 visited_generator_queue: self.visited_generator_queue,
417 }),
418 }
419 }
420
421 pub fn with_params(params: HnswBuildParams, storage: &impl VectorStore) -> Self {
423 let len = storage.len();
424 let max_level = params.max_level;
425
426 let level_count = (0..max_level)
427 .map(|_| AtomicUsize::new(0))
428 .collect::<Vec<_>>();
429
430 let visited_generator_queue = Arc::new(ArrayQueue::new(get_num_compute_intensive_cpus()));
431 for _ in 0..get_num_compute_intensive_cpus() {
432 visited_generator_queue
433 .push(VisitedGenerator::new(0))
434 .unwrap();
435 }
436 let mut builder = Self {
437 params,
438 nodes: Arc::new(Vec::new()),
439 level_count,
440 entry_point: 0,
441 visited_generator_queue,
442 };
443
444 if storage.is_empty() {
445 return builder;
446 }
447
448 let mut nodes = Vec::with_capacity(len);
449 {
450 if len > 0 {
451 nodes.push(RwLock::new(GraphBuilderNode::new(0, max_level as usize)));
452 }
453 let mut level_rng = rng();
454 for i in 1..len {
455 nodes.push(RwLock::new(GraphBuilderNode::new(
456 i as u32,
457 builder.random_level(&mut level_rng) as usize + 1,
458 )));
459 }
460 }
461 builder.nodes = Arc::new(nodes);
462
463 builder
464 }
465
466 fn random_level<R: Rng + ?Sized>(&self, rng: &mut R) -> u16 {
470 let ml = 1.0 / (self.params.m as f32).ln();
471 min(
472 (-rng.random::<f32>().ln() * ml) as u16,
473 self.params.max_level - 1,
474 )
475 }
476
477 fn insert(
479 &self,
480 node: u32,
481 visited_generator: &mut VisitedGenerator,
482 storage: &impl VectorStore,
483 ) {
484 let nodes = &self.nodes;
485 let target_level = nodes[node as usize].read().unwrap().level_neighbors.len() as u16 - 1;
486 let dist_calc = storage.dist_calculator_from_id(node);
487 let mut ep = OrderedNode::new(
488 self.entry_point,
489 dist_calc.distance(self.entry_point).into(),
490 );
491
492 for level in (target_level + 1..self.params.max_level).rev() {
501 let cur_level = HnswLevelView::new(level, nodes);
502 ep = greedy_search(&cur_level, ep, &dist_calc, self.params.prefetch_distance);
503 }
504
505 let mut pruned_neighbors_per_level: Vec<Vec<_>> =
506 vec![Vec::new(); (target_level + 1) as usize];
507 {
508 let mut current_node = nodes[node as usize].write().unwrap();
509 for level in (0..=target_level).rev() {
510 self.level_count[level as usize].fetch_add(1, Ordering::Relaxed);
511
512 let neighbors = self.search_level(&ep, level, &dist_calc, nodes, visited_generator);
513 for neighbor in &neighbors {
514 current_node.add_neighbor(neighbor.id, neighbor.dist, level);
515 }
516 self.prune(storage, &mut current_node, level);
517 pruned_neighbors_per_level[level as usize]
518 .clone_from(¤t_node.level_neighbors_ranked[level as usize]);
519
520 ep = neighbors[0].clone();
521 }
522 }
523 for (level, pruned_neighbors) in pruned_neighbors_per_level.iter().enumerate() {
524 let _: Vec<_> = pruned_neighbors
525 .iter()
526 .map(|unpruned_edge| {
527 let level = level as u16;
528 let m_max = match level {
529 0 => self.params.m * 2,
530 _ => self.params.m,
531 };
532 if unpruned_edge.dist
533 < nodes[unpruned_edge.id as usize]
534 .read()
535 .unwrap()
536 .cutoff(level, m_max)
537 {
538 let mut chosen_node = nodes[unpruned_edge.id as usize].write().unwrap();
539 chosen_node.add_neighbor(node, unpruned_edge.dist, level);
540 self.prune(storage, &mut chosen_node, level);
541 }
542 })
543 .collect();
544 }
545 }
546
547 fn search_level(
548 &self,
549 ep: &OrderedNode,
550 level: u16,
551 dist_calc: &impl DistCalculator,
552 nodes: &[RwLock<GraphBuilderNode>],
553 visited_generator: &mut VisitedGenerator,
554 ) -> Vec<OrderedNode> {
555 let cur_level = HnswLevelView::new(level, nodes);
556 let mut visited = visited_generator.generate(nodes.len());
557 beam_search(
558 &cur_level,
559 ep,
560 &HnswQueryParams {
561 ef: self.params.ef_construction,
562 lower_bound: None,
563 upper_bound: None,
564 dist_q_c: 0.0,
565 },
566 dist_calc,
567 None,
568 self.params.prefetch_distance,
569 &mut visited,
570 )
571 }
572
573 fn prune(&self, storage: &impl VectorStore, builder_node: &mut GraphBuilderNode, level: u16) {
574 let m_max = match level {
575 0 => self.params.m * 2,
576 _ => self.params.m,
577 };
578
579 let neighbors_ranked = &mut builder_node.level_neighbors_ranked[level as usize];
580 let level_neighbors = neighbors_ranked.clone();
581 if level_neighbors.len() <= m_max {
582 builder_node.update_from_ranked_neighbors(level);
583 return;
584 }
585
586 *neighbors_ranked = select_neighbors_heuristic(storage, &level_neighbors, m_max);
587 builder_node.update_from_ranked_neighbors(level);
588 }
589}
590
591pub(crate) struct HnswLevelView<'a> {
594 level: u16,
595 nodes: &'a [RwLock<GraphBuilderNode>],
596}
597
598impl<'a> HnswLevelView<'a> {
599 pub fn new(level: u16, nodes: &'a [RwLock<GraphBuilderNode>]) -> Self {
600 Self { level, nodes }
601 }
602}
603
604impl Graph for HnswLevelView<'_> {
605 fn len(&self) -> usize {
606 self.nodes.len()
607 }
608
609 fn neighbors(&self, key: u32) -> Arc<Vec<u32>> {
610 let node = &self.nodes[key as usize];
611 node.read().unwrap().level_neighbors[self.level as usize].clone()
612 }
613}
614
615pub(crate) struct ImmutableHnswLevelView<'a> {
616 level: u16,
617 nodes: &'a [GraphBuilderNode],
618}
619
620impl<'a> ImmutableHnswLevelView<'a> {
621 pub fn new(level: u16, nodes: &'a [GraphBuilderNode]) -> Self {
622 Self { level, nodes }
623 }
624}
625
626impl Graph for ImmutableHnswLevelView<'_> {
627 fn len(&self) -> usize {
628 self.nodes.len()
629 }
630
631 fn neighbors(&self, key: u32) -> Arc<Vec<u32>> {
632 self.nodes[key as usize].level_neighbors[self.level as usize].clone()
633 }
634}
635
636impl BorrowingGraph for ImmutableHnswLevelView<'_> {
637 fn len(&self) -> usize {
638 self.nodes.len()
639 }
640
641 fn neighbors(&self, key: u32) -> &[u32] {
642 self.nodes[key as usize].level_neighbors[self.level as usize].as_slice()
643 }
644}
645
646pub(crate) struct ImmutableHnswBottomView<'a> {
647 nodes: &'a [GraphBuilderNode],
648}
649
650impl<'a> ImmutableHnswBottomView<'a> {
651 pub fn new(nodes: &'a [GraphBuilderNode]) -> Self {
652 Self { nodes }
653 }
654}
655
656impl Graph for ImmutableHnswBottomView<'_> {
657 fn len(&self) -> usize {
658 self.nodes.len()
659 }
660
661 fn neighbors(&self, key: u32) -> Arc<Vec<u32>> {
662 self.nodes[key as usize].bottom_neighbors.clone()
663 }
664}
665
666impl BorrowingGraph for ImmutableHnswBottomView<'_> {
667 fn len(&self) -> usize {
668 self.nodes.len()
669 }
670
671 fn neighbors(&self, key: u32) -> &[u32] {
672 self.nodes[key as usize].bottom_neighbors.as_slice()
673 }
674}
675
676#[derive(Debug, Clone, Copy)]
677pub struct HnswQueryParams {
678 pub ef: usize,
679 pub lower_bound: Option<f32>,
680 pub upper_bound: Option<f32>,
681 pub dist_q_c: f32,
682}
683
684impl From<&Query> for HnswQueryParams {
685 fn from(query: &Query) -> Self {
686 let k = query.k * query.refine_factor.unwrap_or(1) as usize;
687 Self {
688 ef: query.ef.unwrap_or(k + k / 2),
689 lower_bound: query.lower_bound,
690 upper_bound: query.upper_bound,
691 dist_q_c: query.dist_q_c,
692 }
693 }
694}
695
696impl IvfSubIndex for HNSW {
697 type BuildParams = HnswBuildParams;
698 type QueryParams = HnswQueryParams;
699
700 fn load(data: RecordBatch) -> Result<Self>
701 where
702 Self: Sized,
703 {
704 if data.num_rows() == 0 {
705 return Ok(Self::empty());
706 }
707
708 let hnsw_metadata = data
709 .schema_ref()
710 .metadata()
711 .get(HNSW_METADATA_KEY)
712 .ok_or(Error::index(format!("{} not found", HNSW_METADATA_KEY)))?;
713 let hnsw_metadata: HnswMetadata = serde_json::from_str(hnsw_metadata).map_err(|e| {
714 Error::index(format!(
715 "Failed to decode HNSW metadata: {}, json: {}",
716 e, hnsw_metadata
717 ))
718 })?;
719
720 let levels: Vec<_> = hnsw_metadata
721 .level_offsets
722 .iter()
723 .tuple_windows()
724 .map(|(start, end)| data.slice(*start, end - start))
725 .collect();
726
727 let level_count = levels.iter().map(|b| b.num_rows()).collect::<Vec<_>>();
728
729 let bottom_level_len = levels[0].num_rows();
730 let mut nodes = Vec::with_capacity(bottom_level_len);
731 for i in 0..bottom_level_len {
732 nodes.push(GraphBuilderNode::new(i as u32, levels.len()));
733 }
734 for (level, batch) in levels.into_iter().enumerate() {
735 let ids = batch[VECTOR_ID_COL].as_primitive::<UInt32Type>();
736 let neighbors = batch[NEIGHBORS_COL].as_list::<i32>();
737 let distances = batch[DIST_COL].as_list::<i32>();
738
739 for ((node, neighbors), distances) in
740 ids.iter().zip(neighbors.iter()).zip(distances.iter())
741 {
742 let node = node.unwrap();
743 let neighbors = neighbors.as_ref().unwrap().as_primitive::<UInt32Type>();
744 let distances = distances.as_ref().unwrap().as_primitive::<Float32Type>();
745
746 nodes[node as usize].level_neighbors_ranked[level] = neighbors
747 .iter()
748 .zip(distances.iter())
749 .map(|(n, dist)| OrderedNode::new(n.unwrap(), OrderedFloat(dist.unwrap())))
750 .collect();
751 nodes[node as usize].update_from_ranked_neighbors(level as u16);
752 }
753 }
754
755 let visited_generator_queue =
756 Arc::new(ArrayQueue::new(get_num_compute_intensive_cpus() * 2));
757 for _ in 0..get_num_compute_intensive_cpus() * 2 {
758 visited_generator_queue
759 .push(VisitedGenerator::new(0))
760 .unwrap();
761 }
762 let inner = HnswCore {
763 params: hnsw_metadata.params,
764 nodes: Arc::new(nodes),
765 level_count,
766 entry_point: hnsw_metadata.entry_point,
767 visited_generator_queue,
768 };
769
770 Ok(Self {
771 inner: Arc::new(inner),
772 })
773 }
774
775 fn name() -> &'static str {
776 HNSW_TYPE
777 }
778
779 fn metadata_key() -> &'static str {
780 "lance:hnsw"
781 }
782
783 fn schema() -> arrow_schema::SchemaRef {
785 arrow_schema::Schema::new(vec![
786 VECTOR_ID_FIELD.clone(),
787 NEIGHBORS_FIELD.clone(),
788 DISTS_FIELD.clone(),
789 ])
790 .into()
791 }
792
793 #[instrument(level = "debug", skip(self, query, storage, prefilter, _metrics))]
794 fn search(
795 &self,
796 query: ArrayRef,
797 k: usize,
798 params: Self::QueryParams,
799 storage: &impl VectorStore,
800 prefilter: Arc<dyn PreFilter>,
801 _metrics: &dyn MetricsCollector,
802 ) -> Result<RecordBatch> {
803 if params.ef < k {
804 return Err(Error::index(
805 "ef must be greater than or equal to k".to_string(),
806 ));
807 }
808
809 let schema = VECTOR_RESULT_SCHEMA.clone();
810 if self.is_empty() {
811 return Ok(RecordBatch::new_empty(schema));
812 }
813
814 let mut prefilter_generator = self
815 .inner
816 .visited_generator_queue
817 .pop()
818 .unwrap_or_else(|| VisitedGenerator::new(storage.len()));
819 let prefilter_bitset = if prefilter.is_empty() {
820 None
821 } else {
822 let indices = prefilter.filter_row_ids(Box::new(storage.row_ids()));
823 let mut bitset = prefilter_generator.generate(storage.len());
824 for indices in indices {
825 bitset.insert(indices as u32);
826 }
827 Some(bitset)
828 };
829
830 let remained = prefilter_bitset
831 .as_ref()
832 .map(|b| b.count_ones())
833 .unwrap_or(storage.len());
834 let results = if remained < self.len() * 10 / 100 {
835 let prefilter_bitset =
836 prefilter_bitset.expect("the prefilter bitset must be set for flat search");
837 self.flat_search(storage, query, k, prefilter_bitset, ¶ms)
838 } else {
839 self.search_basic(query, k, ¶ms, prefilter_bitset, storage)?
840 };
841 let _ = self.inner.visited_generator_queue.push(prefilter_generator);
843
844 let (row_ids, dists): (Vec<_>, Vec<_>) = results
846 .into_iter()
847 .map(|r| (storage.row_id(r.id), r.dist.0))
848 .unique_by(|r| r.0)
849 .unzip();
850 let row_ids = Arc::new(UInt64Array::from(row_ids));
851 let distances = Arc::new(Float32Array::from(dists));
852
853 Ok(RecordBatch::try_new(schema, vec![distances, row_ids])?)
854 }
855
856 fn index_vectors(storage: &impl VectorStore, params: Self::BuildParams) -> Result<Self>
858 where
859 Self: Sized,
860 {
861 let builder = HnswBuilder::with_params(params, storage);
862
863 log::debug!(
864 "Building HNSW graph: num={}, max_levels={}, m={}, ef_construction={}, distance_type:{}",
865 storage.len(),
866 builder.params.max_level,
867 builder.params.m,
868 builder.params.ef_construction,
869 storage.distance_type(),
870 );
871
872 if storage.is_empty() {
873 return Ok(builder.finish());
874 }
875
876 let len = storage.len();
877 builder.level_count[0].fetch_add(1, Ordering::Relaxed);
878 (1..len).into_par_iter().for_each_init(
879 || VisitedGenerator::new(len),
880 |visited_generator, node| {
881 builder.insert(node as u32, visited_generator, storage);
882 },
883 );
884
885 assert_eq!(builder.level_count[0].load(Ordering::Relaxed), len);
886 Ok(builder.finish())
887 }
888
889 fn remap(
890 &self,
891 _mapping: &HashMap<u64, Option<u64>>, store: &impl VectorStore,
893 ) -> Result<Self> {
894 Self::index_vectors(store, self.inner.params.clone())
897 }
898
899 fn to_batch(&self) -> Result<RecordBatch> {
901 let mut vector_id_builder = UInt32Builder::with_capacity(self.len());
902 let mut neighbors_builder = ListBuilder::with_capacity(UInt32Builder::new(), self.len());
903 let mut distances_builder =
904 ListBuilder::with_capacity(arrow_array::builder::Float32Builder::new(), self.len());
905 let mut batches = Vec::with_capacity(self.max_level() as usize);
906 for level in 0..self.max_level() {
907 let level = level as usize;
908 for (id, node) in self.inner.nodes.iter().enumerate() {
909 if level >= node.level_neighbors.len() {
910 continue;
911 }
912 let neighbors = node.level_neighbors[level].iter().map(|n| Some(*n));
913 let distances = node.level_neighbors_ranked[level]
914 .iter()
915 .map(|n| Some(n.dist.0));
916 vector_id_builder.append_value(id as u32);
917 neighbors_builder.append_value(neighbors);
918 distances_builder.append_value(distances);
919 }
920
921 let batch = RecordBatch::try_new(
922 Self::schema(),
923 vec![
924 Arc::new(vector_id_builder.finish()),
925 Arc::new(neighbors_builder.finish()),
926 Arc::new(distances_builder.finish()),
927 ],
928 )?;
929 batches.push(batch);
930 }
931
932 let metadata = self.metadata();
933 let metadata = serde_json::to_string(&metadata)?;
934 let schema = Self::schema()
935 .as_ref()
936 .clone()
937 .with_metadata(HashMap::from_iter(vec![(
938 HNSW_METADATA_KEY.to_string(),
939 metadata,
940 )]));
941 let batch = concat_batches(&Self::schema(), batches.iter())?;
942 let batch = batch.with_schema(Arc::new(schema))?;
943 Ok(batch)
944 }
945}
946
947#[cfg(test)]
948mod tests {
949 use std::sync::Arc;
950
951 use arrow_array::{FixedSizeListArray, UInt8Array};
952 use arrow_schema::Schema;
953 use lance_arrow::FixedSizeListArrayExt;
954 use lance_file::previous::{
955 reader::FileReader as PreviousFileReader,
956 writer::{
957 FileWriter as PreviousFileWriter, FileWriterOptions as PreviousFileWriterOptions,
958 },
959 };
960 use lance_io::object_store::ObjectStore;
961 use lance_linalg::distance::DistanceType;
962 use lance_table::format::SelfDescribingFileReader;
963 use lance_table::io::manifest::ManifestDescribing;
964 use lance_testing::datagen::generate_random_array;
965 use object_store::path::Path;
966
967 use crate::scalar::IndexWriter;
968 use crate::vector::v3::subindex::IvfSubIndex;
969 use crate::vector::{
970 flat::storage::{FlatBinStorage, FlatFloatStorage},
971 graph::{DISTS_FIELD, NEIGHBORS_FIELD},
972 hnsw::{
973 HNSW, VECTOR_ID_FIELD,
974 builder::{HnswBuildParams, HnswQueryParams},
975 },
976 };
977
978 #[tokio::test]
979 async fn test_builder_write_load() {
980 const DIM: usize = 32;
981 const TOTAL: usize = 2048;
982 const NUM_EDGES: usize = 20;
983 let data = generate_random_array(TOTAL * DIM);
984 let fsl = FixedSizeListArray::try_new_from_values(data, DIM as i32).unwrap();
985 let store = Arc::new(FlatFloatStorage::new(fsl.clone(), DistanceType::L2));
986 let builder = HNSW::index_vectors(
987 store.as_ref(),
988 HnswBuildParams::default()
989 .num_edges(NUM_EDGES)
990 .ef_construction(50),
991 )
992 .unwrap();
993
994 let object_store = ObjectStore::memory();
995 let path = Path::from("test_builder_write_load");
996 let writer = object_store.create(&path).await.unwrap();
997 let schema = Schema::new(vec![
998 VECTOR_ID_FIELD.clone(),
999 NEIGHBORS_FIELD.clone(),
1000 DISTS_FIELD.clone(),
1001 ]);
1002 let schema = lance_core::datatypes::Schema::try_from(&schema).unwrap();
1003 let mut writer = PreviousFileWriter::<ManifestDescribing>::with_object_writer(
1004 writer,
1005 schema,
1006 &PreviousFileWriterOptions::default(),
1007 )
1008 .unwrap();
1009 let batch = builder.to_batch().unwrap();
1010 let metadata = batch.schema_ref().metadata().clone();
1011 writer.write_record_batch(batch).await.unwrap();
1012 writer.finish_with_metadata(&metadata).await.unwrap();
1013
1014 let reader = PreviousFileReader::try_new_self_described(&object_store, &path, None)
1015 .await
1016 .unwrap();
1017 let batch = reader
1018 .read_range(0..reader.len(), reader.schema())
1019 .await
1020 .unwrap();
1021 let loaded_hnsw = HNSW::load(batch).unwrap();
1022
1023 let query = fsl.value(0);
1024 let k = 10;
1025 let params = HnswQueryParams {
1026 ef: 50,
1027 lower_bound: None,
1028 upper_bound: None,
1029 dist_q_c: 0.0,
1030 };
1031 let builder_results = builder
1032 .search_basic(query.clone(), k, ¶ms, None, store.as_ref())
1033 .unwrap();
1034 let loaded_results = loaded_hnsw
1035 .search_basic(query, k, ¶ms, None, store.as_ref())
1036 .unwrap();
1037 assert_eq!(builder_results, loaded_results);
1038 }
1039
1040 #[tokio::test]
1041 async fn test_builder_write_load_binary_hamming() {
1042 const DIM: usize = 8;
1043 const TOTAL: usize = 256;
1044 const NUM_EDGES: usize = 20;
1045 let data = UInt8Array::from_iter_values((0..TOTAL * DIM).map(|v| (v % 16) as u8));
1046 let fsl = FixedSizeListArray::try_new_from_values(data, DIM as i32).unwrap();
1047 let store = Arc::new(FlatBinStorage::new(fsl.clone(), DistanceType::Hamming));
1048 let builder = HnswBuildParams::default()
1049 .num_edges(NUM_EDGES)
1050 .ef_construction(50)
1051 .build(Arc::new(fsl.clone()), DistanceType::Hamming)
1052 .await
1053 .unwrap();
1054
1055 let object_store = ObjectStore::memory();
1056 let path = Path::from("test_builder_write_load_binary_hamming");
1057 let writer = object_store.create(&path).await.unwrap();
1058 let schema = Schema::new(vec![
1059 VECTOR_ID_FIELD.clone(),
1060 NEIGHBORS_FIELD.clone(),
1061 DISTS_FIELD.clone(),
1062 ]);
1063 let schema = lance_core::datatypes::Schema::try_from(&schema).unwrap();
1064 let mut writer = PreviousFileWriter::<ManifestDescribing>::with_object_writer(
1065 writer,
1066 schema,
1067 &PreviousFileWriterOptions::default(),
1068 )
1069 .unwrap();
1070 let batch = builder.to_batch().unwrap();
1071 let metadata = batch.schema_ref().metadata().clone();
1072 writer.write_record_batch(batch).await.unwrap();
1073 writer.finish_with_metadata(&metadata).await.unwrap();
1074
1075 let reader = PreviousFileReader::try_new_self_described(&object_store, &path, None)
1076 .await
1077 .unwrap();
1078 let batch = reader
1079 .read_range(0..reader.len(), reader.schema())
1080 .await
1081 .unwrap();
1082 let loaded_hnsw = HNSW::load(batch).unwrap();
1083
1084 let query = fsl.value(0);
1085 let k = 10;
1086 let params = HnswQueryParams {
1087 ef: 50,
1088 lower_bound: None,
1089 upper_bound: None,
1090 dist_q_c: 0.0,
1091 };
1092 let builder_results = builder
1093 .search_basic(query.clone(), k, ¶ms, None, store.as_ref())
1094 .unwrap();
1095 let loaded_results = loaded_hnsw
1096 .search_basic(query, k, ¶ms, None, store.as_ref())
1097 .unwrap();
1098 assert_eq!(builder_results, loaded_results);
1099 }
1100}