Skip to main content

datafusion_physical_plan/
recursive_query.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines the recursive query plan
19
20use std::any::Any;
21use std::sync::Arc;
22use std::task::{Context, Poll};
23
24use super::work_table::{ReservedBatches, WorkTable};
25use crate::aggregates::group_values::{GroupValues, new_group_values};
26use crate::aggregates::order::GroupOrdering;
27use crate::execution_plan::{Boundedness, EmissionType, reset_plan_states};
28use crate::metrics::{
29    BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, RecordOutput,
30};
31use crate::{
32    DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, RecordBatchStream,
33    SendableRecordBatchStream,
34};
35use arrow::array::{BooleanArray, BooleanBuilder};
36use arrow::compute::filter_record_batch;
37use arrow::datatypes::SchemaRef;
38use arrow::record_batch::RecordBatch;
39use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
40use datafusion_common::{Result, internal_datafusion_err, not_impl_err};
41use datafusion_execution::TaskContext;
42use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
43use datafusion_physical_expr::{EquivalenceProperties, Partitioning};
44
45use futures::{Stream, StreamExt, ready};
46
47/// Recursive query execution plan.
48///
49/// This plan has two components: a base part (the static term) and
50/// a dynamic part (the recursive term). The execution will start from
51/// the base, and as long as the previous iteration produced at least
52/// a single new row (taking care of the distinction) the recursive
53/// part will be continuously executed.
54///
55/// Before each execution of the dynamic part, the rows from the previous
56/// iteration will be available in a "working table" (not a real table,
57/// can be only accessed using a continuance operation).
58///
59/// Note that there won't be any limit or checks applied to detect
60/// an infinite recursion, so it is up to the planner to ensure that
61/// it won't happen.
62#[derive(Debug, Clone)]
63pub struct RecursiveQueryExec {
64    /// Name of the query handler
65    name: String,
66    /// The working table of cte
67    work_table: Arc<WorkTable>,
68    /// The base part (static term)
69    static_term: Arc<dyn ExecutionPlan>,
70    /// The dynamic part (recursive term)
71    recursive_term: Arc<dyn ExecutionPlan>,
72    /// Distinction
73    is_distinct: bool,
74    /// Execution metrics
75    metrics: ExecutionPlanMetricsSet,
76    /// Cache holding plan properties like equivalences, output partitioning etc.
77    cache: Arc<PlanProperties>,
78}
79
80impl RecursiveQueryExec {
81    /// Create a new RecursiveQueryExec
82    pub fn try_new(
83        name: String,
84        static_term: Arc<dyn ExecutionPlan>,
85        recursive_term: Arc<dyn ExecutionPlan>,
86        is_distinct: bool,
87    ) -> Result<Self> {
88        // Each recursive query needs its own work table
89        let work_table = Arc::new(WorkTable::new(name.clone()));
90        // Use the same work table for both the WorkTableExec and the recursive term
91        let recursive_term = assign_work_table(recursive_term, &work_table)?;
92        let cache = Self::compute_properties(static_term.schema());
93        Ok(RecursiveQueryExec {
94            name,
95            static_term,
96            recursive_term,
97            is_distinct,
98            work_table,
99            metrics: ExecutionPlanMetricsSet::new(),
100            cache: Arc::new(cache),
101        })
102    }
103
104    /// Ref to name
105    pub fn name(&self) -> &str {
106        &self.name
107    }
108
109    /// Ref to static term
110    pub fn static_term(&self) -> &Arc<dyn ExecutionPlan> {
111        &self.static_term
112    }
113
114    /// Ref to recursive term
115    pub fn recursive_term(&self) -> &Arc<dyn ExecutionPlan> {
116        &self.recursive_term
117    }
118
119    /// is distinct
120    pub fn is_distinct(&self) -> bool {
121        self.is_distinct
122    }
123
124    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
125    fn compute_properties(schema: SchemaRef) -> PlanProperties {
126        let eq_properties = EquivalenceProperties::new(schema);
127
128        PlanProperties::new(
129            eq_properties,
130            Partitioning::UnknownPartitioning(1),
131            EmissionType::Incremental,
132            Boundedness::Bounded,
133        )
134    }
135}
136
137impl ExecutionPlan for RecursiveQueryExec {
138    fn name(&self) -> &'static str {
139        "RecursiveQueryExec"
140    }
141
142    fn as_any(&self) -> &dyn Any {
143        self
144    }
145
146    fn properties(&self) -> &Arc<PlanProperties> {
147        &self.cache
148    }
149
150    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
151        vec![&self.static_term, &self.recursive_term]
152    }
153
154    // TODO: control these hints and see whether we can
155    // infer some from the child plans (static/recursive terms).
156    fn maintains_input_order(&self) -> Vec<bool> {
157        vec![false, false]
158    }
159
160    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
161        vec![false, false]
162    }
163
164    fn required_input_distribution(&self) -> Vec<crate::Distribution> {
165        vec![
166            crate::Distribution::SinglePartition,
167            crate::Distribution::SinglePartition,
168        ]
169    }
170
171    fn with_new_children(
172        self: Arc<Self>,
173        children: Vec<Arc<dyn ExecutionPlan>>,
174    ) -> Result<Arc<dyn ExecutionPlan>> {
175        RecursiveQueryExec::try_new(
176            self.name.clone(),
177            Arc::clone(&children[0]),
178            Arc::clone(&children[1]),
179            self.is_distinct,
180        )
181        .map(|e| Arc::new(e) as _)
182    }
183
184    fn execute(
185        &self,
186        partition: usize,
187        context: Arc<TaskContext>,
188    ) -> Result<SendableRecordBatchStream> {
189        // TODO: we might be able to handle multiple partitions in the future.
190        if partition != 0 {
191            return Err(internal_datafusion_err!(
192                "RecursiveQueryExec got an invalid partition {partition} (expected 0)"
193            ));
194        }
195
196        let static_stream = self.static_term.execute(partition, Arc::clone(&context))?;
197        let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
198        Ok(Box::pin(RecursiveQueryStream::new(
199            context,
200            Arc::clone(&self.work_table),
201            Arc::clone(&self.recursive_term),
202            static_stream,
203            self.is_distinct,
204            baseline_metrics,
205        )?))
206    }
207
208    fn metrics(&self) -> Option<MetricsSet> {
209        Some(self.metrics.clone_inner())
210    }
211}
212
213impl DisplayAs for RecursiveQueryExec {
214    fn fmt_as(
215        &self,
216        t: DisplayFormatType,
217        f: &mut std::fmt::Formatter,
218    ) -> std::fmt::Result {
219        match t {
220            DisplayFormatType::Default | DisplayFormatType::Verbose => {
221                write!(
222                    f,
223                    "RecursiveQueryExec: name={}, is_distinct={}",
224                    self.name, self.is_distinct
225                )
226            }
227            DisplayFormatType::TreeRender => {
228                // TODO: collect info
229                write!(f, "")
230            }
231        }
232    }
233}
234
235/// The actual logic of the recursive queries happens during the streaming
236/// process. A simplified version of the algorithm is the following:
237///
238/// buffer = []
239///
240/// while batch := static_stream.next():
241///    buffer.push(batch)
242///    yield buffer
243///
244/// while buffer.len() > 0:
245///    sender, receiver = Channel()
246///    register_continuation(handle_name, receiver)
247///    sender.send(buffer.drain())
248///    recursive_stream = recursive_term.execute()
249///    while batch := recursive_stream.next():
250///        buffer.append(batch)
251///        yield buffer
252struct RecursiveQueryStream {
253    /// The context to be used for managing handlers & executing new tasks
254    task_context: Arc<TaskContext>,
255    /// The working table state, representing the self referencing cte table
256    work_table: Arc<WorkTable>,
257    /// The dynamic part (recursive term) as is (without being executed)
258    recursive_term: Arc<dyn ExecutionPlan>,
259    /// The static part (static term) as a stream. If the processing of this
260    /// part is completed, then it will be None.
261    static_stream: Option<SendableRecordBatchStream>,
262    /// The dynamic part (recursive term) as a stream. If the processing of this
263    /// part has not started yet, or has been completed, then it will be None.
264    recursive_stream: Option<SendableRecordBatchStream>,
265    /// The schema of the output.
266    schema: SchemaRef,
267    /// In-memory buffer for storing a copy of the current results. Will be
268    /// cleared after each iteration.
269    buffer: Vec<RecordBatch>,
270    /// Tracks the memory used by the buffer
271    reservation: MemoryReservation,
272    /// If the distinct flag is set, then we use this hash table to remove duplicates from result and work tables
273    distinct_deduplicator: Option<DistinctDeduplicator>,
274    /// Metrics.
275    baseline_metrics: BaselineMetrics,
276}
277
278impl RecursiveQueryStream {
279    /// Create a new recursive query stream
280    fn new(
281        task_context: Arc<TaskContext>,
282        work_table: Arc<WorkTable>,
283        recursive_term: Arc<dyn ExecutionPlan>,
284        static_stream: SendableRecordBatchStream,
285        is_distinct: bool,
286        baseline_metrics: BaselineMetrics,
287    ) -> Result<Self> {
288        let schema = static_stream.schema();
289        let reservation =
290            MemoryConsumer::new("RecursiveQuery").register(task_context.memory_pool());
291        let distinct_deduplicator = is_distinct
292            .then(|| DistinctDeduplicator::new(Arc::clone(&schema), &task_context))
293            .transpose()?;
294        Ok(Self {
295            task_context,
296            work_table,
297            recursive_term,
298            static_stream: Some(static_stream),
299            recursive_stream: None,
300            schema,
301            buffer: vec![],
302            reservation,
303            distinct_deduplicator,
304            baseline_metrics,
305        })
306    }
307
308    /// Push a clone of the given batch to the in memory buffer, and then return
309    /// a poll with it.
310    fn push_batch(
311        mut self: std::pin::Pin<&mut Self>,
312        mut batch: RecordBatch,
313    ) -> Poll<Option<Result<RecordBatch>>> {
314        let baseline_metrics = self.baseline_metrics.clone();
315        if let Some(deduplicator) = &mut self.distinct_deduplicator {
316            let _timer_guard = baseline_metrics.elapsed_compute().timer();
317            batch = deduplicator.deduplicate(&batch)?;
318        }
319
320        if let Err(e) = self.reservation.try_grow(batch.get_array_memory_size()) {
321            return Poll::Ready(Some(Err(e)));
322        }
323        self.buffer.push(batch.clone());
324        (&batch).record_output(&baseline_metrics);
325        Poll::Ready(Some(Ok(batch)))
326    }
327
328    /// Start polling for the next iteration, will be called either after the static term
329    /// is completed or another term is completed. It will follow the algorithm above on
330    /// to check whether the recursion has ended.
331    fn poll_next_iteration(
332        mut self: std::pin::Pin<&mut Self>,
333        cx: &mut Context<'_>,
334    ) -> Poll<Option<Result<RecordBatch>>> {
335        let total_length = self
336            .buffer
337            .iter()
338            .fold(0, |acc, batch| acc + batch.num_rows());
339
340        if total_length == 0 {
341            return Poll::Ready(None);
342        }
343
344        // Update the work table with the current buffer
345        let reserved_batches = ReservedBatches::new(
346            std::mem::take(&mut self.buffer),
347            self.reservation.take(),
348        );
349        self.work_table.update(reserved_batches);
350
351        // We always execute (and re-execute iteratively) the first partition.
352        // Downstream plans should not expect any partitioning.
353        let partition = 0;
354
355        let recursive_plan = reset_plan_states(Arc::clone(&self.recursive_term))?;
356        self.recursive_stream =
357            Some(recursive_plan.execute(partition, Arc::clone(&self.task_context))?);
358        self.poll_next(cx)
359    }
360}
361
362fn assign_work_table(
363    plan: Arc<dyn ExecutionPlan>,
364    work_table: &Arc<WorkTable>,
365) -> Result<Arc<dyn ExecutionPlan>> {
366    let mut work_table_refs = 0;
367    plan.transform_down(|plan| {
368        if let Some(new_plan) =
369            plan.with_new_state(Arc::clone(work_table) as Arc<dyn Any + Send + Sync>)
370        {
371            if work_table_refs > 0 {
372                not_impl_err!(
373                    "Multiple recursive references to the same CTE are not supported"
374                )
375            } else {
376                work_table_refs += 1;
377                Ok(Transformed::yes(new_plan))
378            }
379        } else {
380            Ok(Transformed::no(plan))
381        }
382    })
383    .data()
384}
385
386impl Stream for RecursiveQueryStream {
387    type Item = Result<RecordBatch>;
388
389    fn poll_next(
390        mut self: std::pin::Pin<&mut Self>,
391        cx: &mut Context<'_>,
392    ) -> Poll<Option<Self::Item>> {
393        if let Some(static_stream) = &mut self.static_stream {
394            // While the static term's stream is available, we'll be forwarding the batches from it (also
395            // saving them for the initial iteration of the recursive term).
396            let batch_result = ready!(static_stream.poll_next_unpin(cx));
397            match &batch_result {
398                None => {
399                    // Once this is done, we can start running the setup for the recursive term.
400                    self.static_stream = None;
401                    self.poll_next_iteration(cx)
402                }
403                Some(Ok(batch)) => self.push_batch(batch.clone()),
404                _ => Poll::Ready(batch_result),
405            }
406        } else if let Some(recursive_stream) = &mut self.recursive_stream {
407            let batch_result = ready!(recursive_stream.poll_next_unpin(cx));
408            match batch_result {
409                None => {
410                    self.recursive_stream = None;
411                    self.poll_next_iteration(cx)
412                }
413                Some(Ok(batch)) => self.push_batch(batch),
414                _ => Poll::Ready(batch_result),
415            }
416        } else {
417            Poll::Ready(None)
418        }
419    }
420}
421
422impl RecordBatchStream for RecursiveQueryStream {
423    /// Get the schema
424    fn schema(&self) -> SchemaRef {
425        Arc::clone(&self.schema)
426    }
427}
428
429/// Deduplicator based on a hash table.
430struct DistinctDeduplicator {
431    /// Grouped rows used for distinct
432    group_values: Box<dyn GroupValues>,
433    reservation: MemoryReservation,
434    intern_output_buffer: Vec<usize>,
435}
436
437impl DistinctDeduplicator {
438    fn new(schema: SchemaRef, task_context: &TaskContext) -> Result<Self> {
439        let group_values = new_group_values(schema, &GroupOrdering::None)?;
440        let reservation = MemoryConsumer::new("RecursiveQueryHashTable")
441            .register(task_context.memory_pool());
442        Ok(Self {
443            group_values,
444            reservation,
445            intern_output_buffer: Vec::new(),
446        })
447    }
448
449    /// Remove duplicated rows from the given batch, keeping a state between batches.
450    ///
451    /// We use a hash table to allocate new group ids for the new rows.
452    /// [`GroupValues`] allocate increasing group ids.
453    /// Hence, if groups (i.e., rows) are new, then they have ids >= length before interning, we keep them.
454    /// We also detect duplicates by enforcing that group ids are increasing.
455    fn deduplicate(&mut self, batch: &RecordBatch) -> Result<RecordBatch> {
456        let size_before = self.group_values.len();
457        self.intern_output_buffer.reserve(batch.num_rows());
458        self.group_values
459            .intern(batch.columns(), &mut self.intern_output_buffer)?;
460        let mask = new_groups_mask(&self.intern_output_buffer, size_before);
461        self.intern_output_buffer.clear();
462        // We update the reservation to reflect the new size of the hash table.
463        self.reservation.try_resize(self.group_values.size())?;
464        Ok(filter_record_batch(batch, &mask)?)
465    }
466}
467
468/// Return a mask, each element being true if, and only if, the element is greater than all previous elements and greater or equal than the provided max_already_seen_group_id
469fn new_groups_mask(
470    values: &[usize],
471    mut max_already_seen_group_id: usize,
472) -> BooleanArray {
473    let mut output = BooleanBuilder::with_capacity(values.len());
474    for value in values {
475        if *value >= max_already_seen_group_id {
476            output.append_value(true);
477            max_already_seen_group_id = *value + 1; // We want to be increasing
478        } else {
479            output.append_value(false);
480        }
481    }
482    output.finish()
483}
484
485#[cfg(test)]
486mod tests {}