datafusion_physical_plan/
recursive_query.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines the recursive query plan
19
20use std::any::Any;
21use std::sync::Arc;
22use std::task::{Context, Poll};
23
24use super::work_table::{ReservedBatches, WorkTable, WorkTableExec};
25use crate::execution_plan::{Boundedness, EmissionType};
26use crate::{
27    metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet},
28    PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics,
29};
30use crate::{DisplayAs, DisplayFormatType, ExecutionPlan};
31
32use arrow::datatypes::SchemaRef;
33use arrow::record_batch::RecordBatch;
34use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
35use datafusion_common::{not_impl_err, DataFusionError, Result};
36use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
37use datafusion_execution::TaskContext;
38use datafusion_physical_expr::{EquivalenceProperties, Partitioning};
39
40use futures::{ready, Stream, StreamExt};
41
42/// Recursive query execution plan.
43///
44/// This plan has two components: a base part (the static term) and
45/// a dynamic part (the recursive term). The execution will start from
46/// the base, and as long as the previous iteration produced at least
47/// a single new row (taking care of the distinction) the recursive
48/// part will be continuously executed.
49///
50/// Before each execution of the dynamic part, the rows from the previous
51/// iteration will be available in a "working table" (not a real table,
52/// can be only accessed using a continuance operation).
53///
54/// Note that there won't be any limit or checks applied to detect
55/// an infinite recursion, so it is up to the planner to ensure that
56/// it won't happen.
57#[derive(Debug, Clone)]
58pub struct RecursiveQueryExec {
59    /// Name of the query handler
60    name: String,
61    /// The working table of cte
62    work_table: Arc<WorkTable>,
63    /// The base part (static term)
64    static_term: Arc<dyn ExecutionPlan>,
65    /// The dynamic part (recursive term)
66    recursive_term: Arc<dyn ExecutionPlan>,
67    /// Distinction
68    is_distinct: bool,
69    /// Execution metrics
70    metrics: ExecutionPlanMetricsSet,
71    /// Cache holding plan properties like equivalences, output partitioning etc.
72    cache: PlanProperties,
73}
74
75impl RecursiveQueryExec {
76    /// Create a new RecursiveQueryExec
77    pub fn try_new(
78        name: String,
79        static_term: Arc<dyn ExecutionPlan>,
80        recursive_term: Arc<dyn ExecutionPlan>,
81        is_distinct: bool,
82    ) -> Result<Self> {
83        // Each recursive query needs its own work table
84        let work_table = Arc::new(WorkTable::new());
85        // Use the same work table for both the WorkTableExec and the recursive term
86        let recursive_term = assign_work_table(recursive_term, Arc::clone(&work_table))?;
87        let cache = Self::compute_properties(static_term.schema());
88        Ok(RecursiveQueryExec {
89            name,
90            static_term,
91            recursive_term,
92            is_distinct,
93            work_table,
94            metrics: ExecutionPlanMetricsSet::new(),
95            cache,
96        })
97    }
98
99    /// Ref to name
100    pub fn name(&self) -> &str {
101        &self.name
102    }
103
104    /// Ref to static term
105    pub fn static_term(&self) -> &Arc<dyn ExecutionPlan> {
106        &self.static_term
107    }
108
109    /// Ref to recursive term
110    pub fn recursive_term(&self) -> &Arc<dyn ExecutionPlan> {
111        &self.recursive_term
112    }
113
114    /// is distinct
115    pub fn is_distinct(&self) -> bool {
116        self.is_distinct
117    }
118
119    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
120    fn compute_properties(schema: SchemaRef) -> PlanProperties {
121        let eq_properties = EquivalenceProperties::new(schema);
122
123        PlanProperties::new(
124            eq_properties,
125            Partitioning::UnknownPartitioning(1),
126            EmissionType::Incremental,
127            Boundedness::Bounded,
128        )
129    }
130}
131
132impl ExecutionPlan for RecursiveQueryExec {
133    fn name(&self) -> &'static str {
134        "RecursiveQueryExec"
135    }
136
137    fn as_any(&self) -> &dyn Any {
138        self
139    }
140
141    fn properties(&self) -> &PlanProperties {
142        &self.cache
143    }
144
145    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
146        vec![&self.static_term, &self.recursive_term]
147    }
148
149    // TODO: control these hints and see whether we can
150    // infer some from the child plans (static/recursive terms).
151    fn maintains_input_order(&self) -> Vec<bool> {
152        vec![false, false]
153    }
154
155    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
156        vec![false, false]
157    }
158
159    fn required_input_distribution(&self) -> Vec<crate::Distribution> {
160        vec![
161            crate::Distribution::SinglePartition,
162            crate::Distribution::SinglePartition,
163        ]
164    }
165
166    fn with_new_children(
167        self: Arc<Self>,
168        children: Vec<Arc<dyn ExecutionPlan>>,
169    ) -> Result<Arc<dyn ExecutionPlan>> {
170        RecursiveQueryExec::try_new(
171            self.name.clone(),
172            Arc::clone(&children[0]),
173            Arc::clone(&children[1]),
174            self.is_distinct,
175        )
176        .map(|e| Arc::new(e) as _)
177    }
178
179    fn execute(
180        &self,
181        partition: usize,
182        context: Arc<TaskContext>,
183    ) -> Result<SendableRecordBatchStream> {
184        // TODO: we might be able to handle multiple partitions in the future.
185        if partition != 0 {
186            return Err(DataFusionError::Internal(format!(
187                "RecursiveQueryExec got an invalid partition {partition} (expected 0)"
188            )));
189        }
190
191        let static_stream = self.static_term.execute(partition, Arc::clone(&context))?;
192        let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
193        Ok(Box::pin(RecursiveQueryStream::new(
194            context,
195            Arc::clone(&self.work_table),
196            Arc::clone(&self.recursive_term),
197            static_stream,
198            baseline_metrics,
199        )))
200    }
201
202    fn metrics(&self) -> Option<MetricsSet> {
203        Some(self.metrics.clone_inner())
204    }
205
206    fn statistics(&self) -> Result<Statistics> {
207        Ok(Statistics::new_unknown(&self.schema()))
208    }
209}
210
211impl DisplayAs for RecursiveQueryExec {
212    fn fmt_as(
213        &self,
214        t: DisplayFormatType,
215        f: &mut std::fmt::Formatter,
216    ) -> std::fmt::Result {
217        match t {
218            DisplayFormatType::Default | DisplayFormatType::Verbose => {
219                write!(
220                    f,
221                    "RecursiveQueryExec: name={}, is_distinct={}",
222                    self.name, self.is_distinct
223                )
224            }
225            DisplayFormatType::TreeRender => {
226                // TODO: collect info
227                write!(f, "")
228            }
229        }
230    }
231}
232
233/// The actual logic of the recursive queries happens during the streaming
234/// process. A simplified version of the algorithm is the following:
235///
236/// buffer = []
237///
238/// while batch := static_stream.next():
239///    buffer.push(batch)
240///    yield buffer
241///
242/// while buffer.len() > 0:
243///    sender, receiver = Channel()
244///    register_continuation(handle_name, receiver)
245///    sender.send(buffer.drain())
246///    recursive_stream = recursive_term.execute()
247///    while batch := recursive_stream.next():
248///        buffer.append(batch)
249///        yield buffer
250///
251struct RecursiveQueryStream {
252    /// The context to be used for managing handlers & executing new tasks
253    task_context: Arc<TaskContext>,
254    /// The working table state, representing the self referencing cte table
255    work_table: Arc<WorkTable>,
256    /// The dynamic part (recursive term) as is (without being executed)
257    recursive_term: Arc<dyn ExecutionPlan>,
258    /// The static part (static term) as a stream. If the processing of this
259    /// part is completed, then it will be None.
260    static_stream: Option<SendableRecordBatchStream>,
261    /// The dynamic part (recursive term) as a stream. If the processing of this
262    /// part has not started yet, or has been completed, then it will be None.
263    recursive_stream: Option<SendableRecordBatchStream>,
264    /// The schema of the output.
265    schema: SchemaRef,
266    /// In-memory buffer for storing a copy of the current results. Will be
267    /// cleared after each iteration.
268    buffer: Vec<RecordBatch>,
269    /// Tracks the memory used by the buffer
270    reservation: MemoryReservation,
271    // /// Metrics.
272    _baseline_metrics: BaselineMetrics,
273}
274
275impl RecursiveQueryStream {
276    /// Create a new recursive query stream
277    fn new(
278        task_context: Arc<TaskContext>,
279        work_table: Arc<WorkTable>,
280        recursive_term: Arc<dyn ExecutionPlan>,
281        static_stream: SendableRecordBatchStream,
282        baseline_metrics: BaselineMetrics,
283    ) -> Self {
284        let schema = static_stream.schema();
285        let reservation =
286            MemoryConsumer::new("RecursiveQuery").register(task_context.memory_pool());
287        Self {
288            task_context,
289            work_table,
290            recursive_term,
291            static_stream: Some(static_stream),
292            recursive_stream: None,
293            schema,
294            buffer: vec![],
295            reservation,
296            _baseline_metrics: baseline_metrics,
297        }
298    }
299
300    /// Push a clone of the given batch to the in memory buffer, and then return
301    /// a poll with it.
302    fn push_batch(
303        mut self: std::pin::Pin<&mut Self>,
304        batch: RecordBatch,
305    ) -> Poll<Option<Result<RecordBatch>>> {
306        if let Err(e) = self.reservation.try_grow(batch.get_array_memory_size()) {
307            return Poll::Ready(Some(Err(e)));
308        }
309
310        self.buffer.push(batch.clone());
311        Poll::Ready(Some(Ok(batch)))
312    }
313
314    /// Start polling for the next iteration, will be called either after the static term
315    /// is completed or another term is completed. It will follow the algorithm above on
316    /// to check whether the recursion has ended.
317    fn poll_next_iteration(
318        mut self: std::pin::Pin<&mut Self>,
319        cx: &mut Context<'_>,
320    ) -> Poll<Option<Result<RecordBatch>>> {
321        let total_length = self
322            .buffer
323            .iter()
324            .fold(0, |acc, batch| acc + batch.num_rows());
325
326        if total_length == 0 {
327            return Poll::Ready(None);
328        }
329
330        // Update the work table with the current buffer
331        let reserved_batches = ReservedBatches::new(
332            std::mem::take(&mut self.buffer),
333            self.reservation.take(),
334        );
335        self.work_table.update(reserved_batches);
336
337        // We always execute (and re-execute iteratively) the first partition.
338        // Downstream plans should not expect any partitioning.
339        let partition = 0;
340
341        let recursive_plan = reset_plan_states(Arc::clone(&self.recursive_term))?;
342        self.recursive_stream =
343            Some(recursive_plan.execute(partition, Arc::clone(&self.task_context))?);
344        self.poll_next(cx)
345    }
346}
347
348fn assign_work_table(
349    plan: Arc<dyn ExecutionPlan>,
350    work_table: Arc<WorkTable>,
351) -> Result<Arc<dyn ExecutionPlan>> {
352    let mut work_table_refs = 0;
353    plan.transform_down(|plan| {
354        if let Some(exec) = plan.as_any().downcast_ref::<WorkTableExec>() {
355            if work_table_refs > 0 {
356                not_impl_err!(
357                    "Multiple recursive references to the same CTE are not supported"
358                )
359            } else {
360                work_table_refs += 1;
361                Ok(Transformed::yes(Arc::new(
362                    exec.with_work_table(Arc::clone(&work_table)),
363                )))
364            }
365        } else if plan.as_any().is::<RecursiveQueryExec>() {
366            not_impl_err!("Recursive queries cannot be nested")
367        } else {
368            Ok(Transformed::no(plan))
369        }
370    })
371    .data()
372}
373
374/// Some plans will change their internal states after execution, making them unable to be executed again.
375/// This function uses `ExecutionPlan::with_new_children` to fork a new plan with initial states.
376///
377/// An example is `CrossJoinExec`, which loads the left table into memory and stores it in the plan.
378/// However, if the data of the left table is derived from the work table, it will become outdated
379/// as the work table changes. When the next iteration executes this plan again, we must clear the left table.
380fn reset_plan_states(plan: Arc<dyn ExecutionPlan>) -> Result<Arc<dyn ExecutionPlan>> {
381    plan.transform_up(|plan| {
382        // WorkTableExec's states have already been updated correctly.
383        if plan.as_any().is::<WorkTableExec>() {
384            Ok(Transformed::no(plan))
385        } else {
386            let new_plan = Arc::clone(&plan)
387                .with_new_children(plan.children().into_iter().cloned().collect())?;
388            Ok(Transformed::yes(new_plan))
389        }
390    })
391    .data()
392}
393
394impl Stream for RecursiveQueryStream {
395    type Item = Result<RecordBatch>;
396
397    fn poll_next(
398        mut self: std::pin::Pin<&mut Self>,
399        cx: &mut Context<'_>,
400    ) -> Poll<Option<Self::Item>> {
401        // TODO: we should use this poll to record some metrics!
402        if let Some(static_stream) = &mut self.static_stream {
403            // While the static term's stream is available, we'll be forwarding the batches from it (also
404            // saving them for the initial iteration of the recursive term).
405            let batch_result = ready!(static_stream.poll_next_unpin(cx));
406            match &batch_result {
407                None => {
408                    // Once this is done, we can start running the setup for the recursive term.
409                    self.static_stream = None;
410                    self.poll_next_iteration(cx)
411                }
412                Some(Ok(batch)) => self.push_batch(batch.clone()),
413                _ => Poll::Ready(batch_result),
414            }
415        } else if let Some(recursive_stream) = &mut self.recursive_stream {
416            let batch_result = ready!(recursive_stream.poll_next_unpin(cx));
417            match batch_result {
418                None => {
419                    self.recursive_stream = None;
420                    self.poll_next_iteration(cx)
421                }
422                Some(Ok(batch)) => self.push_batch(batch),
423                _ => Poll::Ready(batch_result),
424            }
425        } else {
426            Poll::Ready(None)
427        }
428    }
429}
430
431impl RecordBatchStream for RecursiveQueryStream {
432    /// Get the schema
433    fn schema(&self) -> SchemaRef {
434        Arc::clone(&self.schema)
435    }
436}
437
438#[cfg(test)]
439mod tests {}