datafusion_physical_plan/
memory.rs1use std::any::Any;
21use std::fmt;
22use std::sync::Arc;
23use std::task::{Context, Poll};
24
25use crate::execution_plan::{Boundedness, EmissionType};
26use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
27use crate::{
28 DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties,
29 RecordBatchStream, SendableRecordBatchStream, Statistics,
30};
31
32use arrow::array::RecordBatch;
33use arrow::datatypes::SchemaRef;
34use datafusion_common::{internal_err, Result};
35use datafusion_execution::memory_pool::MemoryReservation;
36use datafusion_execution::TaskContext;
37use datafusion_physical_expr::EquivalenceProperties;
38
39use futures::Stream;
40use parking_lot::RwLock;
41
42pub struct MemoryStream {
44 data: Vec<RecordBatch>,
46 reservation: Option<MemoryReservation>,
48 schema: SchemaRef,
50 projection: Option<Vec<usize>>,
52 index: usize,
54 fetch: Option<usize>,
56}
57
58impl MemoryStream {
59 pub fn try_new(
61 data: Vec<RecordBatch>,
62 schema: SchemaRef,
63 projection: Option<Vec<usize>>,
64 ) -> Result<Self> {
65 Ok(Self {
66 data,
67 reservation: None,
68 schema,
69 projection,
70 index: 0,
71 fetch: None,
72 })
73 }
74
75 pub fn with_reservation(mut self, reservation: MemoryReservation) -> Self {
77 self.reservation = Some(reservation);
78 self
79 }
80
81 pub fn with_fetch(mut self, fetch: Option<usize>) -> Self {
83 self.fetch = fetch;
84 self
85 }
86}
87
88impl Stream for MemoryStream {
89 type Item = Result<RecordBatch>;
90
91 fn poll_next(
92 mut self: std::pin::Pin<&mut Self>,
93 _: &mut Context<'_>,
94 ) -> Poll<Option<Self::Item>> {
95 if self.index >= self.data.len() {
96 return Poll::Ready(None);
97 }
98 self.index += 1;
99 let batch = &self.data[self.index - 1];
100 let batch = match self.projection.as_ref() {
102 Some(columns) => batch.project(columns)?,
103 None => batch.clone(),
104 };
105
106 let Some(&fetch) = self.fetch.as_ref() else {
107 return Poll::Ready(Some(Ok(batch)));
108 };
109 if fetch == 0 {
110 return Poll::Ready(None);
111 }
112
113 let batch = if batch.num_rows() > fetch {
114 batch.slice(0, fetch)
115 } else {
116 batch
117 };
118 self.fetch = Some(fetch - batch.num_rows());
119 Poll::Ready(Some(Ok(batch)))
120 }
121
122 fn size_hint(&self) -> (usize, Option<usize>) {
123 (self.data.len(), Some(self.data.len()))
124 }
125}
126
127impl RecordBatchStream for MemoryStream {
128 fn schema(&self) -> SchemaRef {
130 Arc::clone(&self.schema)
131 }
132}
133
134pub trait LazyBatchGenerator: Send + Sync + fmt::Debug + fmt::Display {
135 fn generate_next_batch(&mut self) -> Result<Option<RecordBatch>>;
137}
138
139pub struct LazyMemoryExec {
144 schema: SchemaRef,
146 batch_generators: Vec<Arc<RwLock<dyn LazyBatchGenerator>>>,
148 cache: PlanProperties,
150 metrics: ExecutionPlanMetricsSet,
152}
153
154impl LazyMemoryExec {
155 pub fn try_new(
157 schema: SchemaRef,
158 generators: Vec<Arc<RwLock<dyn LazyBatchGenerator>>>,
159 ) -> Result<Self> {
160 let cache = PlanProperties::new(
161 EquivalenceProperties::new(Arc::clone(&schema)),
162 Partitioning::RoundRobinBatch(generators.len()),
163 EmissionType::Incremental,
164 Boundedness::Bounded,
165 );
166 Ok(Self {
167 schema,
168 batch_generators: generators,
169 cache,
170 metrics: ExecutionPlanMetricsSet::new(),
171 })
172 }
173}
174
175impl fmt::Debug for LazyMemoryExec {
176 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
177 f.debug_struct("LazyMemoryExec")
178 .field("schema", &self.schema)
179 .field("batch_generators", &self.batch_generators)
180 .finish()
181 }
182}
183
184impl DisplayAs for LazyMemoryExec {
185 fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result {
186 match t {
187 DisplayFormatType::Default | DisplayFormatType::Verbose => {
188 write!(
189 f,
190 "LazyMemoryExec: partitions={}, batch_generators=[{}]",
191 self.batch_generators.len(),
192 self.batch_generators
193 .iter()
194 .map(|g| g.read().to_string())
195 .collect::<Vec<_>>()
196 .join(", ")
197 )
198 }
199 DisplayFormatType::TreeRender => {
200 writeln!(
202 f,
203 "batch_generators={}",
204 self.batch_generators
205 .iter()
206 .map(|g| g.read().to_string())
207 .collect::<Vec<String>>()
208 .join(", ")
209 )?;
210 Ok(())
211 }
212 }
213 }
214}
215
216impl ExecutionPlan for LazyMemoryExec {
217 fn name(&self) -> &'static str {
218 "LazyMemoryExec"
219 }
220
221 fn as_any(&self) -> &dyn Any {
222 self
223 }
224
225 fn schema(&self) -> SchemaRef {
226 Arc::clone(&self.schema)
227 }
228
229 fn properties(&self) -> &PlanProperties {
230 &self.cache
231 }
232
233 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
234 vec![]
235 }
236
237 fn with_new_children(
238 self: Arc<Self>,
239 children: Vec<Arc<dyn ExecutionPlan>>,
240 ) -> Result<Arc<dyn ExecutionPlan>> {
241 if children.is_empty() {
242 Ok(self)
243 } else {
244 internal_err!("Children cannot be replaced in LazyMemoryExec")
245 }
246 }
247
248 fn execute(
249 &self,
250 partition: usize,
251 _context: Arc<TaskContext>,
252 ) -> Result<SendableRecordBatchStream> {
253 if partition >= self.batch_generators.len() {
254 return internal_err!(
255 "Invalid partition {} for LazyMemoryExec with {} partitions",
256 partition,
257 self.batch_generators.len()
258 );
259 }
260
261 let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
262 Ok(Box::pin(LazyMemoryStream {
263 schema: Arc::clone(&self.schema),
264 generator: Arc::clone(&self.batch_generators[partition]),
265 baseline_metrics,
266 }))
267 }
268
269 fn metrics(&self) -> Option<MetricsSet> {
270 Some(self.metrics.clone_inner())
271 }
272
273 fn statistics(&self) -> Result<Statistics> {
274 Ok(Statistics::new_unknown(&self.schema))
275 }
276}
277
278pub struct LazyMemoryStream {
280 schema: SchemaRef,
281 generator: Arc<RwLock<dyn LazyBatchGenerator>>,
289 baseline_metrics: BaselineMetrics,
291}
292
293impl Stream for LazyMemoryStream {
294 type Item = Result<RecordBatch>;
295
296 fn poll_next(
297 self: std::pin::Pin<&mut Self>,
298 _: &mut Context<'_>,
299 ) -> Poll<Option<Self::Item>> {
300 let _timer_guard = self.baseline_metrics.elapsed_compute().timer();
301 let batch = self.generator.write().generate_next_batch();
302
303 let poll = match batch {
304 Ok(Some(batch)) => Poll::Ready(Some(Ok(batch))),
305 Ok(None) => Poll::Ready(None),
306 Err(e) => Poll::Ready(Some(Err(e))),
307 };
308
309 self.baseline_metrics.record_poll(poll)
310 }
311}
312
313impl RecordBatchStream for LazyMemoryStream {
314 fn schema(&self) -> SchemaRef {
315 Arc::clone(&self.schema)
316 }
317}
318
319#[cfg(test)]
320mod lazy_memory_tests {
321 use super::*;
322 use crate::common::collect;
323 use arrow::array::Int64Array;
324 use arrow::datatypes::{DataType, Field, Schema};
325 use futures::StreamExt;
326
327 #[derive(Debug, Clone)]
328 struct TestGenerator {
329 counter: i64,
330 max_batches: i64,
331 batch_size: usize,
332 schema: SchemaRef,
333 }
334
335 impl fmt::Display for TestGenerator {
336 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
337 write!(
338 f,
339 "TestGenerator: counter={}, max_batches={}, batch_size={}",
340 self.counter, self.max_batches, self.batch_size
341 )
342 }
343 }
344
345 impl LazyBatchGenerator for TestGenerator {
346 fn generate_next_batch(&mut self) -> Result<Option<RecordBatch>> {
347 if self.counter >= self.max_batches {
348 return Ok(None);
349 }
350
351 let array = Int64Array::from_iter_values(
352 (self.counter * self.batch_size as i64)
353 ..(self.counter * self.batch_size as i64 + self.batch_size as i64),
354 );
355 self.counter += 1;
356 Ok(Some(RecordBatch::try_new(
357 Arc::clone(&self.schema),
358 vec![Arc::new(array)],
359 )?))
360 }
361 }
362
363 #[tokio::test]
364 async fn test_lazy_memory_exec() -> Result<()> {
365 let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)]));
366 let generator = TestGenerator {
367 counter: 0,
368 max_batches: 3,
369 batch_size: 2,
370 schema: Arc::clone(&schema),
371 };
372
373 let exec =
374 LazyMemoryExec::try_new(schema, vec![Arc::new(RwLock::new(generator))])?;
375
376 assert_eq!(exec.schema().fields().len(), 1);
378 assert_eq!(exec.schema().field(0).name(), "a");
379
380 let stream = exec.execute(0, Arc::new(TaskContext::default()))?;
382 let batches: Vec<_> = stream.collect::<Vec<_>>().await;
383
384 assert_eq!(batches.len(), 3);
385
386 let batch0 = batches[0].as_ref().unwrap();
388 let array0 = batch0
389 .column(0)
390 .as_any()
391 .downcast_ref::<Int64Array>()
392 .unwrap();
393 assert_eq!(array0.values(), &[0, 1]);
394
395 let batch1 = batches[1].as_ref().unwrap();
396 let array1 = batch1
397 .column(0)
398 .as_any()
399 .downcast_ref::<Int64Array>()
400 .unwrap();
401 assert_eq!(array1.values(), &[2, 3]);
402
403 let batch2 = batches[2].as_ref().unwrap();
404 let array2 = batch2
405 .column(0)
406 .as_any()
407 .downcast_ref::<Int64Array>()
408 .unwrap();
409 assert_eq!(array2.values(), &[4, 5]);
410
411 Ok(())
412 }
413
414 #[tokio::test]
415 async fn test_lazy_memory_exec_invalid_partition() -> Result<()> {
416 let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)]));
417 let generator = TestGenerator {
418 counter: 0,
419 max_batches: 1,
420 batch_size: 1,
421 schema: Arc::clone(&schema),
422 };
423
424 let exec =
425 LazyMemoryExec::try_new(schema, vec![Arc::new(RwLock::new(generator))])?;
426
427 let result = exec.execute(1, Arc::new(TaskContext::default()));
429
430 assert!(matches!(
432 result,
433 Err(e) if e.to_string().contains("Invalid partition 1 for LazyMemoryExec with 1 partitions")
434 ));
435
436 Ok(())
437 }
438
439 #[tokio::test]
440 async fn test_generate_series_metrics_integration() -> Result<()> {
441 let test_cases = vec![
443 (10, 2, 10), (100, 10, 100), (5, 1, 5), ];
447
448 for (total_rows, batch_size, expected_rows) in test_cases {
449 let schema =
450 Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)]));
451 let generator = TestGenerator {
452 counter: 0,
453 max_batches: (total_rows + batch_size - 1) / batch_size, batch_size: batch_size as usize,
455 schema: Arc::clone(&schema),
456 };
457
458 let exec =
459 LazyMemoryExec::try_new(schema, vec![Arc::new(RwLock::new(generator))])?;
460 let task_ctx = Arc::new(TaskContext::default());
461
462 let stream = exec.execute(0, task_ctx)?;
463 let batches = collect(stream).await?;
464
465 let metrics = exec.metrics().unwrap();
467
468 let actual_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
470 assert_eq!(actual_rows, expected_rows);
471
472 assert_eq!(metrics.output_rows().unwrap(), expected_rows);
474 assert!(metrics.elapsed_compute().unwrap() > 0);
475 }
476
477 Ok(())
478 }
479}