1use std::any::Any;
21use std::sync::Arc;
22use std::task::{Context, Poll};
23
24use super::work_table::{ReservedBatches, WorkTable};
25use crate::aggregates::group_values::{GroupValues, new_group_values};
26use crate::aggregates::order::GroupOrdering;
27use crate::common::project_plan_to_schema;
28use crate::execution_plan::{Boundedness, EmissionType, reset_plan_states};
29use crate::metrics::{
30 BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, RecordOutput,
31};
32use crate::{
33 DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, RecordBatchStream,
34 SendableRecordBatchStream,
35};
36use arrow::array::{BooleanArray, BooleanBuilder};
37use arrow::compute::filter_record_batch;
38use arrow::datatypes::{Field, Schema, SchemaRef};
39use arrow::record_batch::RecordBatch;
40use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
41use datafusion_common::{Result, internal_datafusion_err, not_impl_err};
42use datafusion_execution::TaskContext;
43use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
44use datafusion_physical_expr::{EquivalenceProperties, Partitioning};
45
46use futures::{Stream, StreamExt, ready};
47
48#[derive(Debug, Clone)]
64pub struct RecursiveQueryExec {
65 name: String,
67 work_table: Arc<WorkTable>,
69 static_term: Arc<dyn ExecutionPlan>,
71 recursive_term: Arc<dyn ExecutionPlan>,
73 is_distinct: bool,
75 metrics: ExecutionPlanMetricsSet,
77 cache: Arc<PlanProperties>,
79}
80
81impl RecursiveQueryExec {
82 pub fn try_new(
84 name: String,
85 static_term: Arc<dyn ExecutionPlan>,
86 recursive_term: Arc<dyn ExecutionPlan>,
87 is_distinct: bool,
88 ) -> Result<Self> {
89 let work_table = Arc::new(WorkTable::new(name.clone()));
91 let output_schema =
93 recursive_output_schema(&static_term.schema(), &recursive_term.schema());
94 let static_term = project_plan_to_schema(static_term, &output_schema)?;
95 let recursive_term = assign_work_table(recursive_term, &work_table)?;
96 let recursive_term = project_plan_to_schema(recursive_term, &output_schema)?;
97 let cache = Self::compute_properties(output_schema);
98 Ok(RecursiveQueryExec {
99 name,
100 static_term,
101 recursive_term,
102 is_distinct,
103 work_table,
104 metrics: ExecutionPlanMetricsSet::new(),
105 cache: Arc::new(cache),
106 })
107 }
108
109 pub fn name(&self) -> &str {
111 &self.name
112 }
113
114 pub fn static_term(&self) -> &Arc<dyn ExecutionPlan> {
116 &self.static_term
117 }
118
119 pub fn recursive_term(&self) -> &Arc<dyn ExecutionPlan> {
121 &self.recursive_term
122 }
123
124 pub fn is_distinct(&self) -> bool {
126 self.is_distinct
127 }
128
129 fn compute_properties(schema: SchemaRef) -> PlanProperties {
131 let eq_properties = EquivalenceProperties::new(schema);
132
133 PlanProperties::new(
134 eq_properties,
135 Partitioning::UnknownPartitioning(1),
136 EmissionType::Incremental,
137 Boundedness::Bounded,
138 )
139 }
140}
141
142impl ExecutionPlan for RecursiveQueryExec {
143 fn name(&self) -> &'static str {
144 "RecursiveQueryExec"
145 }
146
147 fn properties(&self) -> &Arc<PlanProperties> {
148 &self.cache
149 }
150
151 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
152 vec![&self.static_term, &self.recursive_term]
153 }
154
155 fn maintains_input_order(&self) -> Vec<bool> {
158 vec![false, false]
159 }
160
161 fn benefits_from_input_partitioning(&self) -> Vec<bool> {
162 vec![false, false]
163 }
164
165 fn required_input_distribution(&self) -> Vec<crate::Distribution> {
166 vec![
167 crate::Distribution::SinglePartition,
168 crate::Distribution::SinglePartition,
169 ]
170 }
171
172 fn with_new_children(
173 self: Arc<Self>,
174 children: Vec<Arc<dyn ExecutionPlan>>,
175 ) -> Result<Arc<dyn ExecutionPlan>> {
176 RecursiveQueryExec::try_new(
177 self.name.clone(),
178 Arc::clone(&children[0]),
179 Arc::clone(&children[1]),
180 self.is_distinct,
181 )
182 .map(|e| Arc::new(e) as _)
183 }
184
185 fn execute(
186 &self,
187 partition: usize,
188 context: Arc<TaskContext>,
189 ) -> Result<SendableRecordBatchStream> {
190 if partition != 0 {
192 return Err(internal_datafusion_err!(
193 "RecursiveQueryExec got an invalid partition {partition} (expected 0)"
194 ));
195 }
196
197 let static_stream = self.static_term.execute(partition, Arc::clone(&context))?;
198 let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
199 Ok(Box::pin(RecursiveQueryStream::new(
200 context,
201 Arc::clone(&self.work_table),
202 Arc::clone(&self.recursive_term),
203 static_stream,
204 self.is_distinct,
205 baseline_metrics,
206 )?))
207 }
208
209 fn metrics(&self) -> Option<MetricsSet> {
210 Some(self.metrics.clone_inner())
211 }
212}
213
214impl DisplayAs for RecursiveQueryExec {
215 fn fmt_as(
216 &self,
217 t: DisplayFormatType,
218 f: &mut std::fmt::Formatter,
219 ) -> std::fmt::Result {
220 match t {
221 DisplayFormatType::Default | DisplayFormatType::Verbose => {
222 write!(
223 f,
224 "RecursiveQueryExec: name={}, is_distinct={}",
225 self.name, self.is_distinct
226 )
227 }
228 DisplayFormatType::TreeRender => {
229 write!(f, "")
231 }
232 }
233 }
234}
235
236struct RecursiveQueryStream {
254 task_context: Arc<TaskContext>,
256 work_table: Arc<WorkTable>,
258 recursive_term: Arc<dyn ExecutionPlan>,
260 static_stream: Option<SendableRecordBatchStream>,
263 recursive_stream: Option<SendableRecordBatchStream>,
266 schema: SchemaRef,
268 buffer: Vec<RecordBatch>,
271 reservation: MemoryReservation,
273 distinct_deduplicator: Option<DistinctDeduplicator>,
275 baseline_metrics: BaselineMetrics,
277}
278
279impl RecursiveQueryStream {
280 fn new(
282 task_context: Arc<TaskContext>,
283 work_table: Arc<WorkTable>,
284 recursive_term: Arc<dyn ExecutionPlan>,
285 static_stream: SendableRecordBatchStream,
286 is_distinct: bool,
287 baseline_metrics: BaselineMetrics,
288 ) -> Result<Self> {
289 let schema = static_stream.schema();
290 let reservation =
291 MemoryConsumer::new("RecursiveQuery").register(task_context.memory_pool());
292 let distinct_deduplicator = is_distinct
293 .then(|| DistinctDeduplicator::new(Arc::clone(&schema), &task_context))
294 .transpose()?;
295 Ok(Self {
296 task_context,
297 work_table,
298 recursive_term,
299 static_stream: Some(static_stream),
300 recursive_stream: None,
301 schema,
302 buffer: vec![],
303 reservation,
304 distinct_deduplicator,
305 baseline_metrics,
306 })
307 }
308
309 fn push_batch(
312 mut self: std::pin::Pin<&mut Self>,
313 mut batch: RecordBatch,
314 ) -> Poll<Option<Result<RecordBatch>>> {
315 let baseline_metrics = self.baseline_metrics.clone();
316
317 if let Some(deduplicator) = &mut self.distinct_deduplicator {
318 let _timer_guard = baseline_metrics.elapsed_compute().timer();
319 batch = deduplicator.deduplicate(&batch)?;
320 }
321
322 if let Err(e) = self.reservation.try_grow(batch.get_array_memory_size()) {
323 return Poll::Ready(Some(Err(e)));
324 }
325 self.buffer.push(batch.clone());
326 (&batch).record_output(&baseline_metrics);
327 Poll::Ready(Some(Ok(batch)))
328 }
329
330 fn poll_next_iteration(
334 mut self: std::pin::Pin<&mut Self>,
335 cx: &mut Context<'_>,
336 ) -> Poll<Option<Result<RecordBatch>>> {
337 let total_length = self
338 .buffer
339 .iter()
340 .fold(0, |acc, batch| acc + batch.num_rows());
341
342 if total_length == 0 {
343 return Poll::Ready(None);
344 }
345
346 let reserved_batches = ReservedBatches::new(
348 std::mem::take(&mut self.buffer),
349 self.reservation.take(),
350 );
351 self.work_table.update(reserved_batches);
352
353 let partition = 0;
356
357 let recursive_plan = reset_plan_states(Arc::clone(&self.recursive_term))?;
358 self.recursive_stream =
359 Some(recursive_plan.execute(partition, Arc::clone(&self.task_context))?);
360 self.poll_next(cx)
361 }
362}
363
364fn recursive_output_schema(
365 static_schema: &SchemaRef,
366 recursive_schema: &SchemaRef,
367) -> SchemaRef {
368 let fields = static_schema
369 .fields()
370 .iter()
371 .zip(recursive_schema.fields())
372 .map(|(static_field, recursive_field)| {
373 Field::new(
374 static_field.name(),
375 static_field.data_type().clone(),
376 static_field.is_nullable() || recursive_field.is_nullable(),
377 )
378 .with_metadata(static_field.metadata().clone())
379 })
380 .collect::<Vec<_>>();
381
382 Arc::new(Schema::new_with_metadata(
383 fields,
384 static_schema.metadata().clone(),
385 ))
386}
387
388fn assign_work_table(
389 plan: Arc<dyn ExecutionPlan>,
390 work_table: &Arc<WorkTable>,
391) -> Result<Arc<dyn ExecutionPlan>> {
392 let mut work_table_refs = 0;
393 plan.transform_down(|plan| {
394 if let Some(new_plan) =
395 plan.with_new_state(Arc::clone(work_table) as Arc<dyn Any + Send + Sync>)
396 {
397 if work_table_refs > 0 {
398 not_impl_err!(
399 "Multiple recursive references to the same CTE are not supported"
400 )
401 } else {
402 work_table_refs += 1;
403 Ok(Transformed::yes(new_plan))
404 }
405 } else {
406 Ok(Transformed::no(plan))
407 }
408 })
409 .data()
410}
411
412impl Stream for RecursiveQueryStream {
413 type Item = Result<RecordBatch>;
414
415 fn poll_next(
416 mut self: std::pin::Pin<&mut Self>,
417 cx: &mut Context<'_>,
418 ) -> Poll<Option<Self::Item>> {
419 if let Some(static_stream) = &mut self.static_stream {
420 let batch_result = ready!(static_stream.poll_next_unpin(cx));
423 match &batch_result {
424 None => {
425 self.static_stream = None;
427 self.poll_next_iteration(cx)
428 }
429 Some(Ok(batch)) => self.push_batch(batch.clone()),
430 _ => Poll::Ready(batch_result),
431 }
432 } else if let Some(recursive_stream) = &mut self.recursive_stream {
433 let batch_result = ready!(recursive_stream.poll_next_unpin(cx));
434 match batch_result {
435 None => {
436 self.recursive_stream = None;
437 self.poll_next_iteration(cx)
438 }
439 Some(Ok(batch)) => self.push_batch(batch),
440 _ => Poll::Ready(batch_result),
441 }
442 } else {
443 Poll::Ready(None)
444 }
445 }
446}
447
448impl RecordBatchStream for RecursiveQueryStream {
449 fn schema(&self) -> SchemaRef {
451 Arc::clone(&self.schema)
452 }
453}
454
455struct DistinctDeduplicator {
457 group_values: Box<dyn GroupValues>,
459 reservation: MemoryReservation,
460 intern_output_buffer: Vec<usize>,
461}
462
463impl DistinctDeduplicator {
464 fn new(schema: SchemaRef, task_context: &TaskContext) -> Result<Self> {
465 let group_values = new_group_values(schema, &GroupOrdering::None)?;
466 let reservation = MemoryConsumer::new("RecursiveQueryHashTable")
467 .register(task_context.memory_pool());
468 Ok(Self {
469 group_values,
470 reservation,
471 intern_output_buffer: Vec::new(),
472 })
473 }
474
475 fn deduplicate(&mut self, batch: &RecordBatch) -> Result<RecordBatch> {
482 let size_before = self.group_values.len();
483 self.intern_output_buffer.reserve(batch.num_rows());
484 self.group_values
485 .intern(batch.columns(), &mut self.intern_output_buffer)?;
486 let mask = new_groups_mask(&self.intern_output_buffer, size_before);
487 self.intern_output_buffer.clear();
488 self.reservation.try_resize(self.group_values.size())?;
490 Ok(filter_record_batch(batch, &mask)?)
491 }
492}
493
494fn new_groups_mask(
496 values: &[usize],
497 mut max_already_seen_group_id: usize,
498) -> BooleanArray {
499 let mut output = BooleanBuilder::with_capacity(values.len());
500 for value in values {
501 if *value >= max_already_seen_group_id {
502 output.append_value(true);
503 max_already_seen_group_id = *value + 1; } else {
505 output.append_value(false);
506 }
507 }
508 output.finish()
509}
510
511#[cfg(test)]
512mod tests {
513 use super::*;
514 use crate::empty::EmptyExec;
515 use crate::projection::ProjectionExec;
516
517 use arrow::datatypes::{DataType, Field, Schema};
518
519 fn empty_exec(fields: Vec<Field>) -> Arc<dyn ExecutionPlan> {
520 Arc::new(EmptyExec::new(Arc::new(Schema::new(fields))))
521 }
522
523 #[test]
524 fn recursive_query_exec_projects_recursive_term_to_reconciled_schema() -> Result<()> {
525 let static_term = empty_exec(vec![Field::new("value", DataType::Int32, false)]);
526 let recursive_term =
527 empty_exec(vec![Field::new("value + Int32(1)", DataType::Int32, false)]);
528
529 let exec = RecursiveQueryExec::try_new(
530 "numbers".to_string(),
531 Arc::clone(&static_term),
532 Arc::clone(&recursive_term),
533 false,
534 )?;
535
536 assert_eq!(exec.schema(), static_term.schema());
537 let projection = exec
538 .recursive_term()
539 .downcast_ref::<ProjectionExec>()
540 .expect("recursive term should be aligned with ProjectionExec");
541 assert!(Arc::ptr_eq(projection.input(), &recursive_term));
542 assert!(!projection.schema().field(0).is_nullable());
543 assert_eq!(projection.expr()[0].alias, "value");
544 Ok(())
545 }
546
547 #[test]
548 fn recursive_query_exec_reconciles_nullability() -> Result<()> {
549 let static_term = empty_exec(vec![Field::new("value", DataType::Int32, false)]);
550 let recursive_term =
551 empty_exec(vec![Field::new("value + Int32(1)", DataType::Int32, true)]);
552
553 let exec = RecursiveQueryExec::try_new(
554 "numbers".to_string(),
555 static_term,
556 recursive_term,
557 false,
558 )?;
559
560 assert!(exec.schema().field(0).is_nullable());
561 assert!(exec.static_term().schema().field(0).is_nullable());
562 assert!(exec.recursive_term().schema().field(0).is_nullable());
563 Ok(())
564 }
565}