datafusion_physical_plan/
memory.rs1use std::any::Any;
21use std::fmt;
22use std::sync::Arc;
23use std::task::{Context, Poll};
24
25use crate::coop::cooperative;
26use crate::execution_plan::{Boundedness, EmissionType, SchedulingType};
27use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
28use crate::{
29 DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties,
30 RecordBatchStream, SendableRecordBatchStream, Statistics,
31};
32
33use arrow::array::RecordBatch;
34use arrow::datatypes::SchemaRef;
35use datafusion_common::{internal_err, Result};
36use datafusion_execution::memory_pool::MemoryReservation;
37use datafusion_execution::TaskContext;
38use datafusion_physical_expr::EquivalenceProperties;
39
40use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
41use futures::Stream;
42use parking_lot::RwLock;
43
44pub struct MemoryStream {
46 data: Vec<RecordBatch>,
48 reservation: Option<MemoryReservation>,
50 schema: SchemaRef,
52 projection: Option<Vec<usize>>,
54 index: usize,
56 fetch: Option<usize>,
58}
59
60impl MemoryStream {
61 pub fn try_new(
63 data: Vec<RecordBatch>,
64 schema: SchemaRef,
65 projection: Option<Vec<usize>>,
66 ) -> Result<Self> {
67 Ok(Self {
68 data,
69 reservation: None,
70 schema,
71 projection,
72 index: 0,
73 fetch: None,
74 })
75 }
76
77 pub fn with_reservation(mut self, reservation: MemoryReservation) -> Self {
79 self.reservation = Some(reservation);
80 self
81 }
82
83 pub fn with_fetch(mut self, fetch: Option<usize>) -> Self {
85 self.fetch = fetch;
86 self
87 }
88}
89
90impl Stream for MemoryStream {
91 type Item = Result<RecordBatch>;
92
93 fn poll_next(
94 mut self: std::pin::Pin<&mut Self>,
95 _: &mut Context<'_>,
96 ) -> Poll<Option<Self::Item>> {
97 if self.index >= self.data.len() {
98 return Poll::Ready(None);
99 }
100 self.index += 1;
101 let batch = &self.data[self.index - 1];
102 let batch = match self.projection.as_ref() {
104 Some(columns) => batch.project(columns)?,
105 None => batch.clone(),
106 };
107
108 let Some(&fetch) = self.fetch.as_ref() else {
109 return Poll::Ready(Some(Ok(batch)));
110 };
111 if fetch == 0 {
112 return Poll::Ready(None);
113 }
114
115 let batch = if batch.num_rows() > fetch {
116 batch.slice(0, fetch)
117 } else {
118 batch
119 };
120 self.fetch = Some(fetch - batch.num_rows());
121 Poll::Ready(Some(Ok(batch)))
122 }
123
124 fn size_hint(&self) -> (usize, Option<usize>) {
125 (self.data.len(), Some(self.data.len()))
126 }
127}
128
129impl RecordBatchStream for MemoryStream {
130 fn schema(&self) -> SchemaRef {
132 Arc::clone(&self.schema)
133 }
134}
135
136pub trait LazyBatchGenerator: Send + Sync + fmt::Debug + fmt::Display {
137 fn as_any(&self) -> &dyn Any;
140
141 fn boundedness(&self) -> Boundedness {
142 Boundedness::Bounded
143 }
144
145 fn generate_next_batch(&mut self) -> Result<Option<RecordBatch>>;
147}
148
149pub struct LazyMemoryExec {
154 schema: SchemaRef,
156 batch_generators: Vec<Arc<RwLock<dyn LazyBatchGenerator>>>,
158 cache: PlanProperties,
160 metrics: ExecutionPlanMetricsSet,
162}
163
164impl LazyMemoryExec {
165 pub fn try_new(
167 schema: SchemaRef,
168 generators: Vec<Arc<RwLock<dyn LazyBatchGenerator>>>,
169 ) -> Result<Self> {
170 let boundedness = generators
171 .iter()
172 .map(|g| g.read().boundedness())
173 .reduce(|acc, b| match acc {
174 Boundedness::Bounded => b,
175 Boundedness::Unbounded {
176 requires_infinite_memory,
177 } => {
178 let acc_infinite_memory = requires_infinite_memory;
179 match b {
180 Boundedness::Bounded => acc,
181 Boundedness::Unbounded {
182 requires_infinite_memory,
183 } => Boundedness::Unbounded {
184 requires_infinite_memory: requires_infinite_memory
185 || acc_infinite_memory,
186 },
187 }
188 }
189 })
190 .unwrap_or(Boundedness::Bounded);
191
192 let cache = PlanProperties::new(
193 EquivalenceProperties::new(Arc::clone(&schema)),
194 Partitioning::RoundRobinBatch(generators.len()),
195 EmissionType::Incremental,
196 boundedness,
197 )
198 .with_scheduling_type(SchedulingType::Cooperative);
199
200 Ok(Self {
201 schema,
202 batch_generators: generators,
203 cache,
204 metrics: ExecutionPlanMetricsSet::new(),
205 })
206 }
207
208 pub fn try_set_partitioning(&mut self, partitioning: Partitioning) -> Result<()> {
209 if partitioning.partition_count() != self.batch_generators.len() {
210 internal_err!(
211 "Partition count must match generator count: {} != {}",
212 partitioning.partition_count(),
213 self.batch_generators.len()
214 )
215 } else {
216 self.cache.partitioning = partitioning;
217 Ok(())
218 }
219 }
220
221 pub fn add_ordering(&mut self, ordering: impl IntoIterator<Item = PhysicalSortExpr>) {
222 self.cache
223 .eq_properties
224 .add_orderings(std::iter::once(ordering));
225 }
226
227 pub fn generators(&self) -> &Vec<Arc<RwLock<dyn LazyBatchGenerator>>> {
229 &self.batch_generators
230 }
231}
232
233impl fmt::Debug for LazyMemoryExec {
234 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
235 f.debug_struct("LazyMemoryExec")
236 .field("schema", &self.schema)
237 .field("batch_generators", &self.batch_generators)
238 .finish()
239 }
240}
241
242impl DisplayAs for LazyMemoryExec {
243 fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result {
244 match t {
245 DisplayFormatType::Default | DisplayFormatType::Verbose => {
246 write!(
247 f,
248 "LazyMemoryExec: partitions={}, batch_generators=[{}]",
249 self.batch_generators.len(),
250 self.batch_generators
251 .iter()
252 .map(|g| g.read().to_string())
253 .collect::<Vec<_>>()
254 .join(", ")
255 )
256 }
257 DisplayFormatType::TreeRender => {
258 writeln!(
260 f,
261 "batch_generators={}",
262 self.batch_generators
263 .iter()
264 .map(|g| g.read().to_string())
265 .collect::<Vec<String>>()
266 .join(", ")
267 )?;
268 Ok(())
269 }
270 }
271 }
272}
273
274impl ExecutionPlan for LazyMemoryExec {
275 fn name(&self) -> &'static str {
276 "LazyMemoryExec"
277 }
278
279 fn as_any(&self) -> &dyn Any {
280 self
281 }
282
283 fn schema(&self) -> SchemaRef {
284 Arc::clone(&self.schema)
285 }
286
287 fn properties(&self) -> &PlanProperties {
288 &self.cache
289 }
290
291 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
292 vec![]
293 }
294
295 fn with_new_children(
296 self: Arc<Self>,
297 children: Vec<Arc<dyn ExecutionPlan>>,
298 ) -> Result<Arc<dyn ExecutionPlan>> {
299 if children.is_empty() {
300 Ok(self)
301 } else {
302 internal_err!("Children cannot be replaced in LazyMemoryExec")
303 }
304 }
305
306 fn execute(
307 &self,
308 partition: usize,
309 _context: Arc<TaskContext>,
310 ) -> Result<SendableRecordBatchStream> {
311 if partition >= self.batch_generators.len() {
312 return internal_err!(
313 "Invalid partition {} for LazyMemoryExec with {} partitions",
314 partition,
315 self.batch_generators.len()
316 );
317 }
318
319 let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
320
321 let stream = LazyMemoryStream {
322 schema: Arc::clone(&self.schema),
323 generator: Arc::clone(&self.batch_generators[partition]),
324 baseline_metrics,
325 };
326 Ok(Box::pin(cooperative(stream)))
327 }
328
329 fn metrics(&self) -> Option<MetricsSet> {
330 Some(self.metrics.clone_inner())
331 }
332
333 fn statistics(&self) -> Result<Statistics> {
334 Ok(Statistics::new_unknown(&self.schema))
335 }
336}
337
338pub struct LazyMemoryStream {
340 schema: SchemaRef,
341 generator: Arc<RwLock<dyn LazyBatchGenerator>>,
349 baseline_metrics: BaselineMetrics,
351}
352
353impl Stream for LazyMemoryStream {
354 type Item = Result<RecordBatch>;
355
356 fn poll_next(
357 self: std::pin::Pin<&mut Self>,
358 _: &mut Context<'_>,
359 ) -> Poll<Option<Self::Item>> {
360 let _timer_guard = self.baseline_metrics.elapsed_compute().timer();
361 let batch = self.generator.write().generate_next_batch();
362
363 let poll = match batch {
364 Ok(Some(batch)) => Poll::Ready(Some(Ok(batch))),
365 Ok(None) => Poll::Ready(None),
366 Err(e) => Poll::Ready(Some(Err(e))),
367 };
368
369 self.baseline_metrics.record_poll(poll)
370 }
371}
372
373impl RecordBatchStream for LazyMemoryStream {
374 fn schema(&self) -> SchemaRef {
375 Arc::clone(&self.schema)
376 }
377}
378
379#[cfg(test)]
380mod lazy_memory_tests {
381 use super::*;
382 use crate::common::collect;
383 use arrow::array::Int64Array;
384 use arrow::datatypes::{DataType, Field, Schema};
385 use futures::StreamExt;
386
387 #[derive(Debug, Clone)]
388 struct TestGenerator {
389 counter: i64,
390 max_batches: i64,
391 batch_size: usize,
392 schema: SchemaRef,
393 }
394
395 impl fmt::Display for TestGenerator {
396 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
397 write!(
398 f,
399 "TestGenerator: counter={}, max_batches={}, batch_size={}",
400 self.counter, self.max_batches, self.batch_size
401 )
402 }
403 }
404
405 impl LazyBatchGenerator for TestGenerator {
406 fn as_any(&self) -> &dyn Any {
407 self
408 }
409
410 fn generate_next_batch(&mut self) -> Result<Option<RecordBatch>> {
411 if self.counter >= self.max_batches {
412 return Ok(None);
413 }
414
415 let array = Int64Array::from_iter_values(
416 (self.counter * self.batch_size as i64)
417 ..(self.counter * self.batch_size as i64 + self.batch_size as i64),
418 );
419 self.counter += 1;
420 Ok(Some(RecordBatch::try_new(
421 Arc::clone(&self.schema),
422 vec![Arc::new(array)],
423 )?))
424 }
425 }
426
427 #[tokio::test]
428 async fn test_lazy_memory_exec() -> Result<()> {
429 let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)]));
430 let generator = TestGenerator {
431 counter: 0,
432 max_batches: 3,
433 batch_size: 2,
434 schema: Arc::clone(&schema),
435 };
436
437 let exec =
438 LazyMemoryExec::try_new(schema, vec![Arc::new(RwLock::new(generator))])?;
439
440 assert_eq!(exec.schema().fields().len(), 1);
442 assert_eq!(exec.schema().field(0).name(), "a");
443
444 let stream = exec.execute(0, Arc::new(TaskContext::default()))?;
446 let batches: Vec<_> = stream.collect::<Vec<_>>().await;
447
448 assert_eq!(batches.len(), 3);
449
450 let batch0 = batches[0].as_ref().unwrap();
452 let array0 = batch0
453 .column(0)
454 .as_any()
455 .downcast_ref::<Int64Array>()
456 .unwrap();
457 assert_eq!(array0.values(), &[0, 1]);
458
459 let batch1 = batches[1].as_ref().unwrap();
460 let array1 = batch1
461 .column(0)
462 .as_any()
463 .downcast_ref::<Int64Array>()
464 .unwrap();
465 assert_eq!(array1.values(), &[2, 3]);
466
467 let batch2 = batches[2].as_ref().unwrap();
468 let array2 = batch2
469 .column(0)
470 .as_any()
471 .downcast_ref::<Int64Array>()
472 .unwrap();
473 assert_eq!(array2.values(), &[4, 5]);
474
475 Ok(())
476 }
477
478 #[tokio::test]
479 async fn test_lazy_memory_exec_invalid_partition() -> Result<()> {
480 let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)]));
481 let generator = TestGenerator {
482 counter: 0,
483 max_batches: 1,
484 batch_size: 1,
485 schema: Arc::clone(&schema),
486 };
487
488 let exec =
489 LazyMemoryExec::try_new(schema, vec![Arc::new(RwLock::new(generator))])?;
490
491 let result = exec.execute(1, Arc::new(TaskContext::default()));
493
494 assert!(matches!(
496 result,
497 Err(e) if e.to_string().contains("Invalid partition 1 for LazyMemoryExec with 1 partitions")
498 ));
499
500 Ok(())
501 }
502
503 #[tokio::test]
504 async fn test_generate_series_metrics_integration() -> Result<()> {
505 let test_cases = vec![
507 (10, 2, 10), (100, 10, 100), (5, 1, 5), ];
511
512 for (total_rows, batch_size, expected_rows) in test_cases {
513 let schema =
514 Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)]));
515 let generator = TestGenerator {
516 counter: 0,
517 max_batches: (total_rows + batch_size - 1) / batch_size, batch_size: batch_size as usize,
519 schema: Arc::clone(&schema),
520 };
521
522 let exec =
523 LazyMemoryExec::try_new(schema, vec![Arc::new(RwLock::new(generator))])?;
524 let task_ctx = Arc::new(TaskContext::default());
525
526 let stream = exec.execute(0, task_ctx)?;
527 let batches = collect(stream).await?;
528
529 let metrics = exec.metrics().unwrap();
531
532 let actual_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
534 assert_eq!(actual_rows, expected_rows);
535
536 assert_eq!(metrics.output_rows().unwrap(), expected_rows);
538 assert!(metrics.elapsed_compute().unwrap() > 0);
539 }
540
541 Ok(())
542 }
543}