use super::{SpillReaderStream, in_progress_spill_file::InProgressSpillFile};
use crate::coop::cooperative;
use crate::{common::spawn_buffered, metrics::SpillMetrics};
use arrow::array::{BinaryViewArray, GenericByteViewArray, StringViewArray};
use arrow::datatypes::{ByteViewType, SchemaRef};
use arrow::record_batch::RecordBatch;
use datafusion_common::{DataFusionError, Result, config::SpillCompression};
use datafusion_execution::SendableRecordBatchStream;
use datafusion_execution::disk_manager::RefCountedTempFile;
use datafusion_execution::runtime_env::RuntimeEnv;
use std::borrow::Borrow;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct SpillManager {
env: Arc<RuntimeEnv>,
pub(crate) metrics: SpillMetrics,
schema: SchemaRef,
batch_read_buffer_capacity: usize,
pub(crate) compression: SpillCompression,
}
impl SpillManager {
pub fn new(env: Arc<RuntimeEnv>, metrics: SpillMetrics, schema: SchemaRef) -> Self {
Self {
env,
metrics,
schema,
batch_read_buffer_capacity: 2,
compression: SpillCompression::default(),
}
}
pub fn with_batch_read_buffer_capacity(
mut self,
batch_read_buffer_capacity: usize,
) -> Self {
self.batch_read_buffer_capacity = batch_read_buffer_capacity;
self
}
pub fn with_compression_type(mut self, spill_compression: SpillCompression) -> Self {
self.compression = spill_compression;
self
}
pub fn schema(&self) -> &SchemaRef {
&self.schema
}
pub fn create_in_progress_file(
&self,
request_msg: &str,
) -> Result<InProgressSpillFile> {
let temp_file = self.env.disk_manager.create_tmp_file(request_msg)?;
Ok(InProgressSpillFile::new(Arc::new(self.clone()), temp_file))
}
pub fn spill_record_batch_and_finish(
&self,
batches: &[RecordBatch],
request_msg: &str,
) -> Result<Option<RefCountedTempFile>> {
let mut in_progress_file = self.create_in_progress_file(request_msg)?;
for batch in batches {
in_progress_file.append_batch(batch)?;
}
in_progress_file.finish()
}
pub(crate) fn spill_record_batch_iter_and_return_max_batch_memory(
&self,
mut iter: impl Iterator<Item = Result<impl Borrow<RecordBatch>>>,
request_description: &str,
) -> Result<Option<(RefCountedTempFile, usize)>> {
let mut in_progress_file = self.create_in_progress_file(request_description)?;
let mut max_record_batch_size = 0;
iter.try_for_each(|batch| {
let batch = batch?;
let borrowed = batch.borrow();
if borrowed.num_rows() == 0 {
return Ok(());
}
let gc_sliced_size = in_progress_file.append_batch(borrowed)?;
max_record_batch_size = max_record_batch_size.max(gc_sliced_size);
Result::<_, DataFusionError>::Ok(())
})?;
let file = in_progress_file.finish()?;
Ok(file.map(|f| (f, max_record_batch_size)))
}
pub(crate) async fn spill_record_batch_stream_and_return_max_batch_memory(
&self,
stream: &mut SendableRecordBatchStream,
request_description: &str,
) -> Result<Option<(RefCountedTempFile, usize)>> {
use futures::StreamExt;
let mut in_progress_file = self.create_in_progress_file(request_description)?;
let mut max_record_batch_size = 0;
while let Some(batch) = stream.next().await {
let batch = batch?;
let gc_sliced_size = in_progress_file.append_batch(&batch)?;
max_record_batch_size = max_record_batch_size.max(gc_sliced_size);
}
let file = in_progress_file.finish()?;
Ok(file.map(|f| (f, max_record_batch_size)))
}
pub fn read_spill_as_stream(
&self,
spill_file_path: RefCountedTempFile,
max_record_batch_memory: Option<usize>,
) -> Result<SendableRecordBatchStream> {
let stream = Box::pin(cooperative(SpillReaderStream::new(
Arc::clone(&self.schema),
spill_file_path,
max_record_batch_memory,
)));
Ok(spawn_buffered(stream, self.batch_read_buffer_capacity))
}
pub fn read_spill_as_stream_unbuffered(
&self,
spill_file_path: RefCountedTempFile,
max_record_batch_memory: Option<usize>,
) -> Result<SendableRecordBatchStream> {
Ok(Box::pin(cooperative(SpillReaderStream::new(
Arc::clone(&self.schema),
spill_file_path,
max_record_batch_memory,
))))
}
}
pub(crate) trait GetSlicedSize {
fn get_sliced_size(&self) -> Result<usize>;
}
impl GetSlicedSize for RecordBatch {
fn get_sliced_size(&self) -> Result<usize> {
let mut total = 0;
for array in self.columns() {
let data = array.to_data();
total += data.get_slice_memory_size()?;
if let Some(sv) = array.as_any().downcast_ref::<StringViewArray>() {
total += byte_view_data_buffer_size(sv);
}
if let Some(bv) = array.as_any().downcast_ref::<BinaryViewArray>() {
total += byte_view_data_buffer_size(bv);
}
}
Ok(total)
}
}
fn byte_view_data_buffer_size<T: ByteViewType>(array: &GenericByteViewArray<T>) -> usize {
array
.data_buffers()
.iter()
.map(|buffer| buffer.capacity())
.sum()
}
#[cfg(test)]
mod tests {
use super::SpillManager;
use crate::common::collect;
use crate::metrics::{ExecutionPlanMetricsSet, SpillMetrics};
use crate::spill::{get_record_batch_memory_size, spill_manager::GetSlicedSize};
use arrow::datatypes::{DataType, Field, Schema};
use arrow::{
array::{ArrayRef, Int32Array, StringArray, StringViewArray},
record_batch::RecordBatch,
};
use datafusion_common::Result;
use datafusion_execution::runtime_env::RuntimeEnv;
use std::sync::Arc;
fn build_test_spill_manager(
env: Arc<RuntimeEnv>,
schema: Arc<Schema>,
) -> SpillManager {
let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
SpillManager::new(env, metrics, schema)
}
fn build_writer_batch(schema: Arc<Schema>) -> Result<RecordBatch> {
RecordBatch::try_new(
schema,
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(StringArray::from(vec!["a", "b", "c"])),
],
)
.map_err(Into::into)
}
#[tokio::test]
async fn test_read_spill_as_stream_from_another_spill_manager_same_schema()
-> Result<()> {
let env = Arc::new(RuntimeEnv::default());
let writer_schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("value", DataType::Utf8, false),
]));
let reader_schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("value", DataType::Utf8, false),
]));
let writer =
build_test_spill_manager(Arc::clone(&env), Arc::clone(&writer_schema));
let reader = build_test_spill_manager(env, Arc::clone(&reader_schema));
let written_batch = build_writer_batch(Arc::clone(&writer_schema))?;
let spill_file = writer
.spill_record_batch_and_finish(
std::slice::from_ref(&written_batch),
"writer",
)?
.unwrap();
let stream = reader.read_spill_as_stream(spill_file, None)?;
assert_eq!(stream.schema(), reader_schema);
let batches = collect(stream).await?;
assert_eq!(batches, vec![written_batch]);
Ok(())
}
#[tokio::test]
async fn test_read_spill_as_stream_from_another_spill_manager_different_schema()
-> Result<()> {
let env = Arc::new(RuntimeEnv::default());
let writer_schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("value", DataType::Utf8, false),
]));
let reader_schema = Arc::new(Schema::new(vec![
Field::new("other_id", DataType::Int32, true),
Field::new("other_value", DataType::Utf8, true),
]));
let writer =
build_test_spill_manager(Arc::clone(&env), Arc::clone(&writer_schema));
let reader = build_test_spill_manager(env, Arc::clone(&reader_schema));
let written_batch = build_writer_batch(Arc::clone(&writer_schema))?;
let spill_file = writer
.spill_record_batch_and_finish(
std::slice::from_ref(&written_batch),
"writer",
)?
.unwrap();
let stream = reader.read_spill_as_stream(spill_file, None)?;
let err = collect(stream)
.await
.expect_err("schema mismatch should fail fast");
let err = err.to_string();
assert!(err.contains("Spill file schema mismatch"));
assert!(err.contains("expected"));
assert!(err.contains("got"));
Ok(())
}
#[test]
fn check_sliced_size_for_string_view_array() -> Result<()> {
let array_length = 50;
let short_len = 8;
let long_len = 25;
let strings: Vec<String> = (0..array_length)
.map(|i| {
if i % 2 == 0 {
"a".repeat(short_len)
} else {
"b".repeat(long_len)
}
})
.collect();
let string_array = StringViewArray::from(strings);
let array_ref: ArrayRef = Arc::new(string_array);
let batch = RecordBatch::try_new(
Arc::new(Schema::new(vec![Field::new(
"strings",
DataType::Utf8View,
false,
)])),
vec![array_ref],
)
.unwrap();
assert_eq!(
batch.get_sliced_size().unwrap(),
get_record_batch_memory_size(&batch)
);
let half_batch = batch.slice(0, array_length / 2);
assert!(
half_batch.get_sliced_size().unwrap()
< get_record_batch_memory_size(&half_batch)
);
let data = arrow::array::Array::to_data(&half_batch.column(0));
let views_sliced_size = data.get_slice_memory_size()?;
assert!(views_sliced_size < half_batch.get_sliced_size().unwrap());
Ok(())
}
}