Skip to main content

lance_index/vector/hnsw/
builder.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Builder of Hnsw Graph.
5
6use arrow::array::{AsArray, ListBuilder, UInt32Builder};
7use arrow::compute::concat_batches;
8use arrow::datatypes::{DataType, Float32Type, UInt32Type};
9use arrow_array::{ArrayRef, Float32Array, RecordBatch, UInt64Array};
10use crossbeam_queue::ArrayQueue;
11use deepsize::DeepSizeOf;
12use itertools::Itertools;
13
14use lance_core::utils::tokio::get_num_compute_intensive_cpus;
15use lance_linalg::distance::DistanceType;
16use rayon::prelude::*;
17use std::cmp::min;
18use std::collections::{BinaryHeap, HashMap, VecDeque};
19use std::fmt::Debug;
20use std::iter;
21use std::sync::Arc;
22use std::sync::RwLock;
23use std::sync::atomic::{AtomicUsize, Ordering};
24use tracing::instrument;
25
26use lance_core::{Error, Result};
27use rand::{Rng, SeedableRng, rngs::SmallRng};
28use serde::{Deserialize, Serialize};
29
30use super::super::graph::beam_search;
31use super::{HNSW_TYPE, HnswMetadata, VECTOR_ID_COL, VECTOR_ID_FIELD, select_neighbors_heuristic};
32use crate::metrics::MetricsCollector;
33use crate::prefilter::PreFilter;
34use crate::vector::flat::storage::{FlatBinStorage, FlatFloatStorage};
35use crate::vector::graph::builder::GraphBuilderNode;
36use crate::vector::graph::{
37    BorrowingGraph, DISTS_FIELD, Graph, NEIGHBORS_COL, NEIGHBORS_FIELD, OrderedFloat, OrderedNode,
38    VisitedGenerator,
39};
40use crate::vector::graph::{Visited, beam_search_borrowed, greedy_search, greedy_search_borrowed};
41use crate::vector::storage::{DistCalculator, VectorStore};
42use crate::vector::v3::subindex::IvfSubIndex;
43use crate::vector::{DIST_COL, Query, VECTOR_RESULT_SCHEMA};
44
45pub const HNSW_METADATA_KEY: &str = "lance:hnsw";
46
47/// Fixed seed for HNSW node-level assignment.
48///
49/// A constant seed makes graph construction reproducible (same data + params =>
50/// same graph), which keeps index builds deterministic and tests stable. Recall
51/// is statistically unaffected — the level distribution is identical, only the
52/// random draws become fixed. Shared by the offline ([`HNSWBuilder`]) and online
53/// ([`super::online::OnlineHnswBuilder`]) builders so both produce comparable graphs.
54pub(crate) const HNSW_LEVEL_RNG_SEED: u64 = 42;
55
56/// Parameters of building HNSW index
57#[derive(Debug, Clone, Serialize, Deserialize, DeepSizeOf)]
58pub struct HnswBuildParams {
59    /// max level ofm
60    pub max_level: u16,
61
62    /// number of connections to establish while inserting new element
63    pub m: usize,
64
65    /// size of the dynamic list for the candidates
66    pub ef_construction: usize,
67
68    /// number of vectors ahead to prefetch while building the graph
69    pub prefetch_distance: Option<usize>,
70}
71
72impl From<&HnswBuildParams> for crate::pb::HnswParameters {
73    fn from(params: &HnswBuildParams) -> Self {
74        Self {
75            max_connections: params.m as u32,
76            construction_ef: params.ef_construction as u32,
77            max_level: params.max_level as u32,
78        }
79    }
80}
81
82impl Default for HnswBuildParams {
83    fn default() -> Self {
84        Self {
85            max_level: 7,
86            m: 20,
87            ef_construction: 150,
88            prefetch_distance: Some(2),
89        }
90    }
91}
92
93impl HnswBuildParams {
94    /// The maximum level of the graph.
95    /// The default value is `8`.
96    pub fn max_level(mut self, max_level: u16) -> Self {
97        self.max_level = max_level;
98        self
99    }
100
101    /// The number of connections to establish while inserting new element
102    /// The default value is `30`.
103    pub fn num_edges(mut self, m: usize) -> Self {
104        self.m = m;
105        self
106    }
107
108    /// Number of candidates to be considered when searching for the nearest neighbors
109    /// during the construction of the graph.
110    ///
111    /// The default value is `100`.
112    pub fn ef_construction(mut self, ef_construction: usize) -> Self {
113        self.ef_construction = ef_construction;
114        self
115    }
116
117    /// Build the HNSW index from the given data.
118    ///
119    /// # Parameters
120    /// - `data`: A FixedSizeList to build the HNSW.
121    /// - `distance_type`: The distance type to use.
122    pub async fn build(self, data: ArrayRef, distance_type: DistanceType) -> Result<HNSW> {
123        let vectors = data.as_fixed_size_list().clone();
124        match (vectors.value_type(), distance_type) {
125            (DataType::UInt8, DistanceType::Hamming) => {
126                let vec_store = Arc::new(FlatBinStorage::new(vectors, distance_type));
127                HNSW::index_vectors(vec_store.as_ref(), self)
128            }
129            (DataType::UInt8, _) => Err(Error::invalid_input(format!(
130                "HNSW only supports hamming distance for UInt8 vectors, got {}",
131                distance_type
132            ))),
133            (_, DistanceType::Hamming) => Err(Error::invalid_input(format!(
134                "HNSW hamming distance only supports UInt8 vectors, got {}",
135                vectors.value_type()
136            ))),
137            _ => {
138                let vec_store = Arc::new(FlatFloatStorage::new(vectors, distance_type));
139                HNSW::index_vectors(vec_store.as_ref(), self)
140            }
141        }
142    }
143}
144
145/// Build a HNSW graph.
146///
147/// Currently, the HNSW graph is fully built in memory.
148///
149/// During the build, the graph is built layer by layer.
150///
151/// Each node in the graph has a global ID which is the index on the base layer.
152#[derive(Clone, DeepSizeOf)]
153pub struct HNSW {
154    inner: Arc<HnswCore>,
155}
156
157struct HnswCore {
158    params: HnswBuildParams,
159    nodes: Arc<Vec<GraphBuilderNode>>,
160    level_count: Vec<usize>,
161    entry_point: u32,
162    visited_generator_queue: Arc<ArrayQueue<VisitedGenerator>>,
163}
164
165impl DeepSizeOf for HnswCore {
166    fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
167        self.params.deep_size_of_children(context)
168            + self.nodes.deep_size_of_children(context)
169            + self.level_count.deep_size_of_children(context)
170        // Skipping the visited_generator_queue
171    }
172}
173
174impl HnswCore {
175    fn max_level(&self) -> u16 {
176        self.params.max_level
177    }
178
179    fn num_nodes(&self, level: usize) -> usize {
180        self.level_count[level]
181    }
182
183    fn nodes(&self) -> Arc<Vec<GraphBuilderNode>> {
184        self.nodes.clone()
185    }
186}
187
188impl Debug for HNSW {
189    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190        write!(f, "HNSW(max_layers: {})", self.inner.max_level() as usize,)
191    }
192}
193
194impl HNSW {
195    /// Construct an HNSW from its constituent parts. Used by the online
196    /// builder when finalizing.
197    pub(crate) fn from_parts(
198        params: HnswBuildParams,
199        nodes: Vec<GraphBuilderNode>,
200        level_count: Vec<usize>,
201        entry_point: u32,
202    ) -> Self {
203        let queue_size = get_num_compute_intensive_cpus().max(1) * 2;
204        let visited_generator_queue = Arc::new(ArrayQueue::new(queue_size));
205        for _ in 0..queue_size {
206            let _ = visited_generator_queue.push(VisitedGenerator::new(0));
207        }
208        Self {
209            inner: Arc::new(HnswCore {
210                params,
211                nodes: Arc::new(nodes),
212                level_count,
213                entry_point,
214                visited_generator_queue,
215            }),
216        }
217    }
218
219    pub fn empty() -> Self {
220        Self {
221            inner: Arc::new(HnswCore {
222                params: HnswBuildParams::default(),
223                nodes: Arc::new(Vec::new()),
224                level_count: Vec::new(),
225                entry_point: 0,
226                visited_generator_queue: Arc::new(ArrayQueue::new(1)),
227            }),
228        }
229    }
230
231    pub fn len(&self) -> usize {
232        self.inner.nodes.len()
233    }
234
235    pub fn is_empty(&self) -> bool {
236        self.len() == 0
237    }
238
239    pub fn max_level(&self) -> u16 {
240        self.inner.max_level()
241    }
242
243    pub fn num_nodes(&self, level: usize) -> usize {
244        self.inner.num_nodes(level)
245    }
246
247    pub fn nodes(&self) -> Arc<Vec<GraphBuilderNode>> {
248        self.inner.nodes()
249    }
250
251    #[allow(clippy::too_many_arguments)]
252    pub fn search_inner(
253        &self,
254        query: ArrayRef,
255        k: usize,
256        params: &HnswQueryParams,
257        bitset: Option<Visited>,
258        visited_generator: &mut VisitedGenerator,
259        storage: &impl VectorStore,
260        prefetch_distance: Option<usize>,
261    ) -> Result<Vec<OrderedNode>> {
262        let dist_calc = storage.dist_calculator(query, params.dist_q_c);
263        let entry = self.inner.entry_point;
264        let mut ep = OrderedNode::new(entry, dist_calc.distance(entry).into());
265        let nodes = self.inner.nodes.as_ref();
266        for level in (0..self.max_level()).rev() {
267            let cur_level = ImmutableHnswLevelView::new(level, nodes);
268            ep = greedy_search_borrowed(
269                &cur_level,
270                ep,
271                &dist_calc,
272                self.inner.params.prefetch_distance,
273            );
274        }
275
276        let bottom_level = ImmutableHnswBottomView::new(nodes);
277        let mut visited = visited_generator.generate(storage.len());
278        Ok(beam_search_borrowed(
279            &bottom_level,
280            &ep,
281            params,
282            &dist_calc,
283            bitset.as_ref(),
284            prefetch_distance,
285            &mut visited,
286        )
287        .into_iter()
288        .take(k)
289        .collect())
290    }
291
292    #[instrument(level = "debug", skip(self, query, bitset, storage))]
293    pub fn search_basic(
294        &self,
295        query: ArrayRef,
296        k: usize,
297        params: &HnswQueryParams,
298        bitset: Option<Visited>,
299        storage: &impl VectorStore,
300    ) -> Result<Vec<OrderedNode>> {
301        let mut visited_generator = self
302            .inner
303            .visited_generator_queue
304            .pop()
305            .unwrap_or_else(|| VisitedGenerator::new(storage.len()));
306        let result = self.search_inner(
307            query,
308            k,
309            params,
310            bitset,
311            &mut visited_generator,
312            storage,
313            Some(2),
314        );
315
316        match self.inner.visited_generator_queue.push(visited_generator) {
317            Ok(_) => {}
318            Err(_) => {
319                log::warn!("visited_generator_queue is full");
320            }
321        }
322
323        result
324    }
325
326    #[instrument(level = "debug", skip(self, storage, query, prefilter_bitset))]
327    fn flat_search(
328        &self,
329        storage: &impl VectorStore,
330        query: ArrayRef,
331        k: usize,
332        prefilter_bitset: Visited,
333        params: &HnswQueryParams,
334    ) -> Vec<OrderedNode> {
335        let lower_bound: OrderedFloat = params.lower_bound.unwrap_or(f32::MIN).into();
336        let upper_bound: OrderedFloat = params.upper_bound.unwrap_or(f32::MAX).into();
337
338        let dist_calc = storage.dist_calculator(query, params.dist_q_c);
339        let mut heap = BinaryHeap::<OrderedNode>::with_capacity(k);
340
341        match self.inner.params.prefetch_distance {
342            Some(ahead) if ahead > 0 => {
343                let mut ids_iter = prefilter_bitset.iter_ones().map(|i| i as u32);
344                let mut buffer = VecDeque::with_capacity(ahead + 1);
345                for _ in 0..=ahead {
346                    if let Some(id) = ids_iter.next() {
347                        buffer.push_back(id);
348                    } else {
349                        break;
350                    }
351                }
352
353                while let Some(node_id) = buffer.pop_front() {
354                    if let Some(&prefetch_id) = buffer.get(ahead - 1) {
355                        dist_calc.prefetch(prefetch_id);
356                    }
357                    if let Some(next) = ids_iter.next() {
358                        buffer.push_back(next);
359                    }
360
361                    let dist: OrderedFloat = dist_calc.distance(node_id).into();
362                    if dist <= lower_bound || dist > upper_bound {
363                        continue;
364                    }
365                    if heap.len() < k {
366                        heap.push((dist, node_id).into());
367                    } else if dist < heap.peek().unwrap().dist {
368                        heap.pop();
369                        heap.push((dist, node_id).into());
370                    }
371                }
372            }
373            _ => {
374                for node_id in prefilter_bitset.iter_ones().map(|i| i as u32) {
375                    let dist: OrderedFloat = dist_calc.distance(node_id).into();
376                    if dist <= lower_bound || dist > upper_bound {
377                        continue;
378                    }
379                    if heap.len() < k {
380                        heap.push((dist, node_id).into());
381                    } else if dist < heap.peek().unwrap().dist {
382                        heap.pop();
383                        heap.push((dist, node_id).into());
384                    }
385                }
386            }
387        };
388        heap.into_sorted_vec()
389    }
390
391    /// Returns the metadata of this [`HNSW`].
392    pub fn metadata(&self) -> HnswMetadata {
393        // calculate the offsets of each level,
394        // start from 0
395        let level_offsets = self
396            .inner
397            .level_count
398            .iter()
399            .chain(iter::once(&0))
400            .scan(0, |state, x| {
401                let start = *state;
402                *state += *x;
403                Some(start)
404            })
405            .collect();
406
407        HnswMetadata {
408            entry_point: self.inner.entry_point,
409            params: self.inner.params.clone(),
410            level_offsets,
411        }
412    }
413}
414
415struct HnswBuilder {
416    params: HnswBuildParams,
417
418    nodes: Arc<Vec<RwLock<GraphBuilderNode>>>,
419    level_count: Vec<AtomicUsize>,
420
421    entry_point: u32,
422
423    visited_generator_queue: Arc<ArrayQueue<VisitedGenerator>>,
424}
425
426impl DeepSizeOf for HnswBuilder {
427    fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
428        self.params.deep_size_of_children(context)
429            + self.nodes.deep_size_of_children(context)
430            + self.level_count.deep_size_of_children(context)
431        // Skipping the visited_generator_queue
432    }
433}
434
435impl HnswBuilder {
436    fn finish(self) -> HNSW {
437        let nodes = match Arc::try_unwrap(self.nodes) {
438            Ok(nodes) => nodes
439                .into_iter()
440                .map(|node| node.into_inner().expect("builder lock poisoned"))
441                .collect(),
442            Err(nodes) => nodes
443                .iter()
444                .map(|node| node.read().expect("builder lock poisoned").clone())
445                .collect(),
446        };
447
448        let level_count = self
449            .level_count
450            .into_iter()
451            .map(|count| count.load(Ordering::Relaxed))
452            .collect();
453
454        HNSW {
455            inner: Arc::new(HnswCore {
456                params: self.params,
457                nodes: Arc::new(nodes),
458                level_count,
459                entry_point: self.entry_point,
460                visited_generator_queue: self.visited_generator_queue,
461            }),
462        }
463    }
464
465    /// Create a new [`HNSWBuilder`] with prepared params and in memory vector storage.
466    pub fn with_params(params: HnswBuildParams, storage: &impl VectorStore) -> Self {
467        let len = storage.len();
468        let max_level = params.max_level;
469
470        let level_count = (0..max_level)
471            .map(|_| AtomicUsize::new(0))
472            .collect::<Vec<_>>();
473
474        let visited_generator_queue = Arc::new(ArrayQueue::new(get_num_compute_intensive_cpus()));
475        for _ in 0..get_num_compute_intensive_cpus() {
476            visited_generator_queue
477                .push(VisitedGenerator::new(0))
478                .unwrap();
479        }
480        let mut builder = Self {
481            params,
482            nodes: Arc::new(Vec::new()),
483            level_count,
484            entry_point: 0,
485            visited_generator_queue,
486        };
487
488        if storage.is_empty() {
489            return builder;
490        }
491
492        let mut nodes = Vec::with_capacity(len);
493        {
494            if len > 0 {
495                nodes.push(RwLock::new(GraphBuilderNode::new(0, max_level as usize)));
496            }
497            let mut level_rng = SmallRng::seed_from_u64(HNSW_LEVEL_RNG_SEED);
498            for i in 1..len {
499                nodes.push(RwLock::new(GraphBuilderNode::new(
500                    i as u32,
501                    builder.random_level(&mut level_rng) as usize + 1,
502                )));
503            }
504        }
505        builder.nodes = Arc::new(nodes);
506
507        builder
508    }
509
510    /// New node's level
511    ///
512    /// See paper `Algorithm 1`
513    fn random_level<R: Rng + ?Sized>(&self, rng: &mut R) -> u16 {
514        let ml = 1.0 / (self.params.m as f32).ln();
515        min(
516            (-rng.random::<f32>().ln() * ml) as u16,
517            self.params.max_level - 1,
518        )
519    }
520
521    /// Insert one node.
522    fn insert(
523        &self,
524        node: u32,
525        visited_generator: &mut VisitedGenerator,
526        storage: &impl VectorStore,
527    ) {
528        let nodes = &self.nodes;
529        let target_level = nodes[node as usize].read().unwrap().level_neighbors.len() as u16 - 1;
530        let dist_calc = storage.dist_calculator_from_id(node);
531        let mut ep = OrderedNode::new(
532            self.entry_point,
533            dist_calc.distance(self.entry_point).into(),
534        );
535
536        //
537        // Search for entry point in paper.
538        // ```
539        //   for l_c in (L..l+1) {
540        //     W = Search-Layer(q, ep, ef=1, l_c)
541        //    ep = Select-Neighbors(W, 1)
542        //  }
543        // ```
544        for level in (target_level + 1..self.params.max_level).rev() {
545            let cur_level = HnswLevelView::new(level, nodes);
546            ep = greedy_search(&cur_level, ep, &dist_calc, self.params.prefetch_distance);
547        }
548
549        let mut pruned_neighbors_per_level: Vec<Vec<_>> =
550            vec![Vec::new(); (target_level + 1) as usize];
551        {
552            let mut current_node = nodes[node as usize].write().unwrap();
553            for level in (0..=target_level).rev() {
554                self.level_count[level as usize].fetch_add(1, Ordering::Relaxed);
555
556                let neighbors = self.search_level(&ep, level, &dist_calc, nodes, visited_generator);
557                for neighbor in &neighbors {
558                    current_node.add_neighbor(neighbor.id, neighbor.dist, level);
559                }
560                self.prune(storage, &mut current_node, level);
561                pruned_neighbors_per_level[level as usize]
562                    .clone_from(&current_node.level_neighbors_ranked[level as usize]);
563
564                ep = neighbors[0].clone();
565            }
566        }
567        for (level, pruned_neighbors) in pruned_neighbors_per_level.iter().enumerate() {
568            let _: Vec<_> = pruned_neighbors
569                .iter()
570                .map(|unpruned_edge| {
571                    let level = level as u16;
572                    let m_max = match level {
573                        0 => self.params.m * 2,
574                        _ => self.params.m,
575                    };
576                    if unpruned_edge.dist
577                        < nodes[unpruned_edge.id as usize]
578                            .read()
579                            .unwrap()
580                            .cutoff(level, m_max)
581                    {
582                        let mut chosen_node = nodes[unpruned_edge.id as usize].write().unwrap();
583                        chosen_node.add_neighbor(node, unpruned_edge.dist, level);
584                        self.prune(storage, &mut chosen_node, level);
585                    }
586                })
587                .collect();
588        }
589    }
590
591    fn search_level(
592        &self,
593        ep: &OrderedNode,
594        level: u16,
595        dist_calc: &impl DistCalculator,
596        nodes: &[RwLock<GraphBuilderNode>],
597        visited_generator: &mut VisitedGenerator,
598    ) -> Vec<OrderedNode> {
599        let cur_level = HnswLevelView::new(level, nodes);
600        let mut visited = visited_generator.generate(nodes.len());
601        beam_search(
602            &cur_level,
603            ep,
604            &HnswQueryParams {
605                ef: self.params.ef_construction,
606                lower_bound: None,
607                upper_bound: None,
608                dist_q_c: 0.0,
609            },
610            dist_calc,
611            None,
612            self.params.prefetch_distance,
613            &mut visited,
614        )
615    }
616
617    fn prune(&self, storage: &impl VectorStore, builder_node: &mut GraphBuilderNode, level: u16) {
618        let m_max = match level {
619            0 => self.params.m * 2,
620            _ => self.params.m,
621        };
622
623        let neighbors_ranked = &mut builder_node.level_neighbors_ranked[level as usize];
624        let level_neighbors = neighbors_ranked.clone();
625        if level_neighbors.len() <= m_max {
626            builder_node.update_from_ranked_neighbors(level);
627            return;
628        }
629
630        *neighbors_ranked = select_neighbors_heuristic(storage, &level_neighbors, m_max);
631        builder_node.update_from_ranked_neighbors(level);
632    }
633}
634
635// View of a level in HNSW graph.
636// This is used to iterate over neighbors in a specific level.
637pub(crate) struct HnswLevelView<'a> {
638    level: u16,
639    nodes: &'a [RwLock<GraphBuilderNode>],
640}
641
642impl<'a> HnswLevelView<'a> {
643    pub fn new(level: u16, nodes: &'a [RwLock<GraphBuilderNode>]) -> Self {
644        Self { level, nodes }
645    }
646}
647
648impl Graph for HnswLevelView<'_> {
649    fn len(&self) -> usize {
650        self.nodes.len()
651    }
652
653    fn neighbors(&self, key: u32) -> Arc<Vec<u32>> {
654        let node = &self.nodes[key as usize];
655        node.read().unwrap().level_neighbors[self.level as usize].clone()
656    }
657}
658
659pub(crate) struct ImmutableHnswLevelView<'a> {
660    level: u16,
661    nodes: &'a [GraphBuilderNode],
662}
663
664impl<'a> ImmutableHnswLevelView<'a> {
665    pub fn new(level: u16, nodes: &'a [GraphBuilderNode]) -> Self {
666        Self { level, nodes }
667    }
668}
669
670impl Graph for ImmutableHnswLevelView<'_> {
671    fn len(&self) -> usize {
672        self.nodes.len()
673    }
674
675    fn neighbors(&self, key: u32) -> Arc<Vec<u32>> {
676        self.nodes[key as usize].level_neighbors[self.level as usize].clone()
677    }
678}
679
680impl BorrowingGraph for ImmutableHnswLevelView<'_> {
681    fn len(&self) -> usize {
682        self.nodes.len()
683    }
684
685    fn neighbors(&self, key: u32) -> &[u32] {
686        self.nodes[key as usize].level_neighbors[self.level as usize].as_slice()
687    }
688}
689
690pub(crate) struct ImmutableHnswBottomView<'a> {
691    nodes: &'a [GraphBuilderNode],
692}
693
694impl<'a> ImmutableHnswBottomView<'a> {
695    pub fn new(nodes: &'a [GraphBuilderNode]) -> Self {
696        Self { nodes }
697    }
698}
699
700impl Graph for ImmutableHnswBottomView<'_> {
701    fn len(&self) -> usize {
702        self.nodes.len()
703    }
704
705    fn neighbors(&self, key: u32) -> Arc<Vec<u32>> {
706        self.nodes[key as usize].bottom_neighbors.clone()
707    }
708}
709
710impl BorrowingGraph for ImmutableHnswBottomView<'_> {
711    fn len(&self) -> usize {
712        self.nodes.len()
713    }
714
715    fn neighbors(&self, key: u32) -> &[u32] {
716        self.nodes[key as usize].bottom_neighbors.as_slice()
717    }
718}
719
720#[derive(Debug, Clone, Copy)]
721pub struct HnswQueryParams {
722    pub ef: usize,
723    pub lower_bound: Option<f32>,
724    pub upper_bound: Option<f32>,
725    pub dist_q_c: f32,
726}
727
728impl From<&Query> for HnswQueryParams {
729    fn from(query: &Query) -> Self {
730        let k = query.k * query.refine_factor.unwrap_or(1) as usize;
731        Self {
732            ef: query.ef.unwrap_or(k + k / 2),
733            lower_bound: query.lower_bound,
734            upper_bound: query.upper_bound,
735            dist_q_c: query.dist_q_c,
736        }
737    }
738}
739
740impl IvfSubIndex for HNSW {
741    type BuildParams = HnswBuildParams;
742    type QueryParams = HnswQueryParams;
743
744    fn load(data: RecordBatch) -> Result<Self>
745    where
746        Self: Sized,
747    {
748        if data.num_rows() == 0 {
749            return Ok(Self::empty());
750        }
751
752        let hnsw_metadata = data
753            .schema_ref()
754            .metadata()
755            .get(HNSW_METADATA_KEY)
756            .ok_or(Error::index(format!("{} not found", HNSW_METADATA_KEY)))?;
757        let hnsw_metadata: HnswMetadata = serde_json::from_str(hnsw_metadata).map_err(|e| {
758            Error::index(format!(
759                "Failed to decode HNSW metadata: {}, json: {}",
760                e, hnsw_metadata
761            ))
762        })?;
763
764        let levels: Vec<_> = hnsw_metadata
765            .level_offsets
766            .iter()
767            .tuple_windows()
768            .map(|(start, end)| data.slice(*start, end - start))
769            .collect();
770
771        let level_count = levels.iter().map(|b| b.num_rows()).collect::<Vec<_>>();
772
773        let bottom_level_len = levels[0].num_rows();
774        let mut nodes = Vec::with_capacity(bottom_level_len);
775        for i in 0..bottom_level_len {
776            nodes.push(GraphBuilderNode::new(i as u32, levels.len()));
777        }
778        for (level, batch) in levels.into_iter().enumerate() {
779            let ids = batch[VECTOR_ID_COL].as_primitive::<UInt32Type>();
780            let neighbors = batch[NEIGHBORS_COL].as_list::<i32>();
781            let distances = batch[DIST_COL].as_list::<i32>();
782
783            for ((node, neighbors), distances) in
784                ids.iter().zip(neighbors.iter()).zip(distances.iter())
785            {
786                let node = node.unwrap();
787                let neighbors = neighbors.as_ref().unwrap().as_primitive::<UInt32Type>();
788                let distances = distances.as_ref().unwrap().as_primitive::<Float32Type>();
789
790                nodes[node as usize].level_neighbors_ranked[level] = neighbors
791                    .iter()
792                    .zip(distances.iter())
793                    .map(|(n, dist)| OrderedNode::new(n.unwrap(), OrderedFloat(dist.unwrap())))
794                    .collect();
795                nodes[node as usize].update_from_ranked_neighbors(level as u16);
796            }
797        }
798
799        let visited_generator_queue =
800            Arc::new(ArrayQueue::new(get_num_compute_intensive_cpus() * 2));
801        for _ in 0..get_num_compute_intensive_cpus() * 2 {
802            visited_generator_queue
803                .push(VisitedGenerator::new(0))
804                .unwrap();
805        }
806        let inner = HnswCore {
807            params: hnsw_metadata.params,
808            nodes: Arc::new(nodes),
809            level_count,
810            entry_point: hnsw_metadata.entry_point,
811            visited_generator_queue,
812        };
813
814        Ok(Self {
815            inner: Arc::new(inner),
816        })
817    }
818
819    fn name() -> &'static str {
820        HNSW_TYPE
821    }
822
823    fn metadata_key() -> &'static str {
824        "lance:hnsw"
825    }
826
827    /// Return the schema of the sub index
828    fn schema() -> arrow_schema::SchemaRef {
829        arrow_schema::Schema::new(vec![
830            VECTOR_ID_FIELD.clone(),
831            NEIGHBORS_FIELD.clone(),
832            DISTS_FIELD.clone(),
833        ])
834        .into()
835    }
836
837    #[instrument(level = "debug", skip(self, query, storage, prefilter, _metrics))]
838    fn search(
839        &self,
840        query: ArrayRef,
841        k: usize,
842        params: Self::QueryParams,
843        storage: &impl VectorStore,
844        prefilter: Arc<dyn PreFilter>,
845        _metrics: &dyn MetricsCollector,
846    ) -> Result<RecordBatch> {
847        if params.ef < k {
848            return Err(Error::index(
849                "ef must be greater than or equal to k".to_string(),
850            ));
851        }
852
853        let schema = VECTOR_RESULT_SCHEMA.clone();
854        if self.is_empty() {
855            return Ok(RecordBatch::new_empty(schema));
856        }
857
858        let mut prefilter_generator = self
859            .inner
860            .visited_generator_queue
861            .pop()
862            .unwrap_or_else(|| VisitedGenerator::new(storage.len()));
863        let prefilter_bitset = if prefilter.is_empty() {
864            None
865        } else {
866            let indices = prefilter.filter_row_ids(Box::new(storage.row_ids()));
867            let mut bitset = prefilter_generator.generate(storage.len());
868            for indices in indices {
869                bitset.insert(indices as u32);
870            }
871            Some(bitset)
872        };
873
874        let remained = prefilter_bitset
875            .as_ref()
876            .map(|b| b.count_ones())
877            .unwrap_or(storage.len());
878        let results = if remained < self.len() * 10 / 100 {
879            let prefilter_bitset =
880                prefilter_bitset.expect("the prefilter bitset must be set for flat search");
881            self.flat_search(storage, query, k, prefilter_bitset, &params)
882        } else {
883            self.search_basic(query, k, &params, prefilter_bitset, storage)?
884        };
885        // if the queue is full, we just don't push it back, so ignore the error here
886        let _ = self.inner.visited_generator_queue.push(prefilter_generator);
887
888        // need to unique by row ids in case of searching multivector
889        let (row_ids, dists): (Vec<_>, Vec<_>) = results
890            .into_iter()
891            .map(|r| (storage.row_id(r.id), r.dist.0))
892            .unique_by(|r| r.0)
893            .unzip();
894        let row_ids = Arc::new(UInt64Array::from(row_ids));
895        let distances = Arc::new(Float32Array::from(dists));
896
897        Ok(RecordBatch::try_new(schema, vec![distances, row_ids])?)
898    }
899
900    /// Given a vector storage, containing all the data for the IVF partition, build the sub index.
901    fn index_vectors(storage: &impl VectorStore, params: Self::BuildParams) -> Result<Self>
902    where
903        Self: Sized,
904    {
905        let builder = HnswBuilder::with_params(params, storage);
906
907        log::debug!(
908            "Building HNSW graph: num={}, max_levels={}, m={}, ef_construction={}, distance_type:{}",
909            storage.len(),
910            builder.params.max_level,
911            builder.params.m,
912            builder.params.ef_construction,
913            storage.distance_type(),
914        );
915
916        if storage.is_empty() {
917            return Ok(builder.finish());
918        }
919
920        let len = storage.len();
921        builder.level_count[0].fetch_add(1, Ordering::Relaxed);
922        (1..len).into_par_iter().for_each_init(
923            || VisitedGenerator::new(len),
924            |visited_generator, node| {
925                builder.insert(node as u32, visited_generator, storage);
926            },
927        );
928
929        assert_eq!(builder.level_count[0].load(Ordering::Relaxed), len);
930        Ok(builder.finish())
931    }
932
933    fn remap(
934        &self,
935        _mapping: &HashMap<u64, Option<u64>>, // we don't need the mapping here because we rebuild the graph from remapped storage
936        store: &impl VectorStore,
937    ) -> Result<Self> {
938        // We can't simply remap the row ids in the graph because the vectors are changed,
939        // so the graph needs to be rebuilt.
940        Self::index_vectors(store, self.inner.params.clone())
941    }
942
943    /// Encode the sub index into a record batch
944    fn to_batch(&self) -> Result<RecordBatch> {
945        let mut vector_id_builder = UInt32Builder::with_capacity(self.len());
946        let mut neighbors_builder = ListBuilder::with_capacity(UInt32Builder::new(), self.len());
947        let mut distances_builder =
948            ListBuilder::with_capacity(arrow_array::builder::Float32Builder::new(), self.len());
949        let mut batches = Vec::with_capacity(self.max_level() as usize);
950        for level in 0..self.max_level() {
951            let level = level as usize;
952            for (id, node) in self.inner.nodes.iter().enumerate() {
953                if level >= node.level_neighbors.len() {
954                    continue;
955                }
956                let neighbors = node.level_neighbors[level].iter().map(|n| Some(*n));
957                let distances = node.level_neighbors_ranked[level]
958                    .iter()
959                    .map(|n| Some(n.dist.0));
960                vector_id_builder.append_value(id as u32);
961                neighbors_builder.append_value(neighbors);
962                distances_builder.append_value(distances);
963            }
964
965            let batch = RecordBatch::try_new(
966                Self::schema(),
967                vec![
968                    Arc::new(vector_id_builder.finish()),
969                    Arc::new(neighbors_builder.finish()),
970                    Arc::new(distances_builder.finish()),
971                ],
972            )?;
973            batches.push(batch);
974        }
975
976        let metadata = self.metadata();
977        let metadata = serde_json::to_string(&metadata)?;
978        let schema = Self::schema()
979            .as_ref()
980            .clone()
981            .with_metadata(HashMap::from_iter(vec![(
982                HNSW_METADATA_KEY.to_string(),
983                metadata,
984            )]));
985        let batch = concat_batches(&Self::schema(), batches.iter())?;
986        let batch = batch.with_schema(Arc::new(schema))?;
987        Ok(batch)
988    }
989}
990
991#[cfg(test)]
992mod tests {
993    use std::sync::Arc;
994
995    use arrow_array::{FixedSizeListArray, UInt8Array};
996    use arrow_schema::Schema;
997    use lance_arrow::FixedSizeListArrayExt;
998    use lance_file::previous::{
999        reader::FileReader as PreviousFileReader,
1000        writer::{
1001            FileWriter as PreviousFileWriter, FileWriterOptions as PreviousFileWriterOptions,
1002        },
1003    };
1004    use lance_io::object_store::ObjectStore;
1005    use lance_linalg::distance::DistanceType;
1006    use lance_table::format::SelfDescribingFileReader;
1007    use lance_table::io::manifest::ManifestDescribing;
1008    use lance_testing::datagen::generate_random_array;
1009    use object_store::path::Path;
1010
1011    use crate::scalar::IndexWriter;
1012    use crate::vector::v3::subindex::IvfSubIndex;
1013    use crate::vector::{
1014        flat::storage::{FlatBinStorage, FlatFloatStorage},
1015        graph::{DISTS_FIELD, NEIGHBORS_FIELD},
1016        hnsw::{
1017            HNSW, VECTOR_ID_FIELD,
1018            builder::{HnswBuildParams, HnswQueryParams},
1019        },
1020    };
1021
1022    #[tokio::test]
1023    async fn test_builder_write_load() {
1024        const DIM: usize = 32;
1025        const TOTAL: usize = 2048;
1026        const NUM_EDGES: usize = 20;
1027        let data = generate_random_array(TOTAL * DIM);
1028        let fsl = FixedSizeListArray::try_new_from_values(data, DIM as i32).unwrap();
1029        let store = Arc::new(FlatFloatStorage::new(fsl.clone(), DistanceType::L2));
1030        let builder = HNSW::index_vectors(
1031            store.as_ref(),
1032            HnswBuildParams::default()
1033                .num_edges(NUM_EDGES)
1034                .ef_construction(50),
1035        )
1036        .unwrap();
1037
1038        let object_store = ObjectStore::memory();
1039        let path = Path::from("test_builder_write_load");
1040        let writer = object_store.create(&path).await.unwrap();
1041        let schema = Schema::new(vec![
1042            VECTOR_ID_FIELD.clone(),
1043            NEIGHBORS_FIELD.clone(),
1044            DISTS_FIELD.clone(),
1045        ]);
1046        let schema = lance_core::datatypes::Schema::try_from(&schema).unwrap();
1047        let mut writer = PreviousFileWriter::<ManifestDescribing>::with_object_writer(
1048            writer,
1049            schema,
1050            &PreviousFileWriterOptions::default(),
1051        )
1052        .unwrap();
1053        let batch = builder.to_batch().unwrap();
1054        let metadata = batch.schema_ref().metadata().clone();
1055        writer.write_record_batch(batch).await.unwrap();
1056        writer.finish_with_metadata(&metadata).await.unwrap();
1057
1058        let reader = PreviousFileReader::try_new_self_described(&object_store, &path, None)
1059            .await
1060            .unwrap();
1061        let batch = reader
1062            .read_range(0..reader.len(), reader.schema())
1063            .await
1064            .unwrap();
1065        let loaded_hnsw = HNSW::load(batch).unwrap();
1066
1067        let query = fsl.value(0);
1068        let k = 10;
1069        let params = HnswQueryParams {
1070            ef: 50,
1071            lower_bound: None,
1072            upper_bound: None,
1073            dist_q_c: 0.0,
1074        };
1075        let builder_results = builder
1076            .search_basic(query.clone(), k, &params, None, store.as_ref())
1077            .unwrap();
1078        let loaded_results = loaded_hnsw
1079            .search_basic(query, k, &params, None, store.as_ref())
1080            .unwrap();
1081        assert_eq!(builder_results, loaded_results);
1082    }
1083
1084    #[tokio::test]
1085    async fn test_builder_write_load_binary_hamming() {
1086        const DIM: usize = 8;
1087        const TOTAL: usize = 256;
1088        const NUM_EDGES: usize = 20;
1089        let data = UInt8Array::from_iter_values((0..TOTAL * DIM).map(|v| (v % 16) as u8));
1090        let fsl = FixedSizeListArray::try_new_from_values(data, DIM as i32).unwrap();
1091        let store = Arc::new(FlatBinStorage::new(fsl.clone(), DistanceType::Hamming));
1092        let builder = HnswBuildParams::default()
1093            .num_edges(NUM_EDGES)
1094            .ef_construction(50)
1095            .build(Arc::new(fsl.clone()), DistanceType::Hamming)
1096            .await
1097            .unwrap();
1098
1099        let object_store = ObjectStore::memory();
1100        let path = Path::from("test_builder_write_load_binary_hamming");
1101        let writer = object_store.create(&path).await.unwrap();
1102        let schema = Schema::new(vec![
1103            VECTOR_ID_FIELD.clone(),
1104            NEIGHBORS_FIELD.clone(),
1105            DISTS_FIELD.clone(),
1106        ]);
1107        let schema = lance_core::datatypes::Schema::try_from(&schema).unwrap();
1108        let mut writer = PreviousFileWriter::<ManifestDescribing>::with_object_writer(
1109            writer,
1110            schema,
1111            &PreviousFileWriterOptions::default(),
1112        )
1113        .unwrap();
1114        let batch = builder.to_batch().unwrap();
1115        let metadata = batch.schema_ref().metadata().clone();
1116        writer.write_record_batch(batch).await.unwrap();
1117        writer.finish_with_metadata(&metadata).await.unwrap();
1118
1119        let reader = PreviousFileReader::try_new_self_described(&object_store, &path, None)
1120            .await
1121            .unwrap();
1122        let batch = reader
1123            .read_range(0..reader.len(), reader.schema())
1124            .await
1125            .unwrap();
1126        let loaded_hnsw = HNSW::load(batch).unwrap();
1127
1128        let query = fsl.value(0);
1129        let k = 10;
1130        let params = HnswQueryParams {
1131            ef: 50,
1132            lower_bound: None,
1133            upper_bound: None,
1134            dist_q_c: 0.0,
1135        };
1136        let builder_results = builder
1137            .search_basic(query.clone(), k, &params, None, store.as_ref())
1138            .unwrap();
1139        let loaded_results = loaded_hnsw
1140            .search_basic(query, k, &params, None, store.as_ref())
1141            .unwrap();
1142        assert_eq!(builder_results, loaded_results);
1143    }
1144}