Skip to main content

datafusion_physical_plan/
work_table.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines the work table query plan
19
20use std::any::Any;
21use std::sync::{Arc, Mutex};
22
23use crate::coop::cooperative;
24use crate::execution_plan::{Boundedness, EmissionType, SchedulingType};
25use crate::memory::MemoryStream;
26use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet};
27use crate::{
28    DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties,
29    SendableRecordBatchStream, Statistics,
30};
31
32use arrow::datatypes::SchemaRef;
33use arrow::record_batch::RecordBatch;
34use datafusion_common::{Result, assert_eq_or_internal_err, internal_datafusion_err};
35use datafusion_execution::TaskContext;
36use datafusion_execution::memory_pool::MemoryReservation;
37use datafusion_physical_expr::{EquivalenceProperties, Partitioning};
38
39/// A vector of record batches with a memory reservation.
40#[derive(Debug)]
41pub(super) struct ReservedBatches {
42    batches: Vec<RecordBatch>,
43    reservation: MemoryReservation,
44}
45
46impl ReservedBatches {
47    pub(super) fn new(batches: Vec<RecordBatch>, reservation: MemoryReservation) -> Self {
48        ReservedBatches {
49            batches,
50            reservation,
51        }
52    }
53}
54
55/// The name is from PostgreSQL's terminology.
56/// See <https://wiki.postgresql.org/wiki/CTEReadme#How_Recursion_Works>
57/// This table serves as a mirror or buffer between each iteration of a recursive query.
58#[derive(Debug)]
59pub struct WorkTable {
60    batches: Mutex<Option<ReservedBatches>>,
61    name: String,
62}
63
64impl WorkTable {
65    /// Create a new work table.
66    pub(super) fn new(name: String) -> Self {
67        Self {
68            batches: Mutex::new(None),
69            name,
70        }
71    }
72
73    /// Take the previously written batches from the work table.
74    /// This will be called by the [`WorkTableExec`] when it is executed.
75    fn take(&self) -> Result<ReservedBatches> {
76        self.batches
77            .lock()
78            .unwrap()
79            .take()
80            .ok_or_else(|| internal_datafusion_err!("Unexpected empty work table"))
81    }
82
83    /// Update the results of a recursive query iteration to the work table.
84    pub(super) fn update(&self, batches: ReservedBatches) {
85        self.batches.lock().unwrap().replace(batches);
86    }
87}
88
89/// A temporary "working table" operation where the input data will be
90/// taken from the named handle during the execution and will be re-published
91/// as is (kind of like a mirror).
92///
93/// Most notably used in the implementation of recursive queries where the
94/// underlying relation does not exist yet but the data will come as the previous
95/// term is evaluated. This table will be used such that the recursive plan
96/// will register a receiver in the task context and this plan will use that
97/// receiver to get the data and stream it back up so that the batches are available
98/// in the next iteration.
99#[derive(Clone, Debug)]
100pub struct WorkTableExec {
101    /// Name of the relation handler
102    name: String,
103    /// The schema of the stream
104    schema: SchemaRef,
105    /// Projection to apply to build the output stream from the recursion state
106    projection: Option<Vec<usize>>,
107    /// The work table
108    work_table: Arc<WorkTable>,
109    /// Execution metrics
110    metrics: ExecutionPlanMetricsSet,
111    /// Cache holding plan properties like equivalences, output partitioning etc.
112    cache: Arc<PlanProperties>,
113}
114
115impl WorkTableExec {
116    /// Create a new execution plan for a worktable exec.
117    pub fn new(
118        name: String,
119        mut schema: SchemaRef,
120        projection: Option<Vec<usize>>,
121    ) -> Result<Self> {
122        if let Some(projection) = &projection {
123            schema = Arc::new(schema.project(projection)?);
124        }
125        let cache = Self::compute_properties(Arc::clone(&schema));
126        Ok(Self {
127            name: name.clone(),
128            schema,
129            projection,
130            work_table: Arc::new(WorkTable::new(name)),
131            metrics: ExecutionPlanMetricsSet::new(),
132            cache: Arc::new(cache),
133        })
134    }
135
136    /// Ref to name
137    pub fn name(&self) -> &str {
138        &self.name
139    }
140
141    /// Arc clone of ref to schema
142    pub fn schema(&self) -> SchemaRef {
143        Arc::clone(&self.schema)
144    }
145
146    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
147    fn compute_properties(schema: SchemaRef) -> PlanProperties {
148        PlanProperties::new(
149            EquivalenceProperties::new(schema),
150            Partitioning::UnknownPartitioning(1),
151            EmissionType::Incremental,
152            Boundedness::Bounded,
153        )
154        .with_scheduling_type(SchedulingType::Cooperative)
155    }
156}
157
158impl DisplayAs for WorkTableExec {
159    fn fmt_as(
160        &self,
161        t: DisplayFormatType,
162        f: &mut std::fmt::Formatter,
163    ) -> std::fmt::Result {
164        match t {
165            DisplayFormatType::Default | DisplayFormatType::Verbose => {
166                write!(f, "WorkTableExec: name={}", self.name)
167            }
168            DisplayFormatType::TreeRender => {
169                write!(f, "name={}", self.name)
170            }
171        }
172    }
173}
174
175impl ExecutionPlan for WorkTableExec {
176    fn name(&self) -> &'static str {
177        "WorkTableExec"
178    }
179
180    fn as_any(&self) -> &dyn Any {
181        self
182    }
183
184    fn properties(&self) -> &Arc<PlanProperties> {
185        &self.cache
186    }
187
188    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
189        vec![]
190    }
191
192    fn with_new_children(
193        self: Arc<Self>,
194        _: Vec<Arc<dyn ExecutionPlan>>,
195    ) -> Result<Arc<dyn ExecutionPlan>> {
196        Ok(Arc::clone(&self) as Arc<dyn ExecutionPlan>)
197    }
198
199    /// Stream the batches that were written to the work table.
200    fn execute(
201        &self,
202        partition: usize,
203        _context: Arc<TaskContext>,
204    ) -> Result<SendableRecordBatchStream> {
205        // WorkTable streams must be the plan base.
206        assert_eq_or_internal_err!(
207            partition,
208            0,
209            "WorkTableExec got an invalid partition {partition} (expected 0)"
210        );
211        let ReservedBatches {
212            mut batches,
213            reservation,
214        } = self.work_table.take()?;
215        if let Some(projection) = &self.projection {
216            // We apply the projection
217            // TODO: it would be better to apply it as soon as possible and not only here
218            // TODO: an aggressive projection makes the memory reservation smaller, even if we do not edit it
219            batches = batches
220                .into_iter()
221                .map(|b| b.project(projection))
222                .collect::<Result<Vec<_>, _>>()?;
223        }
224
225        let stream = MemoryStream::try_new(batches, Arc::clone(&self.schema), None)?
226            .with_reservation(reservation);
227        Ok(Box::pin(cooperative(stream)))
228    }
229
230    fn metrics(&self) -> Option<MetricsSet> {
231        Some(self.metrics.clone_inner())
232    }
233
234    fn partition_statistics(&self, _partition: Option<usize>) -> Result<Statistics> {
235        Ok(Statistics::new_unknown(&self.schema()))
236    }
237
238    /// Injects run-time state into this `WorkTableExec`.
239    ///
240    /// The only state this node currently understands is an [`Arc<WorkTable>`].
241    /// If `state` can be down-cast to that type, a new `WorkTableExec` backed
242    /// by the provided work table is returned.  Otherwise `None` is returned
243    /// so that callers can attempt to propagate the state further down the
244    /// execution plan tree.
245    fn with_new_state(
246        &self,
247        state: Arc<dyn Any + Send + Sync>,
248    ) -> Option<Arc<dyn ExecutionPlan>> {
249        // Down-cast to the expected state type; propagate `None` on failure
250        let work_table = state.downcast::<WorkTable>().ok()?;
251
252        if work_table.name != self.name {
253            return None; // Different table
254        }
255
256        Some(Arc::new(Self {
257            name: self.name.clone(),
258            schema: Arc::clone(&self.schema),
259            projection: self.projection.clone(),
260            metrics: ExecutionPlanMetricsSet::new(),
261            work_table,
262            cache: Arc::clone(&self.cache),
263        }))
264    }
265}
266
267#[cfg(test)]
268mod tests {
269    use super::*;
270    use arrow::array::{ArrayRef, Int16Array, Int32Array, Int64Array};
271    use arrow_schema::{DataType, Field, Schema};
272    use datafusion_execution::memory_pool::{MemoryConsumer, UnboundedMemoryPool};
273    use futures::StreamExt;
274
275    #[test]
276    fn test_work_table() {
277        let work_table = WorkTable::new("test".into());
278        // Can't take from empty work_table
279        assert!(work_table.take().is_err());
280
281        let pool = Arc::new(UnboundedMemoryPool::default()) as _;
282        let reservation = MemoryConsumer::new("test_work_table").register(&pool);
283
284        // Update batch to work_table
285        let array: ArrayRef = Arc::new((0..5).collect::<Int32Array>());
286        let batch = RecordBatch::try_from_iter(vec![("col", array)]).unwrap();
287        reservation.try_grow(100).unwrap();
288        work_table.update(ReservedBatches::new(vec![batch.clone()], reservation));
289        // Take from work_table
290        let reserved_batches = work_table.take().unwrap();
291        assert_eq!(reserved_batches.batches, vec![batch.clone()]);
292
293        // Consume the batch by the MemoryStream
294        let memory_stream =
295            MemoryStream::try_new(reserved_batches.batches, batch.schema(), None)
296                .unwrap()
297                .with_reservation(reserved_batches.reservation);
298
299        // Should still be reserved
300        assert_eq!(pool.reserved(), 100);
301
302        // The reservation should be freed after drop the memory_stream
303        drop(memory_stream);
304        assert_eq!(pool.reserved(), 0);
305    }
306
307    #[tokio::test]
308    async fn test_work_table_exec() {
309        let schema = Arc::new(Schema::new(vec![
310            Field::new("a", DataType::Int64, false),
311            Field::new("b", DataType::Int32, false),
312            Field::new("c", DataType::Int16, false),
313        ]));
314        let work_table_exec =
315            WorkTableExec::new("wt".into(), Arc::clone(&schema), Some(vec![2, 1]))
316                .unwrap();
317
318        // We inject the work table
319        let work_table = Arc::new(WorkTable::new("wt".into()));
320        let work_table_exec = work_table_exec
321            .with_new_state(Arc::clone(&work_table) as _)
322            .unwrap();
323
324        // We update the work table
325        let pool = Arc::new(UnboundedMemoryPool::default()) as _;
326        let reservation = MemoryConsumer::new("test_work_table").register(&pool);
327        let batch = RecordBatch::try_new(
328            Arc::clone(&schema),
329            vec![
330                Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5])),
331                Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])),
332                Arc::new(Int16Array::from(vec![1, 2, 3, 4, 5])),
333            ],
334        )
335        .unwrap();
336        work_table.update(ReservedBatches::new(vec![batch], reservation));
337
338        // We get back the batch from the work table
339        let returned_batch = work_table_exec
340            .execute(0, Arc::new(TaskContext::default()))
341            .unwrap()
342            .next()
343            .await
344            .unwrap()
345            .unwrap();
346        assert_eq!(
347            returned_batch,
348            RecordBatch::try_from_iter(vec![
349                ("c", Arc::new(Int16Array::from(vec![1, 2, 3, 4, 5])) as _),
350                ("b", Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])) as _),
351            ])
352            .unwrap()
353        );
354    }
355}