use std::sync::Arc;
use arrow_schema::{DataType, Field, Schema, SchemaRef, SortOptions};
use datafusion::physical_expr::expressions::Column;
use datafusion::physical_expr::{LexOrdering, PhysicalSortExpr};
use datafusion::physical_plan::sorts::sort::SortExec;
use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec;
use datafusion::physical_plan::union::UnionExec;
use datafusion::physical_plan::{ExecutionPlan, limit::GlobalLimitExec};
use datafusion::prelude::Expr;
use lance_core::Result;
use super::collector::LsmDataSourceCollector;
use super::data_source::LsmDataSource;
use super::exec::{DeduplicateExec, MEMTABLE_GEN_COLUMN, MemtableGenTagExec, ROW_ADDRESS_COLUMN};
pub struct LsmScanPlanner {
collector: LsmDataSourceCollector,
pk_columns: Vec<String>,
base_schema: SchemaRef,
}
impl LsmScanPlanner {
pub fn new(
collector: LsmDataSourceCollector,
pk_columns: Vec<String>,
base_schema: SchemaRef,
) -> Self {
Self {
collector,
pk_columns,
base_schema,
}
}
pub async fn plan_scan(
&self,
projection: Option<&[String]>,
filter: Option<&Expr>,
limit: Option<usize>,
offset: Option<usize>,
with_memtable_gen: bool,
keep_row_address: bool,
) -> Result<Arc<dyn ExecutionPlan>> {
let sources = self.collector.collect()?;
if sources.is_empty() {
return self.empty_plan(projection, with_memtable_gen, keep_row_address);
}
let sources: Vec<_> = sources.into_iter().rev().collect();
let mut sorted_plans = Vec::new();
for source in sources {
let scan = self.build_source_scan(&source, projection, filter).await?;
let local_sort_exprs = self.build_local_sort_exprs(&scan)?;
let lex_ordering = LexOrdering::new(local_sort_exprs).ok_or_else(|| {
lance_core::Error::internal(
"Failed to create LexOrdering from sort expressions".to_string(),
)
})?;
let sorted: Arc<dyn ExecutionPlan> = Arc::new(SortExec::new(lex_ordering, scan));
let plan: Arc<dyn ExecutionPlan> = if with_memtable_gen {
Arc::new(MemtableGenTagExec::new(sorted, source.generation()))
} else {
sorted
};
sorted_plans.push(plan);
}
let merged: Arc<dyn ExecutionPlan> = if sorted_plans.len() == 1 {
sorted_plans.remove(0)
} else {
let merge_sort_exprs = self.build_merge_sort_exprs(&sorted_plans[0])?;
let lex_ordering = LexOrdering::new(merge_sort_exprs).ok_or_else(|| {
lance_core::Error::internal(
"Failed to create LexOrdering from sort expressions".to_string(),
)
})?;
#[allow(deprecated)]
let union = Arc::new(UnionExec::new(sorted_plans));
Arc::new(SortPreservingMergeExec::new(lex_ordering, union))
};
let dedup = DeduplicateExec::new_sorted(
merged,
self.pk_columns.clone(),
with_memtable_gen,
keep_row_address,
)?;
let mut plan: Arc<dyn ExecutionPlan> = Arc::new(dedup);
if let Some(limit) = limit {
plan = Arc::new(GlobalLimitExec::new(plan, offset.unwrap_or(0), Some(limit)));
}
Ok(plan)
}
fn build_local_sort_exprs(
&self,
plan: &Arc<dyn ExecutionPlan>,
) -> Result<Vec<PhysicalSortExpr>> {
let schema = plan.schema();
let mut sort_exprs = Vec::new();
for col in &self.pk_columns {
let (idx, _) = schema.column_with_name(col).ok_or_else(|| {
lance_core::Error::invalid_input(format!("Column '{}' not found in schema", col))
})?;
sort_exprs.push(PhysicalSortExpr {
expr: Arc::new(Column::new(col, idx)),
options: SortOptions {
descending: false,
nulls_first: false,
},
});
}
let (addr_idx, _) = schema.column_with_name(ROW_ADDRESS_COLUMN).ok_or_else(|| {
lance_core::Error::invalid_input(format!(
"Column '{}' not found in schema",
ROW_ADDRESS_COLUMN
))
})?;
sort_exprs.push(PhysicalSortExpr {
expr: Arc::new(Column::new(ROW_ADDRESS_COLUMN, addr_idx)),
options: SortOptions {
descending: true,
nulls_first: false,
},
});
Ok(sort_exprs)
}
fn build_merge_sort_exprs(
&self,
plan: &Arc<dyn ExecutionPlan>,
) -> Result<Vec<PhysicalSortExpr>> {
let schema = plan.schema();
let mut sort_exprs = Vec::new();
for col in &self.pk_columns {
let (idx, _) = schema.column_with_name(col).ok_or_else(|| {
lance_core::Error::invalid_input(format!("Column '{}' not found in schema", col))
})?;
sort_exprs.push(PhysicalSortExpr {
expr: Arc::new(Column::new(col, idx)),
options: SortOptions {
descending: false,
nulls_first: false,
},
});
}
Ok(sort_exprs)
}
async fn build_source_scan(
&self,
source: &LsmDataSource,
projection: Option<&[String]>,
filter: Option<&Expr>,
) -> Result<Arc<dyn ExecutionPlan>> {
match source {
LsmDataSource::BaseTable { dataset } => {
let mut scanner = dataset.scan();
let cols = self.build_projection_with_rowaddr(projection);
scanner.project(&cols.iter().map(|s| s.as_str()).collect::<Vec<_>>())?;
scanner.with_row_address();
if let Some(expr) = filter {
scanner.filter_expr(expr.clone());
}
scanner.create_plan().await
}
LsmDataSource::FlushedMemTable { path, .. } => {
let dataset = crate::dataset::DatasetBuilder::from_uri(path)
.load()
.await?;
let mut scanner = dataset.scan();
let cols = self.build_projection_with_rowaddr(projection);
scanner.project(&cols.iter().map(|s| s.as_str()).collect::<Vec<_>>())?;
scanner.with_row_address();
if let Some(expr) = filter {
scanner.filter_expr(expr.clone());
}
scanner.create_plan().await
}
LsmDataSource::ActiveMemTable {
batch_store,
index_store,
schema,
..
} => {
use crate::dataset::mem_wal::memtable::scanner::MemTableScanner;
let mut scanner =
MemTableScanner::new(batch_store.clone(), index_store.clone(), schema.clone());
if let Some(cols) = projection {
scanner.project(&cols.iter().map(|s| s.as_str()).collect::<Vec<_>>());
}
scanner.with_row_address();
if let Some(expr) = filter {
scanner.filter_expr(expr.clone());
}
scanner.create_plan().await
}
}
}
fn build_projection_with_rowaddr(&self, projection: Option<&[String]>) -> Vec<String> {
let mut cols: Vec<String> = if let Some(p) = projection {
p.to_vec()
} else {
self.base_schema
.fields()
.iter()
.map(|f| f.name().clone())
.collect()
};
for pk in &self.pk_columns {
if !cols.contains(pk) {
cols.push(pk.clone());
}
}
cols
}
fn empty_plan(
&self,
projection: Option<&[String]>,
with_memtable_gen: bool,
keep_row_address: bool,
) -> Result<Arc<dyn ExecutionPlan>> {
use datafusion::physical_plan::empty::EmptyExec;
let mut fields: Vec<Arc<Field>> = if let Some(cols) = projection {
cols.iter()
.filter_map(|name| {
self.base_schema
.field_with_name(name)
.ok()
.map(|f| Arc::new(f.clone()))
})
.collect()
} else {
self.base_schema.fields().iter().cloned().collect()
};
if with_memtable_gen {
fields.push(Arc::new(Field::new(
MEMTABLE_GEN_COLUMN,
DataType::UInt64,
false,
)));
}
if keep_row_address {
fields.push(Arc::new(Field::new(
ROW_ADDRESS_COLUMN,
DataType::UInt64,
false,
)));
}
let schema = Arc::new(Schema::new(fields));
Ok(Arc::new(EmptyExec::new(schema)))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dataset::mem_wal::scanner::data_source::RegionSnapshot;
fn create_test_schema() -> SchemaRef {
Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("name", DataType::Utf8, true),
Field::new("value", DataType::Float64, true),
]))
}
#[test]
fn test_build_projection_with_rowaddr() {
let schema = create_test_schema();
let pk_columns = vec!["id".to_string()];
let mut cols: Vec<String> = vec!["name".to_string()];
for pk in &pk_columns {
if !cols.contains(pk) {
cols.push(pk.clone());
}
}
assert!(cols.contains(&"name".to_string()));
assert!(cols.contains(&"id".to_string()));
let cols_all: Vec<String> = schema.fields().iter().map(|f| f.name().clone()).collect();
assert_eq!(cols_all.len(), 3);
}
#[test]
fn test_region_snapshot() {
let region_id = uuid::Uuid::new_v4();
let snapshot = RegionSnapshot::new(region_id)
.with_current_generation(5)
.with_flushed_generation(1, "gen_1".to_string())
.with_flushed_generation(2, "gen_2".to_string());
assert_eq!(snapshot.flushed_generations.len(), 2);
assert_eq!(snapshot.current_generation, 5);
}
}
#[cfg(test)]
mod integration_tests {
use std::collections::HashMap;
use std::sync::Arc;
use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator, StringArray};
use arrow_schema::{DataType, Field, Schema as ArrowSchema};
use futures::TryStreamExt;
use uuid::Uuid;
use crate::dataset::mem_wal::scanner::LsmScanner;
use crate::dataset::mem_wal::scanner::collector::ActiveMemTableRef;
use crate::dataset::mem_wal::scanner::data_source::RegionSnapshot;
use crate::dataset::mem_wal::write::{BatchStore, IndexStore};
use crate::dataset::{Dataset, WriteParams};
use crate::utils::test::assert_plan_node_equals;
fn create_pk_schema() -> Arc<ArrowSchema> {
let mut id_metadata = HashMap::new();
id_metadata.insert(
"lance-schema:unenforced-primary-key".to_string(),
"true".to_string(),
);
let id_field = Field::new("id", DataType::Int32, false).with_metadata(id_metadata);
Arc::new(ArrowSchema::new(vec![
id_field,
Field::new("name", DataType::Utf8, true),
]))
}
fn create_test_batch(schema: &ArrowSchema, ids: &[i32], name_prefix: &str) -> RecordBatch {
let names: Vec<String> = ids
.iter()
.map(|id| format!("{}_{}", name_prefix, id))
.collect();
RecordBatch::try_new(
Arc::new(schema.clone()),
vec![
Arc::new(Int32Array::from(ids.to_vec())),
Arc::new(StringArray::from(names)),
],
)
.unwrap()
}
async fn create_dataset(uri: &str, batches: Vec<RecordBatch>) -> Dataset {
let schema = batches[0].schema();
let reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema);
Dataset::write(reader, uri, Some(WriteParams::default()))
.await
.unwrap()
}
async fn setup_multi_level_lsm() -> (
Arc<Dataset>,
Vec<RegionSnapshot>,
Option<(Uuid, ActiveMemTableRef)>,
Vec<String>,
String, // temp_dir path for cleanup
) {
let schema = create_pk_schema();
let temp_dir = tempfile::tempdir().unwrap();
let base_path = temp_dir.path().to_str().unwrap();
let base_uri = format!("{}/base", base_path);
let base_batch = create_test_batch(&schema, &[1, 2, 3, 4, 5], "base");
let base_dataset = Arc::new(create_dataset(&base_uri, vec![base_batch]).await);
let region_id = Uuid::new_v4();
let gen1_uri = format!("{}/_mem_wal/{}/gen_1", base_uri, region_id);
let gen1_batch = create_test_batch(&schema, &[3, 4], "gen1");
create_dataset(&gen1_uri, vec![gen1_batch]).await;
let gen2_uri = format!("{}/_mem_wal/{}/gen_2", base_uri, region_id);
let gen2_batch = create_test_batch(&schema, &[4, 5, 6], "gen2");
create_dataset(&gen2_uri, vec![gen2_batch]).await;
let region_snapshot = RegionSnapshot::new(region_id)
.with_current_generation(3)
.with_flushed_generation(1, "gen_1".to_string())
.with_flushed_generation(2, "gen_2".to_string());
let batch_store = Arc::new(BatchStore::with_capacity(100));
let index_store = Arc::new(IndexStore::new());
let active_batch = create_test_batch(&schema, &[5, 6, 7], "active");
let _ = batch_store.append(active_batch);
let active_memtable = ActiveMemTableRef {
batch_store,
index_store,
schema: schema.clone(),
generation: 3,
};
let pk_columns = vec!["id".to_string()];
let temp_path = temp_dir.keep().to_string_lossy().to_string();
(
base_dataset,
vec![region_snapshot],
Some((region_id, active_memtable)),
pk_columns,
temp_path,
)
}
#[tokio::test]
async fn test_lsm_scan_query_plan_without_memtable_gen() {
let (base_dataset, region_snapshots, active_memtable, pk_columns, _temp_path) =
setup_multi_level_lsm().await;
let mut scanner = LsmScanner::new(base_dataset, region_snapshots, pk_columns);
if let Some((region_id, memtable)) = active_memtable {
scanner = scanner.with_active_memtable(region_id, memtable);
}
let plan = scanner.create_plan().await.unwrap();
assert_plan_node_equals(
plan,
"DeduplicateExec: pk=[id], with_memtable_gen=false, keep_addr=false, input_sorted=true
SortPreservingMergeExec: [id@0 ASC NULLS LAST]
UnionExec
SortExec: expr=[id@0 ASC NULLS LAST, _rowaddr@2 DESC NULLS LAST]...
MemTableScanExec: projection=[id, name, _rowaddr], with_row_id=false, with_row_address=true
SortExec: expr=[id@0 ASC NULLS LAST, _rowaddr@2 DESC NULLS LAST]...
LanceRead:...gen_2...
SortExec: expr=[id@0 ASC NULLS LAST, _rowaddr@2 DESC NULLS LAST]...
LanceRead:...gen_1...
SortExec: expr=[id@0 ASC NULLS LAST, _rowaddr@2 DESC NULLS LAST]...
LanceRead:...base/data...refine_filter=--",
)
.await
.unwrap();
}
#[tokio::test]
async fn test_lsm_scan_query_plan_with_memtable_gen() {
let (base_dataset, region_snapshots, active_memtable, pk_columns, _temp_path) =
setup_multi_level_lsm().await;
let mut scanner =
LsmScanner::new(base_dataset, region_snapshots, pk_columns).with_memtable_gen();
if let Some((region_id, memtable)) = active_memtable {
scanner = scanner.with_active_memtable(region_id, memtable);
}
let plan = scanner.create_plan().await.unwrap();
assert_plan_node_equals(
plan,
"DeduplicateExec: pk=[id], with_memtable_gen=true, keep_addr=false, input_sorted=true
SortPreservingMergeExec: [id@0 ASC NULLS LAST]
UnionExec
MemtableGenTagExec: gen=gen3
SortExec: expr=[id@0 ASC NULLS LAST, _rowaddr@2 DESC NULLS LAST]...
MemTableScanExec: projection=[id, name, _rowaddr], with_row_id=false, with_row_address=true
MemtableGenTagExec: gen=gen2
SortExec: expr=[id@0 ASC NULLS LAST, _rowaddr@2 DESC NULLS LAST]...
LanceRead:...gen_2...
MemtableGenTagExec: gen=gen1
SortExec: expr=[id@0 ASC NULLS LAST, _rowaddr@2 DESC NULLS LAST]...
LanceRead:...gen_1...
MemtableGenTagExec: gen=base
SortExec: expr=[id@0 ASC NULLS LAST, _rowaddr@2 DESC NULLS LAST]...
LanceRead:...base/data...refine_filter=--",
)
.await
.unwrap();
}
#[tokio::test]
async fn test_lsm_scan_deduplication_results() {
let (base_dataset, region_snapshots, active_memtable, pk_columns, _temp_path) =
setup_multi_level_lsm().await;
let mut scanner = LsmScanner::new(base_dataset, region_snapshots, pk_columns);
if let Some((region_id, memtable)) = active_memtable {
scanner = scanner.with_active_memtable(region_id, memtable);
}
let batches: Vec<RecordBatch> = scanner
.try_into_stream()
.await
.unwrap()
.try_collect()
.await
.unwrap();
let mut results: HashMap<i32, String> = HashMap::new();
for batch in batches {
let ids = batch
.column_by_name("id")
.unwrap()
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
let names = batch
.column_by_name("name")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
for i in 0..batch.num_rows() {
results.insert(ids.value(i), names.value(i).to_string());
}
}
assert_eq!(results.len(), 7, "Should have 7 unique rows after dedup");
assert_eq!(results.get(&1), Some(&"base_1".to_string()));
assert_eq!(results.get(&2), Some(&"base_2".to_string()));
assert_eq!(results.get(&3), Some(&"gen1_3".to_string()));
assert_eq!(results.get(&4), Some(&"gen2_4".to_string()));
assert_eq!(results.get(&5), Some(&"active_5".to_string()));
assert_eq!(results.get(&6), Some(&"active_6".to_string()));
assert_eq!(results.get(&7), Some(&"active_7".to_string()));
}
#[tokio::test]
async fn test_lsm_scan_with_projection() {
let (base_dataset, region_snapshots, active_memtable, pk_columns, _temp_path) =
setup_multi_level_lsm().await;
let mut scanner =
LsmScanner::new(base_dataset, region_snapshots, pk_columns).project(&["id"]);
if let Some((region_id, memtable)) = active_memtable {
scanner = scanner.with_active_memtable(region_id, memtable);
}
let batches: Vec<RecordBatch> = scanner
.try_into_stream()
.await
.unwrap()
.try_collect()
.await
.unwrap();
let schema = batches[0].schema();
assert_eq!(schema.fields().len(), 1);
assert_eq!(schema.field(0).name(), "id");
let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
assert_eq!(total_rows, 7, "Should have 7 unique rows after dedup");
}
#[tokio::test]
async fn test_lsm_scan_with_limit() {
let (base_dataset, region_snapshots, active_memtable, pk_columns, _temp_path) =
setup_multi_level_lsm().await;
let mut scanner =
LsmScanner::new(base_dataset, region_snapshots, pk_columns).limit(3, None);
if let Some((region_id, memtable)) = active_memtable {
scanner = scanner.with_active_memtable(region_id, memtable);
}
let batches: Vec<RecordBatch> = scanner
.try_into_stream()
.await
.unwrap()
.try_collect()
.await
.unwrap();
let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
assert_eq!(total_rows, 3, "Should have 3 rows due to limit");
}
#[tokio::test]
async fn test_lsm_scan_base_only() {
let (base_dataset, _, _, pk_columns, _temp_path) = setup_multi_level_lsm().await;
let scanner = LsmScanner::new(base_dataset, vec![], pk_columns);
let plan = scanner.create_plan().await.unwrap();
assert_plan_node_equals(
plan,
"DeduplicateExec: pk=[id], with_memtable_gen=false, keep_addr=false, input_sorted=true
SortExec: expr=[id@0 ASC NULLS LAST, _rowaddr@2 DESC NULLS LAST]...
LanceRead:...base/data...refine_filter=--",
)
.await
.unwrap();
let scanner = LsmScanner::new(
Arc::new(
Dataset::open(&format!("{}/base", _temp_path))
.await
.unwrap(),
),
vec![],
vec!["id".to_string()],
);
let batches: Vec<RecordBatch> = scanner
.try_into_stream()
.await
.unwrap()
.try_collect()
.await
.unwrap();
let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
assert_eq!(total_rows, 5, "Should have 5 rows from base table");
}
#[tokio::test]
async fn test_lsm_scan_flushed_only_no_active() {
let (base_dataset, region_snapshots, _, pk_columns, _temp_path) =
setup_multi_level_lsm().await;
let scanner = LsmScanner::new(base_dataset, region_snapshots, pk_columns);
let batches: Vec<RecordBatch> = scanner
.try_into_stream()
.await
.unwrap()
.try_collect()
.await
.unwrap();
let mut results: HashMap<i32, String> = HashMap::new();
for batch in batches {
let ids = batch
.column_by_name("id")
.unwrap()
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
let names = batch
.column_by_name("name")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
for i in 0..batch.num_rows() {
results.insert(ids.value(i), names.value(i).to_string());
}
}
assert_eq!(results.len(), 6, "Should have 6 unique rows (no id=7)");
assert_eq!(results.get(&1), Some(&"base_1".to_string()));
assert_eq!(results.get(&2), Some(&"base_2".to_string()));
assert_eq!(results.get(&3), Some(&"gen1_3".to_string()));
assert_eq!(results.get(&4), Some(&"gen2_4".to_string()));
assert_eq!(results.get(&5), Some(&"gen2_5".to_string()));
assert_eq!(results.get(&6), Some(&"gen2_6".to_string()));
assert_eq!(results.get(&7), None);
}
#[tokio::test]
async fn test_lsm_scan_with_row_address() {
let (base_dataset, region_snapshots, active_memtable, pk_columns, _temp_path) =
setup_multi_level_lsm().await;
let mut scanner =
LsmScanner::new(base_dataset, region_snapshots, pk_columns).with_row_address();
if let Some((region_id, memtable)) = active_memtable {
scanner = scanner.with_active_memtable(region_id, memtable);
}
let plan = scanner.create_plan().await.unwrap();
assert_plan_node_equals(
plan,
"DeduplicateExec: pk=[id], with_memtable_gen=false, keep_addr=true, input_sorted=true
SortPreservingMergeExec: [id@0 ASC NULLS LAST]
UnionExec
SortExec: expr=[id@0 ASC NULLS LAST, _rowaddr@2 DESC NULLS LAST]...
MemTableScanExec: projection=[id, name, _rowaddr], with_row_id=false, with_row_address=true
SortExec: expr=[id@0 ASC NULLS LAST, _rowaddr@2 DESC NULLS LAST]...
LanceRead:...gen_2...
SortExec: expr=[id@0 ASC NULLS LAST, _rowaddr@2 DESC NULLS LAST]...
LanceRead:...gen_1...
SortExec: expr=[id@0 ASC NULLS LAST, _rowaddr@2 DESC NULLS LAST]...
LanceRead:...base/data...refine_filter=--",
)
.await
.unwrap();
let scanner = LsmScanner::new(
Arc::new(
Dataset::open(&format!("{}/base", _temp_path))
.await
.unwrap(),
),
vec![],
vec!["id".to_string()],
)
.with_row_address();
let batches: Vec<RecordBatch> = scanner
.try_into_stream()
.await
.unwrap()
.try_collect()
.await
.unwrap();
let schema = batches[0].schema();
assert!(
schema.column_with_name("_rowaddr").is_some(),
"Schema should include _rowaddr"
);
}
#[tokio::test]
async fn test_lsm_scan_with_both_memtable_gen_and_row_address() {
let (base_dataset, region_snapshots, active_memtable, pk_columns, _temp_path) =
setup_multi_level_lsm().await;
let mut scanner = LsmScanner::new(base_dataset, region_snapshots, pk_columns)
.with_memtable_gen()
.with_row_address();
if let Some((region_id, memtable)) = active_memtable {
scanner = scanner.with_active_memtable(region_id, memtable);
}
let plan = scanner.create_plan().await.unwrap();
assert_plan_node_equals(
plan,
"DeduplicateExec: pk=[id], with_memtable_gen=true, keep_addr=true, input_sorted=true
SortPreservingMergeExec: [id@0 ASC NULLS LAST]
UnionExec
MemtableGenTagExec: gen=gen3
SortExec: expr=[id@0 ASC NULLS LAST, _rowaddr@2 DESC NULLS LAST]...
MemTableScanExec: projection=[id, name, _rowaddr], with_row_id=false, with_row_address=true
MemtableGenTagExec: gen=gen2
SortExec: expr=[id@0 ASC NULLS LAST, _rowaddr@2 DESC NULLS LAST]...
LanceRead:...gen_2...
MemtableGenTagExec: gen=gen1
SortExec: expr=[id@0 ASC NULLS LAST, _rowaddr@2 DESC NULLS LAST]...
LanceRead:...gen_1...
MemtableGenTagExec: gen=base
SortExec: expr=[id@0 ASC NULLS LAST, _rowaddr@2 DESC NULLS LAST]...
LanceRead:...base/data...refine_filter=--",
)
.await
.unwrap();
}
async fn setup_multi_level_lsm_with_btree_index() -> (
Arc<Dataset>,
Vec<RegionSnapshot>,
Option<(Uuid, ActiveMemTableRef)>,
Vec<String>,
String,
) {
use crate::index::CreateIndexBuilder;
use lance_index::IndexType;
use lance_index::scalar::ScalarIndexParams;
let schema = create_pk_schema();
let temp_dir = tempfile::tempdir().unwrap();
let base_path = temp_dir.path().to_str().unwrap();
let base_uri = format!("{}/base", base_path);
let base_batch = create_test_batch(&schema, &[1, 2, 3, 4, 5], "base");
let mut base_dataset = create_dataset(&base_uri, vec![base_batch]).await;
let params = ScalarIndexParams::default();
CreateIndexBuilder::new(&mut base_dataset, &["id"], IndexType::BTree, ¶ms)
.await
.unwrap();
let base_dataset = Arc::new(Dataset::open(&base_uri).await.unwrap());
let region_id = Uuid::new_v4();
let gen1_uri = format!("{}/_mem_wal/{}/gen_1", base_uri, region_id);
let gen1_batch = create_test_batch(&schema, &[3, 4], "gen1");
let mut gen1_dataset = create_dataset(&gen1_uri, vec![gen1_batch]).await;
CreateIndexBuilder::new(&mut gen1_dataset, &["id"], IndexType::BTree, ¶ms)
.await
.unwrap();
let gen2_uri = format!("{}/_mem_wal/{}/gen_2", base_uri, region_id);
let gen2_batch = create_test_batch(&schema, &[4, 5, 6], "gen2");
let mut gen2_dataset = create_dataset(&gen2_uri, vec![gen2_batch]).await;
CreateIndexBuilder::new(&mut gen2_dataset, &["id"], IndexType::BTree, ¶ms)
.await
.unwrap();
let region_snapshot = RegionSnapshot::new(region_id)
.with_current_generation(3)
.with_flushed_generation(1, "gen_1".to_string())
.with_flushed_generation(2, "gen_2".to_string());
let batch_store = Arc::new(BatchStore::with_capacity(100));
let mut index_store = IndexStore::new();
index_store.add_btree("id_idx".to_string(), 0, "id".to_string());
let active_batch = create_test_batch(&schema, &[5, 6, 7], "active");
let _ = batch_store.append(active_batch.clone());
index_store
.insert_with_batch_position(&active_batch, 0, Some(0))
.unwrap();
let index_store = Arc::new(index_store);
let active_memtable = ActiveMemTableRef {
batch_store,
index_store,
schema: schema.clone(),
generation: 3,
};
let pk_columns = vec!["id".to_string()];
let temp_path = temp_dir.keep().to_string_lossy().to_string();
(
base_dataset,
vec![region_snapshot],
Some((region_id, active_memtable)),
pk_columns,
temp_path,
)
}
#[tokio::test]
async fn test_lsm_scan_with_btree_index_filter() {
let (base_dataset, region_snapshots, active_memtable, pk_columns, _temp_path) =
setup_multi_level_lsm_with_btree_index().await;
let mut scanner = LsmScanner::new(base_dataset, region_snapshots, pk_columns)
.filter("id = 5")
.unwrap();
if let Some((region_id, memtable)) = active_memtable {
scanner = scanner.with_active_memtable(region_id, memtable);
}
let plan = scanner.create_plan().await.unwrap();
use datafusion::physical_plan::displayable;
let plan_str = format!("{}", displayable(plan.as_ref()).indent(true));
assert!(
plan_str.contains("DeduplicateExec: pk=[id]"),
"Should have DeduplicateExec at top"
);
assert!(
plan_str.contains("SortPreservingMergeExec"),
"Should use SortPreservingMergeExec for merging"
);
assert!(plan_str.contains("UnionExec"), "Should have UnionExec");
assert!(
plan_str.contains("BTreeIndexExec: predicate=Eq"),
"Active memtable should use BTreeIndexExec instead of MemTableScanExec"
);
assert!(
plan_str.contains("gen_2") && plan_str.contains("full_filter="),
"gen_2 should have filter pushed down"
);
assert!(
plan_str.contains("gen_1") && plan_str.contains("full_filter="),
"gen_1 should have filter pushed down"
);
assert!(
plan_str.contains("base/data") && plan_str.contains("full_filter="),
"base table should have filter pushed down"
);
let batches: Vec<RecordBatch> = scanner
.try_into_stream()
.await
.unwrap()
.try_collect()
.await
.unwrap();
let mut results: HashMap<i32, String> = HashMap::new();
for batch in batches {
let ids = batch
.column_by_name("id")
.unwrap()
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
let names = batch
.column_by_name("name")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
for i in 0..batch.num_rows() {
results.insert(ids.value(i), names.value(i).to_string());
}
}
assert_eq!(results.len(), 1, "Filter should return only matching rows");
assert_eq!(
results.get(&5),
Some(&"active_5".to_string()),
"Should get newest version (active) for id=5"
);
}
#[tokio::test]
async fn test_lsm_scan_with_filter_no_index() {
let (base_dataset, region_snapshots, active_memtable, pk_columns, _temp_path) =
setup_multi_level_lsm().await;
let mut scanner = LsmScanner::new(base_dataset, region_snapshots, pk_columns)
.filter("id = 3")
.unwrap();
if let Some((region_id, memtable)) = active_memtable {
scanner = scanner.with_active_memtable(region_id, memtable);
}
let batches: Vec<RecordBatch> = scanner
.try_into_stream()
.await
.unwrap()
.try_collect()
.await
.unwrap();
let mut results: HashMap<i32, String> = HashMap::new();
for batch in batches {
let ids = batch
.column_by_name("id")
.unwrap()
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
let names = batch
.column_by_name("name")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
for i in 0..batch.num_rows() {
results.insert(ids.value(i), names.value(i).to_string());
}
}
assert_eq!(results.len(), 1);
assert_eq!(results.get(&3), Some(&"gen1_3".to_string()));
}
}