use datafusion_common::Result;
use std::sync::Arc;
use arrow::array::RecordBatch;
use datafusion_common::exec_datafusion_err;
use datafusion_execution::disk_manager::RefCountedTempFile;
use super::{IPCStreamWriter, spill_manager::SpillManager};
pub struct InProgressSpillFile {
pub(crate) spill_writer: Arc<SpillManager>,
writer: Option<IPCStreamWriter>,
in_progress_file: Option<RefCountedTempFile>,
}
impl InProgressSpillFile {
pub fn new(
spill_writer: Arc<SpillManager>,
in_progress_file: RefCountedTempFile,
) -> Self {
Self {
spill_writer,
in_progress_file: Some(in_progress_file),
writer: None,
}
}
pub fn append_batch(&mut self, batch: &RecordBatch) -> Result<()> {
if self.in_progress_file.is_none() {
return Err(exec_datafusion_err!(
"Append operation failed: No active in-progress file. The file may have already been finalized."
));
}
if self.writer.is_none() {
let schema = self.spill_writer.schema();
if let Some(in_progress_file) = &mut self.in_progress_file {
self.writer = Some(IPCStreamWriter::new(
in_progress_file.path(),
schema.as_ref(),
self.spill_writer.compression,
)?);
self.spill_writer.metrics.spill_file_count.add(1);
in_progress_file.update_disk_usage()?;
let initial_size = in_progress_file.current_disk_usage();
self.spill_writer
.metrics
.spilled_bytes
.add(initial_size as usize);
}
}
if let Some(writer) = &mut self.writer {
let (spilled_rows, _) = writer.write(batch)?;
if let Some(in_progress_file) = &mut self.in_progress_file {
let pre_size = in_progress_file.current_disk_usage();
in_progress_file.update_disk_usage()?;
let post_size = in_progress_file.current_disk_usage();
self.spill_writer.metrics.spilled_rows.add(spilled_rows);
self.spill_writer
.metrics
.spilled_bytes
.add((post_size - pre_size) as usize);
} else {
unreachable!() }
}
Ok(())
}
pub fn flush(&mut self) -> Result<()> {
if let Some(writer) = &mut self.writer {
writer.flush()?;
}
Ok(())
}
pub fn file(&self) -> Option<&RefCountedTempFile> {
self.in_progress_file.as_ref()
}
pub fn finish(&mut self) -> Result<Option<RefCountedTempFile>> {
if let Some(writer) = &mut self.writer {
writer.finish()?;
} else {
return Ok(None);
}
if let Some(in_progress_file) = &mut self.in_progress_file {
let pre_size = in_progress_file.current_disk_usage();
in_progress_file.update_disk_usage()?;
let post_size = in_progress_file.current_disk_usage();
self.spill_writer
.metrics
.spilled_bytes
.add((post_size - pre_size) as usize);
}
Ok(self.in_progress_file.take())
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::array::Int64Array;
use arrow_schema::{DataType, Field, Schema};
use datafusion_execution::runtime_env::RuntimeEnvBuilder;
use datafusion_physical_expr_common::metrics::{
ExecutionPlanMetricsSet, SpillMetrics,
};
use futures::TryStreamExt;
#[tokio::test]
async fn test_spill_file_uses_spill_manager_schema() -> Result<()> {
let nullable_schema = Arc::new(Schema::new(vec![
Field::new("key", DataType::Int64, false),
Field::new("val", DataType::Int64, true),
]));
let non_nullable_schema = Arc::new(Schema::new(vec![
Field::new("key", DataType::Int64, false),
Field::new("val", DataType::Int64, false),
]));
let runtime = Arc::new(RuntimeEnvBuilder::new().build()?);
let metrics_set = ExecutionPlanMetricsSet::new();
let spill_metrics = SpillMetrics::new(&metrics_set, 0);
let spill_manager = Arc::new(SpillManager::new(
runtime,
spill_metrics,
Arc::clone(&nullable_schema),
));
let mut in_progress = spill_manager.create_in_progress_file("test")?;
let non_nullable_batch = RecordBatch::try_new(
Arc::clone(&non_nullable_schema),
vec![
Arc::new(Int64Array::from(vec![1, 2, 3])),
Arc::new(Int64Array::from(vec![0, 0, 0])),
],
)?;
in_progress.append_batch(&non_nullable_batch)?;
let nullable_batch = RecordBatch::try_new(
Arc::clone(&nullable_schema),
vec![
Arc::new(Int64Array::from(vec![4, 5, 6])),
Arc::new(Int64Array::from(vec![Some(10), None, Some(30)])),
],
)?;
in_progress.append_batch(&nullable_batch)?;
let spill_file = in_progress.finish()?.unwrap();
let stream = spill_manager.read_spill_as_stream(spill_file, None)?;
assert_eq!(stream.schema(), nullable_schema);
let batches = stream.try_collect::<Vec<_>>().await?;
assert_eq!(batches.len(), 2);
assert_eq!(
batches[0],
non_nullable_batch.with_schema(Arc::clone(&nullable_schema))?
);
assert_eq!(batches[1], nullable_batch);
Ok(())
}
}