datafusion_physical_plan/
recursive_query.rs1use std::any::Any;
21use std::sync::Arc;
22use std::task::{Context, Poll};
23
24use super::work_table::{ReservedBatches, WorkTable};
25use crate::aggregates::group_values::{GroupValues, new_group_values};
26use crate::aggregates::order::GroupOrdering;
27use crate::execution_plan::{Boundedness, EmissionType};
28use crate::metrics::{
29 BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, RecordOutput,
30};
31use crate::{
32 DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, RecordBatchStream,
33 SendableRecordBatchStream, Statistics,
34};
35use arrow::array::{BooleanArray, BooleanBuilder};
36use arrow::compute::filter_record_batch;
37use arrow::datatypes::SchemaRef;
38use arrow::record_batch::RecordBatch;
39use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
40use datafusion_common::{Result, internal_datafusion_err, not_impl_err};
41use datafusion_execution::TaskContext;
42use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
43use datafusion_physical_expr::{EquivalenceProperties, Partitioning};
44
45use futures::{Stream, StreamExt, ready};
46
47#[derive(Debug, Clone)]
63pub struct RecursiveQueryExec {
64 name: String,
66 work_table: Arc<WorkTable>,
68 static_term: Arc<dyn ExecutionPlan>,
70 recursive_term: Arc<dyn ExecutionPlan>,
72 is_distinct: bool,
74 metrics: ExecutionPlanMetricsSet,
76 cache: PlanProperties,
78}
79
80impl RecursiveQueryExec {
81 pub fn try_new(
83 name: String,
84 static_term: Arc<dyn ExecutionPlan>,
85 recursive_term: Arc<dyn ExecutionPlan>,
86 is_distinct: bool,
87 ) -> Result<Self> {
88 let work_table = Arc::new(WorkTable::new(name.clone()));
90 let recursive_term = assign_work_table(recursive_term, &work_table)?;
92 let cache = Self::compute_properties(static_term.schema());
93 Ok(RecursiveQueryExec {
94 name,
95 static_term,
96 recursive_term,
97 is_distinct,
98 work_table,
99 metrics: ExecutionPlanMetricsSet::new(),
100 cache,
101 })
102 }
103
104 pub fn name(&self) -> &str {
106 &self.name
107 }
108
109 pub fn static_term(&self) -> &Arc<dyn ExecutionPlan> {
111 &self.static_term
112 }
113
114 pub fn recursive_term(&self) -> &Arc<dyn ExecutionPlan> {
116 &self.recursive_term
117 }
118
119 pub fn is_distinct(&self) -> bool {
121 self.is_distinct
122 }
123
124 fn compute_properties(schema: SchemaRef) -> PlanProperties {
126 let eq_properties = EquivalenceProperties::new(schema);
127
128 PlanProperties::new(
129 eq_properties,
130 Partitioning::UnknownPartitioning(1),
131 EmissionType::Incremental,
132 Boundedness::Bounded,
133 )
134 }
135}
136
137impl ExecutionPlan for RecursiveQueryExec {
138 fn name(&self) -> &'static str {
139 "RecursiveQueryExec"
140 }
141
142 fn as_any(&self) -> &dyn Any {
143 self
144 }
145
146 fn properties(&self) -> &PlanProperties {
147 &self.cache
148 }
149
150 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
151 vec![&self.static_term, &self.recursive_term]
152 }
153
154 fn maintains_input_order(&self) -> Vec<bool> {
157 vec![false, false]
158 }
159
160 fn benefits_from_input_partitioning(&self) -> Vec<bool> {
161 vec![false, false]
162 }
163
164 fn required_input_distribution(&self) -> Vec<crate::Distribution> {
165 vec![
166 crate::Distribution::SinglePartition,
167 crate::Distribution::SinglePartition,
168 ]
169 }
170
171 fn with_new_children(
172 self: Arc<Self>,
173 children: Vec<Arc<dyn ExecutionPlan>>,
174 ) -> Result<Arc<dyn ExecutionPlan>> {
175 RecursiveQueryExec::try_new(
176 self.name.clone(),
177 Arc::clone(&children[0]),
178 Arc::clone(&children[1]),
179 self.is_distinct,
180 )
181 .map(|e| Arc::new(e) as _)
182 }
183
184 fn execute(
185 &self,
186 partition: usize,
187 context: Arc<TaskContext>,
188 ) -> Result<SendableRecordBatchStream> {
189 if partition != 0 {
191 return Err(internal_datafusion_err!(
192 "RecursiveQueryExec got an invalid partition {partition} (expected 0)"
193 ));
194 }
195
196 let static_stream = self.static_term.execute(partition, Arc::clone(&context))?;
197 let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
198 Ok(Box::pin(RecursiveQueryStream::new(
199 context,
200 Arc::clone(&self.work_table),
201 Arc::clone(&self.recursive_term),
202 static_stream,
203 self.is_distinct,
204 baseline_metrics,
205 )?))
206 }
207
208 fn metrics(&self) -> Option<MetricsSet> {
209 Some(self.metrics.clone_inner())
210 }
211
212 fn statistics(&self) -> Result<Statistics> {
213 Ok(Statistics::new_unknown(&self.schema()))
214 }
215}
216
217impl DisplayAs for RecursiveQueryExec {
218 fn fmt_as(
219 &self,
220 t: DisplayFormatType,
221 f: &mut std::fmt::Formatter,
222 ) -> std::fmt::Result {
223 match t {
224 DisplayFormatType::Default | DisplayFormatType::Verbose => {
225 write!(
226 f,
227 "RecursiveQueryExec: name={}, is_distinct={}",
228 self.name, self.is_distinct
229 )
230 }
231 DisplayFormatType::TreeRender => {
232 write!(f, "")
234 }
235 }
236 }
237}
238
239struct RecursiveQueryStream {
257 task_context: Arc<TaskContext>,
259 work_table: Arc<WorkTable>,
261 recursive_term: Arc<dyn ExecutionPlan>,
263 static_stream: Option<SendableRecordBatchStream>,
266 recursive_stream: Option<SendableRecordBatchStream>,
269 schema: SchemaRef,
271 buffer: Vec<RecordBatch>,
274 reservation: MemoryReservation,
276 distinct_deduplicator: Option<DistinctDeduplicator>,
278 baseline_metrics: BaselineMetrics,
280}
281
282impl RecursiveQueryStream {
283 fn new(
285 task_context: Arc<TaskContext>,
286 work_table: Arc<WorkTable>,
287 recursive_term: Arc<dyn ExecutionPlan>,
288 static_stream: SendableRecordBatchStream,
289 is_distinct: bool,
290 baseline_metrics: BaselineMetrics,
291 ) -> Result<Self> {
292 let schema = static_stream.schema();
293 let reservation =
294 MemoryConsumer::new("RecursiveQuery").register(task_context.memory_pool());
295 let distinct_deduplicator = is_distinct
296 .then(|| DistinctDeduplicator::new(Arc::clone(&schema), &task_context))
297 .transpose()?;
298 Ok(Self {
299 task_context,
300 work_table,
301 recursive_term,
302 static_stream: Some(static_stream),
303 recursive_stream: None,
304 schema,
305 buffer: vec![],
306 reservation,
307 distinct_deduplicator,
308 baseline_metrics,
309 })
310 }
311
312 fn push_batch(
315 mut self: std::pin::Pin<&mut Self>,
316 mut batch: RecordBatch,
317 ) -> Poll<Option<Result<RecordBatch>>> {
318 let baseline_metrics = self.baseline_metrics.clone();
319 if let Some(deduplicator) = &mut self.distinct_deduplicator {
320 let _timer_guard = baseline_metrics.elapsed_compute().timer();
321 batch = deduplicator.deduplicate(&batch)?;
322 }
323
324 if let Err(e) = self.reservation.try_grow(batch.get_array_memory_size()) {
325 return Poll::Ready(Some(Err(e)));
326 }
327 self.buffer.push(batch.clone());
328 (&batch).record_output(&baseline_metrics);
329 Poll::Ready(Some(Ok(batch)))
330 }
331
332 fn poll_next_iteration(
336 mut self: std::pin::Pin<&mut Self>,
337 cx: &mut Context<'_>,
338 ) -> Poll<Option<Result<RecordBatch>>> {
339 let total_length = self
340 .buffer
341 .iter()
342 .fold(0, |acc, batch| acc + batch.num_rows());
343
344 if total_length == 0 {
345 return Poll::Ready(None);
346 }
347
348 let reserved_batches = ReservedBatches::new(
350 std::mem::take(&mut self.buffer),
351 self.reservation.take(),
352 );
353 self.work_table.update(reserved_batches);
354
355 let partition = 0;
358
359 let recursive_plan = reset_plan_states(Arc::clone(&self.recursive_term))?;
360 self.recursive_stream =
361 Some(recursive_plan.execute(partition, Arc::clone(&self.task_context))?);
362 self.poll_next(cx)
363 }
364}
365
366fn assign_work_table(
367 plan: Arc<dyn ExecutionPlan>,
368 work_table: &Arc<WorkTable>,
369) -> Result<Arc<dyn ExecutionPlan>> {
370 let mut work_table_refs = 0;
371 plan.transform_down(|plan| {
372 if let Some(new_plan) =
373 plan.with_new_state(Arc::clone(work_table) as Arc<dyn Any + Send + Sync>)
374 {
375 if work_table_refs > 0 {
376 not_impl_err!(
377 "Multiple recursive references to the same CTE are not supported"
378 )
379 } else {
380 work_table_refs += 1;
381 Ok(Transformed::yes(new_plan))
382 }
383 } else {
384 Ok(Transformed::no(plan))
385 }
386 })
387 .data()
388}
389
390fn reset_plan_states(plan: Arc<dyn ExecutionPlan>) -> Result<Arc<dyn ExecutionPlan>> {
397 plan.transform_up(|plan| {
398 let new_plan = Arc::clone(&plan).reset_state()?;
399 Ok(Transformed::yes(new_plan))
400 })
401 .data()
402}
403
404impl Stream for RecursiveQueryStream {
405 type Item = Result<RecordBatch>;
406
407 fn poll_next(
408 mut self: std::pin::Pin<&mut Self>,
409 cx: &mut Context<'_>,
410 ) -> Poll<Option<Self::Item>> {
411 if let Some(static_stream) = &mut self.static_stream {
412 let batch_result = ready!(static_stream.poll_next_unpin(cx));
415 match &batch_result {
416 None => {
417 self.static_stream = None;
419 self.poll_next_iteration(cx)
420 }
421 Some(Ok(batch)) => self.push_batch(batch.clone()),
422 _ => Poll::Ready(batch_result),
423 }
424 } else if let Some(recursive_stream) = &mut self.recursive_stream {
425 let batch_result = ready!(recursive_stream.poll_next_unpin(cx));
426 match batch_result {
427 None => {
428 self.recursive_stream = None;
429 self.poll_next_iteration(cx)
430 }
431 Some(Ok(batch)) => self.push_batch(batch),
432 _ => Poll::Ready(batch_result),
433 }
434 } else {
435 Poll::Ready(None)
436 }
437 }
438}
439
440impl RecordBatchStream for RecursiveQueryStream {
441 fn schema(&self) -> SchemaRef {
443 Arc::clone(&self.schema)
444 }
445}
446
447struct DistinctDeduplicator {
449 group_values: Box<dyn GroupValues>,
451 reservation: MemoryReservation,
452 intern_output_buffer: Vec<usize>,
453}
454
455impl DistinctDeduplicator {
456 fn new(schema: SchemaRef, task_context: &TaskContext) -> Result<Self> {
457 let group_values = new_group_values(schema, &GroupOrdering::None)?;
458 let reservation = MemoryConsumer::new("RecursiveQueryHashTable")
459 .register(task_context.memory_pool());
460 Ok(Self {
461 group_values,
462 reservation,
463 intern_output_buffer: Vec::new(),
464 })
465 }
466
467 fn deduplicate(&mut self, batch: &RecordBatch) -> Result<RecordBatch> {
474 let size_before = self.group_values.len();
475 self.intern_output_buffer.reserve(batch.num_rows());
476 self.group_values
477 .intern(batch.columns(), &mut self.intern_output_buffer)?;
478 let mask = new_groups_mask(&self.intern_output_buffer, size_before);
479 self.intern_output_buffer.clear();
480 self.reservation.try_resize(self.group_values.size())?;
482 Ok(filter_record_batch(batch, &mask)?)
483 }
484}
485
486fn new_groups_mask(
488 values: &[usize],
489 mut max_already_seen_group_id: usize,
490) -> BooleanArray {
491 let mut output = BooleanBuilder::with_capacity(values.len());
492 for value in values {
493 if *value >= max_already_seen_group_id {
494 output.append_value(true);
495 max_already_seen_group_id = *value + 1; } else {
497 output.append_value(false);
498 }
499 }
500 output.finish()
501}
502
503#[cfg(test)]
504mod tests {}