use std::sync::Arc;
use arrow_schema::SchemaRef;
use datafusion::common::ScalarValue;
use datafusion::physical_plan::ExecutionPlan;
use datafusion::physical_plan::limit::GlobalLimitExec;
use datafusion::prelude::Expr;
use lance_core::Result;
use lance_index::scalar::bloomfilter::sbbf::Sbbf;
use super::collector::LsmDataSourceCollector;
use super::data_source::LsmDataSource;
use super::exec::{BloomFilterGuardExec, CoalesceFirstExec, compute_pk_hash_from_scalars};
pub struct LsmPointLookupPlanner {
collector: LsmDataSourceCollector,
pk_columns: Vec<String>,
base_schema: SchemaRef,
bloom_filters: std::collections::HashMap<u64, Arc<Sbbf>>,
}
impl LsmPointLookupPlanner {
pub fn new(
collector: LsmDataSourceCollector,
pk_columns: Vec<String>,
base_schema: SchemaRef,
) -> Self {
Self {
collector,
pk_columns,
base_schema,
bloom_filters: std::collections::HashMap::new(),
}
}
pub fn with_bloom_filter(mut self, generation: u64, bloom_filter: Arc<Sbbf>) -> Self {
self.bloom_filters.insert(generation, bloom_filter);
self
}
pub fn with_bloom_filters(
mut self,
bloom_filters: impl IntoIterator<Item = (u64, Arc<Sbbf>)>,
) -> Self {
self.bloom_filters.extend(bloom_filters);
self
}
pub async fn plan_lookup(
&self,
pk_values: &[ScalarValue],
projection: Option<&[String]>,
) -> Result<Arc<dyn ExecutionPlan>> {
if pk_values.len() != self.pk_columns.len() {
return Err(lance_core::Error::invalid_input(format!(
"Expected {} primary key values, got {}",
self.pk_columns.len(),
pk_values.len()
)));
}
let pk_hash = compute_pk_hash_from_scalars(pk_values);
let filter_expr = self.build_pk_filter_expr(pk_values)?;
let sources = self.collector.collect()?;
if sources.is_empty() {
return self.empty_plan(projection);
}
let mut sources: Vec<_> = sources.into_iter().collect();
sources.sort_by_key(|b| std::cmp::Reverse(b.generation()));
let mut source_plans = Vec::new();
for source in sources {
let generation = source.generation().as_u64();
let scan = self
.build_source_scan(&source, projection, &filter_expr)
.await?;
let limited: Arc<dyn ExecutionPlan> = Arc::new(GlobalLimitExec::new(scan, 0, Some(1)));
let guarded_plan: Arc<dyn ExecutionPlan> =
if let Some(bf) = self.bloom_filters.get(&generation) {
Arc::new(BloomFilterGuardExec::new(
limited,
bf.clone(),
pk_hash,
generation,
))
} else {
limited
};
source_plans.push(guarded_plan);
}
let plan: Arc<dyn ExecutionPlan> = if source_plans.len() == 1 {
source_plans.remove(0)
} else {
Arc::new(CoalesceFirstExec::new(source_plans))
};
Ok(plan)
}
fn build_pk_filter_expr(&self, pk_values: &[ScalarValue]) -> Result<Expr> {
use datafusion::prelude::{col, lit};
let mut expr: Option<Expr> = None;
for (col_name, value) in self.pk_columns.iter().zip(pk_values.iter()) {
let eq_expr = col(col_name.as_str()).eq(lit(value.clone()));
expr = Some(match expr {
Some(e) => e.and(eq_expr),
None => eq_expr,
});
}
expr.ok_or_else(|| lance_core::Error::invalid_input("No primary key columns specified"))
}
async fn build_source_scan(
&self,
source: &LsmDataSource,
projection: Option<&[String]>,
filter: &Expr,
) -> Result<Arc<dyn ExecutionPlan>> {
match source {
LsmDataSource::BaseTable { dataset } => {
let mut scanner = dataset.scan();
let cols = self.build_projection(projection);
scanner.project(&cols.iter().map(|s| s.as_str()).collect::<Vec<_>>())?;
scanner.filter_expr(filter.clone());
scanner.create_plan().await
}
LsmDataSource::FlushedMemTable { path, .. } => {
let dataset = crate::dataset::DatasetBuilder::from_uri(path)
.load()
.await?;
let mut scanner = dataset.scan();
let cols = self.build_projection(projection);
scanner.project(&cols.iter().map(|s| s.as_str()).collect::<Vec<_>>())?;
scanner.filter_expr(filter.clone());
scanner.create_plan().await
}
LsmDataSource::ActiveMemTable {
batch_store,
index_store,
schema,
..
} => {
use crate::dataset::mem_wal::memtable::scanner::MemTableScanner;
let mut scanner =
MemTableScanner::new(batch_store.clone(), index_store.clone(), schema.clone());
if let Some(cols) = projection {
scanner.project(&cols.iter().map(|s| s.as_str()).collect::<Vec<_>>());
}
scanner.filter_expr(filter.clone());
scanner.create_plan().await
}
}
}
fn build_projection(&self, projection: Option<&[String]>) -> Vec<String> {
let mut cols: Vec<String> = if let Some(p) = projection {
p.to_vec()
} else {
self.base_schema
.fields()
.iter()
.map(|f| f.name().clone())
.collect()
};
for pk in &self.pk_columns {
if !cols.contains(pk) {
cols.push(pk.clone());
}
}
cols
}
fn empty_plan(&self, projection: Option<&[String]>) -> Result<Arc<dyn ExecutionPlan>> {
use arrow_schema::{Field, Schema};
use datafusion::physical_plan::empty::EmptyExec;
let fields: Vec<Arc<Field>> = if let Some(cols) = projection {
cols.iter()
.filter_map(|name| {
self.base_schema
.field_with_name(name)
.ok()
.map(|f| Arc::new(f.clone()))
})
.collect()
} else {
self.base_schema.fields().iter().cloned().collect()
};
let schema = Arc::new(Schema::new(fields));
Ok(Arc::new(EmptyExec::new(schema)))
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator, StringArray};
use arrow_schema::{DataType, Field, Schema as ArrowSchema};
use datafusion::physical_plan::displayable;
use std::collections::HashMap;
use uuid::Uuid;
use crate::dataset::mem_wal::scanner::data_source::RegionSnapshot;
use crate::dataset::{Dataset, WriteParams};
fn create_pk_schema() -> Arc<ArrowSchema> {
let mut id_metadata = HashMap::new();
id_metadata.insert(
"lance-schema:unenforced-primary-key".to_string(),
"true".to_string(),
);
let id_field = Field::new("id", DataType::Int32, false).with_metadata(id_metadata);
Arc::new(ArrowSchema::new(vec![
id_field,
Field::new("name", DataType::Utf8, true),
]))
}
fn create_test_batch(schema: &ArrowSchema, ids: &[i32], name_prefix: &str) -> RecordBatch {
let names: Vec<String> = ids
.iter()
.map(|id| format!("{}_{}", name_prefix, id))
.collect();
RecordBatch::try_new(
Arc::new(schema.clone()),
vec![
Arc::new(Int32Array::from(ids.to_vec())),
Arc::new(StringArray::from(names)),
],
)
.unwrap()
}
async fn create_dataset(uri: &str, batches: Vec<RecordBatch>) -> Dataset {
let schema = batches[0].schema();
let reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema);
Dataset::write(reader, uri, Some(WriteParams::default()))
.await
.unwrap()
}
#[tokio::test]
async fn test_point_lookup_plan_structure() {
let schema = create_pk_schema();
let temp_dir = tempfile::tempdir().unwrap();
let base_path = temp_dir.path().to_str().unwrap();
let base_uri = format!("{}/base", base_path);
let base_batch = create_test_batch(&schema, &[1, 2, 3], "base");
let base_dataset = Arc::new(create_dataset(&base_uri, vec![base_batch]).await);
let collector = LsmDataSourceCollector::new(base_dataset, vec![]);
let planner = LsmPointLookupPlanner::new(collector, vec!["id".to_string()], schema.clone());
let pk_values = vec![ScalarValue::Int32(Some(2))];
let plan = planner.plan_lookup(&pk_values, None).await.unwrap();
let plan_str = format!("{}", displayable(plan.as_ref()).indent(true));
assert!(
plan_str.contains("GlobalLimitExec"),
"Should have GlobalLimitExec in plan: {}",
plan_str
);
}
#[tokio::test]
async fn test_point_lookup_with_memtables() {
let schema = create_pk_schema();
let temp_dir = tempfile::tempdir().unwrap();
let base_path = temp_dir.path().to_str().unwrap();
let base_uri = format!("{}/base", base_path);
let base_batch = create_test_batch(&schema, &[1, 2, 3], "base");
let base_dataset = Arc::new(create_dataset(&base_uri, vec![base_batch]).await);
let region_id = Uuid::new_v4();
let gen1_uri = format!("{}/_mem_wal/{}/gen_1", base_uri, region_id);
let gen1_batch = create_test_batch(&schema, &[2], "gen1"); create_dataset(&gen1_uri, vec![gen1_batch]).await;
let region_snapshot = RegionSnapshot::new(region_id)
.with_current_generation(2)
.with_flushed_generation(1, "gen_1".to_string());
let collector = LsmDataSourceCollector::new(base_dataset, vec![region_snapshot]);
let planner = LsmPointLookupPlanner::new(collector, vec!["id".to_string()], schema.clone());
let pk_values = vec![ScalarValue::Int32(Some(2))];
let plan = planner.plan_lookup(&pk_values, None).await.unwrap();
let plan_str = format!("{}", displayable(plan.as_ref()).indent(true));
assert!(
plan_str.contains("CoalesceFirstExec") || plan_str.contains("GlobalLimitExec"),
"Should have CoalesceFirstExec or GlobalLimitExec in plan: {}",
plan_str
);
}
#[tokio::test]
async fn test_point_lookup_with_bloom_filter() {
let schema = create_pk_schema();
let temp_dir = tempfile::tempdir().unwrap();
let base_path = temp_dir.path().to_str().unwrap();
let base_uri = format!("{}/base", base_path);
let base_batch = create_test_batch(&schema, &[1, 2, 3], "base");
let base_dataset = Arc::new(create_dataset(&base_uri, vec![base_batch]).await);
let collector = LsmDataSourceCollector::new(base_dataset, vec![]);
let mut bf = Sbbf::with_ndv_fpp(100, 0.01).unwrap();
let pk_hash = compute_pk_hash_from_scalars(&[ScalarValue::Int32(Some(2))]);
bf.insert_hash(pk_hash);
let planner = LsmPointLookupPlanner::new(collector, vec!["id".to_string()], schema.clone())
.with_bloom_filter(1, Arc::new(bf));
let pk_values = vec![ScalarValue::Int32(Some(2))];
let plan = planner.plan_lookup(&pk_values, None).await.unwrap();
assert!(plan.schema().field_with_name("id").is_ok());
}
#[tokio::test]
async fn test_pk_filter_expr() {
let schema = create_pk_schema();
let temp_dir = tempfile::tempdir().unwrap();
let base_uri = format!("{}/base", temp_dir.path().to_str().unwrap());
let base_batch = create_test_batch(&schema, &[1], "base");
let base_dataset = Arc::new(create_dataset(&base_uri, vec![base_batch]).await);
let collector = LsmDataSourceCollector::new(base_dataset, vec![]);
let planner = LsmPointLookupPlanner::new(collector, vec!["id".to_string()], schema);
let pk_values = vec![ScalarValue::Int32(Some(42))];
let expr = planner.build_pk_filter_expr(&pk_values).unwrap();
let expr_str = format!("{}", expr);
assert!(
expr_str.contains("id"),
"Expression should contain column name"
);
}
}