datafusion_physical_plan/
work_table.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines the work table query plan
19
20use std::any::Any;
21use std::sync::{Arc, Mutex};
22
23use crate::coop::cooperative;
24use crate::execution_plan::{Boundedness, EmissionType, SchedulingType};
25use crate::memory::MemoryStream;
26use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet};
27use crate::{
28    DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties,
29    SendableRecordBatchStream, Statistics,
30};
31
32use arrow::datatypes::SchemaRef;
33use arrow::record_batch::RecordBatch;
34use datafusion_common::{Result, assert_eq_or_internal_err, internal_datafusion_err};
35use datafusion_execution::TaskContext;
36use datafusion_execution::memory_pool::MemoryReservation;
37use datafusion_physical_expr::{EquivalenceProperties, Partitioning};
38
39/// A vector of record batches with a memory reservation.
40#[derive(Debug)]
41pub(super) struct ReservedBatches {
42    batches: Vec<RecordBatch>,
43    reservation: MemoryReservation,
44}
45
46impl ReservedBatches {
47    pub(super) fn new(batches: Vec<RecordBatch>, reservation: MemoryReservation) -> Self {
48        ReservedBatches {
49            batches,
50            reservation,
51        }
52    }
53}
54
55/// The name is from PostgreSQL's terminology.
56/// See <https://wiki.postgresql.org/wiki/CTEReadme#How_Recursion_Works>
57/// This table serves as a mirror or buffer between each iteration of a recursive query.
58#[derive(Debug)]
59pub struct WorkTable {
60    batches: Mutex<Option<ReservedBatches>>,
61    name: String,
62}
63
64impl WorkTable {
65    /// Create a new work table.
66    pub(super) fn new(name: String) -> Self {
67        Self {
68            batches: Mutex::new(None),
69            name,
70        }
71    }
72
73    /// Take the previously written batches from the work table.
74    /// This will be called by the [`WorkTableExec`] when it is executed.
75    fn take(&self) -> Result<ReservedBatches> {
76        self.batches
77            .lock()
78            .unwrap()
79            .take()
80            .ok_or_else(|| internal_datafusion_err!("Unexpected empty work table"))
81    }
82
83    /// Update the results of a recursive query iteration to the work table.
84    pub(super) fn update(&self, batches: ReservedBatches) {
85        self.batches.lock().unwrap().replace(batches);
86    }
87}
88
89/// A temporary "working table" operation where the input data will be
90/// taken from the named handle during the execution and will be re-published
91/// as is (kind of like a mirror).
92///
93/// Most notably used in the implementation of recursive queries where the
94/// underlying relation does not exist yet but the data will come as the previous
95/// term is evaluated. This table will be used such that the recursive plan
96/// will register a receiver in the task context and this plan will use that
97/// receiver to get the data and stream it back up so that the batches are available
98/// in the next iteration.
99#[derive(Clone, Debug)]
100pub struct WorkTableExec {
101    /// Name of the relation handler
102    name: String,
103    /// The schema of the stream
104    schema: SchemaRef,
105    /// Projection to apply to build the output stream from the recursion state
106    projection: Option<Vec<usize>>,
107    /// The work table
108    work_table: Arc<WorkTable>,
109    /// Execution metrics
110    metrics: ExecutionPlanMetricsSet,
111    /// Cache holding plan properties like equivalences, output partitioning etc.
112    cache: PlanProperties,
113}
114
115impl WorkTableExec {
116    /// Create a new execution plan for a worktable exec.
117    pub fn new(
118        name: String,
119        mut schema: SchemaRef,
120        projection: Option<Vec<usize>>,
121    ) -> Result<Self> {
122        if let Some(projection) = &projection {
123            schema = Arc::new(schema.project(projection)?);
124        }
125        let cache = Self::compute_properties(Arc::clone(&schema));
126        Ok(Self {
127            name: name.clone(),
128            schema,
129            projection,
130            work_table: Arc::new(WorkTable::new(name)),
131            metrics: ExecutionPlanMetricsSet::new(),
132            cache,
133        })
134    }
135
136    /// Ref to name
137    pub fn name(&self) -> &str {
138        &self.name
139    }
140
141    /// Arc clone of ref to schema
142    pub fn schema(&self) -> SchemaRef {
143        Arc::clone(&self.schema)
144    }
145
146    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
147    fn compute_properties(schema: SchemaRef) -> PlanProperties {
148        PlanProperties::new(
149            EquivalenceProperties::new(schema),
150            Partitioning::UnknownPartitioning(1),
151            EmissionType::Incremental,
152            Boundedness::Bounded,
153        )
154        .with_scheduling_type(SchedulingType::Cooperative)
155    }
156}
157
158impl DisplayAs for WorkTableExec {
159    fn fmt_as(
160        &self,
161        t: DisplayFormatType,
162        f: &mut std::fmt::Formatter,
163    ) -> std::fmt::Result {
164        match t {
165            DisplayFormatType::Default | DisplayFormatType::Verbose => {
166                write!(f, "WorkTableExec: name={}", self.name)
167            }
168            DisplayFormatType::TreeRender => {
169                write!(f, "name={}", self.name)
170            }
171        }
172    }
173}
174
175impl ExecutionPlan for WorkTableExec {
176    fn name(&self) -> &'static str {
177        "WorkTableExec"
178    }
179
180    fn as_any(&self) -> &dyn Any {
181        self
182    }
183
184    fn properties(&self) -> &PlanProperties {
185        &self.cache
186    }
187
188    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
189        vec![]
190    }
191
192    fn with_new_children(
193        self: Arc<Self>,
194        _: Vec<Arc<dyn ExecutionPlan>>,
195    ) -> Result<Arc<dyn ExecutionPlan>> {
196        Ok(Arc::clone(&self) as Arc<dyn ExecutionPlan>)
197    }
198
199    /// Stream the batches that were written to the work table.
200    fn execute(
201        &self,
202        partition: usize,
203        _context: Arc<TaskContext>,
204    ) -> Result<SendableRecordBatchStream> {
205        // WorkTable streams must be the plan base.
206        assert_eq_or_internal_err!(
207            partition,
208            0,
209            "WorkTableExec got an invalid partition {partition} (expected 0)"
210        );
211        let ReservedBatches {
212            mut batches,
213            reservation,
214        } = self.work_table.take()?;
215        if let Some(projection) = &self.projection {
216            // We apply the projection
217            // TODO: it would be better to apply it as soon as possible and not only here
218            // TODO: an aggressive projection makes the memory reservation smaller, even if we do not edit it
219            batches = batches
220                .into_iter()
221                .map(|b| b.project(projection))
222                .collect::<Result<Vec<_>, _>>()?;
223        }
224
225        let stream = MemoryStream::try_new(batches, Arc::clone(&self.schema), None)?
226            .with_reservation(reservation);
227        Ok(Box::pin(cooperative(stream)))
228    }
229
230    fn metrics(&self) -> Option<MetricsSet> {
231        Some(self.metrics.clone_inner())
232    }
233
234    fn statistics(&self) -> Result<Statistics> {
235        Ok(Statistics::new_unknown(&self.schema()))
236    }
237
238    fn partition_statistics(&self, _partition: Option<usize>) -> Result<Statistics> {
239        Ok(Statistics::new_unknown(&self.schema()))
240    }
241
242    /// Injects run-time state into this `WorkTableExec`.
243    ///
244    /// The only state this node currently understands is an [`Arc<WorkTable>`].
245    /// If `state` can be down-cast to that type, a new `WorkTableExec` backed
246    /// by the provided work table is returned.  Otherwise `None` is returned
247    /// so that callers can attempt to propagate the state further down the
248    /// execution plan tree.
249    fn with_new_state(
250        &self,
251        state: Arc<dyn Any + Send + Sync>,
252    ) -> Option<Arc<dyn ExecutionPlan>> {
253        // Down-cast to the expected state type; propagate `None` on failure
254        let work_table = state.downcast::<WorkTable>().ok()?;
255
256        if work_table.name != self.name {
257            return None; // Different table
258        }
259
260        Some(Arc::new(Self {
261            name: self.name.clone(),
262            schema: Arc::clone(&self.schema),
263            projection: self.projection.clone(),
264            metrics: ExecutionPlanMetricsSet::new(),
265            work_table,
266            cache: self.cache.clone(),
267        }))
268    }
269}
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274    use arrow::array::{ArrayRef, Int16Array, Int32Array, Int64Array};
275    use arrow_schema::{DataType, Field, Schema};
276    use datafusion_execution::memory_pool::{MemoryConsumer, UnboundedMemoryPool};
277    use futures::StreamExt;
278
279    #[test]
280    fn test_work_table() {
281        let work_table = WorkTable::new("test".into());
282        // Can't take from empty work_table
283        assert!(work_table.take().is_err());
284
285        let pool = Arc::new(UnboundedMemoryPool::default()) as _;
286        let mut reservation = MemoryConsumer::new("test_work_table").register(&pool);
287
288        // Update batch to work_table
289        let array: ArrayRef = Arc::new((0..5).collect::<Int32Array>());
290        let batch = RecordBatch::try_from_iter(vec![("col", array)]).unwrap();
291        reservation.try_grow(100).unwrap();
292        work_table.update(ReservedBatches::new(vec![batch.clone()], reservation));
293        // Take from work_table
294        let reserved_batches = work_table.take().unwrap();
295        assert_eq!(reserved_batches.batches, vec![batch.clone()]);
296
297        // Consume the batch by the MemoryStream
298        let memory_stream =
299            MemoryStream::try_new(reserved_batches.batches, batch.schema(), None)
300                .unwrap()
301                .with_reservation(reserved_batches.reservation);
302
303        // Should still be reserved
304        assert_eq!(pool.reserved(), 100);
305
306        // The reservation should be freed after drop the memory_stream
307        drop(memory_stream);
308        assert_eq!(pool.reserved(), 0);
309    }
310
311    #[tokio::test]
312    async fn test_work_table_exec() {
313        let schema = Arc::new(Schema::new(vec![
314            Field::new("a", DataType::Int64, false),
315            Field::new("b", DataType::Int32, false),
316            Field::new("c", DataType::Int16, false),
317        ]));
318        let work_table_exec =
319            WorkTableExec::new("wt".into(), Arc::clone(&schema), Some(vec![2, 1]))
320                .unwrap();
321
322        // We inject the work table
323        let work_table = Arc::new(WorkTable::new("wt".into()));
324        let work_table_exec = work_table_exec
325            .with_new_state(Arc::clone(&work_table) as _)
326            .unwrap();
327
328        // We update the work table
329        let pool = Arc::new(UnboundedMemoryPool::default()) as _;
330        let reservation = MemoryConsumer::new("test_work_table").register(&pool);
331        let batch = RecordBatch::try_new(
332            Arc::clone(&schema),
333            vec![
334                Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5])),
335                Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])),
336                Arc::new(Int16Array::from(vec![1, 2, 3, 4, 5])),
337            ],
338        )
339        .unwrap();
340        work_table.update(ReservedBatches::new(vec![batch], reservation));
341
342        // We get back the batch from the work table
343        let returned_batch = work_table_exec
344            .execute(0, Arc::new(TaskContext::default()))
345            .unwrap()
346            .next()
347            .await
348            .unwrap()
349            .unwrap();
350        assert_eq!(
351            returned_batch,
352            RecordBatch::try_from_iter(vec![
353                ("c", Arc::new(Int16Array::from(vec![1, 2, 3, 4, 5])) as _),
354                ("b", Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])) as _),
355            ])
356            .unwrap()
357        );
358    }
359}