lance_index/vector/
flat.rs1use std::sync::Arc;
8
9use arrow::{array::AsArray, buffer::NullBuffer};
10use arrow_array::{make_array, Array, ArrayRef, Float32Array, RecordBatch};
11use arrow_schema::{DataType, Field as ArrowField};
12use lance_arrow::*;
13use lance_core::{Error, Result, ROW_ID};
14use lance_linalg::distance::{multivec_distance, DistanceType};
15use snafu::location;
16use tracing::instrument;
17
18use super::DIST_COL;
19
20pub mod index;
21pub mod storage;
22pub mod transform;
23
24fn distance_field() -> ArrowField {
25 ArrowField::new(DIST_COL, DataType::Float32, true)
26}
27
28#[instrument(level = "debug", skip_all)]
29pub async fn compute_distance(
30 key: ArrayRef,
31 dt: DistanceType,
32 column: &str,
33 mut batch: RecordBatch,
34) -> Result<RecordBatch> {
35 if batch.column_by_name(DIST_COL).is_some() {
36 batch = batch.drop_column(DIST_COL)?;
38 }
39 let vectors = batch
40 .column_by_name(column)
41 .ok_or_else(|| Error::Schema {
42 message: format!("column {} does not exist in dataset", column),
43 location: location!(),
44 })?
45 .clone();
46
47 let validity_buffer = if let Some(rowids) = batch.column_by_name(ROW_ID) {
48 NullBuffer::union(rowids.nulls(), vectors.nulls())
49 } else {
50 vectors.nulls().cloned()
51 };
52
53 tokio::task::spawn_blocking(move || {
54 let vectors = vectors
58 .into_data()
59 .into_builder()
60 .null_bit_buffer(validity_buffer.map(|b| b.buffer().clone()))
61 .build()
62 .map(make_array)?;
63 let distances = match vectors.data_type() {
64 DataType::FixedSizeList(_, _) => {
65 let vectors = vectors.as_fixed_size_list();
66 dt.arrow_batch_func()(key.as_ref(), vectors)? as ArrayRef
67 }
68 DataType::List(_) => {
69 let vectors = vectors.as_list();
70 let dists = multivec_distance(key.as_ref(), vectors, dt)?;
71 Arc::new(Float32Array::from(dists))
72 }
73 _ => {
74 unreachable!()
75 }
76 };
77
78 batch
79 .try_with_column(distance_field(), distances)
80 .map_err(|e| Error::Execution {
81 message: format!("Failed to adding distance column: {}", e),
82 location: location!(),
83 })
84 })
85 .await
86 .unwrap()
87}