lance_index/vector/
flat.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Flat Vector Index.
5//!
6
7use std::sync::Arc;
8
9use arrow::{array::AsArray, buffer::NullBuffer};
10use arrow_array::{make_array, Array, ArrayRef, Float32Array, RecordBatch};
11use arrow_schema::{DataType, Field as ArrowField};
12use lance_arrow::*;
13use lance_core::{Error, Result, ROW_ID};
14use lance_linalg::distance::{multivec_distance, DistanceType};
15use snafu::location;
16use tracing::instrument;
17
18use super::DIST_COL;
19
20pub mod index;
21pub mod storage;
22pub mod transform;
23
24fn distance_field() -> ArrowField {
25    ArrowField::new(DIST_COL, DataType::Float32, true)
26}
27
28#[instrument(level = "debug", skip_all)]
29pub async fn compute_distance(
30    key: ArrayRef,
31    dt: DistanceType,
32    column: &str,
33    mut batch: RecordBatch,
34) -> Result<RecordBatch> {
35    if batch.column_by_name(DIST_COL).is_some() {
36        // Ignore the distance calculated from inner vector index.
37        batch = batch.drop_column(DIST_COL)?;
38    }
39    let vectors = batch
40        .column_by_name(column)
41        .ok_or_else(|| Error::Schema {
42            message: format!("column {} does not exist in dataset", column),
43            location: location!(),
44        })?
45        .clone();
46
47    let validity_buffer = if let Some(rowids) = batch.column_by_name(ROW_ID) {
48        NullBuffer::union(rowids.nulls(), vectors.nulls())
49    } else {
50        vectors.nulls().cloned()
51    };
52
53    tokio::task::spawn_blocking(move || {
54        // A selection vector may have been applied to _rowid column, so we need to
55        // push that onto vectors if possible.
56
57        let vectors = vectors
58            .into_data()
59            .into_builder()
60            .null_bit_buffer(validity_buffer.map(|b| b.buffer().clone()))
61            .build()
62            .map(make_array)?;
63        let distances = match vectors.data_type() {
64            DataType::FixedSizeList(_, _) => {
65                let vectors = vectors.as_fixed_size_list();
66                dt.arrow_batch_func()(key.as_ref(), vectors)? as ArrayRef
67            }
68            DataType::List(_) => {
69                let vectors = vectors.as_list();
70                let dists = multivec_distance(key.as_ref(), vectors, dt)?;
71                Arc::new(Float32Array::from(dists))
72            }
73            _ => {
74                unreachable!()
75            }
76        };
77
78        batch
79            .try_with_column(distance_field(), distances)
80            .map_err(|e| Error::Execution {
81                message: format!("Failed to adding distance column: {}", e),
82                location: location!(),
83            })
84    })
85    .await
86    .unwrap()
87}