dbx_core/sql/executor/
spill.rs1use crate::error::{DbxError, DbxResult};
12use arrow::ipc::reader::FileReader;
13use arrow::ipc::writer::FileWriter;
14use arrow::record_batch::RecordBatch;
15use std::fs::File;
16use std::io::BufWriter;
17use std::path::PathBuf;
18use tempfile::TempDir;
19
20pub struct SpillContext {
25 budget_bytes: usize,
27 used_bytes: usize,
29 _temp_dir: TempDir,
31 temp_path: PathBuf,
33 spill_count: usize,
35}
36
37impl SpillContext {
38 pub fn new() -> DbxResult<Self> {
40 Self::with_budget(128 * 1024 * 1024)
41 }
42
43 pub fn with_budget(budget_bytes: usize) -> DbxResult<Self> {
45 let temp_dir = tempfile::tempdir()
46 .map_err(|e| DbxError::Storage(format!("Spill: tempdir 생성 실패: {}", e)))?;
47 let temp_path = temp_dir.path().to_path_buf();
48 Ok(Self {
49 budget_bytes,
50 used_bytes: 0,
51 _temp_dir: temp_dir,
52 temp_path,
53 spill_count: 0,
54 })
55 }
56
57 #[inline]
60 pub fn track(&mut self, bytes: usize) {
61 self.used_bytes += bytes;
62 }
63
64 #[inline]
66 pub fn should_spill(&self) -> bool {
67 self.used_bytes >= self.budget_bytes
68 }
69
70 #[inline]
72 pub fn reset_tracking(&mut self) {
73 self.used_bytes = 0;
74 }
75
76 pub fn spill_batches(&mut self, batches: &[RecordBatch]) -> DbxResult<PathBuf> {
81 if batches.is_empty() {
82 return Err(DbxError::Storage("Spill: 빈 batch 목록".to_string()));
83 }
84
85 let file_path = self
86 .temp_path
87 .join(format!("spill_{}.ipc", self.spill_count));
88 self.spill_count += 1;
89 self.write_ipc_file(&file_path, batches)?;
90
91 self.reset_tracking();
92 Ok(file_path)
93 }
94
95 pub fn spill_partition_batch(
98 &mut self,
99 side: &str,
100 part_idx: usize,
101 batch: RecordBatch,
102 ) -> DbxResult<PathBuf> {
103 let file_path = self
104 .temp_path
105 .join(format!("{}_{}_{}.ipc", side, part_idx, self.spill_count));
106 self.spill_count += 1;
107 self.write_ipc_file(&file_path, &[batch])?;
108 Ok(file_path)
109 }
110
111 fn write_ipc_file(&self, path: &PathBuf, batches: &[RecordBatch]) -> DbxResult<()> {
113 let file = File::create(path).map_err(|e| {
114 DbxError::Storage(format!("Spill: 파일 생성 실패 {}: {}", path.display(), e))
115 })?;
116 let writer_buf = BufWriter::new(file);
117 let schema = batches[0].schema();
118
119 let mut writer = FileWriter::try_new(writer_buf, &schema).map_err(|e| {
120 DbxError::Storage(format!("Spill: Arrow IPC writer 초기화 실패: {}", e))
121 })?;
122
123 for batch in batches {
124 writer
125 .write(batch)
126 .map_err(|e| DbxError::Storage(format!("Spill: batch 쓰기 실패: {}", e)))?;
127 }
128
129 writer
130 .finish()
131 .map_err(|e| DbxError::Storage(format!("Spill: IPC 파일 완료 실패: {}", e)))?;
132
133 Ok(())
134 }
135
136 pub fn reload_batches(path: &PathBuf) -> DbxResult<Vec<RecordBatch>> {
140 let file = File::open(path).map_err(|e| {
141 DbxError::Storage(format!(
142 "Spill: reload 파일 열기 실패 {}: {}",
143 path.display(),
144 e
145 ))
146 })?;
147
148 let reader = FileReader::try_new(file, None).map_err(|e| {
149 DbxError::Storage(format!("Spill: Arrow IPC reader 초기화 실패: {}", e))
150 })?;
151
152 let mut batches = Vec::new();
153 for result in reader {
154 let batch =
155 result.map_err(|e| DbxError::Storage(format!("Spill: batch 읽기 실패: {}", e)))?;
156 batches.push(batch);
157 }
158
159 Ok(batches)
160 }
161
162 pub fn estimate_batch_bytes(batch: &RecordBatch) -> usize {
164 batch
165 .columns()
166 .iter()
167 .map(|col| col.get_buffer_memory_size())
168 .sum()
169 }
170
171 #[cfg(test)]
173 pub fn used_bytes(&self) -> usize {
174 self.used_bytes
175 }
176
177 #[cfg(test)]
179 pub fn spill_count(&self) -> usize {
180 self.spill_count
181 }
182
183 #[cfg(test)]
185 pub fn budget_bytes(&self) -> usize {
186 self.budget_bytes
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193 use arrow::array::{Float64Array, Int32Array};
194 use arrow::datatypes::{DataType, Field, Schema};
195 use std::sync::Arc;
196
197 fn make_batch(n: usize) -> RecordBatch {
198 let schema = Arc::new(Schema::new(vec![
199 Field::new("id", DataType::Int32, false),
200 Field::new("val", DataType::Float64, false),
201 ]));
202 RecordBatch::try_new(
203 schema,
204 vec![
205 Arc::new(Int32Array::from((0..n as i32).collect::<Vec<_>>())),
206 Arc::new(Float64Array::from(
207 (0..n).map(|i| i as f64).collect::<Vec<_>>(),
208 )),
209 ],
210 )
211 .unwrap()
212 }
213
214 #[test]
215 fn test_spill_context_creation() {
216 let ctx = SpillContext::new().unwrap();
217 assert_eq!(ctx.budget_bytes(), 128 * 1024 * 1024);
218 assert!(!ctx.should_spill());
219 }
220
221 #[test]
222 fn test_memory_tracking() {
223 let mut ctx = SpillContext::with_budget(1000).unwrap();
224 ctx.track(500);
225 assert!(!ctx.should_spill());
226 ctx.track(600);
227 assert!(ctx.should_spill());
228 }
229
230 #[test]
231 fn test_spill_and_reload_round_trip() {
232 let mut ctx = SpillContext::new().unwrap();
233 let batch1 = make_batch(100);
234 let batch2 = make_batch(50);
235 let batches = vec![batch1, batch2];
236
237 let path = ctx.spill_batches(&batches).unwrap();
238 assert!(path.exists());
239 assert_eq!(ctx.spill_count(), 1);
240 assert_eq!(ctx.used_bytes(), 0); let reloaded = SpillContext::reload_batches(&path).unwrap();
243 let total_rows: usize = reloaded.iter().map(|b| b.num_rows()).sum();
244 assert_eq!(total_rows, 150);
245 }
246
247 #[test]
248 fn test_estimate_batch_bytes() {
249 let batch = make_batch(1000);
250 let size = SpillContext::estimate_batch_bytes(&batch);
251 assert!(size > 0, "배치 크기가 0보다 커야 함");
252 }
253
254 #[test]
255 fn test_multiple_spills() {
256 let mut ctx = SpillContext::new().unwrap();
257 let batch = make_batch(10);
258
259 let path1 = ctx.spill_batches(&[batch.clone()]).unwrap();
260 let path2 = ctx.spill_batches(&[batch]).unwrap();
261
262 assert_ne!(path1, path2, "각 Spill은 다른 파일이어야 함");
263 assert_eq!(ctx.spill_count(), 2);
264 }
265}