datafusion_physical_plan/
work_table.rs1use std::any::Any;
21use std::sync::{Arc, Mutex};
22
23use crate::coop::cooperative;
24use crate::execution_plan::{Boundedness, EmissionType, SchedulingType};
25use crate::memory::MemoryStream;
26use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet};
27use crate::{
28 DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties,
29 SendableRecordBatchStream, Statistics,
30};
31
32use arrow::datatypes::SchemaRef;
33use arrow::record_batch::RecordBatch;
34use datafusion_common::{Result, assert_eq_or_internal_err, internal_datafusion_err};
35use datafusion_execution::TaskContext;
36use datafusion_execution::memory_pool::MemoryReservation;
37use datafusion_physical_expr::{EquivalenceProperties, Partitioning};
38
39#[derive(Debug)]
41pub(super) struct ReservedBatches {
42 batches: Vec<RecordBatch>,
43 reservation: MemoryReservation,
44}
45
46impl ReservedBatches {
47 pub(super) fn new(batches: Vec<RecordBatch>, reservation: MemoryReservation) -> Self {
48 ReservedBatches {
49 batches,
50 reservation,
51 }
52 }
53}
54
55#[derive(Debug)]
59pub struct WorkTable {
60 batches: Mutex<Option<ReservedBatches>>,
61 name: String,
62}
63
64impl WorkTable {
65 pub(super) fn new(name: String) -> Self {
67 Self {
68 batches: Mutex::new(None),
69 name,
70 }
71 }
72
73 fn take(&self) -> Result<ReservedBatches> {
76 self.batches
77 .lock()
78 .unwrap()
79 .take()
80 .ok_or_else(|| internal_datafusion_err!("Unexpected empty work table"))
81 }
82
83 pub(super) fn update(&self, batches: ReservedBatches) {
85 self.batches.lock().unwrap().replace(batches);
86 }
87}
88
89#[derive(Clone, Debug)]
100pub struct WorkTableExec {
101 name: String,
103 schema: SchemaRef,
105 projection: Option<Vec<usize>>,
107 work_table: Arc<WorkTable>,
109 metrics: ExecutionPlanMetricsSet,
111 cache: Arc<PlanProperties>,
113}
114
115impl WorkTableExec {
116 pub fn new(
118 name: String,
119 mut schema: SchemaRef,
120 projection: Option<Vec<usize>>,
121 ) -> Result<Self> {
122 if let Some(projection) = &projection {
123 schema = Arc::new(schema.project(projection)?);
124 }
125 let cache = Self::compute_properties(Arc::clone(&schema));
126 Ok(Self {
127 name: name.clone(),
128 schema,
129 projection,
130 work_table: Arc::new(WorkTable::new(name)),
131 metrics: ExecutionPlanMetricsSet::new(),
132 cache: Arc::new(cache),
133 })
134 }
135
136 pub fn name(&self) -> &str {
138 &self.name
139 }
140
141 pub fn schema(&self) -> SchemaRef {
143 Arc::clone(&self.schema)
144 }
145
146 fn compute_properties(schema: SchemaRef) -> PlanProperties {
148 PlanProperties::new(
149 EquivalenceProperties::new(schema),
150 Partitioning::UnknownPartitioning(1),
151 EmissionType::Incremental,
152 Boundedness::Bounded,
153 )
154 .with_scheduling_type(SchedulingType::Cooperative)
155 }
156}
157
158impl DisplayAs for WorkTableExec {
159 fn fmt_as(
160 &self,
161 t: DisplayFormatType,
162 f: &mut std::fmt::Formatter,
163 ) -> std::fmt::Result {
164 match t {
165 DisplayFormatType::Default | DisplayFormatType::Verbose => {
166 write!(f, "WorkTableExec: name={}", self.name)
167 }
168 DisplayFormatType::TreeRender => {
169 write!(f, "name={}", self.name)
170 }
171 }
172 }
173}
174
175impl ExecutionPlan for WorkTableExec {
176 fn name(&self) -> &'static str {
177 "WorkTableExec"
178 }
179
180 fn properties(&self) -> &Arc<PlanProperties> {
181 &self.cache
182 }
183
184 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
185 vec![]
186 }
187
188 fn with_new_children(
189 self: Arc<Self>,
190 _: Vec<Arc<dyn ExecutionPlan>>,
191 ) -> Result<Arc<dyn ExecutionPlan>> {
192 Ok(Arc::clone(&self) as Arc<dyn ExecutionPlan>)
193 }
194
195 fn execute(
197 &self,
198 partition: usize,
199 _context: Arc<TaskContext>,
200 ) -> Result<SendableRecordBatchStream> {
201 assert_eq_or_internal_err!(
203 partition,
204 0,
205 "WorkTableExec got an invalid partition {partition} (expected 0)"
206 );
207 let ReservedBatches {
208 mut batches,
209 reservation,
210 } = self.work_table.take()?;
211 if let Some(projection) = &self.projection {
212 batches = batches
216 .into_iter()
217 .map(|b| b.project(projection))
218 .collect::<Result<Vec<_>, _>>()?;
219 }
220
221 let stream = MemoryStream::try_new(batches, Arc::clone(&self.schema), None)?
222 .with_reservation(reservation);
223 Ok(Box::pin(cooperative(stream)))
224 }
225
226 fn metrics(&self) -> Option<MetricsSet> {
227 Some(self.metrics.clone_inner())
228 }
229
230 fn partition_statistics(&self, _partition: Option<usize>) -> Result<Arc<Statistics>> {
231 Ok(Arc::new(Statistics::new_unknown(&self.schema())))
232 }
233
234 fn with_new_state(
242 &self,
243 state: Arc<dyn Any + Send + Sync>,
244 ) -> Option<Arc<dyn ExecutionPlan>> {
245 let work_table = state.downcast::<WorkTable>().ok()?;
247
248 if work_table.name != self.name {
249 return None; }
251
252 Some(Arc::new(Self {
253 name: self.name.clone(),
254 schema: Arc::clone(&self.schema),
255 projection: self.projection.clone(),
256 metrics: ExecutionPlanMetricsSet::new(),
257 work_table,
258 cache: Arc::clone(&self.cache),
259 }))
260 }
261}
262
263#[cfg(test)]
264mod tests {
265 use super::*;
266 use arrow::array::{ArrayRef, Int16Array, Int32Array, Int64Array};
267 use arrow_schema::{DataType, Field, Schema};
268 use datafusion_execution::memory_pool::{MemoryConsumer, UnboundedMemoryPool};
269 use futures::StreamExt;
270
271 #[test]
272 fn test_work_table() {
273 let work_table = WorkTable::new("test".into());
274 assert!(work_table.take().is_err());
276
277 let pool = Arc::new(UnboundedMemoryPool::default()) as _;
278 let reservation = MemoryConsumer::new("test_work_table").register(&pool);
279
280 let array: ArrayRef = Arc::new((0..5).collect::<Int32Array>());
282 let batch = RecordBatch::try_from_iter(vec![("col", array)]).unwrap();
283 reservation.try_grow(100).unwrap();
284 work_table.update(ReservedBatches::new(vec![batch.clone()], reservation));
285 let reserved_batches = work_table.take().unwrap();
287 assert_eq!(reserved_batches.batches, vec![batch.clone()]);
288
289 let memory_stream =
291 MemoryStream::try_new(reserved_batches.batches, batch.schema(), None)
292 .unwrap()
293 .with_reservation(reserved_batches.reservation);
294
295 assert_eq!(pool.reserved(), 100);
297
298 drop(memory_stream);
300 assert_eq!(pool.reserved(), 0);
301 }
302
303 #[tokio::test]
304 async fn test_work_table_exec() {
305 let schema = Arc::new(Schema::new(vec![
306 Field::new("a", DataType::Int64, false),
307 Field::new("b", DataType::Int32, false),
308 Field::new("c", DataType::Int16, false),
309 ]));
310 let work_table_exec =
311 WorkTableExec::new("wt".into(), Arc::clone(&schema), Some(vec![2, 1]))
312 .unwrap();
313
314 let work_table = Arc::new(WorkTable::new("wt".into()));
316 let work_table_exec = work_table_exec
317 .with_new_state(Arc::clone(&work_table) as _)
318 .unwrap();
319
320 let pool = Arc::new(UnboundedMemoryPool::default()) as _;
322 let reservation = MemoryConsumer::new("test_work_table").register(&pool);
323 let batch = RecordBatch::try_new(
324 Arc::clone(&schema),
325 vec![
326 Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5])),
327 Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])),
328 Arc::new(Int16Array::from(vec![1, 2, 3, 4, 5])),
329 ],
330 )
331 .unwrap();
332 work_table.update(ReservedBatches::new(vec![batch], reservation));
333
334 let returned_batch = work_table_exec
336 .execute(0, Arc::new(TaskContext::default()))
337 .unwrap()
338 .next()
339 .await
340 .unwrap()
341 .unwrap();
342 assert_eq!(
343 returned_batch,
344 RecordBatch::try_from_iter(vec![
345 ("c", Arc::new(Int16Array::from(vec![1, 2, 3, 4, 5])) as _),
346 ("b", Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])) as _),
347 ])
348 .unwrap()
349 );
350 }
351}