datafusion_physical_plan/
recursive_query.rs1use std::any::Any;
21use std::sync::Arc;
22use std::task::{Context, Poll};
23
24use super::work_table::{ReservedBatches, WorkTable, WorkTableExec};
25use crate::execution_plan::{Boundedness, EmissionType};
26use crate::{
27 metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet},
28 PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics,
29};
30use crate::{DisplayAs, DisplayFormatType, ExecutionPlan};
31
32use arrow::datatypes::SchemaRef;
33use arrow::record_batch::RecordBatch;
34use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
35use datafusion_common::{not_impl_err, DataFusionError, Result};
36use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
37use datafusion_execution::TaskContext;
38use datafusion_physical_expr::{EquivalenceProperties, Partitioning};
39
40use futures::{ready, Stream, StreamExt};
41
42#[derive(Debug, Clone)]
58pub struct RecursiveQueryExec {
59 name: String,
61 work_table: Arc<WorkTable>,
63 static_term: Arc<dyn ExecutionPlan>,
65 recursive_term: Arc<dyn ExecutionPlan>,
67 is_distinct: bool,
69 metrics: ExecutionPlanMetricsSet,
71 cache: PlanProperties,
73}
74
75impl RecursiveQueryExec {
76 pub fn try_new(
78 name: String,
79 static_term: Arc<dyn ExecutionPlan>,
80 recursive_term: Arc<dyn ExecutionPlan>,
81 is_distinct: bool,
82 ) -> Result<Self> {
83 let work_table = Arc::new(WorkTable::new());
85 let recursive_term = assign_work_table(recursive_term, Arc::clone(&work_table))?;
87 let cache = Self::compute_properties(static_term.schema());
88 Ok(RecursiveQueryExec {
89 name,
90 static_term,
91 recursive_term,
92 is_distinct,
93 work_table,
94 metrics: ExecutionPlanMetricsSet::new(),
95 cache,
96 })
97 }
98
99 pub fn name(&self) -> &str {
101 &self.name
102 }
103
104 pub fn static_term(&self) -> &Arc<dyn ExecutionPlan> {
106 &self.static_term
107 }
108
109 pub fn recursive_term(&self) -> &Arc<dyn ExecutionPlan> {
111 &self.recursive_term
112 }
113
114 pub fn is_distinct(&self) -> bool {
116 self.is_distinct
117 }
118
119 fn compute_properties(schema: SchemaRef) -> PlanProperties {
121 let eq_properties = EquivalenceProperties::new(schema);
122
123 PlanProperties::new(
124 eq_properties,
125 Partitioning::UnknownPartitioning(1),
126 EmissionType::Incremental,
127 Boundedness::Bounded,
128 )
129 }
130}
131
132impl ExecutionPlan for RecursiveQueryExec {
133 fn name(&self) -> &'static str {
134 "RecursiveQueryExec"
135 }
136
137 fn as_any(&self) -> &dyn Any {
138 self
139 }
140
141 fn properties(&self) -> &PlanProperties {
142 &self.cache
143 }
144
145 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
146 vec![&self.static_term, &self.recursive_term]
147 }
148
149 fn maintains_input_order(&self) -> Vec<bool> {
152 vec![false, false]
153 }
154
155 fn benefits_from_input_partitioning(&self) -> Vec<bool> {
156 vec![false, false]
157 }
158
159 fn required_input_distribution(&self) -> Vec<crate::Distribution> {
160 vec![
161 crate::Distribution::SinglePartition,
162 crate::Distribution::SinglePartition,
163 ]
164 }
165
166 fn with_new_children(
167 self: Arc<Self>,
168 children: Vec<Arc<dyn ExecutionPlan>>,
169 ) -> Result<Arc<dyn ExecutionPlan>> {
170 RecursiveQueryExec::try_new(
171 self.name.clone(),
172 Arc::clone(&children[0]),
173 Arc::clone(&children[1]),
174 self.is_distinct,
175 )
176 .map(|e| Arc::new(e) as _)
177 }
178
179 fn execute(
180 &self,
181 partition: usize,
182 context: Arc<TaskContext>,
183 ) -> Result<SendableRecordBatchStream> {
184 if partition != 0 {
186 return Err(DataFusionError::Internal(format!(
187 "RecursiveQueryExec got an invalid partition {partition} (expected 0)"
188 )));
189 }
190
191 let static_stream = self.static_term.execute(partition, Arc::clone(&context))?;
192 let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
193 Ok(Box::pin(RecursiveQueryStream::new(
194 context,
195 Arc::clone(&self.work_table),
196 Arc::clone(&self.recursive_term),
197 static_stream,
198 baseline_metrics,
199 )))
200 }
201
202 fn metrics(&self) -> Option<MetricsSet> {
203 Some(self.metrics.clone_inner())
204 }
205
206 fn statistics(&self) -> Result<Statistics> {
207 Ok(Statistics::new_unknown(&self.schema()))
208 }
209}
210
211impl DisplayAs for RecursiveQueryExec {
212 fn fmt_as(
213 &self,
214 t: DisplayFormatType,
215 f: &mut std::fmt::Formatter,
216 ) -> std::fmt::Result {
217 match t {
218 DisplayFormatType::Default | DisplayFormatType::Verbose => {
219 write!(
220 f,
221 "RecursiveQueryExec: name={}, is_distinct={}",
222 self.name, self.is_distinct
223 )
224 }
225 DisplayFormatType::TreeRender => {
226 write!(f, "")
228 }
229 }
230 }
231}
232
233struct RecursiveQueryStream {
252 task_context: Arc<TaskContext>,
254 work_table: Arc<WorkTable>,
256 recursive_term: Arc<dyn ExecutionPlan>,
258 static_stream: Option<SendableRecordBatchStream>,
261 recursive_stream: Option<SendableRecordBatchStream>,
264 schema: SchemaRef,
266 buffer: Vec<RecordBatch>,
269 reservation: MemoryReservation,
271 _baseline_metrics: BaselineMetrics,
273}
274
275impl RecursiveQueryStream {
276 fn new(
278 task_context: Arc<TaskContext>,
279 work_table: Arc<WorkTable>,
280 recursive_term: Arc<dyn ExecutionPlan>,
281 static_stream: SendableRecordBatchStream,
282 baseline_metrics: BaselineMetrics,
283 ) -> Self {
284 let schema = static_stream.schema();
285 let reservation =
286 MemoryConsumer::new("RecursiveQuery").register(task_context.memory_pool());
287 Self {
288 task_context,
289 work_table,
290 recursive_term,
291 static_stream: Some(static_stream),
292 recursive_stream: None,
293 schema,
294 buffer: vec![],
295 reservation,
296 _baseline_metrics: baseline_metrics,
297 }
298 }
299
300 fn push_batch(
303 mut self: std::pin::Pin<&mut Self>,
304 batch: RecordBatch,
305 ) -> Poll<Option<Result<RecordBatch>>> {
306 if let Err(e) = self.reservation.try_grow(batch.get_array_memory_size()) {
307 return Poll::Ready(Some(Err(e)));
308 }
309
310 self.buffer.push(batch.clone());
311 Poll::Ready(Some(Ok(batch)))
312 }
313
314 fn poll_next_iteration(
318 mut self: std::pin::Pin<&mut Self>,
319 cx: &mut Context<'_>,
320 ) -> Poll<Option<Result<RecordBatch>>> {
321 let total_length = self
322 .buffer
323 .iter()
324 .fold(0, |acc, batch| acc + batch.num_rows());
325
326 if total_length == 0 {
327 return Poll::Ready(None);
328 }
329
330 let reserved_batches = ReservedBatches::new(
332 std::mem::take(&mut self.buffer),
333 self.reservation.take(),
334 );
335 self.work_table.update(reserved_batches);
336
337 let partition = 0;
340
341 let recursive_plan = reset_plan_states(Arc::clone(&self.recursive_term))?;
342 self.recursive_stream =
343 Some(recursive_plan.execute(partition, Arc::clone(&self.task_context))?);
344 self.poll_next(cx)
345 }
346}
347
348fn assign_work_table(
349 plan: Arc<dyn ExecutionPlan>,
350 work_table: Arc<WorkTable>,
351) -> Result<Arc<dyn ExecutionPlan>> {
352 let mut work_table_refs = 0;
353 plan.transform_down(|plan| {
354 if let Some(exec) = plan.as_any().downcast_ref::<WorkTableExec>() {
355 if work_table_refs > 0 {
356 not_impl_err!(
357 "Multiple recursive references to the same CTE are not supported"
358 )
359 } else {
360 work_table_refs += 1;
361 Ok(Transformed::yes(Arc::new(
362 exec.with_work_table(Arc::clone(&work_table)),
363 )))
364 }
365 } else if plan.as_any().is::<RecursiveQueryExec>() {
366 not_impl_err!("Recursive queries cannot be nested")
367 } else {
368 Ok(Transformed::no(plan))
369 }
370 })
371 .data()
372}
373
374fn reset_plan_states(plan: Arc<dyn ExecutionPlan>) -> Result<Arc<dyn ExecutionPlan>> {
381 plan.transform_up(|plan| {
382 if plan.as_any().is::<WorkTableExec>() {
384 Ok(Transformed::no(plan))
385 } else {
386 let new_plan = Arc::clone(&plan)
387 .with_new_children(plan.children().into_iter().cloned().collect())?;
388 Ok(Transformed::yes(new_plan))
389 }
390 })
391 .data()
392}
393
394impl Stream for RecursiveQueryStream {
395 type Item = Result<RecordBatch>;
396
397 fn poll_next(
398 mut self: std::pin::Pin<&mut Self>,
399 cx: &mut Context<'_>,
400 ) -> Poll<Option<Self::Item>> {
401 if let Some(static_stream) = &mut self.static_stream {
403 let batch_result = ready!(static_stream.poll_next_unpin(cx));
406 match &batch_result {
407 None => {
408 self.static_stream = None;
410 self.poll_next_iteration(cx)
411 }
412 Some(Ok(batch)) => self.push_batch(batch.clone()),
413 _ => Poll::Ready(batch_result),
414 }
415 } else if let Some(recursive_stream) = &mut self.recursive_stream {
416 let batch_result = ready!(recursive_stream.poll_next_unpin(cx));
417 match batch_result {
418 None => {
419 self.recursive_stream = None;
420 self.poll_next_iteration(cx)
421 }
422 Some(Ok(batch)) => self.push_batch(batch),
423 _ => Poll::Ready(batch_result),
424 }
425 } else {
426 Poll::Ready(None)
427 }
428 }
429}
430
431impl RecordBatchStream for RecursiveQueryStream {
432 fn schema(&self) -> SchemaRef {
434 Arc::clone(&self.schema)
435 }
436}
437
438#[cfg(test)]
439mod tests {}