use crate::error::{DbxError, DbxResult};
use arrow::ipc::reader::FileReader;
use arrow::ipc::writer::FileWriter;
use arrow::record_batch::RecordBatch;
use std::fs::File;
use std::io::BufWriter;
use std::path::PathBuf;
use tempfile::TempDir;
pub struct SpillContext {
budget_bytes: usize,
used_bytes: usize,
_temp_dir: TempDir,
temp_path: PathBuf,
spill_count: usize,
}
impl SpillContext {
pub fn new() -> DbxResult<Self> {
Self::with_budget(128 * 1024 * 1024)
}
pub fn with_budget(budget_bytes: usize) -> DbxResult<Self> {
let temp_dir = tempfile::tempdir()
.map_err(|e| DbxError::Storage(format!("Spill: tempdir 생성 실패: {}", e)))?;
let temp_path = temp_dir.path().to_path_buf();
Ok(Self {
budget_bytes,
used_bytes: 0,
_temp_dir: temp_dir,
temp_path,
spill_count: 0,
})
}
#[inline]
pub fn track(&mut self, bytes: usize) {
self.used_bytes += bytes;
}
#[inline]
pub fn should_spill(&self) -> bool {
self.used_bytes >= self.budget_bytes
}
#[inline]
pub fn reset_tracking(&mut self) {
self.used_bytes = 0;
}
pub fn spill_batches(&mut self, batches: &[RecordBatch]) -> DbxResult<PathBuf> {
if batches.is_empty() {
return Err(DbxError::Storage("Spill: 빈 batch 목록".to_string()));
}
let file_path = self
.temp_path
.join(format!("spill_{}.ipc", self.spill_count));
self.spill_count += 1;
self.write_ipc_file(&file_path, batches)?;
self.reset_tracking();
Ok(file_path)
}
pub fn spill_partition_batch(
&mut self,
side: &str,
part_idx: usize,
batch: RecordBatch,
) -> DbxResult<PathBuf> {
let file_path = self
.temp_path
.join(format!("{}_{}_{}.ipc", side, part_idx, self.spill_count));
self.spill_count += 1;
self.write_ipc_file(&file_path, &[batch])?;
Ok(file_path)
}
fn write_ipc_file(&self, path: &PathBuf, batches: &[RecordBatch]) -> DbxResult<()> {
let file = File::create(path).map_err(|e| {
DbxError::Storage(format!("Spill: 파일 생성 실패 {}: {}", path.display(), e))
})?;
let writer_buf = BufWriter::new(file);
let schema = batches[0].schema();
let mut writer = FileWriter::try_new(writer_buf, &schema).map_err(|e| {
DbxError::Storage(format!("Spill: Arrow IPC writer 초기화 실패: {}", e))
})?;
for batch in batches {
writer
.write(batch)
.map_err(|e| DbxError::Storage(format!("Spill: batch 쓰기 실패: {}", e)))?;
}
writer
.finish()
.map_err(|e| DbxError::Storage(format!("Spill: IPC 파일 완료 실패: {}", e)))?;
Ok(())
}
pub fn reload_batches(path: &PathBuf) -> DbxResult<Vec<RecordBatch>> {
let file = File::open(path).map_err(|e| {
DbxError::Storage(format!(
"Spill: reload 파일 열기 실패 {}: {}",
path.display(),
e
))
})?;
let reader = FileReader::try_new(file, None).map_err(|e| {
DbxError::Storage(format!("Spill: Arrow IPC reader 초기화 실패: {}", e))
})?;
let mut batches = Vec::new();
for result in reader {
let batch =
result.map_err(|e| DbxError::Storage(format!("Spill: batch 읽기 실패: {}", e)))?;
batches.push(batch);
}
Ok(batches)
}
pub fn estimate_batch_bytes(batch: &RecordBatch) -> usize {
batch
.columns()
.iter()
.map(|col| col.get_buffer_memory_size())
.sum()
}
#[cfg(test)]
pub fn used_bytes(&self) -> usize {
self.used_bytes
}
#[cfg(test)]
pub fn spill_count(&self) -> usize {
self.spill_count
}
#[cfg(test)]
pub fn budget_bytes(&self) -> usize {
self.budget_bytes
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::array::{Float64Array, Int32Array};
use arrow::datatypes::{DataType, Field, Schema};
use std::sync::Arc;
fn make_batch(n: usize) -> RecordBatch {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("val", DataType::Float64, false),
]));
RecordBatch::try_new(
schema,
vec![
Arc::new(Int32Array::from((0..n as i32).collect::<Vec<_>>())),
Arc::new(Float64Array::from(
(0..n).map(|i| i as f64).collect::<Vec<_>>(),
)),
],
)
.unwrap()
}
#[test]
fn test_spill_context_creation() {
let ctx = SpillContext::new().unwrap();
assert_eq!(ctx.budget_bytes(), 128 * 1024 * 1024);
assert!(!ctx.should_spill());
}
#[test]
fn test_memory_tracking() {
let mut ctx = SpillContext::with_budget(1000).unwrap();
ctx.track(500);
assert!(!ctx.should_spill());
ctx.track(600);
assert!(ctx.should_spill());
}
#[test]
fn test_spill_and_reload_round_trip() {
let mut ctx = SpillContext::new().unwrap();
let batch1 = make_batch(100);
let batch2 = make_batch(50);
let batches = vec![batch1, batch2];
let path = ctx.spill_batches(&batches).unwrap();
assert!(path.exists());
assert_eq!(ctx.spill_count(), 1);
assert_eq!(ctx.used_bytes(), 0);
let reloaded = SpillContext::reload_batches(&path).unwrap();
let total_rows: usize = reloaded.iter().map(|b| b.num_rows()).sum();
assert_eq!(total_rows, 150);
}
#[test]
fn test_estimate_batch_bytes() {
let batch = make_batch(1000);
let size = SpillContext::estimate_batch_bytes(&batch);
assert!(size > 0, "배치 크기가 0보다 커야 함");
}
#[test]
fn test_multiple_spills() {
let mut ctx = SpillContext::new().unwrap();
let batch = make_batch(10);
let path1 = ctx.spill_batches(&[batch.clone()]).unwrap();
let path2 = ctx.spill_batches(&[batch]).unwrap();
assert_ne!(path1, path2, "각 Spill은 다른 파일이어야 함");
assert_eq!(ctx.spill_count(), 2);
}
}