1use arrow::array::{AsArray, ListBuilder, UInt32Builder};
7use arrow::compute::concat_batches;
8use arrow::datatypes::{Float32Type, UInt32Type};
9use arrow_array::{ArrayRef, Float32Array, RecordBatch, UInt64Array};
10use crossbeam_queue::ArrayQueue;
11use deepsize::DeepSizeOf;
12use itertools::Itertools;
13
14use lance_core::utils::tokio::get_num_compute_intensive_cpus;
15use lance_linalg::distance::DistanceType;
16use rayon::prelude::*;
17use snafu::location;
18use std::cmp::min;
19use std::collections::{BinaryHeap, HashMap, VecDeque};
20use std::fmt::Debug;
21use std::iter;
22use std::sync::atomic::{AtomicUsize, Ordering};
23use std::sync::Arc;
24use std::sync::RwLock;
25use tracing::instrument;
26
27use lance_core::{Error, Result};
28use rand::{rng, Rng};
29use serde::{Deserialize, Serialize};
30
31use super::super::graph::beam_search;
32use super::{select_neighbors_heuristic, HnswMetadata, HNSW_TYPE, VECTOR_ID_COL, VECTOR_ID_FIELD};
33use crate::metrics::MetricsCollector;
34use crate::prefilter::PreFilter;
35use crate::vector::flat::storage::FlatFloatStorage;
36use crate::vector::graph::builder::GraphBuilderNode;
37use crate::vector::graph::{greedy_search, Visited};
38use crate::vector::graph::{
39 Graph, OrderedFloat, OrderedNode, VisitedGenerator, DISTS_FIELD, NEIGHBORS_COL, NEIGHBORS_FIELD,
40};
41use crate::vector::storage::{DistCalculator, VectorStore};
42use crate::vector::v3::subindex::IvfSubIndex;
43use crate::vector::{Query, DIST_COL, VECTOR_RESULT_SCHEMA};
44
45pub const HNSW_METADATA_KEY: &str = "lance:hnsw";
46
47#[derive(Debug, Clone, Serialize, Deserialize, DeepSizeOf)]
49pub struct HnswBuildParams {
50 pub max_level: u16,
52
53 pub m: usize,
55
56 pub ef_construction: usize,
58
59 pub prefetch_distance: Option<usize>,
61}
62
63impl Default for HnswBuildParams {
64 fn default() -> Self {
65 Self {
66 max_level: 7,
67 m: 20,
68 ef_construction: 150,
69 prefetch_distance: Some(2),
70 }
71 }
72}
73
74impl HnswBuildParams {
75 pub fn max_level(mut self, max_level: u16) -> Self {
78 self.max_level = max_level;
79 self
80 }
81
82 pub fn num_edges(mut self, m: usize) -> Self {
85 self.m = m;
86 self
87 }
88
89 pub fn ef_construction(mut self, ef_construction: usize) -> Self {
94 self.ef_construction = ef_construction;
95 self
96 }
97
98 pub async fn build(self, data: ArrayRef, distance_type: DistanceType) -> Result<HNSW> {
104 let vec_store = Arc::new(FlatFloatStorage::new(
105 data.as_fixed_size_list().clone(),
106 distance_type,
107 ));
108 HNSW::index_vectors(vec_store.as_ref(), self)
109 }
110}
111
112#[derive(Clone, DeepSizeOf)]
120pub struct HNSW {
121 inner: Arc<HnswBuilder>,
122}
123
124impl Debug for HNSW {
125 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126 write!(f, "HNSW(max_layers: {})", self.inner.max_level() as usize,)
127 }
128}
129
130impl HNSW {
131 pub fn empty() -> Self {
132 Self {
133 inner: Arc::new(HnswBuilder {
134 params: HnswBuildParams::default(),
135 nodes: Arc::new(Vec::new()),
136 level_count: Vec::new(),
137 entry_point: 0,
138 visited_generator_queue: Arc::new(ArrayQueue::new(1)),
139 }),
140 }
141 }
142
143 pub fn len(&self) -> usize {
144 self.inner.nodes.len()
145 }
146
147 pub fn is_empty(&self) -> bool {
148 self.len() == 0
149 }
150
151 pub fn max_level(&self) -> u16 {
152 self.inner.max_level()
153 }
154
155 pub fn num_nodes(&self, level: usize) -> usize {
156 self.inner.num_nodes(level)
157 }
158
159 pub fn nodes(&self) -> Arc<Vec<RwLock<GraphBuilderNode>>> {
160 self.inner.nodes()
161 }
162
163 #[allow(clippy::too_many_arguments)]
164 pub fn search_inner(
165 &self,
166 query: ArrayRef,
167 k: usize,
168 params: &HnswQueryParams,
169 bitset: Option<Visited>,
170 visited_generator: &mut VisitedGenerator,
171 storage: &impl VectorStore,
172 prefetch_distance: Option<usize>,
173 ) -> Result<Vec<OrderedNode>> {
174 let dist_calc = storage.dist_calculator(query, params.dist_q_c);
175 let mut ep = OrderedNode::new(0, dist_calc.distance(0).into());
176 let nodes = &self.nodes();
177 for level in (0..self.max_level()).rev() {
178 let cur_level = HnswLevelView::new(level, nodes);
179 ep = greedy_search(
180 &cur_level,
181 ep,
182 &dist_calc,
183 self.inner.params.prefetch_distance,
184 );
185 }
186
187 let bottom_level = HnswBottomView::new(nodes);
188 let mut visited = visited_generator.generate(storage.len());
189 Ok(beam_search(
190 &bottom_level,
191 &ep,
192 params,
193 &dist_calc,
194 bitset.as_ref(),
195 prefetch_distance,
196 &mut visited,
197 )
198 .into_iter()
199 .take(k)
200 .collect())
201 }
202
203 #[instrument(level = "debug", skip(self, query, bitset, storage))]
204 pub fn search_basic(
205 &self,
206 query: ArrayRef,
207 k: usize,
208 params: &HnswQueryParams,
209 bitset: Option<Visited>,
210 storage: &impl VectorStore,
211 ) -> Result<Vec<OrderedNode>> {
212 let mut visited_generator = self
213 .inner
214 .visited_generator_queue
215 .pop()
216 .unwrap_or_else(|| VisitedGenerator::new(storage.len()));
217 let result = self.search_inner(
218 query,
219 k,
220 params,
221 bitset,
222 &mut visited_generator,
223 storage,
224 Some(2),
225 );
226
227 match self.inner.visited_generator_queue.push(visited_generator) {
228 Ok(_) => {}
229 Err(_) => {
230 log::warn!("visited_generator_queue is full");
231 }
232 }
233
234 result
235 }
236
237 #[instrument(level = "debug", skip(self, storage, query, prefilter_bitset))]
238 fn flat_search(
239 &self,
240 storage: &impl VectorStore,
241 query: ArrayRef,
242 k: usize,
243 prefilter_bitset: Visited,
244 params: &HnswQueryParams,
245 ) -> Vec<OrderedNode> {
246 let lower_bound: OrderedFloat = params.lower_bound.unwrap_or(f32::MIN).into();
247 let upper_bound: OrderedFloat = params.upper_bound.unwrap_or(f32::MAX).into();
248
249 let dist_calc = storage.dist_calculator(query, params.dist_q_c);
250 let mut heap = BinaryHeap::<OrderedNode>::with_capacity(k);
251
252 match self.inner.params.prefetch_distance {
253 Some(ahead) if ahead > 0 => {
254 let mut ids_iter = prefilter_bitset.iter_ones().map(|i| i as u32);
255 let mut buffer = VecDeque::with_capacity(ahead + 1);
256 for _ in 0..=ahead {
257 if let Some(id) = ids_iter.next() {
258 buffer.push_back(id);
259 } else {
260 break;
261 }
262 }
263
264 while let Some(node_id) = buffer.pop_front() {
265 if let Some(&prefetch_id) = buffer.get(ahead - 1) {
266 dist_calc.prefetch(prefetch_id);
267 }
268 if let Some(next) = ids_iter.next() {
269 buffer.push_back(next);
270 }
271
272 let dist: OrderedFloat = dist_calc.distance(node_id).into();
273 if dist <= lower_bound || dist > upper_bound {
274 continue;
275 }
276 if heap.len() < k {
277 heap.push((dist, node_id).into());
278 } else if dist < heap.peek().unwrap().dist {
279 heap.pop();
280 heap.push((dist, node_id).into());
281 }
282 }
283 }
284 _ => {
285 for node_id in prefilter_bitset.iter_ones().map(|i| i as u32) {
286 let dist: OrderedFloat = dist_calc.distance(node_id).into();
287 if dist <= lower_bound || dist > upper_bound {
288 continue;
289 }
290 if heap.len() < k {
291 heap.push((dist, node_id).into());
292 } else if dist < heap.peek().unwrap().dist {
293 heap.pop();
294 heap.push((dist, node_id).into());
295 }
296 }
297 }
298 };
299 heap.into_sorted_vec()
300 }
301
302 pub fn metadata(&self) -> HnswMetadata {
304 let level_offsets = self
307 .inner
308 .level_count
309 .iter()
310 .chain(iter::once(&AtomicUsize::new(0)))
311 .scan(0, |state, x| {
312 let start = *state;
313 *state += x.load(Ordering::Relaxed);
314 Some(start)
315 })
316 .collect();
317
318 HnswMetadata {
319 entry_point: self.inner.entry_point,
320 params: self.inner.params.clone(),
321 level_offsets,
322 }
323 }
324}
325
326struct HnswBuilder {
327 params: HnswBuildParams,
328
329 nodes: Arc<Vec<RwLock<GraphBuilderNode>>>,
330 level_count: Vec<AtomicUsize>,
331
332 entry_point: u32,
333
334 visited_generator_queue: Arc<ArrayQueue<VisitedGenerator>>,
335}
336
337impl DeepSizeOf for HnswBuilder {
338 fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
339 self.params.deep_size_of_children(context)
340 + self.nodes.deep_size_of_children(context)
341 + self.level_count.deep_size_of_children(context)
342 }
344}
345
346impl HnswBuilder {
347 fn max_level(&self) -> u16 {
348 self.params.max_level
349 }
350
351 fn num_nodes(&self, level: usize) -> usize {
352 self.level_count[level].load(Ordering::Relaxed)
353 }
354
355 fn nodes(&self) -> Arc<Vec<RwLock<GraphBuilderNode>>> {
356 self.nodes.clone()
357 }
358
359 pub fn with_params(params: HnswBuildParams, storage: &impl VectorStore) -> Self {
361 let len = storage.len();
362 let max_level = params.max_level;
363
364 let level_count = (0..max_level)
365 .map(|_| AtomicUsize::new(0))
366 .collect::<Vec<_>>();
367
368 let visited_generator_queue = Arc::new(ArrayQueue::new(get_num_compute_intensive_cpus()));
369 for _ in 0..get_num_compute_intensive_cpus() {
370 visited_generator_queue
371 .push(VisitedGenerator::new(0))
372 .unwrap();
373 }
374 let mut builder = Self {
375 params,
376 nodes: Arc::new(Vec::new()),
377 level_count,
378 entry_point: 0,
379 visited_generator_queue,
380 };
381
382 if storage.is_empty() {
383 return builder;
384 }
385
386 let mut nodes = Vec::with_capacity(len);
387 {
388 if len > 0 {
389 nodes.push(RwLock::new(GraphBuilderNode::new(0, max_level as usize)));
390 }
391 for i in 1..len {
392 nodes.push(RwLock::new(GraphBuilderNode::new(
393 i as u32,
394 builder.random_level() as usize + 1,
395 )));
396 }
397 }
398 builder.nodes = Arc::new(nodes);
399
400 builder
401 }
402
403 fn random_level(&self) -> u16 {
407 let mut rng = rng();
408 let ml = 1.0 / (self.params.m as f32).ln();
409 min(
410 (-rng.random::<f32>().ln() * ml) as u16,
411 self.params.max_level - 1,
412 )
413 }
414
415 fn insert(
417 &self,
418 node: u32,
419 visited_generator: &mut VisitedGenerator,
420 storage: &impl VectorStore,
421 ) {
422 let nodes = &self.nodes;
423 let target_level = nodes[node as usize].read().unwrap().level_neighbors.len() as u16 - 1;
424 let dist_calc = storage.dist_calculator_from_id(node);
425 let mut ep = OrderedNode::new(
426 self.entry_point,
427 dist_calc.distance(self.entry_point).into(),
428 );
429
430 for level in (target_level + 1..self.params.max_level).rev() {
439 let cur_level = HnswLevelView::new(level, nodes);
440 ep = greedy_search(&cur_level, ep, &dist_calc, self.params.prefetch_distance);
441 }
442
443 let mut pruned_neighbors_per_level: Vec<Vec<_>> =
444 vec![Vec::new(); (target_level + 1) as usize];
445 {
446 let mut current_node = nodes[node as usize].write().unwrap();
447 for level in (0..=target_level).rev() {
448 self.level_count[level as usize].fetch_add(1, Ordering::Relaxed);
449
450 let neighbors = self.search_level(&ep, level, &dist_calc, nodes, visited_generator);
451 for neighbor in &neighbors {
452 current_node.add_neighbor(neighbor.id, neighbor.dist, level);
453 }
454 self.prune(storage, &mut current_node, level);
455 pruned_neighbors_per_level[level as usize]
456 .clone_from(¤t_node.level_neighbors_ranked[level as usize]);
457
458 ep = neighbors[0].clone();
459 }
460 }
461 for (level, pruned_neighbors) in pruned_neighbors_per_level.iter().enumerate() {
462 let _: Vec<_> = pruned_neighbors
463 .iter()
464 .map(|unpruned_edge| {
465 let level = level as u16;
466 let m_max = match level {
467 0 => self.params.m * 2,
468 _ => self.params.m,
469 };
470 if unpruned_edge.dist
471 < nodes[unpruned_edge.id as usize]
472 .read()
473 .unwrap()
474 .cutoff(level, m_max)
475 {
476 let mut chosen_node = nodes[unpruned_edge.id as usize].write().unwrap();
477 chosen_node.add_neighbor(node, unpruned_edge.dist, level);
478 self.prune(storage, &mut chosen_node, level);
479 }
480 })
481 .collect();
482 }
483 }
484
485 fn search_level(
486 &self,
487 ep: &OrderedNode,
488 level: u16,
489 dist_calc: &impl DistCalculator,
490 nodes: &Vec<RwLock<GraphBuilderNode>>,
491 visited_generator: &mut VisitedGenerator,
492 ) -> Vec<OrderedNode> {
493 let cur_level = HnswLevelView::new(level, nodes);
494 let mut visited = visited_generator.generate(nodes.len());
495 beam_search(
496 &cur_level,
497 ep,
498 &HnswQueryParams {
499 ef: self.params.ef_construction,
500 lower_bound: None,
501 upper_bound: None,
502 dist_q_c: 0.0,
503 },
504 dist_calc,
505 None,
506 self.params.prefetch_distance,
507 &mut visited,
508 )
509 }
510
511 fn prune(&self, storage: &impl VectorStore, builder_node: &mut GraphBuilderNode, level: u16) {
512 let m_max = match level {
513 0 => self.params.m * 2,
514 _ => self.params.m,
515 };
516
517 let neighbors_ranked = &mut builder_node.level_neighbors_ranked[level as usize];
518 let level_neighbors = neighbors_ranked.clone();
519 if level_neighbors.len() <= m_max {
520 builder_node.update_from_ranked_neighbors(level);
521 return;
522 }
524
525 *neighbors_ranked = select_neighbors_heuristic(storage, &level_neighbors, m_max);
526 builder_node.update_from_ranked_neighbors(level);
527 }
528}
529
530pub(crate) struct HnswLevelView<'a> {
533 level: u16,
534 nodes: &'a Vec<RwLock<GraphBuilderNode>>,
535}
536
537impl<'a> HnswLevelView<'a> {
538 pub fn new(level: u16, nodes: &'a Vec<RwLock<GraphBuilderNode>>) -> Self {
539 Self { level, nodes }
540 }
541}
542
543impl Graph for HnswLevelView<'_> {
544 fn len(&self) -> usize {
545 self.nodes.len()
546 }
547
548 fn neighbors(&self, key: u32) -> Arc<Vec<u32>> {
549 let node = &self.nodes[key as usize];
550 node.read().unwrap().level_neighbors[self.level as usize].clone()
551 }
552}
553
554pub(crate) struct HnswBottomView<'a> {
555 nodes: &'a Vec<RwLock<GraphBuilderNode>>,
556}
557
558impl<'a> HnswBottomView<'a> {
559 pub fn new(nodes: &'a Vec<RwLock<GraphBuilderNode>>) -> Self {
560 Self { nodes }
561 }
562}
563
564impl Graph for HnswBottomView<'_> {
565 fn len(&self) -> usize {
566 self.nodes.len()
567 }
568
569 fn neighbors(&self, key: u32) -> Arc<Vec<u32>> {
570 let node = &self.nodes[key as usize];
571 node.read().unwrap().bottom_neighbors.clone()
572 }
573}
574
575#[derive(Debug, Clone, Copy)]
576pub struct HnswQueryParams {
577 pub ef: usize,
578 pub lower_bound: Option<f32>,
579 pub upper_bound: Option<f32>,
580 pub dist_q_c: f32,
581}
582
583impl From<&Query> for HnswQueryParams {
584 fn from(query: &Query) -> Self {
585 let k = query.k * query.refine_factor.unwrap_or(1) as usize;
586 Self {
587 ef: query.ef.unwrap_or(k + k / 2),
588 lower_bound: query.lower_bound,
589 upper_bound: query.upper_bound,
590 dist_q_c: query.dist_q_c,
591 }
592 }
593}
594
595impl IvfSubIndex for HNSW {
596 type BuildParams = HnswBuildParams;
597 type QueryParams = HnswQueryParams;
598
599 fn load(data: RecordBatch) -> Result<Self>
600 where
601 Self: Sized,
602 {
603 if data.num_rows() == 0 {
604 return Ok(Self::empty());
605 }
606
607 let hnsw_metadata =
608 data.schema_ref()
609 .metadata()
610 .get(HNSW_METADATA_KEY)
611 .ok_or(Error::Index {
612 message: format!("{} not found", HNSW_METADATA_KEY),
613 location: location!(),
614 })?;
615 let hnsw_metadata: HnswMetadata =
616 serde_json::from_str(hnsw_metadata).map_err(|e| Error::Index {
617 message: format!(
618 "Failed to decode HNSW metadata: {}, json: {}",
619 e, hnsw_metadata
620 ),
621 location: location!(),
622 })?;
623
624 let levels: Vec<_> = hnsw_metadata
625 .level_offsets
626 .iter()
627 .tuple_windows()
628 .map(|(start, end)| data.slice(*start, end - start))
629 .collect();
630
631 let level_count = levels.iter().map(|b| b.num_rows()).collect::<Vec<_>>();
632
633 let bottom_level_len = levels[0].num_rows();
634 let mut nodes = Vec::with_capacity(bottom_level_len);
635 for i in 0..bottom_level_len {
636 nodes.push(GraphBuilderNode::new(i as u32, levels.len()));
637 }
638 for (level, batch) in levels.into_iter().enumerate() {
639 let ids = batch[VECTOR_ID_COL].as_primitive::<UInt32Type>();
640 let neighbors = batch[NEIGHBORS_COL].as_list::<i32>();
641 let distances = batch[DIST_COL].as_list::<i32>();
642
643 for ((node, neighbors), distances) in
644 ids.iter().zip(neighbors.iter()).zip(distances.iter())
645 {
646 let node = node.unwrap();
647 let neighbors = neighbors.as_ref().unwrap().as_primitive::<UInt32Type>();
648 let distances = distances.as_ref().unwrap().as_primitive::<Float32Type>();
649
650 nodes[node as usize].level_neighbors_ranked[level] = neighbors
651 .iter()
652 .zip(distances.iter())
653 .map(|(n, dist)| OrderedNode::new(n.unwrap(), OrderedFloat(dist.unwrap())))
654 .collect();
655 nodes[node as usize].update_from_ranked_neighbors(level as u16);
656 }
657 }
658
659 let visited_generator_queue =
660 Arc::new(ArrayQueue::new(get_num_compute_intensive_cpus() * 2));
661 for _ in 0..get_num_compute_intensive_cpus() * 2 {
662 visited_generator_queue
663 .push(VisitedGenerator::new(0))
664 .unwrap();
665 }
666 let inner = HnswBuilder {
667 params: hnsw_metadata.params,
668 nodes: Arc::new(nodes.into_iter().map(RwLock::new).collect()),
669 level_count: level_count.into_iter().map(AtomicUsize::new).collect(),
670 entry_point: hnsw_metadata.entry_point,
671 visited_generator_queue,
672 };
673
674 Ok(Self {
675 inner: Arc::new(inner),
676 })
677 }
678
679 fn name() -> &'static str {
680 HNSW_TYPE
681 }
682
683 fn metadata_key() -> &'static str {
684 "lance:hnsw"
685 }
686
687 fn schema() -> arrow_schema::SchemaRef {
689 arrow_schema::Schema::new(vec![
690 VECTOR_ID_FIELD.clone(),
691 NEIGHBORS_FIELD.clone(),
692 DISTS_FIELD.clone(),
693 ])
694 .into()
695 }
696
697 #[instrument(level = "debug", skip(self, query, storage, prefilter, _metrics))]
698 fn search(
699 &self,
700 query: ArrayRef,
701 k: usize,
702 params: Self::QueryParams,
703 storage: &impl VectorStore,
704 prefilter: Arc<dyn PreFilter>,
705 _metrics: &dyn MetricsCollector,
706 ) -> Result<RecordBatch> {
707 if params.ef < k {
708 return Err(Error::Index {
709 message: "ef must be greater than or equal to k".to_string(),
710 location: location!(),
711 });
712 }
713
714 let schema = VECTOR_RESULT_SCHEMA.clone();
715 if self.is_empty() {
716 return Ok(RecordBatch::new_empty(schema));
717 }
718
719 let mut prefilter_generator = self
720 .inner
721 .visited_generator_queue
722 .pop()
723 .unwrap_or_else(|| VisitedGenerator::new(storage.len()));
724 let prefilter_bitset = if prefilter.is_empty() {
725 None
726 } else {
727 let indices = prefilter.filter_row_ids(Box::new(storage.row_ids()));
728 let mut bitset = prefilter_generator.generate(storage.len());
729 for indices in indices {
730 bitset.insert(indices as u32);
731 }
732 Some(bitset)
733 };
734
735 let remained = prefilter_bitset
736 .as_ref()
737 .map(|b| b.count_ones())
738 .unwrap_or(storage.len());
739 let results = if remained < self.len() * 10 / 100 {
740 let prefilter_bitset =
741 prefilter_bitset.expect("the prefilter bitset must be set for flat search");
742 self.flat_search(storage, query, k, prefilter_bitset, ¶ms)
743 } else {
744 self.search_basic(query, k, ¶ms, prefilter_bitset, storage)?
745 };
746 let _ = self.inner.visited_generator_queue.push(prefilter_generator);
748
749 let (row_ids, dists): (Vec<_>, Vec<_>) = results
751 .into_iter()
752 .map(|r| (storage.row_id(r.id), r.dist.0))
753 .unique_by(|r| r.0)
754 .unzip();
755 let row_ids = Arc::new(UInt64Array::from(row_ids));
756 let distances = Arc::new(Float32Array::from(dists));
757
758 Ok(RecordBatch::try_new(schema, vec![distances, row_ids])?)
759 }
760
761 fn index_vectors(storage: &impl VectorStore, params: Self::BuildParams) -> Result<Self>
763 where
764 Self: Sized,
765 {
766 let inner = HnswBuilder::with_params(params, storage);
767 let hnsw = Self {
768 inner: Arc::new(inner),
769 };
770
771 log::debug!(
772 "Building HNSW graph: num={}, max_levels={}, m={}, ef_construction={}, distance_type:{}",
773 storage.len(),
774 hnsw.inner.params.max_level,
775 hnsw.inner.params.m,
776 hnsw.inner.params.ef_construction,
777 storage.distance_type(),
778 );
779
780 if storage.is_empty() {
781 return Ok(hnsw);
782 }
783
784 let len = storage.len();
785 hnsw.inner.level_count[0].fetch_add(1, Ordering::Relaxed);
786 (1..len).into_par_iter().for_each_init(
787 || VisitedGenerator::new(len),
788 |visited_generator, node| {
789 hnsw.inner.insert(node as u32, visited_generator, storage);
790 },
791 );
792
793 assert_eq!(hnsw.inner.level_count[0].load(Ordering::Relaxed), len);
794 Ok(hnsw)
795 }
796
797 fn remap(
798 &self,
799 _mapping: &HashMap<u64, Option<u64>>, store: &impl VectorStore,
801 ) -> Result<Self> {
802 Self::index_vectors(store, self.inner.params.clone())
805 }
806
807 fn to_batch(&self) -> Result<RecordBatch> {
809 let mut vector_id_builder = UInt32Builder::with_capacity(self.len());
810 let mut neighbors_builder = ListBuilder::with_capacity(UInt32Builder::new(), self.len());
811 let mut distances_builder =
812 ListBuilder::with_capacity(arrow_array::builder::Float32Builder::new(), self.len());
813 let mut batches = Vec::with_capacity(self.max_level() as usize);
814 for level in 0..self.max_level() {
815 let level = level as usize;
816 for (id, node) in self.inner.nodes.iter().enumerate() {
817 let node = node.read().unwrap();
818 if level >= node.level_neighbors.len() {
819 continue;
820 }
821 let neighbors = node.level_neighbors[level].iter().map(|n| Some(*n));
822 let distances = node.level_neighbors_ranked[level]
823 .iter()
824 .map(|n| Some(n.dist.0));
825 vector_id_builder.append_value(id as u32);
826 neighbors_builder.append_value(neighbors);
827 distances_builder.append_value(distances);
828 }
829
830 let batch = RecordBatch::try_new(
831 Self::schema(),
832 vec![
833 Arc::new(vector_id_builder.finish()),
834 Arc::new(neighbors_builder.finish()),
835 Arc::new(distances_builder.finish()),
836 ],
837 )?;
838 batches.push(batch);
839 }
840
841 let metadata = self.metadata();
842 let metadata = serde_json::to_string(&metadata)?;
843 let schema = Self::schema()
844 .as_ref()
845 .clone()
846 .with_metadata(HashMap::from_iter(vec![(
847 HNSW_METADATA_KEY.to_string(),
848 metadata,
849 )]));
850 let batch = concat_batches(&Self::schema(), batches.iter())?;
851 let batch = batch.with_schema(Arc::new(schema))?;
852 Ok(batch)
853 }
854}
855
856#[cfg(test)]
857mod tests {
858 use std::sync::Arc;
859
860 use arrow_array::FixedSizeListArray;
861 use arrow_schema::Schema;
862 use lance_arrow::FixedSizeListArrayExt;
863 use lance_file::previous::{
864 reader::FileReader as PreviousFileReader,
865 writer::{
866 FileWriter as PreviousFileWriter, FileWriterOptions as PreviousFileWriterOptions,
867 },
868 };
869 use lance_io::object_store::ObjectStore;
870 use lance_linalg::distance::DistanceType;
871 use lance_table::format::SelfDescribingFileReader;
872 use lance_table::io::manifest::ManifestDescribing;
873 use lance_testing::datagen::generate_random_array;
874 use object_store::path::Path;
875
876 use crate::scalar::IndexWriter;
877 use crate::vector::v3::subindex::IvfSubIndex;
878 use crate::vector::{
879 flat::storage::FlatFloatStorage,
880 graph::{DISTS_FIELD, NEIGHBORS_FIELD},
881 hnsw::{
882 builder::{HnswBuildParams, HnswQueryParams},
883 HNSW, VECTOR_ID_FIELD,
884 },
885 };
886
887 #[tokio::test]
888 async fn test_builder_write_load() {
889 const DIM: usize = 32;
890 const TOTAL: usize = 2048;
891 const NUM_EDGES: usize = 20;
892 let data = generate_random_array(TOTAL * DIM);
893 let fsl = FixedSizeListArray::try_new_from_values(data, DIM as i32).unwrap();
894 let store = Arc::new(FlatFloatStorage::new(fsl.clone(), DistanceType::L2));
895 let builder = HNSW::index_vectors(
896 store.as_ref(),
897 HnswBuildParams::default()
898 .num_edges(NUM_EDGES)
899 .ef_construction(50),
900 )
901 .unwrap();
902
903 let object_store = ObjectStore::memory();
904 let path = Path::from("test_builder_write_load");
905 let writer = object_store.create(&path).await.unwrap();
906 let schema = Schema::new(vec![
907 VECTOR_ID_FIELD.clone(),
908 NEIGHBORS_FIELD.clone(),
909 DISTS_FIELD.clone(),
910 ]);
911 let schema = lance_core::datatypes::Schema::try_from(&schema).unwrap();
912 let mut writer = PreviousFileWriter::<ManifestDescribing>::with_object_writer(
913 writer,
914 schema,
915 &PreviousFileWriterOptions::default(),
916 )
917 .unwrap();
918 let batch = builder.to_batch().unwrap();
919 let metadata = batch.schema_ref().metadata().clone();
920 writer.write_record_batch(batch).await.unwrap();
921 writer.finish_with_metadata(&metadata).await.unwrap();
922
923 let reader = PreviousFileReader::try_new_self_described(&object_store, &path, None)
924 .await
925 .unwrap();
926 let batch = reader
927 .read_range(0..reader.len(), reader.schema())
928 .await
929 .unwrap();
930 let loaded_hnsw = HNSW::load(batch).unwrap();
931
932 let query = fsl.value(0);
933 let k = 10;
934 let params = HnswQueryParams {
935 ef: 50,
936 lower_bound: None,
937 upper_bound: None,
938 dist_q_c: 0.0,
939 };
940 let builder_results = builder
941 .search_basic(query.clone(), k, ¶ms, None, store.as_ref())
942 .unwrap();
943 let loaded_results = loaded_hnsw
944 .search_basic(query, k, ¶ms, None, store.as_ref())
945 .unwrap();
946 assert_eq!(builder_results, loaded_results);
947 }
948}