Skip to main content

lance_index/vector/hnsw/
builder.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Builder of Hnsw Graph.
5
6use arrow::array::{AsArray, ListBuilder, UInt32Builder};
7use arrow::compute::concat_batches;
8use arrow::datatypes::{DataType, Float32Type, UInt32Type};
9use arrow_array::{ArrayRef, Float32Array, RecordBatch, UInt64Array};
10use crossbeam_queue::ArrayQueue;
11use deepsize::DeepSizeOf;
12use itertools::Itertools;
13
14use lance_core::utils::tokio::get_num_compute_intensive_cpus;
15use lance_linalg::distance::DistanceType;
16use rayon::prelude::*;
17use std::cmp::min;
18use std::collections::{BinaryHeap, HashMap, VecDeque};
19use std::fmt::Debug;
20use std::iter;
21use std::sync::Arc;
22use std::sync::RwLock;
23use std::sync::atomic::{AtomicUsize, Ordering};
24use tracing::instrument;
25
26use lance_core::{Error, Result};
27use rand::{Rng, rng};
28use serde::{Deserialize, Serialize};
29
30use super::super::graph::beam_search;
31use super::{HNSW_TYPE, HnswMetadata, VECTOR_ID_COL, VECTOR_ID_FIELD, select_neighbors_heuristic};
32use crate::metrics::MetricsCollector;
33use crate::prefilter::PreFilter;
34use crate::vector::flat::storage::{FlatBinStorage, FlatFloatStorage};
35use crate::vector::graph::builder::GraphBuilderNode;
36use crate::vector::graph::{
37    BorrowingGraph, DISTS_FIELD, Graph, NEIGHBORS_COL, NEIGHBORS_FIELD, OrderedFloat, OrderedNode,
38    VisitedGenerator,
39};
40use crate::vector::graph::{Visited, beam_search_borrowed, greedy_search, greedy_search_borrowed};
41use crate::vector::storage::{DistCalculator, VectorStore};
42use crate::vector::v3::subindex::IvfSubIndex;
43use crate::vector::{DIST_COL, Query, VECTOR_RESULT_SCHEMA};
44
45pub const HNSW_METADATA_KEY: &str = "lance:hnsw";
46
47/// Parameters of building HNSW index
48#[derive(Debug, Clone, Serialize, Deserialize, DeepSizeOf)]
49pub struct HnswBuildParams {
50    /// max level ofm
51    pub max_level: u16,
52
53    /// number of connections to establish while inserting new element
54    pub m: usize,
55
56    /// size of the dynamic list for the candidates
57    pub ef_construction: usize,
58
59    /// number of vectors ahead to prefetch while building the graph
60    pub prefetch_distance: Option<usize>,
61}
62
63impl Default for HnswBuildParams {
64    fn default() -> Self {
65        Self {
66            max_level: 7,
67            m: 20,
68            ef_construction: 150,
69            prefetch_distance: Some(2),
70        }
71    }
72}
73
74impl HnswBuildParams {
75    /// The maximum level of the graph.
76    /// The default value is `8`.
77    pub fn max_level(mut self, max_level: u16) -> Self {
78        self.max_level = max_level;
79        self
80    }
81
82    /// The number of connections to establish while inserting new element
83    /// The default value is `30`.
84    pub fn num_edges(mut self, m: usize) -> Self {
85        self.m = m;
86        self
87    }
88
89    /// Number of candidates to be considered when searching for the nearest neighbors
90    /// during the construction of the graph.
91    ///
92    /// The default value is `100`.
93    pub fn ef_construction(mut self, ef_construction: usize) -> Self {
94        self.ef_construction = ef_construction;
95        self
96    }
97
98    /// Build the HNSW index from the given data.
99    ///
100    /// # Parameters
101    /// - `data`: A FixedSizeList to build the HNSW.
102    /// - `distance_type`: The distance type to use.
103    pub async fn build(self, data: ArrayRef, distance_type: DistanceType) -> Result<HNSW> {
104        let vectors = data.as_fixed_size_list().clone();
105        match (vectors.value_type(), distance_type) {
106            (DataType::UInt8, DistanceType::Hamming) => {
107                let vec_store = Arc::new(FlatBinStorage::new(vectors, distance_type));
108                HNSW::index_vectors(vec_store.as_ref(), self)
109            }
110            (DataType::UInt8, _) => Err(Error::invalid_input(format!(
111                "HNSW only supports hamming distance for UInt8 vectors, got {}",
112                distance_type
113            ))),
114            (_, DistanceType::Hamming) => Err(Error::invalid_input(format!(
115                "HNSW hamming distance only supports UInt8 vectors, got {}",
116                vectors.value_type()
117            ))),
118            _ => {
119                let vec_store = Arc::new(FlatFloatStorage::new(vectors, distance_type));
120                HNSW::index_vectors(vec_store.as_ref(), self)
121            }
122        }
123    }
124}
125
126/// Build a HNSW graph.
127///
128/// Currently, the HNSW graph is fully built in memory.
129///
130/// During the build, the graph is built layer by layer.
131///
132/// Each node in the graph has a global ID which is the index on the base layer.
133#[derive(Clone, DeepSizeOf)]
134pub struct HNSW {
135    inner: Arc<HnswCore>,
136}
137
138struct HnswCore {
139    params: HnswBuildParams,
140    nodes: Arc<Vec<GraphBuilderNode>>,
141    level_count: Vec<usize>,
142    entry_point: u32,
143    visited_generator_queue: Arc<ArrayQueue<VisitedGenerator>>,
144}
145
146impl DeepSizeOf for HnswCore {
147    fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
148        self.params.deep_size_of_children(context)
149            + self.nodes.deep_size_of_children(context)
150            + self.level_count.deep_size_of_children(context)
151        // Skipping the visited_generator_queue
152    }
153}
154
155impl HnswCore {
156    fn max_level(&self) -> u16 {
157        self.params.max_level
158    }
159
160    fn num_nodes(&self, level: usize) -> usize {
161        self.level_count[level]
162    }
163
164    fn nodes(&self) -> Arc<Vec<GraphBuilderNode>> {
165        self.nodes.clone()
166    }
167}
168
169impl Debug for HNSW {
170    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
171        write!(f, "HNSW(max_layers: {})", self.inner.max_level() as usize,)
172    }
173}
174
175impl HNSW {
176    pub fn empty() -> Self {
177        Self {
178            inner: Arc::new(HnswCore {
179                params: HnswBuildParams::default(),
180                nodes: Arc::new(Vec::new()),
181                level_count: Vec::new(),
182                entry_point: 0,
183                visited_generator_queue: Arc::new(ArrayQueue::new(1)),
184            }),
185        }
186    }
187
188    pub fn len(&self) -> usize {
189        self.inner.nodes.len()
190    }
191
192    pub fn is_empty(&self) -> bool {
193        self.len() == 0
194    }
195
196    pub fn max_level(&self) -> u16 {
197        self.inner.max_level()
198    }
199
200    pub fn num_nodes(&self, level: usize) -> usize {
201        self.inner.num_nodes(level)
202    }
203
204    pub fn nodes(&self) -> Arc<Vec<GraphBuilderNode>> {
205        self.inner.nodes()
206    }
207
208    #[allow(clippy::too_many_arguments)]
209    pub fn search_inner(
210        &self,
211        query: ArrayRef,
212        k: usize,
213        params: &HnswQueryParams,
214        bitset: Option<Visited>,
215        visited_generator: &mut VisitedGenerator,
216        storage: &impl VectorStore,
217        prefetch_distance: Option<usize>,
218    ) -> Result<Vec<OrderedNode>> {
219        let dist_calc = storage.dist_calculator(query, params.dist_q_c);
220        let mut ep = OrderedNode::new(0, dist_calc.distance(0).into());
221        let nodes = self.inner.nodes.as_ref();
222        for level in (0..self.max_level()).rev() {
223            let cur_level = ImmutableHnswLevelView::new(level, nodes);
224            ep = greedy_search_borrowed(
225                &cur_level,
226                ep,
227                &dist_calc,
228                self.inner.params.prefetch_distance,
229            );
230        }
231
232        let bottom_level = ImmutableHnswBottomView::new(nodes);
233        let mut visited = visited_generator.generate(storage.len());
234        Ok(beam_search_borrowed(
235            &bottom_level,
236            &ep,
237            params,
238            &dist_calc,
239            bitset.as_ref(),
240            prefetch_distance,
241            &mut visited,
242        )
243        .into_iter()
244        .take(k)
245        .collect())
246    }
247
248    #[instrument(level = "debug", skip(self, query, bitset, storage))]
249    pub fn search_basic(
250        &self,
251        query: ArrayRef,
252        k: usize,
253        params: &HnswQueryParams,
254        bitset: Option<Visited>,
255        storage: &impl VectorStore,
256    ) -> Result<Vec<OrderedNode>> {
257        let mut visited_generator = self
258            .inner
259            .visited_generator_queue
260            .pop()
261            .unwrap_or_else(|| VisitedGenerator::new(storage.len()));
262        let result = self.search_inner(
263            query,
264            k,
265            params,
266            bitset,
267            &mut visited_generator,
268            storage,
269            Some(2),
270        );
271
272        match self.inner.visited_generator_queue.push(visited_generator) {
273            Ok(_) => {}
274            Err(_) => {
275                log::warn!("visited_generator_queue is full");
276            }
277        }
278
279        result
280    }
281
282    #[instrument(level = "debug", skip(self, storage, query, prefilter_bitset))]
283    fn flat_search(
284        &self,
285        storage: &impl VectorStore,
286        query: ArrayRef,
287        k: usize,
288        prefilter_bitset: Visited,
289        params: &HnswQueryParams,
290    ) -> Vec<OrderedNode> {
291        let lower_bound: OrderedFloat = params.lower_bound.unwrap_or(f32::MIN).into();
292        let upper_bound: OrderedFloat = params.upper_bound.unwrap_or(f32::MAX).into();
293
294        let dist_calc = storage.dist_calculator(query, params.dist_q_c);
295        let mut heap = BinaryHeap::<OrderedNode>::with_capacity(k);
296
297        match self.inner.params.prefetch_distance {
298            Some(ahead) if ahead > 0 => {
299                let mut ids_iter = prefilter_bitset.iter_ones().map(|i| i as u32);
300                let mut buffer = VecDeque::with_capacity(ahead + 1);
301                for _ in 0..=ahead {
302                    if let Some(id) = ids_iter.next() {
303                        buffer.push_back(id);
304                    } else {
305                        break;
306                    }
307                }
308
309                while let Some(node_id) = buffer.pop_front() {
310                    if let Some(&prefetch_id) = buffer.get(ahead - 1) {
311                        dist_calc.prefetch(prefetch_id);
312                    }
313                    if let Some(next) = ids_iter.next() {
314                        buffer.push_back(next);
315                    }
316
317                    let dist: OrderedFloat = dist_calc.distance(node_id).into();
318                    if dist <= lower_bound || dist > upper_bound {
319                        continue;
320                    }
321                    if heap.len() < k {
322                        heap.push((dist, node_id).into());
323                    } else if dist < heap.peek().unwrap().dist {
324                        heap.pop();
325                        heap.push((dist, node_id).into());
326                    }
327                }
328            }
329            _ => {
330                for node_id in prefilter_bitset.iter_ones().map(|i| i as u32) {
331                    let dist: OrderedFloat = dist_calc.distance(node_id).into();
332                    if dist <= lower_bound || dist > upper_bound {
333                        continue;
334                    }
335                    if heap.len() < k {
336                        heap.push((dist, node_id).into());
337                    } else if dist < heap.peek().unwrap().dist {
338                        heap.pop();
339                        heap.push((dist, node_id).into());
340                    }
341                }
342            }
343        };
344        heap.into_sorted_vec()
345    }
346
347    /// Returns the metadata of this [`HNSW`].
348    pub fn metadata(&self) -> HnswMetadata {
349        // calculate the offsets of each level,
350        // start from 0
351        let level_offsets = self
352            .inner
353            .level_count
354            .iter()
355            .chain(iter::once(&0))
356            .scan(0, |state, x| {
357                let start = *state;
358                *state += *x;
359                Some(start)
360            })
361            .collect();
362
363        HnswMetadata {
364            entry_point: self.inner.entry_point,
365            params: self.inner.params.clone(),
366            level_offsets,
367        }
368    }
369}
370
371struct HnswBuilder {
372    params: HnswBuildParams,
373
374    nodes: Arc<Vec<RwLock<GraphBuilderNode>>>,
375    level_count: Vec<AtomicUsize>,
376
377    entry_point: u32,
378
379    visited_generator_queue: Arc<ArrayQueue<VisitedGenerator>>,
380}
381
382impl DeepSizeOf for HnswBuilder {
383    fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
384        self.params.deep_size_of_children(context)
385            + self.nodes.deep_size_of_children(context)
386            + self.level_count.deep_size_of_children(context)
387        // Skipping the visited_generator_queue
388    }
389}
390
391impl HnswBuilder {
392    fn finish(self) -> HNSW {
393        let nodes = match Arc::try_unwrap(self.nodes) {
394            Ok(nodes) => nodes
395                .into_iter()
396                .map(|node| node.into_inner().expect("builder lock poisoned"))
397                .collect(),
398            Err(nodes) => nodes
399                .iter()
400                .map(|node| node.read().expect("builder lock poisoned").clone())
401                .collect(),
402        };
403
404        let level_count = self
405            .level_count
406            .into_iter()
407            .map(|count| count.load(Ordering::Relaxed))
408            .collect();
409
410        HNSW {
411            inner: Arc::new(HnswCore {
412                params: self.params,
413                nodes: Arc::new(nodes),
414                level_count,
415                entry_point: self.entry_point,
416                visited_generator_queue: self.visited_generator_queue,
417            }),
418        }
419    }
420
421    /// Create a new [`HNSWBuilder`] with prepared params and in memory vector storage.
422    pub fn with_params(params: HnswBuildParams, storage: &impl VectorStore) -> Self {
423        let len = storage.len();
424        let max_level = params.max_level;
425
426        let level_count = (0..max_level)
427            .map(|_| AtomicUsize::new(0))
428            .collect::<Vec<_>>();
429
430        let visited_generator_queue = Arc::new(ArrayQueue::new(get_num_compute_intensive_cpus()));
431        for _ in 0..get_num_compute_intensive_cpus() {
432            visited_generator_queue
433                .push(VisitedGenerator::new(0))
434                .unwrap();
435        }
436        let mut builder = Self {
437            params,
438            nodes: Arc::new(Vec::new()),
439            level_count,
440            entry_point: 0,
441            visited_generator_queue,
442        };
443
444        if storage.is_empty() {
445            return builder;
446        }
447
448        let mut nodes = Vec::with_capacity(len);
449        {
450            if len > 0 {
451                nodes.push(RwLock::new(GraphBuilderNode::new(0, max_level as usize)));
452            }
453            let mut level_rng = rng();
454            for i in 1..len {
455                nodes.push(RwLock::new(GraphBuilderNode::new(
456                    i as u32,
457                    builder.random_level(&mut level_rng) as usize + 1,
458                )));
459            }
460        }
461        builder.nodes = Arc::new(nodes);
462
463        builder
464    }
465
466    /// New node's level
467    ///
468    /// See paper `Algorithm 1`
469    fn random_level<R: Rng + ?Sized>(&self, rng: &mut R) -> u16 {
470        let ml = 1.0 / (self.params.m as f32).ln();
471        min(
472            (-rng.random::<f32>().ln() * ml) as u16,
473            self.params.max_level - 1,
474        )
475    }
476
477    /// Insert one node.
478    fn insert(
479        &self,
480        node: u32,
481        visited_generator: &mut VisitedGenerator,
482        storage: &impl VectorStore,
483    ) {
484        let nodes = &self.nodes;
485        let target_level = nodes[node as usize].read().unwrap().level_neighbors.len() as u16 - 1;
486        let dist_calc = storage.dist_calculator_from_id(node);
487        let mut ep = OrderedNode::new(
488            self.entry_point,
489            dist_calc.distance(self.entry_point).into(),
490        );
491
492        //
493        // Search for entry point in paper.
494        // ```
495        //   for l_c in (L..l+1) {
496        //     W = Search-Layer(q, ep, ef=1, l_c)
497        //    ep = Select-Neighbors(W, 1)
498        //  }
499        // ```
500        for level in (target_level + 1..self.params.max_level).rev() {
501            let cur_level = HnswLevelView::new(level, nodes);
502            ep = greedy_search(&cur_level, ep, &dist_calc, self.params.prefetch_distance);
503        }
504
505        let mut pruned_neighbors_per_level: Vec<Vec<_>> =
506            vec![Vec::new(); (target_level + 1) as usize];
507        {
508            let mut current_node = nodes[node as usize].write().unwrap();
509            for level in (0..=target_level).rev() {
510                self.level_count[level as usize].fetch_add(1, Ordering::Relaxed);
511
512                let neighbors = self.search_level(&ep, level, &dist_calc, nodes, visited_generator);
513                for neighbor in &neighbors {
514                    current_node.add_neighbor(neighbor.id, neighbor.dist, level);
515                }
516                self.prune(storage, &mut current_node, level);
517                pruned_neighbors_per_level[level as usize]
518                    .clone_from(&current_node.level_neighbors_ranked[level as usize]);
519
520                ep = neighbors[0].clone();
521            }
522        }
523        for (level, pruned_neighbors) in pruned_neighbors_per_level.iter().enumerate() {
524            let _: Vec<_> = pruned_neighbors
525                .iter()
526                .map(|unpruned_edge| {
527                    let level = level as u16;
528                    let m_max = match level {
529                        0 => self.params.m * 2,
530                        _ => self.params.m,
531                    };
532                    if unpruned_edge.dist
533                        < nodes[unpruned_edge.id as usize]
534                            .read()
535                            .unwrap()
536                            .cutoff(level, m_max)
537                    {
538                        let mut chosen_node = nodes[unpruned_edge.id as usize].write().unwrap();
539                        chosen_node.add_neighbor(node, unpruned_edge.dist, level);
540                        self.prune(storage, &mut chosen_node, level);
541                    }
542                })
543                .collect();
544        }
545    }
546
547    fn search_level(
548        &self,
549        ep: &OrderedNode,
550        level: u16,
551        dist_calc: &impl DistCalculator,
552        nodes: &[RwLock<GraphBuilderNode>],
553        visited_generator: &mut VisitedGenerator,
554    ) -> Vec<OrderedNode> {
555        let cur_level = HnswLevelView::new(level, nodes);
556        let mut visited = visited_generator.generate(nodes.len());
557        beam_search(
558            &cur_level,
559            ep,
560            &HnswQueryParams {
561                ef: self.params.ef_construction,
562                lower_bound: None,
563                upper_bound: None,
564                dist_q_c: 0.0,
565            },
566            dist_calc,
567            None,
568            self.params.prefetch_distance,
569            &mut visited,
570        )
571    }
572
573    fn prune(&self, storage: &impl VectorStore, builder_node: &mut GraphBuilderNode, level: u16) {
574        let m_max = match level {
575            0 => self.params.m * 2,
576            _ => self.params.m,
577        };
578
579        let neighbors_ranked = &mut builder_node.level_neighbors_ranked[level as usize];
580        let level_neighbors = neighbors_ranked.clone();
581        if level_neighbors.len() <= m_max {
582            builder_node.update_from_ranked_neighbors(level);
583            return;
584        }
585
586        *neighbors_ranked = select_neighbors_heuristic(storage, &level_neighbors, m_max);
587        builder_node.update_from_ranked_neighbors(level);
588    }
589}
590
591// View of a level in HNSW graph.
592// This is used to iterate over neighbors in a specific level.
593pub(crate) struct HnswLevelView<'a> {
594    level: u16,
595    nodes: &'a [RwLock<GraphBuilderNode>],
596}
597
598impl<'a> HnswLevelView<'a> {
599    pub fn new(level: u16, nodes: &'a [RwLock<GraphBuilderNode>]) -> Self {
600        Self { level, nodes }
601    }
602}
603
604impl Graph for HnswLevelView<'_> {
605    fn len(&self) -> usize {
606        self.nodes.len()
607    }
608
609    fn neighbors(&self, key: u32) -> Arc<Vec<u32>> {
610        let node = &self.nodes[key as usize];
611        node.read().unwrap().level_neighbors[self.level as usize].clone()
612    }
613}
614
615pub(crate) struct ImmutableHnswLevelView<'a> {
616    level: u16,
617    nodes: &'a [GraphBuilderNode],
618}
619
620impl<'a> ImmutableHnswLevelView<'a> {
621    pub fn new(level: u16, nodes: &'a [GraphBuilderNode]) -> Self {
622        Self { level, nodes }
623    }
624}
625
626impl Graph for ImmutableHnswLevelView<'_> {
627    fn len(&self) -> usize {
628        self.nodes.len()
629    }
630
631    fn neighbors(&self, key: u32) -> Arc<Vec<u32>> {
632        self.nodes[key as usize].level_neighbors[self.level as usize].clone()
633    }
634}
635
636impl BorrowingGraph for ImmutableHnswLevelView<'_> {
637    fn len(&self) -> usize {
638        self.nodes.len()
639    }
640
641    fn neighbors(&self, key: u32) -> &[u32] {
642        self.nodes[key as usize].level_neighbors[self.level as usize].as_slice()
643    }
644}
645
646pub(crate) struct ImmutableHnswBottomView<'a> {
647    nodes: &'a [GraphBuilderNode],
648}
649
650impl<'a> ImmutableHnswBottomView<'a> {
651    pub fn new(nodes: &'a [GraphBuilderNode]) -> Self {
652        Self { nodes }
653    }
654}
655
656impl Graph for ImmutableHnswBottomView<'_> {
657    fn len(&self) -> usize {
658        self.nodes.len()
659    }
660
661    fn neighbors(&self, key: u32) -> Arc<Vec<u32>> {
662        self.nodes[key as usize].bottom_neighbors.clone()
663    }
664}
665
666impl BorrowingGraph for ImmutableHnswBottomView<'_> {
667    fn len(&self) -> usize {
668        self.nodes.len()
669    }
670
671    fn neighbors(&self, key: u32) -> &[u32] {
672        self.nodes[key as usize].bottom_neighbors.as_slice()
673    }
674}
675
676#[derive(Debug, Clone, Copy)]
677pub struct HnswQueryParams {
678    pub ef: usize,
679    pub lower_bound: Option<f32>,
680    pub upper_bound: Option<f32>,
681    pub dist_q_c: f32,
682}
683
684impl From<&Query> for HnswQueryParams {
685    fn from(query: &Query) -> Self {
686        let k = query.k * query.refine_factor.unwrap_or(1) as usize;
687        Self {
688            ef: query.ef.unwrap_or(k + k / 2),
689            lower_bound: query.lower_bound,
690            upper_bound: query.upper_bound,
691            dist_q_c: query.dist_q_c,
692        }
693    }
694}
695
696impl IvfSubIndex for HNSW {
697    type BuildParams = HnswBuildParams;
698    type QueryParams = HnswQueryParams;
699
700    fn load(data: RecordBatch) -> Result<Self>
701    where
702        Self: Sized,
703    {
704        if data.num_rows() == 0 {
705            return Ok(Self::empty());
706        }
707
708        let hnsw_metadata = data
709            .schema_ref()
710            .metadata()
711            .get(HNSW_METADATA_KEY)
712            .ok_or(Error::index(format!("{} not found", HNSW_METADATA_KEY)))?;
713        let hnsw_metadata: HnswMetadata = serde_json::from_str(hnsw_metadata).map_err(|e| {
714            Error::index(format!(
715                "Failed to decode HNSW metadata: {}, json: {}",
716                e, hnsw_metadata
717            ))
718        })?;
719
720        let levels: Vec<_> = hnsw_metadata
721            .level_offsets
722            .iter()
723            .tuple_windows()
724            .map(|(start, end)| data.slice(*start, end - start))
725            .collect();
726
727        let level_count = levels.iter().map(|b| b.num_rows()).collect::<Vec<_>>();
728
729        let bottom_level_len = levels[0].num_rows();
730        let mut nodes = Vec::with_capacity(bottom_level_len);
731        for i in 0..bottom_level_len {
732            nodes.push(GraphBuilderNode::new(i as u32, levels.len()));
733        }
734        for (level, batch) in levels.into_iter().enumerate() {
735            let ids = batch[VECTOR_ID_COL].as_primitive::<UInt32Type>();
736            let neighbors = batch[NEIGHBORS_COL].as_list::<i32>();
737            let distances = batch[DIST_COL].as_list::<i32>();
738
739            for ((node, neighbors), distances) in
740                ids.iter().zip(neighbors.iter()).zip(distances.iter())
741            {
742                let node = node.unwrap();
743                let neighbors = neighbors.as_ref().unwrap().as_primitive::<UInt32Type>();
744                let distances = distances.as_ref().unwrap().as_primitive::<Float32Type>();
745
746                nodes[node as usize].level_neighbors_ranked[level] = neighbors
747                    .iter()
748                    .zip(distances.iter())
749                    .map(|(n, dist)| OrderedNode::new(n.unwrap(), OrderedFloat(dist.unwrap())))
750                    .collect();
751                nodes[node as usize].update_from_ranked_neighbors(level as u16);
752            }
753        }
754
755        let visited_generator_queue =
756            Arc::new(ArrayQueue::new(get_num_compute_intensive_cpus() * 2));
757        for _ in 0..get_num_compute_intensive_cpus() * 2 {
758            visited_generator_queue
759                .push(VisitedGenerator::new(0))
760                .unwrap();
761        }
762        let inner = HnswCore {
763            params: hnsw_metadata.params,
764            nodes: Arc::new(nodes),
765            level_count,
766            entry_point: hnsw_metadata.entry_point,
767            visited_generator_queue,
768        };
769
770        Ok(Self {
771            inner: Arc::new(inner),
772        })
773    }
774
775    fn name() -> &'static str {
776        HNSW_TYPE
777    }
778
779    fn metadata_key() -> &'static str {
780        "lance:hnsw"
781    }
782
783    /// Return the schema of the sub index
784    fn schema() -> arrow_schema::SchemaRef {
785        arrow_schema::Schema::new(vec![
786            VECTOR_ID_FIELD.clone(),
787            NEIGHBORS_FIELD.clone(),
788            DISTS_FIELD.clone(),
789        ])
790        .into()
791    }
792
793    #[instrument(level = "debug", skip(self, query, storage, prefilter, _metrics))]
794    fn search(
795        &self,
796        query: ArrayRef,
797        k: usize,
798        params: Self::QueryParams,
799        storage: &impl VectorStore,
800        prefilter: Arc<dyn PreFilter>,
801        _metrics: &dyn MetricsCollector,
802    ) -> Result<RecordBatch> {
803        if params.ef < k {
804            return Err(Error::index(
805                "ef must be greater than or equal to k".to_string(),
806            ));
807        }
808
809        let schema = VECTOR_RESULT_SCHEMA.clone();
810        if self.is_empty() {
811            return Ok(RecordBatch::new_empty(schema));
812        }
813
814        let mut prefilter_generator = self
815            .inner
816            .visited_generator_queue
817            .pop()
818            .unwrap_or_else(|| VisitedGenerator::new(storage.len()));
819        let prefilter_bitset = if prefilter.is_empty() {
820            None
821        } else {
822            let indices = prefilter.filter_row_ids(Box::new(storage.row_ids()));
823            let mut bitset = prefilter_generator.generate(storage.len());
824            for indices in indices {
825                bitset.insert(indices as u32);
826            }
827            Some(bitset)
828        };
829
830        let remained = prefilter_bitset
831            .as_ref()
832            .map(|b| b.count_ones())
833            .unwrap_or(storage.len());
834        let results = if remained < self.len() * 10 / 100 {
835            let prefilter_bitset =
836                prefilter_bitset.expect("the prefilter bitset must be set for flat search");
837            self.flat_search(storage, query, k, prefilter_bitset, &params)
838        } else {
839            self.search_basic(query, k, &params, prefilter_bitset, storage)?
840        };
841        // if the queue is full, we just don't push it back, so ignore the error here
842        let _ = self.inner.visited_generator_queue.push(prefilter_generator);
843
844        // need to unique by row ids in case of searching multivector
845        let (row_ids, dists): (Vec<_>, Vec<_>) = results
846            .into_iter()
847            .map(|r| (storage.row_id(r.id), r.dist.0))
848            .unique_by(|r| r.0)
849            .unzip();
850        let row_ids = Arc::new(UInt64Array::from(row_ids));
851        let distances = Arc::new(Float32Array::from(dists));
852
853        Ok(RecordBatch::try_new(schema, vec![distances, row_ids])?)
854    }
855
856    /// Given a vector storage, containing all the data for the IVF partition, build the sub index.
857    fn index_vectors(storage: &impl VectorStore, params: Self::BuildParams) -> Result<Self>
858    where
859        Self: Sized,
860    {
861        let builder = HnswBuilder::with_params(params, storage);
862
863        log::debug!(
864            "Building HNSW graph: num={}, max_levels={}, m={}, ef_construction={}, distance_type:{}",
865            storage.len(),
866            builder.params.max_level,
867            builder.params.m,
868            builder.params.ef_construction,
869            storage.distance_type(),
870        );
871
872        if storage.is_empty() {
873            return Ok(builder.finish());
874        }
875
876        let len = storage.len();
877        builder.level_count[0].fetch_add(1, Ordering::Relaxed);
878        (1..len).into_par_iter().for_each_init(
879            || VisitedGenerator::new(len),
880            |visited_generator, node| {
881                builder.insert(node as u32, visited_generator, storage);
882            },
883        );
884
885        assert_eq!(builder.level_count[0].load(Ordering::Relaxed), len);
886        Ok(builder.finish())
887    }
888
889    fn remap(
890        &self,
891        _mapping: &HashMap<u64, Option<u64>>, // we don't need the mapping here because we rebuild the graph from remapped storage
892        store: &impl VectorStore,
893    ) -> Result<Self> {
894        // We can't simply remap the row ids in the graph because the vectors are changed,
895        // so the graph needs to be rebuilt.
896        Self::index_vectors(store, self.inner.params.clone())
897    }
898
899    /// Encode the sub index into a record batch
900    fn to_batch(&self) -> Result<RecordBatch> {
901        let mut vector_id_builder = UInt32Builder::with_capacity(self.len());
902        let mut neighbors_builder = ListBuilder::with_capacity(UInt32Builder::new(), self.len());
903        let mut distances_builder =
904            ListBuilder::with_capacity(arrow_array::builder::Float32Builder::new(), self.len());
905        let mut batches = Vec::with_capacity(self.max_level() as usize);
906        for level in 0..self.max_level() {
907            let level = level as usize;
908            for (id, node) in self.inner.nodes.iter().enumerate() {
909                if level >= node.level_neighbors.len() {
910                    continue;
911                }
912                let neighbors = node.level_neighbors[level].iter().map(|n| Some(*n));
913                let distances = node.level_neighbors_ranked[level]
914                    .iter()
915                    .map(|n| Some(n.dist.0));
916                vector_id_builder.append_value(id as u32);
917                neighbors_builder.append_value(neighbors);
918                distances_builder.append_value(distances);
919            }
920
921            let batch = RecordBatch::try_new(
922                Self::schema(),
923                vec![
924                    Arc::new(vector_id_builder.finish()),
925                    Arc::new(neighbors_builder.finish()),
926                    Arc::new(distances_builder.finish()),
927                ],
928            )?;
929            batches.push(batch);
930        }
931
932        let metadata = self.metadata();
933        let metadata = serde_json::to_string(&metadata)?;
934        let schema = Self::schema()
935            .as_ref()
936            .clone()
937            .with_metadata(HashMap::from_iter(vec![(
938                HNSW_METADATA_KEY.to_string(),
939                metadata,
940            )]));
941        let batch = concat_batches(&Self::schema(), batches.iter())?;
942        let batch = batch.with_schema(Arc::new(schema))?;
943        Ok(batch)
944    }
945}
946
947#[cfg(test)]
948mod tests {
949    use std::sync::Arc;
950
951    use arrow_array::{FixedSizeListArray, UInt8Array};
952    use arrow_schema::Schema;
953    use lance_arrow::FixedSizeListArrayExt;
954    use lance_file::previous::{
955        reader::FileReader as PreviousFileReader,
956        writer::{
957            FileWriter as PreviousFileWriter, FileWriterOptions as PreviousFileWriterOptions,
958        },
959    };
960    use lance_io::object_store::ObjectStore;
961    use lance_linalg::distance::DistanceType;
962    use lance_table::format::SelfDescribingFileReader;
963    use lance_table::io::manifest::ManifestDescribing;
964    use lance_testing::datagen::generate_random_array;
965    use object_store::path::Path;
966
967    use crate::scalar::IndexWriter;
968    use crate::vector::v3::subindex::IvfSubIndex;
969    use crate::vector::{
970        flat::storage::{FlatBinStorage, FlatFloatStorage},
971        graph::{DISTS_FIELD, NEIGHBORS_FIELD},
972        hnsw::{
973            HNSW, VECTOR_ID_FIELD,
974            builder::{HnswBuildParams, HnswQueryParams},
975        },
976    };
977
978    #[tokio::test]
979    async fn test_builder_write_load() {
980        const DIM: usize = 32;
981        const TOTAL: usize = 2048;
982        const NUM_EDGES: usize = 20;
983        let data = generate_random_array(TOTAL * DIM);
984        let fsl = FixedSizeListArray::try_new_from_values(data, DIM as i32).unwrap();
985        let store = Arc::new(FlatFloatStorage::new(fsl.clone(), DistanceType::L2));
986        let builder = HNSW::index_vectors(
987            store.as_ref(),
988            HnswBuildParams::default()
989                .num_edges(NUM_EDGES)
990                .ef_construction(50),
991        )
992        .unwrap();
993
994        let object_store = ObjectStore::memory();
995        let path = Path::from("test_builder_write_load");
996        let writer = object_store.create(&path).await.unwrap();
997        let schema = Schema::new(vec![
998            VECTOR_ID_FIELD.clone(),
999            NEIGHBORS_FIELD.clone(),
1000            DISTS_FIELD.clone(),
1001        ]);
1002        let schema = lance_core::datatypes::Schema::try_from(&schema).unwrap();
1003        let mut writer = PreviousFileWriter::<ManifestDescribing>::with_object_writer(
1004            writer,
1005            schema,
1006            &PreviousFileWriterOptions::default(),
1007        )
1008        .unwrap();
1009        let batch = builder.to_batch().unwrap();
1010        let metadata = batch.schema_ref().metadata().clone();
1011        writer.write_record_batch(batch).await.unwrap();
1012        writer.finish_with_metadata(&metadata).await.unwrap();
1013
1014        let reader = PreviousFileReader::try_new_self_described(&object_store, &path, None)
1015            .await
1016            .unwrap();
1017        let batch = reader
1018            .read_range(0..reader.len(), reader.schema())
1019            .await
1020            .unwrap();
1021        let loaded_hnsw = HNSW::load(batch).unwrap();
1022
1023        let query = fsl.value(0);
1024        let k = 10;
1025        let params = HnswQueryParams {
1026            ef: 50,
1027            lower_bound: None,
1028            upper_bound: None,
1029            dist_q_c: 0.0,
1030        };
1031        let builder_results = builder
1032            .search_basic(query.clone(), k, &params, None, store.as_ref())
1033            .unwrap();
1034        let loaded_results = loaded_hnsw
1035            .search_basic(query, k, &params, None, store.as_ref())
1036            .unwrap();
1037        assert_eq!(builder_results, loaded_results);
1038    }
1039
1040    #[tokio::test]
1041    async fn test_builder_write_load_binary_hamming() {
1042        const DIM: usize = 8;
1043        const TOTAL: usize = 256;
1044        const NUM_EDGES: usize = 20;
1045        let data = UInt8Array::from_iter_values((0..TOTAL * DIM).map(|v| (v % 16) as u8));
1046        let fsl = FixedSizeListArray::try_new_from_values(data, DIM as i32).unwrap();
1047        let store = Arc::new(FlatBinStorage::new(fsl.clone(), DistanceType::Hamming));
1048        let builder = HnswBuildParams::default()
1049            .num_edges(NUM_EDGES)
1050            .ef_construction(50)
1051            .build(Arc::new(fsl.clone()), DistanceType::Hamming)
1052            .await
1053            .unwrap();
1054
1055        let object_store = ObjectStore::memory();
1056        let path = Path::from("test_builder_write_load_binary_hamming");
1057        let writer = object_store.create(&path).await.unwrap();
1058        let schema = Schema::new(vec![
1059            VECTOR_ID_FIELD.clone(),
1060            NEIGHBORS_FIELD.clone(),
1061            DISTS_FIELD.clone(),
1062        ]);
1063        let schema = lance_core::datatypes::Schema::try_from(&schema).unwrap();
1064        let mut writer = PreviousFileWriter::<ManifestDescribing>::with_object_writer(
1065            writer,
1066            schema,
1067            &PreviousFileWriterOptions::default(),
1068        )
1069        .unwrap();
1070        let batch = builder.to_batch().unwrap();
1071        let metadata = batch.schema_ref().metadata().clone();
1072        writer.write_record_batch(batch).await.unwrap();
1073        writer.finish_with_metadata(&metadata).await.unwrap();
1074
1075        let reader = PreviousFileReader::try_new_self_described(&object_store, &path, None)
1076            .await
1077            .unwrap();
1078        let batch = reader
1079            .read_range(0..reader.len(), reader.schema())
1080            .await
1081            .unwrap();
1082        let loaded_hnsw = HNSW::load(batch).unwrap();
1083
1084        let query = fsl.value(0);
1085        let k = 10;
1086        let params = HnswQueryParams {
1087            ef: 50,
1088            lower_bound: None,
1089            upper_bound: None,
1090            dist_q_c: 0.0,
1091        };
1092        let builder_results = builder
1093            .search_basic(query.clone(), k, &params, None, store.as_ref())
1094            .unwrap();
1095        let loaded_results = loaded_hnsw
1096            .search_basic(query, k, &params, None, store.as_ref())
1097            .unwrap();
1098        assert_eq!(builder_results, loaded_results);
1099    }
1100}