use super::{SpillReaderStream, in_progress_spill_file::InProgressSpillFile};
use crate::coop::cooperative;
use crate::{common::spawn_buffered, metrics::SpillMetrics};
use arrow::array::StringViewArray;
use arrow::datatypes::SchemaRef;
use arrow::record_batch::RecordBatch;
use datafusion_common::utils::memory::get_record_batch_memory_size;
use datafusion_common::{DataFusionError, Result, config::SpillCompression};
use datafusion_execution::SendableRecordBatchStream;
use datafusion_execution::disk_manager::RefCountedTempFile;
use datafusion_execution::runtime_env::RuntimeEnv;
use std::borrow::Borrow;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct SpillManager {
env: Arc<RuntimeEnv>,
pub(crate) metrics: SpillMetrics,
schema: SchemaRef,
batch_read_buffer_capacity: usize,
pub(crate) compression: SpillCompression,
}
impl SpillManager {
pub fn new(env: Arc<RuntimeEnv>, metrics: SpillMetrics, schema: SchemaRef) -> Self {
Self {
env,
metrics,
schema,
batch_read_buffer_capacity: 2,
compression: SpillCompression::default(),
}
}
pub fn with_batch_read_buffer_capacity(
mut self,
batch_read_buffer_capacity: usize,
) -> Self {
self.batch_read_buffer_capacity = batch_read_buffer_capacity;
self
}
pub fn with_compression_type(mut self, spill_compression: SpillCompression) -> Self {
self.compression = spill_compression;
self
}
pub fn schema(&self) -> &SchemaRef {
&self.schema
}
pub fn create_in_progress_file(
&self,
request_msg: &str,
) -> Result<InProgressSpillFile> {
let temp_file = self.env.disk_manager.create_tmp_file(request_msg)?;
Ok(InProgressSpillFile::new(Arc::new(self.clone()), temp_file))
}
pub fn spill_record_batch_and_finish(
&self,
batches: &[RecordBatch],
request_msg: &str,
) -> Result<Option<RefCountedTempFile>> {
let mut in_progress_file = self.create_in_progress_file(request_msg)?;
for batch in batches {
in_progress_file.append_batch(batch)?;
}
in_progress_file.finish()
}
pub(crate) fn spill_record_batch_iter_and_return_max_batch_memory(
&self,
mut iter: impl Iterator<Item = Result<impl Borrow<RecordBatch>>>,
request_description: &str,
) -> Result<Option<(RefCountedTempFile, usize)>> {
let mut in_progress_file = self.create_in_progress_file(request_description)?;
let mut max_record_batch_size = 0;
iter.try_for_each(|batch| {
let batch = batch?;
let borrowed = batch.borrow();
if borrowed.num_rows() == 0 {
return Ok(());
}
in_progress_file.append_batch(borrowed)?;
max_record_batch_size =
max_record_batch_size.max(get_record_batch_memory_size(borrowed));
Result::<_, DataFusionError>::Ok(())
})?;
let file = in_progress_file.finish()?;
Ok(file.map(|f| (f, max_record_batch_size)))
}
pub(crate) async fn spill_record_batch_stream_and_return_max_batch_memory(
&self,
stream: &mut SendableRecordBatchStream,
request_description: &str,
) -> Result<Option<(RefCountedTempFile, usize)>> {
use futures::StreamExt;
let mut in_progress_file = self.create_in_progress_file(request_description)?;
let mut max_record_batch_size = 0;
while let Some(batch) = stream.next().await {
let batch = batch?;
in_progress_file.append_batch(&batch)?;
max_record_batch_size = max_record_batch_size.max(batch.get_sliced_size()?);
}
let file = in_progress_file.finish()?;
Ok(file.map(|f| (f, max_record_batch_size)))
}
pub fn read_spill_as_stream(
&self,
spill_file_path: RefCountedTempFile,
max_record_batch_memory: Option<usize>,
) -> Result<SendableRecordBatchStream> {
let stream = Box::pin(cooperative(SpillReaderStream::new(
Arc::clone(&self.schema),
spill_file_path,
max_record_batch_memory,
)));
Ok(spawn_buffered(stream, self.batch_read_buffer_capacity))
}
pub fn read_spill_as_stream_unbuffered(
&self,
spill_file_path: RefCountedTempFile,
max_record_batch_memory: Option<usize>,
) -> Result<SendableRecordBatchStream> {
Ok(Box::pin(cooperative(SpillReaderStream::new(
Arc::clone(&self.schema),
spill_file_path,
max_record_batch_memory,
))))
}
}
pub(crate) trait GetSlicedSize {
fn get_sliced_size(&self) -> Result<usize>;
}
impl GetSlicedSize for RecordBatch {
fn get_sliced_size(&self) -> Result<usize> {
let mut total = 0;
for array in self.columns() {
let data = array.to_data();
total += data.get_slice_memory_size()?;
if let Some(sv) = array.as_any().downcast_ref::<StringViewArray>() {
for buffer in sv.data_buffers() {
total += buffer.capacity();
}
}
}
Ok(total)
}
}
#[cfg(test)]
mod tests {
use crate::spill::{get_record_batch_memory_size, spill_manager::GetSlicedSize};
use arrow::datatypes::{DataType, Field, Schema};
use arrow::{
array::{ArrayRef, StringViewArray},
record_batch::RecordBatch,
};
use datafusion_common::Result;
use std::sync::Arc;
#[test]
fn check_sliced_size_for_string_view_array() -> Result<()> {
let array_length = 50;
let short_len = 8;
let long_len = 25;
let strings: Vec<String> = (0..array_length)
.map(|i| {
if i % 2 == 0 {
"a".repeat(short_len)
} else {
"b".repeat(long_len)
}
})
.collect();
let string_array = StringViewArray::from(strings);
let array_ref: ArrayRef = Arc::new(string_array);
let batch = RecordBatch::try_new(
Arc::new(Schema::new(vec![Field::new(
"strings",
DataType::Utf8View,
false,
)])),
vec![array_ref],
)
.unwrap();
assert_eq!(
batch.get_sliced_size().unwrap(),
get_record_batch_memory_size(&batch)
);
let half_batch = batch.slice(0, array_length / 2);
assert!(
half_batch.get_sliced_size().unwrap()
< get_record_batch_memory_size(&half_batch)
);
let data = arrow::array::Array::to_data(&half_batch.column(0));
let views_sliced_size = data.get_slice_memory_size()?;
assert!(views_sliced_size < half_batch.get_sliced_size().unwrap());
Ok(())
}
}