datafusion_physical_plan/
recursive_query.rs1use std::any::Any;
21use std::sync::Arc;
22use std::task::{Context, Poll};
23
24use super::work_table::{ReservedBatches, WorkTable};
25use crate::aggregates::group_values::{GroupValues, new_group_values};
26use crate::aggregates::order::GroupOrdering;
27use crate::execution_plan::{Boundedness, EmissionType, reset_plan_states};
28use crate::metrics::{
29 BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, RecordOutput,
30};
31use crate::{
32 DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, RecordBatchStream,
33 SendableRecordBatchStream,
34};
35use arrow::array::{BooleanArray, BooleanBuilder};
36use arrow::compute::filter_record_batch;
37use arrow::datatypes::SchemaRef;
38use arrow::record_batch::RecordBatch;
39use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
40use datafusion_common::{Result, internal_datafusion_err, not_impl_err};
41use datafusion_execution::TaskContext;
42use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
43use datafusion_physical_expr::{EquivalenceProperties, Partitioning};
44
45use futures::{Stream, StreamExt, ready};
46
47#[derive(Debug, Clone)]
63pub struct RecursiveQueryExec {
64 name: String,
66 work_table: Arc<WorkTable>,
68 static_term: Arc<dyn ExecutionPlan>,
70 recursive_term: Arc<dyn ExecutionPlan>,
72 is_distinct: bool,
74 metrics: ExecutionPlanMetricsSet,
76 cache: Arc<PlanProperties>,
78}
79
80impl RecursiveQueryExec {
81 pub fn try_new(
83 name: String,
84 static_term: Arc<dyn ExecutionPlan>,
85 recursive_term: Arc<dyn ExecutionPlan>,
86 is_distinct: bool,
87 ) -> Result<Self> {
88 let work_table = Arc::new(WorkTable::new(name.clone()));
90 let recursive_term = assign_work_table(recursive_term, &work_table)?;
92 let cache = Self::compute_properties(static_term.schema());
93 Ok(RecursiveQueryExec {
94 name,
95 static_term,
96 recursive_term,
97 is_distinct,
98 work_table,
99 metrics: ExecutionPlanMetricsSet::new(),
100 cache: Arc::new(cache),
101 })
102 }
103
104 pub fn name(&self) -> &str {
106 &self.name
107 }
108
109 pub fn static_term(&self) -> &Arc<dyn ExecutionPlan> {
111 &self.static_term
112 }
113
114 pub fn recursive_term(&self) -> &Arc<dyn ExecutionPlan> {
116 &self.recursive_term
117 }
118
119 pub fn is_distinct(&self) -> bool {
121 self.is_distinct
122 }
123
124 fn compute_properties(schema: SchemaRef) -> PlanProperties {
126 let eq_properties = EquivalenceProperties::new(schema);
127
128 PlanProperties::new(
129 eq_properties,
130 Partitioning::UnknownPartitioning(1),
131 EmissionType::Incremental,
132 Boundedness::Bounded,
133 )
134 }
135}
136
137impl ExecutionPlan for RecursiveQueryExec {
138 fn name(&self) -> &'static str {
139 "RecursiveQueryExec"
140 }
141
142 fn as_any(&self) -> &dyn Any {
143 self
144 }
145
146 fn properties(&self) -> &Arc<PlanProperties> {
147 &self.cache
148 }
149
150 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
151 vec![&self.static_term, &self.recursive_term]
152 }
153
154 fn maintains_input_order(&self) -> Vec<bool> {
157 vec![false, false]
158 }
159
160 fn benefits_from_input_partitioning(&self) -> Vec<bool> {
161 vec![false, false]
162 }
163
164 fn required_input_distribution(&self) -> Vec<crate::Distribution> {
165 vec![
166 crate::Distribution::SinglePartition,
167 crate::Distribution::SinglePartition,
168 ]
169 }
170
171 fn with_new_children(
172 self: Arc<Self>,
173 children: Vec<Arc<dyn ExecutionPlan>>,
174 ) -> Result<Arc<dyn ExecutionPlan>> {
175 RecursiveQueryExec::try_new(
176 self.name.clone(),
177 Arc::clone(&children[0]),
178 Arc::clone(&children[1]),
179 self.is_distinct,
180 )
181 .map(|e| Arc::new(e) as _)
182 }
183
184 fn execute(
185 &self,
186 partition: usize,
187 context: Arc<TaskContext>,
188 ) -> Result<SendableRecordBatchStream> {
189 if partition != 0 {
191 return Err(internal_datafusion_err!(
192 "RecursiveQueryExec got an invalid partition {partition} (expected 0)"
193 ));
194 }
195
196 let static_stream = self.static_term.execute(partition, Arc::clone(&context))?;
197 let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
198 Ok(Box::pin(RecursiveQueryStream::new(
199 context,
200 Arc::clone(&self.work_table),
201 Arc::clone(&self.recursive_term),
202 static_stream,
203 self.is_distinct,
204 baseline_metrics,
205 )?))
206 }
207
208 fn metrics(&self) -> Option<MetricsSet> {
209 Some(self.metrics.clone_inner())
210 }
211}
212
213impl DisplayAs for RecursiveQueryExec {
214 fn fmt_as(
215 &self,
216 t: DisplayFormatType,
217 f: &mut std::fmt::Formatter,
218 ) -> std::fmt::Result {
219 match t {
220 DisplayFormatType::Default | DisplayFormatType::Verbose => {
221 write!(
222 f,
223 "RecursiveQueryExec: name={}, is_distinct={}",
224 self.name, self.is_distinct
225 )
226 }
227 DisplayFormatType::TreeRender => {
228 write!(f, "")
230 }
231 }
232 }
233}
234
235struct RecursiveQueryStream {
253 task_context: Arc<TaskContext>,
255 work_table: Arc<WorkTable>,
257 recursive_term: Arc<dyn ExecutionPlan>,
259 static_stream: Option<SendableRecordBatchStream>,
262 recursive_stream: Option<SendableRecordBatchStream>,
265 schema: SchemaRef,
267 buffer: Vec<RecordBatch>,
270 reservation: MemoryReservation,
272 distinct_deduplicator: Option<DistinctDeduplicator>,
274 baseline_metrics: BaselineMetrics,
276}
277
278impl RecursiveQueryStream {
279 fn new(
281 task_context: Arc<TaskContext>,
282 work_table: Arc<WorkTable>,
283 recursive_term: Arc<dyn ExecutionPlan>,
284 static_stream: SendableRecordBatchStream,
285 is_distinct: bool,
286 baseline_metrics: BaselineMetrics,
287 ) -> Result<Self> {
288 let schema = static_stream.schema();
289 let reservation =
290 MemoryConsumer::new("RecursiveQuery").register(task_context.memory_pool());
291 let distinct_deduplicator = is_distinct
292 .then(|| DistinctDeduplicator::new(Arc::clone(&schema), &task_context))
293 .transpose()?;
294 Ok(Self {
295 task_context,
296 work_table,
297 recursive_term,
298 static_stream: Some(static_stream),
299 recursive_stream: None,
300 schema,
301 buffer: vec![],
302 reservation,
303 distinct_deduplicator,
304 baseline_metrics,
305 })
306 }
307
308 fn push_batch(
311 mut self: std::pin::Pin<&mut Self>,
312 mut batch: RecordBatch,
313 ) -> Poll<Option<Result<RecordBatch>>> {
314 let baseline_metrics = self.baseline_metrics.clone();
315 if let Some(deduplicator) = &mut self.distinct_deduplicator {
316 let _timer_guard = baseline_metrics.elapsed_compute().timer();
317 batch = deduplicator.deduplicate(&batch)?;
318 }
319
320 if let Err(e) = self.reservation.try_grow(batch.get_array_memory_size()) {
321 return Poll::Ready(Some(Err(e)));
322 }
323 self.buffer.push(batch.clone());
324 (&batch).record_output(&baseline_metrics);
325 Poll::Ready(Some(Ok(batch)))
326 }
327
328 fn poll_next_iteration(
332 mut self: std::pin::Pin<&mut Self>,
333 cx: &mut Context<'_>,
334 ) -> Poll<Option<Result<RecordBatch>>> {
335 let total_length = self
336 .buffer
337 .iter()
338 .fold(0, |acc, batch| acc + batch.num_rows());
339
340 if total_length == 0 {
341 return Poll::Ready(None);
342 }
343
344 let reserved_batches = ReservedBatches::new(
346 std::mem::take(&mut self.buffer),
347 self.reservation.take(),
348 );
349 self.work_table.update(reserved_batches);
350
351 let partition = 0;
354
355 let recursive_plan = reset_plan_states(Arc::clone(&self.recursive_term))?;
356 self.recursive_stream =
357 Some(recursive_plan.execute(partition, Arc::clone(&self.task_context))?);
358 self.poll_next(cx)
359 }
360}
361
362fn assign_work_table(
363 plan: Arc<dyn ExecutionPlan>,
364 work_table: &Arc<WorkTable>,
365) -> Result<Arc<dyn ExecutionPlan>> {
366 let mut work_table_refs = 0;
367 plan.transform_down(|plan| {
368 if let Some(new_plan) =
369 plan.with_new_state(Arc::clone(work_table) as Arc<dyn Any + Send + Sync>)
370 {
371 if work_table_refs > 0 {
372 not_impl_err!(
373 "Multiple recursive references to the same CTE are not supported"
374 )
375 } else {
376 work_table_refs += 1;
377 Ok(Transformed::yes(new_plan))
378 }
379 } else {
380 Ok(Transformed::no(plan))
381 }
382 })
383 .data()
384}
385
386impl Stream for RecursiveQueryStream {
387 type Item = Result<RecordBatch>;
388
389 fn poll_next(
390 mut self: std::pin::Pin<&mut Self>,
391 cx: &mut Context<'_>,
392 ) -> Poll<Option<Self::Item>> {
393 if let Some(static_stream) = &mut self.static_stream {
394 let batch_result = ready!(static_stream.poll_next_unpin(cx));
397 match &batch_result {
398 None => {
399 self.static_stream = None;
401 self.poll_next_iteration(cx)
402 }
403 Some(Ok(batch)) => self.push_batch(batch.clone()),
404 _ => Poll::Ready(batch_result),
405 }
406 } else if let Some(recursive_stream) = &mut self.recursive_stream {
407 let batch_result = ready!(recursive_stream.poll_next_unpin(cx));
408 match batch_result {
409 None => {
410 self.recursive_stream = None;
411 self.poll_next_iteration(cx)
412 }
413 Some(Ok(batch)) => self.push_batch(batch),
414 _ => Poll::Ready(batch_result),
415 }
416 } else {
417 Poll::Ready(None)
418 }
419 }
420}
421
422impl RecordBatchStream for RecursiveQueryStream {
423 fn schema(&self) -> SchemaRef {
425 Arc::clone(&self.schema)
426 }
427}
428
429struct DistinctDeduplicator {
431 group_values: Box<dyn GroupValues>,
433 reservation: MemoryReservation,
434 intern_output_buffer: Vec<usize>,
435}
436
437impl DistinctDeduplicator {
438 fn new(schema: SchemaRef, task_context: &TaskContext) -> Result<Self> {
439 let group_values = new_group_values(schema, &GroupOrdering::None)?;
440 let reservation = MemoryConsumer::new("RecursiveQueryHashTable")
441 .register(task_context.memory_pool());
442 Ok(Self {
443 group_values,
444 reservation,
445 intern_output_buffer: Vec::new(),
446 })
447 }
448
449 fn deduplicate(&mut self, batch: &RecordBatch) -> Result<RecordBatch> {
456 let size_before = self.group_values.len();
457 self.intern_output_buffer.reserve(batch.num_rows());
458 self.group_values
459 .intern(batch.columns(), &mut self.intern_output_buffer)?;
460 let mask = new_groups_mask(&self.intern_output_buffer, size_before);
461 self.intern_output_buffer.clear();
462 self.reservation.try_resize(self.group_values.size())?;
464 Ok(filter_record_batch(batch, &mask)?)
465 }
466}
467
468fn new_groups_mask(
470 values: &[usize],
471 mut max_already_seen_group_id: usize,
472) -> BooleanArray {
473 let mut output = BooleanBuilder::with_capacity(values.len());
474 for value in values {
475 if *value >= max_already_seen_group_id {
476 output.append_value(true);
477 max_already_seen_group_id = *value + 1; } else {
479 output.append_value(false);
480 }
481 }
482 output.finish()
483}
484
485#[cfg(test)]
486mod tests {}