use std::sync::Arc;
use arrow_array::FixedSizeListArray;
use arrow_schema::SortOptions;
use arrow_schema::{DataType, Field, Schema, SchemaRef};
use datafusion::physical_expr::expressions::Column;
use datafusion::physical_expr::{LexOrdering, PhysicalSortExpr};
use datafusion::physical_plan::ExecutionPlan;
use datafusion::physical_plan::limit::GlobalLimitExec;
use datafusion::physical_plan::sorts::sort::SortExec;
use datafusion::physical_plan::union::UnionExec;
use lance_core::Result;
use lance_index::scalar::bloomfilter::sbbf::Sbbf;
use super::collector::LsmDataSourceCollector;
use super::data_source::LsmDataSource;
use super::exec::{FilterStaleExec, GenerationBloomFilter, MemtableGenTagExec};
pub const DISTANCE_COLUMN: &str = "_distance";
pub struct LsmVectorSearchPlanner {
collector: LsmDataSourceCollector,
pk_columns: Vec<String>,
base_schema: SchemaRef,
bloom_filters: Vec<GenerationBloomFilter>,
vector_column: String,
distance_type: lance_linalg::distance::DistanceType,
}
impl LsmVectorSearchPlanner {
pub fn new(
collector: LsmDataSourceCollector,
pk_columns: Vec<String>,
base_schema: SchemaRef,
vector_column: String,
distance_type: lance_linalg::distance::DistanceType,
) -> Self {
Self {
collector,
pk_columns,
base_schema,
bloom_filters: Vec::new(),
vector_column,
distance_type,
}
}
pub fn with_bloom_filter(mut self, generation: u64, bloom_filter: Arc<Sbbf>) -> Self {
self.bloom_filters.push(GenerationBloomFilter {
generation,
bloom_filter,
});
self
}
pub fn with_bloom_filters(
mut self,
bloom_filters: impl IntoIterator<Item = (u64, Arc<Sbbf>)>,
) -> Self {
for (generation, bf) in bloom_filters {
self.bloom_filters.push(GenerationBloomFilter {
generation,
bloom_filter: bf,
});
}
self
}
pub async fn plan_search(
&self,
query_vector: &FixedSizeListArray,
k: usize,
nprobes: usize,
projection: Option<&[String]>,
) -> Result<Arc<dyn ExecutionPlan>> {
let sources = self.collector.collect()?;
if sources.is_empty() {
return self.empty_plan(projection);
}
let mut knn_plans = Vec::new();
for source in &sources {
let generation = source.generation();
let knn = self
.build_knn_plan(source, query_vector, k, nprobes, projection)
.await?;
let tagged: Arc<dyn ExecutionPlan> = Arc::new(MemtableGenTagExec::new(knn, generation));
knn_plans.push(tagged);
}
#[allow(deprecated)]
let union: Arc<dyn ExecutionPlan> = Arc::new(UnionExec::new(knn_plans));
let filtered: Arc<dyn ExecutionPlan> = if !self.bloom_filters.is_empty() {
Arc::new(FilterStaleExec::new(
union,
self.pk_columns.clone(),
self.bloom_filters.clone(),
))
} else {
union
};
let distance_idx = filtered.schema().index_of(DISTANCE_COLUMN).map_err(|_| {
lance_core::Error::invalid_input(format!(
"Column '{}' not found in schema",
DISTANCE_COLUMN
))
})?;
let sort_expr = vec![PhysicalSortExpr {
expr: Arc::new(Column::new(DISTANCE_COLUMN, distance_idx)),
options: SortOptions {
descending: false,
nulls_first: false,
},
}];
let lex_ordering = LexOrdering::new(sort_expr).ok_or_else(|| {
lance_core::Error::internal("Failed to create LexOrdering".to_string())
})?;
let sorted: Arc<dyn ExecutionPlan> = Arc::new(SortExec::new(lex_ordering, filtered));
let limited: Arc<dyn ExecutionPlan> = Arc::new(GlobalLimitExec::new(sorted, 0, Some(k)));
Ok(limited)
}
async fn build_knn_plan(
&self,
source: &LsmDataSource,
query_vector: &FixedSizeListArray,
k: usize,
nprobes: usize,
projection: Option<&[String]>,
) -> Result<Arc<dyn ExecutionPlan>> {
match source {
LsmDataSource::BaseTable { dataset } => {
let mut scanner = dataset.scan();
let cols = self.build_projection_for_knn(projection);
scanner.project(&cols.iter().map(|s| s.as_str()).collect::<Vec<_>>())?;
scanner.nearest(&self.vector_column, query_vector, k)?;
scanner.nprobes(nprobes);
scanner.distance_metric(self.distance_type);
scanner.fast_search();
scanner.create_plan().await
}
LsmDataSource::FlushedMemTable { path, .. } => {
let dataset = crate::dataset::DatasetBuilder::from_uri(path)
.load()
.await?;
let mut scanner = dataset.scan();
let cols = self.build_projection_for_knn(projection);
scanner.project(&cols.iter().map(|s| s.as_str()).collect::<Vec<_>>())?;
scanner.nearest(&self.vector_column, query_vector, k)?;
scanner.nprobes(nprobes);
scanner.distance_metric(self.distance_type);
scanner.fast_search();
scanner.create_plan().await
}
LsmDataSource::ActiveMemTable {
batch_store,
index_store,
schema,
..
} => {
use crate::dataset::mem_wal::memtable::scanner::MemTableScanner;
use arrow_array::Array;
let mut scanner =
MemTableScanner::new(batch_store.clone(), index_store.clone(), schema.clone());
if let Some(cols) = projection {
scanner.project(&cols.iter().map(|s| s.as_str()).collect::<Vec<_>>());
}
let query_arr: Arc<dyn Array> = Arc::new(query_vector.clone());
scanner.nearest(&self.vector_column, query_arr, k);
scanner.nprobes(nprobes);
scanner.distance_metric(self.distance_type);
scanner.create_plan().await
}
}
}
fn build_projection_for_knn(&self, projection: Option<&[String]>) -> Vec<String> {
let mut cols: Vec<String> = if let Some(p) = projection {
p.to_vec()
} else {
self.base_schema
.fields()
.iter()
.map(|f| f.name().clone())
.collect()
};
for pk in &self.pk_columns {
if !cols.contains(pk) {
cols.push(pk.clone());
}
}
cols
}
fn empty_plan(&self, projection: Option<&[String]>) -> Result<Arc<dyn ExecutionPlan>> {
use datafusion::physical_plan::empty::EmptyExec;
let mut fields: Vec<Arc<Field>> = if let Some(cols) = projection {
cols.iter()
.filter_map(|name| {
self.base_schema
.field_with_name(name)
.ok()
.map(|f| Arc::new(f.clone()))
})
.collect()
} else {
self.base_schema.fields().iter().cloned().collect()
};
fields.push(Arc::new(Field::new(
DISTANCE_COLUMN,
DataType::Float32,
false,
)));
let schema = Arc::new(Schema::new(fields));
Ok(Arc::new(EmptyExec::new(schema)))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dataset::{Dataset, WriteParams};
use arrow_array::{
Int32Array, RecordBatch, RecordBatchIterator, builder::FixedSizeListBuilder,
};
use arrow_schema::{DataType, Field, Schema as ArrowSchema};
use std::collections::HashMap;
fn create_vector_schema() -> Arc<ArrowSchema> {
let mut id_metadata = HashMap::new();
id_metadata.insert(
"lance-schema:unenforced-primary-key".to_string(),
"true".to_string(),
);
let id_field = Field::new("id", DataType::Int32, false).with_metadata(id_metadata);
Arc::new(ArrowSchema::new(vec![
id_field,
Field::new(
"vector",
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 4),
false,
),
]))
}
fn create_query_vector() -> FixedSizeListArray {
use arrow_array::builder::Float32Builder;
let mut builder = FixedSizeListBuilder::new(Float32Builder::new(), 4);
builder.values().append_value(0.1);
builder.values().append_value(0.2);
builder.values().append_value(0.3);
builder.values().append_value(0.4);
builder.append(true);
builder.finish()
}
fn create_test_batch(schema: &ArrowSchema, ids: &[i32]) -> RecordBatch {
use arrow_array::builder::Float32Builder;
let mut vector_builder = FixedSizeListBuilder::new(Float32Builder::new(), 4);
for id in ids {
let base = *id as f32 * 0.1;
vector_builder.values().append_value(base);
vector_builder.values().append_value(base + 0.1);
vector_builder.values().append_value(base + 0.2);
vector_builder.values().append_value(base + 0.3);
vector_builder.append(true);
}
RecordBatch::try_new(
Arc::new(schema.clone()),
vec![
Arc::new(Int32Array::from(ids.to_vec())),
Arc::new(vector_builder.finish()),
],
)
.unwrap()
}
async fn create_dataset(uri: &str, batches: Vec<RecordBatch>) -> Dataset {
let schema = batches[0].schema();
let reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema);
Dataset::write(reader, uri, Some(WriteParams::default()))
.await
.unwrap()
}
#[tokio::test]
async fn test_vector_search_plan_structure() {
let schema = create_vector_schema();
let temp_dir = tempfile::tempdir().unwrap();
let base_uri = format!("{}/base", temp_dir.path().to_str().unwrap());
let base_batch = create_test_batch(&schema, &[1, 2, 3]);
let base_dataset = Arc::new(create_dataset(&base_uri, vec![base_batch]).await);
let collector = LsmDataSourceCollector::new(base_dataset, vec![]);
let planner = LsmVectorSearchPlanner::new(
collector,
vec!["id".to_string()],
schema.clone(),
"vector".to_string(),
lance_linalg::distance::DistanceType::L2,
);
let query = create_query_vector();
let plan = planner.plan_search(&query, 10, 8, None).await;
assert!(plan.is_ok() || plan.is_err()); }
#[tokio::test]
async fn test_projection_includes_pk() {
let schema = create_vector_schema();
let temp_dir = tempfile::tempdir().unwrap();
let base_uri = format!("{}/base", temp_dir.path().to_str().unwrap());
let base_batch = create_test_batch(&schema, &[1]);
let base_dataset = Arc::new(create_dataset(&base_uri, vec![base_batch]).await);
let collector = LsmDataSourceCollector::new(base_dataset, vec![]);
let planner = LsmVectorSearchPlanner::new(
collector,
vec!["id".to_string()],
schema,
"vector".to_string(),
lance_linalg::distance::DistanceType::L2,
);
let cols = planner.build_projection_for_knn(Some(&["vector".to_string()]));
assert!(cols.contains(&"vector".to_string()));
assert!(cols.contains(&"id".to_string()));
}
}