use std::sync::Arc;
use arrow_schema::{DataType, Field, Schema, SchemaRef, SortOptions};
use datafusion::physical_expr::expressions::Column;
use datafusion::physical_expr::{LexOrdering, PhysicalSortExpr};
use datafusion::physical_plan::sorts::sort::SortExec;
use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec;
use datafusion::physical_plan::union::UnionExec;
use datafusion::physical_plan::{ExecutionPlan, limit::GlobalLimitExec};
use datafusion::prelude::Expr;
use lance_core::{Result, is_system_column};
use tracing::instrument;
use super::collector::LsmDataSourceCollector;
use super::data_source::LsmDataSource;
use super::exec::{DeduplicateExec, MEMTABLE_GEN_COLUMN, MemtableGenTagExec, ROW_ADDRESS_COLUMN};
use super::flushed_cache::{FlushedMemTableCache, open_flushed_dataset};
use super::projection::{
build_scanner_projection, canonical_output_schema, null_columns, project_to_canonical,
};
use crate::session::Session;
pub struct LsmScanPlanner {
collector: LsmDataSourceCollector,
pk_columns: Vec<String>,
base_schema: SchemaRef,
session: Option<Arc<Session>>,
flushed_cache: Option<Arc<FlushedMemTableCache>>,
}
impl LsmScanPlanner {
pub fn new(
collector: LsmDataSourceCollector,
pk_columns: Vec<String>,
base_schema: SchemaRef,
) -> Self {
Self {
collector,
pk_columns,
base_schema,
session: None,
flushed_cache: None,
}
}
pub fn with_session(mut self, session: Arc<Session>) -> Self {
self.session = Some(session);
self
}
pub fn with_flushed_cache(mut self, cache: Arc<FlushedMemTableCache>) -> Self {
self.flushed_cache = Some(cache);
self
}
#[instrument(name = "lsm_plan_scan", level = "debug", skip_all, fields(has_filter = filter.is_some(), limit, offset))]
pub async fn plan_scan(
&self,
projection: Option<&[String]>,
filter: Option<&Expr>,
limit: Option<usize>,
offset: Option<usize>,
with_memtable_gen: bool,
keep_row_address: bool,
) -> Result<Arc<dyn ExecutionPlan>> {
let user_wants_rowaddr = projection
.map(|p| p.iter().any(|c| c == ROW_ADDRESS_COLUMN))
.unwrap_or(false);
let keep_row_address = keep_row_address || user_wants_rowaddr;
let sources = self.collector.collect()?;
if sources.is_empty() {
return self.empty_plan(projection, with_memtable_gen, keep_row_address);
}
let sources: Vec<_> = sources.into_iter().rev().collect();
let mut sorted_plans = Vec::new();
for source in sources {
let is_base = matches!(source, LsmDataSource::BaseTable { .. });
let scan = self.build_source_scan(&source, projection, filter).await?;
let local_sort_exprs = self.build_local_sort_exprs(&scan)?;
let lex_ordering = LexOrdering::new(local_sort_exprs).ok_or_else(|| {
lance_core::Error::internal(
"Failed to create LexOrdering from sort expressions".to_string(),
)
})?;
let sorted: Arc<dyn ExecutionPlan> = Arc::new(SortExec::new(lex_ordering, scan));
let after_sort: Arc<dyn ExecutionPlan> = if !is_base && keep_row_address {
null_columns(sorted, &[ROW_ADDRESS_COLUMN])?
} else {
sorted
};
let plan: Arc<dyn ExecutionPlan> = if with_memtable_gen {
Arc::new(MemtableGenTagExec::new(after_sort, source.generation()))
} else {
after_sort
};
sorted_plans.push(plan);
}
let merged: Arc<dyn ExecutionPlan> = if sorted_plans.len() == 1 {
sorted_plans.remove(0)
} else {
let merge_sort_exprs = self.build_merge_sort_exprs(&sorted_plans[0])?;
let lex_ordering = LexOrdering::new(merge_sort_exprs).ok_or_else(|| {
lance_core::Error::internal(
"Failed to create LexOrdering from sort expressions".to_string(),
)
})?;
#[allow(deprecated)]
let union = Arc::new(UnionExec::new(sorted_plans));
Arc::new(SortPreservingMergeExec::new(lex_ordering, union))
};
let dedup = DeduplicateExec::new_sorted(
merged,
self.pk_columns.clone(),
with_memtable_gen,
keep_row_address,
)?;
let mut plan: Arc<dyn ExecutionPlan> = Arc::new(dedup);
let user_wants_system = projection
.map(|p| p.iter().any(|c| is_system_column(c)))
.unwrap_or(false);
if user_wants_system {
plan = project_to_canonical(
plan,
&self.canonical_scan_schema(projection, with_memtable_gen, keep_row_address),
)?;
}
if let Some(limit) = limit {
plan = Arc::new(GlobalLimitExec::new(plan, offset.unwrap_or(0), Some(limit)));
}
Ok(plan)
}
fn canonical_scan_schema(
&self,
projection: Option<&[String]>,
with_memtable_gen: bool,
keep_row_address: bool,
) -> SchemaRef {
let canonical = canonical_output_schema(
projection,
&self.base_schema,
&self.pk_columns,
false, );
let mut fields: Vec<Arc<Field>> = canonical.fields().iter().cloned().collect();
if keep_row_address && !fields.iter().any(|f| f.name() == ROW_ADDRESS_COLUMN) {
fields.push(Arc::new(Field::new(
ROW_ADDRESS_COLUMN,
DataType::UInt64,
true,
)));
}
if with_memtable_gen && !fields.iter().any(|f| f.name() == MEMTABLE_GEN_COLUMN) {
fields.push(Arc::new(Field::new(
MEMTABLE_GEN_COLUMN,
DataType::UInt64,
false,
)));
}
Arc::new(Schema::new(fields))
}
fn build_local_sort_exprs(
&self,
plan: &Arc<dyn ExecutionPlan>,
) -> Result<Vec<PhysicalSortExpr>> {
let schema = plan.schema();
let mut sort_exprs = Vec::new();
for col in &self.pk_columns {
let (idx, _) = schema.column_with_name(col).ok_or_else(|| {
lance_core::Error::invalid_input(format!("Column '{}' not found in schema", col))
})?;
sort_exprs.push(PhysicalSortExpr {
expr: Arc::new(Column::new(col, idx)),
options: SortOptions {
descending: false,
nulls_first: false,
},
});
}
let (addr_idx, _) = schema.column_with_name(ROW_ADDRESS_COLUMN).ok_or_else(|| {
lance_core::Error::invalid_input(format!(
"Column '{}' not found in schema",
ROW_ADDRESS_COLUMN
))
})?;
sort_exprs.push(PhysicalSortExpr {
expr: Arc::new(Column::new(ROW_ADDRESS_COLUMN, addr_idx)),
options: SortOptions {
descending: true,
nulls_first: false,
},
});
Ok(sort_exprs)
}
fn build_merge_sort_exprs(
&self,
plan: &Arc<dyn ExecutionPlan>,
) -> Result<Vec<PhysicalSortExpr>> {
let schema = plan.schema();
let mut sort_exprs = Vec::new();
for col in &self.pk_columns {
let (idx, _) = schema.column_with_name(col).ok_or_else(|| {
lance_core::Error::invalid_input(format!("Column '{}' not found in schema", col))
})?;
sort_exprs.push(PhysicalSortExpr {
expr: Arc::new(Column::new(col, idx)),
options: SortOptions {
descending: false,
nulls_first: false,
},
});
}
Ok(sort_exprs)
}
async fn build_source_scan(
&self,
source: &LsmDataSource,
projection: Option<&[String]>,
filter: Option<&Expr>,
) -> Result<Arc<dyn ExecutionPlan>> {
match source {
LsmDataSource::BaseTable { dataset } => {
let mut scanner = dataset.scan();
let cols =
build_scanner_projection(projection, &self.base_schema, &self.pk_columns);
scanner.project(&cols.iter().map(|s| s.as_str()).collect::<Vec<_>>())?;
scanner.with_row_address();
if let Some(expr) = filter {
scanner.filter_expr(expr.clone());
}
scanner.create_plan().await
}
LsmDataSource::FlushedMemTable { path, .. } => {
let dataset =
open_flushed_dataset(path, self.session.as_ref(), self.flushed_cache.as_ref())
.await?;
let mut scanner = dataset.scan();
let cols =
build_scanner_projection(projection, &self.base_schema, &self.pk_columns);
scanner.project(&cols.iter().map(|s| s.as_str()).collect::<Vec<_>>())?;
scanner.with_row_address();
if let Some(expr) = filter {
scanner.filter_expr(expr.clone());
}
scanner.create_plan().await
}
LsmDataSource::ActiveMemTable {
batch_store,
index_store,
schema,
..
} => {
use crate::dataset::mem_wal::memtable::scanner::MemTableScanner;
let mut scanner =
MemTableScanner::new(batch_store.clone(), index_store.clone(), schema.clone());
let cols =
build_scanner_projection(projection, &self.base_schema, &self.pk_columns);
scanner.project(&cols.iter().map(|s| s.as_str()).collect::<Vec<_>>());
scanner.with_row_address();
if let Some(expr) = filter {
scanner.filter_expr(expr.clone());
}
scanner.create_plan().await
}
}
}
fn empty_plan(
&self,
projection: Option<&[String]>,
with_memtable_gen: bool,
keep_row_address: bool,
) -> Result<Arc<dyn ExecutionPlan>> {
use datafusion::physical_plan::empty::EmptyExec;
let schema = self.canonical_scan_schema(projection, with_memtable_gen, keep_row_address);
Ok(Arc::new(EmptyExec::new(schema)))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dataset::mem_wal::scanner::data_source::ShardSnapshot;
fn create_test_schema() -> SchemaRef {
Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("name", DataType::Utf8, true),
Field::new("value", DataType::Float64, true),
]))
}
#[test]
fn test_build_projection_with_rowaddr() {
let schema = create_test_schema();
let pk_columns = vec!["id".to_string()];
let mut cols: Vec<String> = vec!["name".to_string()];
for pk in &pk_columns {
if !cols.contains(pk) {
cols.push(pk.clone());
}
}
assert!(cols.contains(&"name".to_string()));
assert!(cols.contains(&"id".to_string()));
let cols_all: Vec<String> = schema.fields().iter().map(|f| f.name().clone()).collect();
assert_eq!(cols_all.len(), 3);
}
#[test]
fn test_shard_snapshot() {
let shard_id = uuid::Uuid::new_v4();
let snapshot = ShardSnapshot::new(shard_id)
.with_current_generation(5)
.with_flushed_generation(1, "gen_1".to_string())
.with_flushed_generation(2, "gen_2".to_string());
assert_eq!(snapshot.flushed_generations.len(), 2);
assert_eq!(snapshot.current_generation, 5);
}
}
#[cfg(test)]
mod integration_tests {
use std::collections::HashMap;
use std::sync::Arc;
use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator, StringArray};
use arrow_schema::{DataType, Field, Schema as ArrowSchema};
use futures::TryStreamExt;
use uuid::Uuid;
use crate::dataset::mem_wal::scanner::LsmScanner;
use crate::dataset::mem_wal::scanner::collector::{InMemoryMemTableRef, InMemoryMemTables};
use crate::dataset::mem_wal::scanner::data_source::ShardSnapshot;
use crate::dataset::mem_wal::write::{BatchStore, IndexStore};
use crate::dataset::{Dataset, WriteParams};
use crate::utils::test::assert_plan_node_equals;
fn create_pk_schema() -> Arc<ArrowSchema> {
let mut id_metadata = HashMap::new();
id_metadata.insert(
"lance-schema:unenforced-primary-key".to_string(),
"true".to_string(),
);
let id_field = Field::new("id", DataType::Int32, false).with_metadata(id_metadata);
Arc::new(ArrowSchema::new(vec![
id_field,
Field::new("name", DataType::Utf8, true),
]))
}
fn create_test_batch(schema: &ArrowSchema, ids: &[i32], name_prefix: &str) -> RecordBatch {
let names: Vec<String> = ids
.iter()
.map(|id| format!("{}_{}", name_prefix, id))
.collect();
RecordBatch::try_new(
Arc::new(schema.clone()),
vec![
Arc::new(Int32Array::from(ids.to_vec())),
Arc::new(StringArray::from(names)),
],
)
.unwrap()
}
async fn create_dataset(uri: &str, batches: Vec<RecordBatch>) -> Dataset {
let schema = batches[0].schema();
let reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema);
Dataset::write(reader, uri, Some(WriteParams::default()))
.await
.unwrap()
}
async fn setup_multi_level_lsm() -> (
Arc<Dataset>,
Vec<ShardSnapshot>,
Option<(Uuid, InMemoryMemTables)>,
Vec<String>,
String, // temp_dir path for cleanup
) {
let schema = create_pk_schema();
let temp_dir = tempfile::tempdir().unwrap();
let base_path = temp_dir.path().to_str().unwrap();
let base_uri = format!("{}/base", base_path);
let base_batch = create_test_batch(&schema, &[1, 2, 3, 4, 5], "base");
let base_dataset = Arc::new(create_dataset(&base_uri, vec![base_batch]).await);
let shard_id = Uuid::new_v4();
let gen1_uri = format!("{}/_mem_wal/{}/gen_1", base_uri, shard_id);
let gen1_batch = create_test_batch(&schema, &[3, 4], "gen1");
create_dataset(&gen1_uri, vec![gen1_batch]).await;
let gen2_uri = format!("{}/_mem_wal/{}/gen_2", base_uri, shard_id);
let gen2_batch = create_test_batch(&schema, &[4, 5, 6], "gen2");
create_dataset(&gen2_uri, vec![gen2_batch]).await;
let shard_snapshot = ShardSnapshot::new(shard_id)
.with_current_generation(3)
.with_flushed_generation(1, "gen_1".to_string())
.with_flushed_generation(2, "gen_2".to_string());
let batch_store = Arc::new(BatchStore::with_capacity(100));
let index_store = Arc::new(IndexStore::new());
let active_batch = create_test_batch(&schema, &[5, 6, 7], "active");
let _ = batch_store.append(active_batch);
let active_memtable = InMemoryMemTables {
active: InMemoryMemTableRef {
batch_store,
index_store,
schema: schema.clone(),
generation: 3,
},
frozen: vec![],
};
let pk_columns = vec!["id".to_string()];
let temp_path = temp_dir.keep().to_string_lossy().to_string();
(
base_dataset,
vec![shard_snapshot],
Some((shard_id, active_memtable)),
pk_columns,
temp_path,
)
}
#[tokio::test]
async fn test_lsm_scan_query_plan_without_memtable_gen() {
let (base_dataset, shard_snapshots, active_memtable, pk_columns, _temp_path) =
setup_multi_level_lsm().await;
let mut scanner = LsmScanner::new(base_dataset, shard_snapshots, pk_columns);
if let Some((shard_id, memtable)) = active_memtable {
scanner = scanner.with_in_memory_memtables(shard_id, memtable);
}
let plan = scanner.create_plan().await.unwrap();
assert_plan_node_equals(
plan,
"DeduplicateExec: pk=[id], with_memtable_gen=false, keep_addr=false, input_sorted=true
SortPreservingMergeExec: [id@0 ASC NULLS LAST]
UnionExec
SortExec: expr=[id@0 ASC NULLS LAST, _rowaddr@2 DESC NULLS LAST]...
MemTableScanExec: projection=[id, name, _rowaddr], with_row_id=false, with_row_address=true
SortExec: expr=[id@0 ASC NULLS LAST, _rowaddr@2 DESC NULLS LAST]...
LanceRead:...gen_2...
SortExec: expr=[id@0 ASC NULLS LAST, _rowaddr@2 DESC NULLS LAST]...
LanceRead:...gen_1...
SortExec: expr=[id@0 ASC NULLS LAST, _rowaddr@2 DESC NULLS LAST]...
LanceRead:...base/data...refine_filter=--",
)
.await
.unwrap();
}
#[tokio::test]
async fn test_lsm_scan_query_plan_with_memtable_gen() {
let (base_dataset, shard_snapshots, active_memtable, pk_columns, _temp_path) =
setup_multi_level_lsm().await;
let mut scanner =
LsmScanner::new(base_dataset, shard_snapshots, pk_columns).with_memtable_gen();
if let Some((shard_id, memtable)) = active_memtable {
scanner = scanner.with_in_memory_memtables(shard_id, memtable);
}
let plan = scanner.create_plan().await.unwrap();
assert_plan_node_equals(
plan,
"DeduplicateExec: pk=[id], with_memtable_gen=true, keep_addr=false, input_sorted=true
SortPreservingMergeExec: [id@0 ASC NULLS LAST]
UnionExec
MemtableGenTagExec: gen=gen3
SortExec: expr=[id@0 ASC NULLS LAST, _rowaddr@2 DESC NULLS LAST]...
MemTableScanExec: projection=[id, name, _rowaddr], with_row_id=false, with_row_address=true
MemtableGenTagExec: gen=gen2
SortExec: expr=[id@0 ASC NULLS LAST, _rowaddr@2 DESC NULLS LAST]...
LanceRead:...gen_2...
MemtableGenTagExec: gen=gen1
SortExec: expr=[id@0 ASC NULLS LAST, _rowaddr@2 DESC NULLS LAST]...
LanceRead:...gen_1...
MemtableGenTagExec: gen=base
SortExec: expr=[id@0 ASC NULLS LAST, _rowaddr@2 DESC NULLS LAST]...
LanceRead:...base/data...refine_filter=--",
)
.await
.unwrap();
}
#[tokio::test]
async fn test_lsm_scan_deduplication_results() {
let (base_dataset, shard_snapshots, active_memtable, pk_columns, _temp_path) =
setup_multi_level_lsm().await;
let mut scanner = LsmScanner::new(base_dataset, shard_snapshots, pk_columns);
if let Some((shard_id, memtable)) = active_memtable {
scanner = scanner.with_in_memory_memtables(shard_id, memtable);
}
let batches: Vec<RecordBatch> = scanner
.try_into_stream()
.await
.unwrap()
.try_collect()
.await
.unwrap();
let mut results: HashMap<i32, String> = HashMap::new();
for batch in batches {
let ids = batch
.column_by_name("id")
.unwrap()
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
let names = batch
.column_by_name("name")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
for i in 0..batch.num_rows() {
results.insert(ids.value(i), names.value(i).to_string());
}
}
assert_eq!(results.len(), 7, "Should have 7 unique rows after dedup");
assert_eq!(results.get(&1), Some(&"base_1".to_string()));
assert_eq!(results.get(&2), Some(&"base_2".to_string()));
assert_eq!(results.get(&3), Some(&"gen1_3".to_string()));
assert_eq!(results.get(&4), Some(&"gen2_4".to_string()));
assert_eq!(results.get(&5), Some(&"active_5".to_string()));
assert_eq!(results.get(&6), Some(&"active_6".to_string()));
assert_eq!(results.get(&7), Some(&"active_7".to_string()));
}
#[tokio::test]
async fn test_lsm_scan_frozen_memtable_in_read_union() {
let schema = create_pk_schema();
let temp_dir = tempfile::tempdir().unwrap();
let base_path = temp_dir.path().to_str().unwrap();
let base_uri = format!("{}/base", base_path);
let base_dataset = Arc::new(
create_dataset(
&base_uri,
vec![create_test_batch(&schema, &[1, 2, 3, 4, 5], "base")],
)
.await,
);
let shard_id = Uuid::new_v4();
let gen1_uri = format!("{}/_mem_wal/{}/gen_1", base_uri, shard_id);
create_dataset(&gen1_uri, vec![create_test_batch(&schema, &[3, 4], "gen1")]).await;
let gen2_uri = format!("{}/_mem_wal/{}/gen_2", base_uri, shard_id);
create_dataset(
&gen2_uri,
vec![create_test_batch(&schema, &[4, 5, 6], "gen2")],
)
.await;
let shard_snapshot = ShardSnapshot::new(shard_id)
.with_current_generation(4)
.with_flushed_generation(1, "gen_1".to_string())
.with_flushed_generation(2, "gen_2".to_string());
let frozen_store = Arc::new(BatchStore::with_capacity(100));
let _ = frozen_store.append(create_test_batch(&schema, &[6, 7], "frozen"));
let frozen = InMemoryMemTableRef {
batch_store: frozen_store,
index_store: Arc::new(IndexStore::new()),
schema: schema.clone(),
generation: 3,
};
let active_store = Arc::new(BatchStore::with_capacity(100));
let _ = active_store.append(create_test_batch(&schema, &[7, 8], "active"));
let in_memory = InMemoryMemTables {
active: InMemoryMemTableRef {
batch_store: active_store,
index_store: Arc::new(IndexStore::new()),
schema: schema.clone(),
generation: 4,
},
frozen: vec![frozen],
};
let scanner = LsmScanner::new(base_dataset, vec![shard_snapshot], vec!["id".to_string()])
.with_in_memory_memtables(shard_id, in_memory);
let batches: Vec<RecordBatch> = scanner
.try_into_stream()
.await
.unwrap()
.try_collect()
.await
.unwrap();
let mut results: HashMap<i32, String> = HashMap::new();
for batch in batches {
let ids = batch
.column_by_name("id")
.unwrap()
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
let names = batch
.column_by_name("name")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
for i in 0..batch.num_rows() {
results.insert(ids.value(i), names.value(i).to_string());
}
}
assert_eq!(results.len(), 8, "ids 1-8 should all be present");
assert_eq!(results.get(&1), Some(&"base_1".to_string()));
assert_eq!(results.get(&3), Some(&"gen1_3".to_string()));
assert_eq!(results.get(&4), Some(&"gen2_4".to_string()));
assert_eq!(results.get(&5), Some(&"gen2_5".to_string()));
assert_eq!(results.get(&6), Some(&"frozen_6".to_string()));
assert_eq!(results.get(&7), Some(&"active_7".to_string()));
assert_eq!(results.get(&8), Some(&"active_8".to_string()));
}
#[tokio::test]
async fn test_lsm_scan_with_projection() {
let (base_dataset, shard_snapshots, active_memtable, pk_columns, _temp_path) =
setup_multi_level_lsm().await;
let mut scanner =
LsmScanner::new(base_dataset, shard_snapshots, pk_columns).project(&["id"]);
if let Some((shard_id, memtable)) = active_memtable {
scanner = scanner.with_in_memory_memtables(shard_id, memtable);
}
let batches: Vec<RecordBatch> = scanner
.try_into_stream()
.await
.unwrap()
.try_collect()
.await
.unwrap();
let schema = batches[0].schema();
assert_eq!(schema.fields().len(), 1);
assert_eq!(schema.field(0).name(), "id");
let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
assert_eq!(total_rows, 7, "Should have 7 unique rows after dedup");
}
#[tokio::test]
async fn test_lsm_scan_with_limit() {
let (base_dataset, shard_snapshots, active_memtable, pk_columns, _temp_path) =
setup_multi_level_lsm().await;
let mut scanner = LsmScanner::new(base_dataset, shard_snapshots, pk_columns).limit(3, None);
if let Some((shard_id, memtable)) = active_memtable {
scanner = scanner.with_in_memory_memtables(shard_id, memtable);
}
let batches: Vec<RecordBatch> = scanner
.try_into_stream()
.await
.unwrap()
.try_collect()
.await
.unwrap();
let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
assert_eq!(total_rows, 3, "Should have 3 rows due to limit");
}
#[tokio::test]
async fn test_lsm_scan_base_only() {
let (base_dataset, _, _, pk_columns, _temp_path) = setup_multi_level_lsm().await;
let scanner = LsmScanner::new(base_dataset, vec![], pk_columns);
let plan = scanner.create_plan().await.unwrap();
assert_plan_node_equals(
plan,
"DeduplicateExec: pk=[id], with_memtable_gen=false, keep_addr=false, input_sorted=true
SortExec: expr=[id@0 ASC NULLS LAST, _rowaddr@2 DESC NULLS LAST]...
LanceRead:...base/data...refine_filter=--",
)
.await
.unwrap();
let scanner = LsmScanner::new(
Arc::new(
Dataset::open(&format!("{}/base", _temp_path))
.await
.unwrap(),
),
vec![],
vec!["id".to_string()],
);
let batches: Vec<RecordBatch> = scanner
.try_into_stream()
.await
.unwrap()
.try_collect()
.await
.unwrap();
let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
assert_eq!(total_rows, 5, "Should have 5 rows from base table");
}
#[tokio::test]
async fn test_lsm_scan_flushed_only_no_active() {
let (base_dataset, shard_snapshots, _, pk_columns, _temp_path) =
setup_multi_level_lsm().await;
let scanner = LsmScanner::new(base_dataset, shard_snapshots, pk_columns);
let batches: Vec<RecordBatch> = scanner
.try_into_stream()
.await
.unwrap()
.try_collect()
.await
.unwrap();
let mut results: HashMap<i32, String> = HashMap::new();
for batch in batches {
let ids = batch
.column_by_name("id")
.unwrap()
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
let names = batch
.column_by_name("name")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
for i in 0..batch.num_rows() {
results.insert(ids.value(i), names.value(i).to_string());
}
}
assert_eq!(results.len(), 6, "Should have 6 unique rows (no id=7)");
assert_eq!(results.get(&1), Some(&"base_1".to_string()));
assert_eq!(results.get(&2), Some(&"base_2".to_string()));
assert_eq!(results.get(&3), Some(&"gen1_3".to_string()));
assert_eq!(results.get(&4), Some(&"gen2_4".to_string()));
assert_eq!(results.get(&5), Some(&"gen2_5".to_string()));
assert_eq!(results.get(&6), Some(&"gen2_6".to_string()));
assert_eq!(results.get(&7), None);
}
#[tokio::test]
async fn test_lsm_scan_with_row_address() {
let (base_dataset, shard_snapshots, active_memtable, pk_columns, _temp_path) =
setup_multi_level_lsm().await;
let mut scanner =
LsmScanner::new(base_dataset, shard_snapshots, pk_columns).with_row_address();
if let Some((shard_id, memtable)) = active_memtable {
scanner = scanner.with_in_memory_memtables(shard_id, memtable);
}
let plan = scanner.create_plan().await.unwrap();
assert_plan_node_equals(
plan,
"DeduplicateExec: pk=[id], with_memtable_gen=false, keep_addr=true, input_sorted=true
SortPreservingMergeExec: [id@0 ASC NULLS LAST]
UnionExec
ProjectionExec: expr=[id@0 as id, name@1 as name, NULL as _rowaddr]
SortExec: expr=[id@0 ASC NULLS LAST, _rowaddr@2 DESC NULLS LAST]...
MemTableScanExec: projection=[id, name, _rowaddr], with_row_id=false, with_row_address=true
ProjectionExec: expr=[id@0 as id, name@1 as name, NULL as _rowaddr]
SortExec: expr=[id@0 ASC NULLS LAST, _rowaddr@2 DESC NULLS LAST]...
LanceRead:...gen_2...
ProjectionExec: expr=[id@0 as id, name@1 as name, NULL as _rowaddr]
SortExec: expr=[id@0 ASC NULLS LAST, _rowaddr@2 DESC NULLS LAST]...
LanceRead:...gen_1...
SortExec: expr=[id@0 ASC NULLS LAST, _rowaddr@2 DESC NULLS LAST]...
LanceRead:...base/data...refine_filter=--",
)
.await
.unwrap();
let scanner = LsmScanner::new(
Arc::new(
Dataset::open(&format!("{}/base", _temp_path))
.await
.unwrap(),
),
vec![],
vec!["id".to_string()],
)
.with_row_address();
let batches: Vec<RecordBatch> = scanner
.try_into_stream()
.await
.unwrap()
.try_collect()
.await
.unwrap();
let schema = batches[0].schema();
assert!(
schema.column_with_name("_rowaddr").is_some(),
"Schema should include _rowaddr"
);
}
#[tokio::test]
async fn test_lsm_scan_with_both_memtable_gen_and_row_address() {
let (base_dataset, shard_snapshots, active_memtable, pk_columns, _temp_path) =
setup_multi_level_lsm().await;
let mut scanner = LsmScanner::new(base_dataset, shard_snapshots, pk_columns)
.with_memtable_gen()
.with_row_address();
if let Some((shard_id, memtable)) = active_memtable {
scanner = scanner.with_in_memory_memtables(shard_id, memtable);
}
let plan = scanner.create_plan().await.unwrap();
assert_plan_node_equals(
plan,
"DeduplicateExec: pk=[id], with_memtable_gen=true, keep_addr=true, input_sorted=true
SortPreservingMergeExec: [id@0 ASC NULLS LAST]
UnionExec
MemtableGenTagExec: gen=gen3
ProjectionExec: expr=[id@0 as id, name@1 as name, NULL as _rowaddr]
SortExec: expr=[id@0 ASC NULLS LAST, _rowaddr@2 DESC NULLS LAST]...
MemTableScanExec: projection=[id, name, _rowaddr], with_row_id=false, with_row_address=true
MemtableGenTagExec: gen=gen2
ProjectionExec: expr=[id@0 as id, name@1 as name, NULL as _rowaddr]
SortExec: expr=[id@0 ASC NULLS LAST, _rowaddr@2 DESC NULLS LAST]...
LanceRead:...gen_2...
MemtableGenTagExec: gen=gen1
ProjectionExec: expr=[id@0 as id, name@1 as name, NULL as _rowaddr]
SortExec: expr=[id@0 ASC NULLS LAST, _rowaddr@2 DESC NULLS LAST]...
LanceRead:...gen_1...
MemtableGenTagExec: gen=base
SortExec: expr=[id@0 ASC NULLS LAST, _rowaddr@2 DESC NULLS LAST]...
LanceRead:...base/data...refine_filter=--",
)
.await
.unwrap();
}
async fn setup_multi_level_lsm_with_btree_index() -> (
Arc<Dataset>,
Vec<ShardSnapshot>,
Option<(Uuid, InMemoryMemTables)>,
Vec<String>,
String,
) {
use crate::index::CreateIndexBuilder;
use lance_index::IndexType;
use lance_index::scalar::ScalarIndexParams;
let schema = create_pk_schema();
let temp_dir = tempfile::tempdir().unwrap();
let base_path = temp_dir.path().to_str().unwrap();
let base_uri = format!("{}/base", base_path);
let base_batch = create_test_batch(&schema, &[1, 2, 3, 4, 5], "base");
let mut base_dataset = create_dataset(&base_uri, vec![base_batch]).await;
let params = ScalarIndexParams::default();
CreateIndexBuilder::new(&mut base_dataset, &["id"], IndexType::BTree, ¶ms)
.await
.unwrap();
let base_dataset = Arc::new(Dataset::open(&base_uri).await.unwrap());
let shard_id = Uuid::new_v4();
let gen1_uri = format!("{}/_mem_wal/{}/gen_1", base_uri, shard_id);
let gen1_batch = create_test_batch(&schema, &[3, 4], "gen1");
let mut gen1_dataset = create_dataset(&gen1_uri, vec![gen1_batch]).await;
CreateIndexBuilder::new(&mut gen1_dataset, &["id"], IndexType::BTree, ¶ms)
.await
.unwrap();
let gen2_uri = format!("{}/_mem_wal/{}/gen_2", base_uri, shard_id);
let gen2_batch = create_test_batch(&schema, &[4, 5, 6], "gen2");
let mut gen2_dataset = create_dataset(&gen2_uri, vec![gen2_batch]).await;
CreateIndexBuilder::new(&mut gen2_dataset, &["id"], IndexType::BTree, ¶ms)
.await
.unwrap();
let shard_snapshot = ShardSnapshot::new(shard_id)
.with_current_generation(3)
.with_flushed_generation(1, "gen_1".to_string())
.with_flushed_generation(2, "gen_2".to_string());
let batch_store = Arc::new(BatchStore::with_capacity(100));
let mut index_store = IndexStore::new();
index_store.add_btree("id_idx".to_string(), 0, "id".to_string());
let active_batch = create_test_batch(&schema, &[5, 6, 7], "active");
let _ = batch_store.append(active_batch.clone());
index_store
.insert_with_batch_position(&active_batch, 0, Some(0))
.unwrap();
let index_store = Arc::new(index_store);
let active_memtable = InMemoryMemTables {
active: InMemoryMemTableRef {
batch_store,
index_store,
schema: schema.clone(),
generation: 3,
},
frozen: vec![],
};
let pk_columns = vec!["id".to_string()];
let temp_path = temp_dir.keep().to_string_lossy().to_string();
(
base_dataset,
vec![shard_snapshot],
Some((shard_id, active_memtable)),
pk_columns,
temp_path,
)
}
#[tokio::test]
async fn test_lsm_scan_with_btree_index_filter() {
let (base_dataset, shard_snapshots, active_memtable, pk_columns, _temp_path) =
setup_multi_level_lsm_with_btree_index().await;
let mut scanner = LsmScanner::new(base_dataset, shard_snapshots, pk_columns)
.filter("id = 5")
.unwrap();
if let Some((shard_id, memtable)) = active_memtable {
scanner = scanner.with_in_memory_memtables(shard_id, memtable);
}
let plan = scanner.create_plan().await.unwrap();
use datafusion::physical_plan::displayable;
let plan_str = format!("{}", displayable(plan.as_ref()).indent(true));
assert!(
plan_str.contains("DeduplicateExec: pk=[id]"),
"Should have DeduplicateExec at top"
);
assert!(
plan_str.contains("SortPreservingMergeExec"),
"Should use SortPreservingMergeExec for merging"
);
assert!(plan_str.contains("UnionExec"), "Should have UnionExec");
assert!(
plan_str.contains("BTreeIndexExec: predicate=Eq"),
"Active memtable should use BTreeIndexExec instead of MemTableScanExec"
);
assert!(
plan_str.contains("gen_2") && plan_str.contains("full_filter="),
"gen_2 should have filter pushed down"
);
assert!(
plan_str.contains("gen_1") && plan_str.contains("full_filter="),
"gen_1 should have filter pushed down"
);
assert!(
plan_str.contains("base/data") && plan_str.contains("full_filter="),
"base table should have filter pushed down"
);
let batches: Vec<RecordBatch> = scanner
.try_into_stream()
.await
.unwrap()
.try_collect()
.await
.unwrap();
let mut results: HashMap<i32, String> = HashMap::new();
for batch in batches {
let ids = batch
.column_by_name("id")
.unwrap()
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
let names = batch
.column_by_name("name")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
for i in 0..batch.num_rows() {
results.insert(ids.value(i), names.value(i).to_string());
}
}
assert_eq!(results.len(), 1, "Filter should return only matching rows");
assert_eq!(
results.get(&5),
Some(&"active_5".to_string()),
"Should get newest version (active) for id=5"
);
}
#[tokio::test]
async fn test_lsm_scan_with_filter_no_index() {
let (base_dataset, shard_snapshots, active_memtable, pk_columns, _temp_path) =
setup_multi_level_lsm().await;
let mut scanner = LsmScanner::new(base_dataset, shard_snapshots, pk_columns)
.filter("id = 3")
.unwrap();
if let Some((shard_id, memtable)) = active_memtable {
scanner = scanner.with_in_memory_memtables(shard_id, memtable);
}
let batches: Vec<RecordBatch> = scanner
.try_into_stream()
.await
.unwrap()
.try_collect()
.await
.unwrap();
let mut results: HashMap<i32, String> = HashMap::new();
for batch in batches {
let ids = batch
.column_by_name("id")
.unwrap()
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
let names = batch
.column_by_name("name")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
for i in 0..batch.num_rows() {
results.insert(ids.value(i), names.value(i).to_string());
}
}
assert_eq!(results.len(), 1);
assert_eq!(results.get(&3), Some(&"gen1_3".to_string()));
}
#[tokio::test]
async fn test_lsm_scan_without_base_table() {
let (base_dataset, shard_snapshots, active_memtable, pk_columns, _temp_path) =
setup_multi_level_lsm().await;
let base_uri = base_dataset.uri().to_string();
let arrow_schema: arrow_schema::Schema = base_dataset.schema().into();
let schema = Arc::new(arrow_schema);
let mut scanner =
LsmScanner::without_base_table(schema, base_uri, shard_snapshots, pk_columns);
if let Some((shard_id, memtable)) = active_memtable {
scanner = scanner.with_in_memory_memtables(shard_id, memtable);
}
let plan = scanner.create_plan().await.unwrap();
let plan_str = format!(
"{}",
datafusion::physical_plan::displayable(plan.as_ref()).indent(true)
);
assert!(
!plan_str.contains("base/data"),
"Plan must not include base table scan, got: {}",
plan_str
);
assert!(
plan_str.contains("gen_1") && plan_str.contains("gen_2"),
"Plan must scan flushed generations, got: {}",
plan_str
);
assert!(
plan_str.contains("MemTableScanExec"),
"Plan must scan the active memtable, got: {}",
plan_str
);
let batches: Vec<RecordBatch> = scanner
.try_into_stream()
.await
.unwrap()
.try_collect()
.await
.unwrap();
let mut results: HashMap<i32, String> = HashMap::new();
for batch in batches {
let ids = batch
.column_by_name("id")
.unwrap()
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
let names = batch
.column_by_name("name")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
for i in 0..batch.num_rows() {
results.insert(ids.value(i), names.value(i).to_string());
}
}
assert_eq!(results.len(), 5, "Fresh tier should yield 5 unique rows");
assert_eq!(results.get(&1), None);
assert_eq!(results.get(&2), None);
assert_eq!(results.get(&3), Some(&"gen1_3".to_string()));
assert_eq!(results.get(&4), Some(&"gen2_4".to_string()));
assert_eq!(results.get(&5), Some(&"active_5".to_string()));
assert_eq!(results.get(&6), Some(&"active_6".to_string()));
assert_eq!(results.get(&7), Some(&"active_7".to_string()));
}
#[tokio::test]
async fn test_lsm_scan_without_base_table_with_filter() {
let (base_dataset, shard_snapshots, active_memtable, pk_columns, _temp_path) =
setup_multi_level_lsm().await;
let base_uri = base_dataset.uri().to_string();
let arrow_schema: arrow_schema::Schema = base_dataset.schema().into();
let schema = Arc::new(arrow_schema);
let mut scanner =
LsmScanner::without_base_table(schema, base_uri, shard_snapshots, pk_columns)
.filter("id > 3")
.unwrap();
if let Some((shard_id, memtable)) = active_memtable {
scanner = scanner.with_in_memory_memtables(shard_id, memtable);
}
let batches: Vec<RecordBatch> = scanner
.try_into_stream()
.await
.unwrap()
.try_collect()
.await
.unwrap();
let mut results: HashMap<i32, String> = HashMap::new();
for batch in batches {
let ids = batch
.column_by_name("id")
.unwrap()
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
let names = batch
.column_by_name("name")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
for i in 0..batch.num_rows() {
results.insert(ids.value(i), names.value(i).to_string());
}
}
assert_eq!(results.len(), 4);
assert_eq!(results.get(&4), Some(&"gen2_4".to_string()));
assert_eq!(results.get(&5), Some(&"active_5".to_string()));
assert_eq!(results.get(&6), Some(&"active_6".to_string()));
assert_eq!(results.get(&7), Some(&"active_7".to_string()));
}
#[tokio::test]
async fn test_lsm_scan_without_base_table_no_flushed_no_active() {
let schema = create_pk_schema();
let scanner = LsmScanner::without_base_table(
schema,
"memory:///fresh-tier-empty",
vec![],
vec!["id".to_string()],
);
let batches: Vec<RecordBatch> = scanner
.try_into_stream()
.await
.unwrap()
.try_collect()
.await
.unwrap();
let total: usize = batches.iter().map(|b| b.num_rows()).sum();
assert_eq!(total, 0);
}
#[tokio::test]
async fn test_lsm_scan_projection_with_system_columns() {
use lance_core::is_system_column;
let (base_dataset, shard_snapshots, active_memtable, pk_columns, _temp_path) =
setup_multi_level_lsm().await;
let mut scanner = LsmScanner::new(base_dataset, shard_snapshots, pk_columns).project(&[
"id",
"_rowoffset",
"name",
"_rowaddr",
"_rowid",
]);
if let Some((shard_id, memtable)) = active_memtable {
scanner = scanner.with_in_memory_memtables(shard_id, memtable);
}
let batches: Vec<RecordBatch> = scanner
.try_into_stream()
.await
.expect("plan must execute when system columns are in projection")
.try_collect()
.await
.expect("collecting batches must not fail");
assert!(!batches.is_empty(), "expected at least one batch");
let out_schema = batches[0].schema();
let names: Vec<&str> = out_schema
.fields()
.iter()
.map(|f| f.name().as_str())
.collect();
assert_eq!(
names,
vec!["id", "_rowoffset", "name", "_rowaddr", "_rowid"],
"system columns must appear at the user's requested position"
);
for sys in ["_rowoffset", "_rowaddr", "_rowid"] {
assert!(is_system_column(sys));
}
let mut rowaddr_real = 0usize;
let mut rowaddr_null = 0usize;
let mut rowid_null = 0usize;
let mut rowoffset_null = 0usize;
let mut total = 0usize;
for batch in &batches {
let rowaddr = batch.column_by_name("_rowaddr").unwrap();
let rowid = batch.column_by_name("_rowid").unwrap();
let rowoffset = batch.column_by_name("_rowoffset").unwrap();
for i in 0..batch.num_rows() {
total += 1;
if rowaddr.is_null(i) {
rowaddr_null += 1;
} else {
rowaddr_real += 1;
}
if rowid.is_null(i) {
rowid_null += 1;
}
if rowoffset.is_null(i) {
rowoffset_null += 1;
}
}
}
assert_eq!(total, 7, "expected 7 unique pks after dedup");
assert_eq!(
rowaddr_real, 2,
"expected 2 rows (id=1,2) with real `_rowaddr` from base"
);
assert_eq!(
rowaddr_null, 5,
"expected 5 rows (id=3-7) with NULL `_rowaddr` from non-base sources"
);
assert_eq!(rowid_null, total, "_rowid must be NULL for every row");
assert_eq!(
rowoffset_null, total,
"_rowoffset must be NULL for every row"
);
}
#[tokio::test]
async fn test_lsm_scan_projection_with_rowid_only_no_rowaddr() {
let (base_dataset, shard_snapshots, active_memtable, pk_columns, _temp_path) =
setup_multi_level_lsm().await;
let mut scanner = LsmScanner::new(base_dataset, shard_snapshots, pk_columns)
.project(&["id", "_rowid", "name"]);
if let Some((shard_id, memtable)) = active_memtable {
scanner = scanner.with_in_memory_memtables(shard_id, memtable);
}
let plan = scanner.create_plan().await.unwrap();
let plan_str = format!(
"{}",
datafusion::physical_plan::displayable(plan.as_ref()).indent(true)
);
assert!(
plan_str.contains("ProjectionExec"),
"expected canonical projection wrap, got:\n{plan_str}",
);
assert!(
!plan_str.contains("NULL as _rowaddr"),
"no per-arm `_rowaddr` NULL'ing expected when caller didn't ask for `_rowaddr`, got:\n{plan_str}",
);
let batches: Vec<RecordBatch> = scanner
.try_into_stream()
.await
.unwrap()
.try_collect()
.await
.unwrap();
let mut total = 0usize;
let mut rowid_null = 0usize;
for batch in &batches {
let rowid = batch.column_by_name("_rowid").unwrap();
for i in 0..batch.num_rows() {
total += 1;
if rowid.is_null(i) {
rowid_null += 1;
}
}
}
assert_eq!(total, 7, "expected 7 unique pks after dedup");
assert_eq!(
rowid_null, total,
"_rowid must be NULL for every row (no opt-in)"
);
}
#[tokio::test]
async fn test_lsm_scan_empty_plan_with_system_columns() {
let temp_dir = tempfile::tempdir().unwrap();
let base_uri = format!("{}/base", temp_dir.path().to_str().unwrap());
let schema: super::SchemaRef = Arc::new(ArrowSchema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("name", DataType::Utf8, true),
]));
let scanner =
LsmScanner::without_base_table(schema, base_uri, vec![], vec!["id".to_string()])
.project(&["id", "_rowaddr", "name", "_rowid"]);
let plan = scanner.create_plan().await.unwrap();
let names: Vec<String> = plan
.schema()
.fields()
.iter()
.map(|f| f.name().clone())
.collect();
assert_eq!(
names,
vec![
"id".to_string(),
"_rowaddr".to_string(),
"name".to_string(),
"_rowid".to_string(),
],
"empty plan must honor user column order including system columns"
);
}
#[tokio::test]
async fn test_lsm_scan_filter_referencing_rowaddr_is_rejected() {
let (base_dataset, shard_snapshots, pk_columns, _temp_path) = {
let (b, s, _a, p, t) = setup_multi_level_lsm().await;
(b, s, p, t)
};
let result =
LsmScanner::new(base_dataset, shard_snapshots, pk_columns).filter("_rowaddr > 0");
let err = result.expect_err("filter referencing `_rowaddr` must be rejected");
let msg = format!("{err}");
assert!(
msg.contains("_rowaddr"),
"rejection message should mention the offending column, got: {msg}",
);
}
}