Skip to main content

datafusion_physical_plan/
recursive_query.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines the recursive query plan
19
20use std::any::Any;
21use std::sync::Arc;
22use std::task::{Context, Poll};
23
24use super::work_table::{ReservedBatches, WorkTable};
25use crate::aggregates::group_values::{GroupValues, new_group_values};
26use crate::aggregates::order::GroupOrdering;
27use crate::execution_plan::{Boundedness, EmissionType};
28use crate::metrics::{
29    BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, RecordOutput,
30};
31use crate::{
32    DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, RecordBatchStream,
33    SendableRecordBatchStream, Statistics,
34};
35use arrow::array::{BooleanArray, BooleanBuilder};
36use arrow::compute::filter_record_batch;
37use arrow::datatypes::SchemaRef;
38use arrow::record_batch::RecordBatch;
39use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
40use datafusion_common::{Result, internal_datafusion_err, not_impl_err};
41use datafusion_execution::TaskContext;
42use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
43use datafusion_physical_expr::{EquivalenceProperties, Partitioning};
44
45use futures::{Stream, StreamExt, ready};
46
47/// Recursive query execution plan.
48///
49/// This plan has two components: a base part (the static term) and
50/// a dynamic part (the recursive term). The execution will start from
51/// the base, and as long as the previous iteration produced at least
52/// a single new row (taking care of the distinction) the recursive
53/// part will be continuously executed.
54///
55/// Before each execution of the dynamic part, the rows from the previous
56/// iteration will be available in a "working table" (not a real table,
57/// can be only accessed using a continuance operation).
58///
59/// Note that there won't be any limit or checks applied to detect
60/// an infinite recursion, so it is up to the planner to ensure that
61/// it won't happen.
62#[derive(Debug, Clone)]
63pub struct RecursiveQueryExec {
64    /// Name of the query handler
65    name: String,
66    /// The working table of cte
67    work_table: Arc<WorkTable>,
68    /// The base part (static term)
69    static_term: Arc<dyn ExecutionPlan>,
70    /// The dynamic part (recursive term)
71    recursive_term: Arc<dyn ExecutionPlan>,
72    /// Distinction
73    is_distinct: bool,
74    /// Execution metrics
75    metrics: ExecutionPlanMetricsSet,
76    /// Cache holding plan properties like equivalences, output partitioning etc.
77    cache: PlanProperties,
78}
79
80impl RecursiveQueryExec {
81    /// Create a new RecursiveQueryExec
82    pub fn try_new(
83        name: String,
84        static_term: Arc<dyn ExecutionPlan>,
85        recursive_term: Arc<dyn ExecutionPlan>,
86        is_distinct: bool,
87    ) -> Result<Self> {
88        // Each recursive query needs its own work table
89        let work_table = Arc::new(WorkTable::new(name.clone()));
90        // Use the same work table for both the WorkTableExec and the recursive term
91        let recursive_term = assign_work_table(recursive_term, &work_table)?;
92        let cache = Self::compute_properties(static_term.schema());
93        Ok(RecursiveQueryExec {
94            name,
95            static_term,
96            recursive_term,
97            is_distinct,
98            work_table,
99            metrics: ExecutionPlanMetricsSet::new(),
100            cache,
101        })
102    }
103
104    /// Ref to name
105    pub fn name(&self) -> &str {
106        &self.name
107    }
108
109    /// Ref to static term
110    pub fn static_term(&self) -> &Arc<dyn ExecutionPlan> {
111        &self.static_term
112    }
113
114    /// Ref to recursive term
115    pub fn recursive_term(&self) -> &Arc<dyn ExecutionPlan> {
116        &self.recursive_term
117    }
118
119    /// is distinct
120    pub fn is_distinct(&self) -> bool {
121        self.is_distinct
122    }
123
124    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
125    fn compute_properties(schema: SchemaRef) -> PlanProperties {
126        let eq_properties = EquivalenceProperties::new(schema);
127
128        PlanProperties::new(
129            eq_properties,
130            Partitioning::UnknownPartitioning(1),
131            EmissionType::Incremental,
132            Boundedness::Bounded,
133        )
134    }
135}
136
137impl ExecutionPlan for RecursiveQueryExec {
138    fn name(&self) -> &'static str {
139        "RecursiveQueryExec"
140    }
141
142    fn as_any(&self) -> &dyn Any {
143        self
144    }
145
146    fn properties(&self) -> &PlanProperties {
147        &self.cache
148    }
149
150    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
151        vec![&self.static_term, &self.recursive_term]
152    }
153
154    // TODO: control these hints and see whether we can
155    // infer some from the child plans (static/recursive terms).
156    fn maintains_input_order(&self) -> Vec<bool> {
157        vec![false, false]
158    }
159
160    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
161        vec![false, false]
162    }
163
164    fn required_input_distribution(&self) -> Vec<crate::Distribution> {
165        vec![
166            crate::Distribution::SinglePartition,
167            crate::Distribution::SinglePartition,
168        ]
169    }
170
171    fn with_new_children(
172        self: Arc<Self>,
173        children: Vec<Arc<dyn ExecutionPlan>>,
174    ) -> Result<Arc<dyn ExecutionPlan>> {
175        RecursiveQueryExec::try_new(
176            self.name.clone(),
177            Arc::clone(&children[0]),
178            Arc::clone(&children[1]),
179            self.is_distinct,
180        )
181        .map(|e| Arc::new(e) as _)
182    }
183
184    fn execute(
185        &self,
186        partition: usize,
187        context: Arc<TaskContext>,
188    ) -> Result<SendableRecordBatchStream> {
189        // TODO: we might be able to handle multiple partitions in the future.
190        if partition != 0 {
191            return Err(internal_datafusion_err!(
192                "RecursiveQueryExec got an invalid partition {partition} (expected 0)"
193            ));
194        }
195
196        let static_stream = self.static_term.execute(partition, Arc::clone(&context))?;
197        let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
198        Ok(Box::pin(RecursiveQueryStream::new(
199            context,
200            Arc::clone(&self.work_table),
201            Arc::clone(&self.recursive_term),
202            static_stream,
203            self.is_distinct,
204            baseline_metrics,
205        )?))
206    }
207
208    fn metrics(&self) -> Option<MetricsSet> {
209        Some(self.metrics.clone_inner())
210    }
211
212    fn statistics(&self) -> Result<Statistics> {
213        Ok(Statistics::new_unknown(&self.schema()))
214    }
215}
216
217impl DisplayAs for RecursiveQueryExec {
218    fn fmt_as(
219        &self,
220        t: DisplayFormatType,
221        f: &mut std::fmt::Formatter,
222    ) -> std::fmt::Result {
223        match t {
224            DisplayFormatType::Default | DisplayFormatType::Verbose => {
225                write!(
226                    f,
227                    "RecursiveQueryExec: name={}, is_distinct={}",
228                    self.name, self.is_distinct
229                )
230            }
231            DisplayFormatType::TreeRender => {
232                // TODO: collect info
233                write!(f, "")
234            }
235        }
236    }
237}
238
239/// The actual logic of the recursive queries happens during the streaming
240/// process. A simplified version of the algorithm is the following:
241///
242/// buffer = []
243///
244/// while batch := static_stream.next():
245///    buffer.push(batch)
246///    yield buffer
247///
248/// while buffer.len() > 0:
249///    sender, receiver = Channel()
250///    register_continuation(handle_name, receiver)
251///    sender.send(buffer.drain())
252///    recursive_stream = recursive_term.execute()
253///    while batch := recursive_stream.next():
254///        buffer.append(batch)
255///        yield buffer
256struct RecursiveQueryStream {
257    /// The context to be used for managing handlers & executing new tasks
258    task_context: Arc<TaskContext>,
259    /// The working table state, representing the self referencing cte table
260    work_table: Arc<WorkTable>,
261    /// The dynamic part (recursive term) as is (without being executed)
262    recursive_term: Arc<dyn ExecutionPlan>,
263    /// The static part (static term) as a stream. If the processing of this
264    /// part is completed, then it will be None.
265    static_stream: Option<SendableRecordBatchStream>,
266    /// The dynamic part (recursive term) as a stream. If the processing of this
267    /// part has not started yet, or has been completed, then it will be None.
268    recursive_stream: Option<SendableRecordBatchStream>,
269    /// The schema of the output.
270    schema: SchemaRef,
271    /// In-memory buffer for storing a copy of the current results. Will be
272    /// cleared after each iteration.
273    buffer: Vec<RecordBatch>,
274    /// Tracks the memory used by the buffer
275    reservation: MemoryReservation,
276    /// If the distinct flag is set, then we use this hash table to remove duplicates from result and work tables
277    distinct_deduplicator: Option<DistinctDeduplicator>,
278    /// Metrics.
279    baseline_metrics: BaselineMetrics,
280}
281
282impl RecursiveQueryStream {
283    /// Create a new recursive query stream
284    fn new(
285        task_context: Arc<TaskContext>,
286        work_table: Arc<WorkTable>,
287        recursive_term: Arc<dyn ExecutionPlan>,
288        static_stream: SendableRecordBatchStream,
289        is_distinct: bool,
290        baseline_metrics: BaselineMetrics,
291    ) -> Result<Self> {
292        let schema = static_stream.schema();
293        let reservation =
294            MemoryConsumer::new("RecursiveQuery").register(task_context.memory_pool());
295        let distinct_deduplicator = is_distinct
296            .then(|| DistinctDeduplicator::new(Arc::clone(&schema), &task_context))
297            .transpose()?;
298        Ok(Self {
299            task_context,
300            work_table,
301            recursive_term,
302            static_stream: Some(static_stream),
303            recursive_stream: None,
304            schema,
305            buffer: vec![],
306            reservation,
307            distinct_deduplicator,
308            baseline_metrics,
309        })
310    }
311
312    /// Push a clone of the given batch to the in memory buffer, and then return
313    /// a poll with it.
314    fn push_batch(
315        mut self: std::pin::Pin<&mut Self>,
316        mut batch: RecordBatch,
317    ) -> Poll<Option<Result<RecordBatch>>> {
318        let baseline_metrics = self.baseline_metrics.clone();
319        if let Some(deduplicator) = &mut self.distinct_deduplicator {
320            let _timer_guard = baseline_metrics.elapsed_compute().timer();
321            batch = deduplicator.deduplicate(&batch)?;
322        }
323
324        if let Err(e) = self.reservation.try_grow(batch.get_array_memory_size()) {
325            return Poll::Ready(Some(Err(e)));
326        }
327        self.buffer.push(batch.clone());
328        (&batch).record_output(&baseline_metrics);
329        Poll::Ready(Some(Ok(batch)))
330    }
331
332    /// Start polling for the next iteration, will be called either after the static term
333    /// is completed or another term is completed. It will follow the algorithm above on
334    /// to check whether the recursion has ended.
335    fn poll_next_iteration(
336        mut self: std::pin::Pin<&mut Self>,
337        cx: &mut Context<'_>,
338    ) -> Poll<Option<Result<RecordBatch>>> {
339        let total_length = self
340            .buffer
341            .iter()
342            .fold(0, |acc, batch| acc + batch.num_rows());
343
344        if total_length == 0 {
345            return Poll::Ready(None);
346        }
347
348        // Update the work table with the current buffer
349        let reserved_batches = ReservedBatches::new(
350            std::mem::take(&mut self.buffer),
351            self.reservation.take(),
352        );
353        self.work_table.update(reserved_batches);
354
355        // We always execute (and re-execute iteratively) the first partition.
356        // Downstream plans should not expect any partitioning.
357        let partition = 0;
358
359        let recursive_plan = reset_plan_states(Arc::clone(&self.recursive_term))?;
360        self.recursive_stream =
361            Some(recursive_plan.execute(partition, Arc::clone(&self.task_context))?);
362        self.poll_next(cx)
363    }
364}
365
366fn assign_work_table(
367    plan: Arc<dyn ExecutionPlan>,
368    work_table: &Arc<WorkTable>,
369) -> Result<Arc<dyn ExecutionPlan>> {
370    let mut work_table_refs = 0;
371    plan.transform_down(|plan| {
372        if let Some(new_plan) =
373            plan.with_new_state(Arc::clone(work_table) as Arc<dyn Any + Send + Sync>)
374        {
375            if work_table_refs > 0 {
376                not_impl_err!(
377                    "Multiple recursive references to the same CTE are not supported"
378                )
379            } else {
380                work_table_refs += 1;
381                Ok(Transformed::yes(new_plan))
382            }
383        } else {
384            Ok(Transformed::no(plan))
385        }
386    })
387    .data()
388}
389
390/// Some plans will change their internal states after execution, making them unable to be executed again.
391/// This function uses [`ExecutionPlan::reset_state`] to reset any internal state within the plan.
392///
393/// An example is `CrossJoinExec`, which loads the left table into memory and stores it in the plan.
394/// However, if the data of the left table is derived from the work table, it will become outdated
395/// as the work table changes. When the next iteration executes this plan again, we must clear the left table.
396fn reset_plan_states(plan: Arc<dyn ExecutionPlan>) -> Result<Arc<dyn ExecutionPlan>> {
397    plan.transform_up(|plan| {
398        let new_plan = Arc::clone(&plan).reset_state()?;
399        Ok(Transformed::yes(new_plan))
400    })
401    .data()
402}
403
404impl Stream for RecursiveQueryStream {
405    type Item = Result<RecordBatch>;
406
407    fn poll_next(
408        mut self: std::pin::Pin<&mut Self>,
409        cx: &mut Context<'_>,
410    ) -> Poll<Option<Self::Item>> {
411        if let Some(static_stream) = &mut self.static_stream {
412            // While the static term's stream is available, we'll be forwarding the batches from it (also
413            // saving them for the initial iteration of the recursive term).
414            let batch_result = ready!(static_stream.poll_next_unpin(cx));
415            match &batch_result {
416                None => {
417                    // Once this is done, we can start running the setup for the recursive term.
418                    self.static_stream = None;
419                    self.poll_next_iteration(cx)
420                }
421                Some(Ok(batch)) => self.push_batch(batch.clone()),
422                _ => Poll::Ready(batch_result),
423            }
424        } else if let Some(recursive_stream) = &mut self.recursive_stream {
425            let batch_result = ready!(recursive_stream.poll_next_unpin(cx));
426            match batch_result {
427                None => {
428                    self.recursive_stream = None;
429                    self.poll_next_iteration(cx)
430                }
431                Some(Ok(batch)) => self.push_batch(batch),
432                _ => Poll::Ready(batch_result),
433            }
434        } else {
435            Poll::Ready(None)
436        }
437    }
438}
439
440impl RecordBatchStream for RecursiveQueryStream {
441    /// Get the schema
442    fn schema(&self) -> SchemaRef {
443        Arc::clone(&self.schema)
444    }
445}
446
447/// Deduplicator based on a hash table.
448struct DistinctDeduplicator {
449    /// Grouped rows used for distinct
450    group_values: Box<dyn GroupValues>,
451    reservation: MemoryReservation,
452    intern_output_buffer: Vec<usize>,
453}
454
455impl DistinctDeduplicator {
456    fn new(schema: SchemaRef, task_context: &TaskContext) -> Result<Self> {
457        let group_values = new_group_values(schema, &GroupOrdering::None)?;
458        let reservation = MemoryConsumer::new("RecursiveQueryHashTable")
459            .register(task_context.memory_pool());
460        Ok(Self {
461            group_values,
462            reservation,
463            intern_output_buffer: Vec::new(),
464        })
465    }
466
467    /// Remove duplicated rows from the given batch, keeping a state between batches.
468    ///
469    /// We use a hash table to allocate new group ids for the new rows.
470    /// [`GroupValues`] allocate increasing group ids.
471    /// Hence, if groups (i.e., rows) are new, then they have ids >= length before interning, we keep them.
472    /// We also detect duplicates by enforcing that group ids are increasing.
473    fn deduplicate(&mut self, batch: &RecordBatch) -> Result<RecordBatch> {
474        let size_before = self.group_values.len();
475        self.intern_output_buffer.reserve(batch.num_rows());
476        self.group_values
477            .intern(batch.columns(), &mut self.intern_output_buffer)?;
478        let mask = new_groups_mask(&self.intern_output_buffer, size_before);
479        self.intern_output_buffer.clear();
480        // We update the reservation to reflect the new size of the hash table.
481        self.reservation.try_resize(self.group_values.size())?;
482        Ok(filter_record_batch(batch, &mask)?)
483    }
484}
485
486/// Return a mask, each element being true if, and only if, the element is greater than all previous elements and greater or equal than the provided max_already_seen_group_id
487fn new_groups_mask(
488    values: &[usize],
489    mut max_already_seen_group_id: usize,
490) -> BooleanArray {
491    let mut output = BooleanBuilder::with_capacity(values.len());
492    for value in values {
493        if *value >= max_already_seen_group_id {
494            output.append_value(true);
495            max_already_seen_group_id = *value + 1; // We want to be increasing
496        } else {
497            output.append_value(false);
498        }
499    }
500    output.finish()
501}
502
503#[cfg(test)]
504mod tests {}