1use arrow::array::{AsArray, ListBuilder, UInt32Builder};
7use arrow::compute::concat_batches;
8use arrow::datatypes::{Float32Type, UInt32Type};
9use arrow_array::{ArrayRef, Float32Array, RecordBatch, UInt64Array};
10use crossbeam_queue::ArrayQueue;
11use deepsize::DeepSizeOf;
12use itertools::Itertools;
13
14use lance_core::utils::tokio::get_num_compute_intensive_cpus;
15use lance_linalg::distance::DistanceType;
16use rayon::prelude::*;
17use snafu::location;
18use std::cmp::min;
19use std::collections::{BinaryHeap, HashMap};
20use std::fmt::Debug;
21use std::iter;
22use std::sync::atomic::{AtomicUsize, Ordering};
23use std::sync::Arc;
24use std::sync::RwLock;
25use tracing::instrument;
26
27use lance_core::{Error, Result};
28use rand::{rng, Rng};
29use serde::{Deserialize, Serialize};
30
31use super::super::graph::beam_search;
32use super::{select_neighbors_heuristic, HnswMetadata, HNSW_TYPE, VECTOR_ID_COL, VECTOR_ID_FIELD};
33use crate::metrics::MetricsCollector;
34use crate::prefilter::PreFilter;
35use crate::vector::flat::storage::FlatFloatStorage;
36use crate::vector::graph::builder::GraphBuilderNode;
37use crate::vector::graph::{greedy_search, Visited};
38use crate::vector::graph::{
39 Graph, OrderedFloat, OrderedNode, VisitedGenerator, DISTS_FIELD, NEIGHBORS_COL, NEIGHBORS_FIELD,
40};
41use crate::vector::storage::{DistCalculator, VectorStore};
42use crate::vector::v3::subindex::IvfSubIndex;
43use crate::vector::{Query, DIST_COL, VECTOR_RESULT_SCHEMA};
44
45pub const HNSW_METADATA_KEY: &str = "lance:hnsw";
46
47#[derive(Debug, Clone, Serialize, Deserialize, DeepSizeOf)]
49pub struct HnswBuildParams {
50 pub max_level: u16,
52
53 pub m: usize,
55
56 pub ef_construction: usize,
58
59 pub prefetch_distance: Option<usize>,
61}
62
63impl Default for HnswBuildParams {
64 fn default() -> Self {
65 Self {
66 max_level: 7,
67 m: 20,
68 ef_construction: 150,
69 prefetch_distance: Some(2),
70 }
71 }
72}
73
74impl HnswBuildParams {
75 pub fn max_level(mut self, max_level: u16) -> Self {
78 self.max_level = max_level;
79 self
80 }
81
82 pub fn num_edges(mut self, m: usize) -> Self {
85 self.m = m;
86 self
87 }
88
89 pub fn ef_construction(mut self, ef_construction: usize) -> Self {
94 self.ef_construction = ef_construction;
95 self
96 }
97
98 pub async fn build(self, data: ArrayRef, distance_type: DistanceType) -> Result<HNSW> {
104 let vec_store = Arc::new(FlatFloatStorage::new(
105 data.as_fixed_size_list().clone(),
106 distance_type,
107 ));
108 HNSW::index_vectors(vec_store.as_ref(), self)
109 }
110}
111
112#[derive(Clone, DeepSizeOf)]
120pub struct HNSW {
121 inner: Arc<HnswBuilder>,
122}
123
124impl Debug for HNSW {
125 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126 write!(f, "HNSW(max_layers: {})", self.inner.max_level() as usize,)
127 }
128}
129
130impl HNSW {
131 pub fn empty() -> Self {
132 Self {
133 inner: Arc::new(HnswBuilder {
134 params: HnswBuildParams::default(),
135 nodes: Arc::new(Vec::new()),
136 level_count: Vec::new(),
137 entry_point: 0,
138 visited_generator_queue: Arc::new(ArrayQueue::new(1)),
139 }),
140 }
141 }
142
143 pub fn len(&self) -> usize {
144 self.inner.nodes.len()
145 }
146
147 pub fn is_empty(&self) -> bool {
148 self.len() == 0
149 }
150
151 pub fn max_level(&self) -> u16 {
152 self.inner.max_level()
153 }
154
155 pub fn num_nodes(&self, level: usize) -> usize {
156 self.inner.num_nodes(level)
157 }
158
159 pub fn nodes(&self) -> Arc<Vec<RwLock<GraphBuilderNode>>> {
160 self.inner.nodes()
161 }
162
163 #[allow(clippy::too_many_arguments)]
164 pub fn search_inner(
165 &self,
166 query: ArrayRef,
167 k: usize,
168 params: &HnswQueryParams,
169 bitset: Option<Visited>,
170 visited_generator: &mut VisitedGenerator,
171 storage: &impl VectorStore,
172 prefetch_distance: Option<usize>,
173 ) -> Result<Vec<OrderedNode>> {
174 let dist_calc = storage.dist_calculator(query, params.dist_q_c);
175 let mut ep = OrderedNode::new(0, dist_calc.distance(0).into());
176 let nodes = &self.nodes();
177 for level in (0..self.max_level()).rev() {
178 let cur_level = HnswLevelView::new(level, nodes);
179 ep = greedy_search(
180 &cur_level,
181 ep,
182 &dist_calc,
183 self.inner.params.prefetch_distance,
184 );
185 }
186
187 let bottom_level = HnswBottomView::new(nodes);
188 let mut visited = visited_generator.generate(storage.len());
189 Ok(beam_search(
190 &bottom_level,
191 &ep,
192 params,
193 &dist_calc,
194 bitset.as_ref(),
195 prefetch_distance,
196 &mut visited,
197 )
198 .into_iter()
199 .take(k)
200 .collect())
201 }
202
203 #[instrument(level = "debug", skip(self, query, bitset, storage))]
204 pub fn search_basic(
205 &self,
206 query: ArrayRef,
207 k: usize,
208 params: &HnswQueryParams,
209 bitset: Option<Visited>,
210 storage: &impl VectorStore,
211 ) -> Result<Vec<OrderedNode>> {
212 let mut visited_generator = self
213 .inner
214 .visited_generator_queue
215 .pop()
216 .unwrap_or_else(|| VisitedGenerator::new(storage.len()));
217 let result = self.search_inner(
218 query,
219 k,
220 params,
221 bitset,
222 &mut visited_generator,
223 storage,
224 Some(2),
225 );
226
227 match self.inner.visited_generator_queue.push(visited_generator) {
228 Ok(_) => {}
229 Err(_) => {
230 log::warn!("visited_generator_queue is full");
231 }
232 }
233
234 result
235 }
236
237 #[instrument(level = "debug", skip(self, storage, query, prefilter_bitset))]
238 fn flat_search(
239 &self,
240 storage: &impl VectorStore,
241 query: ArrayRef,
242 k: usize,
243 prefilter_bitset: Visited,
244 params: &HnswQueryParams,
245 ) -> Vec<OrderedNode> {
246 let node_ids = storage
247 .row_ids()
248 .enumerate()
249 .filter_map(|(node_id, _)| {
250 prefilter_bitset
251 .contains(node_id as u32)
252 .then_some(node_id as u32)
253 })
254 .collect_vec();
255
256 let lower_bound: OrderedFloat = params.lower_bound.unwrap_or(f32::MIN).into();
257 let upper_bound: OrderedFloat = params.upper_bound.unwrap_or(f32::MAX).into();
258
259 let dist_calc = storage.dist_calculator(query, params.dist_q_c);
260 let mut heap = BinaryHeap::<OrderedNode>::with_capacity(k);
261 for i in 0..node_ids.len() {
262 if let Some(ahead) = self.inner.params.prefetch_distance {
263 if i + ahead < node_ids.len() {
264 dist_calc.prefetch(node_ids[i + ahead]);
265 }
266 }
267 let node_id = node_ids[i];
268 let dist: OrderedFloat = dist_calc.distance(node_id).into();
269 if dist <= lower_bound || dist > upper_bound {
270 continue;
271 }
272 if heap.len() < k {
273 heap.push((dist, node_id).into());
274 } else if dist < heap.peek().unwrap().dist {
275 heap.pop();
276 heap.push((dist, node_id).into());
277 }
278 }
279 heap.into_sorted_vec()
280 }
281
282 pub fn metadata(&self) -> HnswMetadata {
284 let level_offsets = self
287 .inner
288 .level_count
289 .iter()
290 .chain(iter::once(&AtomicUsize::new(0)))
291 .scan(0, |state, x| {
292 let start = *state;
293 *state += x.load(Ordering::Relaxed);
294 Some(start)
295 })
296 .collect();
297
298 HnswMetadata {
299 entry_point: self.inner.entry_point,
300 params: self.inner.params.clone(),
301 level_offsets,
302 }
303 }
304}
305
306struct HnswBuilder {
307 params: HnswBuildParams,
308
309 nodes: Arc<Vec<RwLock<GraphBuilderNode>>>,
310 level_count: Vec<AtomicUsize>,
311
312 entry_point: u32,
313
314 visited_generator_queue: Arc<ArrayQueue<VisitedGenerator>>,
315}
316
317impl DeepSizeOf for HnswBuilder {
318 fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
319 self.params.deep_size_of_children(context)
320 + self.nodes.deep_size_of_children(context)
321 + self.level_count.deep_size_of_children(context)
322 }
324}
325
326impl HnswBuilder {
327 fn max_level(&self) -> u16 {
328 self.params.max_level
329 }
330
331 fn num_nodes(&self, level: usize) -> usize {
332 self.level_count[level].load(Ordering::Relaxed)
333 }
334
335 fn nodes(&self) -> Arc<Vec<RwLock<GraphBuilderNode>>> {
336 self.nodes.clone()
337 }
338
339 pub fn with_params(params: HnswBuildParams, storage: &impl VectorStore) -> Self {
341 let len = storage.len();
342 let max_level = params.max_level;
343
344 let level_count = (0..max_level)
345 .map(|_| AtomicUsize::new(0))
346 .collect::<Vec<_>>();
347
348 let visited_generator_queue = Arc::new(ArrayQueue::new(get_num_compute_intensive_cpus()));
349 for _ in 0..get_num_compute_intensive_cpus() {
350 visited_generator_queue
351 .push(VisitedGenerator::new(0))
352 .unwrap();
353 }
354 let mut builder = Self {
355 params,
356 nodes: Arc::new(Vec::new()),
357 level_count,
358 entry_point: 0,
359 visited_generator_queue,
360 };
361
362 if storage.is_empty() {
363 return builder;
364 }
365
366 let mut nodes = Vec::with_capacity(len);
367 {
368 if len > 0 {
369 nodes.push(RwLock::new(GraphBuilderNode::new(0, max_level as usize)));
370 }
371 for i in 1..len {
372 nodes.push(RwLock::new(GraphBuilderNode::new(
373 i as u32,
374 builder.random_level() as usize + 1,
375 )));
376 }
377 }
378 builder.nodes = Arc::new(nodes);
379
380 builder
381 }
382
383 fn random_level(&self) -> u16 {
387 let mut rng = rng();
388 let ml = 1.0 / (self.params.m as f32).ln();
389 min(
390 (-rng.random::<f32>().ln() * ml) as u16,
391 self.params.max_level - 1,
392 )
393 }
394
395 fn insert(
397 &self,
398 node: u32,
399 visited_generator: &mut VisitedGenerator,
400 storage: &impl VectorStore,
401 ) {
402 let nodes = &self.nodes;
403 let target_level = nodes[node as usize].read().unwrap().level_neighbors.len() as u16 - 1;
404 let dist_calc = storage.dist_calculator_from_id(node);
405 let mut ep = OrderedNode::new(
406 self.entry_point,
407 dist_calc.distance(self.entry_point).into(),
408 );
409
410 for level in (target_level + 1..self.params.max_level).rev() {
419 let cur_level = HnswLevelView::new(level, nodes);
420 ep = greedy_search(&cur_level, ep, &dist_calc, self.params.prefetch_distance);
421 }
422
423 let mut pruned_neighbors_per_level: Vec<Vec<_>> =
424 vec![Vec::new(); (target_level + 1) as usize];
425 {
426 let mut current_node = nodes[node as usize].write().unwrap();
427 for level in (0..=target_level).rev() {
428 self.level_count[level as usize].fetch_add(1, Ordering::Relaxed);
429
430 let neighbors = self.search_level(&ep, level, &dist_calc, nodes, visited_generator);
431 for neighbor in &neighbors {
432 current_node.add_neighbor(neighbor.id, neighbor.dist, level);
433 }
434 self.prune(storage, &mut current_node, level);
435 pruned_neighbors_per_level[level as usize]
436 .clone_from(¤t_node.level_neighbors_ranked[level as usize]);
437
438 ep = neighbors[0].clone();
439 }
440 }
441 for (level, pruned_neighbors) in pruned_neighbors_per_level.iter().enumerate() {
442 let _: Vec<_> = pruned_neighbors
443 .iter()
444 .map(|unpruned_edge| {
445 let level = level as u16;
446 let m_max = match level {
447 0 => self.params.m * 2,
448 _ => self.params.m,
449 };
450 if unpruned_edge.dist
451 < nodes[unpruned_edge.id as usize]
452 .read()
453 .unwrap()
454 .cutoff(level, m_max)
455 {
456 let mut chosen_node = nodes[unpruned_edge.id as usize].write().unwrap();
457 chosen_node.add_neighbor(node, unpruned_edge.dist, level);
458 self.prune(storage, &mut chosen_node, level);
459 }
460 })
461 .collect();
462 }
463 }
464
465 fn search_level(
466 &self,
467 ep: &OrderedNode,
468 level: u16,
469 dist_calc: &impl DistCalculator,
470 nodes: &Vec<RwLock<GraphBuilderNode>>,
471 visited_generator: &mut VisitedGenerator,
472 ) -> Vec<OrderedNode> {
473 let cur_level = HnswLevelView::new(level, nodes);
474 let mut visited = visited_generator.generate(nodes.len());
475 beam_search(
476 &cur_level,
477 ep,
478 &HnswQueryParams {
479 ef: self.params.ef_construction,
480 lower_bound: None,
481 upper_bound: None,
482 dist_q_c: 0.0,
483 },
484 dist_calc,
485 None,
486 self.params.prefetch_distance,
487 &mut visited,
488 )
489 }
490
491 fn prune(&self, storage: &impl VectorStore, builder_node: &mut GraphBuilderNode, level: u16) {
492 let m_max = match level {
493 0 => self.params.m * 2,
494 _ => self.params.m,
495 };
496
497 let neighbors_ranked = &mut builder_node.level_neighbors_ranked[level as usize];
498 let level_neighbors = neighbors_ranked.clone();
499 if level_neighbors.len() <= m_max {
500 builder_node.update_from_ranked_neighbors(level);
501 return;
502 }
504
505 *neighbors_ranked = select_neighbors_heuristic(storage, &level_neighbors, m_max);
506 builder_node.update_from_ranked_neighbors(level);
507 }
508}
509
510pub(crate) struct HnswLevelView<'a> {
513 level: u16,
514 nodes: &'a Vec<RwLock<GraphBuilderNode>>,
515}
516
517impl<'a> HnswLevelView<'a> {
518 pub fn new(level: u16, nodes: &'a Vec<RwLock<GraphBuilderNode>>) -> Self {
519 Self { level, nodes }
520 }
521}
522
523impl Graph for HnswLevelView<'_> {
524 fn len(&self) -> usize {
525 self.nodes.len()
526 }
527
528 fn neighbors(&self, key: u32) -> Arc<Vec<u32>> {
529 let node = &self.nodes[key as usize];
530 node.read().unwrap().level_neighbors[self.level as usize].clone()
531 }
532}
533
534pub(crate) struct HnswBottomView<'a> {
535 nodes: &'a Vec<RwLock<GraphBuilderNode>>,
536}
537
538impl<'a> HnswBottomView<'a> {
539 pub fn new(nodes: &'a Vec<RwLock<GraphBuilderNode>>) -> Self {
540 Self { nodes }
541 }
542}
543
544impl Graph for HnswBottomView<'_> {
545 fn len(&self) -> usize {
546 self.nodes.len()
547 }
548
549 fn neighbors(&self, key: u32) -> Arc<Vec<u32>> {
550 let node = &self.nodes[key as usize];
551 node.read().unwrap().bottom_neighbors.clone()
552 }
553}
554
555#[derive(Debug, Clone, Copy)]
556pub struct HnswQueryParams {
557 pub ef: usize,
558 pub lower_bound: Option<f32>,
559 pub upper_bound: Option<f32>,
560 pub dist_q_c: f32,
561}
562
563impl From<&Query> for HnswQueryParams {
564 fn from(query: &Query) -> Self {
565 let k = query.k * query.refine_factor.unwrap_or(1) as usize;
566 Self {
567 ef: query.ef.unwrap_or(k + k / 2),
568 lower_bound: query.lower_bound,
569 upper_bound: query.upper_bound,
570 dist_q_c: query.dist_q_c,
571 }
572 }
573}
574
575impl IvfSubIndex for HNSW {
576 type BuildParams = HnswBuildParams;
577 type QueryParams = HnswQueryParams;
578
579 fn load(data: RecordBatch) -> Result<Self>
580 where
581 Self: Sized,
582 {
583 if data.num_rows() == 0 {
584 return Ok(Self::empty());
585 }
586
587 let hnsw_metadata =
588 data.schema_ref()
589 .metadata()
590 .get(HNSW_METADATA_KEY)
591 .ok_or(Error::Index {
592 message: format!("{} not found", HNSW_METADATA_KEY),
593 location: location!(),
594 })?;
595 let hnsw_metadata: HnswMetadata =
596 serde_json::from_str(hnsw_metadata).map_err(|e| Error::Index {
597 message: format!(
598 "Failed to decode HNSW metadata: {}, json: {}",
599 e, hnsw_metadata
600 ),
601 location: location!(),
602 })?;
603
604 let levels: Vec<_> = hnsw_metadata
605 .level_offsets
606 .iter()
607 .tuple_windows()
608 .map(|(start, end)| data.slice(*start, end - start))
609 .collect();
610
611 let level_count = levels.iter().map(|b| b.num_rows()).collect::<Vec<_>>();
612
613 let bottom_level_len = levels[0].num_rows();
614 let mut nodes = Vec::with_capacity(bottom_level_len);
615 for i in 0..bottom_level_len {
616 nodes.push(GraphBuilderNode::new(i as u32, levels.len()));
617 }
618 for (level, batch) in levels.into_iter().enumerate() {
619 let ids = batch[VECTOR_ID_COL].as_primitive::<UInt32Type>();
620 let neighbors = batch[NEIGHBORS_COL].as_list::<i32>();
621 let distances = batch[DIST_COL].as_list::<i32>();
622
623 for ((node, neighbors), distances) in
624 ids.iter().zip(neighbors.iter()).zip(distances.iter())
625 {
626 let node = node.unwrap();
627 let neighbors = neighbors.as_ref().unwrap().as_primitive::<UInt32Type>();
628 let distances = distances.as_ref().unwrap().as_primitive::<Float32Type>();
629
630 nodes[node as usize].level_neighbors_ranked[level] = neighbors
631 .iter()
632 .zip(distances.iter())
633 .map(|(n, dist)| OrderedNode::new(n.unwrap(), OrderedFloat(dist.unwrap())))
634 .collect();
635 nodes[node as usize].update_from_ranked_neighbors(level as u16);
636 }
637 }
638
639 let visited_generator_queue =
640 Arc::new(ArrayQueue::new(get_num_compute_intensive_cpus() * 2));
641 for _ in 0..get_num_compute_intensive_cpus() * 2 {
642 visited_generator_queue
643 .push(VisitedGenerator::new(0))
644 .unwrap();
645 }
646 let inner = HnswBuilder {
647 params: hnsw_metadata.params,
648 nodes: Arc::new(nodes.into_iter().map(RwLock::new).collect()),
649 level_count: level_count.into_iter().map(AtomicUsize::new).collect(),
650 entry_point: hnsw_metadata.entry_point,
651 visited_generator_queue,
652 };
653
654 Ok(Self {
655 inner: Arc::new(inner),
656 })
657 }
658
659 fn name() -> &'static str {
660 HNSW_TYPE
661 }
662
663 fn metadata_key() -> &'static str {
664 "lance:hnsw"
665 }
666
667 fn schema() -> arrow_schema::SchemaRef {
669 arrow_schema::Schema::new(vec![
670 VECTOR_ID_FIELD.clone(),
671 NEIGHBORS_FIELD.clone(),
672 DISTS_FIELD.clone(),
673 ])
674 .into()
675 }
676
677 #[instrument(level = "debug", skip(self, query, storage, prefilter, _metrics))]
678 fn search(
679 &self,
680 query: ArrayRef,
681 k: usize,
682 params: Self::QueryParams,
683 storage: &impl VectorStore,
684 prefilter: Arc<dyn PreFilter>,
685 _metrics: &dyn MetricsCollector,
686 ) -> Result<RecordBatch> {
687 if params.ef < k {
688 return Err(Error::Index {
689 message: "ef must be greater than or equal to k".to_string(),
690 location: location!(),
691 });
692 }
693
694 let schema = VECTOR_RESULT_SCHEMA.clone();
695 if self.is_empty() {
696 return Ok(RecordBatch::new_empty(schema));
697 }
698
699 let mut prefilter_generator = self
700 .inner
701 .visited_generator_queue
702 .pop()
703 .unwrap_or_else(|| VisitedGenerator::new(storage.len()));
704 let prefilter_bitset = if prefilter.is_empty() {
705 None
706 } else {
707 let indices = prefilter.filter_row_ids(Box::new(storage.row_ids()));
708 let mut bitset = prefilter_generator.generate(storage.len());
709 for indices in indices {
710 bitset.insert(indices as u32);
711 }
712 Some(bitset)
713 };
714
715 let remained = prefilter_bitset
716 .as_ref()
717 .map(|b| b.count_ones())
718 .unwrap_or(storage.len());
719 let results = if remained < self.len() * 10 / 100 {
720 let prefilter_bitset =
721 prefilter_bitset.expect("the prefilter bitset must be set for flat search");
722 self.flat_search(storage, query, k, prefilter_bitset, ¶ms)
723 } else {
724 self.search_basic(query, k, ¶ms, prefilter_bitset, storage)?
725 };
726 let _ = self.inner.visited_generator_queue.push(prefilter_generator);
728
729 let (row_ids, dists): (Vec<_>, Vec<_>) = results
731 .into_iter()
732 .map(|r| (storage.row_id(r.id), r.dist.0))
733 .unique_by(|r| r.0)
734 .unzip();
735 let row_ids = Arc::new(UInt64Array::from(row_ids));
736 let distances = Arc::new(Float32Array::from(dists));
737
738 Ok(RecordBatch::try_new(schema, vec![distances, row_ids])?)
739 }
740
741 fn index_vectors(storage: &impl VectorStore, params: Self::BuildParams) -> Result<Self>
743 where
744 Self: Sized,
745 {
746 let inner = HnswBuilder::with_params(params, storage);
747 let hnsw = Self {
748 inner: Arc::new(inner),
749 };
750
751 log::debug!(
752 "Building HNSW graph: num={}, max_levels={}, m={}, ef_construction={}, distance_type:{}",
753 storage.len(),
754 hnsw.inner.params.max_level,
755 hnsw.inner.params.m,
756 hnsw.inner.params.ef_construction,
757 storage.distance_type(),
758 );
759
760 if storage.is_empty() {
761 return Ok(hnsw);
762 }
763
764 let len = storage.len();
765 hnsw.inner.level_count[0].fetch_add(1, Ordering::Relaxed);
766 (1..len).into_par_iter().for_each_init(
767 || VisitedGenerator::new(len),
768 |visited_generator, node| {
769 hnsw.inner.insert(node as u32, visited_generator, storage);
770 },
771 );
772
773 assert_eq!(hnsw.inner.level_count[0].load(Ordering::Relaxed), len);
774 Ok(hnsw)
775 }
776
777 fn remap(
778 &self,
779 _mapping: &HashMap<u64, Option<u64>>, store: &impl VectorStore,
781 ) -> Result<Self> {
782 Self::index_vectors(store, self.inner.params.clone())
785 }
786
787 fn to_batch(&self) -> Result<RecordBatch> {
789 let mut vector_id_builder = UInt32Builder::with_capacity(self.len());
790 let mut neighbors_builder = ListBuilder::with_capacity(UInt32Builder::new(), self.len());
791 let mut distances_builder =
792 ListBuilder::with_capacity(arrow_array::builder::Float32Builder::new(), self.len());
793 let mut batches = Vec::with_capacity(self.max_level() as usize);
794 for level in 0..self.max_level() {
795 let level = level as usize;
796 for (id, node) in self.inner.nodes.iter().enumerate() {
797 let node = node.read().unwrap();
798 if level >= node.level_neighbors.len() {
799 continue;
800 }
801 let neighbors = node.level_neighbors[level].iter().map(|n| Some(*n));
802 let distances = node.level_neighbors_ranked[level]
803 .iter()
804 .map(|n| Some(n.dist.0));
805 vector_id_builder.append_value(id as u32);
806 neighbors_builder.append_value(neighbors);
807 distances_builder.append_value(distances);
808 }
809
810 let batch = RecordBatch::try_new(
811 Self::schema(),
812 vec![
813 Arc::new(vector_id_builder.finish()),
814 Arc::new(neighbors_builder.finish()),
815 Arc::new(distances_builder.finish()),
816 ],
817 )?;
818 batches.push(batch);
819 }
820
821 let metadata = self.metadata();
822 let metadata = serde_json::to_string(&metadata)?;
823 let schema = Self::schema()
824 .as_ref()
825 .clone()
826 .with_metadata(HashMap::from_iter(vec![(
827 HNSW_METADATA_KEY.to_string(),
828 metadata,
829 )]));
830 let batch = concat_batches(&Self::schema(), batches.iter())?;
831 let batch = batch.with_schema(Arc::new(schema))?;
832 Ok(batch)
833 }
834}
835
836#[cfg(test)]
837mod tests {
838 use std::sync::Arc;
839
840 use arrow_array::FixedSizeListArray;
841 use arrow_schema::Schema;
842 use lance_arrow::FixedSizeListArrayExt;
843 use lance_file::{
844 reader::FileReader,
845 writer::{FileWriter, FileWriterOptions},
846 };
847 use lance_io::object_store::ObjectStore;
848 use lance_linalg::distance::DistanceType;
849 use lance_table::format::SelfDescribingFileReader;
850 use lance_table::io::manifest::ManifestDescribing;
851 use lance_testing::datagen::generate_random_array;
852 use object_store::path::Path;
853
854 use crate::scalar::IndexWriter;
855 use crate::vector::v3::subindex::IvfSubIndex;
856 use crate::vector::{
857 flat::storage::FlatFloatStorage,
858 graph::{DISTS_FIELD, NEIGHBORS_FIELD},
859 hnsw::{
860 builder::{HnswBuildParams, HnswQueryParams},
861 HNSW, VECTOR_ID_FIELD,
862 },
863 };
864
865 #[tokio::test]
866 async fn test_builder_write_load() {
867 const DIM: usize = 32;
868 const TOTAL: usize = 2048;
869 const NUM_EDGES: usize = 20;
870 let data = generate_random_array(TOTAL * DIM);
871 let fsl = FixedSizeListArray::try_new_from_values(data, DIM as i32).unwrap();
872 let store = Arc::new(FlatFloatStorage::new(fsl.clone(), DistanceType::L2));
873 let builder = HNSW::index_vectors(
874 store.as_ref(),
875 HnswBuildParams::default()
876 .num_edges(NUM_EDGES)
877 .ef_construction(50),
878 )
879 .unwrap();
880
881 let object_store = ObjectStore::memory();
882 let path = Path::from("test_builder_write_load");
883 let writer = object_store.create(&path).await.unwrap();
884 let schema = Schema::new(vec![
885 VECTOR_ID_FIELD.clone(),
886 NEIGHBORS_FIELD.clone(),
887 DISTS_FIELD.clone(),
888 ]);
889 let schema = lance_core::datatypes::Schema::try_from(&schema).unwrap();
890 let mut writer = FileWriter::<ManifestDescribing>::with_object_writer(
891 writer,
892 schema,
893 &FileWriterOptions::default(),
894 )
895 .unwrap();
896 let batch = builder.to_batch().unwrap();
897 let metadata = batch.schema_ref().metadata().clone();
898 writer.write_record_batch(batch).await.unwrap();
899 writer.finish_with_metadata(&metadata).await.unwrap();
900
901 let reader = FileReader::try_new_self_described(&object_store, &path, None)
902 .await
903 .unwrap();
904 let batch = reader
905 .read_range(0..reader.len(), reader.schema())
906 .await
907 .unwrap();
908 let loaded_hnsw = HNSW::load(batch).unwrap();
909
910 let query = fsl.value(0);
911 let k = 10;
912 let params = HnswQueryParams {
913 ef: 50,
914 lower_bound: None,
915 upper_bound: None,
916 dist_q_c: 0.0,
917 };
918 let builder_results = builder
919 .search_basic(query.clone(), k, ¶ms, None, store.as_ref())
920 .unwrap();
921 let loaded_results = loaded_hnsw
922 .search_basic(query, k, ¶ms, None, store.as_ref())
923 .unwrap();
924 assert_eq!(builder_results, loaded_results);
925 }
926}