pub(crate) mod in_progress_spill_file;
pub(crate) mod spill_manager;
pub mod spill_pool;
pub use datafusion_common::utils::memory::get_record_batch_memory_size;
#[doc(hidden)]
pub use spill_manager::SpillManager;
use std::fs::File;
use std::io::BufReader;
use std::path::{Path, PathBuf};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use arrow::array::{BufferSpec, layout};
use arrow::datatypes::{Schema, SchemaRef};
use arrow::ipc::{
MetadataVersion,
reader::StreamReader,
writer::{IpcWriteOptions, StreamWriter},
};
use arrow::record_batch::RecordBatch;
use datafusion_common::config::SpillCompression;
use datafusion_common::{DataFusionError, Result, exec_datafusion_err};
use datafusion_common_runtime::SpawnedTask;
use datafusion_execution::RecordBatchStream;
use datafusion_execution::disk_manager::RefCountedTempFile;
use futures::{FutureExt as _, Stream};
use log::debug;
struct SpillReaderStream {
schema: SchemaRef,
state: SpillReaderStreamState,
max_record_batch_memory: Option<usize>,
}
const SPILL_BATCH_MEMORY_MARGIN: usize = 4096;
type NextRecordBatchResult = Result<(StreamReader<BufReader<File>>, Option<RecordBatch>)>;
enum SpillReaderStreamState {
Uninitialized(RefCountedTempFile),
ReadInProgress(SpawnedTask<NextRecordBatchResult>),
Waiting(StreamReader<BufReader<File>>),
Done,
}
impl SpillReaderStream {
fn new(
schema: SchemaRef,
spill_file: RefCountedTempFile,
max_record_batch_memory: Option<usize>,
) -> Self {
Self {
schema,
state: SpillReaderStreamState::Uninitialized(spill_file),
max_record_batch_memory,
}
}
fn poll_next_inner(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<Result<RecordBatch>>> {
match &mut self.state {
SpillReaderStreamState::Uninitialized(_) => {
let SpillReaderStreamState::Uninitialized(spill_file) =
std::mem::replace(&mut self.state, SpillReaderStreamState::Done)
else {
unreachable!()
};
let task = SpawnedTask::spawn_blocking(move || {
let file = BufReader::new(File::open(spill_file.path())?);
let mut reader = unsafe {
StreamReader::try_new(file, None)?.with_skip_validation(true)
};
let next_batch = reader.next().transpose()?;
Ok((reader, next_batch))
});
self.state = SpillReaderStreamState::ReadInProgress(task);
self.poll_next_inner(cx)
}
SpillReaderStreamState::ReadInProgress(task) => {
let result = futures::ready!(task.poll_unpin(cx))
.unwrap_or_else(|err| Err(DataFusionError::External(Box::new(err))));
match result {
Ok((reader, batch)) => {
match batch {
Some(batch) => {
if let Some(max_record_batch_memory) =
self.max_record_batch_memory
{
let actual_size =
get_record_batch_memory_size(&batch);
if actual_size
> max_record_batch_memory
+ SPILL_BATCH_MEMORY_MARGIN
{
debug!(
"Record batch memory usage ({actual_size} bytes) exceeds the expected limit ({max_record_batch_memory} bytes) \n\
by more than the allowed tolerance ({SPILL_BATCH_MEMORY_MARGIN} bytes).\n\
This likely indicates a bug in memory accounting during spilling.\n\
Please report this issue in https://github.com/apache/datafusion/issues/17340."
);
}
}
self.state = SpillReaderStreamState::Waiting(reader);
Poll::Ready(Some(Ok(batch)))
}
None => {
self.state = SpillReaderStreamState::Done;
Poll::Ready(None)
}
}
}
Err(err) => {
self.state = SpillReaderStreamState::Done;
Poll::Ready(Some(Err(err)))
}
}
}
SpillReaderStreamState::Waiting(_) => {
let SpillReaderStreamState::Waiting(mut reader) =
std::mem::replace(&mut self.state, SpillReaderStreamState::Done)
else {
unreachable!()
};
let task = SpawnedTask::spawn_blocking(move || {
let next_batch = reader.next().transpose()?;
Ok((reader, next_batch))
});
self.state = SpillReaderStreamState::ReadInProgress(task);
self.poll_next_inner(cx)
}
SpillReaderStreamState::Done => Poll::Ready(None),
}
}
}
impl Stream for SpillReaderStream {
type Item = Result<RecordBatch>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.get_mut().poll_next_inner(cx)
}
}
impl RecordBatchStream for SpillReaderStream {
fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
}
#[deprecated(
since = "46.0.0",
note = "This method is deprecated. Use `SpillManager::spill_record_batch_by_size` instead."
)]
#[expect(clippy::needless_pass_by_value)]
pub fn spill_record_batch_by_size(
batch: &RecordBatch,
path: PathBuf,
schema: SchemaRef,
batch_size_rows: usize,
) -> Result<()> {
let mut offset = 0;
let total_rows = batch.num_rows();
let mut writer =
IPCStreamWriter::new(&path, schema.as_ref(), SpillCompression::Uncompressed)?;
while offset < total_rows {
let length = std::cmp::min(total_rows - offset, batch_size_rows);
let batch = batch.slice(offset, length);
offset += batch.num_rows();
writer.write(&batch)?;
}
writer.finish()?;
Ok(())
}
struct IPCStreamWriter {
pub writer: StreamWriter<File>,
pub num_batches: usize,
pub num_rows: usize,
pub num_bytes: usize,
}
impl IPCStreamWriter {
pub fn new(
path: &Path,
schema: &Schema,
compression_type: SpillCompression,
) -> Result<Self> {
let file = File::create(path).map_err(|e| {
exec_datafusion_err!("(Hint: you may increase the file descriptor limit with shell command 'ulimit -n 4096') Failed to create partition file at {path:?}: {e:?}")
})?;
let metadata_version = MetadataVersion::V5;
let alignment = get_max_alignment_for_schema(schema);
let mut write_options =
IpcWriteOptions::try_new(alignment, false, metadata_version)?;
write_options = write_options.try_with_compression(compression_type.into())?;
let writer = StreamWriter::try_new_with_options(file, schema, write_options)?;
Ok(Self {
num_batches: 0,
num_rows: 0,
num_bytes: 0,
writer,
})
}
pub fn write(&mut self, batch: &RecordBatch) -> Result<(usize, usize)> {
self.writer.write(batch)?;
self.num_batches += 1;
let delta_num_rows = batch.num_rows();
self.num_rows += delta_num_rows;
let delta_num_bytes: usize = batch.get_array_memory_size();
self.num_bytes += delta_num_bytes;
Ok((delta_num_rows, delta_num_bytes))
}
pub fn flush(&mut self) -> Result<()> {
self.writer.flush()?;
Ok(())
}
pub fn finish(&mut self) -> Result<()> {
self.writer.finish().map_err(Into::into)
}
}
fn get_max_alignment_for_schema(schema: &Schema) -> usize {
let minimum_alignment = 8;
let mut max_alignment = minimum_alignment;
for field in schema.fields() {
let layout = layout(field.data_type());
let required_alignment = layout
.buffers
.iter()
.map(|buffer_spec| {
if let BufferSpec::FixedWidth { alignment, .. } = buffer_spec {
*alignment
} else {
minimum_alignment
}
})
.max()
.unwrap_or(minimum_alignment);
max_alignment = std::cmp::max(max_alignment, required_alignment);
}
max_alignment
}
#[cfg(test)]
mod tests {
use super::in_progress_spill_file::InProgressSpillFile;
use super::*;
use crate::common::collect;
use crate::metrics::ExecutionPlanMetricsSet;
use crate::metrics::SpillMetrics;
use crate::spill::spill_manager::SpillManager;
use crate::test::build_table_i32;
use arrow::array::{ArrayRef, Int32Array, StringArray};
use arrow::compute::cast;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use datafusion_common::Result;
use datafusion_execution::runtime_env::RuntimeEnv;
use futures::StreamExt as _;
use std::sync::Arc;
#[tokio::test]
async fn test_batch_spill_and_read() -> Result<()> {
let batch1 = build_table_i32(
("a2", &vec![0, 1, 2]),
("b2", &vec![3, 4, 5]),
("c2", &vec![4, 5, 6]),
);
let batch2 = build_table_i32(
("a2", &vec![10, 11, 12]),
("b2", &vec![13, 14, 15]),
("c2", &vec![14, 15, 16]),
);
let schema = batch1.schema();
let num_rows = batch1.num_rows() + batch2.num_rows();
let env = Arc::new(RuntimeEnv::default());
let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
let spill_manager = SpillManager::new(env, metrics, Arc::clone(&schema));
let spill_file = spill_manager
.spill_record_batch_and_finish(&[batch1, batch2], "Test")?
.unwrap();
assert!(spill_file.path().exists());
let spilled_rows = spill_manager.metrics.spilled_rows.value();
assert_eq!(spilled_rows, num_rows);
let stream = spill_manager.read_spill_as_stream(spill_file, None)?;
assert_eq!(stream.schema(), schema);
let batches = collect(stream).await?;
assert_eq!(batches.len(), 2);
Ok(())
}
#[tokio::test]
async fn test_batch_spill_and_read_dictionary_arrays() -> Result<()> {
let batch1 = build_table_i32(
("a2", &vec![0, 1, 2]),
("b2", &vec![3, 4, 5]),
("c2", &vec![4, 5, 6]),
);
let batch2 = build_table_i32(
("a2", &vec![10, 11, 12]),
("b2", &vec![13, 14, 15]),
("c2", &vec![14, 15, 16]),
);
let dict_type =
DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Int32));
let dict_schema = Arc::new(Schema::new(vec![
Field::new("a2", dict_type.clone(), true),
Field::new("b2", dict_type.clone(), true),
Field::new("c2", dict_type.clone(), true),
]));
let batch1 = RecordBatch::try_new(
Arc::clone(&dict_schema),
batch1
.columns()
.iter()
.map(|array| cast(array, &dict_type))
.collect::<Result<_, _>>()?,
)?;
let batch2 = RecordBatch::try_new(
Arc::clone(&dict_schema),
batch2
.columns()
.iter()
.map(|array| cast(array, &dict_type))
.collect::<Result<_, _>>()?,
)?;
let env = Arc::new(RuntimeEnv::default());
let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
let spill_manager = SpillManager::new(env, metrics, Arc::clone(&dict_schema));
let num_rows = batch1.num_rows() + batch2.num_rows();
let spill_file = spill_manager
.spill_record_batch_and_finish(&[batch1, batch2], "Test")?
.unwrap();
let spilled_rows = spill_manager.metrics.spilled_rows.value();
assert_eq!(spilled_rows, num_rows);
let stream = spill_manager.read_spill_as_stream(spill_file, None)?;
assert_eq!(stream.schema(), dict_schema);
let batches = collect(stream).await?;
assert_eq!(batches.len(), 2);
Ok(())
}
#[tokio::test]
async fn test_batch_spill_by_size() -> Result<()> {
let batch1 = build_table_i32(
("a2", &vec![0, 1, 2, 3]),
("b2", &vec![3, 4, 5, 6]),
("c2", &vec![4, 5, 6, 7]),
);
let schema = batch1.schema();
let env = Arc::new(RuntimeEnv::default());
let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
let spill_manager = SpillManager::new(env, metrics, Arc::clone(&schema));
let row_batches: Vec<RecordBatch> =
(0..batch1.num_rows()).map(|i| batch1.slice(i, 1)).collect();
let (spill_file, max_batch_mem) = spill_manager
.spill_record_batch_iter_and_return_max_batch_memory(
row_batches.iter().map(Ok),
"Test Spill",
)?
.unwrap();
assert!(spill_file.path().exists());
assert!(max_batch_mem > 0);
let stream = spill_manager.read_spill_as_stream(spill_file, None)?;
assert_eq!(stream.schema(), schema);
let batches = collect(stream).await?;
assert_eq!(batches.len(), 4);
Ok(())
}
fn build_compressible_batch() -> RecordBatch {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Utf8, false),
Field::new("b", DataType::Int32, false),
Field::new("c", DataType::Int32, true),
]));
let a: ArrayRef = Arc::new(StringArray::from_iter_values(std::iter::repeat_n(
"repeated", 100,
)));
let b: ArrayRef = Arc::new(Int32Array::from(vec![1; 100]));
let c: ArrayRef = Arc::new(Int32Array::from(vec![2; 100]));
RecordBatch::try_new(schema, vec![a, b, c]).unwrap()
}
async fn validate(
spill_manager: &SpillManager,
spill_file: RefCountedTempFile,
num_rows: usize,
schema: SchemaRef,
batch_count: usize,
) -> Result<()> {
let spilled_rows = spill_manager.metrics.spilled_rows.value();
assert_eq!(spilled_rows, num_rows);
let stream = spill_manager.read_spill_as_stream(spill_file, None)?;
assert_eq!(stream.schema(), schema);
let batches = collect(stream).await?;
assert_eq!(batches.len(), batch_count);
Ok(())
}
#[tokio::test]
async fn test_spill_compression() -> Result<()> {
let batch = build_compressible_batch();
let num_rows = batch.num_rows();
let schema = batch.schema();
let batch_count = 1;
let batches = [batch];
let env = Arc::new(RuntimeEnv::default());
let uncompressed_metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
let lz4_metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
let zstd_metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
let uncompressed_spill_manager = SpillManager::new(
Arc::clone(&env),
uncompressed_metrics,
Arc::clone(&schema),
);
let lz4_spill_manager =
SpillManager::new(Arc::clone(&env), lz4_metrics, Arc::clone(&schema))
.with_compression_type(SpillCompression::Lz4Frame);
let zstd_spill_manager =
SpillManager::new(env, zstd_metrics, Arc::clone(&schema))
.with_compression_type(SpillCompression::Zstd);
let uncompressed_spill_file = uncompressed_spill_manager
.spill_record_batch_and_finish(&batches, "Test")?
.unwrap();
let lz4_spill_file = lz4_spill_manager
.spill_record_batch_and_finish(&batches, "Lz4_Test")?
.unwrap();
let zstd_spill_file = zstd_spill_manager
.spill_record_batch_and_finish(&batches, "ZSTD_Test")?
.unwrap();
assert!(uncompressed_spill_file.path().exists());
assert!(lz4_spill_file.path().exists());
assert!(zstd_spill_file.path().exists());
let lz4_spill_size = std::fs::metadata(lz4_spill_file.path())?.len();
let zstd_spill_size = std::fs::metadata(zstd_spill_file.path())?.len();
let uncompressed_spill_size =
std::fs::metadata(uncompressed_spill_file.path())?.len();
assert!(uncompressed_spill_size > lz4_spill_size);
assert!(uncompressed_spill_size > zstd_spill_size);
validate(
&lz4_spill_manager,
lz4_spill_file,
num_rows,
Arc::clone(&schema),
batch_count,
)
.await?;
validate(
&zstd_spill_manager,
zstd_spill_file,
num_rows,
Arc::clone(&schema),
batch_count,
)
.await?;
validate(
&uncompressed_spill_manager,
uncompressed_spill_file,
num_rows,
schema,
batch_count,
)
.await?;
Ok(())
}
#[test]
fn test_spill_manager_spill_record_batch_and_finish() -> Result<()> {
let env = Arc::new(RuntimeEnv::default());
let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Utf8, false),
]));
let spill_manager = SpillManager::new(env, metrics, Arc::clone(&schema));
let batch = RecordBatch::try_new(
schema,
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(StringArray::from(vec!["a", "b", "c"])),
],
)?;
let temp_file = spill_manager.spill_record_batch_and_finish(&[batch], "Test")?;
assert!(temp_file.is_some());
assert!(temp_file.unwrap().path().exists());
Ok(())
}
fn verify_metrics(
in_progress_file: &InProgressSpillFile,
expected_spill_file_count: usize,
expected_spilled_bytes: usize,
expected_spilled_rows: usize,
) -> Result<()> {
let actual_spill_file_count = in_progress_file
.spill_writer
.metrics
.spill_file_count
.value();
let actual_spilled_bytes =
in_progress_file.spill_writer.metrics.spilled_bytes.value();
let actual_spilled_rows =
in_progress_file.spill_writer.metrics.spilled_rows.value();
assert_eq!(
actual_spill_file_count, expected_spill_file_count,
"Spill file count mismatch"
);
assert_eq!(
actual_spilled_bytes, expected_spilled_bytes,
"Spilled bytes mismatch"
);
assert_eq!(
actual_spilled_rows, expected_spilled_rows,
"Spilled rows mismatch"
);
Ok(())
}
#[test]
fn test_in_progress_spill_file_append_and_finish() -> Result<()> {
let env = Arc::new(RuntimeEnv::default());
let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Utf8, false),
]));
let spill_manager =
Arc::new(SpillManager::new(env, metrics, Arc::clone(&schema)));
let mut in_progress_file = spill_manager.create_in_progress_file("Test")?;
let batch1 = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(StringArray::from(vec!["a", "b", "c"])),
],
)?;
let batch2 = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Int32Array::from(vec![4, 5, 6])),
Arc::new(StringArray::from(vec!["d", "e", "f"])),
],
)?;
in_progress_file.append_batch(&batch1)?;
verify_metrics(&in_progress_file, 1, 440, 3)?;
in_progress_file.append_batch(&batch2)?;
verify_metrics(&in_progress_file, 1, 704, 6)?;
let completed_file = in_progress_file.finish()?;
assert!(completed_file.is_some());
assert!(completed_file.unwrap().path().exists());
verify_metrics(&in_progress_file, 1, 712, 6)?;
let result = in_progress_file.finish();
assert!(result.is_err());
Ok(())
}
#[test]
fn test_in_progress_spill_file_write_no_batches() -> Result<()> {
let env = Arc::new(RuntimeEnv::default());
let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Utf8, false),
]));
let spill_manager =
Arc::new(SpillManager::new(env, metrics, Arc::clone(&schema)));
let mut in_progress_file = spill_manager.create_in_progress_file("Test")?;
let completed_file = in_progress_file.finish()?;
assert!(completed_file.is_none());
let completed_file = spill_manager.spill_record_batch_and_finish(&[], "Test")?;
assert!(completed_file.is_none());
let empty_batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Int32Array::from(Vec::<Option<i32>>::new())),
Arc::new(StringArray::from(Vec::<Option<&str>>::new())),
],
)?;
let completed_file = spill_manager
.spill_record_batch_iter_and_return_max_batch_memory(
std::iter::once(Ok(&empty_batch)),
"Test",
)?;
assert!(completed_file.is_none());
Ok(())
}
#[test]
fn test_reading_more_spills_than_tokio_blocking_threads() -> Result<()> {
tokio::runtime::Builder::new_current_thread()
.enable_all()
.max_blocking_threads(1)
.build()
.unwrap()
.block_on(async {
let batch = build_table_i32(
("a2", &vec![0, 1, 2]),
("b2", &vec![3, 4, 5]),
("c2", &vec![4, 5, 6]),
);
let schema = batch.schema();
let env = Arc::new(RuntimeEnv::default());
let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
let spill_manager = SpillManager::new(env, metrics, Arc::clone(&schema));
let batches: [_; 10] = std::array::from_fn(|_| batch.clone());
let spill_file_1 = spill_manager
.spill_record_batch_and_finish(&batches, "Test1")?
.unwrap();
let spill_file_2 = spill_manager
.spill_record_batch_and_finish(&batches, "Test2")?
.unwrap();
let mut stream_1 =
spill_manager.read_spill_as_stream(spill_file_1, None)?;
let mut stream_2 =
spill_manager.read_spill_as_stream(spill_file_2, None)?;
stream_1.next().await;
stream_2.next().await;
Ok(())
})
}
#[test]
fn test_alignment_for_schema() -> Result<()> {
let schema = Schema::new(vec![Field::new("strings", DataType::Utf8View, false)]);
let alignment = get_max_alignment_for_schema(&schema);
assert_eq!(alignment, 16);
let schema = Schema::new(vec![
Field::new("int32", DataType::Int32, false),
Field::new("int64", DataType::Int64, false),
]);
let alignment = get_max_alignment_for_schema(&schema);
assert_eq!(alignment, 8);
Ok(())
}
#[tokio::test]
async fn test_real_time_spill_metrics() -> Result<()> {
let env = Arc::new(RuntimeEnv::default());
let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Utf8, false),
]));
let spill_manager = Arc::new(SpillManager::new(
Arc::clone(&env),
metrics.clone(),
Arc::clone(&schema),
));
let mut in_progress_file = spill_manager.create_in_progress_file("Test")?;
let batch1 = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(StringArray::from(vec!["a", "b", "c"])),
],
)?;
assert_eq!(metrics.spilled_bytes.value(), 0);
assert_eq!(metrics.spill_file_count.value(), 0);
in_progress_file.append_batch(&batch1)?;
let bytes_after_batch1 = metrics.spilled_bytes.value();
assert_eq!(bytes_after_batch1, 440);
assert_eq!(metrics.spill_file_count.value(), 1);
let progress = env.spilling_progress();
assert_eq!(progress.current_bytes, bytes_after_batch1 as u64);
assert_eq!(progress.active_files_count, 1);
in_progress_file.append_batch(&batch1)?;
let bytes_after_batch2 = metrics.spilled_bytes.value();
assert!(bytes_after_batch2 > bytes_after_batch1);
let progress = env.spilling_progress();
assert_eq!(progress.current_bytes, bytes_after_batch2 as u64);
let spilled_file = in_progress_file.finish()?;
let final_bytes = metrics.spilled_bytes.value();
assert!(final_bytes > bytes_after_batch2);
let progress = env.spilling_progress();
assert!(progress.current_bytes > 0);
assert_eq!(progress.active_files_count, 1);
drop(spilled_file);
assert_eq!(env.spilling_progress().active_files_count, 0);
assert_eq!(env.spilling_progress().current_bytes, 0);
Ok(())
}
}