use futures::{Stream, StreamExt};
use std::collections::VecDeque;
use std::sync::Arc;
use std::task::Waker;
use parking_lot::Mutex;
use arrow::datatypes::SchemaRef;
use arrow::record_batch::RecordBatch;
use datafusion_common::Result;
use datafusion_execution::disk_manager::RefCountedTempFile;
use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream};
use super::in_progress_spill_file::InProgressSpillFile;
use super::spill_manager::SpillManager;
struct SpillPoolShared {
files: VecDeque<Arc<Mutex<ActiveSpillFileShared>>>,
spill_manager: Arc<SpillManager>,
waker: Option<Waker>,
writer_dropped: bool,
current_write_file: Option<Arc<Mutex<ActiveSpillFileShared>>>,
active_writer_count: usize,
}
impl SpillPoolShared {
fn new(spill_manager: Arc<SpillManager>) -> Self {
Self {
files: VecDeque::new(),
spill_manager,
waker: None,
writer_dropped: false,
current_write_file: None,
active_writer_count: 1,
}
}
fn register_waker(&mut self, waker: Waker) {
self.waker = Some(waker);
}
fn wake(&mut self) {
if let Some(waker) = self.waker.take() {
waker.wake();
}
}
}
pub struct SpillPoolWriter {
max_file_size_bytes: usize,
shared: Arc<Mutex<SpillPoolShared>>,
}
impl Clone for SpillPoolWriter {
fn clone(&self) -> Self {
self.shared.lock().active_writer_count += 1;
Self {
max_file_size_bytes: self.max_file_size_bytes,
shared: Arc::clone(&self.shared),
}
}
}
impl SpillPoolWriter {
pub fn push_batch(&self, batch: &RecordBatch) -> Result<()> {
if batch.num_rows() == 0 {
return Ok(());
}
let batch_size = batch.get_array_memory_size();
let mut shared = self.shared.lock();
if shared.current_write_file.is_none() {
let spill_manager = Arc::clone(&shared.spill_manager);
drop(shared);
let writer = spill_manager.create_in_progress_file("SpillPool")?;
let file = writer.file().expect("InProgressSpillFile should always have a file when it is first created").clone();
let file_shared = Arc::new(Mutex::new(ActiveSpillFileShared {
writer: Some(writer),
file: Some(file), batches_written: 0,
estimated_size: 0,
writer_finished: false,
waker: None,
}));
shared = self.shared.lock();
shared.files.push_back(Arc::clone(&file_shared));
shared.current_write_file = Some(file_shared);
shared.wake(); }
let current_write_file = shared.current_write_file.take();
drop(shared);
if let Some(current_file) = current_write_file {
let mut file_shared = current_file.lock();
if let Some(ref mut writer) = file_shared.writer {
writer.append_batch(batch)?;
writer.flush()?;
file_shared.batches_written += 1;
file_shared.estimated_size += batch_size;
}
file_shared.wake();
let needs_rotation = file_shared.estimated_size > self.max_file_size_bytes;
if needs_rotation {
if let Some(mut writer) = file_shared.writer.take() {
writer.finish()?;
}
file_shared.writer_finished = true;
file_shared.wake();
} else {
drop(file_shared);
let mut shared = self.shared.lock();
shared.current_write_file = Some(current_file);
}
}
Ok(())
}
}
impl Drop for SpillPoolWriter {
fn drop(&mut self) {
let mut shared = self.shared.lock();
shared.active_writer_count -= 1;
let is_last_writer = shared.active_writer_count == 0;
if !is_last_writer {
return;
}
if let Some(current_file) = shared.current_write_file.take() {
drop(shared);
let mut file_shared = current_file.lock();
if let Some(mut writer) = file_shared.writer.take() {
let _ = writer.finish();
}
file_shared.writer_finished = true;
file_shared.wake();
drop(file_shared);
shared = self.shared.lock();
}
shared.writer_dropped = true;
shared.wake();
}
}
pub fn channel(
max_file_size_bytes: usize,
spill_manager: Arc<SpillManager>,
) -> (SpillPoolWriter, SendableRecordBatchStream) {
let schema = Arc::clone(spill_manager.schema());
let shared = Arc::new(Mutex::new(SpillPoolShared::new(spill_manager)));
let writer = SpillPoolWriter {
max_file_size_bytes,
shared: Arc::clone(&shared),
};
let reader = SpillPoolReader::new(shared, schema);
(writer, Box::pin(reader))
}
struct ActiveSpillFileShared {
writer: Option<InProgressSpillFile>,
file: Option<RefCountedTempFile>,
batches_written: usize,
estimated_size: usize,
writer_finished: bool,
waker: Option<Waker>,
}
impl ActiveSpillFileShared {
fn register_waker(&mut self, waker: Waker) {
self.waker = Some(waker);
}
fn wake(&mut self) {
if let Some(waker) = self.waker.take() {
waker.wake();
}
}
}
struct SpillFileReader {
stream: SendableRecordBatchStream,
batches_read: usize,
}
struct SpillFile {
shared: Arc<Mutex<ActiveSpillFileShared>>,
reader: Option<SpillFileReader>,
spill_manager: Arc<SpillManager>,
}
impl Stream for SpillFile {
type Item = Result<RecordBatch>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
use std::task::Poll;
let (should_read, file) = {
let mut shared = self.shared.lock();
let batches_read = self.reader.as_ref().map_or(0, |r| r.batches_read);
if batches_read < shared.batches_written {
let file = if self.reader.is_none() {
shared.file.take()
} else {
None
};
(true, file)
} else if shared.writer_finished {
return Poll::Ready(None);
} else {
shared.register_waker(cx.waker().clone());
return Poll::Pending;
}
};
if self.reader.is_none() && should_read {
if let Some(file) = file {
match self
.spill_manager
.read_spill_as_stream_unbuffered(file, None)
{
Ok(stream) => {
self.reader = Some(SpillFileReader {
stream,
batches_read: 0,
});
}
Err(e) => return Poll::Ready(Some(Err(e))),
}
} else {
let mut shared = self.shared.lock();
shared.register_waker(cx.waker().clone());
return Poll::Pending;
}
}
if let Some(reader) = &mut self.reader {
match reader.stream.poll_next_unpin(cx) {
Poll::Ready(Some(Ok(batch))) => {
reader.batches_read += 1;
Poll::Ready(Some(Ok(batch)))
}
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
Poll::Ready(None) => {
Poll::Ready(None)
}
Poll::Pending => Poll::Pending,
}
} else {
Poll::Ready(None)
}
}
}
pub struct SpillPoolReader {
shared: Arc<Mutex<SpillPoolShared>>,
current_file: Option<SpillFile>,
schema: SchemaRef,
}
impl SpillPoolReader {
fn new(shared: Arc<Mutex<SpillPoolShared>>, schema: SchemaRef) -> Self {
Self {
shared,
current_file: None,
schema,
}
}
}
impl Stream for SpillPoolReader {
type Item = Result<RecordBatch>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
use std::task::Poll;
loop {
if let Some(ref mut file) = self.current_file {
match file.poll_next_unpin(cx) {
Poll::Ready(Some(Ok(batch))) => {
return Poll::Ready(Some(Ok(batch)));
}
Poll::Ready(Some(Err(e))) => {
return Poll::Ready(Some(Err(e)));
}
Poll::Ready(None) => {
let writer_finished = { file.shared.lock().writer_finished };
if writer_finished {
let mut shared = self.shared.lock();
shared.files.pop_front();
drop(shared);
self.current_file = None;
continue;
} else {
return Poll::Ready(None);
}
}
Poll::Pending => {
let mut shared = self.shared.lock();
shared.register_waker(cx.waker().clone());
return Poll::Pending;
}
}
}
let mut shared = self.shared.lock();
if let Some(file_shared) = shared.files.front() {
let spill_manager = Arc::clone(&shared.spill_manager);
let file_shared = Arc::clone(file_shared);
drop(shared);
self.current_file = Some(SpillFile {
shared: file_shared,
reader: None,
spill_manager,
});
continue;
}
if shared.writer_dropped {
return Poll::Ready(None);
}
shared.register_waker(cx.waker().clone());
return Poll::Pending;
}
}
}
impl RecordBatchStream for SpillPoolReader {
fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::metrics::{ExecutionPlanMetricsSet, SpillMetrics};
use arrow::array::{ArrayRef, Int32Array};
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common_runtime::SpawnedTask;
use datafusion_execution::runtime_env::RuntimeEnv;
use futures::StreamExt;
fn create_test_schema() -> SchemaRef {
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]))
}
fn create_test_batch(start: i32, count: usize) -> RecordBatch {
let schema = create_test_schema();
let a: ArrayRef = Arc::new(Int32Array::from(
(start..start + count as i32).collect::<Vec<_>>(),
));
RecordBatch::try_new(schema, vec![a]).unwrap()
}
fn create_spill_channel(
max_file_size: usize,
) -> (SpillPoolWriter, SendableRecordBatchStream) {
let env = Arc::new(RuntimeEnv::default());
let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
let schema = create_test_schema();
let spill_manager = Arc::new(SpillManager::new(env, metrics, schema));
channel(max_file_size, spill_manager)
}
fn create_spill_channel_with_metrics(
max_file_size: usize,
) -> (SpillPoolWriter, SendableRecordBatchStream, SpillMetrics) {
let env = Arc::new(RuntimeEnv::default());
let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
let schema = create_test_schema();
let spill_manager = Arc::new(SpillManager::new(env, metrics.clone(), schema));
let (writer, reader) = channel(max_file_size, spill_manager);
(writer, reader, metrics)
}
#[tokio::test]
async fn test_basic_write_and_read() -> Result<()> {
let (writer, mut reader) = create_spill_channel(1024 * 1024);
let batch1 = create_test_batch(0, 10);
writer.push_batch(&batch1)?;
let result = reader.next().await.unwrap()?;
assert_eq!(result.num_rows(), 10);
let batch2 = create_test_batch(10, 5);
writer.push_batch(&batch2)?;
let result = reader.next().await.unwrap()?;
assert_eq!(result.num_rows(), 5);
Ok(())
}
#[tokio::test]
async fn test_single_batch_write_read() -> Result<()> {
let (writer, mut reader) = create_spill_channel(1024 * 1024);
let batch = create_test_batch(0, 5);
writer.push_batch(&batch)?;
let result = reader.next().await.unwrap()?;
assert_eq!(result.num_rows(), 5);
let col = result
.column(0)
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
assert_eq!(col.value(0), 0);
assert_eq!(col.value(4), 4);
Ok(())
}
#[tokio::test]
async fn test_multiple_batches_sequential() -> Result<()> {
let (writer, mut reader) = create_spill_channel(1024 * 1024);
for i in 0..5 {
let batch = create_test_batch(i * 10, 10);
writer.push_batch(&batch)?;
}
for i in 0..5 {
let result = reader.next().await.unwrap()?;
assert_eq!(result.num_rows(), 10);
let col = result
.column(0)
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
assert_eq!(col.value(0), i * 10, "Batch {i} not in FIFO order");
}
Ok(())
}
#[tokio::test]
async fn test_empty_writer() -> Result<()> {
let (_writer, reader) = create_spill_channel(1024 * 1024);
let mut reader = reader;
let result =
tokio::time::timeout(std::time::Duration::from_millis(100), reader.next())
.await;
assert!(result.is_err(), "Reader should timeout on empty writer");
Ok(())
}
#[tokio::test]
async fn test_empty_batch_skipping() -> Result<()> {
let (writer, mut reader) = create_spill_channel(1024 * 1024);
let empty_batch = create_test_batch(0, 0);
writer.push_batch(&empty_batch)?;
let batch = create_test_batch(0, 5);
writer.push_batch(&batch)?;
let result = reader.next().await.unwrap()?;
assert_eq!(result.num_rows(), 5);
Ok(())
}
#[tokio::test]
async fn test_rotation_triggered_by_size() -> Result<()> {
let batch1 = create_test_batch(0, 10);
let batch_size = batch1.get_array_memory_size() + 1;
let (writer, mut reader, metrics) = create_spill_channel_with_metrics(batch_size);
writer.push_batch(&batch1)?;
assert_eq!(
metrics.spill_file_count.value(),
1,
"Should have created 1 file after first batch"
);
assert_eq!(
metrics.spilled_bytes.value(),
320,
"Spilled bytes should reflect data written (header + 1 batch)"
);
assert_eq!(
metrics.spilled_rows.value(),
10,
"Should have spilled 10 rows from first batch"
);
let batch2 = create_test_batch(10, 10);
assert!(
batch2.get_array_memory_size() <= batch_size,
"batch2 size {} exceeds limit {batch_size}",
batch2.get_array_memory_size(),
);
assert!(
batch1.get_array_memory_size() + batch2.get_array_memory_size() > batch_size,
"Combined size {} does not exceed limit to trigger rotation",
batch1.get_array_memory_size() + batch2.get_array_memory_size()
);
writer.push_batch(&batch2)?;
assert_eq!(
metrics.spill_file_count.value(),
1,
"Should still have 1 file (second file not created until next write)"
);
assert!(
metrics.spilled_bytes.value() > 0,
"Spilled bytes should be > 0 after first file finalized (got {})",
metrics.spilled_bytes.value()
);
assert_eq!(
metrics.spilled_rows.value(),
20,
"Should have spilled 20 total rows (10 + 10)"
);
let batch3 = create_test_batch(20, 5);
writer.push_batch(&batch3)?;
assert_eq!(
metrics.spill_file_count.value(),
2,
"Should have created 2 files after writing to new file"
);
assert_eq!(
metrics.spilled_rows.value(),
25,
"Should have spilled 25 total rows (10 + 10 + 5)"
);
let result1 = reader.next().await.unwrap()?;
assert_eq!(result1.num_rows(), 10);
let result2 = reader.next().await.unwrap()?;
assert_eq!(result2.num_rows(), 10);
let result3 = reader.next().await.unwrap()?;
assert_eq!(result3.num_rows(), 5);
Ok(())
}
#[tokio::test]
async fn test_multiple_rotations() -> Result<()> {
let batches = (0..10)
.map(|i| create_test_batch(i * 10, 10))
.collect::<Vec<_>>();
let batch_size = batches[0].get_array_memory_size() * 2 + 1;
let (writer, mut reader, metrics) = create_spill_channel_with_metrics(batch_size);
for i in 0..10 {
let batch = create_test_batch(i * 10, 10);
writer.push_batch(&batch)?;
}
let file_count = metrics.spill_file_count.value();
assert!(
file_count >= 4,
"Should have created at least 4 files with multiple rotations (got {file_count})"
);
assert!(
metrics.spilled_bytes.value() > 0,
"Spilled bytes should be > 0 after rotations (got {})",
metrics.spilled_bytes.value()
);
assert_eq!(
metrics.spilled_rows.value(),
100,
"Should have spilled 100 total rows (10 batches * 10 rows)"
);
for i in 0..10 {
let result = reader.next().await.unwrap()?;
assert_eq!(result.num_rows(), 10);
let col = result
.column(0)
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
assert_eq!(
col.value(0),
i * 10,
"Batch {i} not in correct order after rotations"
);
}
Ok(())
}
#[tokio::test]
async fn test_single_batch_larger_than_limit() -> Result<()> {
let (writer, mut reader, metrics) = create_spill_channel_with_metrics(100);
let large_batch = create_test_batch(0, 100);
writer.push_batch(&large_batch)?;
assert_eq!(
metrics.spill_file_count.value(),
1,
"Should have created 1 file for large batch"
);
assert_eq!(
metrics.spilled_rows.value(),
100,
"Should have spilled 100 rows from large batch"
);
let result = reader.next().await.unwrap()?;
assert_eq!(result.num_rows(), 100);
let batch2 = create_test_batch(100, 10);
writer.push_batch(&batch2)?;
assert_eq!(
metrics.spill_file_count.value(),
2,
"Should have created 2 files after rotation"
);
assert_eq!(
metrics.spilled_rows.value(),
110,
"Should have spilled 110 total rows (100 + 10)"
);
let result2 = reader.next().await.unwrap()?;
assert_eq!(result2.num_rows(), 10);
Ok(())
}
#[tokio::test]
async fn test_very_small_max_file_size() -> Result<()> {
let (writer, mut reader) = create_spill_channel(1);
let batch = create_test_batch(0, 5);
writer.push_batch(&batch)?;
let result = reader.next().await.unwrap()?;
assert_eq!(result.num_rows(), 5);
Ok(())
}
#[tokio::test]
async fn test_exact_size_boundary() -> Result<()> {
let batch = create_test_batch(0, 10);
let batch_size = batch.get_array_memory_size();
let (writer, mut reader, metrics) = create_spill_channel_with_metrics(batch_size);
writer.push_batch(&batch)?;
assert_eq!(
metrics.spill_file_count.value(),
1,
"Should have created 1 file after first batch at exact boundary"
);
assert_eq!(
metrics.spilled_rows.value(),
10,
"Should have spilled 10 rows from first batch"
);
let batch2 = create_test_batch(10, 10);
writer.push_batch(&batch2)?;
assert_eq!(
metrics.spill_file_count.value(),
1,
"Should still have 1 file after rotation (second file created lazily)"
);
assert_eq!(
metrics.spilled_rows.value(),
20,
"Should have spilled 20 total rows (10 + 10)"
);
assert!(
metrics.spilled_bytes.value() > 0,
"Spilled bytes should be > 0 after file finalization (got {})",
metrics.spilled_bytes.value()
);
let result1 = reader.next().await.unwrap()?;
assert_eq!(result1.num_rows(), 10);
let result2 = reader.next().await.unwrap()?;
assert_eq!(result2.num_rows(), 10);
let batch3 = create_test_batch(20, 5);
writer.push_batch(&batch3)?;
assert_eq!(
metrics.spill_file_count.value(),
2,
"Should have created 2 files after writing to new file"
);
Ok(())
}
#[tokio::test]
async fn test_concurrent_reader_writer() -> Result<()> {
let (writer, mut reader) = create_spill_channel(1024 * 1024);
let writer_handle = SpawnedTask::spawn(async move {
for i in 0..10 {
let batch = create_test_batch(i * 10, 10);
writer.push_batch(&batch).unwrap();
tokio::time::sleep(std::time::Duration::from_millis(5)).await;
}
});
let reader_handle = SpawnedTask::spawn(async move {
let mut count = 0;
for i in 0..10 {
let result = reader.next().await.unwrap().unwrap();
assert_eq!(result.num_rows(), 10);
let col = result
.column(0)
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
assert_eq!(col.value(0), i * 10);
count += 1;
}
count
});
writer_handle.await.unwrap();
let batches_read = reader_handle.await.unwrap();
assert_eq!(batches_read, 10);
Ok(())
}
#[tokio::test]
async fn test_reader_catches_up_to_writer() -> Result<()> {
let (writer, mut reader) = create_spill_channel(1024 * 1024);
let (reader_waiting_tx, reader_waiting_rx) = tokio::sync::oneshot::channel();
let (first_read_done_tx, first_read_done_rx) = tokio::sync::oneshot::channel();
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum ReadWriteEvent {
ReadStart,
Read(usize),
Write(usize),
}
let events = Arc::new(Mutex::new(vec![]));
let reader_events = Arc::clone(&events);
let reader_handle = SpawnedTask::spawn(async move {
reader_events.lock().push(ReadWriteEvent::ReadStart);
reader_waiting_tx
.send(())
.expect("reader_waiting channel closed unexpectedly");
let result = reader.next().await.unwrap().unwrap();
reader_events
.lock()
.push(ReadWriteEvent::Read(result.num_rows()));
first_read_done_tx
.send(())
.expect("first_read_done channel closed unexpectedly");
let result = reader.next().await.unwrap().unwrap();
reader_events
.lock()
.push(ReadWriteEvent::Read(result.num_rows()));
});
reader_waiting_rx
.await
.expect("reader should signal when waiting");
let batch = create_test_batch(0, 5);
events.lock().push(ReadWriteEvent::Write(batch.num_rows()));
writer.push_batch(&batch)?;
first_read_done_rx
.await
.expect("reader should signal when first read completes");
let batch = create_test_batch(5, 10);
events.lock().push(ReadWriteEvent::Write(batch.num_rows()));
writer.push_batch(&batch)?;
reader_handle.await.unwrap();
let events = events.lock().clone();
assert_eq!(
events,
vec![
ReadWriteEvent::ReadStart,
ReadWriteEvent::Write(5),
ReadWriteEvent::Read(5),
ReadWriteEvent::Write(10),
ReadWriteEvent::Read(10)
]
);
Ok(())
}
#[tokio::test]
async fn test_reader_starts_after_writer_finishes() -> Result<()> {
let (writer, reader) = create_spill_channel(128);
for i in 0..5 {
let batch = create_test_batch(i * 10, 10);
writer.push_batch(&batch)?;
}
drop(writer);
let mut reader = reader;
let mut count = 0;
for i in 0..5 {
let result = reader.next().await.unwrap()?;
assert_eq!(result.num_rows(), 10);
let col = result
.column(0)
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
assert_eq!(col.value(0), i * 10);
count += 1;
}
assert_eq!(count, 5, "Should read all batches after writer finishes");
Ok(())
}
#[tokio::test]
async fn test_writer_drop_finalizes_file() -> Result<()> {
let env = Arc::new(RuntimeEnv::default());
let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
let schema = create_test_schema();
let spill_manager =
Arc::new(SpillManager::new(Arc::clone(&env), metrics.clone(), schema));
let (writer, mut reader) = channel(1024 * 1024, spill_manager);
for i in 0..5 {
let batch = create_test_batch(i * 10, 10);
writer.push_batch(&batch)?;
}
let spilled_bytes_before = metrics.spilled_bytes.value();
assert_eq!(
spilled_bytes_before, 1088,
"Spilled bytes should reflect data written (header + 5 batches)"
);
drop(writer);
let spilled_bytes_after = metrics.spilled_bytes.value();
assert!(
spilled_bytes_after > 0,
"Spilled bytes should be > 0 after writer is dropped (got {spilled_bytes_after})"
);
let mut count = 0;
for i in 0..5 {
let result = reader.next().await.unwrap()?;
assert_eq!(result.num_rows(), 10);
let col = result
.column(0)
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
assert_eq!(col.value(0), i * 10);
count += 1;
}
assert_eq!(count, 5, "Should read all batches after writer is dropped");
Ok(())
}
#[tokio::test]
async fn test_clone_drop_does_not_signal_eof_prematurely() -> Result<()> {
let (writer1, mut reader) = create_spill_channel(1024 * 1024);
let writer2 = writer1.clone();
let (proceed_tx, proceed_rx) = tokio::sync::oneshot::channel::<()>();
let writer2_handle = SpawnedTask::spawn(async move {
proceed_rx.await.unwrap();
writer2.push_batch(&create_test_batch(10, 10)).unwrap();
});
writer1.push_batch(&create_test_batch(0, 10))?;
drop(writer1);
let batch1 = reader.next().await.unwrap()?;
assert_eq!(batch1.num_rows(), 10);
let col = batch1
.column(0)
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
assert_eq!(col.value(0), 0);
proceed_tx.send(()).unwrap();
let batch2 =
tokio::time::timeout(std::time::Duration::from_secs(5), reader.next())
.await
.expect("Reader timed out — should not hang");
assert!(
batch2.is_some(),
"Reader must not return EOF while a writer clone is still alive"
);
let batch2 = batch2.unwrap()?;
assert_eq!(batch2.num_rows(), 10);
let col = batch2
.column(0)
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
assert_eq!(col.value(0), 10);
writer2_handle.await.unwrap();
assert!(reader.next().await.is_none());
Ok(())
}
#[tokio::test]
async fn test_disk_usage_decreases_as_files_consumed() -> Result<()> {
use datafusion_execution::runtime_env::RuntimeEnvBuilder;
const NUM_BATCHES: usize = 3;
const ROWS_PER_BATCH: usize = 100;
let batch = create_test_batch(0, ROWS_PER_BATCH);
let batch_size = batch.get_array_memory_size();
let runtime = Arc::new(RuntimeEnvBuilder::default().build()?);
let disk_manager = Arc::clone(&runtime.disk_manager);
let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
let schema = create_test_schema();
let spill_manager = Arc::new(SpillManager::new(runtime, metrics.clone(), schema));
let (writer, mut reader) = channel(batch_size, spill_manager);
for i in 0..NUM_BATCHES {
let start = (i * ROWS_PER_BATCH) as i32;
writer.push_batch(&create_test_batch(start, ROWS_PER_BATCH))?;
}
let file_count = metrics.spill_file_count.value();
assert_eq!(
file_count,
NUM_BATCHES - 1,
"Expected at {} files with rotation, got {file_count}",
NUM_BATCHES - 1
);
let initial_disk_usage = disk_manager.used_disk_space();
assert!(
initial_disk_usage > 0,
"Expected disk usage > 0 after writing batches, got {initial_disk_usage}"
);
for i in 0..(NUM_BATCHES - 1) {
let result = reader.next().await.unwrap()?;
assert_eq!(result.num_rows(), ROWS_PER_BATCH);
let col = result
.column(0)
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
assert_eq!(col.value(0), (i * ROWS_PER_BATCH) as i32);
}
let partial_disk_usage = disk_manager.used_disk_space();
assert!(
partial_disk_usage > 0
&& partial_disk_usage < (batch_size * NUM_BATCHES * 2) as u64,
"Disk usage should be > 0 with remaining batches"
);
assert!(
partial_disk_usage < initial_disk_usage,
"Disk usage should have decreased after reading most batches: initial={initial_disk_usage}, partial={partial_disk_usage}"
);
let result = reader.next().await.unwrap()?;
assert_eq!(result.num_rows(), ROWS_PER_BATCH);
drop(writer);
assert!(
reader.next().await.is_none(),
"Should have no more batches to read"
);
drop(reader);
let final_disk_usage = disk_manager.used_disk_space();
assert_eq!(
final_disk_usage, 0,
"Disk usage should be 0 after all files dropped, got {final_disk_usage}"
);
Ok(())
}
}