use std::sync::Arc;
use arrow_array::FixedSizeListArray;
use arrow_schema::SchemaRef;
use arrow_schema::SortOptions;
use datafusion::physical_expr::expressions::Column;
use datafusion::physical_expr::{LexOrdering, PhysicalSortExpr};
use datafusion::physical_plan::ExecutionPlan;
#[allow(deprecated)]
use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec;
use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec;
use datafusion::physical_plan::sorts::sort::SortExec;
use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec;
use datafusion::physical_plan::union::UnionExec;
use lance_core::Result;
use lance_core::datatypes::OnMissing;
use tracing::instrument;
use crate::dataset::Dataset;
use crate::io::exec::TakeExec;
use super::collector::LsmDataSourceCollector;
use super::data_source::LsmDataSource;
use super::exec::{FreshnessPolarity, LsmGlobalPkDedupExec, LsmSourceTagExec};
use super::flushed_cache::{FlushedMemTableCache, open_flushed_dataset};
use super::projection::{
DISTANCE_COLUMN, build_scanner_projection, canonical_internal_schema, canonical_output_schema,
null_columns, project_to_canonical, wants_row_id,
};
use crate::session::Session;
pub struct LsmVectorSearchPlanner {
collector: LsmDataSourceCollector,
pk_columns: Vec<String>,
base_schema: SchemaRef,
vector_column: String,
distance_type: lance_linalg::distance::DistanceType,
dataset: Option<Arc<Dataset>>,
session: Option<Arc<Session>>,
flushed_cache: Option<Arc<FlushedMemTableCache>>,
}
impl LsmVectorSearchPlanner {
pub fn new(
collector: LsmDataSourceCollector,
pk_columns: Vec<String>,
base_schema: SchemaRef,
vector_column: String,
distance_type: lance_linalg::distance::DistanceType,
) -> Self {
Self {
collector,
pk_columns,
base_schema,
vector_column,
distance_type,
dataset: None,
session: None,
flushed_cache: None,
}
}
pub fn with_session(mut self, session: Arc<Session>) -> Self {
self.session = Some(session);
self
}
pub fn with_flushed_cache(mut self, cache: Arc<FlushedMemTableCache>) -> Self {
self.flushed_cache = Some(cache);
self
}
pub fn with_dataset(mut self, dataset: Arc<Dataset>) -> Self {
self.dataset = Some(dataset);
self
}
#[instrument(name = "lsm_vector_search", level = "info", skip_all, fields(k, nprobes, vector_column = %self.vector_column, distance_type = ?self.distance_type))]
pub async fn plan_search(
&self,
query_vector: &FixedSizeListArray,
k: usize,
nprobes: usize,
projection: Option<&[String]>,
refine_factor: Option<u32>,
) -> Result<Arc<dyn ExecutionPlan>> {
let sources = self.collector.collect()?;
if sources.is_empty() {
return self.empty_plan(projection);
}
let canonical_schema = canonical_output_schema(
projection,
&self.base_schema,
&self.pk_columns,
true, );
let internal_schema =
canonical_internal_schema(projection, &self.base_schema, &self.pk_columns, true);
let mut knn_plans = Vec::new();
for source in &sources {
let generation = source.generation();
let is_base = matches!(source, LsmDataSource::BaseTable { .. });
let knn = self
.build_knn_plan(source, query_vector, k, nprobes, projection, refine_factor)
.await?;
let polarity = match source {
LsmDataSource::FlushedMemTable { .. } => FreshnessPolarity::ReverseWrite,
LsmDataSource::ActiveMemTable { .. } | LsmDataSource::BaseTable { .. } => {
FreshnessPolarity::InsertOrder
}
};
let tagged: Arc<dyn ExecutionPlan> = Arc::new(LsmSourceTagExec::new(
knn,
generation,
polarity,
lance_core::ROW_ID,
));
let after_null = if is_base {
tagged
} else {
null_columns(tagged, &[lance_core::ROW_ID])?
};
let normalized = project_to_canonical(after_null, &internal_schema)?;
knn_plans.push(normalized);
}
#[allow(deprecated)]
let union: Arc<dyn ExecutionPlan> = Arc::new(UnionExec::new(knn_plans));
let coalesced: Arc<dyn ExecutionPlan> = Arc::new(CoalescePartitionsExec::new(union));
let deduped: Arc<dyn ExecutionPlan> = Arc::new(LsmGlobalPkDedupExec::new(
coalesced,
self.pk_columns.clone(),
super::exec::MEMTABLE_GEN_COLUMN,
super::exec::FRESHNESS_COLUMN,
));
let merged: Arc<dyn ExecutionPlan> = project_to_canonical(deduped, &canonical_schema)?;
let distance_idx = merged.schema().index_of(DISTANCE_COLUMN).map_err(|_| {
lance_core::Error::invalid_input(format!(
"Column '{}' not found in schema",
DISTANCE_COLUMN
))
})?;
let sort_expr = vec![PhysicalSortExpr {
expr: Arc::new(Column::new(DISTANCE_COLUMN, distance_idx)),
options: SortOptions {
descending: false,
nulls_first: false,
},
}];
let lex_ordering = LexOrdering::new(sort_expr).ok_or_else(|| {
lance_core::Error::internal("Failed to create LexOrdering".to_string())
})?;
let per_partition_sorted: Arc<dyn ExecutionPlan> = Arc::new(
SortExec::new(lex_ordering.clone(), merged)
.with_preserve_partitioning(true)
.with_fetch(Some(k)),
);
let merged_sorted: Arc<dyn ExecutionPlan> = Arc::new(
SortPreservingMergeExec::new(lex_ordering, per_partition_sorted).with_fetch(Some(k)),
);
#[allow(deprecated)]
let result = if let Some(dataset) = &self.dataset {
let cols = build_scanner_projection(projection, &self.base_schema, &self.pk_columns);
let output_projection = dataset
.empty_projection()
.union_columns(cols, OnMissing::Ignore)?;
let coalesced: Arc<dyn ExecutionPlan> =
Arc::new(CoalesceBatchesExec::new(merged_sorted.clone(), 8192));
if let Some(take_plan) =
TakeExec::try_new(dataset.clone(), coalesced, output_projection)?
{
Arc::new(take_plan) as Arc<dyn ExecutionPlan>
} else {
merged_sorted
}
} else {
merged_sorted
};
Ok(result)
}
async fn build_knn_plan(
&self,
source: &LsmDataSource,
query_vector: &FixedSizeListArray,
k: usize,
nprobes: usize,
projection: Option<&[String]>,
refine_factor: Option<u32>,
) -> Result<Arc<dyn ExecutionPlan>> {
match source {
LsmDataSource::BaseTable { dataset } => {
let mut scanner = dataset.scan();
let cols =
build_scanner_projection(projection, &self.base_schema, &self.pk_columns);
scanner.project(&cols.iter().map(|s| s.as_str()).collect::<Vec<_>>())?;
if wants_row_id(projection) {
scanner.with_row_id();
}
let query_arr = single_query_array(query_vector);
scanner.nearest(&self.vector_column, query_arr.as_ref(), k)?;
scanner.nprobes(nprobes);
scanner.distance_metric(self.distance_type);
scanner.fast_search();
if let Some(factor) = refine_factor {
scanner.refine(factor);
}
scanner.create_plan().await
}
LsmDataSource::FlushedMemTable { path, .. } => {
let dataset =
open_flushed_dataset(path, self.session.as_ref(), self.flushed_cache.as_ref())
.await?;
let mut scanner = dataset.scan();
let cols =
build_scanner_projection(projection, &self.base_schema, &self.pk_columns);
scanner.project(&cols.iter().map(|s| s.as_str()).collect::<Vec<_>>())?;
let query_arr = single_query_array(query_vector);
scanner.nearest(&self.vector_column, query_arr.as_ref(), k)?;
scanner.nprobes(nprobes);
scanner.distance_metric(self.distance_type);
scanner.fast_search();
scanner.create_plan().await
}
LsmDataSource::ActiveMemTable {
batch_store,
index_store,
schema,
..
} => {
use crate::dataset::mem_wal::memtable::scanner::MemTableScanner;
use arrow_array::Array;
let mut scanner =
MemTableScanner::new(batch_store.clone(), index_store.clone(), schema.clone());
let cols =
build_scanner_projection(projection, &self.base_schema, &self.pk_columns);
scanner.project(&cols.iter().map(|s| s.as_str()).collect::<Vec<_>>());
scanner.with_row_id();
let query_arr: Arc<dyn Array> = Arc::new(query_vector.clone());
scanner.nearest(&self.vector_column, query_arr, k);
scanner.nprobes(nprobes);
scanner.distance_metric(self.distance_type);
scanner.create_plan().await
}
}
}
fn empty_plan(&self, projection: Option<&[String]>) -> Result<Arc<dyn ExecutionPlan>> {
use datafusion::physical_plan::empty::EmptyExec;
let schema = canonical_output_schema(projection, &self.base_schema, &self.pk_columns, true);
Ok(Arc::new(EmptyExec::new(schema)))
}
}
fn single_query_array(query_vector: &FixedSizeListArray) -> arrow_array::ArrayRef {
use arrow_array::Array;
if query_vector.len() == 1 {
query_vector.value(0)
} else {
std::sync::Arc::new(query_vector.clone()) as arrow_array::ArrayRef
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dataset::{Dataset, WriteParams};
use arrow_array::{
Int32Array, RecordBatch, RecordBatchIterator, builder::FixedSizeListBuilder,
};
use arrow_schema::{DataType, Field, Schema as ArrowSchema};
use std::collections::HashMap;
fn create_vector_schema() -> Arc<ArrowSchema> {
let mut id_metadata = HashMap::new();
id_metadata.insert(
"lance-schema:unenforced-primary-key".to_string(),
"true".to_string(),
);
let id_field = Field::new("id", DataType::Int32, false).with_metadata(id_metadata);
Arc::new(ArrowSchema::new(vec![
id_field,
Field::new(
"vector",
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 4),
false,
),
]))
}
fn create_query_vector() -> FixedSizeListArray {
use arrow_array::builder::Float32Builder;
let mut builder = FixedSizeListBuilder::new(Float32Builder::new(), 4);
builder.values().append_value(0.1);
builder.values().append_value(0.2);
builder.values().append_value(0.3);
builder.values().append_value(0.4);
builder.append(true);
builder.finish()
}
fn create_test_batch(schema: &ArrowSchema, ids: &[i32]) -> RecordBatch {
use arrow_array::builder::Float32Builder;
let mut vector_builder = FixedSizeListBuilder::new(Float32Builder::new(), 4);
for id in ids {
let base = *id as f32 * 0.1;
vector_builder.values().append_value(base);
vector_builder.values().append_value(base + 0.1);
vector_builder.values().append_value(base + 0.2);
vector_builder.values().append_value(base + 0.3);
vector_builder.append(true);
}
RecordBatch::try_new(
Arc::new(schema.clone()),
vec![
Arc::new(Int32Array::from(ids.to_vec())),
Arc::new(vector_builder.finish()),
],
)
.unwrap()
}
async fn create_dataset(uri: &str, batches: Vec<RecordBatch>) -> Dataset {
let schema = batches[0].schema();
let reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema);
Dataset::write(reader, uri, Some(WriteParams::default()))
.await
.unwrap()
}
#[tokio::test]
async fn test_vector_search_plan_structure() {
let schema = create_vector_schema();
let temp_dir = tempfile::tempdir().unwrap();
let base_uri = format!("{}/base", temp_dir.path().to_str().unwrap());
let base_batch = create_test_batch(&schema, &[1, 2, 3]);
let base_dataset = Arc::new(create_dataset(&base_uri, vec![base_batch]).await);
let collector = LsmDataSourceCollector::new(base_dataset, vec![]);
let planner = LsmVectorSearchPlanner::new(
collector,
vec!["id".to_string()],
schema.clone(),
"vector".to_string(),
lance_linalg::distance::DistanceType::L2,
);
let query = create_query_vector();
let plan = planner.plan_search(&query, 10, 8, None, None).await;
plan.expect("planner should produce a plan even when memtables are empty");
}
#[tokio::test]
async fn test_projection_includes_pk() {
let schema = create_vector_schema();
let temp_dir = tempfile::tempdir().unwrap();
let base_uri = format!("{}/base", temp_dir.path().to_str().unwrap());
let base_batch = create_test_batch(&schema, &[1]);
let base_dataset = Arc::new(create_dataset(&base_uri, vec![base_batch]).await);
let collector = LsmDataSourceCollector::new(base_dataset, vec![]);
let _planner = LsmVectorSearchPlanner::new(
collector,
vec!["id".to_string()],
schema.clone(),
"vector".to_string(),
lance_linalg::distance::DistanceType::L2,
);
let cols =
build_scanner_projection(Some(&["vector".to_string()]), &schema, &["id".to_string()]);
assert!(cols.contains(&"vector".to_string()));
assert!(cols.contains(&"id".to_string()));
}
#[tokio::test]
async fn test_vector_search_base_plus_active_returns_distance() {
use crate::dataset::mem_wal::scanner::collector::{InMemoryMemTableRef, InMemoryMemTables};
use crate::dataset::mem_wal::write::{BatchStore, IndexStore};
use datafusion::prelude::SessionContext;
use futures::TryStreamExt;
let schema = create_vector_schema();
let temp_dir = tempfile::tempdir().unwrap();
let base_uri = format!("{}/base", temp_dir.path().to_str().unwrap());
let base_batch = create_test_batch(&schema, &[10, 20, 30]);
let base_dataset = Arc::new(create_dataset(&base_uri, vec![base_batch]).await);
let batch_store = Arc::new(BatchStore::with_capacity(16));
let mut index_store = IndexStore::new();
index_store.add_hnsw(
"vector_hnsw".to_string(),
1,
"vector".to_string(),
lance_linalg::distance::DistanceType::L2,
64,
8,
);
let batch = create_test_batch(&schema, &[1, 2, 3, 4]);
batch_store.append(batch.clone()).unwrap();
index_store
.insert_with_batch_position(&batch, 0, Some(0))
.unwrap();
let index_store = Arc::new(index_store);
let shard_id = uuid::Uuid::new_v4();
let collector = LsmDataSourceCollector::new(base_dataset, vec![]).with_in_memory_memtables(
shard_id,
InMemoryMemTables {
active: InMemoryMemTableRef {
batch_store,
index_store,
schema: schema.clone(),
generation: 1,
},
frozen: vec![],
},
);
let planner = LsmVectorSearchPlanner::new(
collector,
vec!["id".to_string()],
schema,
"vector".to_string(),
lance_linalg::distance::DistanceType::L2,
);
let query = create_query_vector();
let plan = planner
.plan_search(&query, 3, 1, None, None)
.await
.expect("planner should produce a plan");
let ctx = SessionContext::new();
let stream = plan.execute(0, ctx.task_ctx()).unwrap();
let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
let total: usize = batches.iter().map(|b| b.num_rows()).sum();
assert!(total > 0, "expected at least one result row");
let out_schema = batches[0].schema();
let out_cols: Vec<String> = out_schema
.fields()
.iter()
.map(|f| f.name().clone())
.collect();
assert!(
out_schema.field_with_name(DISTANCE_COLUMN).is_ok(),
"output schema is missing `_distance` column. Got: {:?}",
out_cols
);
assert!(
out_schema.field_with_name("_rowid").is_err(),
"`_rowid` leaked into output: {:?}",
out_cols
);
assert!(
out_schema
.field_with_name(super::super::exec::MEMTABLE_GEN_COLUMN)
.is_err(),
"`_memtable_gen` leaked into output: {:?}",
out_cols
);
let id_col = batches[0]
.column_by_name("id")
.expect("id column missing")
.as_any()
.downcast_ref::<Int32Array>()
.expect("id column should be Int32");
let dist_col = batches[0]
.column_by_name(DISTANCE_COLUMN)
.expect("_distance column missing")
.as_any()
.downcast_ref::<arrow_array::Float32Array>()
.expect("_distance column should be Float32");
assert_eq!(id_col.value(0), 1, "expected id=1 as nearest neighbor");
assert!(
dist_col.value(0).abs() < 1e-3,
"expected near-zero distance for self-match, got {}",
dist_col.value(0)
);
}
#[tokio::test]
async fn test_vector_search_with_projection_returns_distance_and_pk() {
use crate::dataset::mem_wal::scanner::collector::{InMemoryMemTableRef, InMemoryMemTables};
use crate::dataset::mem_wal::write::{BatchStore, IndexStore};
use datafusion::prelude::SessionContext;
use futures::TryStreamExt;
let schema = create_vector_schema();
let temp_dir = tempfile::tempdir().unwrap();
let base_uri = format!("{}/base", temp_dir.path().to_str().unwrap());
let base_batch = create_test_batch(&schema, &[10]);
let base_dataset = Arc::new(create_dataset(&base_uri, vec![base_batch]).await);
let batch_store = Arc::new(BatchStore::with_capacity(16));
let mut index_store = IndexStore::new();
index_store.add_hnsw(
"vector_hnsw".to_string(),
1,
"vector".to_string(),
lance_linalg::distance::DistanceType::L2,
64,
8,
);
let batch = create_test_batch(&schema, &[1, 2, 3, 4]);
batch_store.append(batch.clone()).unwrap();
index_store
.insert_with_batch_position(&batch, 0, Some(0))
.unwrap();
let index_store = Arc::new(index_store);
let shard_id = uuid::Uuid::new_v4();
let collector = LsmDataSourceCollector::new(base_dataset, vec![]).with_in_memory_memtables(
shard_id,
InMemoryMemTables {
active: InMemoryMemTableRef {
batch_store,
index_store,
schema: schema.clone(),
generation: 1,
},
frozen: vec![],
},
);
let planner = LsmVectorSearchPlanner::new(
collector,
vec!["id".to_string()],
schema,
"vector".to_string(),
lance_linalg::distance::DistanceType::L2,
);
let query = create_query_vector();
let projection = vec!["vector".to_string()];
let plan = planner
.plan_search(&query, 3, 1, Some(&projection), None)
.await
.expect("planner should produce a plan");
let ctx = SessionContext::new();
let stream = plan.execute(0, ctx.task_ctx()).unwrap();
let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
let total: usize = batches.iter().map(|b| b.num_rows()).sum();
assert!(total > 0, "expected at least one result row");
let out_schema = batches[0].schema();
assert!(
out_schema.field_with_name("id").is_ok(),
"PK column `id` should be auto-included even when user projects only `vector`"
);
assert!(out_schema.field_with_name("vector").is_ok());
assert!(out_schema.field_with_name(DISTANCE_COLUMN).is_ok());
}
#[tokio::test]
async fn test_vector_search_projection_with_explicit_distance_and_rowid() {
use crate::dataset::mem_wal::scanner::collector::{InMemoryMemTableRef, InMemoryMemTables};
use crate::dataset::mem_wal::write::{BatchStore, IndexStore};
use datafusion::prelude::SessionContext;
use futures::TryStreamExt;
let schema = create_vector_schema();
let temp_dir = tempfile::tempdir().unwrap();
let base_uri = format!("{}/base", temp_dir.path().to_str().unwrap());
let base_batch = create_test_batch(&schema, &[10]);
let base_dataset = Arc::new(create_dataset(&base_uri, vec![base_batch]).await);
let batch_store = Arc::new(BatchStore::with_capacity(16));
let mut index_store = IndexStore::new();
index_store.add_hnsw(
"vector_hnsw".to_string(),
1,
"vector".to_string(),
lance_linalg::distance::DistanceType::L2,
64,
8,
);
let batch = create_test_batch(&schema, &[1, 2, 3, 4]);
batch_store.append(batch.clone()).unwrap();
index_store
.insert_with_batch_position(&batch, 0, Some(0))
.unwrap();
let index_store = Arc::new(index_store);
let shard_id = uuid::Uuid::new_v4();
let collector = LsmDataSourceCollector::new(base_dataset, vec![]).with_in_memory_memtables(
shard_id,
InMemoryMemTables {
active: InMemoryMemTableRef {
batch_store,
index_store,
schema: schema.clone(),
generation: 1,
},
frozen: vec![],
},
);
let planner = LsmVectorSearchPlanner::new(
collector,
vec!["id".to_string()],
schema,
"vector".to_string(),
lance_linalg::distance::DistanceType::L2,
);
let query = create_query_vector();
let projection = vec![
"_distance".to_string(),
"vector".to_string(),
"_rowid".to_string(),
];
let plan = planner
.plan_search(&query, 3, 1, Some(&projection), None)
.await
.expect(
"planner must accept `_distance`/`_rowid` in projection without breaking the plan",
);
let ctx = SessionContext::new();
let stream = plan
.execute(0, ctx.task_ctx())
.expect("plan must execute when `_distance`/`_rowid` are in projection");
let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
let total: usize = batches.iter().map(|b| b.num_rows()).sum();
assert!(total > 0, "expected at least one result row");
let out_schema = batches[0].schema();
let distance_count = out_schema
.fields()
.iter()
.filter(|f| f.name() == DISTANCE_COLUMN)
.count();
assert_eq!(
distance_count,
1,
"`_distance` must appear exactly once in output, got schema: {:?}",
out_schema
.fields()
.iter()
.map(|f| f.name().clone())
.collect::<Vec<_>>()
);
assert!(out_schema.field_with_name("vector").is_ok());
assert!(out_schema.field_with_name("id").is_ok());
assert!(out_schema.field_with_name("_rowid").is_ok());
let rowid = batches[0].column_by_name("_rowid").unwrap();
assert!(
rowid.is_null(0),
"active-memtable `_rowid` must be NULL (not a real Lance row id), got: {:?}",
rowid
);
}
#[tokio::test]
async fn test_vector_search_strips_internal_columns_and_preserves_active_rows() {
use crate::dataset::mem_wal::scanner::collector::{InMemoryMemTableRef, InMemoryMemTables};
use crate::dataset::mem_wal::write::{BatchStore, IndexStore};
use datafusion::prelude::SessionContext;
use futures::TryStreamExt;
let schema = create_vector_schema();
let temp_dir = tempfile::tempdir().unwrap();
let base_uri = format!("{}/base", temp_dir.path().to_str().unwrap());
let base_batch = create_test_batch(&schema, &[10]);
let base_dataset = Arc::new(create_dataset(&base_uri, vec![base_batch]).await);
let batch_store = Arc::new(BatchStore::with_capacity(16));
let mut index_store = IndexStore::new();
index_store.add_hnsw(
"vector_hnsw".to_string(),
1,
"vector".to_string(),
lance_linalg::distance::DistanceType::L2,
64,
8,
);
let batch = create_test_batch(&schema, &[1, 2, 3, 4]);
batch_store.append(batch.clone()).unwrap();
index_store
.insert_with_batch_position(&batch, 0, Some(0))
.unwrap();
let index_store = Arc::new(index_store);
let shard_id = uuid::Uuid::new_v4();
let collector = LsmDataSourceCollector::new(base_dataset, vec![]).with_in_memory_memtables(
shard_id,
InMemoryMemTables {
active: InMemoryMemTableRef {
batch_store,
index_store,
schema: schema.clone(),
generation: 1,
},
frozen: vec![],
},
);
let planner = LsmVectorSearchPlanner::new(
collector,
vec!["id".to_string()],
schema,
"vector".to_string(),
lance_linalg::distance::DistanceType::L2,
);
let query = create_query_vector();
let plan = planner
.plan_search(&query, 3, 1, None, None)
.await
.expect("planner should produce a plan");
let plan_str = format!(
"{}",
datafusion::physical_plan::displayable(plan.as_ref()).indent(true)
);
assert!(
plan_str.contains("LsmGlobalPkDedupExec"),
"expected new global-dedup pipeline, got:\n{}",
plan_str
);
let ctx = SessionContext::new();
let stream = plan.execute(0, ctx.task_ctx()).unwrap();
let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
let total: usize = batches.iter().map(|b| b.num_rows()).sum();
assert!(total > 0, "expected at least one result row");
let out_schema = batches[0].schema();
assert!(out_schema.field_with_name(DISTANCE_COLUMN).is_ok());
for internal in [
super::super::exec::MEMTABLE_GEN_COLUMN,
super::super::exec::FRESHNESS_COLUMN,
] {
assert!(
out_schema.field_with_name(internal).is_err(),
"`{}` leaked into output: {:?}",
internal,
out_schema
.fields()
.iter()
.map(|f| f.name().clone())
.collect::<Vec<_>>(),
);
}
let mut all_ids: Vec<i32> = Vec::new();
for batch in &batches {
let id_col = batch
.column_by_name("id")
.expect("id column missing")
.as_any()
.downcast_ref::<Int32Array>()
.expect("id column should be Int32");
for i in 0..batch.num_rows() {
all_ids.push(id_col.value(i));
}
}
assert!(
all_ids.iter().any(|&id| (1..=4).contains(&id)),
"expected at least one active-memtable row (id in 1..=4) — none found, so \
active partitions were silently dropped. Got ids: {:?}",
all_ids
);
}
#[tokio::test]
async fn test_vector_search_dedup_across_generations() {
use crate::dataset::mem_wal::scanner::collector::{InMemoryMemTableRef, InMemoryMemTables};
use crate::dataset::mem_wal::scanner::data_source::ShardSnapshot;
use crate::dataset::mem_wal::write::{BatchStore, IndexStore};
use datafusion::prelude::SessionContext;
use futures::TryStreamExt;
let schema = create_vector_schema();
let temp_dir = tempfile::tempdir().unwrap();
let base_path = temp_dir.path().to_str().unwrap();
let base_uri = format!("{}/base", base_path);
let shard_id = uuid::Uuid::new_v4();
let gen1_uri = format!("{}/_mem_wal/{}/gen_1", base_uri, shard_id);
let old_pk1 = create_test_batch_with_vector(&schema, 1, [9.0, 9.0, 9.0, 9.0]);
create_dataset(&gen1_uri, vec![old_pk1]).await;
let batch_store = Arc::new(BatchStore::with_capacity(16));
let mut index_store = IndexStore::new();
index_store.add_hnsw(
"vector_hnsw".to_string(),
1,
"vector".to_string(),
lance_linalg::distance::DistanceType::L2,
64,
8,
);
let new_pk1 = create_test_batch_with_vector(&schema, 1, [0.1, 0.2, 0.3, 0.4]);
let other = create_test_batch_with_vector(&schema, 2, [5.0, 5.0, 5.0, 5.0]);
let (_, _, bp1) = batch_store.append(new_pk1.clone()).unwrap();
index_store
.insert_with_batch_position(&new_pk1, 0, Some(bp1))
.unwrap();
let (_, _, bp2) = batch_store.append(other.clone()).unwrap();
index_store
.insert_with_batch_position(&other, 1, Some(bp2))
.unwrap();
let index_store = Arc::new(index_store);
let shard_snapshot = ShardSnapshot::new(shard_id)
.with_current_generation(2)
.with_flushed_generation(1, "gen_1".to_string());
let collector = LsmDataSourceCollector::without_base_table(base_uri, vec![shard_snapshot])
.with_in_memory_memtables(
shard_id,
InMemoryMemTables {
active: InMemoryMemTableRef {
batch_store,
index_store,
schema: schema.clone(),
generation: 2,
},
frozen: vec![],
},
);
let planner = LsmVectorSearchPlanner::new(
collector,
vec!["id".to_string()],
schema,
"vector".to_string(),
lance_linalg::distance::DistanceType::L2,
);
let query = create_query_vector();
let plan = planner.plan_search(&query, 5, 1, None, None).await.unwrap();
let ctx = SessionContext::new();
let stream = plan.execute(0, ctx.task_ctx()).unwrap();
let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
let ids: Vec<i32> = batches
.iter()
.flat_map(|b| {
b.column_by_name("id")
.unwrap()
.as_any()
.downcast_ref::<Int32Array>()
.unwrap()
.values()
.to_vec()
})
.collect();
let pk1_count = ids.iter().filter(|i| **i == 1).count();
assert_eq!(
pk1_count, 1,
"pk=1 must appear exactly once after cross-source dedup; got ids={:?}",
ids,
);
}
#[tokio::test]
async fn test_vector_search_system_columns_real_only_for_base() {
use crate::dataset::mem_wal::scanner::collector::{InMemoryMemTableRef, InMemoryMemTables};
use crate::dataset::mem_wal::scanner::data_source::ShardSnapshot;
use crate::dataset::mem_wal::write::{BatchStore, IndexStore};
use crate::index::DatasetIndexExt;
use crate::index::vector::VectorIndexParams;
use datafusion::prelude::SessionContext;
use futures::TryStreamExt;
use lance_index::IndexType;
let schema = create_vector_schema();
let temp_dir = tempfile::tempdir().unwrap();
let base_uri = format!("{}/base", temp_dir.path().to_str().unwrap());
let base_batch = create_test_batch(&schema, &[1]);
let mut base_dataset = create_dataset(&base_uri, vec![base_batch]).await;
let ivf_flat = VectorIndexParams::ivf_flat(1, lance_linalg::distance::DistanceType::L2);
base_dataset
.create_index(&["vector"], IndexType::Vector, None, &ivf_flat, true)
.await
.unwrap();
let base_dataset = Arc::new(base_dataset);
let shard_id = uuid::Uuid::new_v4();
let gen1_uri = format!("{}/_mem_wal/{}/gen_1", base_uri, shard_id);
let gen1_batch = create_test_batch(&schema, &[2]);
let mut gen1_dataset = create_dataset(&gen1_uri, vec![gen1_batch]).await;
gen1_dataset
.create_index(&["vector"], IndexType::Vector, None, &ivf_flat, true)
.await
.unwrap();
let batch_store = Arc::new(BatchStore::with_capacity(16));
let mut index_store = IndexStore::new();
index_store.add_hnsw(
"vector_hnsw".to_string(),
1,
"vector".to_string(),
lance_linalg::distance::DistanceType::L2,
64,
8,
);
let active_batch = create_test_batch(&schema, &[3]);
batch_store.append(active_batch.clone()).unwrap();
index_store
.insert_with_batch_position(&active_batch, 0, Some(0))
.unwrap();
let index_store = Arc::new(index_store);
let shard_snapshot = ShardSnapshot::new(shard_id)
.with_current_generation(2)
.with_flushed_generation(1, "gen_1".to_string());
let collector = LsmDataSourceCollector::new(base_dataset, vec![shard_snapshot])
.with_in_memory_memtables(
shard_id,
InMemoryMemTables {
active: InMemoryMemTableRef {
batch_store,
index_store,
schema: schema.clone(),
generation: 2,
},
frozen: vec![],
},
);
let planner = LsmVectorSearchPlanner::new(
collector,
vec!["id".to_string()],
schema,
"vector".to_string(),
lance_linalg::distance::DistanceType::L2,
);
let query = create_query_vector();
let projection = vec![
"id".to_string(),
"_rowid".to_string(),
"_rowaddr".to_string(),
"vector".to_string(),
];
let plan = planner
.plan_search(&query, 3, 1, Some(&projection), None)
.await
.expect("planner should produce a plan");
let ctx = SessionContext::new();
let stream = plan.execute(0, ctx.task_ctx()).unwrap();
let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
let total: usize = batches.iter().map(|b| b.num_rows()).sum();
assert_eq!(total, 3, "expected one row per source");
let mut seen: std::collections::HashMap<i32, (bool, bool)> =
std::collections::HashMap::new();
for batch in &batches {
let ids = batch
.column_by_name("id")
.unwrap()
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
let rowid = batch.column_by_name("_rowid").unwrap();
let rowaddr = batch.column_by_name("_rowaddr").unwrap();
for i in 0..batch.num_rows() {
seen.insert(ids.value(i), (rowid.is_null(i), rowaddr.is_null(i)));
}
}
let (rid_null, raddr_null) = seen.get(&1).expect("base row id=1 missing");
assert!(
!rid_null,
"base row `_rowid` must be real (Lance row id), got NULL"
);
assert!(
raddr_null,
"`_rowaddr` is incompatible with vector_search's fast_search; must be NULL"
);
let (rid_null, raddr_null) = seen.get(&2).expect("flushed row id=2 missing");
assert!(rid_null, "flushed row `_rowid` must be NULL");
assert!(raddr_null, "flushed row `_rowaddr` must be NULL");
let (rid_null, raddr_null) = seen.get(&3).expect("active row id=3 missing");
assert!(rid_null, "active row `_rowid` must be NULL");
assert!(raddr_null, "active row `_rowaddr` must be NULL");
}
#[tokio::test]
async fn test_vector_search_empty_plan_with_system_columns() {
let schema = create_vector_schema();
let temp_dir = tempfile::tempdir().unwrap();
let base_uri = format!("{}/base", temp_dir.path().to_str().unwrap());
let collector = LsmDataSourceCollector::without_base_table(base_uri, vec![]);
let planner = LsmVectorSearchPlanner::new(
collector,
vec!["id".to_string()],
schema,
"vector".to_string(),
lance_linalg::distance::DistanceType::L2,
);
let projection = vec![
"_rowid".to_string(),
"vector".to_string(),
"_rowaddr".to_string(),
];
let query = create_query_vector();
let plan = planner
.plan_search(&query, 5, 1, Some(&projection), None)
.await
.expect("empty plan must accept system columns in projection");
let names: Vec<String> = plan
.schema()
.fields()
.iter()
.map(|f| f.name().clone())
.collect();
assert_eq!(
names,
vec![
"_rowid".to_string(),
"vector".to_string(),
"_rowaddr".to_string(),
"id".to_string(), "_distance".to_string(), ],
"empty KNN plan must honor user position for system cols and append PK + _distance"
);
}
#[tokio::test]
async fn test_vector_search_without_base_table() {
use futures::TryStreamExt;
let schema = create_vector_schema();
let temp_dir = tempfile::tempdir().unwrap();
let base_uri = format!("{}/base", temp_dir.path().to_str().unwrap());
let collector = LsmDataSourceCollector::without_base_table(base_uri, vec![]);
let planner = LsmVectorSearchPlanner::new(
collector,
vec!["id".to_string()],
schema,
"vector".to_string(),
lance_linalg::distance::DistanceType::L2,
);
let query = create_query_vector();
let plan = planner
.plan_search(&query, 10, 8, None, None)
.await
.expect("planner should produce a plan without a base table");
let plan_str = format!(
"{}",
datafusion::physical_plan::displayable(plan.as_ref()).indent(true)
);
assert!(
!plan_str.contains("base/data"),
"Plan must not scan base table, got: {}",
plan_str
);
let ctx = datafusion::prelude::SessionContext::new();
let stream = plan
.execute(0, ctx.task_ctx())
.expect("plan should execute without a base table");
let batches: Vec<RecordBatch> = stream
.try_collect()
.await
.expect("collecting batches should succeed");
let total: usize = batches.iter().map(|b| b.num_rows()).sum();
assert_eq!(total, 0, "fresh tier with no sources should yield no rows");
}
fn create_test_batch_with_vector(
schema: &ArrowSchema,
id: i32,
vector: [f32; 4],
) -> RecordBatch {
use arrow_array::builder::Float32Builder;
let mut vector_builder = FixedSizeListBuilder::new(Float32Builder::new(), 4);
for v in &vector {
vector_builder.values().append_value(*v);
}
vector_builder.append(true);
let vector_array = vector_builder.finish();
RecordBatch::try_new(
Arc::new(schema.clone()),
vec![Arc::new(Int32Array::from(vec![id])), Arc::new(vector_array)],
)
.unwrap()
}
#[tokio::test]
async fn test_vector_search_dedup_within_active_memtable() {
use crate::dataset::mem_wal::scanner::collector::{InMemoryMemTableRef, InMemoryMemTables};
use crate::dataset::mem_wal::write::{BatchStore, IndexStore};
use datafusion::prelude::SessionContext;
use futures::TryStreamExt;
let schema = create_vector_schema();
let temp_dir = tempfile::tempdir().unwrap();
let base_uri = format!("{}/base", temp_dir.path().to_str().unwrap());
let batch_store = Arc::new(BatchStore::with_capacity(16));
let mut index_store = IndexStore::new();
index_store.add_hnsw(
"vector_hnsw".to_string(),
1,
"vector".to_string(),
lance_linalg::distance::DistanceType::L2,
64,
8,
);
let b_old = create_test_batch_with_vector(&schema, 1, [9.0, 9.0, 9.0, 9.0]);
let b_new = create_test_batch_with_vector(&schema, 1, [0.1, 0.2, 0.3, 0.4]);
let b_other = create_test_batch_with_vector(&schema, 2, [5.0, 5.0, 5.0, 5.0]);
let (_, _, bp_old) = batch_store.append(b_old.clone()).unwrap();
index_store
.insert_with_batch_position(&b_old, 0, Some(bp_old))
.unwrap();
let (_, _, bp_new) = batch_store.append(b_new.clone()).unwrap();
index_store
.insert_with_batch_position(&b_new, 1, Some(bp_new))
.unwrap();
let (_, _, bp_other) = batch_store.append(b_other.clone()).unwrap();
index_store
.insert_with_batch_position(&b_other, 2, Some(bp_other))
.unwrap();
let index_store = Arc::new(index_store);
let shard_id = uuid::Uuid::new_v4();
let collector = LsmDataSourceCollector::without_base_table(base_uri, vec![])
.with_in_memory_memtables(
shard_id,
InMemoryMemTables {
active: InMemoryMemTableRef {
batch_store,
index_store,
schema: schema.clone(),
generation: 1,
},
frozen: vec![],
},
);
let planner = LsmVectorSearchPlanner::new(
collector,
vec!["id".to_string()],
schema,
"vector".to_string(),
lance_linalg::distance::DistanceType::L2,
);
let query = create_query_vector();
let plan = planner.plan_search(&query, 5, 1, None, None).await.unwrap();
let ctx = SessionContext::new();
let stream = plan.execute(0, ctx.task_ctx()).unwrap();
let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
let ids: Vec<i32> = batches
.iter()
.flat_map(|b| {
b.column_by_name("id")
.unwrap()
.as_any()
.downcast_ref::<Int32Array>()
.unwrap()
.values()
.to_vec()
})
.collect();
let pk1_count = ids.iter().filter(|i| **i == 1).count();
assert_eq!(
pk1_count, 1,
"pk=1 must appear exactly once after within-source dedup; got ids={:?}",
ids,
);
}
}