use crate::physical_plan::Index;
use datafusion::arrow::datatypes::SchemaRef;
use datafusion::error::DataFusionError;
use datafusion::execution::TaskContext;
use datafusion::logical_expr::Expr;
use datafusion::physical_expr::EquivalenceProperties;
use datafusion::physical_plan::execution_plan::Boundedness;
use datafusion::physical_plan::expressions::{Column, PhysicalSortExpr};
use datafusion::physical_plan::{
execution_plan::EmissionType, DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning,
PlanProperties, SendableRecordBatchStream,
};
use std::sync::Arc;
#[derive(Debug)]
pub struct IndexScanExec {
index: Arc<dyn Index>,
filters: Vec<Expr>,
limit: Option<usize>,
plan_properties: PlanProperties,
}
impl DisplayAs for IndexScanExec {
fn fmt_as(
&self,
t: datafusion::physical_plan::DisplayFormatType,
f: &mut std::fmt::Formatter,
) -> std::fmt::Result {
match t {
DisplayFormatType::Default
| DisplayFormatType::Verbose
| DisplayFormatType::TreeRender => {
write!(f, "IndexScanExec: index={}, filters=[", self.index.name())?;
for (i, filter) in self.filters.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{filter}")?;
}
write!(f, "], limit={:?}", self.limit)
}
}
}
}
impl ExecutionPlan for IndexScanExec {
fn name(&self) -> &str {
"IndexScanExec"
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn properties(&self) -> &PlanProperties {
&self.plan_properties
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![]
}
fn with_new_children(
self: Arc<Self>,
_children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>, DataFusionError> {
Ok(self)
}
fn execute(
&self,
partition: usize,
_context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream, DataFusionError> {
if partition != 0 {
return Err(DataFusionError::Internal(
"IndexScanExec only supports a single partition".to_string(),
));
}
self.index.scan(&self.filters, self.limit)
}
}
impl IndexScanExec {
pub fn try_new(
index: Arc<dyn Index>,
filters: Vec<Expr>,
limit: Option<usize>,
schema: SchemaRef,
) -> Result<Self, DataFusionError> {
let ordering = if index.is_ordered() {
schema
.fields()
.iter()
.enumerate()
.map(|(i, field)| PhysicalSortExpr {
expr: Arc::new(Column::new(field.name(), i)),
options: Default::default(),
})
.collect()
} else {
vec![]
};
let eq = EquivalenceProperties::new_with_orderings(schema.clone(), [ordering]);
let plan_properties = PlanProperties::new(
eq,
Partitioning::UnknownPartitioning(1),
EmissionType::Incremental,
Boundedness::Bounded,
);
Ok(Self {
index,
filters,
limit,
plan_properties,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use datafusion::arrow::datatypes::{DataType, Field, Schema};
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::common::Statistics;
use datafusion::physical_plan::memory::MemoryStream;
use std::any::Any;
use std::sync::Mutex;
#[derive(Debug)]
struct MockIndex {
schema: SchemaRef,
scan_called: Mutex<bool>,
}
impl MockIndex {
fn new() -> Self {
Self {
schema: Arc::new(Schema::new(vec![Field::new("id", DataType::UInt64, false)])),
scan_called: Mutex::new(false),
}
}
}
impl Index for MockIndex {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"mock_index"
}
fn index_schema(&self) -> SchemaRef {
self.schema.clone()
}
fn table_name(&self) -> &str {
"mock_table"
}
fn column_name(&self) -> &str {
"mock_column"
}
fn scan(
&self,
_filters: &[Expr],
_limit: Option<usize>,
) -> Result<SendableRecordBatchStream, DataFusionError> {
*self.scan_called.lock().unwrap() = true;
let batch = RecordBatch::new_empty(self.schema.clone());
let stream = MemoryStream::try_new(vec![batch], self.schema.clone(), None)?;
Ok(Box::pin(stream))
}
fn statistics(&self) -> Statistics {
Statistics::new_unknown(&self.schema)
}
}
#[tokio::test]
async fn test_index_scan_exec() -> datafusion::common::Result<()> {
let index = Arc::new(MockIndex::new());
let schema = index.index_schema();
let exec = IndexScanExec::try_new(index.clone(), vec![], None, schema)?;
let task_ctx = Arc::new(TaskContext::default());
let stream = exec.execute(0, task_ctx)?;
let batches = datafusion::physical_plan::common::collect(stream).await?;
assert!(*index.scan_called.lock().unwrap());
assert_eq!(batches.len(), 1);
assert_eq!(batches[0].num_rows(), 0);
Ok(())
}
#[tokio::test]
async fn test_index_scan_exec_invalid_partition() {
let index = Arc::new(MockIndex::new());
let schema = index.index_schema();
let exec = IndexScanExec::try_new(index.clone(), vec![], None, schema).unwrap();
let task_ctx = Arc::new(TaskContext::default());
let res = exec.execute(1, task_ctx);
match res {
Err(e) => {
assert!(
e.to_string()
.contains("IndexScanExec only supports a single partition"),
"unexpected error message: {e}"
);
}
Ok(_) => panic!("expected an error"),
}
}
}