use std::sync::Arc;
use arrow_schema::SchemaRef;
use datafusion::common::ScalarValue;
use datafusion::physical_plan::ExecutionPlan;
use datafusion::physical_plan::limit::GlobalLimitExec;
use datafusion::prelude::Expr;
use lance_core::Result;
use lance_index::scalar::bloomfilter::sbbf::Sbbf;
use tracing::instrument;
use super::collector::LsmDataSourceCollector;
use super::data_source::LsmDataSource;
use super::exec::{
BloomFilterGuardExec, CoalesceFirstExec, DedupDirection, WithinSourceDedupExec,
compute_pk_hash_from_scalars,
};
use super::flushed_cache::{FlushedMemTableCache, open_flushed_dataset};
use super::projection::{
build_scanner_projection, canonical_output_schema, null_columns, project_to_canonical,
wants_row_address, wants_row_id,
};
use crate::session::Session;
pub struct LsmPointLookupPlanner {
collector: LsmDataSourceCollector,
pk_columns: Vec<String>,
base_schema: SchemaRef,
bloom_filters: std::collections::HashMap<u64, Arc<Sbbf>>,
session: Option<Arc<Session>>,
flushed_cache: Option<Arc<FlushedMemTableCache>>,
}
impl LsmPointLookupPlanner {
pub fn new(
collector: LsmDataSourceCollector,
pk_columns: Vec<String>,
base_schema: SchemaRef,
) -> Self {
Self {
collector,
pk_columns,
base_schema,
bloom_filters: std::collections::HashMap::new(),
session: None,
flushed_cache: None,
}
}
pub fn with_session(mut self, session: Arc<Session>) -> Self {
self.session = Some(session);
self
}
pub fn with_flushed_cache(mut self, cache: Arc<FlushedMemTableCache>) -> Self {
self.flushed_cache = Some(cache);
self
}
pub fn with_bloom_filter(mut self, generation: u64, bloom_filter: Arc<Sbbf>) -> Self {
self.bloom_filters.insert(generation, bloom_filter);
self
}
pub fn with_bloom_filters(
mut self,
bloom_filters: impl IntoIterator<Item = (u64, Arc<Sbbf>)>,
) -> Self {
self.bloom_filters.extend(bloom_filters);
self
}
#[instrument(name = "lsm_point_lookup", level = "debug", skip_all, fields(pk_column_count = self.pk_columns.len()))]
pub async fn plan_lookup(
&self,
pk_values: &[ScalarValue],
projection: Option<&[String]>,
) -> Result<Arc<dyn ExecutionPlan>> {
if pk_values.len() != self.pk_columns.len() {
return Err(lance_core::Error::invalid_input(format!(
"Expected {} primary key values, got {}",
self.pk_columns.len(),
pk_values.len()
)));
}
let pk_hash = compute_pk_hash_from_scalars(pk_values);
let filter_expr = self.build_pk_filter_expr(pk_values)?;
let sources = self.collector.collect()?;
if sources.is_empty() {
return self.empty_plan(projection);
}
let mut sources: Vec<_> = sources.into_iter().collect();
sources.sort_by_key(|b| std::cmp::Reverse(b.generation()));
let mut source_plans = Vec::new();
for source in sources {
let generation = source.generation().as_u64();
let scan = self
.build_source_scan(&source, projection, &filter_expr)
.await?;
let limited: Arc<dyn ExecutionPlan> = Arc::new(GlobalLimitExec::new(scan, 0, Some(1)));
let guarded_plan: Arc<dyn ExecutionPlan> =
if let Some(bf) = self.bloom_filters.get(&generation) {
Arc::new(BloomFilterGuardExec::new(
limited,
bf.clone(),
pk_hash,
generation,
))
} else {
limited
};
source_plans.push(guarded_plan);
}
let plan: Arc<dyn ExecutionPlan> = if source_plans.len() == 1 {
source_plans.remove(0)
} else {
Arc::new(CoalesceFirstExec::new(source_plans))
};
Ok(plan)
}
fn build_pk_filter_expr(&self, pk_values: &[ScalarValue]) -> Result<Expr> {
use datafusion::prelude::{col, lit};
let mut expr: Option<Expr> = None;
for (col_name, value) in self.pk_columns.iter().zip(pk_values.iter()) {
let eq_expr = col(col_name.as_str()).eq(lit(value.clone()));
expr = Some(match expr {
Some(e) => e.and(eq_expr),
None => eq_expr,
});
}
expr.ok_or_else(|| lance_core::Error::invalid_input("No primary key columns specified"))
}
async fn build_source_scan(
&self,
source: &LsmDataSource,
projection: Option<&[String]>,
filter: &Expr,
) -> Result<Arc<dyn ExecutionPlan>> {
let cols = build_scanner_projection(projection, &self.base_schema, &self.pk_columns);
let target =
canonical_output_schema(projection, &self.base_schema, &self.pk_columns, false);
let want_row_id = wants_row_id(projection);
let want_row_addr = wants_row_address(projection);
let scan: Arc<dyn ExecutionPlan> = match source {
LsmDataSource::BaseTable { dataset } => {
let mut scanner = dataset.scan();
scanner.project(&cols.iter().map(|s| s.as_str()).collect::<Vec<_>>())?;
if want_row_id {
scanner.with_row_id();
}
if want_row_addr {
scanner.with_row_address();
}
scanner.filter_expr(filter.clone());
scanner.create_plan().await?
}
LsmDataSource::FlushedMemTable { path, .. } => {
let dataset =
open_flushed_dataset(path, self.session.as_ref(), self.flushed_cache.as_ref())
.await?;
let mut scanner = dataset.scan();
scanner.project(&cols.iter().map(|s| s.as_str()).collect::<Vec<_>>())?;
scanner.filter_expr(filter.clone());
scanner.create_plan().await?
}
LsmDataSource::ActiveMemTable {
batch_store,
index_store,
schema,
..
} => {
use crate::dataset::mem_wal::memtable::scanner::MemTableScanner;
let mut scanner =
MemTableScanner::new(batch_store.clone(), index_store.clone(), schema.clone());
scanner.project(&cols.iter().map(|s| s.as_str()).collect::<Vec<_>>());
scanner.filter_expr(filter.clone());
scanner.with_row_id();
let raw = scanner.create_plan().await?;
let deduped: Arc<dyn ExecutionPlan> = Arc::new(WithinSourceDedupExec::new(
raw,
self.pk_columns.clone(),
lance_core::ROW_ID,
DedupDirection::KeepMaxRowAddr,
));
null_columns(deduped, &[lance_core::ROW_ID])?
}
};
project_to_canonical(scan, &target)
}
fn empty_plan(&self, projection: Option<&[String]>) -> Result<Arc<dyn ExecutionPlan>> {
use datafusion::physical_plan::empty::EmptyExec;
let schema =
canonical_output_schema(projection, &self.base_schema, &self.pk_columns, false);
Ok(Arc::new(EmptyExec::new(schema)))
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator, StringArray};
use arrow_schema::{DataType, Field, Schema as ArrowSchema};
use datafusion::physical_plan::displayable;
use std::collections::HashMap;
use uuid::Uuid;
use crate::dataset::mem_wal::scanner::data_source::ShardSnapshot;
use crate::dataset::{Dataset, WriteParams};
fn create_pk_schema() -> Arc<ArrowSchema> {
let mut id_metadata = HashMap::new();
id_metadata.insert(
"lance-schema:unenforced-primary-key".to_string(),
"true".to_string(),
);
let id_field = Field::new("id", DataType::Int32, false).with_metadata(id_metadata);
Arc::new(ArrowSchema::new(vec![
id_field,
Field::new("name", DataType::Utf8, true),
]))
}
fn create_test_batch(schema: &ArrowSchema, ids: &[i32], name_prefix: &str) -> RecordBatch {
let names: Vec<String> = ids
.iter()
.map(|id| format!("{}_{}", name_prefix, id))
.collect();
RecordBatch::try_new(
Arc::new(schema.clone()),
vec![
Arc::new(Int32Array::from(ids.to_vec())),
Arc::new(StringArray::from(names)),
],
)
.unwrap()
}
async fn create_dataset(uri: &str, batches: Vec<RecordBatch>) -> Dataset {
let schema = batches[0].schema();
let reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema);
Dataset::write(reader, uri, Some(WriteParams::default()))
.await
.unwrap()
}
#[tokio::test]
async fn test_point_lookup_plan_structure() {
let schema = create_pk_schema();
let temp_dir = tempfile::tempdir().unwrap();
let base_path = temp_dir.path().to_str().unwrap();
let base_uri = format!("{}/base", base_path);
let base_batch = create_test_batch(&schema, &[1, 2, 3], "base");
let base_dataset = Arc::new(create_dataset(&base_uri, vec![base_batch]).await);
let collector = LsmDataSourceCollector::new(base_dataset, vec![]);
let planner = LsmPointLookupPlanner::new(collector, vec!["id".to_string()], schema.clone());
let pk_values = vec![ScalarValue::Int32(Some(2))];
let plan = planner.plan_lookup(&pk_values, None).await.unwrap();
let plan_str = format!("{}", displayable(plan.as_ref()).indent(true));
assert!(
plan_str.contains("GlobalLimitExec"),
"Should have GlobalLimitExec in plan: {}",
plan_str
);
}
#[tokio::test]
async fn test_point_lookup_with_memtables() {
let schema = create_pk_schema();
let temp_dir = tempfile::tempdir().unwrap();
let base_path = temp_dir.path().to_str().unwrap();
let base_uri = format!("{}/base", base_path);
let base_batch = create_test_batch(&schema, &[1, 2, 3], "base");
let base_dataset = Arc::new(create_dataset(&base_uri, vec![base_batch]).await);
let shard_id = Uuid::new_v4();
let gen1_uri = format!("{}/_mem_wal/{}/gen_1", base_uri, shard_id);
let gen1_batch = create_test_batch(&schema, &[2], "gen1"); create_dataset(&gen1_uri, vec![gen1_batch]).await;
let shard_snapshot = ShardSnapshot::new(shard_id)
.with_current_generation(2)
.with_flushed_generation(1, "gen_1".to_string());
let collector = LsmDataSourceCollector::new(base_dataset, vec![shard_snapshot]);
let planner = LsmPointLookupPlanner::new(collector, vec!["id".to_string()], schema.clone());
let pk_values = vec![ScalarValue::Int32(Some(2))];
let plan = planner.plan_lookup(&pk_values, None).await.unwrap();
let plan_str = format!("{}", displayable(plan.as_ref()).indent(true));
assert!(
plan_str.contains("CoalesceFirstExec") || plan_str.contains("GlobalLimitExec"),
"Should have CoalesceFirstExec or GlobalLimitExec in plan: {}",
plan_str
);
}
#[tokio::test]
async fn test_point_lookup_with_bloom_filter() {
let schema = create_pk_schema();
let temp_dir = tempfile::tempdir().unwrap();
let base_path = temp_dir.path().to_str().unwrap();
let base_uri = format!("{}/base", base_path);
let base_batch = create_test_batch(&schema, &[1, 2, 3], "base");
let base_dataset = Arc::new(create_dataset(&base_uri, vec![base_batch]).await);
let collector = LsmDataSourceCollector::new(base_dataset, vec![]);
let mut bf = Sbbf::with_ndv_fpp(100, 0.01).unwrap();
let pk_hash = compute_pk_hash_from_scalars(&[ScalarValue::Int32(Some(2))]);
bf.insert_hash(pk_hash);
let planner = LsmPointLookupPlanner::new(collector, vec!["id".to_string()], schema.clone())
.with_bloom_filter(1, Arc::new(bf));
let pk_values = vec![ScalarValue::Int32(Some(2))];
let plan = planner.plan_lookup(&pk_values, None).await.unwrap();
assert!(plan.schema().field_with_name("id").is_ok());
}
#[tokio::test]
async fn test_pk_filter_expr() {
let schema = create_pk_schema();
let temp_dir = tempfile::tempdir().unwrap();
let base_uri = format!("{}/base", temp_dir.path().to_str().unwrap());
let base_batch = create_test_batch(&schema, &[1], "base");
let base_dataset = Arc::new(create_dataset(&base_uri, vec![base_batch]).await);
let collector = LsmDataSourceCollector::new(base_dataset, vec![]);
let planner = LsmPointLookupPlanner::new(collector, vec!["id".to_string()], schema);
let pk_values = vec![ScalarValue::Int32(Some(42))];
let expr = planner.build_pk_filter_expr(&pk_values).unwrap();
let expr_str = format!("{}", expr);
assert!(
expr_str.contains("id"),
"Expression should contain column name"
);
}
#[tokio::test]
async fn test_point_lookup_without_base_table() {
use futures::TryStreamExt;
let schema = create_pk_schema();
let temp_dir = tempfile::tempdir().unwrap();
let base_path = temp_dir.path().to_str().unwrap();
let base_uri = format!("{}/base", base_path);
let shard_id = Uuid::new_v4();
let gen1_uri = format!("{}/_mem_wal/{}/gen_1", base_uri, shard_id);
let gen1_batch = create_test_batch(&schema, &[2, 3], "gen1");
create_dataset(&gen1_uri, vec![gen1_batch]).await;
let shard_snapshot = ShardSnapshot::new(shard_id)
.with_current_generation(2)
.with_flushed_generation(1, "gen_1".to_string());
let collector = LsmDataSourceCollector::without_base_table(base_uri, vec![shard_snapshot]);
let planner = LsmPointLookupPlanner::new(collector, vec!["id".to_string()], schema);
let pk_values = vec![ScalarValue::Int32(Some(3))];
let plan = planner.plan_lookup(&pk_values, None).await.unwrap();
let plan_str = format!("{}", displayable(plan.as_ref()).indent(true));
assert!(
!plan_str.contains("base/data"),
"Plan must not scan base table, got: {}",
plan_str
);
assert!(plan_str.contains("gen_1"));
let ctx = datafusion::prelude::SessionContext::new();
let stream = plan.execute(0, ctx.task_ctx()).unwrap();
let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
let total: usize = batches.iter().map(|b| b.num_rows()).sum();
assert_eq!(total, 1);
let plan = planner
.plan_lookup(&[ScalarValue::Int32(Some(99))], None)
.await
.unwrap();
let stream = plan.execute(0, ctx.task_ctx()).unwrap();
let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
let total: usize = batches.iter().map(|b| b.num_rows()).sum();
assert_eq!(total, 0);
}
#[tokio::test]
async fn test_point_lookup_projection_with_system_columns() {
use futures::TryStreamExt;
use lance_core::is_system_column;
let schema = create_pk_schema();
let temp_dir = tempfile::tempdir().unwrap();
let base_uri = format!("{}/base", temp_dir.path().to_str().unwrap());
let base_batch = create_test_batch(&schema, &[1, 2, 3], "base");
let base_dataset = Arc::new(create_dataset(&base_uri, vec![base_batch]).await);
let collector = LsmDataSourceCollector::new(base_dataset, vec![]);
let planner = LsmPointLookupPlanner::new(collector, vec!["id".to_string()], schema);
let projection = vec![
"id".to_string(),
"_rowaddr".to_string(),
"name".to_string(),
"_rowoffset".to_string(),
];
let pk_values = vec![ScalarValue::Int32(Some(2))];
let plan = planner
.plan_lookup(&pk_values, Some(&projection))
.await
.expect("planner must accept system columns in projection");
let ctx = datafusion::prelude::SessionContext::new();
let stream = plan.execute(0, ctx.task_ctx()).unwrap();
let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
let total: usize = batches.iter().map(|b| b.num_rows()).sum();
assert_eq!(total, 1, "expected exactly one matching row");
let out_schema = batches[0].schema();
let out_cols: Vec<String> = out_schema
.fields()
.iter()
.map(|f| f.name().clone())
.collect();
assert_eq!(
out_cols,
vec![
"id".to_string(),
"_rowaddr".to_string(),
"name".to_string(),
"_rowoffset".to_string(),
],
"system columns must appear at the user's requested position"
);
let rowaddr = batches[0].column_by_name("_rowaddr").unwrap();
assert!(
!rowaddr.is_null(0),
"_rowaddr from base should be populated, got: {:?}",
rowaddr
);
let rowoffset = batches[0].column_by_name("_rowoffset").unwrap();
assert!(is_system_column("_rowoffset"));
assert!(
rowoffset.is_null(0),
"_rowoffset has no per-source flag, must be NULL across LSM, got: {:?}",
rowoffset
);
}
#[tokio::test]
async fn test_point_lookup_empty_plan_with_system_columns() {
let schema = create_pk_schema();
let temp_dir = tempfile::tempdir().unwrap();
let base_uri = format!("{}/base", temp_dir.path().to_str().unwrap());
let collector = LsmDataSourceCollector::without_base_table(base_uri, vec![]);
let planner = LsmPointLookupPlanner::new(collector, vec!["id".to_string()], schema);
let projection = vec![
"id".to_string(),
"_rowaddr".to_string(),
"name".to_string(),
"_rowid".to_string(),
];
let pk_values = vec![ScalarValue::Int32(Some(2))];
let plan = planner
.plan_lookup(&pk_values, Some(&projection))
.await
.expect("empty plan must accept system columns in projection");
let names: Vec<String> = plan
.schema()
.fields()
.iter()
.map(|f| f.name().clone())
.collect();
assert_eq!(
names,
vec![
"id".to_string(),
"_rowaddr".to_string(),
"name".to_string(),
"_rowid".to_string(),
],
"empty point-lookup plan must honor user column order including system columns"
);
}
#[tokio::test]
async fn test_point_lookup_active_memtable_returns_newest_duplicate() {
use crate::dataset::mem_wal::scanner::collector::{InMemoryMemTableRef, InMemoryMemTables};
use crate::dataset::mem_wal::write::{BatchStore, IndexStore};
use futures::TryStreamExt;
let schema = create_pk_schema();
let temp_dir = tempfile::tempdir().unwrap();
let base_uri = format!("{}/base", temp_dir.path().to_str().unwrap());
let batch_store = Arc::new(BatchStore::with_capacity(16));
let mut index_store = IndexStore::new();
index_store.add_btree("id_idx".to_string(), 0, "id".to_string());
let b_old = create_test_batch(&schema, &[1], "old");
let b_new = create_test_batch(&schema, &[1], "new");
let b_other = create_test_batch(&schema, &[2], "two");
let (_, _, bp_old) = batch_store.append(b_old.clone()).unwrap();
index_store
.insert_with_batch_position(&b_old, 0, Some(bp_old))
.unwrap();
let (_, _, bp_new) = batch_store.append(b_new.clone()).unwrap();
index_store
.insert_with_batch_position(&b_new, 1, Some(bp_new))
.unwrap();
let (_, _, bp_other) = batch_store.append(b_other.clone()).unwrap();
index_store
.insert_with_batch_position(&b_other, 2, Some(bp_other))
.unwrap();
let index_store = Arc::new(index_store);
let shard_id = Uuid::new_v4();
let collector = LsmDataSourceCollector::without_base_table(base_uri, vec![])
.with_in_memory_memtables(
shard_id,
InMemoryMemTables {
active: InMemoryMemTableRef {
batch_store,
index_store,
schema: schema.clone(),
generation: 1,
},
frozen: vec![],
},
);
let planner = LsmPointLookupPlanner::new(collector, vec!["id".to_string()], schema);
let plan = planner
.plan_lookup(&[ScalarValue::Int32(Some(1))], None)
.await
.unwrap();
let ctx = datafusion::prelude::SessionContext::new();
let stream = plan.execute(0, ctx.task_ctx()).unwrap();
let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
let total: usize = batches.iter().map(|b| b.num_rows()).sum();
assert_eq!(total, 1, "expected exactly one row for pk=1");
let name_col = batches[0].column_by_name("name").unwrap();
let name_arr = name_col.as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(
name_arr.value(0),
"new_1",
"active-arm lookup must return the newer insert, not the oldest"
);
}
#[tokio::test]
async fn test_point_lookup_flushed_memtable_returns_newest_duplicate() {
use futures::TryStreamExt;
let schema = create_pk_schema();
let temp_dir = tempfile::tempdir().unwrap();
let base_path = temp_dir.path().to_str().unwrap();
let base_uri = format!("{}/base", base_path);
let shard_id = Uuid::new_v4();
let gen1_uri = format!("{}/_mem_wal/{}/gen_1", base_uri, shard_id);
let row_new = create_test_batch(&schema, &[1], "new");
let row_old = create_test_batch(&schema, &[1], "old");
create_dataset(&gen1_uri, vec![row_new, row_old]).await;
let shard_snapshot = ShardSnapshot::new(shard_id)
.with_current_generation(2)
.with_flushed_generation(1, "gen_1".to_string());
let collector = LsmDataSourceCollector::without_base_table(base_uri, vec![shard_snapshot]);
let planner = LsmPointLookupPlanner::new(collector, vec!["id".to_string()], schema);
let plan = planner
.plan_lookup(&[ScalarValue::Int32(Some(1))], None)
.await
.unwrap();
let ctx = datafusion::prelude::SessionContext::new();
let stream = plan.execute(0, ctx.task_ctx()).unwrap();
let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
let total: usize = batches.iter().map(|b| b.num_rows()).sum();
assert_eq!(total, 1, "expected exactly one row for pk=1");
let name_col = batches[0].column_by_name("name").unwrap();
let name_arr = name_col.as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(
name_arr.value(0),
"new_1",
"flushed-arm lookup must return the row at the smallest _rowid (newest under reverse-write)"
);
}
}