use std::any::Any;
use std::sync::{Arc, Mutex};
use crate::coop::cooperative;
use crate::execution_plan::{Boundedness, EmissionType, SchedulingType};
use crate::memory::MemoryStream;
use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet};
use crate::{
DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties,
SendableRecordBatchStream, Statistics,
};
use arrow::datatypes::SchemaRef;
use arrow::record_batch::RecordBatch;
use datafusion_common::{Result, assert_eq_or_internal_err, internal_datafusion_err};
use datafusion_execution::TaskContext;
use datafusion_execution::memory_pool::MemoryReservation;
use datafusion_physical_expr::{EquivalenceProperties, Partitioning};
#[derive(Debug)]
pub(super) struct ReservedBatches {
batches: Vec<RecordBatch>,
reservation: MemoryReservation,
}
impl ReservedBatches {
pub(super) fn new(batches: Vec<RecordBatch>, reservation: MemoryReservation) -> Self {
ReservedBatches {
batches,
reservation,
}
}
}
#[derive(Debug)]
pub struct WorkTable {
batches: Mutex<Option<ReservedBatches>>,
name: String,
}
impl WorkTable {
pub(super) fn new(name: String) -> Self {
Self {
batches: Mutex::new(None),
name,
}
}
fn take(&self) -> Result<ReservedBatches> {
self.batches
.lock()
.unwrap()
.take()
.ok_or_else(|| internal_datafusion_err!("Unexpected empty work table"))
}
pub(super) fn update(&self, batches: ReservedBatches) {
self.batches.lock().unwrap().replace(batches);
}
}
#[derive(Clone, Debug)]
pub struct WorkTableExec {
name: String,
schema: SchemaRef,
projection: Option<Vec<usize>>,
work_table: Arc<WorkTable>,
metrics: ExecutionPlanMetricsSet,
cache: Arc<PlanProperties>,
}
impl WorkTableExec {
pub fn new(
name: String,
mut schema: SchemaRef,
projection: Option<Vec<usize>>,
) -> Result<Self> {
if let Some(projection) = &projection {
schema = Arc::new(schema.project(projection)?);
}
let cache = Self::compute_properties(Arc::clone(&schema));
Ok(Self {
name: name.clone(),
schema,
projection,
work_table: Arc::new(WorkTable::new(name)),
metrics: ExecutionPlanMetricsSet::new(),
cache: Arc::new(cache),
})
}
pub fn name(&self) -> &str {
&self.name
}
pub fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
fn compute_properties(schema: SchemaRef) -> PlanProperties {
PlanProperties::new(
EquivalenceProperties::new(schema),
Partitioning::UnknownPartitioning(1),
EmissionType::Incremental,
Boundedness::Bounded,
)
.with_scheduling_type(SchedulingType::Cooperative)
}
}
impl DisplayAs for WorkTableExec {
fn fmt_as(
&self,
t: DisplayFormatType,
f: &mut std::fmt::Formatter,
) -> std::fmt::Result {
match t {
DisplayFormatType::Default | DisplayFormatType::Verbose => {
write!(f, "WorkTableExec: name={}", self.name)
}
DisplayFormatType::TreeRender => {
write!(f, "name={}", self.name)
}
}
}
}
impl ExecutionPlan for WorkTableExec {
fn name(&self) -> &'static str {
"WorkTableExec"
}
fn as_any(&self) -> &dyn Any {
self
}
fn properties(&self) -> &Arc<PlanProperties> {
&self.cache
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![]
}
fn with_new_children(
self: Arc<Self>,
_: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
Ok(Arc::clone(&self) as Arc<dyn ExecutionPlan>)
}
fn execute(
&self,
partition: usize,
_context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
assert_eq_or_internal_err!(
partition,
0,
"WorkTableExec got an invalid partition {partition} (expected 0)"
);
let ReservedBatches {
mut batches,
reservation,
} = self.work_table.take()?;
if let Some(projection) = &self.projection {
batches = batches
.into_iter()
.map(|b| b.project(projection))
.collect::<Result<Vec<_>, _>>()?;
}
let stream = MemoryStream::try_new(batches, Arc::clone(&self.schema), None)?
.with_reservation(reservation);
Ok(Box::pin(cooperative(stream)))
}
fn metrics(&self) -> Option<MetricsSet> {
Some(self.metrics.clone_inner())
}
fn partition_statistics(&self, _partition: Option<usize>) -> Result<Statistics> {
Ok(Statistics::new_unknown(&self.schema()))
}
fn with_new_state(
&self,
state: Arc<dyn Any + Send + Sync>,
) -> Option<Arc<dyn ExecutionPlan>> {
let work_table = state.downcast::<WorkTable>().ok()?;
if work_table.name != self.name {
return None; }
Some(Arc::new(Self {
name: self.name.clone(),
schema: Arc::clone(&self.schema),
projection: self.projection.clone(),
metrics: ExecutionPlanMetricsSet::new(),
work_table,
cache: Arc::clone(&self.cache),
}))
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::array::{ArrayRef, Int16Array, Int32Array, Int64Array};
use arrow_schema::{DataType, Field, Schema};
use datafusion_execution::memory_pool::{MemoryConsumer, UnboundedMemoryPool};
use futures::StreamExt;
#[test]
fn test_work_table() {
let work_table = WorkTable::new("test".into());
assert!(work_table.take().is_err());
let pool = Arc::new(UnboundedMemoryPool::default()) as _;
let reservation = MemoryConsumer::new("test_work_table").register(&pool);
let array: ArrayRef = Arc::new((0..5).collect::<Int32Array>());
let batch = RecordBatch::try_from_iter(vec![("col", array)]).unwrap();
reservation.try_grow(100).unwrap();
work_table.update(ReservedBatches::new(vec![batch.clone()], reservation));
let reserved_batches = work_table.take().unwrap();
assert_eq!(reserved_batches.batches, vec![batch.clone()]);
let memory_stream =
MemoryStream::try_new(reserved_batches.batches, batch.schema(), None)
.unwrap()
.with_reservation(reserved_batches.reservation);
assert_eq!(pool.reserved(), 100);
drop(memory_stream);
assert_eq!(pool.reserved(), 0);
}
#[tokio::test]
async fn test_work_table_exec() {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int64, false),
Field::new("b", DataType::Int32, false),
Field::new("c", DataType::Int16, false),
]));
let work_table_exec =
WorkTableExec::new("wt".into(), Arc::clone(&schema), Some(vec![2, 1]))
.unwrap();
let work_table = Arc::new(WorkTable::new("wt".into()));
let work_table_exec = work_table_exec
.with_new_state(Arc::clone(&work_table) as _)
.unwrap();
let pool = Arc::new(UnboundedMemoryPool::default()) as _;
let reservation = MemoryConsumer::new("test_work_table").register(&pool);
let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5])),
Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])),
Arc::new(Int16Array::from(vec![1, 2, 3, 4, 5])),
],
)
.unwrap();
work_table.update(ReservedBatches::new(vec![batch], reservation));
let returned_batch = work_table_exec
.execute(0, Arc::new(TaskContext::default()))
.unwrap()
.next()
.await
.unwrap()
.unwrap();
assert_eq!(
returned_batch,
RecordBatch::try_from_iter(vec![
("c", Arc::new(Int16Array::from(vec![1, 2, 3, 4, 5])) as _),
("b", Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])) as _),
])
.unwrap()
);
}
}