lance_index/vector/hnsw/
builder.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Builder of Hnsw Graph.
5
6use arrow::array::{AsArray, ListBuilder, UInt32Builder};
7use arrow::compute::concat_batches;
8use arrow::datatypes::{Float32Type, UInt32Type};
9use arrow_array::{ArrayRef, Float32Array, RecordBatch, UInt64Array};
10use crossbeam_queue::ArrayQueue;
11use deepsize::DeepSizeOf;
12use itertools::Itertools;
13
14use lance_core::utils::tokio::get_num_compute_intensive_cpus;
15use lance_linalg::distance::DistanceType;
16use rayon::prelude::*;
17use snafu::location;
18use std::cmp::min;
19use std::collections::{BinaryHeap, HashMap};
20use std::fmt::Debug;
21use std::iter;
22use std::sync::atomic::{AtomicUsize, Ordering};
23use std::sync::Arc;
24use std::sync::RwLock;
25use tracing::instrument;
26
27use lance_core::{Error, Result};
28use rand::{rng, Rng};
29use serde::{Deserialize, Serialize};
30
31use super::super::graph::beam_search;
32use super::{select_neighbors_heuristic, HnswMetadata, HNSW_TYPE, VECTOR_ID_COL, VECTOR_ID_FIELD};
33use crate::metrics::MetricsCollector;
34use crate::prefilter::PreFilter;
35use crate::vector::flat::storage::FlatFloatStorage;
36use crate::vector::graph::builder::GraphBuilderNode;
37use crate::vector::graph::{greedy_search, Visited};
38use crate::vector::graph::{
39    Graph, OrderedFloat, OrderedNode, VisitedGenerator, DISTS_FIELD, NEIGHBORS_COL, NEIGHBORS_FIELD,
40};
41use crate::vector::storage::{DistCalculator, VectorStore};
42use crate::vector::v3::subindex::IvfSubIndex;
43use crate::vector::{Query, DIST_COL, VECTOR_RESULT_SCHEMA};
44
45pub const HNSW_METADATA_KEY: &str = "lance:hnsw";
46
47/// Parameters of building HNSW index
48#[derive(Debug, Clone, Serialize, Deserialize, DeepSizeOf)]
49pub struct HnswBuildParams {
50    /// max level ofm
51    pub max_level: u16,
52
53    /// number of connections to establish while inserting new element
54    pub m: usize,
55
56    /// size of the dynamic list for the candidates
57    pub ef_construction: usize,
58
59    /// number of vectors ahead to prefetch while building the graph
60    pub prefetch_distance: Option<usize>,
61}
62
63impl Default for HnswBuildParams {
64    fn default() -> Self {
65        Self {
66            max_level: 7,
67            m: 20,
68            ef_construction: 150,
69            prefetch_distance: Some(2),
70        }
71    }
72}
73
74impl HnswBuildParams {
75    /// The maximum level of the graph.
76    /// The default value is `8`.
77    pub fn max_level(mut self, max_level: u16) -> Self {
78        self.max_level = max_level;
79        self
80    }
81
82    /// The number of connections to establish while inserting new element
83    /// The default value is `30`.
84    pub fn num_edges(mut self, m: usize) -> Self {
85        self.m = m;
86        self
87    }
88
89    /// Number of candidates to be considered when searching for the nearest neighbors
90    /// during the construction of the graph.
91    ///
92    /// The default value is `100`.
93    pub fn ef_construction(mut self, ef_construction: usize) -> Self {
94        self.ef_construction = ef_construction;
95        self
96    }
97
98    /// Build the HNSW index from the given data.
99    ///
100    /// # Parameters
101    /// - `data`: A FixedSizeList to build the HNSW.
102    /// - `distance_type`: The distance type to use.
103    pub async fn build(self, data: ArrayRef, distance_type: DistanceType) -> Result<HNSW> {
104        let vec_store = Arc::new(FlatFloatStorage::new(
105            data.as_fixed_size_list().clone(),
106            distance_type,
107        ));
108        HNSW::index_vectors(vec_store.as_ref(), self)
109    }
110}
111
112/// Build a HNSW graph.
113///
114/// Currently, the HNSW graph is fully built in memory.
115///
116/// During the build, the graph is built layer by layer.
117///
118/// Each node in the graph has a global ID which is the index on the base layer.
119#[derive(Clone, DeepSizeOf)]
120pub struct HNSW {
121    inner: Arc<HnswBuilder>,
122}
123
124impl Debug for HNSW {
125    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126        write!(f, "HNSW(max_layers: {})", self.inner.max_level() as usize,)
127    }
128}
129
130impl HNSW {
131    pub fn empty() -> Self {
132        Self {
133            inner: Arc::new(HnswBuilder {
134                params: HnswBuildParams::default(),
135                nodes: Arc::new(Vec::new()),
136                level_count: Vec::new(),
137                entry_point: 0,
138                visited_generator_queue: Arc::new(ArrayQueue::new(1)),
139            }),
140        }
141    }
142
143    pub fn len(&self) -> usize {
144        self.inner.nodes.len()
145    }
146
147    pub fn is_empty(&self) -> bool {
148        self.len() == 0
149    }
150
151    pub fn max_level(&self) -> u16 {
152        self.inner.max_level()
153    }
154
155    pub fn num_nodes(&self, level: usize) -> usize {
156        self.inner.num_nodes(level)
157    }
158
159    pub fn nodes(&self) -> Arc<Vec<RwLock<GraphBuilderNode>>> {
160        self.inner.nodes()
161    }
162
163    #[allow(clippy::too_many_arguments)]
164    pub fn search_inner(
165        &self,
166        query: ArrayRef,
167        k: usize,
168        params: &HnswQueryParams,
169        bitset: Option<Visited>,
170        visited_generator: &mut VisitedGenerator,
171        storage: &impl VectorStore,
172        prefetch_distance: Option<usize>,
173    ) -> Result<Vec<OrderedNode>> {
174        let dist_calc = storage.dist_calculator(query, params.dist_q_c);
175        let mut ep = OrderedNode::new(0, dist_calc.distance(0).into());
176        let nodes = &self.nodes();
177        for level in (0..self.max_level()).rev() {
178            let cur_level = HnswLevelView::new(level, nodes);
179            ep = greedy_search(
180                &cur_level,
181                ep,
182                &dist_calc,
183                self.inner.params.prefetch_distance,
184            );
185        }
186
187        let bottom_level = HnswBottomView::new(nodes);
188        let mut visited = visited_generator.generate(storage.len());
189        Ok(beam_search(
190            &bottom_level,
191            &ep,
192            params,
193            &dist_calc,
194            bitset.as_ref(),
195            prefetch_distance,
196            &mut visited,
197        )
198        .into_iter()
199        .take(k)
200        .collect())
201    }
202
203    #[instrument(level = "debug", skip(self, query, bitset, storage))]
204    pub fn search_basic(
205        &self,
206        query: ArrayRef,
207        k: usize,
208        params: &HnswQueryParams,
209        bitset: Option<Visited>,
210        storage: &impl VectorStore,
211    ) -> Result<Vec<OrderedNode>> {
212        let mut visited_generator = self
213            .inner
214            .visited_generator_queue
215            .pop()
216            .unwrap_or_else(|| VisitedGenerator::new(storage.len()));
217        let result = self.search_inner(
218            query,
219            k,
220            params,
221            bitset,
222            &mut visited_generator,
223            storage,
224            Some(2),
225        );
226
227        match self.inner.visited_generator_queue.push(visited_generator) {
228            Ok(_) => {}
229            Err(_) => {
230                log::warn!("visited_generator_queue is full");
231            }
232        }
233
234        result
235    }
236
237    #[instrument(level = "debug", skip(self, storage, query, prefilter_bitset))]
238    fn flat_search(
239        &self,
240        storage: &impl VectorStore,
241        query: ArrayRef,
242        k: usize,
243        prefilter_bitset: Visited,
244        params: &HnswQueryParams,
245    ) -> Vec<OrderedNode> {
246        let node_ids = storage
247            .row_ids()
248            .enumerate()
249            .filter_map(|(node_id, _)| {
250                prefilter_bitset
251                    .contains(node_id as u32)
252                    .then_some(node_id as u32)
253            })
254            .collect_vec();
255
256        let lower_bound: OrderedFloat = params.lower_bound.unwrap_or(f32::MIN).into();
257        let upper_bound: OrderedFloat = params.upper_bound.unwrap_or(f32::MAX).into();
258
259        let dist_calc = storage.dist_calculator(query, params.dist_q_c);
260        let mut heap = BinaryHeap::<OrderedNode>::with_capacity(k);
261        for i in 0..node_ids.len() {
262            if let Some(ahead) = self.inner.params.prefetch_distance {
263                if i + ahead < node_ids.len() {
264                    dist_calc.prefetch(node_ids[i + ahead]);
265                }
266            }
267            let node_id = node_ids[i];
268            let dist: OrderedFloat = dist_calc.distance(node_id).into();
269            if dist <= lower_bound || dist > upper_bound {
270                continue;
271            }
272            if heap.len() < k {
273                heap.push((dist, node_id).into());
274            } else if dist < heap.peek().unwrap().dist {
275                heap.pop();
276                heap.push((dist, node_id).into());
277            }
278        }
279        heap.into_sorted_vec()
280    }
281
282    /// Returns the metadata of this [`HNSW`].
283    pub fn metadata(&self) -> HnswMetadata {
284        // calculate the offsets of each level,
285        // start from 0
286        let level_offsets = self
287            .inner
288            .level_count
289            .iter()
290            .chain(iter::once(&AtomicUsize::new(0)))
291            .scan(0, |state, x| {
292                let start = *state;
293                *state += x.load(Ordering::Relaxed);
294                Some(start)
295            })
296            .collect();
297
298        HnswMetadata {
299            entry_point: self.inner.entry_point,
300            params: self.inner.params.clone(),
301            level_offsets,
302        }
303    }
304}
305
306struct HnswBuilder {
307    params: HnswBuildParams,
308
309    nodes: Arc<Vec<RwLock<GraphBuilderNode>>>,
310    level_count: Vec<AtomicUsize>,
311
312    entry_point: u32,
313
314    visited_generator_queue: Arc<ArrayQueue<VisitedGenerator>>,
315}
316
317impl DeepSizeOf for HnswBuilder {
318    fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
319        self.params.deep_size_of_children(context)
320            + self.nodes.deep_size_of_children(context)
321            + self.level_count.deep_size_of_children(context)
322        // Skipping the visited_generator_queue
323    }
324}
325
326impl HnswBuilder {
327    fn max_level(&self) -> u16 {
328        self.params.max_level
329    }
330
331    fn num_nodes(&self, level: usize) -> usize {
332        self.level_count[level].load(Ordering::Relaxed)
333    }
334
335    fn nodes(&self) -> Arc<Vec<RwLock<GraphBuilderNode>>> {
336        self.nodes.clone()
337    }
338
339    /// Create a new [`HNSWBuilder`] with prepared params and in memory vector storage.
340    pub fn with_params(params: HnswBuildParams, storage: &impl VectorStore) -> Self {
341        let len = storage.len();
342        let max_level = params.max_level;
343
344        let level_count = (0..max_level)
345            .map(|_| AtomicUsize::new(0))
346            .collect::<Vec<_>>();
347
348        let visited_generator_queue = Arc::new(ArrayQueue::new(get_num_compute_intensive_cpus()));
349        for _ in 0..get_num_compute_intensive_cpus() {
350            visited_generator_queue
351                .push(VisitedGenerator::new(0))
352                .unwrap();
353        }
354        let mut builder = Self {
355            params,
356            nodes: Arc::new(Vec::new()),
357            level_count,
358            entry_point: 0,
359            visited_generator_queue,
360        };
361
362        if storage.is_empty() {
363            return builder;
364        }
365
366        let mut nodes = Vec::with_capacity(len);
367        {
368            if len > 0 {
369                nodes.push(RwLock::new(GraphBuilderNode::new(0, max_level as usize)));
370            }
371            for i in 1..len {
372                nodes.push(RwLock::new(GraphBuilderNode::new(
373                    i as u32,
374                    builder.random_level() as usize + 1,
375                )));
376            }
377        }
378        builder.nodes = Arc::new(nodes);
379
380        builder
381    }
382
383    /// New node's level
384    ///
385    /// See paper `Algorithm 1`
386    fn random_level(&self) -> u16 {
387        let mut rng = rng();
388        let ml = 1.0 / (self.params.m as f32).ln();
389        min(
390            (-rng.random::<f32>().ln() * ml) as u16,
391            self.params.max_level - 1,
392        )
393    }
394
395    /// Insert one node.
396    fn insert(
397        &self,
398        node: u32,
399        visited_generator: &mut VisitedGenerator,
400        storage: &impl VectorStore,
401    ) {
402        let nodes = &self.nodes;
403        let target_level = nodes[node as usize].read().unwrap().level_neighbors.len() as u16 - 1;
404        let dist_calc = storage.dist_calculator_from_id(node);
405        let mut ep = OrderedNode::new(
406            self.entry_point,
407            dist_calc.distance(self.entry_point).into(),
408        );
409
410        //
411        // Search for entry point in paper.
412        // ```
413        //   for l_c in (L..l+1) {
414        //     W = Search-Layer(q, ep, ef=1, l_c)
415        //    ep = Select-Neighbors(W, 1)
416        //  }
417        // ```
418        for level in (target_level + 1..self.params.max_level).rev() {
419            let cur_level = HnswLevelView::new(level, nodes);
420            ep = greedy_search(&cur_level, ep, &dist_calc, self.params.prefetch_distance);
421        }
422
423        let mut pruned_neighbors_per_level: Vec<Vec<_>> =
424            vec![Vec::new(); (target_level + 1) as usize];
425        {
426            let mut current_node = nodes[node as usize].write().unwrap();
427            for level in (0..=target_level).rev() {
428                self.level_count[level as usize].fetch_add(1, Ordering::Relaxed);
429
430                let neighbors = self.search_level(&ep, level, &dist_calc, nodes, visited_generator);
431                for neighbor in &neighbors {
432                    current_node.add_neighbor(neighbor.id, neighbor.dist, level);
433                }
434                self.prune(storage, &mut current_node, level);
435                pruned_neighbors_per_level[level as usize]
436                    .clone_from(&current_node.level_neighbors_ranked[level as usize]);
437
438                ep = neighbors[0].clone();
439            }
440        }
441        for (level, pruned_neighbors) in pruned_neighbors_per_level.iter().enumerate() {
442            let _: Vec<_> = pruned_neighbors
443                .iter()
444                .map(|unpruned_edge| {
445                    let level = level as u16;
446                    let m_max = match level {
447                        0 => self.params.m * 2,
448                        _ => self.params.m,
449                    };
450                    if unpruned_edge.dist
451                        < nodes[unpruned_edge.id as usize]
452                            .read()
453                            .unwrap()
454                            .cutoff(level, m_max)
455                    {
456                        let mut chosen_node = nodes[unpruned_edge.id as usize].write().unwrap();
457                        chosen_node.add_neighbor(node, unpruned_edge.dist, level);
458                        self.prune(storage, &mut chosen_node, level);
459                    }
460                })
461                .collect();
462        }
463    }
464
465    fn search_level(
466        &self,
467        ep: &OrderedNode,
468        level: u16,
469        dist_calc: &impl DistCalculator,
470        nodes: &Vec<RwLock<GraphBuilderNode>>,
471        visited_generator: &mut VisitedGenerator,
472    ) -> Vec<OrderedNode> {
473        let cur_level = HnswLevelView::new(level, nodes);
474        let mut visited = visited_generator.generate(nodes.len());
475        beam_search(
476            &cur_level,
477            ep,
478            &HnswQueryParams {
479                ef: self.params.ef_construction,
480                lower_bound: None,
481                upper_bound: None,
482                dist_q_c: 0.0,
483            },
484            dist_calc,
485            None,
486            self.params.prefetch_distance,
487            &mut visited,
488        )
489    }
490
491    fn prune(&self, storage: &impl VectorStore, builder_node: &mut GraphBuilderNode, level: u16) {
492        let m_max = match level {
493            0 => self.params.m * 2,
494            _ => self.params.m,
495        };
496
497        let neighbors_ranked = &mut builder_node.level_neighbors_ranked[level as usize];
498        let level_neighbors = neighbors_ranked.clone();
499        if level_neighbors.len() <= m_max {
500            builder_node.update_from_ranked_neighbors(level);
501            return;
502            //return level_neighbors;
503        }
504
505        *neighbors_ranked = select_neighbors_heuristic(storage, &level_neighbors, m_max);
506        builder_node.update_from_ranked_neighbors(level);
507    }
508}
509
510// View of a level in HNSW graph.
511// This is used to iterate over neighbors in a specific level.
512pub(crate) struct HnswLevelView<'a> {
513    level: u16,
514    nodes: &'a Vec<RwLock<GraphBuilderNode>>,
515}
516
517impl<'a> HnswLevelView<'a> {
518    pub fn new(level: u16, nodes: &'a Vec<RwLock<GraphBuilderNode>>) -> Self {
519        Self { level, nodes }
520    }
521}
522
523impl Graph for HnswLevelView<'_> {
524    fn len(&self) -> usize {
525        self.nodes.len()
526    }
527
528    fn neighbors(&self, key: u32) -> Arc<Vec<u32>> {
529        let node = &self.nodes[key as usize];
530        node.read().unwrap().level_neighbors[self.level as usize].clone()
531    }
532}
533
534pub(crate) struct HnswBottomView<'a> {
535    nodes: &'a Vec<RwLock<GraphBuilderNode>>,
536}
537
538impl<'a> HnswBottomView<'a> {
539    pub fn new(nodes: &'a Vec<RwLock<GraphBuilderNode>>) -> Self {
540        Self { nodes }
541    }
542}
543
544impl Graph for HnswBottomView<'_> {
545    fn len(&self) -> usize {
546        self.nodes.len()
547    }
548
549    fn neighbors(&self, key: u32) -> Arc<Vec<u32>> {
550        let node = &self.nodes[key as usize];
551        node.read().unwrap().bottom_neighbors.clone()
552    }
553}
554
555#[derive(Debug, Clone, Copy)]
556pub struct HnswQueryParams {
557    pub ef: usize,
558    pub lower_bound: Option<f32>,
559    pub upper_bound: Option<f32>,
560    pub dist_q_c: f32,
561}
562
563impl From<&Query> for HnswQueryParams {
564    fn from(query: &Query) -> Self {
565        let k = query.k * query.refine_factor.unwrap_or(1) as usize;
566        Self {
567            ef: query.ef.unwrap_or(k + k / 2),
568            lower_bound: query.lower_bound,
569            upper_bound: query.upper_bound,
570            dist_q_c: query.dist_q_c,
571        }
572    }
573}
574
575impl IvfSubIndex for HNSW {
576    type BuildParams = HnswBuildParams;
577    type QueryParams = HnswQueryParams;
578
579    fn load(data: RecordBatch) -> Result<Self>
580    where
581        Self: Sized,
582    {
583        if data.num_rows() == 0 {
584            return Ok(Self::empty());
585        }
586
587        let hnsw_metadata =
588            data.schema_ref()
589                .metadata()
590                .get(HNSW_METADATA_KEY)
591                .ok_or(Error::Index {
592                    message: format!("{} not found", HNSW_METADATA_KEY),
593                    location: location!(),
594                })?;
595        let hnsw_metadata: HnswMetadata =
596            serde_json::from_str(hnsw_metadata).map_err(|e| Error::Index {
597                message: format!(
598                    "Failed to decode HNSW metadata: {}, json: {}",
599                    e, hnsw_metadata
600                ),
601                location: location!(),
602            })?;
603
604        let levels: Vec<_> = hnsw_metadata
605            .level_offsets
606            .iter()
607            .tuple_windows()
608            .map(|(start, end)| data.slice(*start, end - start))
609            .collect();
610
611        let level_count = levels.iter().map(|b| b.num_rows()).collect::<Vec<_>>();
612
613        let bottom_level_len = levels[0].num_rows();
614        let mut nodes = Vec::with_capacity(bottom_level_len);
615        for i in 0..bottom_level_len {
616            nodes.push(GraphBuilderNode::new(i as u32, levels.len()));
617        }
618        for (level, batch) in levels.into_iter().enumerate() {
619            let ids = batch[VECTOR_ID_COL].as_primitive::<UInt32Type>();
620            let neighbors = batch[NEIGHBORS_COL].as_list::<i32>();
621            let distances = batch[DIST_COL].as_list::<i32>();
622
623            for ((node, neighbors), distances) in
624                ids.iter().zip(neighbors.iter()).zip(distances.iter())
625            {
626                let node = node.unwrap();
627                let neighbors = neighbors.as_ref().unwrap().as_primitive::<UInt32Type>();
628                let distances = distances.as_ref().unwrap().as_primitive::<Float32Type>();
629
630                nodes[node as usize].level_neighbors_ranked[level] = neighbors
631                    .iter()
632                    .zip(distances.iter())
633                    .map(|(n, dist)| OrderedNode::new(n.unwrap(), OrderedFloat(dist.unwrap())))
634                    .collect();
635                nodes[node as usize].update_from_ranked_neighbors(level as u16);
636            }
637        }
638
639        let visited_generator_queue =
640            Arc::new(ArrayQueue::new(get_num_compute_intensive_cpus() * 2));
641        for _ in 0..get_num_compute_intensive_cpus() * 2 {
642            visited_generator_queue
643                .push(VisitedGenerator::new(0))
644                .unwrap();
645        }
646        let inner = HnswBuilder {
647            params: hnsw_metadata.params,
648            nodes: Arc::new(nodes.into_iter().map(RwLock::new).collect()),
649            level_count: level_count.into_iter().map(AtomicUsize::new).collect(),
650            entry_point: hnsw_metadata.entry_point,
651            visited_generator_queue,
652        };
653
654        Ok(Self {
655            inner: Arc::new(inner),
656        })
657    }
658
659    fn name() -> &'static str {
660        HNSW_TYPE
661    }
662
663    fn metadata_key() -> &'static str {
664        "lance:hnsw"
665    }
666
667    /// Return the schema of the sub index
668    fn schema() -> arrow_schema::SchemaRef {
669        arrow_schema::Schema::new(vec![
670            VECTOR_ID_FIELD.clone(),
671            NEIGHBORS_FIELD.clone(),
672            DISTS_FIELD.clone(),
673        ])
674        .into()
675    }
676
677    #[instrument(level = "debug", skip(self, query, storage, prefilter, _metrics))]
678    fn search(
679        &self,
680        query: ArrayRef,
681        k: usize,
682        params: Self::QueryParams,
683        storage: &impl VectorStore,
684        prefilter: Arc<dyn PreFilter>,
685        _metrics: &dyn MetricsCollector,
686    ) -> Result<RecordBatch> {
687        if params.ef < k {
688            return Err(Error::Index {
689                message: "ef must be greater than or equal to k".to_string(),
690                location: location!(),
691            });
692        }
693
694        let schema = VECTOR_RESULT_SCHEMA.clone();
695        if self.is_empty() {
696            return Ok(RecordBatch::new_empty(schema));
697        }
698
699        let mut prefilter_generator = self
700            .inner
701            .visited_generator_queue
702            .pop()
703            .unwrap_or_else(|| VisitedGenerator::new(storage.len()));
704        let prefilter_bitset = if prefilter.is_empty() {
705            None
706        } else {
707            let indices = prefilter.filter_row_ids(Box::new(storage.row_ids()));
708            let mut bitset = prefilter_generator.generate(storage.len());
709            for indices in indices {
710                bitset.insert(indices as u32);
711            }
712            Some(bitset)
713        };
714
715        let remained = prefilter_bitset
716            .as_ref()
717            .map(|b| b.count_ones())
718            .unwrap_or(storage.len());
719        let results = if remained < self.len() * 10 / 100 {
720            let prefilter_bitset =
721                prefilter_bitset.expect("the prefilter bitset must be set for flat search");
722            self.flat_search(storage, query, k, prefilter_bitset, &params)
723        } else {
724            self.search_basic(query, k, &params, prefilter_bitset, storage)?
725        };
726        // if the queue is full, we just don't push it back, so ignore the error here
727        let _ = self.inner.visited_generator_queue.push(prefilter_generator);
728
729        // need to unique by row ids in case of searching multivector
730        let (row_ids, dists): (Vec<_>, Vec<_>) = results
731            .into_iter()
732            .map(|r| (storage.row_id(r.id), r.dist.0))
733            .unique_by(|r| r.0)
734            .unzip();
735        let row_ids = Arc::new(UInt64Array::from(row_ids));
736        let distances = Arc::new(Float32Array::from(dists));
737
738        Ok(RecordBatch::try_new(schema, vec![distances, row_ids])?)
739    }
740
741    /// Given a vector storage, containing all the data for the IVF partition, build the sub index.
742    fn index_vectors(storage: &impl VectorStore, params: Self::BuildParams) -> Result<Self>
743    where
744        Self: Sized,
745    {
746        let inner = HnswBuilder::with_params(params, storage);
747        let hnsw = Self {
748            inner: Arc::new(inner),
749        };
750
751        log::debug!(
752            "Building HNSW graph: num={}, max_levels={}, m={}, ef_construction={}, distance_type:{}",
753            storage.len(),
754            hnsw.inner.params.max_level,
755            hnsw.inner.params.m,
756            hnsw.inner.params.ef_construction,
757            storage.distance_type(),
758        );
759
760        if storage.is_empty() {
761            return Ok(hnsw);
762        }
763
764        let len = storage.len();
765        hnsw.inner.level_count[0].fetch_add(1, Ordering::Relaxed);
766        (1..len).into_par_iter().for_each_init(
767            || VisitedGenerator::new(len),
768            |visited_generator, node| {
769                hnsw.inner.insert(node as u32, visited_generator, storage);
770            },
771        );
772
773        assert_eq!(hnsw.inner.level_count[0].load(Ordering::Relaxed), len);
774        Ok(hnsw)
775    }
776
777    fn remap(
778        &self,
779        _mapping: &HashMap<u64, Option<u64>>, // we don't need the mapping here because we rebuild the graph from remapped storage
780        store: &impl VectorStore,
781    ) -> Result<Self> {
782        // We can't simply remap the row ids in the graph because the vectors are changed,
783        // so the graph needs to be rebuilt.
784        Self::index_vectors(store, self.inner.params.clone())
785    }
786
787    /// Encode the sub index into a record batch
788    fn to_batch(&self) -> Result<RecordBatch> {
789        let mut vector_id_builder = UInt32Builder::with_capacity(self.len());
790        let mut neighbors_builder = ListBuilder::with_capacity(UInt32Builder::new(), self.len());
791        let mut distances_builder =
792            ListBuilder::with_capacity(arrow_array::builder::Float32Builder::new(), self.len());
793        let mut batches = Vec::with_capacity(self.max_level() as usize);
794        for level in 0..self.max_level() {
795            let level = level as usize;
796            for (id, node) in self.inner.nodes.iter().enumerate() {
797                let node = node.read().unwrap();
798                if level >= node.level_neighbors.len() {
799                    continue;
800                }
801                let neighbors = node.level_neighbors[level].iter().map(|n| Some(*n));
802                let distances = node.level_neighbors_ranked[level]
803                    .iter()
804                    .map(|n| Some(n.dist.0));
805                vector_id_builder.append_value(id as u32);
806                neighbors_builder.append_value(neighbors);
807                distances_builder.append_value(distances);
808            }
809
810            let batch = RecordBatch::try_new(
811                Self::schema(),
812                vec![
813                    Arc::new(vector_id_builder.finish()),
814                    Arc::new(neighbors_builder.finish()),
815                    Arc::new(distances_builder.finish()),
816                ],
817            )?;
818            batches.push(batch);
819        }
820
821        let metadata = self.metadata();
822        let metadata = serde_json::to_string(&metadata)?;
823        let schema = Self::schema()
824            .as_ref()
825            .clone()
826            .with_metadata(HashMap::from_iter(vec![(
827                HNSW_METADATA_KEY.to_string(),
828                metadata,
829            )]));
830        let batch = concat_batches(&Self::schema(), batches.iter())?;
831        let batch = batch.with_schema(Arc::new(schema))?;
832        Ok(batch)
833    }
834}
835
836#[cfg(test)]
837mod tests {
838    use std::sync::Arc;
839
840    use arrow_array::FixedSizeListArray;
841    use arrow_schema::Schema;
842    use lance_arrow::FixedSizeListArrayExt;
843    use lance_file::{
844        reader::FileReader,
845        writer::{FileWriter, FileWriterOptions},
846    };
847    use lance_io::object_store::ObjectStore;
848    use lance_linalg::distance::DistanceType;
849    use lance_table::format::SelfDescribingFileReader;
850    use lance_table::io::manifest::ManifestDescribing;
851    use lance_testing::datagen::generate_random_array;
852    use object_store::path::Path;
853
854    use crate::scalar::IndexWriter;
855    use crate::vector::v3::subindex::IvfSubIndex;
856    use crate::vector::{
857        flat::storage::FlatFloatStorage,
858        graph::{DISTS_FIELD, NEIGHBORS_FIELD},
859        hnsw::{
860            builder::{HnswBuildParams, HnswQueryParams},
861            HNSW, VECTOR_ID_FIELD,
862        },
863    };
864
865    #[tokio::test]
866    async fn test_builder_write_load() {
867        const DIM: usize = 32;
868        const TOTAL: usize = 2048;
869        const NUM_EDGES: usize = 20;
870        let data = generate_random_array(TOTAL * DIM);
871        let fsl = FixedSizeListArray::try_new_from_values(data, DIM as i32).unwrap();
872        let store = Arc::new(FlatFloatStorage::new(fsl.clone(), DistanceType::L2));
873        let builder = HNSW::index_vectors(
874            store.as_ref(),
875            HnswBuildParams::default()
876                .num_edges(NUM_EDGES)
877                .ef_construction(50),
878        )
879        .unwrap();
880
881        let object_store = ObjectStore::memory();
882        let path = Path::from("test_builder_write_load");
883        let writer = object_store.create(&path).await.unwrap();
884        let schema = Schema::new(vec![
885            VECTOR_ID_FIELD.clone(),
886            NEIGHBORS_FIELD.clone(),
887            DISTS_FIELD.clone(),
888        ]);
889        let schema = lance_core::datatypes::Schema::try_from(&schema).unwrap();
890        let mut writer = FileWriter::<ManifestDescribing>::with_object_writer(
891            writer,
892            schema,
893            &FileWriterOptions::default(),
894        )
895        .unwrap();
896        let batch = builder.to_batch().unwrap();
897        let metadata = batch.schema_ref().metadata().clone();
898        writer.write_record_batch(batch).await.unwrap();
899        writer.finish_with_metadata(&metadata).await.unwrap();
900
901        let reader = FileReader::try_new_self_described(&object_store, &path, None)
902            .await
903            .unwrap();
904        let batch = reader
905            .read_range(0..reader.len(), reader.schema())
906            .await
907            .unwrap();
908        let loaded_hnsw = HNSW::load(batch).unwrap();
909
910        let query = fsl.value(0);
911        let k = 10;
912        let params = HnswQueryParams {
913            ef: 50,
914            lower_bound: None,
915            upper_bound: None,
916            dist_q_c: 0.0,
917        };
918        let builder_results = builder
919            .search_basic(query.clone(), k, &params, None, store.as_ref())
920            .unwrap();
921        let loaded_results = loaded_hnsw
922            .search_basic(query, k, &params, None, store.as_ref())
923            .unwrap();
924        assert_eq!(builder_results, loaded_results);
925    }
926}