Skip to main content

datafusion_physical_plan/
recursive_query.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines the recursive query plan
19
20use std::any::Any;
21use std::sync::Arc;
22use std::task::{Context, Poll};
23
24use super::work_table::{ReservedBatches, WorkTable};
25use crate::aggregates::group_values::{GroupValues, new_group_values};
26use crate::aggregates::order::GroupOrdering;
27use crate::common::project_plan_to_schema;
28use crate::execution_plan::{Boundedness, EmissionType, reset_plan_states};
29use crate::metrics::{
30    BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, RecordOutput,
31};
32use crate::{
33    DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, RecordBatchStream,
34    SendableRecordBatchStream,
35};
36use arrow::array::{BooleanArray, BooleanBuilder};
37use arrow::compute::filter_record_batch;
38use arrow::datatypes::{Field, Schema, SchemaRef};
39use arrow::record_batch::RecordBatch;
40use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
41use datafusion_common::{Result, internal_datafusion_err, not_impl_err};
42use datafusion_execution::TaskContext;
43use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
44use datafusion_physical_expr::{EquivalenceProperties, Partitioning};
45
46use futures::{Stream, StreamExt, ready};
47
48/// Recursive query execution plan.
49///
50/// This plan has two components: a base part (the static term) and
51/// a dynamic part (the recursive term). The execution will start from
52/// the base, and as long as the previous iteration produced at least
53/// a single new row (taking care of the distinction) the recursive
54/// part will be continuously executed.
55///
56/// Before each execution of the dynamic part, the rows from the previous
57/// iteration will be available in a "working table" (not a real table,
58/// can be only accessed using a continuance operation).
59///
60/// Note that there won't be any limit or checks applied to detect
61/// an infinite recursion, so it is up to the planner to ensure that
62/// it won't happen.
63#[derive(Debug, Clone)]
64pub struct RecursiveQueryExec {
65    /// Name of the query handler
66    name: String,
67    /// The working table of cte
68    work_table: Arc<WorkTable>,
69    /// The base part (static term)
70    static_term: Arc<dyn ExecutionPlan>,
71    /// The dynamic part (recursive term)
72    recursive_term: Arc<dyn ExecutionPlan>,
73    /// Distinction
74    is_distinct: bool,
75    /// Execution metrics
76    metrics: ExecutionPlanMetricsSet,
77    /// Cache holding plan properties like equivalences, output partitioning etc.
78    cache: Arc<PlanProperties>,
79}
80
81impl RecursiveQueryExec {
82    /// Create a new RecursiveQueryExec
83    pub fn try_new(
84        name: String,
85        static_term: Arc<dyn ExecutionPlan>,
86        recursive_term: Arc<dyn ExecutionPlan>,
87        is_distinct: bool,
88    ) -> Result<Self> {
89        // Each recursive query needs its own work table
90        let work_table = Arc::new(WorkTable::new(name.clone()));
91        // Use the same work table for both the WorkTableExec and the recursive term
92        let output_schema =
93            recursive_output_schema(&static_term.schema(), &recursive_term.schema());
94        let static_term = project_plan_to_schema(static_term, &output_schema)?;
95        let recursive_term = assign_work_table(recursive_term, &work_table)?;
96        let recursive_term = project_plan_to_schema(recursive_term, &output_schema)?;
97        let cache = Self::compute_properties(output_schema);
98        Ok(RecursiveQueryExec {
99            name,
100            static_term,
101            recursive_term,
102            is_distinct,
103            work_table,
104            metrics: ExecutionPlanMetricsSet::new(),
105            cache: Arc::new(cache),
106        })
107    }
108
109    /// Ref to name
110    pub fn name(&self) -> &str {
111        &self.name
112    }
113
114    /// Ref to static term
115    pub fn static_term(&self) -> &Arc<dyn ExecutionPlan> {
116        &self.static_term
117    }
118
119    /// Ref to recursive term
120    pub fn recursive_term(&self) -> &Arc<dyn ExecutionPlan> {
121        &self.recursive_term
122    }
123
124    /// is distinct
125    pub fn is_distinct(&self) -> bool {
126        self.is_distinct
127    }
128
129    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
130    fn compute_properties(schema: SchemaRef) -> PlanProperties {
131        let eq_properties = EquivalenceProperties::new(schema);
132
133        PlanProperties::new(
134            eq_properties,
135            Partitioning::UnknownPartitioning(1),
136            EmissionType::Incremental,
137            Boundedness::Bounded,
138        )
139    }
140}
141
142impl ExecutionPlan for RecursiveQueryExec {
143    fn name(&self) -> &'static str {
144        "RecursiveQueryExec"
145    }
146
147    fn properties(&self) -> &Arc<PlanProperties> {
148        &self.cache
149    }
150
151    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
152        vec![&self.static_term, &self.recursive_term]
153    }
154
155    // TODO: control these hints and see whether we can
156    // infer some from the child plans (static/recursive terms).
157    fn maintains_input_order(&self) -> Vec<bool> {
158        vec![false, false]
159    }
160
161    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
162        vec![false, false]
163    }
164
165    fn required_input_distribution(&self) -> Vec<crate::Distribution> {
166        vec![
167            crate::Distribution::SinglePartition,
168            crate::Distribution::SinglePartition,
169        ]
170    }
171
172    fn with_new_children(
173        self: Arc<Self>,
174        children: Vec<Arc<dyn ExecutionPlan>>,
175    ) -> Result<Arc<dyn ExecutionPlan>> {
176        RecursiveQueryExec::try_new(
177            self.name.clone(),
178            Arc::clone(&children[0]),
179            Arc::clone(&children[1]),
180            self.is_distinct,
181        )
182        .map(|e| Arc::new(e) as _)
183    }
184
185    fn execute(
186        &self,
187        partition: usize,
188        context: Arc<TaskContext>,
189    ) -> Result<SendableRecordBatchStream> {
190        // TODO: we might be able to handle multiple partitions in the future.
191        if partition != 0 {
192            return Err(internal_datafusion_err!(
193                "RecursiveQueryExec got an invalid partition {partition} (expected 0)"
194            ));
195        }
196
197        let static_stream = self.static_term.execute(partition, Arc::clone(&context))?;
198        let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
199        Ok(Box::pin(RecursiveQueryStream::new(
200            context,
201            Arc::clone(&self.work_table),
202            Arc::clone(&self.recursive_term),
203            static_stream,
204            self.is_distinct,
205            baseline_metrics,
206        )?))
207    }
208
209    fn metrics(&self) -> Option<MetricsSet> {
210        Some(self.metrics.clone_inner())
211    }
212}
213
214impl DisplayAs for RecursiveQueryExec {
215    fn fmt_as(
216        &self,
217        t: DisplayFormatType,
218        f: &mut std::fmt::Formatter,
219    ) -> std::fmt::Result {
220        match t {
221            DisplayFormatType::Default | DisplayFormatType::Verbose => {
222                write!(
223                    f,
224                    "RecursiveQueryExec: name={}, is_distinct={}",
225                    self.name, self.is_distinct
226                )
227            }
228            DisplayFormatType::TreeRender => {
229                // TODO: collect info
230                write!(f, "")
231            }
232        }
233    }
234}
235
236/// The actual logic of the recursive queries happens during the streaming
237/// process. A simplified version of the algorithm is the following:
238///
239/// buffer = []
240///
241/// while batch := static_stream.next():
242///    buffer.push(batch)
243///    yield buffer
244///
245/// while buffer.len() > 0:
246///    sender, receiver = Channel()
247///    register_continuation(handle_name, receiver)
248///    sender.send(buffer.drain())
249///    recursive_stream = recursive_term.execute()
250///    while batch := recursive_stream.next():
251///        buffer.append(batch)
252///        yield buffer
253struct RecursiveQueryStream {
254    /// The context to be used for managing handlers & executing new tasks
255    task_context: Arc<TaskContext>,
256    /// The working table state, representing the self referencing cte table
257    work_table: Arc<WorkTable>,
258    /// The dynamic part (recursive term) as is (without being executed)
259    recursive_term: Arc<dyn ExecutionPlan>,
260    /// The static part (static term) as a stream. If the processing of this
261    /// part is completed, then it will be None.
262    static_stream: Option<SendableRecordBatchStream>,
263    /// The dynamic part (recursive term) as a stream. If the processing of this
264    /// part has not started yet, or has been completed, then it will be None.
265    recursive_stream: Option<SendableRecordBatchStream>,
266    /// The schema of the output.
267    schema: SchemaRef,
268    /// In-memory buffer for storing a copy of the current results. Will be
269    /// cleared after each iteration.
270    buffer: Vec<RecordBatch>,
271    /// Tracks the memory used by the buffer
272    reservation: MemoryReservation,
273    /// If the distinct flag is set, then we use this hash table to remove duplicates from result and work tables
274    distinct_deduplicator: Option<DistinctDeduplicator>,
275    /// Metrics.
276    baseline_metrics: BaselineMetrics,
277}
278
279impl RecursiveQueryStream {
280    /// Create a new recursive query stream
281    fn new(
282        task_context: Arc<TaskContext>,
283        work_table: Arc<WorkTable>,
284        recursive_term: Arc<dyn ExecutionPlan>,
285        static_stream: SendableRecordBatchStream,
286        is_distinct: bool,
287        baseline_metrics: BaselineMetrics,
288    ) -> Result<Self> {
289        let schema = static_stream.schema();
290        let reservation =
291            MemoryConsumer::new("RecursiveQuery").register(task_context.memory_pool());
292        let distinct_deduplicator = is_distinct
293            .then(|| DistinctDeduplicator::new(Arc::clone(&schema), &task_context))
294            .transpose()?;
295        Ok(Self {
296            task_context,
297            work_table,
298            recursive_term,
299            static_stream: Some(static_stream),
300            recursive_stream: None,
301            schema,
302            buffer: vec![],
303            reservation,
304            distinct_deduplicator,
305            baseline_metrics,
306        })
307    }
308
309    /// Push a clone of the given batch to the in memory buffer, and then return
310    /// a poll with it.
311    fn push_batch(
312        mut self: std::pin::Pin<&mut Self>,
313        mut batch: RecordBatch,
314    ) -> Poll<Option<Result<RecordBatch>>> {
315        let baseline_metrics = self.baseline_metrics.clone();
316
317        if let Some(deduplicator) = &mut self.distinct_deduplicator {
318            let _timer_guard = baseline_metrics.elapsed_compute().timer();
319            batch = deduplicator.deduplicate(&batch)?;
320        }
321
322        if let Err(e) = self.reservation.try_grow(batch.get_array_memory_size()) {
323            return Poll::Ready(Some(Err(e)));
324        }
325        self.buffer.push(batch.clone());
326        (&batch).record_output(&baseline_metrics);
327        Poll::Ready(Some(Ok(batch)))
328    }
329
330    /// Start polling for the next iteration, will be called either after the static term
331    /// is completed or another term is completed. It will follow the algorithm above on
332    /// to check whether the recursion has ended.
333    fn poll_next_iteration(
334        mut self: std::pin::Pin<&mut Self>,
335        cx: &mut Context<'_>,
336    ) -> Poll<Option<Result<RecordBatch>>> {
337        let total_length = self
338            .buffer
339            .iter()
340            .fold(0, |acc, batch| acc + batch.num_rows());
341
342        if total_length == 0 {
343            return Poll::Ready(None);
344        }
345
346        // Update the work table with the current buffer
347        let reserved_batches = ReservedBatches::new(
348            std::mem::take(&mut self.buffer),
349            self.reservation.take(),
350        );
351        self.work_table.update(reserved_batches);
352
353        // We always execute (and re-execute iteratively) the first partition.
354        // Downstream plans should not expect any partitioning.
355        let partition = 0;
356
357        let recursive_plan = reset_plan_states(Arc::clone(&self.recursive_term))?;
358        self.recursive_stream =
359            Some(recursive_plan.execute(partition, Arc::clone(&self.task_context))?);
360        self.poll_next(cx)
361    }
362}
363
364fn recursive_output_schema(
365    static_schema: &SchemaRef,
366    recursive_schema: &SchemaRef,
367) -> SchemaRef {
368    let fields = static_schema
369        .fields()
370        .iter()
371        .zip(recursive_schema.fields())
372        .map(|(static_field, recursive_field)| {
373            Field::new(
374                static_field.name(),
375                static_field.data_type().clone(),
376                static_field.is_nullable() || recursive_field.is_nullable(),
377            )
378            .with_metadata(static_field.metadata().clone())
379        })
380        .collect::<Vec<_>>();
381
382    Arc::new(Schema::new_with_metadata(
383        fields,
384        static_schema.metadata().clone(),
385    ))
386}
387
388fn assign_work_table(
389    plan: Arc<dyn ExecutionPlan>,
390    work_table: &Arc<WorkTable>,
391) -> Result<Arc<dyn ExecutionPlan>> {
392    let mut work_table_refs = 0;
393    plan.transform_down(|plan| {
394        if let Some(new_plan) =
395            plan.with_new_state(Arc::clone(work_table) as Arc<dyn Any + Send + Sync>)
396        {
397            if work_table_refs > 0 {
398                not_impl_err!(
399                    "Multiple recursive references to the same CTE are not supported"
400                )
401            } else {
402                work_table_refs += 1;
403                Ok(Transformed::yes(new_plan))
404            }
405        } else {
406            Ok(Transformed::no(plan))
407        }
408    })
409    .data()
410}
411
412impl Stream for RecursiveQueryStream {
413    type Item = Result<RecordBatch>;
414
415    fn poll_next(
416        mut self: std::pin::Pin<&mut Self>,
417        cx: &mut Context<'_>,
418    ) -> Poll<Option<Self::Item>> {
419        if let Some(static_stream) = &mut self.static_stream {
420            // While the static term's stream is available, we'll be forwarding the batches from it (also
421            // saving them for the initial iteration of the recursive term).
422            let batch_result = ready!(static_stream.poll_next_unpin(cx));
423            match &batch_result {
424                None => {
425                    // Once this is done, we can start running the setup for the recursive term.
426                    self.static_stream = None;
427                    self.poll_next_iteration(cx)
428                }
429                Some(Ok(batch)) => self.push_batch(batch.clone()),
430                _ => Poll::Ready(batch_result),
431            }
432        } else if let Some(recursive_stream) = &mut self.recursive_stream {
433            let batch_result = ready!(recursive_stream.poll_next_unpin(cx));
434            match batch_result {
435                None => {
436                    self.recursive_stream = None;
437                    self.poll_next_iteration(cx)
438                }
439                Some(Ok(batch)) => self.push_batch(batch),
440                _ => Poll::Ready(batch_result),
441            }
442        } else {
443            Poll::Ready(None)
444        }
445    }
446}
447
448impl RecordBatchStream for RecursiveQueryStream {
449    /// Get the schema
450    fn schema(&self) -> SchemaRef {
451        Arc::clone(&self.schema)
452    }
453}
454
455/// Deduplicator based on a hash table.
456struct DistinctDeduplicator {
457    /// Grouped rows used for distinct
458    group_values: Box<dyn GroupValues>,
459    reservation: MemoryReservation,
460    intern_output_buffer: Vec<usize>,
461}
462
463impl DistinctDeduplicator {
464    fn new(schema: SchemaRef, task_context: &TaskContext) -> Result<Self> {
465        let group_values = new_group_values(schema, &GroupOrdering::None)?;
466        let reservation = MemoryConsumer::new("RecursiveQueryHashTable")
467            .register(task_context.memory_pool());
468        Ok(Self {
469            group_values,
470            reservation,
471            intern_output_buffer: Vec::new(),
472        })
473    }
474
475    /// Remove duplicated rows from the given batch, keeping a state between batches.
476    ///
477    /// We use a hash table to allocate new group ids for the new rows.
478    /// [`GroupValues`] allocate increasing group ids.
479    /// Hence, if groups (i.e., rows) are new, then they have ids >= length before interning, we keep them.
480    /// We also detect duplicates by enforcing that group ids are increasing.
481    fn deduplicate(&mut self, batch: &RecordBatch) -> Result<RecordBatch> {
482        let size_before = self.group_values.len();
483        self.intern_output_buffer.reserve(batch.num_rows());
484        self.group_values
485            .intern(batch.columns(), &mut self.intern_output_buffer)?;
486        let mask = new_groups_mask(&self.intern_output_buffer, size_before);
487        self.intern_output_buffer.clear();
488        // We update the reservation to reflect the new size of the hash table.
489        self.reservation.try_resize(self.group_values.size())?;
490        Ok(filter_record_batch(batch, &mask)?)
491    }
492}
493
494/// Return a mask, each element being true if, and only if, the element is greater than all previous elements and greater or equal than the provided max_already_seen_group_id
495fn new_groups_mask(
496    values: &[usize],
497    mut max_already_seen_group_id: usize,
498) -> BooleanArray {
499    let mut output = BooleanBuilder::with_capacity(values.len());
500    for value in values {
501        if *value >= max_already_seen_group_id {
502            output.append_value(true);
503            max_already_seen_group_id = *value + 1; // We want to be increasing
504        } else {
505            output.append_value(false);
506        }
507    }
508    output.finish()
509}
510
511#[cfg(test)]
512mod tests {
513    use super::*;
514    use crate::empty::EmptyExec;
515    use crate::projection::ProjectionExec;
516
517    use arrow::datatypes::{DataType, Field, Schema};
518
519    fn empty_exec(fields: Vec<Field>) -> Arc<dyn ExecutionPlan> {
520        Arc::new(EmptyExec::new(Arc::new(Schema::new(fields))))
521    }
522
523    #[test]
524    fn recursive_query_exec_projects_recursive_term_to_reconciled_schema() -> Result<()> {
525        let static_term = empty_exec(vec![Field::new("value", DataType::Int32, false)]);
526        let recursive_term =
527            empty_exec(vec![Field::new("value + Int32(1)", DataType::Int32, false)]);
528
529        let exec = RecursiveQueryExec::try_new(
530            "numbers".to_string(),
531            Arc::clone(&static_term),
532            Arc::clone(&recursive_term),
533            false,
534        )?;
535
536        assert_eq!(exec.schema(), static_term.schema());
537        let projection = exec
538            .recursive_term()
539            .downcast_ref::<ProjectionExec>()
540            .expect("recursive term should be aligned with ProjectionExec");
541        assert!(Arc::ptr_eq(projection.input(), &recursive_term));
542        assert!(!projection.schema().field(0).is_nullable());
543        assert_eq!(projection.expr()[0].alias, "value");
544        Ok(())
545    }
546
547    #[test]
548    fn recursive_query_exec_reconciles_nullability() -> Result<()> {
549        let static_term = empty_exec(vec![Field::new("value", DataType::Int32, false)]);
550        let recursive_term =
551            empty_exec(vec![Field::new("value + Int32(1)", DataType::Int32, true)]);
552
553        let exec = RecursiveQueryExec::try_new(
554            "numbers".to_string(),
555            static_term,
556            recursive_term,
557            false,
558        )?;
559
560        assert!(exec.schema().field(0).is_nullable());
561        assert!(exec.static_term().schema().field(0).is_nullable());
562        assert!(exec.recursive_term().schema().field(0).is_nullable());
563        Ok(())
564    }
565}