datafusion_physical_plan/spill/spill_pool.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use futures::{Stream, StreamExt};
19use std::collections::VecDeque;
20use std::sync::Arc;
21use std::task::Waker;
22
23use parking_lot::Mutex;
24
25use arrow::datatypes::SchemaRef;
26use arrow::record_batch::RecordBatch;
27use datafusion_common::Result;
28use datafusion_execution::disk_manager::RefCountedTempFile;
29use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream};
30
31use super::in_progress_spill_file::InProgressSpillFile;
32use super::spill_manager::SpillManager;
33
34/// Shared state between the writer and readers of a spill pool.
35/// This contains the queue of files and coordination state.
36///
37/// # Locking Design
38///
39/// This struct uses **fine-grained locking** with nested `Arc<Mutex<>>`:
40/// - `SpillPoolShared` is wrapped in `Arc<Mutex<>>` (outer lock)
41/// - Each `ActiveSpillFileShared` is wrapped in `Arc<Mutex<>>` (inner lock)
42///
43/// This enables:
44/// 1. **Short critical sections**: The outer lock is held only for queue operations
45/// 2. **I/O outside locks**: Disk I/O happens while holding only the file-specific lock
46/// 3. **Concurrent operations**: Reader can access the queue while writer does I/O
47///
48/// **Lock ordering discipline**: Never hold both locks simultaneously to prevent deadlock.
49/// Always: acquire outer lock → release outer lock → acquire inner lock (if needed).
50struct SpillPoolShared {
51 /// Queue of ALL files (including the current write file if it exists).
52 /// Readers always read from the front of this queue (FIFO).
53 /// Each file has its own lock to enable concurrent reader/writer access.
54 files: VecDeque<Arc<Mutex<ActiveSpillFileShared>>>,
55 /// SpillManager for creating files and tracking metrics
56 spill_manager: Arc<SpillManager>,
57 /// Pool-level waker to notify when new files are available (single reader)
58 waker: Option<Waker>,
59 /// Whether the writer has been dropped (no more files will be added)
60 writer_dropped: bool,
61 /// Writer's reference to the current file (shared by all cloned writers).
62 /// Has its own lock to allow I/O without blocking queue access.
63 current_write_file: Option<Arc<Mutex<ActiveSpillFileShared>>>,
64 /// Number of active writer clones. Only when this reaches zero should
65 /// `writer_dropped` be set to true. This prevents premature EOF signaling
66 /// when one writer clone is dropped while others are still active.
67 active_writer_count: usize,
68}
69
70impl SpillPoolShared {
71 /// Creates a new shared pool state
72 fn new(spill_manager: Arc<SpillManager>) -> Self {
73 Self {
74 files: VecDeque::new(),
75 spill_manager,
76 waker: None,
77 writer_dropped: false,
78 current_write_file: None,
79 active_writer_count: 1,
80 }
81 }
82
83 /// Registers a waker to be notified when new data is available (pool-level)
84 fn register_waker(&mut self, waker: Waker) {
85 self.waker = Some(waker);
86 }
87
88 /// Wakes the pool-level reader
89 fn wake(&mut self) {
90 if let Some(waker) = self.waker.take() {
91 waker.wake();
92 }
93 }
94}
95
96/// Writer for a spill pool. Provides coordinated write access with FIFO semantics.
97///
98/// Created by [`channel`]. See that function for architecture diagrams and usage examples.
99///
100/// The writer is `Clone`, allowing multiple writers to coordinate on the same pool.
101/// All clones share the same current write file and coordinate file rotation.
102/// The writer automatically manages file rotation based on the `max_file_size_bytes`
103/// configured in [`channel`]. When the last writer clone is dropped, it finalizes the
104/// current file so readers can access all written data.
105pub struct SpillPoolWriter {
106 /// Maximum size in bytes before rotating to a new file.
107 /// Typically set from configuration `datafusion.execution.max_spill_file_size_bytes`.
108 max_file_size_bytes: usize,
109 /// Shared state with readers (includes current_write_file for coordination)
110 shared: Arc<Mutex<SpillPoolShared>>,
111}
112
113impl Clone for SpillPoolWriter {
114 fn clone(&self) -> Self {
115 // Increment the active writer count so that `writer_dropped` is only
116 // set to true when the *last* clone is dropped.
117 self.shared.lock().active_writer_count += 1;
118 Self {
119 max_file_size_bytes: self.max_file_size_bytes,
120 shared: Arc::clone(&self.shared),
121 }
122 }
123}
124
125impl SpillPoolWriter {
126 /// Spills a batch to the pool, rotating files when necessary.
127 ///
128 /// If the current file would exceed `max_file_size_bytes` after adding
129 /// this batch, the file is finalized and a new one is started.
130 ///
131 /// See [`channel`] for overall architecture and examples.
132 ///
133 /// # File Rotation Logic
134 ///
135 /// ```text
136 /// push_batch()
137 /// │
138 /// ▼
139 /// Current file exists?
140 /// │
141 /// ├─ No ──▶ Create new file ──▶ Add to shared queue
142 /// │ Wake readers
143 /// ▼
144 /// Write batch to current file
145 /// │
146 /// ▼
147 /// estimated_size > max_file_size_bytes?
148 /// │
149 /// ├─ No ──▶ Keep current file for next batch
150 /// │
151 /// ▼
152 /// Yes: finish() current file
153 /// Mark writer_finished = true
154 /// Wake readers
155 /// │
156 /// ▼
157 /// Next push_batch() creates new file
158 /// ```
159 ///
160 /// # Errors
161 ///
162 /// Returns an error if disk I/O fails or disk quota is exceeded.
163 pub fn push_batch(&self, batch: &RecordBatch) -> Result<()> {
164 if batch.num_rows() == 0 {
165 // Skip empty batches
166 return Ok(());
167 }
168
169 let batch_size = batch.get_array_memory_size();
170
171 // Fine-grained locking: Lock shared state briefly for queue access
172 let mut shared = self.shared.lock();
173
174 // Create new file if we don't have one yet
175 if shared.current_write_file.is_none() {
176 let spill_manager = Arc::clone(&shared.spill_manager);
177 // Release shared lock before disk I/O (fine-grained locking)
178 drop(shared);
179
180 let writer = spill_manager.create_in_progress_file("SpillPool")?;
181 // Clone the file so readers can access it immediately
182 let file = writer.file().expect("InProgressSpillFile should always have a file when it is first created").clone();
183
184 let file_shared = Arc::new(Mutex::new(ActiveSpillFileShared {
185 writer: Some(writer),
186 file: Some(file), // Set immediately so readers can access it
187 batches_written: 0,
188 estimated_size: 0,
189 writer_finished: false,
190 waker: None,
191 }));
192
193 // Re-acquire lock and push to shared queue
194 shared = self.shared.lock();
195 shared.files.push_back(Arc::clone(&file_shared));
196 shared.current_write_file = Some(file_shared);
197 shared.wake(); // Wake readers waiting for new files
198 }
199
200 let current_write_file = shared.current_write_file.take();
201 // Release shared lock before file I/O (fine-grained locking)
202 // This allows readers to access the queue while we do disk I/O
203 drop(shared);
204
205 // Write batch to current file - lock only the specific file
206 if let Some(current_file) = current_write_file {
207 // Now lock just this file for I/O (separate from shared lock)
208 let mut file_shared = current_file.lock();
209
210 // Append the batch
211 if let Some(ref mut writer) = file_shared.writer {
212 writer.append_batch(batch)?;
213 // make sure we flush the writer for readers
214 writer.flush()?;
215 file_shared.batches_written += 1;
216 file_shared.estimated_size += batch_size;
217 }
218
219 // Wake reader waiting on this specific file
220 file_shared.wake();
221
222 // Check if we need to rotate
223 let needs_rotation = file_shared.estimated_size > self.max_file_size_bytes;
224
225 if needs_rotation {
226 // Finish the IPC writer
227 if let Some(mut writer) = file_shared.writer.take() {
228 writer.finish()?;
229 }
230 // Mark as finished so readers know not to wait for more data
231 file_shared.writer_finished = true;
232 // Wake reader waiting on this file (it's now finished)
233 file_shared.wake();
234 // Don't put back current_write_file - let it rotate
235 } else {
236 // Release file lock
237 drop(file_shared);
238 // Put back the current file for further writing
239 let mut shared = self.shared.lock();
240 shared.current_write_file = Some(current_file);
241 }
242 }
243
244 Ok(())
245 }
246}
247
248impl Drop for SpillPoolWriter {
249 fn drop(&mut self) {
250 let mut shared = self.shared.lock();
251
252 shared.active_writer_count -= 1;
253 let is_last_writer = shared.active_writer_count == 0;
254
255 if !is_last_writer {
256 // Other writer clones are still active; do not finalize or
257 // signal EOF to readers.
258 return;
259 }
260
261 // Finalize the current file when the last writer is dropped
262 if let Some(current_file) = shared.current_write_file.take() {
263 // Release shared lock before locking file
264 drop(shared);
265
266 let mut file_shared = current_file.lock();
267
268 // Finish the current writer if it exists
269 if let Some(mut writer) = file_shared.writer.take() {
270 // Ignore errors on drop - we're in destructor
271 let _ = writer.finish();
272 }
273
274 // Mark as finished so readers know not to wait for more data
275 file_shared.writer_finished = true;
276
277 // Wake reader waiting on this file (it's now finished)
278 file_shared.wake();
279
280 drop(file_shared);
281 shared = self.shared.lock();
282 }
283
284 // Mark writer as dropped and wake pool-level readers
285 shared.writer_dropped = true;
286 shared.wake();
287 }
288}
289
290/// Creates a paired writer and reader for a spill pool with MPSC (multi-producer, single-consumer)
291/// semantics.
292///
293/// This is the recommended way to create a spill pool. The writer is `Clone`, allowing
294/// multiple producers to coordinate writes to the same pool. The reader can consume batches
295/// in FIFO order. The reader can start reading immediately after a writer appends a batch
296/// to the spill file, without waiting for the file to be sealed, while writers continue to
297/// write more data.
298///
299/// Internally this coordinates rotating spill files based on size limits, and
300/// handles asynchronous notification between the writer and reader using wakers.
301/// This ensures that we manage disk usage efficiently while allowing concurrent
302/// I/O between the writer and reader.
303///
304/// # Data Flow Overview
305///
306/// 1. Writer write batch `B0` to F1
307/// 2. Writer write batch `B1` to F1, notices the size limit exceeded, finishes F1.
308/// 3. Reader read `B0` from F1
309/// 4. Reader read `B1`, no more batch to read -> wait on the waker
310/// 5. Writer write batch `B2` to a new file `F2`, wake up the waiting reader.
311/// 6. Reader read `B2` from F2.
312/// 7. Repeat until writer is dropped.
313///
314/// # Architecture
315///
316/// ```text
317/// ┌─────────────────────────────────────────────────────────────────────────┐
318/// │ SpillPool │
319/// │ │
320/// │ Writer Side Shared State Reader Side │
321/// │ ─────────── ──────────── ─────────── │
322/// │ │
323/// │ SpillPoolWriter ┌────────────────────┐ SpillPoolReader │
324/// │ │ │ VecDeque<File> │ │ │
325/// │ │ │ ┌────┐┌────┐ │ │ │
326/// │ push_batch() │ │ F1 ││ F2 │ ... │ next().await │
327/// │ │ │ └────┘└────┘ │ │ │
328/// │ ▼ │ (FIFO order) │ ▼ │
329/// │ ┌─────────┐ │ │ ┌──────────┐ │
330/// │ │Current │───────▶│ Coordination: │◀───│ Current │ │
331/// │ │Write │ │ - Wakers │ │ Read │ │
332/// │ │File │ │ - Batch counts │ │ File │ │
333/// │ └─────────┘ │ - Writer status │ └──────────┘ │
334/// │ │ └────────────────────┘ │ │
335/// │ │ │ │
336/// │ Size > limit? Read all batches? │
337/// │ │ │ │
338/// │ ▼ ▼ │
339/// │ Rotate to new file Pop from queue │
340/// └─────────────────────────────────────────────────────────────────────────┘
341///
342/// Writer produces → Shared FIFO queue → Reader consumes
343/// ```
344///
345/// # File State Machine
346///
347/// Each file in the pool coordinates between writer and reader:
348///
349/// ```text
350/// Writer View Reader View
351/// ─────────── ───────────
352///
353/// Created writer: Some(..) batches_read: 0
354/// batches_written: 0 (waiting for data)
355/// │
356/// ▼
357/// Writing append_batch() Can read if:
358/// batches_written++ batches_read < batches_written
359/// wake readers
360/// │ │
361/// │ ▼
362/// ┌──────┴──────┐ poll_next() → batch
363/// │ │ batches_read++
364/// ▼ ▼
365/// Size > limit? More data?
366/// │ │
367/// │ └─▶ Yes ──▶ Continue writing
368/// ▼
369/// finish() Reader catches up:
370/// writer_finished = true batches_read == batches_written
371/// wake readers │
372/// │ ▼
373/// └─────────────────────▶ Returns Poll::Ready(None)
374/// File complete, pop from queue
375/// ```
376///
377/// # Arguments
378///
379/// * `max_file_size_bytes` - Maximum size per file before rotation. When a file
380/// exceeds this size, the writer automatically rotates to a new file.
381/// * `spill_manager` - Manager for file creation and metrics tracking
382///
383/// # Returns
384///
385/// A tuple of `(SpillPoolWriter, SendableRecordBatchStream)` that share the same
386/// underlying pool. The reader is returned as a stream for immediate use with
387/// async stream combinators.
388///
389/// # Example
390///
391/// ```
392/// use std::sync::Arc;
393/// use arrow::array::{ArrayRef, Int32Array};
394/// use arrow::datatypes::{DataType, Field, Schema};
395/// use arrow::record_batch::RecordBatch;
396/// use datafusion_execution::runtime_env::RuntimeEnv;
397/// use futures::StreamExt;
398///
399/// # use datafusion_physical_plan::spill::spill_pool;
400/// # use datafusion_physical_plan::spill::SpillManager; // Re-exported for doctests
401/// # use datafusion_physical_plan::metrics::{ExecutionPlanMetricsSet, SpillMetrics};
402/// #
403/// # #[tokio::main]
404/// # async fn main() -> datafusion_common::Result<()> {
405/// # // Setup for the example (typically comes from TaskContext in production)
406/// # let env = Arc::new(RuntimeEnv::default());
407/// # let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
408/// # let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
409/// # let spill_manager = Arc::new(SpillManager::new(env, metrics, schema.clone()));
410/// #
411/// // Create channel with 1MB file size limit
412/// let (writer, mut reader) = spill_pool::channel(1024 * 1024, spill_manager);
413///
414/// // Spawn writer and reader concurrently; writer wakes reader via wakers
415/// let writer_task = tokio::spawn(async move {
416/// for i in 0..5 {
417/// let array: ArrayRef = Arc::new(Int32Array::from(vec![i; 100]));
418/// let batch = RecordBatch::try_new(schema.clone(), vec![array]).unwrap();
419/// writer.push_batch(&batch)?;
420/// }
421/// // Explicitly drop writer to finalize the spill file and wake the reader
422/// drop(writer);
423/// datafusion_common::Result::<()>::Ok(())
424/// });
425///
426/// let reader_task = tokio::spawn(async move {
427/// let mut batches_read = 0;
428/// while let Some(result) = reader.next().await {
429/// let _batch = result?;
430/// batches_read += 1;
431/// }
432/// datafusion_common::Result::<usize>::Ok(batches_read)
433/// });
434///
435/// let (writer_res, reader_res) = tokio::join!(writer_task, reader_task);
436/// writer_res
437/// .map_err(|e| datafusion_common::DataFusionError::Execution(e.to_string()))??;
438/// let batches_read = reader_res
439/// .map_err(|e| datafusion_common::DataFusionError::Execution(e.to_string()))??;
440///
441/// assert_eq!(batches_read, 5);
442/// # Ok(())
443/// # }
444/// ```
445///
446/// # Why rotate files?
447///
448/// File rotation ensures we don't end up with unreferenced disk usage.
449/// If we used a single file for all spilled data, we would end up with
450/// unreferenced data at the beginning of the file that has already been read
451/// by readers but we can't delete because you can't truncate from the start of a file.
452///
453/// Consider the case of a query like `SELECT * FROM large_table WHERE false`.
454/// Obviously this query produces no output rows, but if we had a spilling operator
455/// in the middle of this query between the scan and the filter it would see the entire
456/// `large_table` flow through it and thus would spill all of that data to disk.
457/// So we'd end up using up to `size(large_table)` bytes of disk space.
458/// If instead we use file rotation, and as long as the readers can keep up with the writer,
459/// then we can ensure that once a file is fully read by all readers it can be deleted,
460/// thus bounding the maximum disk usage to roughly `max_file_size_bytes`.
461pub fn channel(
462 max_file_size_bytes: usize,
463 spill_manager: Arc<SpillManager>,
464) -> (SpillPoolWriter, SendableRecordBatchStream) {
465 let schema = Arc::clone(spill_manager.schema());
466 let shared = Arc::new(Mutex::new(SpillPoolShared::new(spill_manager)));
467
468 let writer = SpillPoolWriter {
469 max_file_size_bytes,
470 shared: Arc::clone(&shared),
471 };
472
473 let reader = SpillPoolReader::new(shared, schema);
474
475 (writer, Box::pin(reader))
476}
477
478/// Shared state between writer and readers for an active spill file.
479/// Protected by a Mutex to coordinate between concurrent readers and the writer.
480struct ActiveSpillFileShared {
481 /// Writer handle - taken (set to None) when finish() is called
482 writer: Option<InProgressSpillFile>,
483 /// The spill file, set when the writer finishes.
484 /// Taken by the reader when creating a stream (the file stays open via file handles).
485 file: Option<RefCountedTempFile>,
486 /// Total number of batches written to this file
487 batches_written: usize,
488 /// Estimated size in bytes of data written to this file
489 estimated_size: usize,
490 /// Whether the writer has finished writing to this file
491 writer_finished: bool,
492 /// Waker for reader waiting on this specific file (SPSC: only one reader)
493 waker: Option<Waker>,
494}
495
496impl ActiveSpillFileShared {
497 /// Registers a waker to be notified when new data is written to this file
498 fn register_waker(&mut self, waker: Waker) {
499 self.waker = Some(waker);
500 }
501
502 /// Wakes the reader waiting on this file
503 fn wake(&mut self) {
504 if let Some(waker) = self.waker.take() {
505 waker.wake();
506 }
507 }
508}
509
510/// Reader state for a SpillFile (owned by individual SpillFile instances).
511/// This is kept separate from the shared state to avoid holding locks during I/O.
512struct SpillFileReader {
513 /// The actual stream reading from disk
514 stream: SendableRecordBatchStream,
515 /// Number of batches this reader has consumed
516 batches_read: usize,
517}
518
519struct SpillFile {
520 /// Shared coordination state (contains writer and batch counts)
521 shared: Arc<Mutex<ActiveSpillFileShared>>,
522 /// Reader state (lazy-initialized, owned by this SpillFile)
523 reader: Option<SpillFileReader>,
524 /// Spill manager for creating readers
525 spill_manager: Arc<SpillManager>,
526}
527
528impl Stream for SpillFile {
529 type Item = Result<RecordBatch>;
530
531 fn poll_next(
532 mut self: std::pin::Pin<&mut Self>,
533 cx: &mut std::task::Context<'_>,
534 ) -> std::task::Poll<Option<Self::Item>> {
535 use std::task::Poll;
536
537 // Step 1: Lock shared state and check coordination
538 let (should_read, file) = {
539 let mut shared = self.shared.lock();
540
541 // Determine if we can read
542 let batches_read = self.reader.as_ref().map_or(0, |r| r.batches_read);
543
544 if batches_read < shared.batches_written {
545 // More data available to read - take the file if we don't have a reader yet
546 let file = if self.reader.is_none() {
547 shared.file.take()
548 } else {
549 None
550 };
551 (true, file)
552 } else if shared.writer_finished {
553 // No more data and writer is done - EOF
554 return Poll::Ready(None);
555 } else {
556 // Caught up to writer, but writer still active - register waker and wait
557 shared.register_waker(cx.waker().clone());
558 return Poll::Pending;
559 }
560 }; // Lock released here
561
562 // Step 2: Lazy-create reader stream if needed
563 if self.reader.is_none() && should_read {
564 if let Some(file) = file {
565 // we want this unbuffered because files are actively being written to
566 match self
567 .spill_manager
568 .read_spill_as_stream_unbuffered(file, None)
569 {
570 Ok(stream) => {
571 self.reader = Some(SpillFileReader {
572 stream,
573 batches_read: 0,
574 });
575 }
576 Err(e) => return Poll::Ready(Some(Err(e))),
577 }
578 } else {
579 // File not available yet (writer hasn't finished or already taken)
580 // Register waker and wait for file to be ready
581 let mut shared = self.shared.lock();
582 shared.register_waker(cx.waker().clone());
583 return Poll::Pending;
584 }
585 }
586
587 // Step 3: Poll the reader stream (no lock held)
588 if let Some(reader) = &mut self.reader {
589 match reader.stream.poll_next_unpin(cx) {
590 Poll::Ready(Some(Ok(batch))) => {
591 // Successfully read a batch - increment counter
592 reader.batches_read += 1;
593 Poll::Ready(Some(Ok(batch)))
594 }
595 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
596 Poll::Ready(None) => {
597 // Stream exhausted unexpectedly
598 // This shouldn't happen if coordination is correct, but handle gracefully
599 Poll::Ready(None)
600 }
601 Poll::Pending => Poll::Pending,
602 }
603 } else {
604 // Should not reach here, but handle gracefully
605 Poll::Ready(None)
606 }
607 }
608}
609
610/// A stream that reads from a SpillPool in FIFO order.
611///
612/// Created by [`channel`]. See that function for architecture diagrams and usage examples.
613///
614/// The stream automatically handles file rotation and reads from completed files.
615/// When no data is available, it returns `Poll::Pending` and registers a waker to
616/// be notified when the writer produces more data.
617///
618/// # Infinite Stream Semantics
619///
620/// This stream never returns `None` (`Poll::Ready(None)`) on its own - it will keep
621/// waiting for the writer to produce more data. The stream ends only when:
622/// - The reader is dropped
623/// - The writer is dropped AND all queued data has been consumed
624///
625/// This makes it suitable for continuous streaming scenarios where the writer may
626/// produce data intermittently.
627pub struct SpillPoolReader {
628 /// Shared reference to the spill pool
629 shared: Arc<Mutex<SpillPoolShared>>,
630 /// Current SpillFile we're reading from
631 current_file: Option<SpillFile>,
632 /// Schema of the spilled data
633 schema: SchemaRef,
634}
635
636impl SpillPoolReader {
637 /// Creates a new reader from shared pool state.
638 ///
639 /// This is private - use the `channel()` function to create a reader/writer pair.
640 ///
641 /// # Arguments
642 ///
643 /// * `shared` - Shared reference to the pool state
644 fn new(shared: Arc<Mutex<SpillPoolShared>>, schema: SchemaRef) -> Self {
645 Self {
646 shared,
647 current_file: None,
648 schema,
649 }
650 }
651}
652
653impl Stream for SpillPoolReader {
654 type Item = Result<RecordBatch>;
655
656 fn poll_next(
657 mut self: std::pin::Pin<&mut Self>,
658 cx: &mut std::task::Context<'_>,
659 ) -> std::task::Poll<Option<Self::Item>> {
660 use std::task::Poll;
661
662 loop {
663 // If we have a current file, try to read from it
664 if let Some(ref mut file) = self.current_file {
665 match file.poll_next_unpin(cx) {
666 Poll::Ready(Some(Ok(batch))) => {
667 // Got a batch, return it
668 return Poll::Ready(Some(Ok(batch)));
669 }
670 Poll::Ready(Some(Err(e))) => {
671 // Error reading batch
672 return Poll::Ready(Some(Err(e)));
673 }
674 Poll::Ready(None) => {
675 // Current file stream exhausted
676 // Check if this file is marked as writer_finished
677 let writer_finished = { file.shared.lock().writer_finished };
678
679 if writer_finished {
680 // File is complete, pop it from the queue and move to next
681 let mut shared = self.shared.lock();
682 shared.files.pop_front();
683 drop(shared); // Release lock
684
685 // Clear current file and continue loop to get next file
686 self.current_file = None;
687 continue;
688 } else {
689 // Stream exhausted but writer not finished - unexpected
690 // This shouldn't happen with proper coordination
691 return Poll::Ready(None);
692 }
693 }
694 Poll::Pending => {
695 // File not ready yet (waiting for writer)
696 // Register waker so we get notified when writer adds more batches
697 let mut shared = self.shared.lock();
698 shared.register_waker(cx.waker().clone());
699 return Poll::Pending;
700 }
701 }
702 }
703
704 // No current file, need to get the next one
705 let mut shared = self.shared.lock();
706
707 // Peek at the front of the queue (don't pop yet)
708 if let Some(file_shared) = shared.files.front() {
709 // Create a SpillFile from the shared state
710 let spill_manager = Arc::clone(&shared.spill_manager);
711 let file_shared = Arc::clone(file_shared);
712 drop(shared); // Release lock before creating SpillFile
713
714 self.current_file = Some(SpillFile {
715 shared: file_shared,
716 reader: None,
717 spill_manager,
718 });
719
720 // Continue loop to poll the new file
721 continue;
722 }
723
724 // No files in queue - check if writer is done
725 if shared.writer_dropped {
726 // Writer is done and no more files will be added - EOF
727 return Poll::Ready(None);
728 }
729
730 // Writer still active, register waker that will get notified when new files are added
731 shared.register_waker(cx.waker().clone());
732 return Poll::Pending;
733 }
734 }
735}
736
737impl RecordBatchStream for SpillPoolReader {
738 fn schema(&self) -> SchemaRef {
739 Arc::clone(&self.schema)
740 }
741}
742
743#[cfg(test)]
744mod tests {
745 use super::*;
746 use crate::metrics::{ExecutionPlanMetricsSet, SpillMetrics};
747 use arrow::array::{ArrayRef, Int32Array};
748 use arrow::datatypes::{DataType, Field, Schema};
749 use datafusion_common_runtime::SpawnedTask;
750 use datafusion_execution::runtime_env::RuntimeEnv;
751 use futures::StreamExt;
752
753 fn create_test_schema() -> SchemaRef {
754 Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]))
755 }
756
757 fn create_test_batch(start: i32, count: usize) -> RecordBatch {
758 let schema = create_test_schema();
759 let a: ArrayRef = Arc::new(Int32Array::from(
760 (start..start + count as i32).collect::<Vec<_>>(),
761 ));
762 RecordBatch::try_new(schema, vec![a]).unwrap()
763 }
764
765 fn create_spill_channel(
766 max_file_size: usize,
767 ) -> (SpillPoolWriter, SendableRecordBatchStream) {
768 let env = Arc::new(RuntimeEnv::default());
769 let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
770 let schema = create_test_schema();
771 let spill_manager = Arc::new(SpillManager::new(env, metrics, schema));
772
773 channel(max_file_size, spill_manager)
774 }
775
776 fn create_spill_channel_with_metrics(
777 max_file_size: usize,
778 ) -> (SpillPoolWriter, SendableRecordBatchStream, SpillMetrics) {
779 let env = Arc::new(RuntimeEnv::default());
780 let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
781 let schema = create_test_schema();
782 let spill_manager = Arc::new(SpillManager::new(env, metrics.clone(), schema));
783
784 let (writer, reader) = channel(max_file_size, spill_manager);
785 (writer, reader, metrics)
786 }
787
788 #[tokio::test]
789 async fn test_basic_write_and_read() -> Result<()> {
790 let (writer, mut reader) = create_spill_channel(1024 * 1024);
791
792 // Write one batch
793 let batch1 = create_test_batch(0, 10);
794 writer.push_batch(&batch1)?;
795
796 // Read the batch
797 let result = reader.next().await.unwrap()?;
798 assert_eq!(result.num_rows(), 10);
799
800 // Write another batch
801 let batch2 = create_test_batch(10, 5);
802 writer.push_batch(&batch2)?;
803 // Read the second batch
804 let result = reader.next().await.unwrap()?;
805 assert_eq!(result.num_rows(), 5);
806
807 Ok(())
808 }
809
810 #[tokio::test]
811 async fn test_single_batch_write_read() -> Result<()> {
812 let (writer, mut reader) = create_spill_channel(1024 * 1024);
813
814 // Write one batch
815 let batch = create_test_batch(0, 5);
816 writer.push_batch(&batch)?;
817
818 // Read it back
819 let result = reader.next().await.unwrap()?;
820 assert_eq!(result.num_rows(), 5);
821
822 // Verify the actual data
823 let col = result
824 .column(0)
825 .as_any()
826 .downcast_ref::<Int32Array>()
827 .unwrap();
828 assert_eq!(col.value(0), 0);
829 assert_eq!(col.value(4), 4);
830
831 Ok(())
832 }
833
834 #[tokio::test]
835 async fn test_multiple_batches_sequential() -> Result<()> {
836 let (writer, mut reader) = create_spill_channel(1024 * 1024);
837
838 // Write multiple batches
839 for i in 0..5 {
840 let batch = create_test_batch(i * 10, 10);
841 writer.push_batch(&batch)?;
842 }
843
844 // Read all batches and verify FIFO order
845 for i in 0..5 {
846 let result = reader.next().await.unwrap()?;
847 assert_eq!(result.num_rows(), 10);
848
849 let col = result
850 .column(0)
851 .as_any()
852 .downcast_ref::<Int32Array>()
853 .unwrap();
854 assert_eq!(col.value(0), i * 10, "Batch {i} not in FIFO order");
855 }
856
857 Ok(())
858 }
859
860 #[tokio::test]
861 async fn test_empty_writer() -> Result<()> {
862 let (_writer, reader) = create_spill_channel(1024 * 1024);
863
864 // Reader should pend since no batches were written
865 let mut reader = reader;
866 let result =
867 tokio::time::timeout(std::time::Duration::from_millis(100), reader.next())
868 .await;
869
870 assert!(result.is_err(), "Reader should timeout on empty writer");
871
872 Ok(())
873 }
874
875 #[tokio::test]
876 async fn test_empty_batch_skipping() -> Result<()> {
877 let (writer, mut reader) = create_spill_channel(1024 * 1024);
878
879 // Write empty batch
880 let empty_batch = create_test_batch(0, 0);
881 writer.push_batch(&empty_batch)?;
882
883 // Write non-empty batch
884 let batch = create_test_batch(0, 5);
885 writer.push_batch(&batch)?;
886
887 // Should only read the non-empty batch
888 let result = reader.next().await.unwrap()?;
889 assert_eq!(result.num_rows(), 5);
890
891 Ok(())
892 }
893
894 #[tokio::test]
895 async fn test_rotation_triggered_by_size() -> Result<()> {
896 // Set a small max_file_size to trigger rotation after one batch
897 let batch1 = create_test_batch(0, 10);
898 let batch_size = batch1.get_array_memory_size() + 1;
899
900 let (writer, mut reader, metrics) = create_spill_channel_with_metrics(batch_size);
901
902 // Write first batch (should fit in first file)
903 writer.push_batch(&batch1)?;
904
905 // Check metrics after first batch - file created but not finalized yet
906 assert_eq!(
907 metrics.spill_file_count.value(),
908 1,
909 "Should have created 1 file after first batch"
910 );
911 assert_eq!(
912 metrics.spilled_bytes.value(),
913 0,
914 "Spilled bytes should be 0 before file finalization"
915 );
916 assert_eq!(
917 metrics.spilled_rows.value(),
918 10,
919 "Should have spilled 10 rows from first batch"
920 );
921
922 // Write second batch (should trigger rotation - finalize first file)
923 let batch2 = create_test_batch(10, 10);
924 assert!(
925 batch2.get_array_memory_size() <= batch_size,
926 "batch2 size {} exceeds limit {batch_size}",
927 batch2.get_array_memory_size(),
928 );
929 assert!(
930 batch1.get_array_memory_size() + batch2.get_array_memory_size() > batch_size,
931 "Combined size {} does not exceed limit to trigger rotation",
932 batch1.get_array_memory_size() + batch2.get_array_memory_size()
933 );
934 writer.push_batch(&batch2)?;
935
936 // Check metrics after rotation - first file finalized, but second file not created yet
937 // (new file created lazily on next push_batch call)
938 assert_eq!(
939 metrics.spill_file_count.value(),
940 1,
941 "Should still have 1 file (second file not created until next write)"
942 );
943 assert!(
944 metrics.spilled_bytes.value() > 0,
945 "Spilled bytes should be > 0 after first file finalized (got {})",
946 metrics.spilled_bytes.value()
947 );
948 assert_eq!(
949 metrics.spilled_rows.value(),
950 20,
951 "Should have spilled 20 total rows (10 + 10)"
952 );
953
954 // Write a third batch to confirm rotation occurred (creates second file)
955 let batch3 = create_test_batch(20, 5);
956 writer.push_batch(&batch3)?;
957
958 // Now check that second file was created
959 assert_eq!(
960 metrics.spill_file_count.value(),
961 2,
962 "Should have created 2 files after writing to new file"
963 );
964 assert_eq!(
965 metrics.spilled_rows.value(),
966 25,
967 "Should have spilled 25 total rows (10 + 10 + 5)"
968 );
969
970 // Read all three batches
971 let result1 = reader.next().await.unwrap()?;
972 assert_eq!(result1.num_rows(), 10);
973
974 let result2 = reader.next().await.unwrap()?;
975 assert_eq!(result2.num_rows(), 10);
976
977 let result3 = reader.next().await.unwrap()?;
978 assert_eq!(result3.num_rows(), 5);
979
980 Ok(())
981 }
982
983 #[tokio::test]
984 async fn test_multiple_rotations() -> Result<()> {
985 let batches = (0..10)
986 .map(|i| create_test_batch(i * 10, 10))
987 .collect::<Vec<_>>();
988
989 let batch_size = batches[0].get_array_memory_size() * 2 + 1;
990
991 // Very small max_file_size to force frequent rotations
992 let (writer, mut reader, metrics) = create_spill_channel_with_metrics(batch_size);
993
994 // Write many batches to cause multiple rotations
995 for i in 0..10 {
996 let batch = create_test_batch(i * 10, 10);
997 writer.push_batch(&batch)?;
998 }
999
1000 // Check metrics after all writes - should have multiple files due to rotations
1001 // With batch_size = 2 * one_batch + 1, each file fits ~2 batches before rotating
1002 // 10 batches should create multiple files (exact count depends on rotation timing)
1003 let file_count = metrics.spill_file_count.value();
1004 assert!(
1005 file_count >= 4,
1006 "Should have created at least 4 files with multiple rotations (got {file_count})"
1007 );
1008 assert!(
1009 metrics.spilled_bytes.value() > 0,
1010 "Spilled bytes should be > 0 after rotations (got {})",
1011 metrics.spilled_bytes.value()
1012 );
1013 assert_eq!(
1014 metrics.spilled_rows.value(),
1015 100,
1016 "Should have spilled 100 total rows (10 batches * 10 rows)"
1017 );
1018
1019 // Read all batches and verify order
1020 for i in 0..10 {
1021 let result = reader.next().await.unwrap()?;
1022 assert_eq!(result.num_rows(), 10);
1023
1024 let col = result
1025 .column(0)
1026 .as_any()
1027 .downcast_ref::<Int32Array>()
1028 .unwrap();
1029 assert_eq!(
1030 col.value(0),
1031 i * 10,
1032 "Batch {i} not in correct order after rotations"
1033 );
1034 }
1035
1036 Ok(())
1037 }
1038
1039 #[tokio::test]
1040 async fn test_single_batch_larger_than_limit() -> Result<()> {
1041 // Very small limit
1042 let (writer, mut reader, metrics) = create_spill_channel_with_metrics(100);
1043
1044 // Write a batch that exceeds the limit
1045 let large_batch = create_test_batch(0, 100);
1046 writer.push_batch(&large_batch)?;
1047
1048 // Check metrics after large batch - should trigger rotation immediately
1049 assert_eq!(
1050 metrics.spill_file_count.value(),
1051 1,
1052 "Should have created 1 file for large batch"
1053 );
1054 assert_eq!(
1055 metrics.spilled_rows.value(),
1056 100,
1057 "Should have spilled 100 rows from large batch"
1058 );
1059
1060 // Should still write and read successfully
1061 let result = reader.next().await.unwrap()?;
1062 assert_eq!(result.num_rows(), 100);
1063
1064 // Next batch should go to a new file
1065 let batch2 = create_test_batch(100, 10);
1066 writer.push_batch(&batch2)?;
1067
1068 // Check metrics after second batch - should have rotated to a new file
1069 assert_eq!(
1070 metrics.spill_file_count.value(),
1071 2,
1072 "Should have created 2 files after rotation"
1073 );
1074 assert_eq!(
1075 metrics.spilled_rows.value(),
1076 110,
1077 "Should have spilled 110 total rows (100 + 10)"
1078 );
1079
1080 let result2 = reader.next().await.unwrap()?;
1081 assert_eq!(result2.num_rows(), 10);
1082
1083 Ok(())
1084 }
1085
1086 #[tokio::test]
1087 async fn test_very_small_max_file_size() -> Result<()> {
1088 // Test with just 1 byte max (extreme case)
1089 let (writer, mut reader) = create_spill_channel(1);
1090
1091 // Any batch will exceed this limit
1092 let batch = create_test_batch(0, 5);
1093 writer.push_batch(&batch)?;
1094
1095 // Should still work
1096 let result = reader.next().await.unwrap()?;
1097 assert_eq!(result.num_rows(), 5);
1098
1099 Ok(())
1100 }
1101
1102 #[tokio::test]
1103 async fn test_exact_size_boundary() -> Result<()> {
1104 // Create a batch and measure its approximate size
1105 let batch = create_test_batch(0, 10);
1106 let batch_size = batch.get_array_memory_size();
1107
1108 // Set max_file_size to exactly the batch size
1109 let (writer, mut reader, metrics) = create_spill_channel_with_metrics(batch_size);
1110
1111 // Write first batch (exactly at the size limit)
1112 writer.push_batch(&batch)?;
1113
1114 // Check metrics after first batch - should NOT rotate yet (size == limit, not >)
1115 assert_eq!(
1116 metrics.spill_file_count.value(),
1117 1,
1118 "Should have created 1 file after first batch at exact boundary"
1119 );
1120 assert_eq!(
1121 metrics.spilled_rows.value(),
1122 10,
1123 "Should have spilled 10 rows from first batch"
1124 );
1125
1126 // Write second batch (exceeds the limit, should trigger rotation)
1127 let batch2 = create_test_batch(10, 10);
1128 writer.push_batch(&batch2)?;
1129
1130 // Check metrics after second batch - rotation triggered, first file finalized
1131 // Note: second file not created yet (lazy creation on next write)
1132 assert_eq!(
1133 metrics.spill_file_count.value(),
1134 1,
1135 "Should still have 1 file after rotation (second file created lazily)"
1136 );
1137 assert_eq!(
1138 metrics.spilled_rows.value(),
1139 20,
1140 "Should have spilled 20 total rows (10 + 10)"
1141 );
1142 // Verify first file was finalized by checking spilled_bytes
1143 assert!(
1144 metrics.spilled_bytes.value() > 0,
1145 "Spilled bytes should be > 0 after file finalization (got {})",
1146 metrics.spilled_bytes.value()
1147 );
1148
1149 // Both should be readable
1150 let result1 = reader.next().await.unwrap()?;
1151 assert_eq!(result1.num_rows(), 10);
1152
1153 let result2 = reader.next().await.unwrap()?;
1154 assert_eq!(result2.num_rows(), 10);
1155
1156 // Spill another batch, now we should see the second file created
1157 let batch3 = create_test_batch(20, 5);
1158 writer.push_batch(&batch3)?;
1159 assert_eq!(
1160 metrics.spill_file_count.value(),
1161 2,
1162 "Should have created 2 files after writing to new file"
1163 );
1164
1165 Ok(())
1166 }
1167
1168 #[tokio::test]
1169 async fn test_concurrent_reader_writer() -> Result<()> {
1170 let (writer, mut reader) = create_spill_channel(1024 * 1024);
1171
1172 // Spawn writer task
1173 let writer_handle = SpawnedTask::spawn(async move {
1174 for i in 0..10 {
1175 let batch = create_test_batch(i * 10, 10);
1176 writer.push_batch(&batch).unwrap();
1177 // Small delay to simulate real concurrent work
1178 tokio::time::sleep(std::time::Duration::from_millis(5)).await;
1179 }
1180 });
1181
1182 // Reader task (runs concurrently)
1183 let reader_handle = SpawnedTask::spawn(async move {
1184 let mut count = 0;
1185 for i in 0..10 {
1186 let result = reader.next().await.unwrap().unwrap();
1187 assert_eq!(result.num_rows(), 10);
1188
1189 let col = result
1190 .column(0)
1191 .as_any()
1192 .downcast_ref::<Int32Array>()
1193 .unwrap();
1194 assert_eq!(col.value(0), i * 10);
1195 count += 1;
1196 }
1197 count
1198 });
1199
1200 // Wait for both to complete
1201 writer_handle.await.unwrap();
1202 let batches_read = reader_handle.await.unwrap();
1203 assert_eq!(batches_read, 10);
1204
1205 Ok(())
1206 }
1207
1208 #[tokio::test]
1209 async fn test_reader_catches_up_to_writer() -> Result<()> {
1210 let (writer, mut reader) = create_spill_channel(1024 * 1024);
1211
1212 let (reader_waiting_tx, reader_waiting_rx) = tokio::sync::oneshot::channel();
1213 let (first_read_done_tx, first_read_done_rx) = tokio::sync::oneshot::channel();
1214
1215 #[derive(Clone, Copy, Debug, PartialEq, Eq)]
1216 enum ReadWriteEvent {
1217 ReadStart,
1218 Read(usize),
1219 Write(usize),
1220 }
1221
1222 let events = Arc::new(Mutex::new(vec![]));
1223 // Start reader first (will pend)
1224 let reader_events = Arc::clone(&events);
1225 let reader_handle = SpawnedTask::spawn(async move {
1226 reader_events.lock().push(ReadWriteEvent::ReadStart);
1227 reader_waiting_tx
1228 .send(())
1229 .expect("reader_waiting channel closed unexpectedly");
1230 let result = reader.next().await.unwrap().unwrap();
1231 reader_events
1232 .lock()
1233 .push(ReadWriteEvent::Read(result.num_rows()));
1234 first_read_done_tx
1235 .send(())
1236 .expect("first_read_done channel closed unexpectedly");
1237 let result = reader.next().await.unwrap().unwrap();
1238 reader_events
1239 .lock()
1240 .push(ReadWriteEvent::Read(result.num_rows()));
1241 });
1242
1243 // Wait until the reader is pending on the first batch
1244 reader_waiting_rx
1245 .await
1246 .expect("reader should signal when waiting");
1247
1248 // Now write a batch (should wake the reader)
1249 let batch = create_test_batch(0, 5);
1250 events.lock().push(ReadWriteEvent::Write(batch.num_rows()));
1251 writer.push_batch(&batch)?;
1252
1253 // Wait for the reader to finish the first read before allowing the
1254 // second write. This ensures deterministic ordering of events:
1255 // 1. The reader starts and pends on the first `next()`
1256 // 2. The first write wakes the reader
1257 // 3. The reader processes the first batch and signals completion
1258 // 4. The second write is issued, ensuring consistent event ordering
1259 first_read_done_rx
1260 .await
1261 .expect("reader should signal when first read completes");
1262
1263 // Write another batch
1264 let batch = create_test_batch(5, 10);
1265 events.lock().push(ReadWriteEvent::Write(batch.num_rows()));
1266 writer.push_batch(&batch)?;
1267
1268 // Reader should complete
1269 reader_handle.await.unwrap();
1270 let events = events.lock().clone();
1271 assert_eq!(
1272 events,
1273 vec![
1274 ReadWriteEvent::ReadStart,
1275 ReadWriteEvent::Write(5),
1276 ReadWriteEvent::Read(5),
1277 ReadWriteEvent::Write(10),
1278 ReadWriteEvent::Read(10)
1279 ]
1280 );
1281
1282 Ok(())
1283 }
1284
1285 #[tokio::test]
1286 async fn test_reader_starts_after_writer_finishes() -> Result<()> {
1287 let (writer, reader) = create_spill_channel(128);
1288
1289 // Writer writes all data
1290 for i in 0..5 {
1291 let batch = create_test_batch(i * 10, 10);
1292 writer.push_batch(&batch)?;
1293 }
1294
1295 drop(writer);
1296
1297 // Now start reader
1298 let mut reader = reader;
1299 let mut count = 0;
1300 for i in 0..5 {
1301 let result = reader.next().await.unwrap()?;
1302 assert_eq!(result.num_rows(), 10);
1303
1304 let col = result
1305 .column(0)
1306 .as_any()
1307 .downcast_ref::<Int32Array>()
1308 .unwrap();
1309 assert_eq!(col.value(0), i * 10);
1310 count += 1;
1311 }
1312
1313 assert_eq!(count, 5, "Should read all batches after writer finishes");
1314
1315 Ok(())
1316 }
1317
1318 #[tokio::test]
1319 async fn test_writer_drop_finalizes_file() -> Result<()> {
1320 let env = Arc::new(RuntimeEnv::default());
1321 let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
1322 let schema = create_test_schema();
1323 let spill_manager =
1324 Arc::new(SpillManager::new(Arc::clone(&env), metrics.clone(), schema));
1325
1326 let (writer, mut reader) = channel(1024 * 1024, spill_manager);
1327
1328 // Write some batches
1329 for i in 0..5 {
1330 let batch = create_test_batch(i * 10, 10);
1331 writer.push_batch(&batch)?;
1332 }
1333
1334 // Check metrics before drop - spilled_bytes should be 0 since file isn't finalized yet
1335 let spilled_bytes_before = metrics.spilled_bytes.value();
1336 assert_eq!(
1337 spilled_bytes_before, 0,
1338 "Spilled bytes should be 0 before writer is dropped"
1339 );
1340
1341 // Explicitly drop the writer - this should finalize the current file
1342 drop(writer);
1343
1344 // Check metrics after drop - spilled_bytes should be > 0 now
1345 let spilled_bytes_after = metrics.spilled_bytes.value();
1346 assert!(
1347 spilled_bytes_after > 0,
1348 "Spilled bytes should be > 0 after writer is dropped (got {spilled_bytes_after})"
1349 );
1350
1351 // Verify reader can still read all batches
1352 let mut count = 0;
1353 for i in 0..5 {
1354 let result = reader.next().await.unwrap()?;
1355 assert_eq!(result.num_rows(), 10);
1356
1357 let col = result
1358 .column(0)
1359 .as_any()
1360 .downcast_ref::<Int32Array>()
1361 .unwrap();
1362 assert_eq!(col.value(0), i * 10);
1363 count += 1;
1364 }
1365
1366 assert_eq!(count, 5, "Should read all batches after writer is dropped");
1367
1368 Ok(())
1369 }
1370
1371 /// Verifies that the reader stays alive as long as any writer clone exists.
1372 ///
1373 /// `SpillPoolWriter` is `Clone`, and in non-preserve-order repartitioning
1374 /// mode multiple input partition tasks share clones of the same writer.
1375 /// The reader must not see EOF until **all** clones have been dropped,
1376 /// even if the queue is temporarily empty between writes from different
1377 /// clones.
1378 ///
1379 /// The test sequence is:
1380 ///
1381 /// 1. writer1 writes a batch, then is dropped.
1382 /// 2. The reader consumes that batch (queue is now empty).
1383 /// 3. writer2 (still alive) writes a batch.
1384 /// 4. The reader must see that batch.
1385 /// 5. EOF is only signalled after writer2 is also dropped.
1386 #[tokio::test]
1387 async fn test_clone_drop_does_not_signal_eof_prematurely() -> Result<()> {
1388 let (writer1, mut reader) = create_spill_channel(1024 * 1024);
1389 let writer2 = writer1.clone();
1390
1391 // Synchronization: tell writer2 when it may proceed.
1392 let (proceed_tx, proceed_rx) = tokio::sync::oneshot::channel::<()>();
1393
1394 // Spawn writer2 — it waits for the signal before writing.
1395 let writer2_handle = SpawnedTask::spawn(async move {
1396 proceed_rx.await.unwrap();
1397 writer2.push_batch(&create_test_batch(10, 10)).unwrap();
1398 // writer2 is dropped here (last clone → true EOF)
1399 });
1400
1401 // Writer1 writes one batch, then drops.
1402 writer1.push_batch(&create_test_batch(0, 10))?;
1403 drop(writer1);
1404
1405 // Read writer1's batch.
1406 let batch1 = reader.next().await.unwrap()?;
1407 assert_eq!(batch1.num_rows(), 10);
1408 let col = batch1
1409 .column(0)
1410 .as_any()
1411 .downcast_ref::<Int32Array>()
1412 .unwrap();
1413 assert_eq!(col.value(0), 0);
1414
1415 // Signal writer2 to write its batch. It will execute when the
1416 // current task yields (i.e. when reader.next() returns Pending).
1417 proceed_tx.send(()).unwrap();
1418
1419 // The reader should wait (Pending) for writer2's data, not EOF.
1420 let batch2 =
1421 tokio::time::timeout(std::time::Duration::from_secs(5), reader.next())
1422 .await
1423 .expect("Reader timed out — should not hang");
1424
1425 assert!(
1426 batch2.is_some(),
1427 "Reader must not return EOF while a writer clone is still alive"
1428 );
1429 let batch2 = batch2.unwrap()?;
1430 assert_eq!(batch2.num_rows(), 10);
1431 let col = batch2
1432 .column(0)
1433 .as_any()
1434 .downcast_ref::<Int32Array>()
1435 .unwrap();
1436 assert_eq!(col.value(0), 10);
1437
1438 writer2_handle.await.unwrap();
1439
1440 // All writers dropped — reader should see real EOF now.
1441 assert!(reader.next().await.is_none());
1442
1443 Ok(())
1444 }
1445
1446 #[tokio::test]
1447 async fn test_disk_usage_decreases_as_files_consumed() -> Result<()> {
1448 use datafusion_execution::runtime_env::RuntimeEnvBuilder;
1449
1450 // Test configuration
1451 const NUM_BATCHES: usize = 3;
1452 const ROWS_PER_BATCH: usize = 100;
1453
1454 // Step 1: Create a test batch and measure its size
1455 let batch = create_test_batch(0, ROWS_PER_BATCH);
1456 let batch_size = batch.get_array_memory_size();
1457
1458 // Step 2: Configure file rotation to approximately 1 batch per file
1459 // Create a custom RuntimeEnv so we can access the DiskManager
1460 let runtime = Arc::new(RuntimeEnvBuilder::default().build()?);
1461 let disk_manager = Arc::clone(&runtime.disk_manager);
1462
1463 let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
1464 let schema = create_test_schema();
1465 let spill_manager = Arc::new(SpillManager::new(runtime, metrics.clone(), schema));
1466
1467 let (writer, mut reader) = channel(batch_size, spill_manager);
1468
1469 // Step 3: Write NUM_BATCHES batches to create approximately NUM_BATCHES files
1470 for i in 0..NUM_BATCHES {
1471 let start = (i * ROWS_PER_BATCH) as i32;
1472 writer.push_batch(&create_test_batch(start, ROWS_PER_BATCH))?;
1473 }
1474
1475 // Check how many files were created (should be at least a few due to file rotation)
1476 let file_count = metrics.spill_file_count.value();
1477 assert_eq!(
1478 file_count,
1479 NUM_BATCHES - 1,
1480 "Expected at {} files with rotation, got {file_count}",
1481 NUM_BATCHES - 1
1482 );
1483
1484 // Step 4: Verify initial disk usage reflects all files
1485 let initial_disk_usage = disk_manager.used_disk_space();
1486 assert!(
1487 initial_disk_usage > 0,
1488 "Expected disk usage > 0 after writing batches, got {initial_disk_usage}"
1489 );
1490
1491 // Step 5: Read NUM_BATCHES - 1 batches (all but 1)
1492 // As each file is fully consumed, it should be dropped and disk usage should decrease
1493 for i in 0..(NUM_BATCHES - 1) {
1494 let result = reader.next().await.unwrap()?;
1495 assert_eq!(result.num_rows(), ROWS_PER_BATCH);
1496
1497 let col = result
1498 .column(0)
1499 .as_any()
1500 .downcast_ref::<Int32Array>()
1501 .unwrap();
1502 assert_eq!(col.value(0), (i * ROWS_PER_BATCH) as i32);
1503 }
1504
1505 // Step 6: Verify disk usage decreased but is not zero (at least 1 batch remains)
1506 let partial_disk_usage = disk_manager.used_disk_space();
1507 assert!(
1508 partial_disk_usage > 0
1509 && partial_disk_usage < (batch_size * NUM_BATCHES * 2) as u64,
1510 "Disk usage should be > 0 with remaining batches"
1511 );
1512 assert!(
1513 partial_disk_usage < initial_disk_usage,
1514 "Disk usage should have decreased after reading most batches: initial={initial_disk_usage}, partial={partial_disk_usage}"
1515 );
1516
1517 // Step 7: Read the final batch
1518 let result = reader.next().await.unwrap()?;
1519 assert_eq!(result.num_rows(), ROWS_PER_BATCH);
1520
1521 // Step 8: Drop writer first to signal no more data will be written
1522 // The reader has infinite stream semantics and will wait for the writer
1523 // to be dropped before returning None
1524 drop(writer);
1525
1526 // Verify we've read all batches - now the reader should return None
1527 assert!(
1528 reader.next().await.is_none(),
1529 "Should have no more batches to read"
1530 );
1531
1532 // Step 9: Drop reader to release all references
1533 drop(reader);
1534
1535 // Step 10: Verify complete cleanup - disk usage should be 0
1536 let final_disk_usage = disk_manager.used_disk_space();
1537 assert_eq!(
1538 final_disk_usage, 0,
1539 "Disk usage should be 0 after all files dropped, got {final_disk_usage}"
1540 );
1541
1542 Ok(())
1543 }
1544}