lance_index/vector/
flat.rs1use std::sync::Arc;
8
9use arrow::{array::AsArray, buffer::NullBuffer};
10use arrow_array::{make_array, Array, ArrayRef, Float32Array, RecordBatch};
11use arrow_schema::{DataType, Field as ArrowField};
12use lance_arrow::*;
13use lance_core::{Error, Result, ROW_ID};
14use lance_linalg::distance::{multivec_distance, DistanceType};
15use snafu::location;
16use tracing::instrument;
17
18use super::DIST_COL;
19
20pub mod index;
21pub mod storage;
22pub mod transform;
23
24fn distance_field() -> ArrowField {
25 ArrowField::new(DIST_COL, DataType::Float32, true)
26}
27
28fn get_column_from_batch(batch: &RecordBatch, column: &str) -> Result<ArrayRef> {
35 if let Some(col) = batch.column_by_name(column) {
37 return Ok(col.clone());
38 }
39
40 let parts = lance_core::datatypes::parse_field_path(column).map_err(|e| Error::Schema {
43 message: format!("Failed to parse field path '{}': {}", column, e),
44 location: location!(),
45 })?;
46
47 if parts.is_empty() {
48 return Err(Error::Schema {
49 message: format!("Invalid empty field path: {}", column),
50 location: location!(),
51 });
52 }
53
54 let mut current_array: ArrayRef = batch
56 .column_by_name(&parts[0])
57 .ok_or_else(|| Error::Schema {
58 message: format!(
59 "Column '{}' does not exist in batch (looking for root field '{}')",
60 column, parts[0]
61 ),
62 location: location!(),
63 })?
64 .clone();
65
66 for part in &parts[1..] {
68 let struct_array = current_array
69 .as_any()
70 .downcast_ref::<arrow_array::StructArray>()
71 .ok_or_else(|| Error::Schema {
72 message: format!(
73 "Cannot access nested field '{}' in column '{}': parent is not a struct",
74 part, column
75 ),
76 location: location!(),
77 })?;
78
79 current_array = struct_array
80 .column_by_name(part)
81 .ok_or_else(|| Error::Schema {
82 message: format!(
83 "Nested field '{}' does not exist in column '{}'",
84 part, column
85 ),
86 location: location!(),
87 })?
88 .clone();
89 }
90
91 Ok(current_array)
92}
93
94#[instrument(level = "debug", skip_all)]
95pub async fn compute_distance(
96 key: ArrayRef,
97 dt: DistanceType,
98 column: &str,
99 mut batch: RecordBatch,
100) -> Result<RecordBatch> {
101 if batch.column_by_name(DIST_COL).is_some() {
102 batch = batch.drop_column(DIST_COL)?;
104 }
105
106 let vectors = get_column_from_batch(&batch, column)?;
107
108 let validity_buffer = if let Some(rowids) = batch.column_by_name(ROW_ID) {
109 NullBuffer::union(rowids.nulls(), vectors.nulls())
110 } else {
111 vectors.nulls().cloned()
112 };
113
114 tokio::task::spawn_blocking(move || {
115 let vectors = vectors
119 .into_data()
120 .into_builder()
121 .null_bit_buffer(validity_buffer.map(|b| b.buffer().clone()))
122 .build()
123 .map(make_array)?;
124 let distances = match vectors.data_type() {
125 DataType::FixedSizeList(_, _) => {
126 let vectors = vectors.as_fixed_size_list();
127 dt.arrow_batch_func()(key.as_ref(), vectors)? as ArrayRef
128 }
129 DataType::List(_) => {
130 let vectors = vectors.as_list();
131 let dists = multivec_distance(key.as_ref(), vectors, dt)?;
132 Arc::new(Float32Array::from(dists))
133 }
134 _ => {
135 unreachable!()
136 }
137 };
138
139 batch
140 .try_with_column(distance_field(), distances)
141 .map_err(|e| Error::Execution {
142 message: format!("Failed to adding distance column: {}", e),
143 location: location!(),
144 })
145 })
146 .await
147 .unwrap()
148}