Skip to main content

datafusion_physical_plan/
work_table.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines the work table query plan
19
20use std::any::Any;
21use std::sync::{Arc, Mutex};
22
23use crate::coop::cooperative;
24use crate::execution_plan::{Boundedness, EmissionType, SchedulingType};
25use crate::memory::MemoryStream;
26use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet};
27use crate::{
28    DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties,
29    SendableRecordBatchStream, Statistics,
30};
31
32use arrow::datatypes::SchemaRef;
33use arrow::record_batch::RecordBatch;
34use datafusion_common::{Result, assert_eq_or_internal_err, internal_datafusion_err};
35use datafusion_execution::TaskContext;
36use datafusion_execution::memory_pool::MemoryReservation;
37use datafusion_physical_expr::{EquivalenceProperties, Partitioning};
38
39/// A vector of record batches with a memory reservation.
40#[derive(Debug)]
41pub(super) struct ReservedBatches {
42    batches: Vec<RecordBatch>,
43    reservation: MemoryReservation,
44}
45
46impl ReservedBatches {
47    pub(super) fn new(batches: Vec<RecordBatch>, reservation: MemoryReservation) -> Self {
48        ReservedBatches {
49            batches,
50            reservation,
51        }
52    }
53}
54
55/// The name is from PostgreSQL's terminology.
56/// See <https://wiki.postgresql.org/wiki/CTEReadme#How_Recursion_Works>
57/// This table serves as a mirror or buffer between each iteration of a recursive query.
58#[derive(Debug)]
59pub struct WorkTable {
60    batches: Mutex<Option<ReservedBatches>>,
61    name: String,
62}
63
64impl WorkTable {
65    /// Create a new work table.
66    pub(super) fn new(name: String) -> Self {
67        Self {
68            batches: Mutex::new(None),
69            name,
70        }
71    }
72
73    /// Take the previously written batches from the work table.
74    /// This will be called by the [`WorkTableExec`] when it is executed.
75    fn take(&self) -> Result<ReservedBatches> {
76        self.batches
77            .lock()
78            .unwrap()
79            .take()
80            .ok_or_else(|| internal_datafusion_err!("Unexpected empty work table"))
81    }
82
83    /// Update the results of a recursive query iteration to the work table.
84    pub(super) fn update(&self, batches: ReservedBatches) {
85        self.batches.lock().unwrap().replace(batches);
86    }
87}
88
89/// A temporary "working table" operation where the input data will be
90/// taken from the named handle during the execution and will be re-published
91/// as is (kind of like a mirror).
92///
93/// Most notably used in the implementation of recursive queries where the
94/// underlying relation does not exist yet but the data will come as the previous
95/// term is evaluated. This table will be used such that the recursive plan
96/// will register a receiver in the task context and this plan will use that
97/// receiver to get the data and stream it back up so that the batches are available
98/// in the next iteration.
99#[derive(Clone, Debug)]
100pub struct WorkTableExec {
101    /// Name of the relation handler
102    name: String,
103    /// The schema of the stream
104    schema: SchemaRef,
105    /// Projection to apply to build the output stream from the recursion state
106    projection: Option<Vec<usize>>,
107    /// The work table
108    work_table: Arc<WorkTable>,
109    /// Execution metrics
110    metrics: ExecutionPlanMetricsSet,
111    /// Cache holding plan properties like equivalences, output partitioning etc.
112    cache: Arc<PlanProperties>,
113}
114
115impl WorkTableExec {
116    /// Create a new execution plan for a worktable exec.
117    pub fn new(
118        name: String,
119        mut schema: SchemaRef,
120        projection: Option<Vec<usize>>,
121    ) -> Result<Self> {
122        if let Some(projection) = &projection {
123            schema = Arc::new(schema.project(projection)?);
124        }
125        let cache = Self::compute_properties(Arc::clone(&schema));
126        Ok(Self {
127            name: name.clone(),
128            schema,
129            projection,
130            work_table: Arc::new(WorkTable::new(name)),
131            metrics: ExecutionPlanMetricsSet::new(),
132            cache: Arc::new(cache),
133        })
134    }
135
136    /// Ref to name
137    pub fn name(&self) -> &str {
138        &self.name
139    }
140
141    /// Arc clone of ref to schema
142    pub fn schema(&self) -> SchemaRef {
143        Arc::clone(&self.schema)
144    }
145
146    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
147    fn compute_properties(schema: SchemaRef) -> PlanProperties {
148        PlanProperties::new(
149            EquivalenceProperties::new(schema),
150            Partitioning::UnknownPartitioning(1),
151            EmissionType::Incremental,
152            Boundedness::Bounded,
153        )
154        .with_scheduling_type(SchedulingType::Cooperative)
155    }
156}
157
158impl DisplayAs for WorkTableExec {
159    fn fmt_as(
160        &self,
161        t: DisplayFormatType,
162        f: &mut std::fmt::Formatter,
163    ) -> std::fmt::Result {
164        match t {
165            DisplayFormatType::Default | DisplayFormatType::Verbose => {
166                write!(f, "WorkTableExec: name={}", self.name)
167            }
168            DisplayFormatType::TreeRender => {
169                write!(f, "name={}", self.name)
170            }
171        }
172    }
173}
174
175impl ExecutionPlan for WorkTableExec {
176    fn name(&self) -> &'static str {
177        "WorkTableExec"
178    }
179
180    fn properties(&self) -> &Arc<PlanProperties> {
181        &self.cache
182    }
183
184    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
185        vec![]
186    }
187
188    fn with_new_children(
189        self: Arc<Self>,
190        _: Vec<Arc<dyn ExecutionPlan>>,
191    ) -> Result<Arc<dyn ExecutionPlan>> {
192        Ok(Arc::clone(&self) as Arc<dyn ExecutionPlan>)
193    }
194
195    /// Stream the batches that were written to the work table.
196    fn execute(
197        &self,
198        partition: usize,
199        _context: Arc<TaskContext>,
200    ) -> Result<SendableRecordBatchStream> {
201        // WorkTable streams must be the plan base.
202        assert_eq_or_internal_err!(
203            partition,
204            0,
205            "WorkTableExec got an invalid partition {partition} (expected 0)"
206        );
207        let ReservedBatches {
208            mut batches,
209            reservation,
210        } = self.work_table.take()?;
211        if let Some(projection) = &self.projection {
212            // We apply the projection
213            // TODO: it would be better to apply it as soon as possible and not only here
214            // TODO: an aggressive projection makes the memory reservation smaller, even if we do not edit it
215            batches = batches
216                .into_iter()
217                .map(|b| b.project(projection))
218                .collect::<Result<Vec<_>, _>>()?;
219        }
220
221        let stream = MemoryStream::try_new(batches, Arc::clone(&self.schema), None)?
222            .with_reservation(reservation);
223        Ok(Box::pin(cooperative(stream)))
224    }
225
226    fn metrics(&self) -> Option<MetricsSet> {
227        Some(self.metrics.clone_inner())
228    }
229
230    fn partition_statistics(&self, _partition: Option<usize>) -> Result<Arc<Statistics>> {
231        Ok(Arc::new(Statistics::new_unknown(&self.schema())))
232    }
233
234    /// Injects run-time state into this `WorkTableExec`.
235    ///
236    /// The only state this node currently understands is an [`Arc<WorkTable>`].
237    /// If `state` can be down-cast to that type, a new `WorkTableExec` backed
238    /// by the provided work table is returned.  Otherwise `None` is returned
239    /// so that callers can attempt to propagate the state further down the
240    /// execution plan tree.
241    fn with_new_state(
242        &self,
243        state: Arc<dyn Any + Send + Sync>,
244    ) -> Option<Arc<dyn ExecutionPlan>> {
245        // Down-cast to the expected state type; propagate `None` on failure
246        let work_table = state.downcast::<WorkTable>().ok()?;
247
248        if work_table.name != self.name {
249            return None; // Different table
250        }
251
252        Some(Arc::new(Self {
253            name: self.name.clone(),
254            schema: Arc::clone(&self.schema),
255            projection: self.projection.clone(),
256            metrics: ExecutionPlanMetricsSet::new(),
257            work_table,
258            cache: Arc::clone(&self.cache),
259        }))
260    }
261}
262
263#[cfg(test)]
264mod tests {
265    use super::*;
266    use arrow::array::{ArrayRef, Int16Array, Int32Array, Int64Array};
267    use arrow_schema::{DataType, Field, Schema};
268    use datafusion_execution::memory_pool::{MemoryConsumer, UnboundedMemoryPool};
269    use futures::StreamExt;
270
271    #[test]
272    fn test_work_table() {
273        let work_table = WorkTable::new("test".into());
274        // Can't take from empty work_table
275        assert!(work_table.take().is_err());
276
277        let pool = Arc::new(UnboundedMemoryPool::default()) as _;
278        let reservation = MemoryConsumer::new("test_work_table").register(&pool);
279
280        // Update batch to work_table
281        let array: ArrayRef = Arc::new((0..5).collect::<Int32Array>());
282        let batch = RecordBatch::try_from_iter(vec![("col", array)]).unwrap();
283        reservation.try_grow(100).unwrap();
284        work_table.update(ReservedBatches::new(vec![batch.clone()], reservation));
285        // Take from work_table
286        let reserved_batches = work_table.take().unwrap();
287        assert_eq!(reserved_batches.batches, vec![batch.clone()]);
288
289        // Consume the batch by the MemoryStream
290        let memory_stream =
291            MemoryStream::try_new(reserved_batches.batches, batch.schema(), None)
292                .unwrap()
293                .with_reservation(reserved_batches.reservation);
294
295        // Should still be reserved
296        assert_eq!(pool.reserved(), 100);
297
298        // The reservation should be freed after drop the memory_stream
299        drop(memory_stream);
300        assert_eq!(pool.reserved(), 0);
301    }
302
303    #[tokio::test]
304    async fn test_work_table_exec() {
305        let schema = Arc::new(Schema::new(vec![
306            Field::new("a", DataType::Int64, false),
307            Field::new("b", DataType::Int32, false),
308            Field::new("c", DataType::Int16, false),
309        ]));
310        let work_table_exec =
311            WorkTableExec::new("wt".into(), Arc::clone(&schema), Some(vec![2, 1]))
312                .unwrap();
313
314        // We inject the work table
315        let work_table = Arc::new(WorkTable::new("wt".into()));
316        let work_table_exec = work_table_exec
317            .with_new_state(Arc::clone(&work_table) as _)
318            .unwrap();
319
320        // We update the work table
321        let pool = Arc::new(UnboundedMemoryPool::default()) as _;
322        let reservation = MemoryConsumer::new("test_work_table").register(&pool);
323        let batch = RecordBatch::try_new(
324            Arc::clone(&schema),
325            vec![
326                Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5])),
327                Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])),
328                Arc::new(Int16Array::from(vec![1, 2, 3, 4, 5])),
329            ],
330        )
331        .unwrap();
332        work_table.update(ReservedBatches::new(vec![batch], reservation));
333
334        // We get back the batch from the work table
335        let returned_batch = work_table_exec
336            .execute(0, Arc::new(TaskContext::default()))
337            .unwrap()
338            .next()
339            .await
340            .unwrap()
341            .unwrap();
342        assert_eq!(
343            returned_batch,
344            RecordBatch::try_from_iter(vec![
345                ("c", Arc::new(Int16Array::from(vec![1, 2, 3, 4, 5])) as _),
346                ("b", Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])) as _),
347            ])
348            .unwrap()
349        );
350    }
351}