1use std::any::Any;
21use std::fmt;
22use std::sync::Arc;
23use std::task::{Context, Poll};
24
25use crate::coop::cooperative;
26use crate::execution_plan::{Boundedness, EmissionType, SchedulingType};
27use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
28use crate::{
29 DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties,
30 RecordBatchStream, SendableRecordBatchStream, Statistics,
31};
32
33use arrow::array::RecordBatch;
34use arrow::datatypes::SchemaRef;
35use datafusion_common::{Result, assert_eq_or_internal_err, assert_or_internal_err};
36use datafusion_execution::TaskContext;
37use datafusion_execution::memory_pool::MemoryReservation;
38use datafusion_physical_expr::EquivalenceProperties;
39
40use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
41use futures::Stream;
42use parking_lot::RwLock;
43
44pub struct MemoryStream {
46 data: Vec<RecordBatch>,
48 reservation: Option<MemoryReservation>,
50 schema: SchemaRef,
52 projection: Option<Vec<usize>>,
54 index: usize,
56 fetch: Option<usize>,
58}
59
60impl MemoryStream {
61 pub fn try_new(
63 data: Vec<RecordBatch>,
64 schema: SchemaRef,
65 projection: Option<Vec<usize>>,
66 ) -> Result<Self> {
67 Ok(Self {
68 data,
69 reservation: None,
70 schema,
71 projection,
72 index: 0,
73 fetch: None,
74 })
75 }
76
77 pub fn with_reservation(mut self, reservation: MemoryReservation) -> Self {
79 self.reservation = Some(reservation);
80 self
81 }
82
83 pub fn with_fetch(mut self, fetch: Option<usize>) -> Self {
85 self.fetch = fetch;
86 self
87 }
88}
89
90impl Stream for MemoryStream {
91 type Item = Result<RecordBatch>;
92
93 fn poll_next(
94 mut self: std::pin::Pin<&mut Self>,
95 _: &mut Context<'_>,
96 ) -> Poll<Option<Self::Item>> {
97 if self.index >= self.data.len() {
98 return Poll::Ready(None);
99 }
100 self.index += 1;
101 let batch = &self.data[self.index - 1];
102 let batch = match self.projection.as_ref() {
104 Some(columns) => batch.project(columns)?,
105 None => batch.clone(),
106 };
107
108 let Some(&fetch) = self.fetch.as_ref() else {
109 return Poll::Ready(Some(Ok(batch)));
110 };
111 if fetch == 0 {
112 return Poll::Ready(None);
113 }
114
115 let batch = if batch.num_rows() > fetch {
116 batch.slice(0, fetch)
117 } else {
118 batch
119 };
120 self.fetch = Some(fetch - batch.num_rows());
121 Poll::Ready(Some(Ok(batch)))
122 }
123
124 fn size_hint(&self) -> (usize, Option<usize>) {
125 (self.data.len(), Some(self.data.len()))
126 }
127}
128
129impl RecordBatchStream for MemoryStream {
130 fn schema(&self) -> SchemaRef {
132 Arc::clone(&self.schema)
133 }
134}
135
136pub trait LazyBatchGenerator: Send + Sync + fmt::Debug + fmt::Display {
137 fn as_any(&self) -> &dyn Any;
140
141 fn boundedness(&self) -> Boundedness {
142 Boundedness::Bounded
143 }
144
145 fn generate_next_batch(&mut self) -> Result<Option<RecordBatch>>;
147
148 fn reset_state(&self) -> Arc<RwLock<dyn LazyBatchGenerator>>;
150}
151
152pub struct LazyMemoryExec {
157 schema: SchemaRef,
159 projection: Option<Vec<usize>>,
161 batch_generators: Vec<Arc<RwLock<dyn LazyBatchGenerator>>>,
163 cache: PlanProperties,
165 metrics: ExecutionPlanMetricsSet,
167}
168
169impl LazyMemoryExec {
170 pub fn try_new(
172 schema: SchemaRef,
173 generators: Vec<Arc<RwLock<dyn LazyBatchGenerator>>>,
174 ) -> Result<Self> {
175 let boundedness = generators
176 .iter()
177 .map(|g| g.read().boundedness())
178 .reduce(|acc, b| match acc {
179 Boundedness::Bounded => b,
180 Boundedness::Unbounded {
181 requires_infinite_memory,
182 } => {
183 let acc_infinite_memory = requires_infinite_memory;
184 match b {
185 Boundedness::Bounded => acc,
186 Boundedness::Unbounded {
187 requires_infinite_memory,
188 } => Boundedness::Unbounded {
189 requires_infinite_memory: requires_infinite_memory
190 || acc_infinite_memory,
191 },
192 }
193 }
194 })
195 .unwrap_or(Boundedness::Bounded);
196
197 let cache = PlanProperties::new(
198 EquivalenceProperties::new(Arc::clone(&schema)),
199 Partitioning::RoundRobinBatch(generators.len()),
200 EmissionType::Incremental,
201 boundedness,
202 )
203 .with_scheduling_type(SchedulingType::Cooperative);
204
205 Ok(Self {
206 schema,
207 projection: None,
208 batch_generators: generators,
209 cache,
210 metrics: ExecutionPlanMetricsSet::new(),
211 })
212 }
213
214 pub fn with_projection(mut self, projection: Option<Vec<usize>>) -> Self {
215 match projection.as_ref() {
216 Some(columns) => {
217 let projected = Arc::new(self.schema.project(columns).unwrap());
218 self.cache = self.cache.with_eq_properties(EquivalenceProperties::new(
219 Arc::clone(&projected),
220 ));
221 self.schema = projected;
222 self.projection = projection;
223 self
224 }
225 _ => self,
226 }
227 }
228
229 pub fn try_set_partitioning(&mut self, partitioning: Partitioning) -> Result<()> {
230 let partition_count = partitioning.partition_count();
231 let generator_count = self.batch_generators.len();
232 assert_eq_or_internal_err!(
233 partition_count,
234 generator_count,
235 "Partition count must match generator count: {} != {}",
236 partition_count,
237 generator_count
238 );
239 self.cache.partitioning = partitioning;
240 Ok(())
241 }
242
243 pub fn add_ordering(&mut self, ordering: impl IntoIterator<Item = PhysicalSortExpr>) {
244 self.cache
245 .eq_properties
246 .add_orderings(std::iter::once(ordering));
247 }
248
249 pub fn generators(&self) -> &Vec<Arc<RwLock<dyn LazyBatchGenerator>>> {
251 &self.batch_generators
252 }
253}
254
255impl fmt::Debug for LazyMemoryExec {
256 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
257 f.debug_struct("LazyMemoryExec")
258 .field("schema", &self.schema)
259 .field("batch_generators", &self.batch_generators)
260 .finish()
261 }
262}
263
264impl DisplayAs for LazyMemoryExec {
265 fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result {
266 match t {
267 DisplayFormatType::Default | DisplayFormatType::Verbose => {
268 write!(
269 f,
270 "LazyMemoryExec: partitions={}, batch_generators=[{}]",
271 self.batch_generators.len(),
272 self.batch_generators
273 .iter()
274 .map(|g| g.read().to_string())
275 .collect::<Vec<_>>()
276 .join(", ")
277 )
278 }
279 DisplayFormatType::TreeRender => {
280 writeln!(
282 f,
283 "batch_generators={}",
284 self.batch_generators
285 .iter()
286 .map(|g| g.read().to_string())
287 .collect::<Vec<String>>()
288 .join(", ")
289 )?;
290 Ok(())
291 }
292 }
293 }
294}
295
296impl ExecutionPlan for LazyMemoryExec {
297 fn name(&self) -> &'static str {
298 "LazyMemoryExec"
299 }
300
301 fn as_any(&self) -> &dyn Any {
302 self
303 }
304
305 fn schema(&self) -> SchemaRef {
306 Arc::clone(&self.schema)
307 }
308
309 fn properties(&self) -> &PlanProperties {
310 &self.cache
311 }
312
313 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
314 vec![]
315 }
316
317 fn with_new_children(
318 self: Arc<Self>,
319 children: Vec<Arc<dyn ExecutionPlan>>,
320 ) -> Result<Arc<dyn ExecutionPlan>> {
321 assert_or_internal_err!(
322 children.is_empty(),
323 "Children cannot be replaced in LazyMemoryExec"
324 );
325 Ok(self)
326 }
327
328 fn execute(
329 &self,
330 partition: usize,
331 _context: Arc<TaskContext>,
332 ) -> Result<SendableRecordBatchStream> {
333 assert_or_internal_err!(
334 partition < self.batch_generators.len(),
335 "Invalid partition {} for LazyMemoryExec with {} partitions",
336 partition,
337 self.batch_generators.len()
338 );
339
340 let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
341
342 let stream = LazyMemoryStream {
343 schema: Arc::clone(&self.schema),
344 projection: self.projection.clone(),
345 generator: Arc::clone(&self.batch_generators[partition]),
346 baseline_metrics,
347 };
348 Ok(Box::pin(cooperative(stream)))
349 }
350
351 fn metrics(&self) -> Option<MetricsSet> {
352 Some(self.metrics.clone_inner())
353 }
354
355 fn statistics(&self) -> Result<Statistics> {
356 Ok(Statistics::new_unknown(&self.schema))
357 }
358
359 fn reset_state(self: Arc<Self>) -> Result<Arc<dyn ExecutionPlan>> {
360 let generators = self
361 .generators()
362 .iter()
363 .map(|g| g.read().reset_state())
364 .collect::<Vec<_>>();
365 Ok(Arc::new(LazyMemoryExec {
366 schema: Arc::clone(&self.schema),
367 batch_generators: generators,
368 cache: self.cache.clone(),
369 metrics: ExecutionPlanMetricsSet::new(),
370 projection: self.projection.clone(),
371 }))
372 }
373}
374
375pub struct LazyMemoryStream {
377 schema: SchemaRef,
378 projection: Option<Vec<usize>>,
380 generator: Arc<RwLock<dyn LazyBatchGenerator>>,
388 baseline_metrics: BaselineMetrics,
390}
391
392impl Stream for LazyMemoryStream {
393 type Item = Result<RecordBatch>;
394
395 fn poll_next(
396 self: std::pin::Pin<&mut Self>,
397 _: &mut Context<'_>,
398 ) -> Poll<Option<Self::Item>> {
399 let _timer_guard = self.baseline_metrics.elapsed_compute().timer();
400 let batch = self.generator.write().generate_next_batch();
401
402 let poll = match batch {
403 Ok(Some(batch)) => {
404 let batch = match self.projection.as_ref() {
406 Some(columns) => batch.project(columns)?,
407 None => batch,
408 };
409 Poll::Ready(Some(Ok(batch)))
410 }
411 Ok(None) => Poll::Ready(None),
412 Err(e) => Poll::Ready(Some(Err(e))),
413 };
414
415 self.baseline_metrics.record_poll(poll)
416 }
417}
418
419impl RecordBatchStream for LazyMemoryStream {
420 fn schema(&self) -> SchemaRef {
421 Arc::clone(&self.schema)
422 }
423}
424
425#[cfg(test)]
426mod lazy_memory_tests {
427 use super::*;
428 use crate::common::collect;
429 use arrow::array::Int64Array;
430 use arrow::datatypes::{DataType, Field, Schema};
431 use futures::StreamExt;
432
433 #[derive(Debug, Clone)]
434 struct TestGenerator {
435 counter: i64,
436 max_batches: i64,
437 batch_size: usize,
438 schema: SchemaRef,
439 }
440
441 impl fmt::Display for TestGenerator {
442 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
443 write!(
444 f,
445 "TestGenerator: counter={}, max_batches={}, batch_size={}",
446 self.counter, self.max_batches, self.batch_size
447 )
448 }
449 }
450
451 impl LazyBatchGenerator for TestGenerator {
452 fn as_any(&self) -> &dyn Any {
453 self
454 }
455
456 fn generate_next_batch(&mut self) -> Result<Option<RecordBatch>> {
457 if self.counter >= self.max_batches {
458 return Ok(None);
459 }
460
461 let array = Int64Array::from_iter_values(
462 (self.counter * self.batch_size as i64)
463 ..(self.counter * self.batch_size as i64 + self.batch_size as i64),
464 );
465 self.counter += 1;
466 Ok(Some(RecordBatch::try_new(
467 Arc::clone(&self.schema),
468 vec![Arc::new(array)],
469 )?))
470 }
471
472 fn reset_state(&self) -> Arc<RwLock<dyn LazyBatchGenerator>> {
473 Arc::new(RwLock::new(TestGenerator {
474 counter: 0,
475 max_batches: self.max_batches,
476 batch_size: self.batch_size,
477 schema: Arc::clone(&self.schema),
478 }))
479 }
480 }
481
482 #[tokio::test]
483 async fn test_lazy_memory_exec() -> Result<()> {
484 let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)]));
485 let generator = TestGenerator {
486 counter: 0,
487 max_batches: 3,
488 batch_size: 2,
489 schema: Arc::clone(&schema),
490 };
491
492 let exec =
493 LazyMemoryExec::try_new(schema, vec![Arc::new(RwLock::new(generator))])?;
494
495 assert_eq!(exec.schema().fields().len(), 1);
497 assert_eq!(exec.schema().field(0).name(), "a");
498
499 let stream = exec.execute(0, Arc::new(TaskContext::default()))?;
501 let batches: Vec<_> = stream.collect::<Vec<_>>().await;
502
503 assert_eq!(batches.len(), 3);
504
505 let batch0 = batches[0].as_ref().unwrap();
507 let array0 = batch0
508 .column(0)
509 .as_any()
510 .downcast_ref::<Int64Array>()
511 .unwrap();
512 assert_eq!(array0.values(), &[0, 1]);
513
514 let batch1 = batches[1].as_ref().unwrap();
515 let array1 = batch1
516 .column(0)
517 .as_any()
518 .downcast_ref::<Int64Array>()
519 .unwrap();
520 assert_eq!(array1.values(), &[2, 3]);
521
522 let batch2 = batches[2].as_ref().unwrap();
523 let array2 = batch2
524 .column(0)
525 .as_any()
526 .downcast_ref::<Int64Array>()
527 .unwrap();
528 assert_eq!(array2.values(), &[4, 5]);
529
530 Ok(())
531 }
532
533 #[tokio::test]
534 async fn test_lazy_memory_exec_invalid_partition() -> Result<()> {
535 let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)]));
536 let generator = TestGenerator {
537 counter: 0,
538 max_batches: 1,
539 batch_size: 1,
540 schema: Arc::clone(&schema),
541 };
542
543 let exec =
544 LazyMemoryExec::try_new(schema, vec![Arc::new(RwLock::new(generator))])?;
545
546 let result = exec.execute(1, Arc::new(TaskContext::default()));
548
549 assert!(matches!(
551 result,
552 Err(e) if e.to_string().contains("Invalid partition 1 for LazyMemoryExec with 1 partitions")
553 ));
554
555 Ok(())
556 }
557
558 #[tokio::test]
559 async fn test_generate_series_metrics_integration() -> Result<()> {
560 let test_cases = vec![
562 (10, 2, 10), (100, 10, 100), (5, 1, 5), ];
566
567 for (total_rows, batch_size, expected_rows) in test_cases {
568 let schema =
569 Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)]));
570 let generator = TestGenerator {
571 counter: 0,
572 max_batches: (total_rows + batch_size - 1) / batch_size, batch_size: batch_size as usize,
574 schema: Arc::clone(&schema),
575 };
576
577 let exec =
578 LazyMemoryExec::try_new(schema, vec![Arc::new(RwLock::new(generator))])?;
579 let task_ctx = Arc::new(TaskContext::default());
580
581 let stream = exec.execute(0, task_ctx)?;
582 let batches = collect(stream).await?;
583
584 let metrics = exec.metrics().unwrap();
586
587 let actual_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
589 assert_eq!(actual_rows, expected_rows);
590
591 assert_eq!(metrics.output_rows().unwrap(), expected_rows);
593 assert!(metrics.elapsed_compute().unwrap() > 0);
594 }
595
596 Ok(())
597 }
598
599 #[tokio::test]
600 async fn test_lazy_memory_exec_reset_state() -> Result<()> {
601 let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)]));
602 let generator = TestGenerator {
603 counter: 0,
604 max_batches: 3,
605 batch_size: 2,
606 schema: Arc::clone(&schema),
607 };
608
609 let exec = Arc::new(LazyMemoryExec::try_new(
610 schema,
611 vec![Arc::new(RwLock::new(generator))],
612 )?);
613 let stream = exec.execute(0, Arc::new(TaskContext::default()))?;
614 let batches = collect(stream).await?;
615
616 let exec_reset = exec.reset_state()?;
617 let stream = exec_reset.execute(0, Arc::new(TaskContext::default()))?;
618 let batches_reset = collect(stream).await?;
619
620 assert_eq!(batches, batches_reset);
622
623 Ok(())
624 }
625}