lance_index/vector/hnsw/
builder.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Builder of Hnsw Graph.
5
6use arrow::array::{AsArray, ListBuilder, UInt32Builder};
7use arrow::compute::concat_batches;
8use arrow::datatypes::{Float32Type, UInt32Type};
9use arrow_array::{ArrayRef, Float32Array, RecordBatch, UInt64Array};
10use crossbeam_queue::ArrayQueue;
11use deepsize::DeepSizeOf;
12use itertools::Itertools;
13
14use lance_core::utils::tokio::get_num_compute_intensive_cpus;
15use lance_linalg::distance::DistanceType;
16use rayon::prelude::*;
17use snafu::location;
18use std::cmp::min;
19use std::collections::{BinaryHeap, HashMap, VecDeque};
20use std::fmt::Debug;
21use std::iter;
22use std::sync::atomic::{AtomicUsize, Ordering};
23use std::sync::Arc;
24use std::sync::RwLock;
25use tracing::instrument;
26
27use lance_core::{Error, Result};
28use rand::{rng, Rng};
29use serde::{Deserialize, Serialize};
30
31use super::super::graph::beam_search;
32use super::{select_neighbors_heuristic, HnswMetadata, HNSW_TYPE, VECTOR_ID_COL, VECTOR_ID_FIELD};
33use crate::metrics::MetricsCollector;
34use crate::prefilter::PreFilter;
35use crate::vector::flat::storage::FlatFloatStorage;
36use crate::vector::graph::builder::GraphBuilderNode;
37use crate::vector::graph::{greedy_search, Visited};
38use crate::vector::graph::{
39    Graph, OrderedFloat, OrderedNode, VisitedGenerator, DISTS_FIELD, NEIGHBORS_COL, NEIGHBORS_FIELD,
40};
41use crate::vector::storage::{DistCalculator, VectorStore};
42use crate::vector::v3::subindex::IvfSubIndex;
43use crate::vector::{Query, DIST_COL, VECTOR_RESULT_SCHEMA};
44
45pub const HNSW_METADATA_KEY: &str = "lance:hnsw";
46
47/// Parameters of building HNSW index
48#[derive(Debug, Clone, Serialize, Deserialize, DeepSizeOf)]
49pub struct HnswBuildParams {
50    /// max level ofm
51    pub max_level: u16,
52
53    /// number of connections to establish while inserting new element
54    pub m: usize,
55
56    /// size of the dynamic list for the candidates
57    pub ef_construction: usize,
58
59    /// number of vectors ahead to prefetch while building the graph
60    pub prefetch_distance: Option<usize>,
61}
62
63impl Default for HnswBuildParams {
64    fn default() -> Self {
65        Self {
66            max_level: 7,
67            m: 20,
68            ef_construction: 150,
69            prefetch_distance: Some(2),
70        }
71    }
72}
73
74impl HnswBuildParams {
75    /// The maximum level of the graph.
76    /// The default value is `8`.
77    pub fn max_level(mut self, max_level: u16) -> Self {
78        self.max_level = max_level;
79        self
80    }
81
82    /// The number of connections to establish while inserting new element
83    /// The default value is `30`.
84    pub fn num_edges(mut self, m: usize) -> Self {
85        self.m = m;
86        self
87    }
88
89    /// Number of candidates to be considered when searching for the nearest neighbors
90    /// during the construction of the graph.
91    ///
92    /// The default value is `100`.
93    pub fn ef_construction(mut self, ef_construction: usize) -> Self {
94        self.ef_construction = ef_construction;
95        self
96    }
97
98    /// Build the HNSW index from the given data.
99    ///
100    /// # Parameters
101    /// - `data`: A FixedSizeList to build the HNSW.
102    /// - `distance_type`: The distance type to use.
103    pub async fn build(self, data: ArrayRef, distance_type: DistanceType) -> Result<HNSW> {
104        let vec_store = Arc::new(FlatFloatStorage::new(
105            data.as_fixed_size_list().clone(),
106            distance_type,
107        ));
108        HNSW::index_vectors(vec_store.as_ref(), self)
109    }
110}
111
112/// Build a HNSW graph.
113///
114/// Currently, the HNSW graph is fully built in memory.
115///
116/// During the build, the graph is built layer by layer.
117///
118/// Each node in the graph has a global ID which is the index on the base layer.
119#[derive(Clone, DeepSizeOf)]
120pub struct HNSW {
121    inner: Arc<HnswBuilder>,
122}
123
124impl Debug for HNSW {
125    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126        write!(f, "HNSW(max_layers: {})", self.inner.max_level() as usize,)
127    }
128}
129
130impl HNSW {
131    pub fn empty() -> Self {
132        Self {
133            inner: Arc::new(HnswBuilder {
134                params: HnswBuildParams::default(),
135                nodes: Arc::new(Vec::new()),
136                level_count: Vec::new(),
137                entry_point: 0,
138                visited_generator_queue: Arc::new(ArrayQueue::new(1)),
139            }),
140        }
141    }
142
143    pub fn len(&self) -> usize {
144        self.inner.nodes.len()
145    }
146
147    pub fn is_empty(&self) -> bool {
148        self.len() == 0
149    }
150
151    pub fn max_level(&self) -> u16 {
152        self.inner.max_level()
153    }
154
155    pub fn num_nodes(&self, level: usize) -> usize {
156        self.inner.num_nodes(level)
157    }
158
159    pub fn nodes(&self) -> Arc<Vec<RwLock<GraphBuilderNode>>> {
160        self.inner.nodes()
161    }
162
163    #[allow(clippy::too_many_arguments)]
164    pub fn search_inner(
165        &self,
166        query: ArrayRef,
167        k: usize,
168        params: &HnswQueryParams,
169        bitset: Option<Visited>,
170        visited_generator: &mut VisitedGenerator,
171        storage: &impl VectorStore,
172        prefetch_distance: Option<usize>,
173    ) -> Result<Vec<OrderedNode>> {
174        let dist_calc = storage.dist_calculator(query, params.dist_q_c);
175        let mut ep = OrderedNode::new(0, dist_calc.distance(0).into());
176        let nodes = &self.nodes();
177        for level in (0..self.max_level()).rev() {
178            let cur_level = HnswLevelView::new(level, nodes);
179            ep = greedy_search(
180                &cur_level,
181                ep,
182                &dist_calc,
183                self.inner.params.prefetch_distance,
184            );
185        }
186
187        let bottom_level = HnswBottomView::new(nodes);
188        let mut visited = visited_generator.generate(storage.len());
189        Ok(beam_search(
190            &bottom_level,
191            &ep,
192            params,
193            &dist_calc,
194            bitset.as_ref(),
195            prefetch_distance,
196            &mut visited,
197        )
198        .into_iter()
199        .take(k)
200        .collect())
201    }
202
203    #[instrument(level = "debug", skip(self, query, bitset, storage))]
204    pub fn search_basic(
205        &self,
206        query: ArrayRef,
207        k: usize,
208        params: &HnswQueryParams,
209        bitset: Option<Visited>,
210        storage: &impl VectorStore,
211    ) -> Result<Vec<OrderedNode>> {
212        let mut visited_generator = self
213            .inner
214            .visited_generator_queue
215            .pop()
216            .unwrap_or_else(|| VisitedGenerator::new(storage.len()));
217        let result = self.search_inner(
218            query,
219            k,
220            params,
221            bitset,
222            &mut visited_generator,
223            storage,
224            Some(2),
225        );
226
227        match self.inner.visited_generator_queue.push(visited_generator) {
228            Ok(_) => {}
229            Err(_) => {
230                log::warn!("visited_generator_queue is full");
231            }
232        }
233
234        result
235    }
236
237    #[instrument(level = "debug", skip(self, storage, query, prefilter_bitset))]
238    fn flat_search(
239        &self,
240        storage: &impl VectorStore,
241        query: ArrayRef,
242        k: usize,
243        prefilter_bitset: Visited,
244        params: &HnswQueryParams,
245    ) -> Vec<OrderedNode> {
246        let lower_bound: OrderedFloat = params.lower_bound.unwrap_or(f32::MIN).into();
247        let upper_bound: OrderedFloat = params.upper_bound.unwrap_or(f32::MAX).into();
248
249        let dist_calc = storage.dist_calculator(query, params.dist_q_c);
250        let mut heap = BinaryHeap::<OrderedNode>::with_capacity(k);
251
252        match self.inner.params.prefetch_distance {
253            Some(ahead) if ahead > 0 => {
254                let mut ids_iter = prefilter_bitset.iter_ones().map(|i| i as u32);
255                let mut buffer = VecDeque::with_capacity(ahead + 1);
256                for _ in 0..=ahead {
257                    if let Some(id) = ids_iter.next() {
258                        buffer.push_back(id);
259                    } else {
260                        break;
261                    }
262                }
263
264                while let Some(node_id) = buffer.pop_front() {
265                    if let Some(&prefetch_id) = buffer.get(ahead - 1) {
266                        dist_calc.prefetch(prefetch_id);
267                    }
268                    if let Some(next) = ids_iter.next() {
269                        buffer.push_back(next);
270                    }
271
272                    let dist: OrderedFloat = dist_calc.distance(node_id).into();
273                    if dist <= lower_bound || dist > upper_bound {
274                        continue;
275                    }
276                    if heap.len() < k {
277                        heap.push((dist, node_id).into());
278                    } else if dist < heap.peek().unwrap().dist {
279                        heap.pop();
280                        heap.push((dist, node_id).into());
281                    }
282                }
283            }
284            _ => {
285                for node_id in prefilter_bitset.iter_ones().map(|i| i as u32) {
286                    let dist: OrderedFloat = dist_calc.distance(node_id).into();
287                    if dist <= lower_bound || dist > upper_bound {
288                        continue;
289                    }
290                    if heap.len() < k {
291                        heap.push((dist, node_id).into());
292                    } else if dist < heap.peek().unwrap().dist {
293                        heap.pop();
294                        heap.push((dist, node_id).into());
295                    }
296                }
297            }
298        };
299        heap.into_sorted_vec()
300    }
301
302    /// Returns the metadata of this [`HNSW`].
303    pub fn metadata(&self) -> HnswMetadata {
304        // calculate the offsets of each level,
305        // start from 0
306        let level_offsets = self
307            .inner
308            .level_count
309            .iter()
310            .chain(iter::once(&AtomicUsize::new(0)))
311            .scan(0, |state, x| {
312                let start = *state;
313                *state += x.load(Ordering::Relaxed);
314                Some(start)
315            })
316            .collect();
317
318        HnswMetadata {
319            entry_point: self.inner.entry_point,
320            params: self.inner.params.clone(),
321            level_offsets,
322        }
323    }
324}
325
326struct HnswBuilder {
327    params: HnswBuildParams,
328
329    nodes: Arc<Vec<RwLock<GraphBuilderNode>>>,
330    level_count: Vec<AtomicUsize>,
331
332    entry_point: u32,
333
334    visited_generator_queue: Arc<ArrayQueue<VisitedGenerator>>,
335}
336
337impl DeepSizeOf for HnswBuilder {
338    fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
339        self.params.deep_size_of_children(context)
340            + self.nodes.deep_size_of_children(context)
341            + self.level_count.deep_size_of_children(context)
342        // Skipping the visited_generator_queue
343    }
344}
345
346impl HnswBuilder {
347    fn max_level(&self) -> u16 {
348        self.params.max_level
349    }
350
351    fn num_nodes(&self, level: usize) -> usize {
352        self.level_count[level].load(Ordering::Relaxed)
353    }
354
355    fn nodes(&self) -> Arc<Vec<RwLock<GraphBuilderNode>>> {
356        self.nodes.clone()
357    }
358
359    /// Create a new [`HNSWBuilder`] with prepared params and in memory vector storage.
360    pub fn with_params(params: HnswBuildParams, storage: &impl VectorStore) -> Self {
361        let len = storage.len();
362        let max_level = params.max_level;
363
364        let level_count = (0..max_level)
365            .map(|_| AtomicUsize::new(0))
366            .collect::<Vec<_>>();
367
368        let visited_generator_queue = Arc::new(ArrayQueue::new(get_num_compute_intensive_cpus()));
369        for _ in 0..get_num_compute_intensive_cpus() {
370            visited_generator_queue
371                .push(VisitedGenerator::new(0))
372                .unwrap();
373        }
374        let mut builder = Self {
375            params,
376            nodes: Arc::new(Vec::new()),
377            level_count,
378            entry_point: 0,
379            visited_generator_queue,
380        };
381
382        if storage.is_empty() {
383            return builder;
384        }
385
386        let mut nodes = Vec::with_capacity(len);
387        {
388            if len > 0 {
389                nodes.push(RwLock::new(GraphBuilderNode::new(0, max_level as usize)));
390            }
391            for i in 1..len {
392                nodes.push(RwLock::new(GraphBuilderNode::new(
393                    i as u32,
394                    builder.random_level() as usize + 1,
395                )));
396            }
397        }
398        builder.nodes = Arc::new(nodes);
399
400        builder
401    }
402
403    /// New node's level
404    ///
405    /// See paper `Algorithm 1`
406    fn random_level(&self) -> u16 {
407        let mut rng = rng();
408        let ml = 1.0 / (self.params.m as f32).ln();
409        min(
410            (-rng.random::<f32>().ln() * ml) as u16,
411            self.params.max_level - 1,
412        )
413    }
414
415    /// Insert one node.
416    fn insert(
417        &self,
418        node: u32,
419        visited_generator: &mut VisitedGenerator,
420        storage: &impl VectorStore,
421    ) {
422        let nodes = &self.nodes;
423        let target_level = nodes[node as usize].read().unwrap().level_neighbors.len() as u16 - 1;
424        let dist_calc = storage.dist_calculator_from_id(node);
425        let mut ep = OrderedNode::new(
426            self.entry_point,
427            dist_calc.distance(self.entry_point).into(),
428        );
429
430        //
431        // Search for entry point in paper.
432        // ```
433        //   for l_c in (L..l+1) {
434        //     W = Search-Layer(q, ep, ef=1, l_c)
435        //    ep = Select-Neighbors(W, 1)
436        //  }
437        // ```
438        for level in (target_level + 1..self.params.max_level).rev() {
439            let cur_level = HnswLevelView::new(level, nodes);
440            ep = greedy_search(&cur_level, ep, &dist_calc, self.params.prefetch_distance);
441        }
442
443        let mut pruned_neighbors_per_level: Vec<Vec<_>> =
444            vec![Vec::new(); (target_level + 1) as usize];
445        {
446            let mut current_node = nodes[node as usize].write().unwrap();
447            for level in (0..=target_level).rev() {
448                self.level_count[level as usize].fetch_add(1, Ordering::Relaxed);
449
450                let neighbors = self.search_level(&ep, level, &dist_calc, nodes, visited_generator);
451                for neighbor in &neighbors {
452                    current_node.add_neighbor(neighbor.id, neighbor.dist, level);
453                }
454                self.prune(storage, &mut current_node, level);
455                pruned_neighbors_per_level[level as usize]
456                    .clone_from(&current_node.level_neighbors_ranked[level as usize]);
457
458                ep = neighbors[0].clone();
459            }
460        }
461        for (level, pruned_neighbors) in pruned_neighbors_per_level.iter().enumerate() {
462            let _: Vec<_> = pruned_neighbors
463                .iter()
464                .map(|unpruned_edge| {
465                    let level = level as u16;
466                    let m_max = match level {
467                        0 => self.params.m * 2,
468                        _ => self.params.m,
469                    };
470                    if unpruned_edge.dist
471                        < nodes[unpruned_edge.id as usize]
472                            .read()
473                            .unwrap()
474                            .cutoff(level, m_max)
475                    {
476                        let mut chosen_node = nodes[unpruned_edge.id as usize].write().unwrap();
477                        chosen_node.add_neighbor(node, unpruned_edge.dist, level);
478                        self.prune(storage, &mut chosen_node, level);
479                    }
480                })
481                .collect();
482        }
483    }
484
485    fn search_level(
486        &self,
487        ep: &OrderedNode,
488        level: u16,
489        dist_calc: &impl DistCalculator,
490        nodes: &Vec<RwLock<GraphBuilderNode>>,
491        visited_generator: &mut VisitedGenerator,
492    ) -> Vec<OrderedNode> {
493        let cur_level = HnswLevelView::new(level, nodes);
494        let mut visited = visited_generator.generate(nodes.len());
495        beam_search(
496            &cur_level,
497            ep,
498            &HnswQueryParams {
499                ef: self.params.ef_construction,
500                lower_bound: None,
501                upper_bound: None,
502                dist_q_c: 0.0,
503            },
504            dist_calc,
505            None,
506            self.params.prefetch_distance,
507            &mut visited,
508        )
509    }
510
511    fn prune(&self, storage: &impl VectorStore, builder_node: &mut GraphBuilderNode, level: u16) {
512        let m_max = match level {
513            0 => self.params.m * 2,
514            _ => self.params.m,
515        };
516
517        let neighbors_ranked = &mut builder_node.level_neighbors_ranked[level as usize];
518        let level_neighbors = neighbors_ranked.clone();
519        if level_neighbors.len() <= m_max {
520            builder_node.update_from_ranked_neighbors(level);
521            return;
522            //return level_neighbors;
523        }
524
525        *neighbors_ranked = select_neighbors_heuristic(storage, &level_neighbors, m_max);
526        builder_node.update_from_ranked_neighbors(level);
527    }
528}
529
530// View of a level in HNSW graph.
531// This is used to iterate over neighbors in a specific level.
532pub(crate) struct HnswLevelView<'a> {
533    level: u16,
534    nodes: &'a Vec<RwLock<GraphBuilderNode>>,
535}
536
537impl<'a> HnswLevelView<'a> {
538    pub fn new(level: u16, nodes: &'a Vec<RwLock<GraphBuilderNode>>) -> Self {
539        Self { level, nodes }
540    }
541}
542
543impl Graph for HnswLevelView<'_> {
544    fn len(&self) -> usize {
545        self.nodes.len()
546    }
547
548    fn neighbors(&self, key: u32) -> Arc<Vec<u32>> {
549        let node = &self.nodes[key as usize];
550        node.read().unwrap().level_neighbors[self.level as usize].clone()
551    }
552}
553
554pub(crate) struct HnswBottomView<'a> {
555    nodes: &'a Vec<RwLock<GraphBuilderNode>>,
556}
557
558impl<'a> HnswBottomView<'a> {
559    pub fn new(nodes: &'a Vec<RwLock<GraphBuilderNode>>) -> Self {
560        Self { nodes }
561    }
562}
563
564impl Graph for HnswBottomView<'_> {
565    fn len(&self) -> usize {
566        self.nodes.len()
567    }
568
569    fn neighbors(&self, key: u32) -> Arc<Vec<u32>> {
570        let node = &self.nodes[key as usize];
571        node.read().unwrap().bottom_neighbors.clone()
572    }
573}
574
575#[derive(Debug, Clone, Copy)]
576pub struct HnswQueryParams {
577    pub ef: usize,
578    pub lower_bound: Option<f32>,
579    pub upper_bound: Option<f32>,
580    pub dist_q_c: f32,
581}
582
583impl From<&Query> for HnswQueryParams {
584    fn from(query: &Query) -> Self {
585        let k = query.k * query.refine_factor.unwrap_or(1) as usize;
586        Self {
587            ef: query.ef.unwrap_or(k + k / 2),
588            lower_bound: query.lower_bound,
589            upper_bound: query.upper_bound,
590            dist_q_c: query.dist_q_c,
591        }
592    }
593}
594
595impl IvfSubIndex for HNSW {
596    type BuildParams = HnswBuildParams;
597    type QueryParams = HnswQueryParams;
598
599    fn load(data: RecordBatch) -> Result<Self>
600    where
601        Self: Sized,
602    {
603        if data.num_rows() == 0 {
604            return Ok(Self::empty());
605        }
606
607        let hnsw_metadata =
608            data.schema_ref()
609                .metadata()
610                .get(HNSW_METADATA_KEY)
611                .ok_or(Error::Index {
612                    message: format!("{} not found", HNSW_METADATA_KEY),
613                    location: location!(),
614                })?;
615        let hnsw_metadata: HnswMetadata =
616            serde_json::from_str(hnsw_metadata).map_err(|e| Error::Index {
617                message: format!(
618                    "Failed to decode HNSW metadata: {}, json: {}",
619                    e, hnsw_metadata
620                ),
621                location: location!(),
622            })?;
623
624        let levels: Vec<_> = hnsw_metadata
625            .level_offsets
626            .iter()
627            .tuple_windows()
628            .map(|(start, end)| data.slice(*start, end - start))
629            .collect();
630
631        let level_count = levels.iter().map(|b| b.num_rows()).collect::<Vec<_>>();
632
633        let bottom_level_len = levels[0].num_rows();
634        let mut nodes = Vec::with_capacity(bottom_level_len);
635        for i in 0..bottom_level_len {
636            nodes.push(GraphBuilderNode::new(i as u32, levels.len()));
637        }
638        for (level, batch) in levels.into_iter().enumerate() {
639            let ids = batch[VECTOR_ID_COL].as_primitive::<UInt32Type>();
640            let neighbors = batch[NEIGHBORS_COL].as_list::<i32>();
641            let distances = batch[DIST_COL].as_list::<i32>();
642
643            for ((node, neighbors), distances) in
644                ids.iter().zip(neighbors.iter()).zip(distances.iter())
645            {
646                let node = node.unwrap();
647                let neighbors = neighbors.as_ref().unwrap().as_primitive::<UInt32Type>();
648                let distances = distances.as_ref().unwrap().as_primitive::<Float32Type>();
649
650                nodes[node as usize].level_neighbors_ranked[level] = neighbors
651                    .iter()
652                    .zip(distances.iter())
653                    .map(|(n, dist)| OrderedNode::new(n.unwrap(), OrderedFloat(dist.unwrap())))
654                    .collect();
655                nodes[node as usize].update_from_ranked_neighbors(level as u16);
656            }
657        }
658
659        let visited_generator_queue =
660            Arc::new(ArrayQueue::new(get_num_compute_intensive_cpus() * 2));
661        for _ in 0..get_num_compute_intensive_cpus() * 2 {
662            visited_generator_queue
663                .push(VisitedGenerator::new(0))
664                .unwrap();
665        }
666        let inner = HnswBuilder {
667            params: hnsw_metadata.params,
668            nodes: Arc::new(nodes.into_iter().map(RwLock::new).collect()),
669            level_count: level_count.into_iter().map(AtomicUsize::new).collect(),
670            entry_point: hnsw_metadata.entry_point,
671            visited_generator_queue,
672        };
673
674        Ok(Self {
675            inner: Arc::new(inner),
676        })
677    }
678
679    fn name() -> &'static str {
680        HNSW_TYPE
681    }
682
683    fn metadata_key() -> &'static str {
684        "lance:hnsw"
685    }
686
687    /// Return the schema of the sub index
688    fn schema() -> arrow_schema::SchemaRef {
689        arrow_schema::Schema::new(vec![
690            VECTOR_ID_FIELD.clone(),
691            NEIGHBORS_FIELD.clone(),
692            DISTS_FIELD.clone(),
693        ])
694        .into()
695    }
696
697    #[instrument(level = "debug", skip(self, query, storage, prefilter, _metrics))]
698    fn search(
699        &self,
700        query: ArrayRef,
701        k: usize,
702        params: Self::QueryParams,
703        storage: &impl VectorStore,
704        prefilter: Arc<dyn PreFilter>,
705        _metrics: &dyn MetricsCollector,
706    ) -> Result<RecordBatch> {
707        if params.ef < k {
708            return Err(Error::Index {
709                message: "ef must be greater than or equal to k".to_string(),
710                location: location!(),
711            });
712        }
713
714        let schema = VECTOR_RESULT_SCHEMA.clone();
715        if self.is_empty() {
716            return Ok(RecordBatch::new_empty(schema));
717        }
718
719        let mut prefilter_generator = self
720            .inner
721            .visited_generator_queue
722            .pop()
723            .unwrap_or_else(|| VisitedGenerator::new(storage.len()));
724        let prefilter_bitset = if prefilter.is_empty() {
725            None
726        } else {
727            let indices = prefilter.filter_row_ids(Box::new(storage.row_ids()));
728            let mut bitset = prefilter_generator.generate(storage.len());
729            for indices in indices {
730                bitset.insert(indices as u32);
731            }
732            Some(bitset)
733        };
734
735        let remained = prefilter_bitset
736            .as_ref()
737            .map(|b| b.count_ones())
738            .unwrap_or(storage.len());
739        let results = if remained < self.len() * 10 / 100 {
740            let prefilter_bitset =
741                prefilter_bitset.expect("the prefilter bitset must be set for flat search");
742            self.flat_search(storage, query, k, prefilter_bitset, &params)
743        } else {
744            self.search_basic(query, k, &params, prefilter_bitset, storage)?
745        };
746        // if the queue is full, we just don't push it back, so ignore the error here
747        let _ = self.inner.visited_generator_queue.push(prefilter_generator);
748
749        // need to unique by row ids in case of searching multivector
750        let (row_ids, dists): (Vec<_>, Vec<_>) = results
751            .into_iter()
752            .map(|r| (storage.row_id(r.id), r.dist.0))
753            .unique_by(|r| r.0)
754            .unzip();
755        let row_ids = Arc::new(UInt64Array::from(row_ids));
756        let distances = Arc::new(Float32Array::from(dists));
757
758        Ok(RecordBatch::try_new(schema, vec![distances, row_ids])?)
759    }
760
761    /// Given a vector storage, containing all the data for the IVF partition, build the sub index.
762    fn index_vectors(storage: &impl VectorStore, params: Self::BuildParams) -> Result<Self>
763    where
764        Self: Sized,
765    {
766        let inner = HnswBuilder::with_params(params, storage);
767        let hnsw = Self {
768            inner: Arc::new(inner),
769        };
770
771        log::debug!(
772            "Building HNSW graph: num={}, max_levels={}, m={}, ef_construction={}, distance_type:{}",
773            storage.len(),
774            hnsw.inner.params.max_level,
775            hnsw.inner.params.m,
776            hnsw.inner.params.ef_construction,
777            storage.distance_type(),
778        );
779
780        if storage.is_empty() {
781            return Ok(hnsw);
782        }
783
784        let len = storage.len();
785        hnsw.inner.level_count[0].fetch_add(1, Ordering::Relaxed);
786        (1..len).into_par_iter().for_each_init(
787            || VisitedGenerator::new(len),
788            |visited_generator, node| {
789                hnsw.inner.insert(node as u32, visited_generator, storage);
790            },
791        );
792
793        assert_eq!(hnsw.inner.level_count[0].load(Ordering::Relaxed), len);
794        Ok(hnsw)
795    }
796
797    fn remap(
798        &self,
799        _mapping: &HashMap<u64, Option<u64>>, // we don't need the mapping here because we rebuild the graph from remapped storage
800        store: &impl VectorStore,
801    ) -> Result<Self> {
802        // We can't simply remap the row ids in the graph because the vectors are changed,
803        // so the graph needs to be rebuilt.
804        Self::index_vectors(store, self.inner.params.clone())
805    }
806
807    /// Encode the sub index into a record batch
808    fn to_batch(&self) -> Result<RecordBatch> {
809        let mut vector_id_builder = UInt32Builder::with_capacity(self.len());
810        let mut neighbors_builder = ListBuilder::with_capacity(UInt32Builder::new(), self.len());
811        let mut distances_builder =
812            ListBuilder::with_capacity(arrow_array::builder::Float32Builder::new(), self.len());
813        let mut batches = Vec::with_capacity(self.max_level() as usize);
814        for level in 0..self.max_level() {
815            let level = level as usize;
816            for (id, node) in self.inner.nodes.iter().enumerate() {
817                let node = node.read().unwrap();
818                if level >= node.level_neighbors.len() {
819                    continue;
820                }
821                let neighbors = node.level_neighbors[level].iter().map(|n| Some(*n));
822                let distances = node.level_neighbors_ranked[level]
823                    .iter()
824                    .map(|n| Some(n.dist.0));
825                vector_id_builder.append_value(id as u32);
826                neighbors_builder.append_value(neighbors);
827                distances_builder.append_value(distances);
828            }
829
830            let batch = RecordBatch::try_new(
831                Self::schema(),
832                vec![
833                    Arc::new(vector_id_builder.finish()),
834                    Arc::new(neighbors_builder.finish()),
835                    Arc::new(distances_builder.finish()),
836                ],
837            )?;
838            batches.push(batch);
839        }
840
841        let metadata = self.metadata();
842        let metadata = serde_json::to_string(&metadata)?;
843        let schema = Self::schema()
844            .as_ref()
845            .clone()
846            .with_metadata(HashMap::from_iter(vec![(
847                HNSW_METADATA_KEY.to_string(),
848                metadata,
849            )]));
850        let batch = concat_batches(&Self::schema(), batches.iter())?;
851        let batch = batch.with_schema(Arc::new(schema))?;
852        Ok(batch)
853    }
854}
855
856#[cfg(test)]
857mod tests {
858    use std::sync::Arc;
859
860    use arrow_array::FixedSizeListArray;
861    use arrow_schema::Schema;
862    use lance_arrow::FixedSizeListArrayExt;
863    use lance_file::previous::{
864        reader::FileReader as PreviousFileReader,
865        writer::{
866            FileWriter as PreviousFileWriter, FileWriterOptions as PreviousFileWriterOptions,
867        },
868    };
869    use lance_io::object_store::ObjectStore;
870    use lance_linalg::distance::DistanceType;
871    use lance_table::format::SelfDescribingFileReader;
872    use lance_table::io::manifest::ManifestDescribing;
873    use lance_testing::datagen::generate_random_array;
874    use object_store::path::Path;
875
876    use crate::scalar::IndexWriter;
877    use crate::vector::v3::subindex::IvfSubIndex;
878    use crate::vector::{
879        flat::storage::FlatFloatStorage,
880        graph::{DISTS_FIELD, NEIGHBORS_FIELD},
881        hnsw::{
882            builder::{HnswBuildParams, HnswQueryParams},
883            HNSW, VECTOR_ID_FIELD,
884        },
885    };
886
887    #[tokio::test]
888    async fn test_builder_write_load() {
889        const DIM: usize = 32;
890        const TOTAL: usize = 2048;
891        const NUM_EDGES: usize = 20;
892        let data = generate_random_array(TOTAL * DIM);
893        let fsl = FixedSizeListArray::try_new_from_values(data, DIM as i32).unwrap();
894        let store = Arc::new(FlatFloatStorage::new(fsl.clone(), DistanceType::L2));
895        let builder = HNSW::index_vectors(
896            store.as_ref(),
897            HnswBuildParams::default()
898                .num_edges(NUM_EDGES)
899                .ef_construction(50),
900        )
901        .unwrap();
902
903        let object_store = ObjectStore::memory();
904        let path = Path::from("test_builder_write_load");
905        let writer = object_store.create(&path).await.unwrap();
906        let schema = Schema::new(vec![
907            VECTOR_ID_FIELD.clone(),
908            NEIGHBORS_FIELD.clone(),
909            DISTS_FIELD.clone(),
910        ]);
911        let schema = lance_core::datatypes::Schema::try_from(&schema).unwrap();
912        let mut writer = PreviousFileWriter::<ManifestDescribing>::with_object_writer(
913            writer,
914            schema,
915            &PreviousFileWriterOptions::default(),
916        )
917        .unwrap();
918        let batch = builder.to_batch().unwrap();
919        let metadata = batch.schema_ref().metadata().clone();
920        writer.write_record_batch(batch).await.unwrap();
921        writer.finish_with_metadata(&metadata).await.unwrap();
922
923        let reader = PreviousFileReader::try_new_self_described(&object_store, &path, None)
924            .await
925            .unwrap();
926        let batch = reader
927            .read_range(0..reader.len(), reader.schema())
928            .await
929            .unwrap();
930        let loaded_hnsw = HNSW::load(batch).unwrap();
931
932        let query = fsl.value(0);
933        let k = 10;
934        let params = HnswQueryParams {
935            ef: 50,
936            lower_bound: None,
937            upper_bound: None,
938            dist_q_c: 0.0,
939        };
940        let builder_results = builder
941            .search_basic(query.clone(), k, &params, None, store.as_ref())
942            .unwrap();
943        let loaded_results = loaded_hnsw
944            .search_basic(query, k, &params, None, store.as_ref())
945            .unwrap();
946        assert_eq!(builder_results, loaded_results);
947    }
948}