Skip to main content

dbx_core/sql/executor/
spill.rs

1//! Spill-to-Disk 메모리 방어 컨텍스트
2//!
3//! `SpillContext`는 쿼리 단위로 생성되는 메모리 예산 추적기입니다.
4//! 메모리 임계치를 초과할 경우 Arrow IPC 포맷으로 임시 파일에 Spill 합니다.
5//!
6//! # 설계 원칙
7//! - **쿼리 단위 격리**: 글로벌 상태 없음. Atomic 불필요 (`next()`는 `&mut self`)
8//! - **Arrow IPC 포맷**: Parquet 대비 낮은 쓰기 오버헤드 (메타데이터/통계 없음)
9//! - **자동 정리**: `SpillContext` drop 시 TempDir 자동 삭제
10
11use crate::error::{DbxError, DbxResult};
12use arrow::ipc::reader::FileReader;
13use arrow::ipc::writer::FileWriter;
14use arrow::record_batch::RecordBatch;
15use std::fs::File;
16use std::io::BufWriter;
17use std::path::PathBuf;
18use tempfile::TempDir;
19
20/// 쿼리 단위 Spill 컨텍스트.
21///
22/// `HashAggregateOperator` 등 메모리 집약 연산자에 주입되어,
23/// 메모리 사용량이 `budget_bytes`를 초과하면 RecordBatch를 임시 디스크로 내보냅니다.
24pub struct SpillContext {
25    /// 이 쿼리에 허용된 최대 메모리 바이트 (기본: 128MB)
26    budget_bytes: usize,
27    /// 현재 추적 중인 메모리 사용량 (바이트)
28    used_bytes: usize,
29    /// Spill 파일이 기록되는 임시 디렉토리 (Drop 시 자동 삭제)
30    _temp_dir: TempDir,
31    /// 임시 디렉토리 경로 (파일 생성 시 사용)
32    temp_path: PathBuf,
33    /// 생성된 Spill 파일 카운터
34    spill_count: usize,
35}
36
37impl SpillContext {
38    /// 기본 메모리 예산(128MB)으로 SpillContext 생성.
39    pub fn new() -> DbxResult<Self> {
40        Self::with_budget(128 * 1024 * 1024)
41    }
42
43    /// 지정된 메모리 예산(바이트)으로 SpillContext 생성.
44    pub fn with_budget(budget_bytes: usize) -> DbxResult<Self> {
45        let temp_dir = tempfile::tempdir()
46            .map_err(|e| DbxError::Storage(format!("Spill: tempdir 생성 실패: {}", e)))?;
47        let temp_path = temp_dir.path().to_path_buf();
48        Ok(Self {
49            budget_bytes,
50            used_bytes: 0,
51            _temp_dir: temp_dir,
52            temp_path,
53            spill_count: 0,
54        })
55    }
56
57    /// 바이트 단위 메모리 사용량을 등록합니다.
58    /// RecordBatch의 크기는 `estimate_batch_bytes`로 추정합니다.
59    #[inline]
60    pub fn track(&mut self, bytes: usize) {
61        self.used_bytes += bytes;
62    }
63
64    /// 현재 메모리 사용량이 예산을 초과했는지 확인합니다.
65    #[inline]
66    pub fn should_spill(&self) -> bool {
67        self.used_bytes >= self.budget_bytes
68    }
69
70    /// 메모리 사용량 추적을 초기화합니다 (Spill 후 재시작).
71    #[inline]
72    pub fn reset_tracking(&mut self) {
73        self.used_bytes = 0;
74    }
75
76    /// RecordBatch 목록을 Arrow IPC 포맷으로 임시 파일에 씁니다.
77    ///
78    /// # 반환
79    /// Spill된 파일의 경로를 반환합니다.
80    pub fn spill_batches(&mut self, batches: &[RecordBatch]) -> DbxResult<PathBuf> {
81        if batches.is_empty() {
82            return Err(DbxError::Storage("Spill: 빈 batch 목록".to_string()));
83        }
84
85        let file_path = self
86            .temp_path
87            .join(format!("spill_{}.ipc", self.spill_count));
88        self.spill_count += 1;
89        self.write_ipc_file(&file_path, batches)?;
90
91        self.reset_tracking();
92        Ok(file_path)
93    }
94
95    /// 특정 파티션에 속한 RecordBatch를 Spill 합니다 (Grace Hash Join 용).
96    /// 파일명 규칙: `{side}_{part_idx}_{count}.ipc`
97    pub fn spill_partition_batch(
98        &mut self,
99        side: &str,
100        part_idx: usize,
101        batch: RecordBatch,
102    ) -> DbxResult<PathBuf> {
103        let file_path = self
104            .temp_path
105            .join(format!("{}_{}_{}.ipc", side, part_idx, self.spill_count));
106        self.spill_count += 1;
107        self.write_ipc_file(&file_path, &[batch])?;
108        Ok(file_path)
109    }
110
111    /// 내부 유틸리티: RecordBatch 목록을 IPC 파일로 기록.
112    fn write_ipc_file(&self, path: &PathBuf, batches: &[RecordBatch]) -> DbxResult<()> {
113        let file = File::create(path).map_err(|e| {
114            DbxError::Storage(format!("Spill: 파일 생성 실패 {}: {}", path.display(), e))
115        })?;
116        let writer_buf = BufWriter::new(file);
117        let schema = batches[0].schema();
118
119        let mut writer = FileWriter::try_new(writer_buf, &schema).map_err(|e| {
120            DbxError::Storage(format!("Spill: Arrow IPC writer 초기화 실패: {}", e))
121        })?;
122
123        for batch in batches {
124            writer
125                .write(batch)
126                .map_err(|e| DbxError::Storage(format!("Spill: batch 쓰기 실패: {}", e)))?;
127        }
128
129        writer
130            .finish()
131            .map_err(|e| DbxError::Storage(format!("Spill: IPC 파일 완료 실패: {}", e)))?;
132
133        Ok(())
134    }
135
136    /// Spill 파일에서 RecordBatch 목록을 읽어옵니다.
137    ///
138    /// 읽기 완료 후 파일은 삭제되지 않습니다 (TempDir drop 시 일괄 삭제).
139    pub fn reload_batches(path: &PathBuf) -> DbxResult<Vec<RecordBatch>> {
140        let file = File::open(path).map_err(|e| {
141            DbxError::Storage(format!(
142                "Spill: reload 파일 열기 실패 {}: {}",
143                path.display(),
144                e
145            ))
146        })?;
147
148        let reader = FileReader::try_new(file, None).map_err(|e| {
149            DbxError::Storage(format!("Spill: Arrow IPC reader 초기화 실패: {}", e))
150        })?;
151
152        let mut batches = Vec::new();
153        for result in reader {
154            let batch =
155                result.map_err(|e| DbxError::Storage(format!("Spill: batch 읽기 실패: {}", e)))?;
156            batches.push(batch);
157        }
158
159        Ok(batches)
160    }
161
162    /// RecordBatch의 메모리 사용량을 추정합니다 (컬럼 버퍼 합산).
163    pub fn estimate_batch_bytes(batch: &RecordBatch) -> usize {
164        batch
165            .columns()
166            .iter()
167            .map(|col| col.get_buffer_memory_size())
168            .sum()
169    }
170
171    /// 현재 메모리 사용량 반환 (바이트).
172    #[cfg(test)]
173    pub fn used_bytes(&self) -> usize {
174        self.used_bytes
175    }
176
177    /// 생성된 Spill 파일 수 반환.
178    #[cfg(test)]
179    pub fn spill_count(&self) -> usize {
180        self.spill_count
181    }
182
183    /// 메모리 예산 반환 (바이트).
184    #[cfg(test)]
185    pub fn budget_bytes(&self) -> usize {
186        self.budget_bytes
187    }
188}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193    use arrow::array::{Float64Array, Int32Array};
194    use arrow::datatypes::{DataType, Field, Schema};
195    use std::sync::Arc;
196
197    fn make_batch(n: usize) -> RecordBatch {
198        let schema = Arc::new(Schema::new(vec![
199            Field::new("id", DataType::Int32, false),
200            Field::new("val", DataType::Float64, false),
201        ]));
202        RecordBatch::try_new(
203            schema,
204            vec![
205                Arc::new(Int32Array::from((0..n as i32).collect::<Vec<_>>())),
206                Arc::new(Float64Array::from(
207                    (0..n).map(|i| i as f64).collect::<Vec<_>>(),
208                )),
209            ],
210        )
211        .unwrap()
212    }
213
214    #[test]
215    fn test_spill_context_creation() {
216        let ctx = SpillContext::new().unwrap();
217        assert_eq!(ctx.budget_bytes(), 128 * 1024 * 1024);
218        assert!(!ctx.should_spill());
219    }
220
221    #[test]
222    fn test_memory_tracking() {
223        let mut ctx = SpillContext::with_budget(1000).unwrap();
224        ctx.track(500);
225        assert!(!ctx.should_spill());
226        ctx.track(600);
227        assert!(ctx.should_spill());
228    }
229
230    #[test]
231    fn test_spill_and_reload_round_trip() {
232        let mut ctx = SpillContext::new().unwrap();
233        let batch1 = make_batch(100);
234        let batch2 = make_batch(50);
235        let batches = vec![batch1, batch2];
236
237        let path = ctx.spill_batches(&batches).unwrap();
238        assert!(path.exists());
239        assert_eq!(ctx.spill_count(), 1);
240        assert_eq!(ctx.used_bytes(), 0); // reset 됐어야 함
241
242        let reloaded = SpillContext::reload_batches(&path).unwrap();
243        let total_rows: usize = reloaded.iter().map(|b| b.num_rows()).sum();
244        assert_eq!(total_rows, 150);
245    }
246
247    #[test]
248    fn test_estimate_batch_bytes() {
249        let batch = make_batch(1000);
250        let size = SpillContext::estimate_batch_bytes(&batch);
251        assert!(size > 0, "배치 크기가 0보다 커야 함");
252    }
253
254    #[test]
255    fn test_multiple_spills() {
256        let mut ctx = SpillContext::new().unwrap();
257        let batch = make_batch(10);
258
259        let path1 = ctx.spill_batches(&[batch.clone()]).unwrap();
260        let path2 = ctx.spill_batches(&[batch]).unwrap();
261
262        assert_ne!(path1, path2, "각 Spill은 다른 파일이어야 함");
263        assert_eq!(ctx.spill_count(), 2);
264    }
265}