datafusion_physical_plan/
work_table.rs1use std::any::Any;
21use std::sync::{Arc, Mutex};
22
23use crate::coop::cooperative;
24use crate::execution_plan::{Boundedness, EmissionType, SchedulingType};
25use crate::memory::MemoryStream;
26use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet};
27use crate::{
28 DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties,
29 SendableRecordBatchStream, Statistics,
30};
31
32use arrow::datatypes::SchemaRef;
33use arrow::record_batch::RecordBatch;
34use datafusion_common::{Result, assert_eq_or_internal_err, internal_datafusion_err};
35use datafusion_execution::TaskContext;
36use datafusion_execution::memory_pool::MemoryReservation;
37use datafusion_physical_expr::{EquivalenceProperties, Partitioning};
38
39#[derive(Debug)]
41pub(super) struct ReservedBatches {
42 batches: Vec<RecordBatch>,
43 reservation: MemoryReservation,
44}
45
46impl ReservedBatches {
47 pub(super) fn new(batches: Vec<RecordBatch>, reservation: MemoryReservation) -> Self {
48 ReservedBatches {
49 batches,
50 reservation,
51 }
52 }
53}
54
55#[derive(Debug)]
59pub struct WorkTable {
60 batches: Mutex<Option<ReservedBatches>>,
61 name: String,
62}
63
64impl WorkTable {
65 pub(super) fn new(name: String) -> Self {
67 Self {
68 batches: Mutex::new(None),
69 name,
70 }
71 }
72
73 fn take(&self) -> Result<ReservedBatches> {
76 self.batches
77 .lock()
78 .unwrap()
79 .take()
80 .ok_or_else(|| internal_datafusion_err!("Unexpected empty work table"))
81 }
82
83 pub(super) fn update(&self, batches: ReservedBatches) {
85 self.batches.lock().unwrap().replace(batches);
86 }
87}
88
89#[derive(Clone, Debug)]
100pub struct WorkTableExec {
101 name: String,
103 schema: SchemaRef,
105 projection: Option<Vec<usize>>,
107 work_table: Arc<WorkTable>,
109 metrics: ExecutionPlanMetricsSet,
111 cache: PlanProperties,
113}
114
115impl WorkTableExec {
116 pub fn new(
118 name: String,
119 mut schema: SchemaRef,
120 projection: Option<Vec<usize>>,
121 ) -> Result<Self> {
122 if let Some(projection) = &projection {
123 schema = Arc::new(schema.project(projection)?);
124 }
125 let cache = Self::compute_properties(Arc::clone(&schema));
126 Ok(Self {
127 name: name.clone(),
128 schema,
129 projection,
130 work_table: Arc::new(WorkTable::new(name)),
131 metrics: ExecutionPlanMetricsSet::new(),
132 cache,
133 })
134 }
135
136 pub fn name(&self) -> &str {
138 &self.name
139 }
140
141 pub fn schema(&self) -> SchemaRef {
143 Arc::clone(&self.schema)
144 }
145
146 fn compute_properties(schema: SchemaRef) -> PlanProperties {
148 PlanProperties::new(
149 EquivalenceProperties::new(schema),
150 Partitioning::UnknownPartitioning(1),
151 EmissionType::Incremental,
152 Boundedness::Bounded,
153 )
154 .with_scheduling_type(SchedulingType::Cooperative)
155 }
156}
157
158impl DisplayAs for WorkTableExec {
159 fn fmt_as(
160 &self,
161 t: DisplayFormatType,
162 f: &mut std::fmt::Formatter,
163 ) -> std::fmt::Result {
164 match t {
165 DisplayFormatType::Default | DisplayFormatType::Verbose => {
166 write!(f, "WorkTableExec: name={}", self.name)
167 }
168 DisplayFormatType::TreeRender => {
169 write!(f, "name={}", self.name)
170 }
171 }
172 }
173}
174
175impl ExecutionPlan for WorkTableExec {
176 fn name(&self) -> &'static str {
177 "WorkTableExec"
178 }
179
180 fn as_any(&self) -> &dyn Any {
181 self
182 }
183
184 fn properties(&self) -> &PlanProperties {
185 &self.cache
186 }
187
188 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
189 vec![]
190 }
191
192 fn with_new_children(
193 self: Arc<Self>,
194 _: Vec<Arc<dyn ExecutionPlan>>,
195 ) -> Result<Arc<dyn ExecutionPlan>> {
196 Ok(Arc::clone(&self) as Arc<dyn ExecutionPlan>)
197 }
198
199 fn execute(
201 &self,
202 partition: usize,
203 _context: Arc<TaskContext>,
204 ) -> Result<SendableRecordBatchStream> {
205 assert_eq_or_internal_err!(
207 partition,
208 0,
209 "WorkTableExec got an invalid partition {partition} (expected 0)"
210 );
211 let ReservedBatches {
212 mut batches,
213 reservation,
214 } = self.work_table.take()?;
215 if let Some(projection) = &self.projection {
216 batches = batches
220 .into_iter()
221 .map(|b| b.project(projection))
222 .collect::<Result<Vec<_>, _>>()?;
223 }
224
225 let stream = MemoryStream::try_new(batches, Arc::clone(&self.schema), None)?
226 .with_reservation(reservation);
227 Ok(Box::pin(cooperative(stream)))
228 }
229
230 fn metrics(&self) -> Option<MetricsSet> {
231 Some(self.metrics.clone_inner())
232 }
233
234 fn statistics(&self) -> Result<Statistics> {
235 Ok(Statistics::new_unknown(&self.schema()))
236 }
237
238 fn partition_statistics(&self, _partition: Option<usize>) -> Result<Statistics> {
239 Ok(Statistics::new_unknown(&self.schema()))
240 }
241
242 fn with_new_state(
250 &self,
251 state: Arc<dyn Any + Send + Sync>,
252 ) -> Option<Arc<dyn ExecutionPlan>> {
253 let work_table = state.downcast::<WorkTable>().ok()?;
255
256 if work_table.name != self.name {
257 return None; }
259
260 Some(Arc::new(Self {
261 name: self.name.clone(),
262 schema: Arc::clone(&self.schema),
263 projection: self.projection.clone(),
264 metrics: ExecutionPlanMetricsSet::new(),
265 work_table,
266 cache: self.cache.clone(),
267 }))
268 }
269}
270
271#[cfg(test)]
272mod tests {
273 use super::*;
274 use arrow::array::{ArrayRef, Int16Array, Int32Array, Int64Array};
275 use arrow_schema::{DataType, Field, Schema};
276 use datafusion_execution::memory_pool::{MemoryConsumer, UnboundedMemoryPool};
277 use futures::StreamExt;
278
279 #[test]
280 fn test_work_table() {
281 let work_table = WorkTable::new("test".into());
282 assert!(work_table.take().is_err());
284
285 let pool = Arc::new(UnboundedMemoryPool::default()) as _;
286 let mut reservation = MemoryConsumer::new("test_work_table").register(&pool);
287
288 let array: ArrayRef = Arc::new((0..5).collect::<Int32Array>());
290 let batch = RecordBatch::try_from_iter(vec![("col", array)]).unwrap();
291 reservation.try_grow(100).unwrap();
292 work_table.update(ReservedBatches::new(vec![batch.clone()], reservation));
293 let reserved_batches = work_table.take().unwrap();
295 assert_eq!(reserved_batches.batches, vec![batch.clone()]);
296
297 let memory_stream =
299 MemoryStream::try_new(reserved_batches.batches, batch.schema(), None)
300 .unwrap()
301 .with_reservation(reserved_batches.reservation);
302
303 assert_eq!(pool.reserved(), 100);
305
306 drop(memory_stream);
308 assert_eq!(pool.reserved(), 0);
309 }
310
311 #[tokio::test]
312 async fn test_work_table_exec() {
313 let schema = Arc::new(Schema::new(vec![
314 Field::new("a", DataType::Int64, false),
315 Field::new("b", DataType::Int32, false),
316 Field::new("c", DataType::Int16, false),
317 ]));
318 let work_table_exec =
319 WorkTableExec::new("wt".into(), Arc::clone(&schema), Some(vec![2, 1]))
320 .unwrap();
321
322 let work_table = Arc::new(WorkTable::new("wt".into()));
324 let work_table_exec = work_table_exec
325 .with_new_state(Arc::clone(&work_table) as _)
326 .unwrap();
327
328 let pool = Arc::new(UnboundedMemoryPool::default()) as _;
330 let reservation = MemoryConsumer::new("test_work_table").register(&pool);
331 let batch = RecordBatch::try_new(
332 Arc::clone(&schema),
333 vec![
334 Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5])),
335 Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])),
336 Arc::new(Int16Array::from(vec![1, 2, 3, 4, 5])),
337 ],
338 )
339 .unwrap();
340 work_table.update(ReservedBatches::new(vec![batch], reservation));
341
342 let returned_batch = work_table_exec
344 .execute(0, Arc::new(TaskContext::default()))
345 .unwrap()
346 .next()
347 .await
348 .unwrap()
349 .unwrap();
350 assert_eq!(
351 returned_batch,
352 RecordBatch::try_from_iter(vec![
353 ("c", Arc::new(Int16Array::from(vec![1, 2, 3, 4, 5])) as _),
354 ("b", Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])) as _),
355 ])
356 .unwrap()
357 );
358 }
359}