lance_index/vector/
flat.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Flat Vector Index.
5//!
6
7use std::sync::Arc;
8
9use arrow::{array::AsArray, buffer::NullBuffer};
10use arrow_array::{make_array, Array, ArrayRef, Float32Array, RecordBatch};
11use arrow_schema::{DataType, Field as ArrowField};
12use lance_arrow::*;
13use lance_core::{Error, Result, ROW_ID};
14use lance_linalg::distance::{multivec_distance, DistanceType};
15use snafu::location;
16use tracing::instrument;
17
18use super::DIST_COL;
19
20pub mod index;
21pub mod storage;
22pub mod transform;
23
24fn distance_field() -> ArrowField {
25    ArrowField::new(DIST_COL, DataType::Float32, true)
26}
27
28/// Get a column from a RecordBatch, supporting nested field paths.
29///
30/// This function handles:
31/// - Simple column names: "column"
32/// - Nested paths: "parent.child" or "parent.child.grandchild"
33/// - Backtick-escaped field names: "parent.`field.with.dots`"
34fn get_column_from_batch(batch: &RecordBatch, column: &str) -> Result<ArrayRef> {
35    // Try to get the column directly first (fast path for simple columns)
36    if let Some(col) = batch.column_by_name(column) {
37        return Ok(col.clone());
38    }
39
40    // Parse the field path using Lance's field path parsing logic
41    // This properly handles backtick-escaped field names
42    let parts = lance_core::datatypes::parse_field_path(column).map_err(|e| Error::Schema {
43        message: format!("Failed to parse field path '{}': {}", column, e),
44        location: location!(),
45    })?;
46
47    if parts.is_empty() {
48        return Err(Error::Schema {
49            message: format!("Invalid empty field path: {}", column),
50            location: location!(),
51        });
52    }
53
54    // Get the root column
55    let mut current_array: ArrayRef = batch
56        .column_by_name(&parts[0])
57        .ok_or_else(|| Error::Schema {
58            message: format!(
59                "Column '{}' does not exist in batch (looking for root field '{}')",
60                column, parts[0]
61            ),
62            location: location!(),
63        })?
64        .clone();
65
66    // Navigate through nested struct fields
67    for part in &parts[1..] {
68        let struct_array = current_array
69            .as_any()
70            .downcast_ref::<arrow_array::StructArray>()
71            .ok_or_else(|| Error::Schema {
72                message: format!(
73                    "Cannot access nested field '{}' in column '{}': parent is not a struct",
74                    part, column
75                ),
76                location: location!(),
77            })?;
78
79        current_array = struct_array
80            .column_by_name(part)
81            .ok_or_else(|| Error::Schema {
82                message: format!(
83                    "Nested field '{}' does not exist in column '{}'",
84                    part, column
85                ),
86                location: location!(),
87            })?
88            .clone();
89    }
90
91    Ok(current_array)
92}
93
94#[instrument(level = "debug", skip_all)]
95pub async fn compute_distance(
96    key: ArrayRef,
97    dt: DistanceType,
98    column: &str,
99    mut batch: RecordBatch,
100) -> Result<RecordBatch> {
101    if batch.column_by_name(DIST_COL).is_some() {
102        // Ignore the distance calculated from inner vector index.
103        batch = batch.drop_column(DIST_COL)?;
104    }
105
106    let vectors = get_column_from_batch(&batch, column)?;
107
108    let validity_buffer = if let Some(rowids) = batch.column_by_name(ROW_ID) {
109        NullBuffer::union(rowids.nulls(), vectors.nulls())
110    } else {
111        vectors.nulls().cloned()
112    };
113
114    tokio::task::spawn_blocking(move || {
115        // A selection vector may have been applied to _rowid column, so we need to
116        // push that onto vectors if possible.
117
118        let vectors = vectors
119            .into_data()
120            .into_builder()
121            .null_bit_buffer(validity_buffer.map(|b| b.buffer().clone()))
122            .build()
123            .map(make_array)?;
124        let distances = match vectors.data_type() {
125            DataType::FixedSizeList(_, _) => {
126                let vectors = vectors.as_fixed_size_list();
127                dt.arrow_batch_func()(key.as_ref(), vectors)? as ArrayRef
128            }
129            DataType::List(_) => {
130                let vectors = vectors.as_list();
131                let dists = multivec_distance(key.as_ref(), vectors, dt)?;
132                Arc::new(Float32Array::from(dists))
133            }
134            _ => {
135                unreachable!()
136            }
137        };
138
139        batch
140            .try_with_column(distance_field(), distances)
141            .map_err(|e| Error::Execution {
142                message: format!("Failed to adding distance column: {}", e),
143                location: location!(),
144            })
145    })
146    .await
147    .unwrap()
148}