use std::collections::HashMap;
use std::sync::Arc;
use arrow_array::{Array, RecordBatch};
use arrow_schema::{SchemaRef, SortOptions};
use datafusion::common::ScalarValue;
use datafusion::execution::TaskContext;
use datafusion::physical_expr::expressions::Column;
use datafusion::physical_expr::{LexOrdering, PhysicalSortExpr};
use datafusion::physical_plan::ExecutionPlan;
use datafusion::physical_plan::limit::GlobalLimitExec;
use datafusion::physical_plan::sorts::sort::SortExec;
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
use datafusion::prelude::{Expr, SessionContext};
use futures::TryStreamExt;
use lance_core::utils::bloomfilter::sbbf::Sbbf;
use lance_core::{Result, is_system_column};
use lance_datafusion::exec::OneShotExec;
use tracing::instrument;
use crate::dataset::mem_wal::index::IndexStore;
use crate::dataset::mem_wal::memtable::batch_store::BatchStore;
use super::collector::LsmDataSourceCollector;
use super::data_source::LsmDataSource;
use super::exec::{BloomFilterGuardExec, CoalesceFirstExec, compute_pk_hash_from_scalars};
use super::flushed_cache::{DatasetCache, GenerationWarmer, open_flushed_dataset};
use super::projection::{
build_scanner_projection, canonical_output_schema, null_columns, project_to_canonical,
wants_row_address, wants_row_id,
};
use crate::session::Session;
pub struct LsmPointLookupPlanner {
collector: LsmDataSourceCollector,
pk_columns: Vec<String>,
base_schema: SchemaRef,
bloom_filters: std::collections::HashMap<u64, Arc<Sbbf>>,
session: Option<Arc<Session>>,
flushed_cache: Option<Arc<dyn DatasetCache>>,
warmer: Option<Arc<dyn GenerationWarmer>>,
none_target: SchemaRef,
task_ctx: Arc<TaskContext>,
}
impl LsmPointLookupPlanner {
pub fn new(
collector: LsmDataSourceCollector,
pk_columns: Vec<String>,
base_schema: SchemaRef,
) -> Self {
let none_target = canonical_output_schema(None, &base_schema, &pk_columns, false);
Self {
collector,
pk_columns,
base_schema,
bloom_filters: std::collections::HashMap::new(),
session: None,
flushed_cache: None,
warmer: None,
none_target,
task_ctx: SessionContext::new().task_ctx(),
}
}
pub fn with_session(mut self, session: Arc<Session>) -> Self {
self.session = Some(session);
self
}
pub fn with_flushed_cache(mut self, cache: Arc<dyn DatasetCache>) -> Self {
self.flushed_cache = Some(cache);
self
}
pub fn with_warmer(mut self, warmer: Arc<dyn GenerationWarmer>) -> Self {
self.warmer = Some(warmer);
self
}
pub fn with_bloom_filter(mut self, generation: u64, bloom_filter: Arc<Sbbf>) -> Self {
self.bloom_filters.insert(generation, bloom_filter);
self
}
pub fn with_bloom_filters(
mut self,
bloom_filters: impl IntoIterator<Item = (u64, Arc<Sbbf>)>,
) -> Self {
self.bloom_filters.extend(bloom_filters);
self
}
#[instrument(name = "lsm_point_lookup", level = "debug", skip_all, fields(pk_column_count = self.pk_columns.len()))]
pub async fn plan_lookup(
&self,
pk_values: &[ScalarValue],
projection: Option<&[String]>,
) -> Result<Arc<dyn ExecutionPlan>> {
if pk_values.len() != self.pk_columns.len() {
return Err(lance_core::Error::invalid_input(format!(
"Expected {} primary key values, got {}",
self.pk_columns.len(),
pk_values.len()
)));
}
let pk_hash = compute_pk_hash_from_scalars(pk_values);
let filter_expr = self.build_pk_filter_expr(pk_values)?;
let sources = self.collector.collect()?;
if sources.is_empty() {
return self.empty_plan(projection);
}
let mut sources: Vec<_> = sources.into_iter().collect();
sources.sort_by_key(|b| std::cmp::Reverse(b.generation()));
let mut source_plans = Vec::new();
for source in sources {
let generation = source.generation().as_u64();
let scan = self
.build_source_scan(&source, projection, &filter_expr)
.await?;
let limited: Arc<dyn ExecutionPlan> = Arc::new(GlobalLimitExec::new(scan, 0, Some(1)));
let guarded_plan: Arc<dyn ExecutionPlan> =
if let Some(bf) = self.bloom_filters.get(&generation) {
Arc::new(BloomFilterGuardExec::new(
limited,
bf.clone(),
pk_hash,
generation,
))
} else {
limited
};
source_plans.push(guarded_plan);
}
let plan: Arc<dyn ExecutionPlan> = if source_plans.len() == 1 {
source_plans.remove(0)
} else {
Arc::new(CoalesceFirstExec::new(source_plans))
};
Ok(plan)
}
#[instrument(name = "lsm_lookup", level = "debug", skip_all)]
pub async fn lookup(
&self,
pk_values: &[ScalarValue],
projection: Option<&[String]>,
) -> Result<Option<RecordBatch>> {
let fast_eligible = pk_values.len() == 1
&& self.pk_columns.len() == 1
&& self
.base_schema
.field_with_name(&self.pk_columns[0])
.ok()
.map(|f| f.data_type() == &pk_values[0].data_type())
.unwrap_or(false);
if fast_eligible {
let projected;
let target: &SchemaRef = match projection {
None => &self.none_target,
Some(_) => {
projected = canonical_output_schema(
projection,
&self.base_schema,
&self.pk_columns,
false,
);
&projected
}
};
if !target.fields().iter().any(|f| is_system_column(f.name())) {
let outcome = self.collector.find_in_memory_newest_first(
|m| -> Result<Option<FastOutcome>> {
match probe_memtable(
&m.batch_store,
&m.index_store,
&self.pk_columns[0],
&pk_values[0],
target,
)? {
Probe::Hit(batch) => Ok(Some(FastOutcome::Hit(batch))),
Probe::Miss => Ok(None),
Probe::NoIndex => Ok(Some(FastOutcome::NeedsFallback)),
}
},
)?;
match outcome {
Some(FastOutcome::Hit(batch)) => return Ok(Some(batch)),
Some(FastOutcome::NeedsFallback) => { }
None => {
if !self.collector.has_on_disk_sources() {
return Ok(None);
}
}
}
}
}
self.lookup_via_plan(pk_values, projection).await
}
async fn lookup_via_plan(
&self,
pk_values: &[ScalarValue],
projection: Option<&[String]>,
) -> Result<Option<RecordBatch>> {
let plan = self.plan_lookup(pk_values, projection).await?;
let batches: Vec<RecordBatch> = plan
.execute(0, self.task_ctx.clone())?
.try_collect()
.await?;
for batch in batches {
if batch.num_rows() > 0 {
return Ok(Some(batch.slice(0, 1)));
}
}
Ok(None)
}
#[instrument(name = "lsm_lookup_many", level = "debug", skip_all, fields(n = keys.len()))]
pub async fn lookup_many(
&self,
keys: &[ScalarValue],
projection: Option<&[String]>,
) -> Result<RecordBatch> {
let target = match projection {
None => self.none_target.clone(),
Some(_) => {
canonical_output_schema(projection, &self.base_schema, &self.pk_columns, false)
}
};
if keys.is_empty() {
return Ok(RecordBatch::new_empty(target));
}
if keys.len() == 1 {
return Ok(self
.lookup(keys, projection)
.await?
.unwrap_or_else(|| RecordBatch::new_empty(target)));
}
let pk_type = self
.pk_columns
.first()
.and_then(|c| self.base_schema.field_with_name(c).ok())
.map(|f| f.data_type().clone());
let fast_eligible = self.pk_columns.len() == 1
&& !target.fields().iter().any(|f| is_system_column(f.name()))
&& pk_type
.as_ref()
.map(|t| keys.iter().all(|k| &k.data_type() == t))
.unwrap_or(false);
if !fast_eligible {
return self
.lookup_many_via_per_key(keys, projection, &target)
.await;
}
let pk_col = &self.pk_columns[0];
let refs = self.collector.in_memory_refs_newest_first();
let mut hits: HashMap<(usize, usize), Vec<u32>> = HashMap::new();
let mut pending: Vec<ScalarValue> = Vec::new();
for key in keys {
let mut resolved = false;
for (ri, m) in refs.iter().enumerate() {
match probe_position(&m.batch_store, &m.index_store, pk_col, key)? {
ProbePos::Found { batch_idx, row } => {
hits.entry((ri, batch_idx)).or_default().push(row as u32);
resolved = true;
break;
}
ProbePos::Miss => continue,
ProbePos::NoIndex => {
return self
.lookup_many_via_per_key(keys, projection, &target)
.await;
}
}
}
if !resolved {
pending.push(key.clone());
}
}
let mut out: Vec<RecordBatch> = Vec::with_capacity(hits.len() + 1);
for ((ri, batch_idx), rows) in hits {
out.push(gather_rows(
&refs[ri].batch_store,
batch_idx,
&rows,
&target,
)?);
}
if !pending.is_empty() && self.collector.has_on_disk_sources() {
out.push(
self.lookup_many_via_per_key(&pending, projection, &target)
.await?,
);
}
match out.len() {
0 => Ok(RecordBatch::new_empty(target)),
1 => Ok(out.pop().unwrap()),
_ => Ok(arrow_select::concat::concat_batches(&target, &out)?),
}
}
async fn lookup_many_via_per_key(
&self,
keys: &[ScalarValue],
projection: Option<&[String]>,
target: &SchemaRef,
) -> Result<RecordBatch> {
let mut out: Vec<RecordBatch> = Vec::new();
for key in keys {
if let Some(b) = self.lookup(std::slice::from_ref(key), projection).await? {
out.push(b);
}
}
match out.len() {
0 => Ok(RecordBatch::new_empty(target.clone())),
1 => Ok(out.pop().unwrap()),
_ => Ok(arrow_select::concat::concat_batches(target, &out)?),
}
}
pub async fn plan_point_lookup(
&self,
keys: &[ScalarValue],
projection: Option<&[String]>,
) -> Result<Arc<dyn ExecutionPlan>> {
let batch = if keys.len() == 1 {
match self.lookup(keys, projection).await? {
Some(b) => b,
None => RecordBatch::new_empty(canonical_output_schema(
projection,
&self.base_schema,
&self.pk_columns,
false,
)),
}
} else {
self.lookup_many(keys, projection).await?
};
let schema = batch.schema();
let stream = futures::stream::once(async move { Ok(batch) });
let adapter = RecordBatchStreamAdapter::new(schema, stream);
Ok(Arc::new(OneShotExec::new(Box::pin(adapter))))
}
fn build_pk_filter_expr(&self, pk_values: &[ScalarValue]) -> Result<Expr> {
use datafusion::prelude::{col, lit};
let mut expr: Option<Expr> = None;
for (col_name, value) in self.pk_columns.iter().zip(pk_values.iter()) {
let eq_expr = col(col_name.as_str()).eq(lit(value.clone()));
expr = Some(match expr {
Some(e) => e.and(eq_expr),
None => eq_expr,
});
}
expr.ok_or_else(|| lance_core::Error::invalid_input("No primary key columns specified"))
}
async fn build_source_scan(
&self,
source: &LsmDataSource,
projection: Option<&[String]>,
filter: &Expr,
) -> Result<Arc<dyn ExecutionPlan>> {
let cols = build_scanner_projection(projection, &self.base_schema, &self.pk_columns);
let target =
canonical_output_schema(projection, &self.base_schema, &self.pk_columns, false);
let want_row_id = wants_row_id(projection);
let want_row_addr = wants_row_address(projection);
let scan: Arc<dyn ExecutionPlan> = match source {
LsmDataSource::BaseTable { dataset } => {
let mut scanner = dataset.scan();
scanner.project(&cols.iter().map(|s| s.as_str()).collect::<Vec<_>>())?;
if want_row_id {
scanner.with_row_id();
}
if want_row_addr {
scanner.with_row_address();
}
scanner.filter_expr(filter.clone());
scanner.create_plan().await?
}
LsmDataSource::FlushedMemTable { path, .. } => {
let dataset = open_flushed_dataset(
path,
self.session.as_ref(),
self.flushed_cache.as_ref(),
self.warmer.as_ref(),
)
.await?;
let mut scanner = dataset.scan();
scanner.project(&cols.iter().map(|s| s.as_str()).collect::<Vec<_>>())?;
scanner.filter_expr(filter.clone());
scanner.create_plan().await?
}
LsmDataSource::ActiveMemTable {
batch_store,
index_store,
schema,
..
} => {
use crate::dataset::mem_wal::memtable::scanner::MemTableScanner;
let mut scanner =
MemTableScanner::new(batch_store.clone(), index_store.clone(), schema.clone());
scanner.project(&cols.iter().map(|s| s.as_str()).collect::<Vec<_>>());
scanner.filter_expr(filter.clone());
scanner.with_row_id();
let raw = scanner.create_plan().await?;
let rowid_idx = raw.schema().index_of(lance_core::ROW_ID)?;
let ordering = LexOrdering::new(vec![PhysicalSortExpr {
expr: Arc::new(Column::new(lance_core::ROW_ID, rowid_idx)),
options: SortOptions {
descending: true,
nulls_first: false,
},
}])
.ok_or_else(|| {
lance_core::Error::internal("point-lookup: failed to build _rowid ordering")
})?;
let newest: Arc<dyn ExecutionPlan> =
Arc::new(SortExec::new(ordering, raw).with_fetch(Some(1)));
null_columns(newest, &[lance_core::ROW_ID])?
}
};
project_to_canonical(scan, &target)
}
fn empty_plan(&self, projection: Option<&[String]>) -> Result<Arc<dyn ExecutionPlan>> {
use datafusion::physical_plan::empty::EmptyExec;
let schema =
canonical_output_schema(projection, &self.base_schema, &self.pk_columns, false);
Ok(Arc::new(EmptyExec::new(schema)))
}
}
enum FastOutcome {
Hit(RecordBatch),
NeedsFallback,
}
enum Probe {
Hit(RecordBatch),
Miss,
NoIndex,
}
enum ProbePos {
Found {
batch_idx: usize,
row: usize,
},
Miss,
NoIndex,
}
fn probe_position(
batch_store: &BatchStore,
index_store: &IndexStore,
pk_column: &str,
pk_value: &ScalarValue,
) -> Result<ProbePos> {
let len = batch_store.len();
if len == 0 {
return Ok(ProbePos::Miss);
}
let last_visible_idx = index_store.max_visible_batch_position().min(len - 1);
let last = batch_store.get(last_visible_idx).ok_or_else(|| {
lance_core::Error::internal("point-lookup: visible batch index out of range")
})?;
let visible_end = last.row_offset + last.num_rows as u64; if visible_end == 0 {
return Ok(ProbePos::Miss);
}
let max_visible_row = visible_end - 1;
let Some(btree) = index_store.get_btree_by_column(pk_column) else {
return Ok(ProbePos::NoIndex);
};
let Some(pos) = btree.get_newest_visible(pk_value, max_visible_row) else {
return Ok(ProbePos::Miss);
};
let (batch_idx, row) = resolve_position(batch_store, last_visible_idx, pos)?;
Ok(ProbePos::Found { batch_idx, row })
}
fn resolve_position(
batch_store: &BatchStore,
last_visible_idx: usize,
position: u64,
) -> Result<(usize, usize)> {
let (mut lo, mut hi) = (0usize, last_visible_idx);
while lo < hi {
let mid = lo + (hi - lo).div_ceil(2);
let off = batch_store.get(mid).map(|b| b.row_offset).ok_or_else(|| {
lance_core::Error::internal("point-lookup: batch index out of range during search")
})?;
if off <= position {
lo = mid;
} else {
hi = mid - 1;
}
}
let stored = batch_store
.get(lo)
.ok_or_else(|| lance_core::Error::internal("point-lookup: resolved batch missing"))?;
Ok((lo, (position - stored.row_offset) as usize))
}
fn gather_rows(
batch_store: &BatchStore,
batch_idx: usize,
rows: &[u32],
target: &SchemaRef,
) -> Result<RecordBatch> {
let stored = batch_store
.get(batch_idx)
.ok_or_else(|| lance_core::Error::internal("point-lookup: gather batch missing"))?;
let indices = (rows.len() > 1).then(|| arrow_array::UInt32Array::from(rows.to_vec()));
let stored_schema = stored.data.schema_ref();
let cols: Vec<Arc<dyn Array>> = target
.fields()
.iter()
.map(|f| {
let idx = stored_schema.index_of(f.name()).map_err(|_| {
lance_core::Error::invalid_input(format!(
"point-lookup projection column '{}' not found in memtable batch",
f.name()
))
})?;
let col = stored.data.column(idx);
match &indices {
None => Ok(col.slice(rows[0] as usize, 1)),
Some(idxs) => arrow_select::take::take(col.as_ref(), idxs, None)
.map_err(lance_core::Error::from),
}
})
.collect::<Result<Vec<_>>>()?;
Ok(RecordBatch::try_new(target.clone(), cols)?)
}
fn probe_memtable(
batch_store: &BatchStore,
index_store: &IndexStore,
pk_column: &str,
pk_value: &ScalarValue,
target: &SchemaRef,
) -> Result<Probe> {
match probe_position(batch_store, index_store, pk_column, pk_value)? {
ProbePos::NoIndex => Ok(Probe::NoIndex),
ProbePos::Miss => Ok(Probe::Miss),
ProbePos::Found { batch_idx, row } => Ok(Probe::Hit(gather_rows(
batch_store,
batch_idx,
&[row as u32],
target,
)?)),
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator, StringArray};
use arrow_schema::{DataType, Field, Schema as ArrowSchema};
use datafusion::physical_plan::displayable;
use std::collections::HashMap;
use uuid::Uuid;
use crate::dataset::mem_wal::scanner::data_source::ShardSnapshot;
use crate::dataset::{Dataset, WriteParams};
fn create_pk_schema() -> Arc<ArrowSchema> {
let mut id_metadata = HashMap::new();
id_metadata.insert(
"lance-schema:unenforced-primary-key".to_string(),
"true".to_string(),
);
let id_field = Field::new("id", DataType::Int32, false).with_metadata(id_metadata);
Arc::new(ArrowSchema::new(vec![
id_field,
Field::new("name", DataType::Utf8, true),
]))
}
fn create_test_batch(schema: &ArrowSchema, ids: &[i32], name_prefix: &str) -> RecordBatch {
let names: Vec<String> = ids
.iter()
.map(|id| format!("{}_{}", name_prefix, id))
.collect();
RecordBatch::try_new(
Arc::new(schema.clone()),
vec![
Arc::new(Int32Array::from(ids.to_vec())),
Arc::new(StringArray::from(names)),
],
)
.unwrap()
}
async fn create_dataset(uri: &str, batches: Vec<RecordBatch>) -> Dataset {
let schema = batches[0].schema();
let reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema);
Dataset::write(reader, uri, Some(WriteParams::default()))
.await
.unwrap()
}
#[tokio::test]
async fn test_point_lookup_plan_structure() {
let schema = create_pk_schema();
let temp_dir = tempfile::tempdir().unwrap();
let base_path = temp_dir.path().to_str().unwrap();
let base_uri = format!("{}/base", base_path);
let base_batch = create_test_batch(&schema, &[1, 2, 3], "base");
let base_dataset = Arc::new(create_dataset(&base_uri, vec![base_batch]).await);
let collector = LsmDataSourceCollector::new(base_dataset, vec![]);
let planner = LsmPointLookupPlanner::new(collector, vec!["id".to_string()], schema.clone());
let pk_values = vec![ScalarValue::Int32(Some(2))];
let plan = planner.plan_lookup(&pk_values, None).await.unwrap();
let plan_str = format!("{}", displayable(plan.as_ref()).indent(true));
assert!(
plan_str.contains("GlobalLimitExec"),
"Should have GlobalLimitExec in plan: {}",
plan_str
);
}
#[tokio::test]
async fn test_point_lookup_with_memtables() {
let schema = create_pk_schema();
let temp_dir = tempfile::tempdir().unwrap();
let base_path = temp_dir.path().to_str().unwrap();
let base_uri = format!("{}/base", base_path);
let base_batch = create_test_batch(&schema, &[1, 2, 3], "base");
let base_dataset = Arc::new(create_dataset(&base_uri, vec![base_batch]).await);
let shard_id = Uuid::new_v4();
let gen1_uri = format!("{}/_mem_wal/{}/gen_1", base_uri, shard_id);
let gen1_batch = create_test_batch(&schema, &[2], "gen1"); create_dataset(&gen1_uri, vec![gen1_batch]).await;
let shard_snapshot = ShardSnapshot::new(shard_id)
.with_current_generation(2)
.with_flushed_generation(1, "gen_1".to_string());
let collector = LsmDataSourceCollector::new(base_dataset, vec![shard_snapshot]);
let planner = LsmPointLookupPlanner::new(collector, vec!["id".to_string()], schema.clone());
let pk_values = vec![ScalarValue::Int32(Some(2))];
let plan = planner.plan_lookup(&pk_values, None).await.unwrap();
let plan_str = format!("{}", displayable(plan.as_ref()).indent(true));
assert!(
plan_str.contains("CoalesceFirstExec") || plan_str.contains("GlobalLimitExec"),
"Should have CoalesceFirstExec or GlobalLimitExec in plan: {}",
plan_str
);
}
#[tokio::test]
async fn test_point_lookup_with_bloom_filter() {
let schema = create_pk_schema();
let temp_dir = tempfile::tempdir().unwrap();
let base_path = temp_dir.path().to_str().unwrap();
let base_uri = format!("{}/base", base_path);
let base_batch = create_test_batch(&schema, &[1, 2, 3], "base");
let base_dataset = Arc::new(create_dataset(&base_uri, vec![base_batch]).await);
let collector = LsmDataSourceCollector::new(base_dataset, vec![]);
let mut bf = Sbbf::with_ndv_fpp(100, 0.01).unwrap();
let pk_hash = compute_pk_hash_from_scalars(&[ScalarValue::Int32(Some(2))]);
bf.insert_hash(pk_hash);
let planner = LsmPointLookupPlanner::new(collector, vec!["id".to_string()], schema.clone())
.with_bloom_filter(1, Arc::new(bf));
let pk_values = vec![ScalarValue::Int32(Some(2))];
let plan = planner.plan_lookup(&pk_values, None).await.unwrap();
assert!(plan.schema().field_with_name("id").is_ok());
}
#[tokio::test]
async fn test_pk_filter_expr() {
let schema = create_pk_schema();
let temp_dir = tempfile::tempdir().unwrap();
let base_uri = format!("{}/base", temp_dir.path().to_str().unwrap());
let base_batch = create_test_batch(&schema, &[1], "base");
let base_dataset = Arc::new(create_dataset(&base_uri, vec![base_batch]).await);
let collector = LsmDataSourceCollector::new(base_dataset, vec![]);
let planner = LsmPointLookupPlanner::new(collector, vec!["id".to_string()], schema);
let pk_values = vec![ScalarValue::Int32(Some(42))];
let expr = planner.build_pk_filter_expr(&pk_values).unwrap();
let expr_str = format!("{}", expr);
assert!(
expr_str.contains("id"),
"Expression should contain column name"
);
}
#[tokio::test]
async fn test_point_lookup_without_base_table() {
use futures::TryStreamExt;
let schema = create_pk_schema();
let temp_dir = tempfile::tempdir().unwrap();
let base_path = temp_dir.path().to_str().unwrap();
let base_uri = format!("{}/base", base_path);
let shard_id = Uuid::new_v4();
let gen1_uri = format!("{}/_mem_wal/{}/gen_1", base_uri, shard_id);
let gen1_batch = create_test_batch(&schema, &[2, 3], "gen1");
create_dataset(&gen1_uri, vec![gen1_batch]).await;
let shard_snapshot = ShardSnapshot::new(shard_id)
.with_current_generation(2)
.with_flushed_generation(1, "gen_1".to_string());
let collector = LsmDataSourceCollector::without_base_table(base_uri, vec![shard_snapshot]);
let planner = LsmPointLookupPlanner::new(collector, vec!["id".to_string()], schema);
let pk_values = vec![ScalarValue::Int32(Some(3))];
let plan = planner.plan_lookup(&pk_values, None).await.unwrap();
let plan_str = format!("{}", displayable(plan.as_ref()).indent(true));
assert!(
!plan_str.contains("base/data"),
"Plan must not scan base table, got: {}",
plan_str
);
assert!(plan_str.contains("gen_1"));
let ctx = datafusion::prelude::SessionContext::new();
let stream = plan.execute(0, ctx.task_ctx()).unwrap();
let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
let total: usize = batches.iter().map(|b| b.num_rows()).sum();
assert_eq!(total, 1);
let plan = planner
.plan_lookup(&[ScalarValue::Int32(Some(99))], None)
.await
.unwrap();
let stream = plan.execute(0, ctx.task_ctx()).unwrap();
let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
let total: usize = batches.iter().map(|b| b.num_rows()).sum();
assert_eq!(total, 0);
}
#[tokio::test]
async fn test_point_lookup_projection_with_system_columns() {
use futures::TryStreamExt;
use lance_core::is_system_column;
let schema = create_pk_schema();
let temp_dir = tempfile::tempdir().unwrap();
let base_uri = format!("{}/base", temp_dir.path().to_str().unwrap());
let base_batch = create_test_batch(&schema, &[1, 2, 3], "base");
let base_dataset = Arc::new(create_dataset(&base_uri, vec![base_batch]).await);
let collector = LsmDataSourceCollector::new(base_dataset, vec![]);
let planner = LsmPointLookupPlanner::new(collector, vec!["id".to_string()], schema);
let projection = vec![
"id".to_string(),
"_rowaddr".to_string(),
"name".to_string(),
"_rowoffset".to_string(),
];
let pk_values = vec![ScalarValue::Int32(Some(2))];
let plan = planner
.plan_lookup(&pk_values, Some(&projection))
.await
.expect("planner must accept system columns in projection");
let ctx = datafusion::prelude::SessionContext::new();
let stream = plan.execute(0, ctx.task_ctx()).unwrap();
let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
let total: usize = batches.iter().map(|b| b.num_rows()).sum();
assert_eq!(total, 1, "expected exactly one matching row");
let out_schema = batches[0].schema();
let out_cols: Vec<String> = out_schema
.fields()
.iter()
.map(|f| f.name().clone())
.collect();
assert_eq!(
out_cols,
vec![
"id".to_string(),
"_rowaddr".to_string(),
"name".to_string(),
"_rowoffset".to_string(),
],
"system columns must appear at the user's requested position"
);
let rowaddr = batches[0].column_by_name("_rowaddr").unwrap();
assert!(
!rowaddr.is_null(0),
"_rowaddr from base should be populated, got: {:?}",
rowaddr
);
let rowoffset = batches[0].column_by_name("_rowoffset").unwrap();
assert!(is_system_column("_rowoffset"));
assert!(
rowoffset.is_null(0),
"_rowoffset has no per-source flag, must be NULL across LSM, got: {:?}",
rowoffset
);
}
#[tokio::test]
async fn test_point_lookup_empty_plan_with_system_columns() {
let schema = create_pk_schema();
let temp_dir = tempfile::tempdir().unwrap();
let base_uri = format!("{}/base", temp_dir.path().to_str().unwrap());
let collector = LsmDataSourceCollector::without_base_table(base_uri, vec![]);
let planner = LsmPointLookupPlanner::new(collector, vec!["id".to_string()], schema);
let projection = vec![
"id".to_string(),
"_rowaddr".to_string(),
"name".to_string(),
"_rowid".to_string(),
];
let pk_values = vec![ScalarValue::Int32(Some(2))];
let plan = planner
.plan_lookup(&pk_values, Some(&projection))
.await
.expect("empty plan must accept system columns in projection");
let names: Vec<String> = plan
.schema()
.fields()
.iter()
.map(|f| f.name().clone())
.collect();
assert_eq!(
names,
vec![
"id".to_string(),
"_rowaddr".to_string(),
"name".to_string(),
"_rowid".to_string(),
],
"empty point-lookup plan must honor user column order including system columns"
);
}
#[tokio::test]
async fn test_point_lookup_active_memtable_returns_newest_duplicate() {
use crate::dataset::mem_wal::scanner::collector::{InMemoryMemTableRef, InMemoryMemTables};
use crate::dataset::mem_wal::write::{BatchStore, IndexStore};
use futures::TryStreamExt;
let schema = create_pk_schema();
let temp_dir = tempfile::tempdir().unwrap();
let base_uri = format!("{}/base", temp_dir.path().to_str().unwrap());
let batch_store = Arc::new(BatchStore::with_capacity(16));
let mut index_store = IndexStore::new();
index_store.add_btree("id_idx".to_string(), 0, "id".to_string());
let b_old = create_test_batch(&schema, &[1], "old");
let b_new = create_test_batch(&schema, &[1], "new");
let b_other = create_test_batch(&schema, &[2], "two");
let (bp_old, off_old, _) = batch_store.append(b_old.clone()).unwrap();
index_store
.insert_with_batch_position(&b_old, off_old, Some(bp_old))
.unwrap();
let (bp_new, off_new, _) = batch_store.append(b_new.clone()).unwrap();
index_store
.insert_with_batch_position(&b_new, off_new, Some(bp_new))
.unwrap();
let (bp_other, off_other, _) = batch_store.append(b_other.clone()).unwrap();
index_store
.insert_with_batch_position(&b_other, off_other, Some(bp_other))
.unwrap();
let index_store = Arc::new(index_store);
let shard_id = Uuid::new_v4();
let collector = LsmDataSourceCollector::without_base_table(base_uri, vec![])
.with_in_memory_memtables(
shard_id,
InMemoryMemTables {
active: InMemoryMemTableRef {
batch_store,
index_store,
schema: schema.clone(),
generation: 1,
},
frozen: vec![],
},
);
let planner = LsmPointLookupPlanner::new(collector, vec!["id".to_string()], schema);
let plan = planner
.plan_lookup(&[ScalarValue::Int32(Some(1))], None)
.await
.unwrap();
let ctx = datafusion::prelude::SessionContext::new();
let stream = plan.execute(0, ctx.task_ctx()).unwrap();
let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
let total: usize = batches.iter().map(|b| b.num_rows()).sum();
assert_eq!(total, 1, "expected exactly one row for pk=1");
let name_col = batches[0].column_by_name("name").unwrap();
let name_arr = name_col.as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(
name_arr.value(0),
"new_1",
"active-arm lookup must return the newer insert, not the oldest"
);
}
#[tokio::test]
async fn test_point_lookup_probes_auto_created_pk_btree() {
use crate::dataset::mem_wal::scanner::collector::{InMemoryMemTableRef, InMemoryMemTables};
use crate::dataset::mem_wal::write::{BatchStore, IndexStore};
let schema = create_pk_schema();
let temp_dir = tempfile::tempdir().unwrap();
let base_uri = format!("{}/base", temp_dir.path().to_str().unwrap());
let batch_store = Arc::new(BatchStore::with_capacity(16));
let mut index_store = IndexStore::new();
index_store.enable_pk_index(&[("id".to_string(), 0)]);
let b_old = create_test_batch(&schema, &[1], "old");
let b_new = create_test_batch(&schema, &[1], "new");
let b_other = create_test_batch(&schema, &[2], "two");
let (bp_old, off_old, _) = batch_store.append(b_old.clone()).unwrap();
index_store
.insert_with_batch_position(&b_old, off_old, Some(bp_old))
.unwrap();
let (bp_new, off_new, _) = batch_store.append(b_new.clone()).unwrap();
index_store
.insert_with_batch_position(&b_new, off_new, Some(bp_new))
.unwrap();
let (bp_other, off_other, _) = batch_store.append(b_other.clone()).unwrap();
index_store
.insert_with_batch_position(&b_other, off_other, Some(bp_other))
.unwrap();
let index_store = Arc::new(index_store);
let shard_id = Uuid::new_v4();
let collector = LsmDataSourceCollector::without_base_table(base_uri, vec![])
.with_in_memory_memtables(
shard_id,
InMemoryMemTables {
active: InMemoryMemTableRef {
batch_store,
index_store,
schema: schema.clone(),
generation: 1,
},
frozen: vec![],
},
);
let planner = LsmPointLookupPlanner::new(collector, vec!["id".to_string()], schema);
let hit = planner
.lookup(&[ScalarValue::Int32(Some(1))], None)
.await
.unwrap()
.expect("pk=1 must be found via the PK-position index probe");
assert_eq!(hit.num_rows(), 1);
let name = hit
.column_by_name("name")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
assert_eq!(
name.value(0),
"new_1",
"probe must return the newest version"
);
assert!(
planner
.lookup(&[ScalarValue::Int32(Some(999))], None)
.await
.unwrap()
.is_none(),
"absent key must miss"
);
}
#[tokio::test]
async fn test_point_lookup_flushed_memtable_returns_newest_duplicate() {
use futures::TryStreamExt;
let schema = create_pk_schema();
let temp_dir = tempfile::tempdir().unwrap();
let base_path = temp_dir.path().to_str().unwrap();
let base_uri = format!("{}/base", base_path);
let shard_id = Uuid::new_v4();
let gen1_uri = format!("{}/_mem_wal/{}/gen_1", base_uri, shard_id);
let row_new = create_test_batch(&schema, &[1], "new");
let row_old = create_test_batch(&schema, &[1], "old");
create_dataset(&gen1_uri, vec![row_new, row_old]).await;
let shard_snapshot = ShardSnapshot::new(shard_id)
.with_current_generation(2)
.with_flushed_generation(1, "gen_1".to_string());
let collector = LsmDataSourceCollector::without_base_table(base_uri, vec![shard_snapshot]);
let planner = LsmPointLookupPlanner::new(collector, vec!["id".to_string()], schema);
let plan = planner
.plan_lookup(&[ScalarValue::Int32(Some(1))], None)
.await
.unwrap();
let ctx = datafusion::prelude::SessionContext::new();
let stream = plan.execute(0, ctx.task_ctx()).unwrap();
let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
let total: usize = batches.iter().map(|b| b.num_rows()).sum();
assert_eq!(total, 1, "expected exactly one row for pk=1");
let name_col = batches[0].column_by_name("name").unwrap();
let name_arr = name_col.as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(
name_arr.value(0),
"new_1",
"flushed-arm lookup must return the row at the smallest _rowid (newest under reverse-write)"
);
}
fn active_memtable_ref(
schema: &Arc<ArrowSchema>,
batches: &[RecordBatch],
generation: u64,
) -> crate::dataset::mem_wal::scanner::collector::InMemoryMemTableRef {
use crate::dataset::mem_wal::scanner::collector::InMemoryMemTableRef;
let batch_store = Arc::new(BatchStore::with_capacity(64));
let mut index_store = IndexStore::new();
index_store.add_btree("id_idx".to_string(), 0, "id".to_string());
for b in batches {
let (idx, row_offset, _) = batch_store.append(b.clone()).unwrap();
index_store
.insert_with_batch_position(b, row_offset, Some(idx))
.unwrap();
}
InMemoryMemTableRef {
batch_store,
index_store: Arc::new(index_store),
schema: schema.clone(),
generation,
}
}
fn id_at(batch: &RecordBatch) -> i32 {
batch
.column_by_name("id")
.unwrap()
.as_any()
.downcast_ref::<Int32Array>()
.unwrap()
.value(0)
}
fn name_at(batch: &RecordBatch) -> String {
batch
.column_by_name("name")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap()
.value(0)
.to_string()
}
#[tokio::test]
async fn test_lookup_fast_path_active_hit_and_absent() {
use crate::dataset::mem_wal::scanner::collector::InMemoryMemTables;
let schema = create_pk_schema();
let temp = tempfile::tempdir().unwrap();
let base_uri = format!("{}/base", temp.path().to_str().unwrap());
let active = active_memtable_ref(
&schema,
&[create_test_batch(&schema, &[10, 20, 30], "v")],
1,
);
let collector = LsmDataSourceCollector::without_base_table(base_uri, vec![])
.with_in_memory_memtables(
Uuid::new_v4(),
InMemoryMemTables {
active,
frozen: vec![],
},
);
let planner = LsmPointLookupPlanner::new(collector, vec!["id".to_string()], schema.clone());
let row = planner
.lookup(&[ScalarValue::Int32(Some(20))], None)
.await
.unwrap()
.expect("hit");
assert_eq!(row.num_rows(), 1);
assert_eq!(id_at(&row), 20);
assert!(
planner
.lookup(&[ScalarValue::Int32(Some(99))], None)
.await
.unwrap()
.is_none()
);
}
#[tokio::test]
async fn test_lookup_fast_path_newest_duplicate() {
use crate::dataset::mem_wal::scanner::collector::InMemoryMemTables;
let schema = create_pk_schema();
let temp = tempfile::tempdir().unwrap();
let base_uri = format!("{}/base", temp.path().to_str().unwrap());
let active = active_memtable_ref(
&schema,
&[
create_test_batch(&schema, &[5], "old"),
create_test_batch(&schema, &[5], "new"),
],
1,
);
let collector = LsmDataSourceCollector::without_base_table(base_uri, vec![])
.with_in_memory_memtables(
Uuid::new_v4(),
InMemoryMemTables {
active,
frozen: vec![],
},
);
let planner = LsmPointLookupPlanner::new(collector, vec!["id".to_string()], schema);
let row = planner
.lookup(&[ScalarValue::Int32(Some(5))], None)
.await
.unwrap()
.unwrap();
assert_eq!(name_at(&row), "new_5", "must return the newest insert");
}
#[tokio::test]
async fn test_lookup_miss_falls_back_to_base() {
use crate::dataset::mem_wal::scanner::collector::InMemoryMemTables;
let schema = create_pk_schema();
let temp = tempfile::tempdir().unwrap();
let base_uri = format!("{}/base", temp.path().to_str().unwrap());
let base = Arc::new(
create_dataset(
&base_uri,
vec![create_test_batch(&schema, &[1, 2, 3], "base")],
)
.await,
);
let active = active_memtable_ref(&schema, &[create_test_batch(&schema, &[99], "act")], 1);
let collector = LsmDataSourceCollector::new(base, vec![]).with_in_memory_memtables(
Uuid::new_v4(),
InMemoryMemTables {
active,
frozen: vec![],
},
);
let planner = LsmPointLookupPlanner::new(collector, vec!["id".to_string()], schema.clone());
let row = planner
.lookup(&[ScalarValue::Int32(Some(99))], None)
.await
.unwrap()
.unwrap();
assert_eq!(id_at(&row), 99);
let row = planner
.lookup(&[ScalarValue::Int32(Some(2))], None)
.await
.unwrap()
.expect("base hit via fallback");
assert_eq!(id_at(&row), 2);
assert_eq!(name_at(&row), "base_2");
assert!(
planner
.lookup(&[ScalarValue::Int32(Some(1000))], None)
.await
.unwrap()
.is_none()
);
}
#[tokio::test]
async fn test_lookup_projection_regular_columns() {
use crate::dataset::mem_wal::scanner::collector::InMemoryMemTables;
let schema = create_pk_schema();
let temp = tempfile::tempdir().unwrap();
let base_uri = format!("{}/base", temp.path().to_str().unwrap());
let active = active_memtable_ref(
&schema,
&[create_test_batch(&schema, &[10, 20, 30], "v")],
1,
);
let collector = LsmDataSourceCollector::without_base_table(base_uri, vec![])
.with_in_memory_memtables(
Uuid::new_v4(),
InMemoryMemTables {
active,
frozen: vec![],
},
);
let planner = LsmPointLookupPlanner::new(collector, vec!["id".to_string()], schema);
let row = planner
.lookup(&[ScalarValue::Int32(Some(20))], Some(&["name".to_string()]))
.await
.unwrap()
.unwrap();
let row_schema = row.schema();
let names: Vec<&str> = row_schema
.fields()
.iter()
.map(|f| f.name().as_str())
.collect();
assert_eq!(names, vec!["name", "id"]);
assert_eq!(name_at(&row), "v_20");
assert_eq!(id_at(&row), 20);
}
#[tokio::test]
async fn test_lookup_type_mismatch_falls_back_no_panic() {
use crate::dataset::mem_wal::scanner::collector::InMemoryMemTables;
let schema = create_pk_schema();
let temp = tempfile::tempdir().unwrap();
let base_uri = format!("{}/base", temp.path().to_str().unwrap());
let active = active_memtable_ref(
&schema,
&[create_test_batch(&schema, &[10, 20, 30], "v")],
1,
);
let collector = LsmDataSourceCollector::without_base_table(base_uri, vec![])
.with_in_memory_memtables(
Uuid::new_v4(),
InMemoryMemTables {
active,
frozen: vec![],
},
);
let planner = LsmPointLookupPlanner::new(collector, vec!["id".to_string()], schema);
let row = planner
.lookup(&[ScalarValue::Int64(Some(20))], None)
.await
.expect("must not panic on a coercible-but-different key type")
.expect("plan path coerces Int64 → Int32 and finds id=20");
assert_eq!(id_at(&row), 20);
}
#[tokio::test]
async fn test_lookup_empty_pk_values_errors_not_panics() {
use crate::dataset::mem_wal::scanner::collector::InMemoryMemTables;
let schema = create_pk_schema();
let temp = tempfile::tempdir().unwrap();
let base_uri = format!("{}/base", temp.path().to_str().unwrap());
let active = active_memtable_ref(&schema, &[create_test_batch(&schema, &[1], "v")], 1);
let collector = LsmDataSourceCollector::without_base_table(base_uri, vec![])
.with_in_memory_memtables(
Uuid::new_v4(),
InMemoryMemTables {
active,
frozen: vec![],
},
);
let planner = LsmPointLookupPlanner::new(collector, vec!["id".to_string()], schema);
let err = planner.lookup(&[], None).await;
assert!(err.is_err(), "empty pk_values must error, not panic");
}
fn sorted_ids(batch: &RecordBatch) -> Vec<i32> {
let arr = batch
.column_by_name("id")
.unwrap()
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
let mut v: Vec<i32> = (0..arr.len()).map(|i| arr.value(i)).collect();
v.sort_unstable();
v
}
fn active_planner(batches: &[RecordBatch]) -> LsmPointLookupPlanner {
use crate::dataset::mem_wal::scanner::collector::InMemoryMemTables;
let schema = create_pk_schema();
let temp = tempfile::tempdir().unwrap();
let base_uri = format!("{}/base", temp.path().to_str().unwrap());
let active = active_memtable_ref(&schema, batches, 1);
let collector = LsmDataSourceCollector::without_base_table(base_uri, vec![])
.with_in_memory_memtables(
Uuid::new_v4(),
InMemoryMemTables {
active,
frozen: vec![],
},
);
LsmPointLookupPlanner::new(collector, vec!["id".to_string()], schema)
}
#[tokio::test]
async fn test_lookup_many_hits_and_misses() {
let schema = create_pk_schema();
let planner = active_planner(&[create_test_batch(&schema, &[10, 20, 30], "v")]);
let keys = [
ScalarValue::Int32(Some(30)),
ScalarValue::Int32(Some(10)),
ScalarValue::Int32(Some(999)),
ScalarValue::Int32(Some(20)),
];
let batch = planner.lookup_many(&keys, None).await.unwrap();
assert_eq!(batch.num_rows(), 3);
assert_eq!(sorted_ids(&batch), vec![10, 20, 30]);
let empty = planner.lookup_many(&[], None).await.unwrap();
assert_eq!(empty.num_rows(), 0);
assert!(empty.schema().field_with_name("id").is_ok());
}
#[tokio::test]
async fn test_lookup_many_newest_duplicate() {
let schema = create_pk_schema();
let planner = active_planner(&[
create_test_batch(&schema, &[5], "old"),
create_test_batch(&schema, &[5, 7], "new"),
]);
let batch = planner
.lookup_many(
&[ScalarValue::Int32(Some(5)), ScalarValue::Int32(Some(7))],
None,
)
.await
.unwrap();
assert_eq!(batch.num_rows(), 2);
let names = batch
.column_by_name("name")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let mut got: Vec<&str> = (0..names.len()).map(|i| names.value(i)).collect();
got.sort_unstable();
assert_eq!(got, vec!["new_5", "new_7"]);
}
#[tokio::test]
async fn test_lookup_many_projection_and_equivalence_to_lookup() {
let schema = create_pk_schema();
let planner = active_planner(&[create_test_batch(&schema, &[1, 2, 3, 4], "v")]);
let keys = [
ScalarValue::Int32(Some(2)),
ScalarValue::Int32(Some(4)),
ScalarValue::Int32(Some(1)),
];
let proj = vec!["name".to_string()];
let batch = planner.lookup_many(&keys, Some(&proj)).await.unwrap();
let batch_schema = batch.schema();
let names: Vec<&str> = batch_schema
.fields()
.iter()
.map(|f| f.name().as_str())
.collect();
assert_eq!(names, vec!["name", "id"]); assert_eq!(batch.num_rows(), 3);
assert_eq!(sorted_ids(&batch), vec![1, 2, 4]);
}
#[tokio::test]
async fn test_plan_point_lookup_executes() {
use futures::TryStreamExt;
let schema = create_pk_schema();
let planner = active_planner(&[create_test_batch(&schema, &[10, 20, 30], "v")]);
let plan = planner
.plan_point_lookup(
&[ScalarValue::Int32(Some(10)), ScalarValue::Int32(Some(30))],
None,
)
.await
.unwrap();
let ctx = datafusion::prelude::SessionContext::new();
let batches: Vec<RecordBatch> = plan
.execute(0, ctx.task_ctx())
.unwrap()
.try_collect()
.await
.unwrap();
let total: usize = batches.iter().map(|b| b.num_rows()).sum();
assert_eq!(total, 2);
}
#[tokio::test]
async fn test_lookup_against_from_configs_built_index() {
use crate::dataset::mem_wal::index::{BTreeIndexConfig, IndexStore, MemIndexConfig};
use crate::dataset::mem_wal::scanner::collector::{InMemoryMemTableRef, InMemoryMemTables};
let schema = create_pk_schema();
let batch = create_test_batch(&schema, &[10, 20, 30], "v");
let batch_store = Arc::new(BatchStore::with_capacity(16));
let index_store = IndexStore::from_configs(
&[MemIndexConfig::BTree(BTreeIndexConfig {
name: "id_idx".to_string(),
field_id: 0,
column: "id".to_string(),
})],
1000,
100,
)
.unwrap();
let (idx, row_offset, _) = batch_store.append(batch.clone()).unwrap();
index_store
.insert_with_batch_position(&batch, row_offset, Some(idx))
.unwrap();
let temp = tempfile::tempdir().unwrap();
let base_uri = format!("{}/base", temp.path().to_str().unwrap());
let collector = LsmDataSourceCollector::without_base_table(base_uri, vec![])
.with_in_memory_memtables(
Uuid::new_v4(),
InMemoryMemTables {
active: InMemoryMemTableRef {
batch_store,
index_store: Arc::new(index_store),
schema: schema.clone(),
generation: 1,
},
frozen: vec![],
},
);
let planner = LsmPointLookupPlanner::new(collector, vec!["id".to_string()], schema);
let row = planner
.lookup(&[ScalarValue::Int32(Some(20))], None)
.await
.unwrap()
.expect("range fallback must find the row");
assert_eq!(id_at(&row), 20);
assert!(
planner
.lookup(&[ScalarValue::Int32(Some(99))], None)
.await
.unwrap()
.is_none(),
"absent key must miss"
);
}
}