Skip to main content

dbx_core/wal/
partitioned_wal.rs

1//! Partitioned WAL Writer — Phase 2: Section 5.1
2//!
3//! 테이블별 독립 WAL 세그먼트로 병렬 쓰기를 가능하게 함
4
5use crate::error::{DbxError, DbxResult};
6use crate::wal::WalRecord;
7use dashmap::DashMap;
8use rayon::prelude::*;
9use std::path::PathBuf;
10use std::sync::atomic::{AtomicU64, Ordering};
11
12/// 파티션 기반 WAL 쓰기 엔진
13///
14/// 각 테이블이 독립적인 WAL 세그먼트를 가짐으로써
15/// 서로 다른 테이블의 쓰기가 동시에 진행될 수 있습니다.
16pub struct PartitionedWalWriter {
17    /// 파티션별 WAL 버퍼: table_name → records
18    partitions: DashMap<String, Vec<WalRecord>>,
19    /// WAL 디렉토리
20    wal_dir: PathBuf,
21    /// 글로벌 시퀀스 번호 (원자적 증가)
22    sequence: AtomicU64,
23    /// 버퍼 플러시 임계값 (레코드 수)
24    flush_threshold: usize,
25}
26
27impl PartitionedWalWriter {
28    /// 새 파티션 WAL 생성
29    pub fn new(wal_dir: PathBuf, flush_threshold: usize) -> DbxResult<Self> {
30        if !wal_dir.exists() {
31            std::fs::create_dir_all(&wal_dir).map_err(|source| DbxError::Io { source })?;
32        }
33        Ok(Self {
34            partitions: DashMap::new(),
35            wal_dir,
36            sequence: AtomicU64::new(0),
37            flush_threshold,
38        })
39    }
40
41    /// 기본 설정 (플러시 임계값: 100)
42    pub fn with_defaults(wal_dir: PathBuf) -> DbxResult<Self> {
43        Self::new(wal_dir, 100)
44    }
45
46    /// WAL 레코드 추가 (테이블별 파티션에 버퍼링)
47    pub fn append(&self, record: WalRecord) -> DbxResult<u64> {
48        let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
49        let table = Self::extract_table(&record);
50
51        let mut partition = self.partitions.entry(table).or_default();
52        partition.push(record);
53
54        // 임계값 도달 시 자동 플러시
55        if partition.len() >= self.flush_threshold {
56            let records = std::mem::take(&mut *partition);
57            drop(partition); // DashMap lock 해제
58            self.flush_records(&Self::extract_table(&records[0]), &records)?;
59        }
60
61        Ok(seq)
62    }
63
64    /// 여러 레코드를 한번에 추가 (배치 쓰기)
65    pub fn append_batch(&self, records: Vec<WalRecord>) -> DbxResult<Vec<u64>> {
66        let sequences: Vec<u64> = records
67            .iter()
68            .map(|_| self.sequence.fetch_add(1, Ordering::SeqCst))
69            .collect();
70
71        // 테이블별로 그룹화
72        let mut grouped: std::collections::HashMap<String, Vec<WalRecord>> =
73            std::collections::HashMap::new();
74        for record in records {
75            let table = Self::extract_table(&record);
76            grouped.entry(table).or_default().push(record);
77        }
78
79        // 병렬로 각 파티션에 추가
80        let results: Vec<DbxResult<()>> = grouped
81            .into_par_iter()
82            .map(|(table, partition_records)| {
83                let mut partition = self.partitions.entry(table.clone()).or_default();
84                partition.extend(partition_records);
85
86                if partition.len() >= self.flush_threshold {
87                    let records = std::mem::take(&mut *partition);
88                    drop(partition);
89                    self.flush_records(&table, &records)?;
90                }
91                Ok(())
92            })
93            .collect();
94
95        // 에러 체크
96        for result in results {
97            result?;
98        }
99
100        Ok(sequences)
101    }
102
103    /// 모든 파티션 플러시 (병렬)
104    pub fn flush_all(&self) -> DbxResult<usize> {
105        let tables: Vec<String> = self.partitions.iter().map(|e| e.key().clone()).collect();
106
107        let flushed: Vec<DbxResult<usize>> = tables
108            .par_iter()
109            .map(|table| {
110                if let Some(mut partition) = self.partitions.get_mut(table) {
111                    if partition.is_empty() {
112                        return Ok(0);
113                    }
114                    let records = std::mem::take(&mut *partition);
115                    let count = records.len();
116                    drop(partition);
117                    self.flush_records(table, &records)?;
118                    Ok(count)
119                } else {
120                    Ok(0)
121                }
122            })
123            .collect();
124
125        let mut total = 0;
126        for result in flushed {
127            total += result?;
128        }
129        Ok(total)
130    }
131
132    /// 파티션 수 조회
133    pub fn partition_count(&self) -> usize {
134        self.partitions.len()
135    }
136
137    /// 버퍼에 있는 총 레코드 수
138    pub fn buffered_count(&self) -> usize {
139        self.partitions.iter().map(|e| e.value().len()).sum()
140    }
141
142    /// 현재 시퀀스 번호
143    pub fn current_sequence(&self) -> u64 {
144        self.sequence.load(Ordering::SeqCst)
145    }
146
147    // ─── Internal helpers ───────────────────────────────
148
149    /// 레코드에서 테이블 이름 추출
150    fn extract_table(record: &WalRecord) -> String {
151        match record {
152            WalRecord::Insert { table, .. } => table.clone(),
153            WalRecord::Delete { table, .. } => table.clone(),
154            WalRecord::Batch { table, .. } => table.clone(),
155            WalRecord::Checkpoint { .. } => "__checkpoint__".to_string(),
156            WalRecord::Commit { .. } => "__tx__".to_string(),
157            WalRecord::Rollback { .. } => "__tx__".to_string(),
158        }
159    }
160
161    /// 레코드를 디스크에 플러시
162    fn flush_records(&self, table: &str, records: &[WalRecord]) -> DbxResult<()> {
163        let safe_name = table.replace(['/', '\\', ':'], "_");
164        let path = self.wal_dir.join(format!("{safe_name}.wal"));
165
166        let serialized: Vec<u8> = records
167            .iter()
168            .flat_map(|r| {
169                let mut buf = bincode::serialize(r).unwrap_or_default();
170                let len = buf.len() as u32;
171                let mut frame = len.to_le_bytes().to_vec();
172                frame.append(&mut buf);
173                frame
174            })
175            .collect();
176
177        use std::io::Write;
178        let mut file = std::fs::OpenOptions::new()
179            .create(true)
180            .append(true)
181            .open(&path)
182            .map_err(|source| DbxError::Io { source })?;
183        file.write_all(&serialized)
184            .map_err(|source| DbxError::Io { source })?;
185        file.flush().map_err(|source| DbxError::Io { source })?;
186
187        Ok(())
188    }
189}
190
191/// 병렬 체크포인트 관리자
192///
193/// 여러 테이블의 체크포인트를 동시에 생성합니다.
194pub struct ParallelCheckpointManager {
195    wal_dir: PathBuf,
196}
197
198impl ParallelCheckpointManager {
199    pub fn new(wal_dir: PathBuf) -> Self {
200        Self { wal_dir }
201    }
202
203    /// 여러 테이블을 병렬로 체크포인트
204    pub fn checkpoint_tables(&self, tables: &[String]) -> DbxResult<usize> {
205        let results: Vec<DbxResult<()>> = tables
206            .par_iter()
207            .map(|table| {
208                // 각 테이블의 WAL 파일을 체크포인트
209                let safe_name = table.replace(['/', '\\', ':'], "_");
210                let wal_path = self.wal_dir.join(format!("{safe_name}.wal"));
211                let checkpoint_path = self.wal_dir.join(format!("{safe_name}.checkpoint"));
212
213                if wal_path.exists() {
214                    // WAL 내용을 체크포인트로 이동
215                    std::fs::rename(&wal_path, &checkpoint_path)
216                        .map_err(|source| DbxError::Io { source })?;
217                }
218                Ok(())
219            })
220            .collect();
221
222        let mut count = 0;
223        for result in results {
224            result?;
225            count += 1;
226        }
227        Ok(count)
228    }
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234    use tempfile::tempdir;
235
236    fn insert_record(table: &str, key: &[u8], value: &[u8]) -> WalRecord {
237        WalRecord::Insert {
238            table: table.to_string(),
239            key: key.to_vec(),
240            value: value.to_vec(),
241            ts: 0,
242        }
243    }
244
245    #[test]
246    fn test_partitioned_wal_basic() {
247        let dir = tempdir().unwrap();
248        let wal = PartitionedWalWriter::new(dir.path().to_path_buf(), 100).unwrap();
249
250        let seq = wal.append(insert_record("users", b"k1", b"v1")).unwrap();
251        assert_eq!(seq, 0);
252
253        let seq2 = wal.append(insert_record("orders", b"k2", b"v2")).unwrap();
254        assert_eq!(seq2, 1);
255
256        assert_eq!(wal.partition_count(), 2);
257        assert_eq!(wal.buffered_count(), 2);
258    }
259
260    #[test]
261    fn test_partitioned_wal_batch() {
262        let dir = tempdir().unwrap();
263        let wal = PartitionedWalWriter::new(dir.path().to_path_buf(), 100).unwrap();
264
265        let records = vec![
266            insert_record("users", b"k1", b"v1"),
267            insert_record("users", b"k2", b"v2"),
268            insert_record("orders", b"k3", b"v3"),
269        ];
270
271        let seqs = wal.append_batch(records).unwrap();
272        assert_eq!(seqs.len(), 3);
273        assert_eq!(wal.partition_count(), 2);
274    }
275
276    #[test]
277    fn test_partitioned_wal_flush() {
278        let dir = tempdir().unwrap();
279        let wal = PartitionedWalWriter::new(dir.path().to_path_buf(), 100).unwrap();
280
281        for i in 0..10 {
282            wal.append(insert_record("users", format!("k{i}").as_bytes(), b"v"))
283                .unwrap();
284        }
285
286        let flushed = wal.flush_all().unwrap();
287        assert_eq!(flushed, 10);
288        assert_eq!(wal.buffered_count(), 0);
289
290        // WAL 파일 생성 확인
291        assert!(dir.path().join("users.wal").exists());
292    }
293
294    #[test]
295    fn test_partitioned_wal_auto_flush() {
296        let dir = tempdir().unwrap();
297        let wal = PartitionedWalWriter::new(dir.path().to_path_buf(), 5).unwrap();
298
299        // 5개 추가 시 자동 플러시
300        for i in 0..5 {
301            wal.append(insert_record("users", format!("k{i}").as_bytes(), b"v"))
302                .unwrap();
303        }
304
305        // 자동 플러시 후 버퍼는 비어 있어야 함
306        assert_eq!(wal.buffered_count(), 0);
307        assert!(dir.path().join("users.wal").exists());
308    }
309
310    #[test]
311    fn test_parallel_checkpoint() {
312        let dir = tempdir().unwrap();
313        let wal = PartitionedWalWriter::new(dir.path().to_path_buf(), 100).unwrap();
314
315        // 데이터 추가 후 플러시
316        wal.append(insert_record("users", b"k1", b"v1")).unwrap();
317        wal.append(insert_record("orders", b"k2", b"v2")).unwrap();
318        wal.flush_all().unwrap();
319
320        // 체크포인트
321        let checkpoint_mgr = ParallelCheckpointManager::new(dir.path().to_path_buf());
322        let count = checkpoint_mgr
323            .checkpoint_tables(&["users".to_string(), "orders".to_string()])
324            .unwrap();
325        assert_eq!(count, 2);
326
327        // 체크포인트 파일 확인
328        assert!(dir.path().join("users.checkpoint").exists());
329        assert!(dir.path().join("orders.checkpoint").exists());
330    }
331}