datafusion_physical_plan/
recursive_query.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines the recursive query plan
19
20use std::any::Any;
21use std::sync::Arc;
22use std::task::{Context, Poll};
23
24use super::work_table::{ReservedBatches, WorkTable, WorkTableExec};
25use crate::execution_plan::{Boundedness, EmissionType};
26use crate::{
27    metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet},
28    PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics,
29};
30use crate::{DisplayAs, DisplayFormatType, ExecutionPlan};
31
32use arrow::datatypes::SchemaRef;
33use arrow::record_batch::RecordBatch;
34use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
35use datafusion_common::{internal_datafusion_err, not_impl_err, Result};
36use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
37use datafusion_execution::TaskContext;
38use datafusion_physical_expr::{EquivalenceProperties, Partitioning};
39
40use futures::{ready, Stream, StreamExt};
41
42/// Recursive query execution plan.
43///
44/// This plan has two components: a base part (the static term) and
45/// a dynamic part (the recursive term). The execution will start from
46/// the base, and as long as the previous iteration produced at least
47/// a single new row (taking care of the distinction) the recursive
48/// part will be continuously executed.
49///
50/// Before each execution of the dynamic part, the rows from the previous
51/// iteration will be available in a "working table" (not a real table,
52/// can be only accessed using a continuance operation).
53///
54/// Note that there won't be any limit or checks applied to detect
55/// an infinite recursion, so it is up to the planner to ensure that
56/// it won't happen.
57#[derive(Debug, Clone)]
58pub struct RecursiveQueryExec {
59    /// Name of the query handler
60    name: String,
61    /// The working table of cte
62    work_table: Arc<WorkTable>,
63    /// The base part (static term)
64    static_term: Arc<dyn ExecutionPlan>,
65    /// The dynamic part (recursive term)
66    recursive_term: Arc<dyn ExecutionPlan>,
67    /// Distinction
68    is_distinct: bool,
69    /// Execution metrics
70    metrics: ExecutionPlanMetricsSet,
71    /// Cache holding plan properties like equivalences, output partitioning etc.
72    cache: PlanProperties,
73}
74
75impl RecursiveQueryExec {
76    /// Create a new RecursiveQueryExec
77    pub fn try_new(
78        name: String,
79        static_term: Arc<dyn ExecutionPlan>,
80        recursive_term: Arc<dyn ExecutionPlan>,
81        is_distinct: bool,
82    ) -> Result<Self> {
83        // Each recursive query needs its own work table
84        let work_table = Arc::new(WorkTable::new());
85        // Use the same work table for both the WorkTableExec and the recursive term
86        let recursive_term = assign_work_table(recursive_term, Arc::clone(&work_table))?;
87        let cache = Self::compute_properties(static_term.schema());
88        Ok(RecursiveQueryExec {
89            name,
90            static_term,
91            recursive_term,
92            is_distinct,
93            work_table,
94            metrics: ExecutionPlanMetricsSet::new(),
95            cache,
96        })
97    }
98
99    /// Ref to name
100    pub fn name(&self) -> &str {
101        &self.name
102    }
103
104    /// Ref to static term
105    pub fn static_term(&self) -> &Arc<dyn ExecutionPlan> {
106        &self.static_term
107    }
108
109    /// Ref to recursive term
110    pub fn recursive_term(&self) -> &Arc<dyn ExecutionPlan> {
111        &self.recursive_term
112    }
113
114    /// is distinct
115    pub fn is_distinct(&self) -> bool {
116        self.is_distinct
117    }
118
119    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
120    fn compute_properties(schema: SchemaRef) -> PlanProperties {
121        let eq_properties = EquivalenceProperties::new(schema);
122
123        PlanProperties::new(
124            eq_properties,
125            Partitioning::UnknownPartitioning(1),
126            EmissionType::Incremental,
127            Boundedness::Bounded,
128        )
129    }
130}
131
132impl ExecutionPlan for RecursiveQueryExec {
133    fn name(&self) -> &'static str {
134        "RecursiveQueryExec"
135    }
136
137    fn as_any(&self) -> &dyn Any {
138        self
139    }
140
141    fn properties(&self) -> &PlanProperties {
142        &self.cache
143    }
144
145    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
146        vec![&self.static_term, &self.recursive_term]
147    }
148
149    // TODO: control these hints and see whether we can
150    // infer some from the child plans (static/recursive terms).
151    fn maintains_input_order(&self) -> Vec<bool> {
152        vec![false, false]
153    }
154
155    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
156        vec![false, false]
157    }
158
159    fn required_input_distribution(&self) -> Vec<crate::Distribution> {
160        vec![
161            crate::Distribution::SinglePartition,
162            crate::Distribution::SinglePartition,
163        ]
164    }
165
166    fn with_new_children(
167        self: Arc<Self>,
168        children: Vec<Arc<dyn ExecutionPlan>>,
169    ) -> Result<Arc<dyn ExecutionPlan>> {
170        RecursiveQueryExec::try_new(
171            self.name.clone(),
172            Arc::clone(&children[0]),
173            Arc::clone(&children[1]),
174            self.is_distinct,
175        )
176        .map(|e| Arc::new(e) as _)
177    }
178
179    fn execute(
180        &self,
181        partition: usize,
182        context: Arc<TaskContext>,
183    ) -> Result<SendableRecordBatchStream> {
184        // TODO: we might be able to handle multiple partitions in the future.
185        if partition != 0 {
186            return Err(internal_datafusion_err!(
187                "RecursiveQueryExec got an invalid partition {partition} (expected 0)"
188            ));
189        }
190
191        let static_stream = self.static_term.execute(partition, Arc::clone(&context))?;
192        let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
193        Ok(Box::pin(RecursiveQueryStream::new(
194            context,
195            Arc::clone(&self.work_table),
196            Arc::clone(&self.recursive_term),
197            static_stream,
198            baseline_metrics,
199        )))
200    }
201
202    fn metrics(&self) -> Option<MetricsSet> {
203        Some(self.metrics.clone_inner())
204    }
205
206    fn statistics(&self) -> Result<Statistics> {
207        Ok(Statistics::new_unknown(&self.schema()))
208    }
209}
210
211impl DisplayAs for RecursiveQueryExec {
212    fn fmt_as(
213        &self,
214        t: DisplayFormatType,
215        f: &mut std::fmt::Formatter,
216    ) -> std::fmt::Result {
217        match t {
218            DisplayFormatType::Default | DisplayFormatType::Verbose => {
219                write!(
220                    f,
221                    "RecursiveQueryExec: name={}, is_distinct={}",
222                    self.name, self.is_distinct
223                )
224            }
225            DisplayFormatType::TreeRender => {
226                // TODO: collect info
227                write!(f, "")
228            }
229        }
230    }
231}
232
233/// The actual logic of the recursive queries happens during the streaming
234/// process. A simplified version of the algorithm is the following:
235///
236/// buffer = []
237///
238/// while batch := static_stream.next():
239///    buffer.push(batch)
240///    yield buffer
241///
242/// while buffer.len() > 0:
243///    sender, receiver = Channel()
244///    register_continuation(handle_name, receiver)
245///    sender.send(buffer.drain())
246///    recursive_stream = recursive_term.execute()
247///    while batch := recursive_stream.next():
248///        buffer.append(batch)
249///        yield buffer
250struct RecursiveQueryStream {
251    /// The context to be used for managing handlers & executing new tasks
252    task_context: Arc<TaskContext>,
253    /// The working table state, representing the self referencing cte table
254    work_table: Arc<WorkTable>,
255    /// The dynamic part (recursive term) as is (without being executed)
256    recursive_term: Arc<dyn ExecutionPlan>,
257    /// The static part (static term) as a stream. If the processing of this
258    /// part is completed, then it will be None.
259    static_stream: Option<SendableRecordBatchStream>,
260    /// The dynamic part (recursive term) as a stream. If the processing of this
261    /// part has not started yet, or has been completed, then it will be None.
262    recursive_stream: Option<SendableRecordBatchStream>,
263    /// The schema of the output.
264    schema: SchemaRef,
265    /// In-memory buffer for storing a copy of the current results. Will be
266    /// cleared after each iteration.
267    buffer: Vec<RecordBatch>,
268    /// Tracks the memory used by the buffer
269    reservation: MemoryReservation,
270    // /// Metrics.
271    _baseline_metrics: BaselineMetrics,
272}
273
274impl RecursiveQueryStream {
275    /// Create a new recursive query stream
276    fn new(
277        task_context: Arc<TaskContext>,
278        work_table: Arc<WorkTable>,
279        recursive_term: Arc<dyn ExecutionPlan>,
280        static_stream: SendableRecordBatchStream,
281        baseline_metrics: BaselineMetrics,
282    ) -> Self {
283        let schema = static_stream.schema();
284        let reservation =
285            MemoryConsumer::new("RecursiveQuery").register(task_context.memory_pool());
286        Self {
287            task_context,
288            work_table,
289            recursive_term,
290            static_stream: Some(static_stream),
291            recursive_stream: None,
292            schema,
293            buffer: vec![],
294            reservation,
295            _baseline_metrics: baseline_metrics,
296        }
297    }
298
299    /// Push a clone of the given batch to the in memory buffer, and then return
300    /// a poll with it.
301    fn push_batch(
302        mut self: std::pin::Pin<&mut Self>,
303        batch: RecordBatch,
304    ) -> Poll<Option<Result<RecordBatch>>> {
305        if let Err(e) = self.reservation.try_grow(batch.get_array_memory_size()) {
306            return Poll::Ready(Some(Err(e)));
307        }
308
309        self.buffer.push(batch.clone());
310        Poll::Ready(Some(Ok(batch)))
311    }
312
313    /// Start polling for the next iteration, will be called either after the static term
314    /// is completed or another term is completed. It will follow the algorithm above on
315    /// to check whether the recursion has ended.
316    fn poll_next_iteration(
317        mut self: std::pin::Pin<&mut Self>,
318        cx: &mut Context<'_>,
319    ) -> Poll<Option<Result<RecordBatch>>> {
320        let total_length = self
321            .buffer
322            .iter()
323            .fold(0, |acc, batch| acc + batch.num_rows());
324
325        if total_length == 0 {
326            return Poll::Ready(None);
327        }
328
329        // Update the work table with the current buffer
330        let reserved_batches = ReservedBatches::new(
331            std::mem::take(&mut self.buffer),
332            self.reservation.take(),
333        );
334        self.work_table.update(reserved_batches);
335
336        // We always execute (and re-execute iteratively) the first partition.
337        // Downstream plans should not expect any partitioning.
338        let partition = 0;
339
340        let recursive_plan = reset_plan_states(Arc::clone(&self.recursive_term))?;
341        self.recursive_stream =
342            Some(recursive_plan.execute(partition, Arc::clone(&self.task_context))?);
343        self.poll_next(cx)
344    }
345}
346
347fn assign_work_table(
348    plan: Arc<dyn ExecutionPlan>,
349    work_table: Arc<WorkTable>,
350) -> Result<Arc<dyn ExecutionPlan>> {
351    let mut work_table_refs = 0;
352    plan.transform_down(|plan| {
353        if let Some(new_plan) =
354            plan.with_new_state(Arc::clone(&work_table) as Arc<dyn Any + Send + Sync>)
355        {
356            if work_table_refs > 0 {
357                not_impl_err!(
358                    "Multiple recursive references to the same CTE are not supported"
359                )
360            } else {
361                work_table_refs += 1;
362                Ok(Transformed::yes(new_plan))
363            }
364        } else if plan.as_any().is::<RecursiveQueryExec>() {
365            not_impl_err!("Recursive queries cannot be nested")
366        } else {
367            Ok(Transformed::no(plan))
368        }
369    })
370    .data()
371}
372
373/// Some plans will change their internal states after execution, making them unable to be executed again.
374/// This function uses [`ExecutionPlan::reset_state`] to reset any internal state within the plan.
375///
376/// An example is `CrossJoinExec`, which loads the left table into memory and stores it in the plan.
377/// However, if the data of the left table is derived from the work table, it will become outdated
378/// as the work table changes. When the next iteration executes this plan again, we must clear the left table.
379fn reset_plan_states(plan: Arc<dyn ExecutionPlan>) -> Result<Arc<dyn ExecutionPlan>> {
380    plan.transform_up(|plan| {
381        // WorkTableExec's states have already been updated correctly.
382        if plan.as_any().is::<WorkTableExec>() {
383            Ok(Transformed::no(plan))
384        } else {
385            let new_plan = Arc::clone(&plan).reset_state()?;
386            Ok(Transformed::yes(new_plan))
387        }
388    })
389    .data()
390}
391
392impl Stream for RecursiveQueryStream {
393    type Item = Result<RecordBatch>;
394
395    fn poll_next(
396        mut self: std::pin::Pin<&mut Self>,
397        cx: &mut Context<'_>,
398    ) -> Poll<Option<Self::Item>> {
399        // TODO: we should use this poll to record some metrics!
400        if let Some(static_stream) = &mut self.static_stream {
401            // While the static term's stream is available, we'll be forwarding the batches from it (also
402            // saving them for the initial iteration of the recursive term).
403            let batch_result = ready!(static_stream.poll_next_unpin(cx));
404            match &batch_result {
405                None => {
406                    // Once this is done, we can start running the setup for the recursive term.
407                    self.static_stream = None;
408                    self.poll_next_iteration(cx)
409                }
410                Some(Ok(batch)) => self.push_batch(batch.clone()),
411                _ => Poll::Ready(batch_result),
412            }
413        } else if let Some(recursive_stream) = &mut self.recursive_stream {
414            let batch_result = ready!(recursive_stream.poll_next_unpin(cx));
415            match batch_result {
416                None => {
417                    self.recursive_stream = None;
418                    self.poll_next_iteration(cx)
419                }
420                Some(Ok(batch)) => self.push_batch(batch),
421                _ => Poll::Ready(batch_result),
422            }
423        } else {
424            Poll::Ready(None)
425        }
426    }
427}
428
429impl RecordBatchStream for RecursiveQueryStream {
430    /// Get the schema
431    fn schema(&self) -> SchemaRef {
432        Arc::clone(&self.schema)
433    }
434}
435
436#[cfg(test)]
437mod tests {}