use std::any::Any;
use std::sync::Arc;
use std::task::{Context, Poll};
use super::work_table::{ReservedBatches, WorkTable};
use crate::aggregates::group_values::{GroupValues, new_group_values};
use crate::aggregates::order::GroupOrdering;
use crate::execution_plan::{Boundedness, EmissionType, reset_plan_states};
use crate::metrics::{
BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, RecordOutput,
};
use crate::{
DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, RecordBatchStream,
SendableRecordBatchStream,
};
use arrow::array::{BooleanArray, BooleanBuilder};
use arrow::compute::filter_record_batch;
use arrow::datatypes::SchemaRef;
use arrow::record_batch::RecordBatch;
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::{Result, internal_datafusion_err, not_impl_err};
use datafusion_execution::TaskContext;
use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
use datafusion_physical_expr::{EquivalenceProperties, Partitioning};
use futures::{Stream, StreamExt, ready};
#[derive(Debug, Clone)]
pub struct RecursiveQueryExec {
name: String,
work_table: Arc<WorkTable>,
static_term: Arc<dyn ExecutionPlan>,
recursive_term: Arc<dyn ExecutionPlan>,
is_distinct: bool,
metrics: ExecutionPlanMetricsSet,
cache: Arc<PlanProperties>,
}
impl RecursiveQueryExec {
pub fn try_new(
name: String,
static_term: Arc<dyn ExecutionPlan>,
recursive_term: Arc<dyn ExecutionPlan>,
is_distinct: bool,
) -> Result<Self> {
let work_table = Arc::new(WorkTable::new(name.clone()));
let recursive_term = assign_work_table(recursive_term, &work_table)?;
let cache = Self::compute_properties(static_term.schema());
Ok(RecursiveQueryExec {
name,
static_term,
recursive_term,
is_distinct,
work_table,
metrics: ExecutionPlanMetricsSet::new(),
cache: Arc::new(cache),
})
}
pub fn name(&self) -> &str {
&self.name
}
pub fn static_term(&self) -> &Arc<dyn ExecutionPlan> {
&self.static_term
}
pub fn recursive_term(&self) -> &Arc<dyn ExecutionPlan> {
&self.recursive_term
}
pub fn is_distinct(&self) -> bool {
self.is_distinct
}
fn compute_properties(schema: SchemaRef) -> PlanProperties {
let eq_properties = EquivalenceProperties::new(schema);
PlanProperties::new(
eq_properties,
Partitioning::UnknownPartitioning(1),
EmissionType::Incremental,
Boundedness::Bounded,
)
}
}
impl ExecutionPlan for RecursiveQueryExec {
fn name(&self) -> &'static str {
"RecursiveQueryExec"
}
fn as_any(&self) -> &dyn Any {
self
}
fn properties(&self) -> &Arc<PlanProperties> {
&self.cache
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![&self.static_term, &self.recursive_term]
}
fn maintains_input_order(&self) -> Vec<bool> {
vec![false, false]
}
fn benefits_from_input_partitioning(&self) -> Vec<bool> {
vec![false, false]
}
fn required_input_distribution(&self) -> Vec<crate::Distribution> {
vec![
crate::Distribution::SinglePartition,
crate::Distribution::SinglePartition,
]
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
RecursiveQueryExec::try_new(
self.name.clone(),
Arc::clone(&children[0]),
Arc::clone(&children[1]),
self.is_distinct,
)
.map(|e| Arc::new(e) as _)
}
fn execute(
&self,
partition: usize,
context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
if partition != 0 {
return Err(internal_datafusion_err!(
"RecursiveQueryExec got an invalid partition {partition} (expected 0)"
));
}
let static_stream = self.static_term.execute(partition, Arc::clone(&context))?;
let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
Ok(Box::pin(RecursiveQueryStream::new(
context,
Arc::clone(&self.work_table),
Arc::clone(&self.recursive_term),
static_stream,
self.is_distinct,
baseline_metrics,
)?))
}
fn metrics(&self) -> Option<MetricsSet> {
Some(self.metrics.clone_inner())
}
}
impl DisplayAs for RecursiveQueryExec {
fn fmt_as(
&self,
t: DisplayFormatType,
f: &mut std::fmt::Formatter,
) -> std::fmt::Result {
match t {
DisplayFormatType::Default | DisplayFormatType::Verbose => {
write!(
f,
"RecursiveQueryExec: name={}, is_distinct={}",
self.name, self.is_distinct
)
}
DisplayFormatType::TreeRender => {
write!(f, "")
}
}
}
}
struct RecursiveQueryStream {
task_context: Arc<TaskContext>,
work_table: Arc<WorkTable>,
recursive_term: Arc<dyn ExecutionPlan>,
static_stream: Option<SendableRecordBatchStream>,
recursive_stream: Option<SendableRecordBatchStream>,
schema: SchemaRef,
buffer: Vec<RecordBatch>,
reservation: MemoryReservation,
distinct_deduplicator: Option<DistinctDeduplicator>,
baseline_metrics: BaselineMetrics,
}
impl RecursiveQueryStream {
fn new(
task_context: Arc<TaskContext>,
work_table: Arc<WorkTable>,
recursive_term: Arc<dyn ExecutionPlan>,
static_stream: SendableRecordBatchStream,
is_distinct: bool,
baseline_metrics: BaselineMetrics,
) -> Result<Self> {
let schema = static_stream.schema();
let reservation =
MemoryConsumer::new("RecursiveQuery").register(task_context.memory_pool());
let distinct_deduplicator = is_distinct
.then(|| DistinctDeduplicator::new(Arc::clone(&schema), &task_context))
.transpose()?;
Ok(Self {
task_context,
work_table,
recursive_term,
static_stream: Some(static_stream),
recursive_stream: None,
schema,
buffer: vec![],
reservation,
distinct_deduplicator,
baseline_metrics,
})
}
fn push_batch(
mut self: std::pin::Pin<&mut Self>,
mut batch: RecordBatch,
) -> Poll<Option<Result<RecordBatch>>> {
let baseline_metrics = self.baseline_metrics.clone();
if let Some(deduplicator) = &mut self.distinct_deduplicator {
let _timer_guard = baseline_metrics.elapsed_compute().timer();
batch = deduplicator.deduplicate(&batch)?;
}
if let Err(e) = self.reservation.try_grow(batch.get_array_memory_size()) {
return Poll::Ready(Some(Err(e)));
}
self.buffer.push(batch.clone());
(&batch).record_output(&baseline_metrics);
Poll::Ready(Some(Ok(batch)))
}
fn poll_next_iteration(
mut self: std::pin::Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<RecordBatch>>> {
let total_length = self
.buffer
.iter()
.fold(0, |acc, batch| acc + batch.num_rows());
if total_length == 0 {
return Poll::Ready(None);
}
let reserved_batches = ReservedBatches::new(
std::mem::take(&mut self.buffer),
self.reservation.take(),
);
self.work_table.update(reserved_batches);
let partition = 0;
let recursive_plan = reset_plan_states(Arc::clone(&self.recursive_term))?;
self.recursive_stream =
Some(recursive_plan.execute(partition, Arc::clone(&self.task_context))?);
self.poll_next(cx)
}
}
fn assign_work_table(
plan: Arc<dyn ExecutionPlan>,
work_table: &Arc<WorkTable>,
) -> Result<Arc<dyn ExecutionPlan>> {
let mut work_table_refs = 0;
plan.transform_down(|plan| {
if let Some(new_plan) =
plan.with_new_state(Arc::clone(work_table) as Arc<dyn Any + Send + Sync>)
{
if work_table_refs > 0 {
not_impl_err!(
"Multiple recursive references to the same CTE are not supported"
)
} else {
work_table_refs += 1;
Ok(Transformed::yes(new_plan))
}
} else {
Ok(Transformed::no(plan))
}
})
.data()
}
impl Stream for RecursiveQueryStream {
type Item = Result<RecordBatch>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
if let Some(static_stream) = &mut self.static_stream {
let batch_result = ready!(static_stream.poll_next_unpin(cx));
match &batch_result {
None => {
self.static_stream = None;
self.poll_next_iteration(cx)
}
Some(Ok(batch)) => self.push_batch(batch.clone()),
_ => Poll::Ready(batch_result),
}
} else if let Some(recursive_stream) = &mut self.recursive_stream {
let batch_result = ready!(recursive_stream.poll_next_unpin(cx));
match batch_result {
None => {
self.recursive_stream = None;
self.poll_next_iteration(cx)
}
Some(Ok(batch)) => self.push_batch(batch),
_ => Poll::Ready(batch_result),
}
} else {
Poll::Ready(None)
}
}
}
impl RecordBatchStream for RecursiveQueryStream {
fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
}
struct DistinctDeduplicator {
group_values: Box<dyn GroupValues>,
reservation: MemoryReservation,
intern_output_buffer: Vec<usize>,
}
impl DistinctDeduplicator {
fn new(schema: SchemaRef, task_context: &TaskContext) -> Result<Self> {
let group_values = new_group_values(schema, &GroupOrdering::None)?;
let reservation = MemoryConsumer::new("RecursiveQueryHashTable")
.register(task_context.memory_pool());
Ok(Self {
group_values,
reservation,
intern_output_buffer: Vec::new(),
})
}
fn deduplicate(&mut self, batch: &RecordBatch) -> Result<RecordBatch> {
let size_before = self.group_values.len();
self.intern_output_buffer.reserve(batch.num_rows());
self.group_values
.intern(batch.columns(), &mut self.intern_output_buffer)?;
let mask = new_groups_mask(&self.intern_output_buffer, size_before);
self.intern_output_buffer.clear();
self.reservation.try_resize(self.group_values.size())?;
Ok(filter_record_batch(batch, &mask)?)
}
}
fn new_groups_mask(
values: &[usize],
mut max_already_seen_group_id: usize,
) -> BooleanArray {
let mut output = BooleanBuilder::with_capacity(values.len());
for value in values {
if *value >= max_already_seen_group_id {
output.append_value(true);
max_already_seen_group_id = *value + 1; } else {
output.append_value(false);
}
}
output.finish()
}
#[cfg(test)]
mod tests {}