use std::collections::HashMap;
use std::collections::hash_map::Entry;
use std::sync::Arc;
use arrow_array::RecordBatch;
use arrow_schema::SchemaRef;
use datafusion::common::ToDFSchema;
use datafusion::physical_plan::{ExecutionPlan, SendableRecordBatchStream};
use datafusion::prelude::{Expr, SessionContext};
use futures::TryStreamExt;
use lance_core::{Error, Result};
use uuid::Uuid;
use super::collector::{InMemoryMemTableRef, InMemoryMemTables, LsmDataSourceCollector};
use super::data_source::ShardSnapshot;
use super::flushed_cache::FlushedMemTableCache;
use super::planner::LsmScanPlanner;
use crate::dataset::Dataset;
use crate::session::Session;
enum BaseSource {
Table(Arc<Dataset>),
PathOnly(String),
}
pub struct LsmScanner {
base: BaseSource,
schema: SchemaRef,
shard_snapshots: Vec<ShardSnapshot>,
in_memory_memtables: HashMap<Uuid, InMemoryMemTables>,
projection: Option<Vec<String>>,
filter: Option<Expr>,
limit: Option<usize>,
offset: Option<usize>,
with_row_address: bool,
with_memtable_gen: bool,
pk_columns: Vec<String>,
session: Option<Arc<Session>>,
flushed_cache: Option<Arc<FlushedMemTableCache>>,
}
impl LsmScanner {
pub fn new(
base_table: Arc<Dataset>,
shard_snapshots: Vec<ShardSnapshot>,
pk_columns: Vec<String>,
) -> Self {
let lance_schema = base_table.schema();
let arrow_schema: arrow_schema::Schema = lance_schema.into();
let session = Some(base_table.session());
Self {
base: BaseSource::Table(base_table),
schema: Arc::new(arrow_schema),
shard_snapshots,
in_memory_memtables: HashMap::new(),
projection: None,
filter: None,
limit: None,
offset: None,
with_row_address: false,
with_memtable_gen: false,
pk_columns,
session,
flushed_cache: None,
}
}
pub fn without_base_table(
schema: SchemaRef,
base_path: impl Into<String>,
shard_snapshots: Vec<ShardSnapshot>,
pk_columns: Vec<String>,
) -> Self {
Self {
base: BaseSource::PathOnly(base_path.into()),
schema,
shard_snapshots,
in_memory_memtables: HashMap::new(),
projection: None,
filter: None,
limit: None,
offset: None,
with_row_address: false,
with_memtable_gen: false,
pk_columns,
session: None,
flushed_cache: None,
}
}
pub fn with_active_memtable(mut self, shard_id: Uuid, memtable: InMemoryMemTableRef) -> Self {
match self.in_memory_memtables.entry(shard_id) {
Entry::Occupied(mut e) => e.get_mut().active = memtable,
Entry::Vacant(e) => {
e.insert(InMemoryMemTables {
active: memtable,
frozen: Vec::new(),
});
}
}
self
}
pub fn with_in_memory_memtables(
mut self,
shard_id: Uuid,
memtables: InMemoryMemTables,
) -> Self {
self.in_memory_memtables.insert(shard_id, memtables);
self
}
pub fn with_session(mut self, session: Arc<Session>) -> Self {
self.session = Some(session);
self
}
pub fn with_flushed_cache(mut self, cache: Arc<FlushedMemTableCache>) -> Self {
self.flushed_cache = Some(cache);
self
}
pub fn project(mut self, columns: &[&str]) -> Self {
self.projection = Some(columns.iter().map(|s| s.to_string()).collect());
self
}
pub fn filter(mut self, filter_expr: &str) -> Result<Self> {
let ctx = SessionContext::new();
let df_schema = self
.schema
.as_ref()
.clone()
.to_dfschema()
.map_err(|e| Error::invalid_input(format!("Failed to create DFSchema: {}", e)))?;
let expr = ctx.parse_sql_expr(filter_expr, &df_schema).map_err(|e| {
Error::invalid_input(format!("Failed to parse filter expression: {}", e))
})?;
self.filter = Some(expr);
Ok(self)
}
pub fn filter_expr(mut self, expr: Expr) -> Self {
self.filter = Some(expr);
self
}
pub fn limit(mut self, limit: usize, offset: Option<usize>) -> Self {
self.limit = Some(limit);
self.offset = offset;
self
}
pub fn with_row_address(mut self) -> Self {
self.with_row_address = true;
self
}
pub fn with_memtable_gen(mut self) -> Self {
self.with_memtable_gen = true;
self
}
pub fn schema(&self) -> SchemaRef {
self.schema.clone()
}
pub async fn create_plan(&self) -> Result<Arc<dyn ExecutionPlan>> {
let collector = self.build_collector();
let base_schema = self.schema();
let mut planner = LsmScanPlanner::new(collector, self.pk_columns.clone(), base_schema);
if let Some(session) = &self.session {
planner = planner.with_session(session.clone());
}
if let Some(cache) = &self.flushed_cache {
planner = planner.with_flushed_cache(cache.clone());
}
planner
.plan_scan(
self.projection.as_deref(),
self.filter.as_ref(),
self.limit,
self.offset,
self.with_memtable_gen,
self.with_row_address,
)
.await
}
pub async fn try_into_stream(&self) -> Result<SendableRecordBatchStream> {
let plan = self.create_plan().await?;
let ctx = SessionContext::new();
let task_ctx = ctx.task_ctx();
plan.execute(0, task_ctx)
.map_err(|e| Error::io(format!("Failed to execute plan: {}", e)))
}
pub async fn try_into_batch(&self) -> Result<RecordBatch> {
let stream = self.try_into_stream().await?;
let batches: Vec<RecordBatch> = stream
.try_collect()
.await
.map_err(|e| Error::io(format!("Failed to collect batches: {}", e)))?;
if batches.is_empty() {
let schema = self.schema();
return Ok(RecordBatch::new_empty(schema));
}
let schema = batches[0].schema();
arrow_select::concat::concat_batches(&schema, &batches)
.map_err(|e| Error::io(format!("Failed to concatenate batches: {}", e)))
}
pub async fn count_rows(&self) -> Result<u64> {
let stream = self.try_into_stream().await?;
let batches: Vec<RecordBatch> = stream
.try_collect()
.await
.map_err(|e| Error::io(format!("Failed to count rows: {}", e)))?;
Ok(batches.iter().map(|b| b.num_rows() as u64).sum())
}
fn build_collector(&self) -> LsmDataSourceCollector {
let mut collector = match &self.base {
BaseSource::Table(dataset) => {
LsmDataSourceCollector::new(dataset.clone(), self.shard_snapshots.clone())
}
BaseSource::PathOnly(path) => LsmDataSourceCollector::without_base_table(
path.clone(),
self.shard_snapshots.clone(),
),
};
for (shard_id, mems) in &self.in_memory_memtables {
collector = collector.with_in_memory_memtables(*shard_id, mems.clone());
}
collector
}
}
impl std::fmt::Debug for LsmScanner {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let (label, value) = match &self.base {
BaseSource::Table(dataset) => ("base_table", dataset.uri().to_string()),
BaseSource::PathOnly(path) => ("base_path", path.clone()),
};
f.debug_struct("LsmScanner")
.field(label, &value)
.field("num_shards", &self.shard_snapshots.len())
.field(
"num_in_memory_memtables",
&self
.in_memory_memtables
.values()
.map(|m| 1 + m.frozen.len())
.sum::<usize>(),
)
.field("projection", &self.projection)
.field("limit", &self.limit)
.field("offset", &self.offset)
.field("pk_columns", &self.pk_columns)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_lsm_scanner_builder() {
let pk_columns = ["id".to_string()];
let shard_snapshots: Vec<ShardSnapshot> = vec![];
assert_eq!(pk_columns.len(), 1);
assert!(shard_snapshots.is_empty());
}
#[test]
fn test_shard_snapshot_construction() {
use super::super::data_source::ShardSnapshot;
let shard_id = Uuid::new_v4();
let snapshot = ShardSnapshot::new(shard_id)
.with_spec_id(1)
.with_current_generation(5)
.with_flushed_generation(1, "path/gen_1".to_string())
.with_flushed_generation(2, "path/gen_2".to_string());
assert_eq!(snapshot.shard_id, shard_id);
assert_eq!(snapshot.spec_id, 1);
assert_eq!(snapshot.current_generation, 5);
assert_eq!(snapshot.flushed_generations.len(), 2);
}
#[test]
fn test_in_memory_memtable_ref() {
use crate::dataset::mem_wal::write::{BatchStore, IndexStore};
let batch_store = Arc::new(BatchStore::with_capacity(100));
let index_store = Arc::new(IndexStore::new());
let schema = Arc::new(arrow_schema::Schema::empty());
let memtable_ref = InMemoryMemTableRef {
batch_store,
index_store,
schema,
generation: 10,
};
assert_eq!(memtable_ref.generation, 10);
}
}