use std::any::Any;
use std::fmt::{Debug, Formatter};
use std::sync::Arc;
use arrow_array::{BooleanArray, RecordBatch, UInt64Array};
use arrow_schema::SchemaRef;
use datafusion::common::stats::Precision;
use datafusion::error::Result as DataFusionResult;
use datafusion::execution::TaskContext;
use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
use datafusion::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet};
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
use datafusion::physical_plan::{
DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties,
SendableRecordBatchStream, Statistics,
};
use datafusion::prelude::Expr;
use datafusion_physical_expr::{EquivalenceProperties, PhysicalExprRef};
use futures::stream::{self, StreamExt};
use crate::dataset::mem_wal::write::BatchStore;
pub const ROW_ADDRESS_COLUMN: &str = "_rowaddr";
pub struct MemTableScanExec {
batch_store: Arc<BatchStore>,
max_visible_batch_position: usize,
projection: Option<Vec<usize>>,
output_schema: SchemaRef,
source_schema: SchemaRef,
properties: PlanProperties,
metrics: ExecutionPlanMetricsSet,
with_row_id: bool,
with_row_address: bool,
filter_predicate: Option<PhysicalExprRef>,
filter_expr: Option<Expr>,
}
impl Debug for MemTableScanExec {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MemTableScanExec")
.field(
"max_visible_batch_position",
&self.max_visible_batch_position,
)
.field("projection", &self.projection)
.field("with_row_id", &self.with_row_id)
.field("with_row_address", &self.with_row_address)
.field("has_filter", &self.filter_predicate.is_some())
.finish()
}
}
impl MemTableScanExec {
pub fn new(
batch_store: Arc<BatchStore>,
max_visible_batch_position: usize,
projection: Option<Vec<usize>>,
output_schema: SchemaRef,
with_row_id: bool,
) -> Self {
Self::with_filter(
batch_store,
max_visible_batch_position,
projection,
output_schema.clone(),
output_schema,
with_row_id,
false, None,
None,
)
}
#[allow(clippy::too_many_arguments)]
pub fn with_filter(
batch_store: Arc<BatchStore>,
max_visible_batch_position: usize,
projection: Option<Vec<usize>>,
output_schema: SchemaRef,
source_schema: SchemaRef,
with_row_id: bool,
with_row_address: bool,
filter_predicate: Option<PhysicalExprRef>,
filter_expr: Option<Expr>,
) -> Self {
let properties = PlanProperties::new(
EquivalenceProperties::new(output_schema.clone()),
Partitioning::UnknownPartitioning(1),
EmissionType::Incremental,
Boundedness::Bounded,
);
Self {
batch_store,
max_visible_batch_position,
projection,
output_schema,
source_schema,
properties,
metrics: ExecutionPlanMetricsSet::new(),
with_row_id,
with_row_address,
filter_predicate,
filter_expr,
}
}
}
impl DisplayAs for MemTableScanExec {
fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter<'_>) -> std::fmt::Result {
let projection_names: Vec<&str> = self
.output_schema
.fields()
.iter()
.map(|field| field.name().as_str())
.collect();
let filter_str = self
.filter_expr
.as_ref()
.map(|e| format!(", filter={}", e))
.unwrap_or_default();
let row_addr_str = if self.with_row_address {
", with_row_address=true"
} else {
""
};
match t {
DisplayFormatType::Default | DisplayFormatType::Verbose => {
write!(
f,
"MemTableScanExec: projection=[{}], with_row_id={}{}{}",
projection_names.join(", "),
self.with_row_id,
row_addr_str,
filter_str
)
}
DisplayFormatType::TreeRender => {
write!(
f,
"MemTableScanExec\nprojection=[{}]\nwith_row_id={}{}{}",
projection_names.join(", "),
self.with_row_id,
row_addr_str,
filter_str
)
}
}
}
}
impl ExecutionPlan for MemTableScanExec {
fn name(&self) -> &str {
"MemTableScanExec"
}
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
self.output_schema.clone()
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![]
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
if !children.is_empty() {
return Err(datafusion::error::DataFusionError::Internal(
"MemTableScanExec does not have children".to_string(),
));
}
Ok(self)
}
fn execute(
&self,
_partition: usize,
_context: Arc<TaskContext>,
) -> DataFusionResult<SendableRecordBatchStream> {
let batches_with_offsets = self
.batch_store
.visible_batches_with_offsets(self.max_visible_batch_position);
let projection = self.projection.clone();
let schema = self.output_schema.clone();
let source_schema = self.source_schema.clone();
let with_row_id = self.with_row_id;
let with_row_address = self.with_row_address;
let filter_predicate = self.filter_predicate.clone();
let need_row_offsets = with_row_id || with_row_address;
let projected_batches: Vec<DataFusionResult<RecordBatch>> = batches_with_offsets
.into_iter()
.filter_map(|(batch, row_offset)| {
let (filtered_batch, filtered_row_offsets) = if let Some(ref predicate) =
filter_predicate
{
let filter_result = predicate.evaluate(&batch);
let filter_array = match filter_result {
Ok(v) => match v.into_array(batch.num_rows()) {
Ok(arr) => arr,
Err(e) => return Some(Err(e)),
},
Err(e) => return Some(Err(e)),
};
let Some(filter_array) = filter_array.as_any().downcast_ref::<BooleanArray>()
else {
return Some(Err(datafusion::error::DataFusionError::Internal(
"Filter predicate did not evaluate to boolean".to_string(),
)));
};
let filtered =
match arrow_select::filter::filter_record_batch(&batch, filter_array) {
Ok(b) => b,
Err(e) => return Some(Err(e.into())),
};
let row_offsets = if need_row_offsets {
let mut offsets = Vec::with_capacity(filtered.num_rows());
for (i, valid) in filter_array.iter().enumerate() {
if valid.unwrap_or(false) {
offsets.push(row_offset + i as u64);
}
}
offsets
} else {
vec![]
};
(filtered, row_offsets)
} else {
let row_offsets = if need_row_offsets {
(0..batch.num_rows() as u64)
.map(|i| row_offset + i)
.collect()
} else {
vec![]
};
(batch, row_offsets)
};
if filtered_batch.num_rows() == 0 {
return None;
}
let mut columns: Vec<Arc<dyn arrow_array::Array>> =
if let Some(ref indices) = projection {
indices
.iter()
.map(|&i| filtered_batch.column(i).clone())
.collect()
} else {
filtered_batch.columns().to_vec()
};
if with_row_id {
columns.push(Arc::new(UInt64Array::from(filtered_row_offsets.clone())));
}
if with_row_address {
columns.push(Arc::new(UInt64Array::from(filtered_row_offsets)));
}
Some(
RecordBatch::try_new(schema.clone(), columns)
.map_err(datafusion::error::DataFusionError::from),
)
})
.collect();
let _ = source_schema;
let stream = stream::iter(projected_batches).boxed();
Ok(Box::pin(RecordBatchStreamAdapter::new(
self.output_schema.clone(),
stream,
)))
}
fn partition_statistics(&self, _partition: Option<usize>) -> DataFusionResult<Statistics> {
Ok(Statistics {
num_rows: Precision::Absent,
total_byte_size: Precision::Absent,
column_statistics: vec![],
})
}
fn metrics(&self) -> Option<MetricsSet> {
Some(self.metrics.clone_inner())
}
fn properties(&self) -> &PlanProperties {
&self.properties
}
fn supports_limit_pushdown(&self) -> bool {
false
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow_array::{Int32Array, StringArray};
use arrow_schema::{DataType, Field, Schema};
use futures::TryStreamExt;
fn create_test_schema() -> Arc<Schema> {
Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("name", DataType::Utf8, true),
]))
}
fn create_test_batch(schema: &Schema, start_id: i32, count: usize) -> RecordBatch {
let ids: Vec<i32> = (start_id..start_id + count as i32).collect();
let names: Vec<String> = ids.iter().map(|id| format!("name_{}", id)).collect();
RecordBatch::try_new(
Arc::new(schema.clone()),
vec![
Arc::new(Int32Array::from(ids)),
Arc::new(StringArray::from(names)),
],
)
.unwrap()
}
#[tokio::test]
async fn test_scan_exec_basic() {
let schema = create_test_schema();
let batch_store = Arc::new(BatchStore::with_capacity(100));
let batch = create_test_batch(&schema, 0, 10);
batch_store.append(batch).unwrap();
let exec = MemTableScanExec::new(batch_store, 0, None, schema, false);
let ctx = Arc::new(TaskContext::default());
let stream = exec.execute(0, ctx).unwrap();
let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
assert_eq!(batches.len(), 1);
assert_eq!(batches[0].num_rows(), 10);
}
#[tokio::test]
async fn test_scan_exec_visibility() {
let schema = create_test_schema();
let batch_store = Arc::new(BatchStore::with_capacity(100));
batch_store
.append(create_test_batch(&schema, 0, 10))
.unwrap();
batch_store
.append(create_test_batch(&schema, 10, 10))
.unwrap();
batch_store
.append(create_test_batch(&schema, 20, 10))
.unwrap();
let exec = MemTableScanExec::new(batch_store.clone(), 1, None, schema.clone(), false);
let ctx = Arc::new(TaskContext::default());
let stream = exec.execute(0, ctx).unwrap();
let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
assert_eq!(batches.len(), 2);
let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
assert_eq!(total_rows, 20);
}
#[tokio::test]
async fn test_scan_exec_projection() {
let schema = create_test_schema();
let batch_store = Arc::new(BatchStore::with_capacity(100));
let batch = create_test_batch(&schema, 0, 10);
batch_store.append(batch).unwrap();
let projected_schema =
Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
let exec = MemTableScanExec::new(batch_store, 0, Some(vec![0]), projected_schema, false);
let ctx = Arc::new(TaskContext::default());
let stream = exec.execute(0, ctx).unwrap();
let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
assert_eq!(batches.len(), 1);
assert_eq!(batches[0].num_columns(), 1);
assert_eq!(batches[0].schema().field(0).name(), "id");
}
#[tokio::test]
async fn test_scan_exec_empty() {
let schema = create_test_schema();
let batch_store = Arc::new(BatchStore::with_capacity(100));
let exec = MemTableScanExec::new(batch_store, 0, None, schema, false);
let ctx = Arc::new(TaskContext::default());
let stream = exec.execute(0, ctx).unwrap();
let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
assert!(batches.is_empty());
}
#[tokio::test]
async fn test_scan_exec_statistics() {
let schema = create_test_schema();
let batch_store = Arc::new(BatchStore::with_capacity(100));
batch_store
.append(create_test_batch(&schema, 0, 10))
.unwrap();
batch_store
.append(create_test_batch(&schema, 10, 20))
.unwrap();
let exec = MemTableScanExec::new(batch_store, 1, None, schema, false);
let stats = exec.partition_statistics(None).unwrap();
assert_eq!(stats.num_rows, Precision::Absent);
}
#[tokio::test]
async fn test_scan_exec_with_row_id() {
let schema = create_test_schema();
let batch_store = Arc::new(BatchStore::with_capacity(100));
batch_store
.append(create_test_batch(&schema, 0, 5))
.unwrap();
batch_store
.append(create_test_batch(&schema, 5, 3))
.unwrap();
let schema_with_rowid = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("name", DataType::Utf8, true),
Field::new("_rowid", DataType::UInt64, true),
]));
let exec = MemTableScanExec::new(batch_store, 1, None, schema_with_rowid, true);
let ctx = Arc::new(TaskContext::default());
let stream = exec.execute(0, ctx).unwrap();
let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
assert_eq!(batches.len(), 2);
let row_ids_1 = batches[0]
.column(2)
.as_any()
.downcast_ref::<UInt64Array>()
.unwrap();
assert_eq!(row_ids_1.len(), 5);
assert_eq!(row_ids_1.value(0), 0);
assert_eq!(row_ids_1.value(4), 4);
let row_ids_2 = batches[1]
.column(2)
.as_any()
.downcast_ref::<UInt64Array>()
.unwrap();
assert_eq!(row_ids_2.len(), 3);
assert_eq!(row_ids_2.value(0), 5);
assert_eq!(row_ids_2.value(2), 7);
}
}