datafusion_physical_plan/
recursive_query.rs1use std::any::Any;
21use std::sync::Arc;
22use std::task::{Context, Poll};
23
24use super::work_table::{ReservedBatches, WorkTable, WorkTableExec};
25use crate::execution_plan::{Boundedness, EmissionType};
26use crate::{
27 metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet},
28 PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics,
29};
30use crate::{DisplayAs, DisplayFormatType, ExecutionPlan};
31
32use arrow::datatypes::SchemaRef;
33use arrow::record_batch::RecordBatch;
34use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
35use datafusion_common::{not_impl_err, DataFusionError, Result};
36use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
37use datafusion_execution::TaskContext;
38use datafusion_physical_expr::{EquivalenceProperties, Partitioning};
39
40use futures::{ready, Stream, StreamExt};
41
42#[derive(Debug, Clone)]
58pub struct RecursiveQueryExec {
59 name: String,
61 work_table: Arc<WorkTable>,
63 static_term: Arc<dyn ExecutionPlan>,
65 recursive_term: Arc<dyn ExecutionPlan>,
67 is_distinct: bool,
69 metrics: ExecutionPlanMetricsSet,
71 cache: PlanProperties,
73}
74
75impl RecursiveQueryExec {
76 pub fn try_new(
78 name: String,
79 static_term: Arc<dyn ExecutionPlan>,
80 recursive_term: Arc<dyn ExecutionPlan>,
81 is_distinct: bool,
82 ) -> Result<Self> {
83 let work_table = Arc::new(WorkTable::new());
85 let recursive_term = assign_work_table(recursive_term, Arc::clone(&work_table))?;
87 let cache = Self::compute_properties(static_term.schema());
88 Ok(RecursiveQueryExec {
89 name,
90 static_term,
91 recursive_term,
92 is_distinct,
93 work_table,
94 metrics: ExecutionPlanMetricsSet::new(),
95 cache,
96 })
97 }
98
99 pub fn name(&self) -> &str {
101 &self.name
102 }
103
104 pub fn static_term(&self) -> &Arc<dyn ExecutionPlan> {
106 &self.static_term
107 }
108
109 pub fn recursive_term(&self) -> &Arc<dyn ExecutionPlan> {
111 &self.recursive_term
112 }
113
114 pub fn is_distinct(&self) -> bool {
116 self.is_distinct
117 }
118
119 fn compute_properties(schema: SchemaRef) -> PlanProperties {
121 let eq_properties = EquivalenceProperties::new(schema);
122
123 PlanProperties::new(
124 eq_properties,
125 Partitioning::UnknownPartitioning(1),
126 EmissionType::Incremental,
127 Boundedness::Bounded,
128 )
129 }
130}
131
132impl ExecutionPlan for RecursiveQueryExec {
133 fn name(&self) -> &'static str {
134 "RecursiveQueryExec"
135 }
136
137 fn as_any(&self) -> &dyn Any {
138 self
139 }
140
141 fn properties(&self) -> &PlanProperties {
142 &self.cache
143 }
144
145 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
146 vec![&self.static_term, &self.recursive_term]
147 }
148
149 fn maintains_input_order(&self) -> Vec<bool> {
152 vec![false, false]
153 }
154
155 fn benefits_from_input_partitioning(&self) -> Vec<bool> {
156 vec![false, false]
157 }
158
159 fn required_input_distribution(&self) -> Vec<crate::Distribution> {
160 vec![
161 crate::Distribution::SinglePartition,
162 crate::Distribution::SinglePartition,
163 ]
164 }
165
166 fn with_new_children(
167 self: Arc<Self>,
168 children: Vec<Arc<dyn ExecutionPlan>>,
169 ) -> Result<Arc<dyn ExecutionPlan>> {
170 RecursiveQueryExec::try_new(
171 self.name.clone(),
172 Arc::clone(&children[0]),
173 Arc::clone(&children[1]),
174 self.is_distinct,
175 )
176 .map(|e| Arc::new(e) as _)
177 }
178
179 fn execute(
180 &self,
181 partition: usize,
182 context: Arc<TaskContext>,
183 ) -> Result<SendableRecordBatchStream> {
184 if partition != 0 {
186 return Err(DataFusionError::Internal(format!(
187 "RecursiveQueryExec got an invalid partition {partition} (expected 0)"
188 )));
189 }
190
191 let static_stream = self.static_term.execute(partition, Arc::clone(&context))?;
192 let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
193 Ok(Box::pin(RecursiveQueryStream::new(
194 context,
195 Arc::clone(&self.work_table),
196 Arc::clone(&self.recursive_term),
197 static_stream,
198 baseline_metrics,
199 )))
200 }
201
202 fn metrics(&self) -> Option<MetricsSet> {
203 Some(self.metrics.clone_inner())
204 }
205
206 fn statistics(&self) -> Result<Statistics> {
207 Ok(Statistics::new_unknown(&self.schema()))
208 }
209}
210
211impl DisplayAs for RecursiveQueryExec {
212 fn fmt_as(
213 &self,
214 t: DisplayFormatType,
215 f: &mut std::fmt::Formatter,
216 ) -> std::fmt::Result {
217 match t {
218 DisplayFormatType::Default | DisplayFormatType::Verbose => {
219 write!(
220 f,
221 "RecursiveQueryExec: name={}, is_distinct={}",
222 self.name, self.is_distinct
223 )
224 }
225 DisplayFormatType::TreeRender => {
226 write!(f, "")
228 }
229 }
230 }
231}
232
233struct RecursiveQueryStream {
252 task_context: Arc<TaskContext>,
254 work_table: Arc<WorkTable>,
256 recursive_term: Arc<dyn ExecutionPlan>,
258 static_stream: Option<SendableRecordBatchStream>,
261 recursive_stream: Option<SendableRecordBatchStream>,
264 schema: SchemaRef,
266 buffer: Vec<RecordBatch>,
269 reservation: MemoryReservation,
271 _baseline_metrics: BaselineMetrics,
273}
274
275impl RecursiveQueryStream {
276 fn new(
278 task_context: Arc<TaskContext>,
279 work_table: Arc<WorkTable>,
280 recursive_term: Arc<dyn ExecutionPlan>,
281 static_stream: SendableRecordBatchStream,
282 baseline_metrics: BaselineMetrics,
283 ) -> Self {
284 let schema = static_stream.schema();
285 let reservation =
286 MemoryConsumer::new("RecursiveQuery").register(task_context.memory_pool());
287 Self {
288 task_context,
289 work_table,
290 recursive_term,
291 static_stream: Some(static_stream),
292 recursive_stream: None,
293 schema,
294 buffer: vec![],
295 reservation,
296 _baseline_metrics: baseline_metrics,
297 }
298 }
299
300 fn push_batch(
303 mut self: std::pin::Pin<&mut Self>,
304 batch: RecordBatch,
305 ) -> Poll<Option<Result<RecordBatch>>> {
306 if let Err(e) = self.reservation.try_grow(batch.get_array_memory_size()) {
307 return Poll::Ready(Some(Err(e)));
308 }
309
310 self.buffer.push(batch.clone());
311 Poll::Ready(Some(Ok(batch)))
312 }
313
314 fn poll_next_iteration(
318 mut self: std::pin::Pin<&mut Self>,
319 cx: &mut Context<'_>,
320 ) -> Poll<Option<Result<RecordBatch>>> {
321 let total_length = self
322 .buffer
323 .iter()
324 .fold(0, |acc, batch| acc + batch.num_rows());
325
326 if total_length == 0 {
327 return Poll::Ready(None);
328 }
329
330 let reserved_batches = ReservedBatches::new(
332 std::mem::take(&mut self.buffer),
333 self.reservation.take(),
334 );
335 self.work_table.update(reserved_batches);
336
337 let partition = 0;
340
341 let recursive_plan = reset_plan_states(Arc::clone(&self.recursive_term))?;
342 self.recursive_stream =
343 Some(recursive_plan.execute(partition, Arc::clone(&self.task_context))?);
344 self.poll_next(cx)
345 }
346}
347
348fn assign_work_table(
349 plan: Arc<dyn ExecutionPlan>,
350 work_table: Arc<WorkTable>,
351) -> Result<Arc<dyn ExecutionPlan>> {
352 let mut work_table_refs = 0;
353 plan.transform_down(|plan| {
354 if let Some(new_plan) =
355 plan.with_new_state(Arc::clone(&work_table) as Arc<dyn Any + Send + Sync>)
356 {
357 if work_table_refs > 0 {
358 not_impl_err!(
359 "Multiple recursive references to the same CTE are not supported"
360 )
361 } else {
362 work_table_refs += 1;
363 Ok(Transformed::yes(new_plan))
364 }
365 } else if plan.as_any().is::<RecursiveQueryExec>() {
366 not_impl_err!("Recursive queries cannot be nested")
367 } else {
368 Ok(Transformed::no(plan))
369 }
370 })
371 .data()
372}
373
374fn reset_plan_states(plan: Arc<dyn ExecutionPlan>) -> Result<Arc<dyn ExecutionPlan>> {
381 plan.transform_up(|plan| {
382 if plan.as_any().is::<WorkTableExec>() {
384 Ok(Transformed::no(plan))
385 } else {
386 let new_plan = Arc::clone(&plan).reset_state()?;
387 Ok(Transformed::yes(new_plan))
388 }
389 })
390 .data()
391}
392
393impl Stream for RecursiveQueryStream {
394 type Item = Result<RecordBatch>;
395
396 fn poll_next(
397 mut self: std::pin::Pin<&mut Self>,
398 cx: &mut Context<'_>,
399 ) -> Poll<Option<Self::Item>> {
400 if let Some(static_stream) = &mut self.static_stream {
402 let batch_result = ready!(static_stream.poll_next_unpin(cx));
405 match &batch_result {
406 None => {
407 self.static_stream = None;
409 self.poll_next_iteration(cx)
410 }
411 Some(Ok(batch)) => self.push_batch(batch.clone()),
412 _ => Poll::Ready(batch_result),
413 }
414 } else if let Some(recursive_stream) = &mut self.recursive_stream {
415 let batch_result = ready!(recursive_stream.poll_next_unpin(cx));
416 match batch_result {
417 None => {
418 self.recursive_stream = None;
419 self.poll_next_iteration(cx)
420 }
421 Some(Ok(batch)) => self.push_batch(batch),
422 _ => Poll::Ready(batch_result),
423 }
424 } else {
425 Poll::Ready(None)
426 }
427 }
428}
429
430impl RecordBatchStream for RecursiveQueryStream {
431 fn schema(&self) -> SchemaRef {
433 Arc::clone(&self.schema)
434 }
435}
436
437#[cfg(test)]
438mod tests {}