dbx_core/wal/
partitioned_wal.rs1use crate::error::{DbxError, DbxResult};
6use crate::wal::WalRecord;
7use dashmap::DashMap;
8use rayon::prelude::*;
9use std::path::PathBuf;
10use std::sync::atomic::{AtomicU64, Ordering};
11
12pub struct PartitionedWalWriter {
17 partitions: DashMap<String, Vec<WalRecord>>,
19 wal_dir: PathBuf,
21 sequence: AtomicU64,
23 flush_threshold: usize,
25}
26
27impl PartitionedWalWriter {
28 pub fn new(wal_dir: PathBuf, flush_threshold: usize) -> DbxResult<Self> {
30 if !wal_dir.exists() {
31 std::fs::create_dir_all(&wal_dir).map_err(|source| DbxError::Io { source })?;
32 }
33 Ok(Self {
34 partitions: DashMap::new(),
35 wal_dir,
36 sequence: AtomicU64::new(0),
37 flush_threshold,
38 })
39 }
40
41 pub fn with_defaults(wal_dir: PathBuf) -> DbxResult<Self> {
43 Self::new(wal_dir, 100)
44 }
45
46 pub fn append(&self, record: WalRecord) -> DbxResult<u64> {
48 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
49 let table = Self::extract_table(&record);
50
51 let mut partition = self.partitions.entry(table).or_default();
52 partition.push(record);
53
54 if partition.len() >= self.flush_threshold {
56 let records = std::mem::take(&mut *partition);
57 drop(partition); self.flush_records(&Self::extract_table(&records[0]), &records)?;
59 }
60
61 Ok(seq)
62 }
63
64 pub fn append_batch(&self, records: Vec<WalRecord>) -> DbxResult<Vec<u64>> {
66 let sequences: Vec<u64> = records
67 .iter()
68 .map(|_| self.sequence.fetch_add(1, Ordering::SeqCst))
69 .collect();
70
71 let mut grouped: std::collections::HashMap<String, Vec<WalRecord>> =
73 std::collections::HashMap::new();
74 for record in records {
75 let table = Self::extract_table(&record);
76 grouped.entry(table).or_default().push(record);
77 }
78
79 let results: Vec<DbxResult<()>> = grouped
81 .into_par_iter()
82 .map(|(table, partition_records)| {
83 let mut partition = self.partitions.entry(table.clone()).or_default();
84 partition.extend(partition_records);
85
86 if partition.len() >= self.flush_threshold {
87 let records = std::mem::take(&mut *partition);
88 drop(partition);
89 self.flush_records(&table, &records)?;
90 }
91 Ok(())
92 })
93 .collect();
94
95 for result in results {
97 result?;
98 }
99
100 Ok(sequences)
101 }
102
103 pub fn flush_all(&self) -> DbxResult<usize> {
105 let tables: Vec<String> = self.partitions.iter().map(|e| e.key().clone()).collect();
106
107 let flushed: Vec<DbxResult<usize>> = tables
108 .par_iter()
109 .map(|table| {
110 if let Some(mut partition) = self.partitions.get_mut(table) {
111 if partition.is_empty() {
112 return Ok(0);
113 }
114 let records = std::mem::take(&mut *partition);
115 let count = records.len();
116 drop(partition);
117 self.flush_records(table, &records)?;
118 Ok(count)
119 } else {
120 Ok(0)
121 }
122 })
123 .collect();
124
125 let mut total = 0;
126 for result in flushed {
127 total += result?;
128 }
129 Ok(total)
130 }
131
132 pub fn partition_count(&self) -> usize {
134 self.partitions.len()
135 }
136
137 pub fn buffered_count(&self) -> usize {
139 self.partitions.iter().map(|e| e.value().len()).sum()
140 }
141
142 pub fn current_sequence(&self) -> u64 {
144 self.sequence.load(Ordering::SeqCst)
145 }
146
147 fn extract_table(record: &WalRecord) -> String {
151 match record {
152 WalRecord::Insert { table, .. } => table.clone(),
153 WalRecord::Delete { table, .. } => table.clone(),
154 WalRecord::Batch { table, .. } => table.clone(),
155 WalRecord::Checkpoint { .. } => "__checkpoint__".to_string(),
156 WalRecord::Commit { .. } => "__tx__".to_string(),
157 WalRecord::Rollback { .. } => "__tx__".to_string(),
158 }
159 }
160
161 fn flush_records(&self, table: &str, records: &[WalRecord]) -> DbxResult<()> {
163 let safe_name = table.replace(['/', '\\', ':'], "_");
164 let path = self.wal_dir.join(format!("{safe_name}.wal"));
165
166 let serialized: Vec<u8> = records
167 .iter()
168 .flat_map(|r| {
169 let mut buf = bincode::serialize(r).unwrap_or_default();
170 let len = buf.len() as u32;
171 let mut frame = len.to_le_bytes().to_vec();
172 frame.append(&mut buf);
173 frame
174 })
175 .collect();
176
177 use std::io::Write;
178 let mut file = std::fs::OpenOptions::new()
179 .create(true)
180 .append(true)
181 .open(&path)
182 .map_err(|source| DbxError::Io { source })?;
183 file.write_all(&serialized)
184 .map_err(|source| DbxError::Io { source })?;
185 file.flush().map_err(|source| DbxError::Io { source })?;
186
187 Ok(())
188 }
189}
190
191pub struct ParallelCheckpointManager {
195 wal_dir: PathBuf,
196}
197
198impl ParallelCheckpointManager {
199 pub fn new(wal_dir: PathBuf) -> Self {
200 Self { wal_dir }
201 }
202
203 pub fn checkpoint_tables(&self, tables: &[String]) -> DbxResult<usize> {
205 let results: Vec<DbxResult<()>> = tables
206 .par_iter()
207 .map(|table| {
208 let safe_name = table.replace(['/', '\\', ':'], "_");
210 let wal_path = self.wal_dir.join(format!("{safe_name}.wal"));
211 let checkpoint_path = self.wal_dir.join(format!("{safe_name}.checkpoint"));
212
213 if wal_path.exists() {
214 std::fs::rename(&wal_path, &checkpoint_path)
216 .map_err(|source| DbxError::Io { source })?;
217 }
218 Ok(())
219 })
220 .collect();
221
222 let mut count = 0;
223 for result in results {
224 result?;
225 count += 1;
226 }
227 Ok(count)
228 }
229}
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234 use tempfile::tempdir;
235
236 fn insert_record(table: &str, key: &[u8], value: &[u8]) -> WalRecord {
237 WalRecord::Insert {
238 table: table.to_string(),
239 key: key.to_vec(),
240 value: value.to_vec(),
241 ts: 0,
242 }
243 }
244
245 #[test]
246 fn test_partitioned_wal_basic() {
247 let dir = tempdir().unwrap();
248 let wal = PartitionedWalWriter::new(dir.path().to_path_buf(), 100).unwrap();
249
250 let seq = wal.append(insert_record("users", b"k1", b"v1")).unwrap();
251 assert_eq!(seq, 0);
252
253 let seq2 = wal.append(insert_record("orders", b"k2", b"v2")).unwrap();
254 assert_eq!(seq2, 1);
255
256 assert_eq!(wal.partition_count(), 2);
257 assert_eq!(wal.buffered_count(), 2);
258 }
259
260 #[test]
261 fn test_partitioned_wal_batch() {
262 let dir = tempdir().unwrap();
263 let wal = PartitionedWalWriter::new(dir.path().to_path_buf(), 100).unwrap();
264
265 let records = vec![
266 insert_record("users", b"k1", b"v1"),
267 insert_record("users", b"k2", b"v2"),
268 insert_record("orders", b"k3", b"v3"),
269 ];
270
271 let seqs = wal.append_batch(records).unwrap();
272 assert_eq!(seqs.len(), 3);
273 assert_eq!(wal.partition_count(), 2);
274 }
275
276 #[test]
277 fn test_partitioned_wal_flush() {
278 let dir = tempdir().unwrap();
279 let wal = PartitionedWalWriter::new(dir.path().to_path_buf(), 100).unwrap();
280
281 for i in 0..10 {
282 wal.append(insert_record("users", format!("k{i}").as_bytes(), b"v"))
283 .unwrap();
284 }
285
286 let flushed = wal.flush_all().unwrap();
287 assert_eq!(flushed, 10);
288 assert_eq!(wal.buffered_count(), 0);
289
290 assert!(dir.path().join("users.wal").exists());
292 }
293
294 #[test]
295 fn test_partitioned_wal_auto_flush() {
296 let dir = tempdir().unwrap();
297 let wal = PartitionedWalWriter::new(dir.path().to_path_buf(), 5).unwrap();
298
299 for i in 0..5 {
301 wal.append(insert_record("users", format!("k{i}").as_bytes(), b"v"))
302 .unwrap();
303 }
304
305 assert_eq!(wal.buffered_count(), 0);
307 assert!(dir.path().join("users.wal").exists());
308 }
309
310 #[test]
311 fn test_parallel_checkpoint() {
312 let dir = tempdir().unwrap();
313 let wal = PartitionedWalWriter::new(dir.path().to_path_buf(), 100).unwrap();
314
315 wal.append(insert_record("users", b"k1", b"v1")).unwrap();
317 wal.append(insert_record("orders", b"k2", b"v2")).unwrap();
318 wal.flush_all().unwrap();
319
320 let checkpoint_mgr = ParallelCheckpointManager::new(dir.path().to_path_buf());
322 let count = checkpoint_mgr
323 .checkpoint_tables(&["users".to_string(), "orders".to_string()])
324 .unwrap();
325 assert_eq!(count, 2);
326
327 assert!(dir.path().join("users.checkpoint").exists());
329 assert!(dir.path().join("orders.checkpoint").exists());
330 }
331}