use std::sync::Arc;
use arrow_array::FixedSizeListArray;
use arrow_schema::SchemaRef;
use arrow_schema::SortOptions;
use datafusion::physical_expr::expressions::Column;
use datafusion::physical_expr::{LexOrdering, PhysicalSortExpr};
use datafusion::physical_plan::ExecutionPlan;
#[allow(deprecated)]
use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec;
use datafusion::physical_plan::sorts::sort::SortExec;
use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec;
use datafusion::physical_plan::union::UnionExec;
use lance_core::Result;
use lance_core::datatypes::OnMissing;
use tracing::instrument;
use crate::dataset::Dataset;
use crate::io::exec::TakeExec;
use super::collector::LsmDataSourceCollector;
use super::data_source::LsmDataSource;
use super::flushed_cache::{DatasetCache, GenerationWarmer, open_flushed_dataset};
use super::projection::{
DISTANCE_COLUMN, build_scanner_projection, canonical_output_schema, null_columns,
project_to_canonical, wants_row_id,
};
use crate::session::Session;
pub struct LsmVectorSearchPlanner {
collector: LsmDataSourceCollector,
pk_columns: Vec<String>,
base_schema: SchemaRef,
vector_column: String,
distance_type: lance_linalg::distance::DistanceType,
dataset: Option<Arc<Dataset>>,
session: Option<Arc<Session>>,
flushed_cache: Option<Arc<dyn DatasetCache>>,
warmer: Option<Arc<dyn GenerationWarmer>>,
}
impl LsmVectorSearchPlanner {
pub fn new(
collector: LsmDataSourceCollector,
pk_columns: Vec<String>,
base_schema: SchemaRef,
vector_column: String,
distance_type: lance_linalg::distance::DistanceType,
) -> Self {
Self {
collector,
pk_columns,
base_schema,
vector_column,
distance_type,
dataset: None,
session: None,
flushed_cache: None,
warmer: None,
}
}
pub fn with_session(mut self, session: Arc<Session>) -> Self {
self.session = Some(session);
self
}
pub fn with_flushed_cache(mut self, cache: Arc<dyn DatasetCache>) -> Self {
self.flushed_cache = Some(cache);
self
}
pub fn with_warmer(mut self, warmer: Arc<dyn GenerationWarmer>) -> Self {
self.warmer = Some(warmer);
self
}
pub fn with_dataset(mut self, dataset: Arc<Dataset>) -> Self {
self.dataset = Some(dataset);
self
}
#[instrument(name = "lsm_vector_search", level = "info", skip_all, fields(k, nprobes, vector_column = %self.vector_column, distance_type = ?self.distance_type))]
pub async fn plan_search(
&self,
query_vector: &FixedSizeListArray,
k: usize,
nprobes: usize,
projection: Option<&[String]>,
refine_base_table: bool,
overfetch_factor: f64,
) -> Result<Arc<dyn ExecutionPlan>> {
let sources = self.collector.collect()?;
if sources.is_empty() {
return self.empty_plan(projection);
}
let overfetch_factor = overfetch_factor.max(1.0);
let block_lists = Box::pin(super::block_list::compute_source_block_lists(
&sources,
self.session.as_ref(),
self.flushed_cache.as_ref(),
))
.await?;
let canonical_schema = canonical_output_schema(
projection,
&self.base_schema,
&self.pk_columns,
true, );
let refine_base = refine_base_table || !block_lists.is_empty();
let arm_inputs: Vec<_> = sources
.iter()
.map(|source| {
let generation = source.generation();
let is_base = matches!(source, LsmDataSource::BaseTable { .. });
let is_active = matches!(source, LsmDataSource::ActiveMemTable { .. });
let blocked = block_lists.get(&(source.shard_id(), generation));
let fetch_k = if blocked.is_some() || is_active {
((k as f64) * overfetch_factor).ceil() as usize
} else {
k
};
(source, is_base, is_active, blocked, fetch_k)
})
.collect();
let built = futures::future::try_join_all(arm_inputs.iter().map(
|(source, is_base, _, _, fetch_k)| {
Box::pin(self.build_knn_plan(
source,
query_vector,
*fetch_k,
nprobes,
projection,
*is_base && refine_base,
))
},
))
.await?;
let mut knn_plans = Vec::new();
for ((source, is_base, is_active, blocked, _), (knn, active_max_visible)) in
arm_inputs.iter().zip(built)
{
let is_base = *is_base;
let is_active = *is_active;
let blocked = *blocked;
let knn = if is_active {
let (batch_store, index_store) = match source {
LsmDataSource::ActiveMemTable {
batch_store,
index_store,
..
} => (batch_store.clone(), index_store.clone()),
_ => unreachable!("is_active implies ActiveMemTable"),
};
let filtered: Arc<dyn ExecutionPlan> =
Arc::new(super::exec::NewestPkFilterExec::new(
knn,
self.pk_columns.clone(),
lance_core::ROW_ID,
index_store,
batch_store,
active_max_visible.expect("active arm returns its max_visible snapshot"),
));
sort_by_distance(filtered, k)?
} else {
match blocked {
Some(set) => Arc::new(super::exec::PkBlockFilterExec::new(
knn,
self.pk_columns.clone(),
set.clone(),
k,
)) as Arc<dyn ExecutionPlan>,
None => knn,
}
};
let after_null = if is_base {
knn
} else {
null_columns(knn, &[lance_core::ROW_ID])?
};
let normalized = project_to_canonical(after_null, &canonical_schema)?;
knn_plans.push(normalized);
}
#[allow(deprecated)]
let merged: Arc<dyn ExecutionPlan> = Arc::new(UnionExec::new(knn_plans));
let distance_idx = merged.schema().index_of(DISTANCE_COLUMN).map_err(|_| {
lance_core::Error::invalid_input(format!(
"Column '{}' not found in schema",
DISTANCE_COLUMN
))
})?;
let sort_expr = vec![PhysicalSortExpr {
expr: Arc::new(Column::new(DISTANCE_COLUMN, distance_idx)),
options: SortOptions {
descending: false,
nulls_first: false,
},
}];
let lex_ordering = LexOrdering::new(sort_expr).ok_or_else(|| {
lance_core::Error::internal("Failed to create LexOrdering".to_string())
})?;
let per_partition_sorted: Arc<dyn ExecutionPlan> = Arc::new(
SortExec::new(lex_ordering.clone(), merged)
.with_preserve_partitioning(true)
.with_fetch(Some(k)),
);
let merged_sorted: Arc<dyn ExecutionPlan> = Arc::new(
SortPreservingMergeExec::new(lex_ordering, per_partition_sorted).with_fetch(Some(k)),
);
#[allow(deprecated)]
let result = if let Some(dataset) = &self.dataset {
let cols = build_scanner_projection(projection, &self.base_schema, &self.pk_columns);
let output_projection = dataset
.empty_projection()
.union_columns(cols, OnMissing::Ignore)?;
let coalesced: Arc<dyn ExecutionPlan> =
Arc::new(CoalesceBatchesExec::new(merged_sorted.clone(), 8192));
if let Some(take_plan) =
TakeExec::try_new(dataset.clone(), coalesced, output_projection)?
{
Arc::new(take_plan) as Arc<dyn ExecutionPlan>
} else {
merged_sorted
}
} else {
merged_sorted
};
Ok(result)
}
async fn build_knn_plan(
&self,
source: &LsmDataSource,
query_vector: &FixedSizeListArray,
k: usize,
nprobes: usize,
projection: Option<&[String]>,
refine: bool,
) -> Result<(Arc<dyn ExecutionPlan>, Option<usize>)> {
match source {
LsmDataSource::BaseTable { dataset } => {
let mut scanner = dataset.scan();
let cols =
build_scanner_projection(projection, &self.base_schema, &self.pk_columns);
scanner.project(&cols.iter().map(|s| s.as_str()).collect::<Vec<_>>())?;
if wants_row_id(projection) {
scanner.with_row_id();
}
let query_arr = single_query_array(query_vector);
scanner.nearest(&self.vector_column, query_arr.as_ref(), k)?;
scanner.nprobes(nprobes);
scanner.distance_metric(self.distance_type);
scanner.fast_search();
if refine {
scanner.refine(1);
}
Ok((scanner.create_plan().await?, None))
}
LsmDataSource::FlushedMemTable { path, .. } => {
let dataset = open_flushed_dataset(
path,
self.session.as_ref(),
self.flushed_cache.as_ref(),
self.warmer.as_ref(),
)
.await?;
let mut scanner = dataset.scan();
let cols =
build_scanner_projection(projection, &self.base_schema, &self.pk_columns);
scanner.project(&cols.iter().map(|s| s.as_str()).collect::<Vec<_>>())?;
let query_arr = single_query_array(query_vector);
scanner.nearest(&self.vector_column, query_arr.as_ref(), k)?;
scanner.nprobes(nprobes);
scanner.distance_metric(self.distance_type);
scanner.fast_search();
Ok((scanner.create_plan().await?, None))
}
LsmDataSource::ActiveMemTable {
batch_store,
index_store,
schema,
..
} => {
use crate::dataset::mem_wal::memtable::scanner::MemTableScanner;
use arrow_array::Array;
let mut scanner =
MemTableScanner::new(batch_store.clone(), index_store.clone(), schema.clone());
let cols =
build_scanner_projection(projection, &self.base_schema, &self.pk_columns);
scanner.project(&cols.iter().map(|s| s.as_str()).collect::<Vec<_>>());
scanner.with_row_id();
let query_arr: Arc<dyn Array> = Arc::new(query_vector.clone());
scanner.nearest(&self.vector_column, query_arr, k);
scanner.nprobes(nprobes);
scanner.distance_metric(self.distance_type);
let plan = scanner.create_plan().await?;
Ok((plan, Some(scanner.max_visible_batch_position())))
}
}
}
fn empty_plan(&self, projection: Option<&[String]>) -> Result<Arc<dyn ExecutionPlan>> {
use datafusion::physical_plan::empty::EmptyExec;
let schema = canonical_output_schema(projection, &self.base_schema, &self.pk_columns, true);
Ok(Arc::new(EmptyExec::new(schema)))
}
}
fn sort_by_distance(plan: Arc<dyn ExecutionPlan>, k: usize) -> Result<Arc<dyn ExecutionPlan>> {
let idx = plan.schema().index_of(DISTANCE_COLUMN).map_err(|_| {
lance_core::Error::invalid_input(format!(
"Column '{}' not found in schema",
DISTANCE_COLUMN
))
})?;
let sort_expr = vec![PhysicalSortExpr {
expr: Arc::new(Column::new(DISTANCE_COLUMN, idx)),
options: SortOptions {
descending: false,
nulls_first: false,
},
}];
let ordering = LexOrdering::new(sort_expr)
.ok_or_else(|| lance_core::Error::internal("Failed to create LexOrdering".to_string()))?;
Ok(Arc::new(SortExec::new(ordering, plan).with_fetch(Some(k))))
}
fn single_query_array(query_vector: &FixedSizeListArray) -> arrow_array::ArrayRef {
use arrow_array::Array;
if query_vector.len() == 1 {
query_vector.value(0)
} else {
std::sync::Arc::new(query_vector.clone()) as arrow_array::ArrayRef
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dataset::{Dataset, WriteParams};
use arrow_array::{
Int32Array, RecordBatch, RecordBatchIterator, builder::FixedSizeListBuilder,
};
use arrow_schema::{DataType, Field, Schema as ArrowSchema};
use std::collections::HashMap;
fn create_vector_schema() -> Arc<ArrowSchema> {
let mut id_metadata = HashMap::new();
id_metadata.insert(
"lance-schema:unenforced-primary-key".to_string(),
"true".to_string(),
);
let id_field = Field::new("id", DataType::Int32, false).with_metadata(id_metadata);
Arc::new(ArrowSchema::new(vec![
id_field,
Field::new(
"vector",
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 4),
false,
),
]))
}
fn create_query_vector() -> FixedSizeListArray {
use arrow_array::builder::Float32Builder;
let mut builder = FixedSizeListBuilder::new(Float32Builder::new(), 4);
builder.values().append_value(0.1);
builder.values().append_value(0.2);
builder.values().append_value(0.3);
builder.values().append_value(0.4);
builder.append(true);
builder.finish()
}
fn create_test_batch(schema: &ArrowSchema, ids: &[i32]) -> RecordBatch {
use arrow_array::builder::Float32Builder;
let mut vector_builder = FixedSizeListBuilder::new(Float32Builder::new(), 4);
for id in ids {
let base = *id as f32 * 0.1;
vector_builder.values().append_value(base);
vector_builder.values().append_value(base + 0.1);
vector_builder.values().append_value(base + 0.2);
vector_builder.values().append_value(base + 0.3);
vector_builder.append(true);
}
RecordBatch::try_new(
Arc::new(schema.clone()),
vec![
Arc::new(Int32Array::from(ids.to_vec())),
Arc::new(vector_builder.finish()),
],
)
.unwrap()
}
async fn create_dataset(uri: &str, batches: Vec<RecordBatch>) -> Dataset {
let schema = batches[0].schema();
let has_id = schema.column_with_name("id").is_some();
let reader = RecordBatchIterator::new(batches.clone().into_iter().map(Ok), schema);
let dataset = Dataset::write(reader, uri, Some(WriteParams::default()))
.await
.unwrap();
if has_id {
crate::dataset::mem_wal::scanner::block_list::write_pk_sidecar(uri, &batches, &["id"])
.await
.unwrap();
}
dataset
}
#[tokio::test]
async fn test_vector_search_plan_structure() {
let schema = create_vector_schema();
let temp_dir = tempfile::tempdir().unwrap();
let base_uri = format!("{}/base", temp_dir.path().to_str().unwrap());
let base_batch = create_test_batch(&schema, &[1, 2, 3]);
let base_dataset = Arc::new(create_dataset(&base_uri, vec![base_batch]).await);
let collector = LsmDataSourceCollector::new(base_dataset, vec![]);
let planner = LsmVectorSearchPlanner::new(
collector,
vec!["id".to_string()],
schema.clone(),
"vector".to_string(),
lance_linalg::distance::DistanceType::L2,
);
let query = create_query_vector();
let plan = planner.plan_search(&query, 10, 8, None, false, 1.0).await;
plan.expect("planner should produce a plan even when memtables are empty");
}
#[tokio::test]
async fn test_projection_includes_pk() {
let schema = create_vector_schema();
let temp_dir = tempfile::tempdir().unwrap();
let base_uri = format!("{}/base", temp_dir.path().to_str().unwrap());
let base_batch = create_test_batch(&schema, &[1]);
let base_dataset = Arc::new(create_dataset(&base_uri, vec![base_batch]).await);
let collector = LsmDataSourceCollector::new(base_dataset, vec![]);
let _planner = LsmVectorSearchPlanner::new(
collector,
vec!["id".to_string()],
schema.clone(),
"vector".to_string(),
lance_linalg::distance::DistanceType::L2,
);
let cols =
build_scanner_projection(Some(&["vector".to_string()]), &schema, &["id".to_string()]);
assert!(cols.contains(&"vector".to_string()));
assert!(cols.contains(&"id".to_string()));
}
#[tokio::test]
async fn test_vector_search_base_plus_active_returns_distance() {
use crate::dataset::mem_wal::scanner::collector::{InMemoryMemTableRef, InMemoryMemTables};
use crate::dataset::mem_wal::write::{BatchStore, IndexStore};
use datafusion::prelude::SessionContext;
use futures::TryStreamExt;
let schema = create_vector_schema();
let temp_dir = tempfile::tempdir().unwrap();
let base_uri = format!("{}/base", temp_dir.path().to_str().unwrap());
let base_batch = create_test_batch(&schema, &[10, 20, 30]);
let base_dataset = Arc::new(create_dataset(&base_uri, vec![base_batch]).await);
let batch_store = Arc::new(BatchStore::with_capacity(16));
let mut index_store = IndexStore::new();
index_store.enable_pk_index(&[("id".to_string(), 0)]);
index_store.add_hnsw(
"vector_hnsw".to_string(),
1,
"vector".to_string(),
lance_linalg::distance::DistanceType::L2,
64,
8,
);
let batch = create_test_batch(&schema, &[1, 2, 3, 4]);
batch_store.append(batch.clone()).unwrap();
index_store
.insert_with_batch_position(&batch, 0, Some(0))
.unwrap();
let index_store = Arc::new(index_store);
let shard_id = uuid::Uuid::new_v4();
let collector = LsmDataSourceCollector::new(base_dataset, vec![]).with_in_memory_memtables(
shard_id,
InMemoryMemTables {
active: InMemoryMemTableRef {
batch_store,
index_store,
schema: schema.clone(),
generation: 1,
},
frozen: vec![],
},
);
let planner = LsmVectorSearchPlanner::new(
collector,
vec!["id".to_string()],
schema,
"vector".to_string(),
lance_linalg::distance::DistanceType::L2,
);
let query = create_query_vector();
let plan = planner
.plan_search(&query, 3, 1, None, false, 1.0)
.await
.expect("planner should produce a plan");
let ctx = SessionContext::new();
let stream = plan.execute(0, ctx.task_ctx()).unwrap();
let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
let total: usize = batches.iter().map(|b| b.num_rows()).sum();
assert!(total > 0, "expected at least one result row");
let out_schema = batches[0].schema();
let out_cols: Vec<String> = out_schema
.fields()
.iter()
.map(|f| f.name().clone())
.collect();
assert!(
out_schema.field_with_name(DISTANCE_COLUMN).is_ok(),
"output schema is missing `_distance` column. Got: {:?}",
out_cols
);
assert!(
out_schema.field_with_name("_rowid").is_err(),
"`_rowid` leaked into output: {:?}",
out_cols
);
assert!(
out_schema
.field_with_name(super::super::exec::MEMTABLE_GEN_COLUMN)
.is_err(),
"`_memtable_gen` leaked into output: {:?}",
out_cols
);
let id_col = batches[0]
.column_by_name("id")
.expect("id column missing")
.as_any()
.downcast_ref::<Int32Array>()
.expect("id column should be Int32");
let dist_col = batches[0]
.column_by_name(DISTANCE_COLUMN)
.expect("_distance column missing")
.as_any()
.downcast_ref::<arrow_array::Float32Array>()
.expect("_distance column should be Float32");
assert_eq!(id_col.value(0), 1, "expected id=1 as nearest neighbor");
assert!(
dist_col.value(0).abs() < 1e-3,
"expected near-zero distance for self-match, got {}",
dist_col.value(0)
);
}
#[tokio::test]
async fn test_vector_search_with_projection_returns_distance_and_pk() {
use crate::dataset::mem_wal::scanner::collector::{InMemoryMemTableRef, InMemoryMemTables};
use crate::dataset::mem_wal::write::{BatchStore, IndexStore};
use datafusion::prelude::SessionContext;
use futures::TryStreamExt;
let schema = create_vector_schema();
let temp_dir = tempfile::tempdir().unwrap();
let base_uri = format!("{}/base", temp_dir.path().to_str().unwrap());
let base_batch = create_test_batch(&schema, &[10]);
let base_dataset = Arc::new(create_dataset(&base_uri, vec![base_batch]).await);
let batch_store = Arc::new(BatchStore::with_capacity(16));
let mut index_store = IndexStore::new();
index_store.enable_pk_index(&[("id".to_string(), 0)]);
index_store.add_hnsw(
"vector_hnsw".to_string(),
1,
"vector".to_string(),
lance_linalg::distance::DistanceType::L2,
64,
8,
);
let batch = create_test_batch(&schema, &[1, 2, 3, 4]);
batch_store.append(batch.clone()).unwrap();
index_store
.insert_with_batch_position(&batch, 0, Some(0))
.unwrap();
let index_store = Arc::new(index_store);
let shard_id = uuid::Uuid::new_v4();
let collector = LsmDataSourceCollector::new(base_dataset, vec![]).with_in_memory_memtables(
shard_id,
InMemoryMemTables {
active: InMemoryMemTableRef {
batch_store,
index_store,
schema: schema.clone(),
generation: 1,
},
frozen: vec![],
},
);
let planner = LsmVectorSearchPlanner::new(
collector,
vec!["id".to_string()],
schema,
"vector".to_string(),
lance_linalg::distance::DistanceType::L2,
);
let query = create_query_vector();
let projection = vec!["vector".to_string()];
let plan = planner
.plan_search(&query, 3, 1, Some(&projection), false, 1.0)
.await
.expect("planner should produce a plan");
let ctx = SessionContext::new();
let stream = plan.execute(0, ctx.task_ctx()).unwrap();
let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
let total: usize = batches.iter().map(|b| b.num_rows()).sum();
assert!(total > 0, "expected at least one result row");
let out_schema = batches[0].schema();
assert!(
out_schema.field_with_name("id").is_ok(),
"PK column `id` should be auto-included even when user projects only `vector`"
);
assert!(out_schema.field_with_name("vector").is_ok());
assert!(out_schema.field_with_name(DISTANCE_COLUMN).is_ok());
}
#[tokio::test]
async fn test_vector_search_projection_with_explicit_distance_and_rowid() {
use crate::dataset::mem_wal::scanner::collector::{InMemoryMemTableRef, InMemoryMemTables};
use crate::dataset::mem_wal::write::{BatchStore, IndexStore};
use datafusion::prelude::SessionContext;
use futures::TryStreamExt;
let schema = create_vector_schema();
let temp_dir = tempfile::tempdir().unwrap();
let base_uri = format!("{}/base", temp_dir.path().to_str().unwrap());
let base_batch = create_test_batch(&schema, &[10]);
let base_dataset = Arc::new(create_dataset(&base_uri, vec![base_batch]).await);
let batch_store = Arc::new(BatchStore::with_capacity(16));
let mut index_store = IndexStore::new();
index_store.enable_pk_index(&[("id".to_string(), 0)]);
index_store.add_hnsw(
"vector_hnsw".to_string(),
1,
"vector".to_string(),
lance_linalg::distance::DistanceType::L2,
64,
8,
);
let batch = create_test_batch(&schema, &[1, 2, 3, 4]);
batch_store.append(batch.clone()).unwrap();
index_store
.insert_with_batch_position(&batch, 0, Some(0))
.unwrap();
let index_store = Arc::new(index_store);
let shard_id = uuid::Uuid::new_v4();
let collector = LsmDataSourceCollector::new(base_dataset, vec![]).with_in_memory_memtables(
shard_id,
InMemoryMemTables {
active: InMemoryMemTableRef {
batch_store,
index_store,
schema: schema.clone(),
generation: 1,
},
frozen: vec![],
},
);
let planner = LsmVectorSearchPlanner::new(
collector,
vec!["id".to_string()],
schema,
"vector".to_string(),
lance_linalg::distance::DistanceType::L2,
);
let query = create_query_vector();
let projection = vec![
"_distance".to_string(),
"vector".to_string(),
"_rowid".to_string(),
];
let plan = planner
.plan_search(&query, 3, 1, Some(&projection), false, 1.0)
.await
.expect(
"planner must accept `_distance`/`_rowid` in projection without breaking the plan",
);
let ctx = SessionContext::new();
let stream = plan
.execute(0, ctx.task_ctx())
.expect("plan must execute when `_distance`/`_rowid` are in projection");
let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
let total: usize = batches.iter().map(|b| b.num_rows()).sum();
assert!(total > 0, "expected at least one result row");
let out_schema = batches[0].schema();
let distance_count = out_schema
.fields()
.iter()
.filter(|f| f.name() == DISTANCE_COLUMN)
.count();
assert_eq!(
distance_count,
1,
"`_distance` must appear exactly once in output, got schema: {:?}",
out_schema
.fields()
.iter()
.map(|f| f.name().clone())
.collect::<Vec<_>>()
);
assert!(out_schema.field_with_name("vector").is_ok());
assert!(out_schema.field_with_name("id").is_ok());
assert!(out_schema.field_with_name("_rowid").is_ok());
let rowid = batches[0].column_by_name("_rowid").unwrap();
assert!(
rowid.is_null(0),
"active-memtable `_rowid` must be NULL (not a real Lance row id), got: {:?}",
rowid
);
}
#[tokio::test]
async fn test_vector_search_strips_internal_columns_and_preserves_active_rows() {
use crate::dataset::mem_wal::scanner::collector::{InMemoryMemTableRef, InMemoryMemTables};
use crate::dataset::mem_wal::write::{BatchStore, IndexStore};
use datafusion::prelude::SessionContext;
use futures::TryStreamExt;
let schema = create_vector_schema();
let temp_dir = tempfile::tempdir().unwrap();
let base_uri = format!("{}/base", temp_dir.path().to_str().unwrap());
let base_batch = create_test_batch(&schema, &[10]);
let base_dataset = Arc::new(create_dataset(&base_uri, vec![base_batch]).await);
let batch_store = Arc::new(BatchStore::with_capacity(16));
let mut index_store = IndexStore::new();
index_store.enable_pk_index(&[("id".to_string(), 0)]);
index_store.add_hnsw(
"vector_hnsw".to_string(),
1,
"vector".to_string(),
lance_linalg::distance::DistanceType::L2,
64,
8,
);
let batch = create_test_batch(&schema, &[1, 2, 3, 4]);
batch_store.append(batch.clone()).unwrap();
index_store
.insert_with_batch_position(&batch, 0, Some(0))
.unwrap();
let index_store = Arc::new(index_store);
let shard_id = uuid::Uuid::new_v4();
let collector = LsmDataSourceCollector::new(base_dataset, vec![]).with_in_memory_memtables(
shard_id,
InMemoryMemTables {
active: InMemoryMemTableRef {
batch_store,
index_store,
schema: schema.clone(),
generation: 1,
},
frozen: vec![],
},
);
let planner = LsmVectorSearchPlanner::new(
collector,
vec!["id".to_string()],
schema,
"vector".to_string(),
lance_linalg::distance::DistanceType::L2,
);
let query = create_query_vector();
let plan = planner
.plan_search(&query, 3, 1, None, false, 1.0)
.await
.expect("planner should produce a plan");
let plan_str = format!(
"{}",
datafusion::physical_plan::displayable(plan.as_ref()).indent(true)
);
assert!(
!plan_str.contains("LsmGlobalPkDedupExec") && !plan_str.contains("LsmSourceTagExec"),
"vector plan must not contain a global PK dedup or source tag node, got:\n{}",
plan_str
);
assert!(
plan_str.contains("NewestPkFilterExec") && plan_str.contains("SortPreservingMergeExec"),
"expected per-arm dedup + distance merge, got:\n{}",
plan_str
);
let ctx = SessionContext::new();
let stream = plan.execute(0, ctx.task_ctx()).unwrap();
let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
let total: usize = batches.iter().map(|b| b.num_rows()).sum();
assert!(total > 0, "expected at least one result row");
let out_schema = batches[0].schema();
assert!(out_schema.field_with_name(DISTANCE_COLUMN).is_ok());
for internal in [super::super::exec::MEMTABLE_GEN_COLUMN, "_freshness"] {
assert!(
out_schema.field_with_name(internal).is_err(),
"`{}` leaked into output: {:?}",
internal,
out_schema
.fields()
.iter()
.map(|f| f.name().clone())
.collect::<Vec<_>>(),
);
}
let mut all_ids: Vec<i32> = Vec::new();
for batch in &batches {
let id_col = batch
.column_by_name("id")
.expect("id column missing")
.as_any()
.downcast_ref::<Int32Array>()
.expect("id column should be Int32");
for i in 0..batch.num_rows() {
all_ids.push(id_col.value(i));
}
}
assert!(
all_ids.iter().any(|&id| (1..=4).contains(&id)),
"expected at least one active-memtable row (id in 1..=4) — none found, so \
active partitions were silently dropped. Got ids: {:?}",
all_ids
);
}
#[tokio::test]
async fn test_vector_search_dedup_across_generations() {
use crate::dataset::mem_wal::scanner::collector::{InMemoryMemTableRef, InMemoryMemTables};
use crate::dataset::mem_wal::scanner::data_source::ShardSnapshot;
use crate::dataset::mem_wal::write::{BatchStore, IndexStore};
use datafusion::prelude::SessionContext;
use futures::TryStreamExt;
let schema = create_vector_schema();
let temp_dir = tempfile::tempdir().unwrap();
let base_path = temp_dir.path().to_str().unwrap();
let base_uri = format!("{}/base", base_path);
let shard_id = uuid::Uuid::new_v4();
let gen1_uri = format!("{}/_mem_wal/{}/gen_1", base_uri, shard_id);
let old_pk1 = create_test_batch_with_vector(&schema, 1, [9.0, 9.0, 9.0, 9.0]);
create_dataset(&gen1_uri, vec![old_pk1]).await;
let batch_store = Arc::new(BatchStore::with_capacity(16));
let mut index_store = IndexStore::new();
index_store.enable_pk_index(&[("id".to_string(), 0)]);
index_store.add_hnsw(
"vector_hnsw".to_string(),
1,
"vector".to_string(),
lance_linalg::distance::DistanceType::L2,
64,
8,
);
let new_pk1 = create_test_batch_with_vector(&schema, 1, [0.1, 0.2, 0.3, 0.4]);
let other = create_test_batch_with_vector(&schema, 2, [5.0, 5.0, 5.0, 5.0]);
let (_, _, bp1) = batch_store.append(new_pk1.clone()).unwrap();
index_store
.insert_with_batch_position(&new_pk1, 0, Some(bp1))
.unwrap();
let (_, _, bp2) = batch_store.append(other.clone()).unwrap();
index_store
.insert_with_batch_position(&other, 1, Some(bp2))
.unwrap();
let index_store = Arc::new(index_store);
let shard_snapshot = ShardSnapshot::new(shard_id)
.with_current_generation(2)
.with_flushed_generation(1, "gen_1".to_string());
let collector = LsmDataSourceCollector::without_base_table(base_uri, vec![shard_snapshot])
.with_in_memory_memtables(
shard_id,
InMemoryMemTables {
active: InMemoryMemTableRef {
batch_store,
index_store,
schema: schema.clone(),
generation: 2,
},
frozen: vec![],
},
);
let planner = LsmVectorSearchPlanner::new(
collector,
vec!["id".to_string()],
schema,
"vector".to_string(),
lance_linalg::distance::DistanceType::L2,
);
let query = create_query_vector();
let plan = planner
.plan_search(&query, 5, 1, None, false, 1.0)
.await
.unwrap();
let ctx = SessionContext::new();
let stream = plan.execute(0, ctx.task_ctx()).unwrap();
let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
let ids: Vec<i32> = batches
.iter()
.flat_map(|b| {
b.column_by_name("id")
.unwrap()
.as_any()
.downcast_ref::<Int32Array>()
.unwrap()
.values()
.to_vec()
})
.collect();
let pk1_count = ids.iter().filter(|i| **i == 1).count();
assert_eq!(
pk1_count, 1,
"pk=1 must appear exactly once in the merged top-k; got ids={:?}",
ids,
);
}
#[tokio::test]
async fn test_vector_search_system_columns_real_only_for_base() {
use crate::dataset::mem_wal::scanner::collector::{InMemoryMemTableRef, InMemoryMemTables};
use crate::dataset::mem_wal::scanner::data_source::ShardSnapshot;
use crate::dataset::mem_wal::write::{BatchStore, IndexStore};
use crate::index::DatasetIndexExt;
use crate::index::vector::VectorIndexParams;
use datafusion::prelude::SessionContext;
use futures::TryStreamExt;
use lance_index::IndexType;
let schema = create_vector_schema();
let temp_dir = tempfile::tempdir().unwrap();
let base_uri = format!("{}/base", temp_dir.path().to_str().unwrap());
let base_batch = create_test_batch(&schema, &[1]);
let mut base_dataset = create_dataset(&base_uri, vec![base_batch]).await;
let ivf_flat = VectorIndexParams::ivf_flat(1, lance_linalg::distance::DistanceType::L2);
base_dataset
.create_index(&["vector"], IndexType::Vector, None, &ivf_flat, true)
.await
.unwrap();
let base_dataset = Arc::new(base_dataset);
let shard_id = uuid::Uuid::new_v4();
let gen1_uri = format!("{}/_mem_wal/{}/gen_1", base_uri, shard_id);
let gen1_batch = create_test_batch(&schema, &[2]);
let mut gen1_dataset = create_dataset(&gen1_uri, vec![gen1_batch]).await;
gen1_dataset
.create_index(&["vector"], IndexType::Vector, None, &ivf_flat, true)
.await
.unwrap();
let batch_store = Arc::new(BatchStore::with_capacity(16));
let mut index_store = IndexStore::new();
index_store.enable_pk_index(&[("id".to_string(), 0)]);
index_store.add_hnsw(
"vector_hnsw".to_string(),
1,
"vector".to_string(),
lance_linalg::distance::DistanceType::L2,
64,
8,
);
let active_batch = create_test_batch(&schema, &[3]);
batch_store.append(active_batch.clone()).unwrap();
index_store
.insert_with_batch_position(&active_batch, 0, Some(0))
.unwrap();
let index_store = Arc::new(index_store);
let shard_snapshot = ShardSnapshot::new(shard_id)
.with_current_generation(2)
.with_flushed_generation(1, "gen_1".to_string());
let collector = LsmDataSourceCollector::new(base_dataset, vec![shard_snapshot])
.with_in_memory_memtables(
shard_id,
InMemoryMemTables {
active: InMemoryMemTableRef {
batch_store,
index_store,
schema: schema.clone(),
generation: 2,
},
frozen: vec![],
},
);
let planner = LsmVectorSearchPlanner::new(
collector,
vec!["id".to_string()],
schema,
"vector".to_string(),
lance_linalg::distance::DistanceType::L2,
);
let query = create_query_vector();
let projection = vec![
"id".to_string(),
"_rowid".to_string(),
"_rowaddr".to_string(),
"vector".to_string(),
];
let plan = planner
.plan_search(&query, 3, 1, Some(&projection), false, 1.0)
.await
.expect("planner should produce a plan");
let ctx = SessionContext::new();
let stream = plan.execute(0, ctx.task_ctx()).unwrap();
let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
let total: usize = batches.iter().map(|b| b.num_rows()).sum();
assert_eq!(total, 3, "expected one row per source");
let mut seen: std::collections::HashMap<i32, (bool, bool)> =
std::collections::HashMap::new();
for batch in &batches {
let ids = batch
.column_by_name("id")
.unwrap()
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
let rowid = batch.column_by_name("_rowid").unwrap();
let rowaddr = batch.column_by_name("_rowaddr").unwrap();
for i in 0..batch.num_rows() {
seen.insert(ids.value(i), (rowid.is_null(i), rowaddr.is_null(i)));
}
}
let (rid_null, raddr_null) = seen.get(&1).expect("base row id=1 missing");
assert!(
!rid_null,
"base row `_rowid` must be real (Lance row id), got NULL"
);
assert!(
raddr_null,
"`_rowaddr` is incompatible with vector_search's fast_search; must be NULL"
);
let (rid_null, raddr_null) = seen.get(&2).expect("flushed row id=2 missing");
assert!(rid_null, "flushed row `_rowid` must be NULL");
assert!(raddr_null, "flushed row `_rowaddr` must be NULL");
let (rid_null, raddr_null) = seen.get(&3).expect("active row id=3 missing");
assert!(rid_null, "active row `_rowid` must be NULL");
assert!(raddr_null, "active row `_rowaddr` must be NULL");
}
#[tokio::test]
async fn test_vector_search_empty_plan_with_system_columns() {
let schema = create_vector_schema();
let temp_dir = tempfile::tempdir().unwrap();
let base_uri = format!("{}/base", temp_dir.path().to_str().unwrap());
let collector = LsmDataSourceCollector::without_base_table(base_uri, vec![]);
let planner = LsmVectorSearchPlanner::new(
collector,
vec!["id".to_string()],
schema,
"vector".to_string(),
lance_linalg::distance::DistanceType::L2,
);
let projection = vec![
"_rowid".to_string(),
"vector".to_string(),
"_rowaddr".to_string(),
];
let query = create_query_vector();
let plan = planner
.plan_search(&query, 5, 1, Some(&projection), false, 1.0)
.await
.expect("empty plan must accept system columns in projection");
let names: Vec<String> = plan
.schema()
.fields()
.iter()
.map(|f| f.name().clone())
.collect();
assert_eq!(
names,
vec![
"_rowid".to_string(),
"vector".to_string(),
"_rowaddr".to_string(),
"id".to_string(), "_distance".to_string(), ],
"empty KNN plan must honor user position for system cols and append PK + _distance"
);
}
#[tokio::test]
async fn test_vector_search_without_base_table() {
use futures::TryStreamExt;
let schema = create_vector_schema();
let temp_dir = tempfile::tempdir().unwrap();
let base_uri = format!("{}/base", temp_dir.path().to_str().unwrap());
let collector = LsmDataSourceCollector::without_base_table(base_uri, vec![]);
let planner = LsmVectorSearchPlanner::new(
collector,
vec!["id".to_string()],
schema,
"vector".to_string(),
lance_linalg::distance::DistanceType::L2,
);
let query = create_query_vector();
let plan = planner
.plan_search(&query, 10, 8, None, false, 1.0)
.await
.expect("planner should produce a plan without a base table");
let plan_str = format!(
"{}",
datafusion::physical_plan::displayable(plan.as_ref()).indent(true)
);
assert!(
!plan_str.contains("base/data"),
"Plan must not scan base table, got: {}",
plan_str
);
let ctx = datafusion::prelude::SessionContext::new();
let stream = plan
.execute(0, ctx.task_ctx())
.expect("plan should execute without a base table");
let batches: Vec<RecordBatch> = stream
.try_collect()
.await
.expect("collecting batches should succeed");
let total: usize = batches.iter().map(|b| b.num_rows()).sum();
assert_eq!(total, 0, "fresh tier with no sources should yield no rows");
}
fn create_test_batch_with_vector(
schema: &ArrowSchema,
id: i32,
vector: [f32; 4],
) -> RecordBatch {
use arrow_array::builder::Float32Builder;
let mut vector_builder = FixedSizeListBuilder::new(Float32Builder::new(), 4);
for v in &vector {
vector_builder.values().append_value(*v);
}
vector_builder.append(true);
let vector_array = vector_builder.finish();
RecordBatch::try_new(
Arc::new(schema.clone()),
vec![Arc::new(Int32Array::from(vec![id])), Arc::new(vector_array)],
)
.unwrap()
}
#[tokio::test]
async fn test_vector_search_dedup_within_active_memtable() {
use crate::dataset::mem_wal::scanner::collector::{InMemoryMemTableRef, InMemoryMemTables};
use crate::dataset::mem_wal::write::{BatchStore, IndexStore};
use datafusion::prelude::SessionContext;
use futures::TryStreamExt;
let schema = create_vector_schema();
let temp_dir = tempfile::tempdir().unwrap();
let base_uri = format!("{}/base", temp_dir.path().to_str().unwrap());
let batch_store = Arc::new(BatchStore::with_capacity(16));
let mut index_store = IndexStore::new();
index_store.enable_pk_index(&[("id".to_string(), 0)]);
index_store.add_hnsw(
"vector_hnsw".to_string(),
1,
"vector".to_string(),
lance_linalg::distance::DistanceType::L2,
64,
8,
);
let b_old = create_test_batch_with_vector(&schema, 1, [9.0, 9.0, 9.0, 9.0]);
let b_new = create_test_batch_with_vector(&schema, 1, [0.1, 0.2, 0.3, 0.4]);
let b_other = create_test_batch_with_vector(&schema, 2, [5.0, 5.0, 5.0, 5.0]);
let (_, _, bp_old) = batch_store.append(b_old.clone()).unwrap();
index_store
.insert_with_batch_position(&b_old, 0, Some(bp_old))
.unwrap();
let (_, _, bp_new) = batch_store.append(b_new.clone()).unwrap();
index_store
.insert_with_batch_position(&b_new, 1, Some(bp_new))
.unwrap();
let (_, _, bp_other) = batch_store.append(b_other.clone()).unwrap();
index_store
.insert_with_batch_position(&b_other, 2, Some(bp_other))
.unwrap();
let index_store = Arc::new(index_store);
let shard_id = uuid::Uuid::new_v4();
let collector = LsmDataSourceCollector::without_base_table(base_uri, vec![])
.with_in_memory_memtables(
shard_id,
InMemoryMemTables {
active: InMemoryMemTableRef {
batch_store,
index_store,
schema: schema.clone(),
generation: 1,
},
frozen: vec![],
},
);
let planner = LsmVectorSearchPlanner::new(
collector,
vec!["id".to_string()],
schema,
"vector".to_string(),
lance_linalg::distance::DistanceType::L2,
);
let query = create_query_vector();
let plan = planner
.plan_search(&query, 5, 1, None, false, 1.0)
.await
.unwrap();
let plan_str = format!(
"{}",
datafusion::physical_plan::displayable(plan.as_ref()).indent(true)
);
assert!(
plan_str.contains("NewestPkFilterExec"),
"active vector arm must self-dedup, got:\n{}",
plan_str
);
let ctx = SessionContext::new();
let stream = plan.execute(0, ctx.task_ctx()).unwrap();
let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
let ids: Vec<i32> = batches
.iter()
.flat_map(|b| {
b.column_by_name("id")
.unwrap()
.as_any()
.downcast_ref::<Int32Array>()
.unwrap()
.values()
.to_vec()
})
.collect();
let pk1_count = ids.iter().filter(|i| **i == 1).count();
assert_eq!(
pk1_count, 1,
"pk=1 must appear exactly once after within-source dedup; got ids={:?}",
ids,
);
}
#[tokio::test]
async fn test_vector_search_active_stale_update_out_of_neighborhood() {
use crate::dataset::mem_wal::scanner::collector::{InMemoryMemTableRef, InMemoryMemTables};
use crate::dataset::mem_wal::write::{BatchStore, IndexStore};
use datafusion::prelude::SessionContext;
use futures::TryStreamExt;
let schema = create_vector_schema();
let temp_dir = tempfile::tempdir().unwrap();
let base_uri = format!("{}/base", temp_dir.path().to_str().unwrap());
let batch_store = Arc::new(BatchStore::with_capacity(16));
let mut index_store = IndexStore::new();
index_store.enable_pk_index(&[("id".to_string(), 0)]);
index_store.add_hnsw(
"vector_hnsw".to_string(),
1,
"vector".to_string(),
lance_linalg::distance::DistanceType::L2,
64,
8,
);
let q = [0.1, 0.2, 0.3, 0.4];
let stale_then_fillers = batch_rows(
&schema,
&[
(1, q),
(10, [0.11, 0.21, 0.31, 0.41]),
(11, [0.13, 0.23, 0.33, 0.43]),
(12, [0.15, 0.25, 0.35, 0.45]),
(13, [0.17, 0.27, 0.37, 0.47]),
(14, [0.19, 0.29, 0.39, 0.49]),
],
);
let (bp0, off0, _) = batch_store.append(stale_then_fillers.clone()).unwrap();
index_store
.insert_with_batch_position(&stale_then_fillers, off0, Some(bp0))
.unwrap();
let fresh_pk1 = batch_rows(&schema, &[(1, [9.0, 9.0, 9.0, 9.0])]);
let (bp1, off1, _) = batch_store.append(fresh_pk1.clone()).unwrap();
index_store
.insert_with_batch_position(&fresh_pk1, off1, Some(bp1))
.unwrap();
let index_store = Arc::new(index_store);
let shard_id = uuid::Uuid::new_v4();
let collector = LsmDataSourceCollector::without_base_table(base_uri, vec![])
.with_in_memory_memtables(
shard_id,
InMemoryMemTables {
active: InMemoryMemTableRef {
batch_store,
index_store,
schema: schema.clone(),
generation: 1,
},
frozen: vec![],
},
);
let planner = LsmVectorSearchPlanner::new(
collector,
vec!["id".to_string()],
schema,
"vector".to_string(),
lance_linalg::distance::DistanceType::L2,
);
let query = create_query_vector();
let plan = planner
.plan_search(&query, 3, 1, None, false, 1.0)
.await
.unwrap();
let ctx = SessionContext::new();
let stream = plan.execute(0, ctx.task_ctx()).unwrap();
let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
let rows = collect_id_dist(&batches);
assert!(
!rows.iter().any(|&(id, d)| id == 1 && d.abs() < 1e-3),
"stale near pk=1 leaked: its live vector is far from the query, so it \
must not appear at distance ~0. results={:?}",
rows
);
}
#[tokio::test]
async fn test_vector_search_stale_read_when_fresh_falls_out_of_top_k() {
use crate::dataset::mem_wal::scanner::collector::{InMemoryMemTableRef, InMemoryMemTables};
use crate::dataset::mem_wal::write::{BatchStore, IndexStore};
use crate::index::DatasetIndexExt;
use crate::index::vector::VectorIndexParams;
use datafusion::prelude::SessionContext;
use futures::TryStreamExt;
use lance_index::IndexType;
let schema = create_vector_schema();
let temp_dir = tempfile::tempdir().unwrap();
let base_uri = format!("{}/base", temp_dir.path().to_str().unwrap());
let base_batch = create_test_batch_with_vector(&schema, 1, [0.1, 0.2, 0.3, 0.4]);
let mut base_dataset = create_dataset(&base_uri, vec![base_batch]).await;
let ivf_flat = VectorIndexParams::ivf_flat(1, lance_linalg::distance::DistanceType::L2);
base_dataset
.create_index(&["vector"], IndexType::Vector, None, &ivf_flat, true)
.await
.unwrap();
let base_dataset = Arc::new(base_dataset);
let batch_store = Arc::new(BatchStore::with_capacity(16));
let mut index_store = IndexStore::new();
index_store.enable_pk_index(&[("id".to_string(), 0)]);
index_store.add_hnsw(
"vector_hnsw".to_string(),
1,
"vector".to_string(),
lance_linalg::distance::DistanceType::L2,
64,
8,
);
let fresh_pk1 = create_test_batch_with_vector(&schema, 1, [9.0, 9.0, 9.0, 9.0]);
let pk2 = create_test_batch_with_vector(&schema, 2, [1.0, 1.0, 1.0, 1.0]);
let (_, _, bp1) = batch_store.append(fresh_pk1.clone()).unwrap();
index_store
.insert_with_batch_position(&fresh_pk1, 0, Some(bp1))
.unwrap();
let (_, _, bp2) = batch_store.append(pk2.clone()).unwrap();
index_store
.insert_with_batch_position(&pk2, 1, Some(bp2))
.unwrap();
let index_store = Arc::new(index_store);
let shard_id = uuid::Uuid::new_v4();
let collector = LsmDataSourceCollector::new(base_dataset, vec![]).with_in_memory_memtables(
shard_id,
InMemoryMemTables {
active: InMemoryMemTableRef {
batch_store,
index_store,
schema: schema.clone(),
generation: 1,
},
frozen: vec![],
},
);
let planner = LsmVectorSearchPlanner::new(
collector,
vec!["id".to_string()],
schema,
"vector".to_string(),
lance_linalg::distance::DistanceType::L2,
);
let query = create_query_vector();
let plan = planner
.plan_search(&query, 1, 1, None, false, 1.0)
.await
.unwrap();
let ctx = SessionContext::new();
let stream = plan.execute(0, ctx.task_ctx()).unwrap();
let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
let mut rows: Vec<(i32, f32)> = Vec::new();
for b in &batches {
let ids = b
.column_by_name("id")
.unwrap()
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
let dist = b
.column_by_name(DISTANCE_COLUMN)
.unwrap()
.as_any()
.downcast_ref::<arrow_array::Float32Array>()
.unwrap();
for i in 0..b.num_rows() {
rows.push((ids.value(i), dist.value(i)));
}
}
assert!(
rows.iter().all(|&(id, d)| !(id == 1 && d.abs() < 1e-3)),
"stale read: pk=1 was updated to a far vector in gen 1, but the \
stale base-table copy (distance ~0) was served because fresh \
pk=1 fell out of the active arm's top-k and never deduped it; \
got {:?}",
rows
);
assert_eq!(
rows.len(),
1,
"k=1 must return exactly one row, got {:?}",
rows
);
assert_eq!(
rows[0].0, 2,
"expected nearest live neighbor pk=2, got {:?}",
rows
);
let still_filtered = planner
.plan_search(&query, 1, 1, None, false, 0.0)
.await
.unwrap();
let still_filtered_rows = {
let stream = still_filtered
.execute(0, SessionContext::new().task_ctx())
.unwrap();
let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
collect_id_dist(&batches)
};
assert!(
still_filtered_rows
.iter()
.all(|&(id, d)| !(id == 1 && d.abs() < 1e-3)),
"block-list is unconditional: stale pk=1 must stay suppressed even \
with overfetch_factor < 1.0; got {:?}",
still_filtered_rows
);
}
fn batch_rows(schema: &ArrowSchema, rows: &[(i32, [f32; 4])]) -> RecordBatch {
use arrow_array::builder::Float32Builder;
let mut vb = FixedSizeListBuilder::new(Float32Builder::new(), 4);
for (_, v) in rows {
for x in v {
vb.values().append_value(*x);
}
vb.append(true);
}
let ids: Vec<i32> = rows.iter().map(|(id, _)| *id).collect();
RecordBatch::try_new(
Arc::new(schema.clone()),
vec![Arc::new(Int32Array::from(ids)), Arc::new(vb.finish())],
)
.unwrap()
}
fn collect_id_dist(batches: &[RecordBatch]) -> Vec<(i32, f32)> {
let mut rows = Vec::new();
for b in batches {
let ids = b
.column_by_name("id")
.unwrap()
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
let dist = b
.column_by_name(DISTANCE_COLUMN)
.unwrap()
.as_any()
.downcast_ref::<arrow_array::Float32Array>()
.unwrap();
for i in 0..b.num_rows() {
rows.push((ids.value(i), dist.value(i)));
}
}
rows
}
#[tokio::test]
async fn test_vector_search_overfetch_backfills_when_top_k_all_stale() {
use crate::dataset::mem_wal::scanner::collector::{InMemoryMemTableRef, InMemoryMemTables};
use crate::dataset::mem_wal::write::{BatchStore, IndexStore};
use crate::index::DatasetIndexExt;
use crate::index::vector::VectorIndexParams;
use datafusion::prelude::SessionContext;
use futures::TryStreamExt;
use lance_index::IndexType;
let schema = create_vector_schema();
let temp_dir = tempfile::tempdir().unwrap();
let base_uri = format!("{}/base", temp_dir.path().to_str().unwrap());
let q = [0.1, 0.2, 0.3, 0.4];
let near = [0.12, 0.22, 0.32, 0.42];
let base_batch = batch_rows(
&schema,
&[(1, q), (2, q), (3, q), (4, near), (5, near), (6, near)],
);
let mut base_dataset = create_dataset(&base_uri, vec![base_batch]).await;
let ivf_flat = VectorIndexParams::ivf_flat(1, lance_linalg::distance::DistanceType::L2);
base_dataset
.create_index(&["vector"], IndexType::Vector, None, &ivf_flat, true)
.await
.unwrap();
let base_dataset = Arc::new(base_dataset);
let batch_store = Arc::new(BatchStore::with_capacity(16));
let mut index_store = IndexStore::new();
index_store.enable_pk_index(&[("id".to_string(), 0)]);
index_store.add_hnsw(
"vector_hnsw".to_string(),
1,
"vector".to_string(),
lance_linalg::distance::DistanceType::L2,
64,
8,
);
let far = [9.0, 9.0, 9.0, 9.0];
let active_batch = batch_rows(&schema, &[(1, far), (2, far), (3, far)]);
let (_, _, bp) = batch_store.append(active_batch.clone()).unwrap();
index_store
.insert_with_batch_position(&active_batch, 0, Some(bp))
.unwrap();
let index_store = Arc::new(index_store);
let shard_id = uuid::Uuid::new_v4();
let collector = LsmDataSourceCollector::new(base_dataset, vec![]).with_in_memory_memtables(
shard_id,
InMemoryMemTables {
active: InMemoryMemTableRef {
batch_store,
index_store,
schema: schema.clone(),
generation: 1,
},
frozen: vec![],
},
);
let planner = LsmVectorSearchPlanner::new(
collector,
vec!["id".to_string()],
schema,
"vector".to_string(),
lance_linalg::distance::DistanceType::L2,
);
let query = create_query_vector();
let plan = planner
.plan_search(&query, 3, 1, None, false, 2.5)
.await
.unwrap();
let ctx = SessionContext::new();
let stream = plan.execute(0, ctx.task_ctx()).unwrap();
let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
let rows = collect_id_dist(&batches);
assert_eq!(rows.len(), 3, "expected k=3 live results, got {:?}", rows);
let ids: std::collections::HashSet<i32> = rows.iter().map(|(id, _)| *id).collect();
assert_eq!(
ids,
std::collections::HashSet::from([4, 5, 6]),
"expected the next-nearest live rows {{4,5,6}}, got {:?}",
rows
);
assert!(
rows.iter()
.all(|&(id, d)| !((1..=3).contains(&id) && d.abs() < 1e-3)),
"stale read: a superseded base row was served; got {:?}",
rows
);
}
#[tokio::test]
async fn test_vector_search_flushed_superseded_by_newer_flushed() {
use crate::dataset::mem_wal::scanner::data_source::ShardSnapshot;
use crate::index::DatasetIndexExt;
use crate::index::vector::VectorIndexParams;
use datafusion::prelude::SessionContext;
use futures::TryStreamExt;
use lance_index::IndexType;
let schema = create_vector_schema();
let temp_dir = tempfile::tempdir().unwrap();
let base_uri = format!("{}/base", temp_dir.path().to_str().unwrap());
let shard_id = uuid::Uuid::new_v4();
let ivf_flat = VectorIndexParams::ivf_flat(1, lance_linalg::distance::DistanceType::L2);
let q = [0.1, 0.2, 0.3, 0.4];
let far = [9.0, 9.0, 9.0, 9.0];
let near = [0.12, 0.22, 0.32, 0.42];
let gen1_uri = format!("{}/_mem_wal/{}/gen_1", base_uri, shard_id);
let mut gen1 = create_dataset(&gen1_uri, vec![batch_rows(&schema, &[(1, q)])]).await;
gen1.create_index(&["vector"], IndexType::Vector, None, &ivf_flat, true)
.await
.unwrap();
let gen2_uri = format!("{}/_mem_wal/{}/gen_2", base_uri, shard_id);
let mut gen2 =
create_dataset(&gen2_uri, vec![batch_rows(&schema, &[(1, far), (2, near)])]).await;
gen2.create_index(&["vector"], IndexType::Vector, None, &ivf_flat, true)
.await
.unwrap();
let snapshot = ShardSnapshot::new(shard_id)
.with_current_generation(3)
.with_flushed_generation(1, "gen_1".to_string())
.with_flushed_generation(2, "gen_2".to_string());
let collector = LsmDataSourceCollector::without_base_table(base_uri, vec![snapshot]);
let planner = LsmVectorSearchPlanner::new(
collector,
vec!["id".to_string()],
schema,
"vector".to_string(),
lance_linalg::distance::DistanceType::L2,
);
let query = create_query_vector();
let plan = planner
.plan_search(&query, 1, 1, None, false, 1.0)
.await
.unwrap();
let ctx = SessionContext::new();
let stream = plan.execute(0, ctx.task_ctx()).unwrap();
let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
let rows = collect_id_dist(&batches);
assert_eq!(rows.len(), 1, "expected one result, got {:?}", rows);
assert_eq!(
rows[0].0, 2,
"expected nearest live row pk=2, got {:?}",
rows
);
}
fn create_multicol_schema() -> Arc<ArrowSchema> {
Arc::new(ArrowSchema::new(vec![
Field::new("id1", DataType::Int32, false),
Field::new("id2", DataType::Int32, false),
Field::new(
"vector",
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 4),
false,
),
]))
}
fn multicol_batch(schema: &ArrowSchema, rows: &[((i32, i32), [f32; 4])]) -> RecordBatch {
use arrow_array::builder::Float32Builder;
let mut vb = FixedSizeListBuilder::new(Float32Builder::new(), 4);
for (_, v) in rows {
for x in v {
vb.values().append_value(*x);
}
vb.append(true);
}
let id1: Vec<i32> = rows.iter().map(|((a, _), _)| *a).collect();
let id2: Vec<i32> = rows.iter().map(|((_, b), _)| *b).collect();
RecordBatch::try_new(
Arc::new(schema.clone()),
vec![
Arc::new(Int32Array::from(id1)),
Arc::new(Int32Array::from(id2)),
Arc::new(vb.finish()),
],
)
.unwrap()
}
#[tokio::test]
async fn test_vector_search_stale_read_with_composite_pk() {
use crate::dataset::mem_wal::scanner::collector::{InMemoryMemTableRef, InMemoryMemTables};
use crate::dataset::mem_wal::write::{BatchStore, IndexStore};
use crate::index::DatasetIndexExt;
use crate::index::vector::VectorIndexParams;
use datafusion::prelude::SessionContext;
use futures::TryStreamExt;
use lance_index::IndexType;
let schema = create_multicol_schema();
let temp_dir = tempfile::tempdir().unwrap();
let base_uri = format!("{}/base", temp_dir.path().to_str().unwrap());
let q = [0.1, 0.2, 0.3, 0.4];
let far = [9.0, 9.0, 9.0, 9.0];
let near = [0.12, 0.22, 0.32, 0.42];
let base_batch = multicol_batch(&schema, &[((1, 1), q)]);
let mut base_dataset = create_dataset(&base_uri, vec![base_batch]).await;
let ivf_flat = VectorIndexParams::ivf_flat(1, lance_linalg::distance::DistanceType::L2);
base_dataset
.create_index(&["vector"], IndexType::Vector, None, &ivf_flat, true)
.await
.unwrap();
let base_dataset = Arc::new(base_dataset);
let batch_store = Arc::new(BatchStore::with_capacity(16));
let mut index_store = IndexStore::new();
index_store.enable_pk_index(&[("id1".to_string(), 0), ("id2".to_string(), 1)]);
index_store.add_hnsw(
"vector_hnsw".to_string(),
1,
"vector".to_string(),
lance_linalg::distance::DistanceType::L2,
64,
8,
);
let active_batch = multicol_batch(&schema, &[((1, 1), far), ((2, 2), near)]);
let (_, _, bp) = batch_store.append(active_batch.clone()).unwrap();
index_store
.insert_with_batch_position(&active_batch, 0, Some(bp))
.unwrap();
let index_store = Arc::new(index_store);
let shard_id = uuid::Uuid::new_v4();
let collector = LsmDataSourceCollector::new(base_dataset, vec![]).with_in_memory_memtables(
shard_id,
InMemoryMemTables {
active: InMemoryMemTableRef {
batch_store,
index_store,
schema: schema.clone(),
generation: 1,
},
frozen: vec![],
},
);
let planner = LsmVectorSearchPlanner::new(
collector,
vec!["id1".to_string(), "id2".to_string()],
schema,
"vector".to_string(),
lance_linalg::distance::DistanceType::L2,
);
let query = create_query_vector();
let plan = planner
.plan_search(&query, 1, 1, None, false, 1.0)
.await
.unwrap();
let ctx = SessionContext::new();
let stream = plan.execute(0, ctx.task_ctx()).unwrap();
let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
let mut rows: Vec<(i32, i32, f32)> = Vec::new();
for b in &batches {
let id1 = b
.column_by_name("id1")
.unwrap()
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
let id2 = b
.column_by_name("id2")
.unwrap()
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
let dist = b
.column_by_name(DISTANCE_COLUMN)
.unwrap()
.as_any()
.downcast_ref::<arrow_array::Float32Array>()
.unwrap();
for i in 0..b.num_rows() {
rows.push((id1.value(i), id2.value(i), dist.value(i)));
}
}
assert_eq!(rows.len(), 1, "expected one result, got {:?}", rows);
assert_eq!(
(rows[0].0, rows[0].1),
(2, 2),
"expected nearest live composite key (2,2), got {:?}",
rows
);
}
#[tokio::test]
async fn test_vector_search_same_l0_override_newest_wins() {
use crate::dataset::mem_wal::scanner::collector::{InMemoryMemTableRef, InMemoryMemTables};
use crate::dataset::mem_wal::write::{BatchStore, IndexStore};
use datafusion::prelude::SessionContext;
use futures::TryStreamExt;
let schema = create_vector_schema();
let temp_dir = tempfile::tempdir().unwrap();
let base_uri = format!("{}/base", temp_dir.path().to_str().unwrap());
let on_query = [0.1, 0.2, 0.3, 0.4]; let far = [9.0, 9.0, 9.0, 9.0]; let other = [1.0, 1.0, 1.0, 1.0];
let batch_store = Arc::new(BatchStore::with_capacity(16));
let mut index_store = IndexStore::new();
index_store.enable_pk_index(&[("id".to_string(), 0)]);
index_store.add_hnsw(
"vector_hnsw".to_string(),
1,
"vector".to_string(),
lance_linalg::distance::DistanceType::L2,
64,
8,
);
let b0 = batch_rows(&schema, &[(1, on_query), (2, other)]);
let (_, _, bp0) = batch_store.append(b0.clone()).unwrap();
index_store
.insert_with_batch_position(&b0, 0, Some(bp0))
.unwrap();
let b1 = batch_rows(&schema, &[(1, far)]);
let (_, _, bp1) = batch_store.append(b1.clone()).unwrap();
index_store
.insert_with_batch_position(&b1, 2, Some(bp1))
.unwrap();
let index_store = Arc::new(index_store);
let shard_id = uuid::Uuid::new_v4();
let collector = LsmDataSourceCollector::without_base_table(base_uri, vec![])
.with_in_memory_memtables(
shard_id,
InMemoryMemTables {
active: InMemoryMemTableRef {
batch_store,
index_store,
schema: schema.clone(),
generation: 1,
},
frozen: vec![],
},
);
let planner = LsmVectorSearchPlanner::new(
collector,
vec!["id".to_string()],
schema,
"vector".to_string(),
lance_linalg::distance::DistanceType::L2,
);
let query = create_query_vector();
let plan = planner
.plan_search(&query, 5, 1, None, false, 1.0)
.await
.unwrap();
let ctx = SessionContext::new();
let stream = plan.execute(0, ctx.task_ctx()).unwrap();
let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
let rows = collect_id_dist(&batches);
let id1: Vec<f32> = rows
.iter()
.filter(|&&(id, _)| id == 1)
.map(|&(_, d)| d)
.collect();
assert_eq!(
id1.len(),
1,
"newest-wins: id=1 must appear exactly once after a same-L0 override, got {:?}",
rows
);
assert!(
id1[0] > 1.0,
"newest-wins: surviving id=1 must be the newer far vector, not the stale near one — got distance {}",
id1[0]
);
assert!(
rows.iter().all(|&(_, d)| d.abs() >= 1e-3),
"newest-wins: the stale on-query copy (distance ~0) must be excluded, got {:?}",
rows
);
}
}