1use crate::storage::engine::crc32::{crc32, crc32_update};
2use std::io::{self, Read};
3
4pub const WAL_MAGIC: &[u8; 4] = b"RDBW";
6
7pub const WAL_VERSION: u8 = 2;
9
10const COMPRESS_THRESHOLD: usize = 256;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16#[repr(u8)]
17pub enum Compression {
18 None = 0,
19 Zstd = 1,
20}
21
22impl Compression {
23 fn from_u8(v: u8) -> Option<Self> {
24 match v {
25 0 => Some(Compression::None),
26 1 => Some(Compression::Zstd),
27 _ => None,
28 }
29 }
30}
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34#[repr(u8)]
35pub enum RecordType {
36 Begin = 1,
37 Commit = 2,
38 Rollback = 3,
39 PageWrite = 4,
42 Checkpoint = 5,
43 PageWriteCompressed = 6,
52 TxCommitBatch = 7,
59}
60
61impl RecordType {
62 pub fn from_u8(v: u8) -> Option<Self> {
63 match v {
64 1 => Some(RecordType::Begin),
65 2 => Some(RecordType::Commit),
66 3 => Some(RecordType::Rollback),
67 4 => Some(RecordType::PageWrite),
68 5 => Some(RecordType::Checkpoint),
69 6 => Some(RecordType::PageWriteCompressed),
70 7 => Some(RecordType::TxCommitBatch),
71 _ => None,
72 }
73 }
74}
75
76#[derive(Debug, Clone, PartialEq)]
78pub enum WalRecord {
79 Begin { tx_id: u64 },
81 Commit { tx_id: u64 },
83 Rollback { tx_id: u64 },
85 PageWrite {
88 tx_id: u64,
89 page_id: u32,
90 data: Vec<u8>,
91 },
92 TxCommitBatch { tx_id: u64, actions: Vec<Vec<u8>> },
95 Checkpoint { lsn: u64 },
97}
98
99impl WalRecord {
100 pub fn encode(&self) -> Vec<u8> {
106 let mut buf = Vec::new();
107
108 match self {
123 WalRecord::Begin { tx_id } => {
124 buf.push(RecordType::Begin as u8);
125 buf.extend_from_slice(&tx_id.to_le_bytes());
126 }
127 WalRecord::Commit { tx_id } => {
128 buf.push(RecordType::Commit as u8);
129 buf.extend_from_slice(&tx_id.to_le_bytes());
130 }
131 WalRecord::Rollback { tx_id } => {
132 buf.push(RecordType::Rollback as u8);
133 buf.extend_from_slice(&tx_id.to_le_bytes());
134 }
135 WalRecord::PageWrite {
136 tx_id,
137 page_id,
138 data,
139 } => {
140 if data.len() >= COMPRESS_THRESHOLD {
141 if let Ok(compressed) =
143 zstd::bulk::compress(data.as_slice(), 3)
144 {
145 if compressed.len() < data.len() {
146 buf.push(RecordType::PageWriteCompressed as u8);
148 buf.extend_from_slice(&tx_id.to_le_bytes());
149 buf.extend_from_slice(&page_id.to_le_bytes());
150 buf.push(Compression::Zstd as u8);
151 buf.extend_from_slice(&(data.len() as u32).to_le_bytes()); buf.extend_from_slice(&(compressed.len() as u32).to_le_bytes());
153 buf.extend_from_slice(&compressed);
154 let checksum = crc32(&buf);
155 buf.extend_from_slice(&checksum.to_le_bytes());
156 return buf;
157 }
158 }
159 }
160 buf.push(RecordType::PageWrite as u8);
162 buf.extend_from_slice(&tx_id.to_le_bytes());
163 buf.extend_from_slice(&page_id.to_le_bytes());
164 buf.extend_from_slice(&(data.len() as u32).to_le_bytes());
165 buf.extend_from_slice(data);
166 }
167 WalRecord::TxCommitBatch { tx_id, actions } => {
168 buf.push(RecordType::TxCommitBatch as u8);
169 buf.extend_from_slice(&tx_id.to_le_bytes());
170 buf.extend_from_slice(&(actions.len() as u32).to_le_bytes());
171 for action in actions {
172 buf.extend_from_slice(&(action.len() as u32).to_le_bytes());
173 buf.extend_from_slice(action);
174 }
175 }
176 WalRecord::Checkpoint { lsn } => {
177 buf.push(RecordType::Checkpoint as u8);
178 buf.extend_from_slice(&lsn.to_le_bytes());
179 }
180 }
181
182 let checksum = crc32(&buf);
184 buf.extend_from_slice(&checksum.to_le_bytes());
185
186 buf
187 }
188
189 pub fn read<R: Read>(reader: &mut R) -> io::Result<Option<WalRecord>> {
194 let mut type_buf = [0u8; 1];
196 match reader.read_exact(&mut type_buf) {
197 Ok(_) => (),
198 Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
199 Err(e) => return Err(e),
200 };
201
202 let record_type = RecordType::from_u8(type_buf[0])
203 .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Invalid record type"))?;
204
205 let mut running_crc = crc32_update(0, &type_buf);
207
208 let record = match record_type {
209 RecordType::Begin | RecordType::Commit | RecordType::Rollback => {
210 let mut buf = [0u8; 8];
211 reader.read_exact(&mut buf)?;
212 running_crc = crc32_update(running_crc, &buf);
213 let tx_id = u64::from_le_bytes(buf);
214
215 match record_type {
216 RecordType::Begin => WalRecord::Begin { tx_id },
217 RecordType::Commit => WalRecord::Commit { tx_id },
218 RecordType::Rollback => WalRecord::Rollback { tx_id },
219 _ => unreachable!(),
220 }
221 }
222 RecordType::PageWrite => {
223 let mut tx_buf = [0u8; 8];
225 reader.read_exact(&mut tx_buf)?;
226 running_crc = crc32_update(running_crc, &tx_buf);
227 let tx_id = u64::from_le_bytes(tx_buf);
228
229 let mut page_buf = [0u8; 4];
231 reader.read_exact(&mut page_buf)?;
232 running_crc = crc32_update(running_crc, &page_buf);
233 let page_id = u32::from_le_bytes(page_buf);
234
235 let mut len_buf = [0u8; 4];
237 reader.read_exact(&mut len_buf)?;
238 running_crc = crc32_update(running_crc, &len_buf);
239 let len = u32::from_le_bytes(len_buf) as usize;
240
241 let mut data = vec![0u8; len];
243 reader.read_exact(&mut data)?;
244 running_crc = crc32_update(running_crc, &data);
245
246 WalRecord::PageWrite {
247 tx_id,
248 page_id,
249 data,
250 }
251 }
252 RecordType::PageWriteCompressed => {
253 let mut tx_buf = [0u8; 8];
255 reader.read_exact(&mut tx_buf)?;
256 running_crc = crc32_update(running_crc, &tx_buf);
257 let tx_id = u64::from_le_bytes(tx_buf);
258
259 let mut page_buf = [0u8; 4];
261 reader.read_exact(&mut page_buf)?;
262 running_crc = crc32_update(running_crc, &page_buf);
263 let page_id = u32::from_le_bytes(page_buf);
264
265 let mut comp_buf = [0u8; 1];
267 reader.read_exact(&mut comp_buf)?;
268 running_crc = crc32_update(running_crc, &comp_buf);
269 let compression = Compression::from_u8(comp_buf[0]).ok_or_else(|| {
270 io::Error::new(
271 io::ErrorKind::InvalidData,
272 format!("Unknown WAL compression algorithm: {}", comp_buf[0]),
273 )
274 })?;
275
276 let mut orig_len_buf = [0u8; 4];
278 reader.read_exact(&mut orig_len_buf)?;
279 running_crc = crc32_update(running_crc, &orig_len_buf);
280 let orig_len = u32::from_le_bytes(orig_len_buf) as usize;
281
282 let mut len_buf = [0u8; 4];
284 reader.read_exact(&mut len_buf)?;
285 running_crc = crc32_update(running_crc, &len_buf);
286 let len = u32::from_le_bytes(len_buf) as usize;
287
288 let mut compressed = vec![0u8; len];
290 reader.read_exact(&mut compressed)?;
291 running_crc = crc32_update(running_crc, &compressed);
292
293 let data = match compression {
295 Compression::Zstd => {
296 let mut out = vec![0u8; orig_len];
297 zstd::bulk::decompress_to_buffer(&compressed, &mut out).map_err(|e| {
298 io::Error::new(
299 io::ErrorKind::InvalidData,
300 format!("WAL zstd decompress failed: {e}"),
301 )
302 })?;
303 out
304 }
305 Compression::None => compressed,
306 };
307
308 WalRecord::PageWrite {
309 tx_id,
310 page_id,
311 data,
312 }
313 }
314 RecordType::TxCommitBatch => {
315 let mut tx_buf = [0u8; 8];
316 reader.read_exact(&mut tx_buf)?;
317 running_crc = crc32_update(running_crc, &tx_buf);
318 let tx_id = u64::from_le_bytes(tx_buf);
319
320 let mut count_buf = [0u8; 4];
321 reader.read_exact(&mut count_buf)?;
322 running_crc = crc32_update(running_crc, &count_buf);
323 let count = u32::from_le_bytes(count_buf) as usize;
324
325 let mut actions = Vec::with_capacity(count);
326 for _ in 0..count {
327 let mut len_buf = [0u8; 4];
328 reader.read_exact(&mut len_buf)?;
329 running_crc = crc32_update(running_crc, &len_buf);
330 let len = u32::from_le_bytes(len_buf) as usize;
331
332 let mut action = vec![0u8; len];
333 reader.read_exact(&mut action)?;
334 running_crc = crc32_update(running_crc, &action);
335 actions.push(action);
336 }
337
338 WalRecord::TxCommitBatch { tx_id, actions }
339 }
340 RecordType::Checkpoint => {
341 let mut buf = [0u8; 8];
342 reader.read_exact(&mut buf)?;
343 running_crc = crc32_update(running_crc, &buf);
344 let lsn = u64::from_le_bytes(buf);
345 WalRecord::Checkpoint { lsn }
346 }
347 };
348
349 let mut crc_buf = [0u8; 4];
351 reader.read_exact(&mut crc_buf)?;
352 let stored_crc = u32::from_le_bytes(crc_buf);
353
354 if running_crc != stored_crc {
355 return Err(io::Error::new(
356 io::ErrorKind::InvalidData,
357 "WAL record checksum mismatch",
358 ));
359 }
360
361 Ok(Some(record))
362 }
363}
364
365#[cfg(test)]
366mod tests {
367 use super::*;
368 use std::io::Cursor;
369
370 #[test]
373 fn test_record_type_from_u8() {
374 assert_eq!(RecordType::from_u8(1), Some(RecordType::Begin));
375 assert_eq!(RecordType::from_u8(2), Some(RecordType::Commit));
376 assert_eq!(RecordType::from_u8(3), Some(RecordType::Rollback));
377 assert_eq!(RecordType::from_u8(4), Some(RecordType::PageWrite));
378 assert_eq!(RecordType::from_u8(5), Some(RecordType::Checkpoint));
379 assert_eq!(
380 RecordType::from_u8(6),
381 Some(RecordType::PageWriteCompressed)
382 );
383 assert_eq!(RecordType::from_u8(7), Some(RecordType::TxCommitBatch));
384 }
385
386 #[test]
387 fn test_record_type_invalid() {
388 assert_eq!(RecordType::from_u8(0), None);
389 assert_eq!(RecordType::from_u8(8), None);
390 assert_eq!(RecordType::from_u8(255), None);
391 }
392
393 #[test]
396 fn test_encode_begin() {
397 let record = WalRecord::Begin { tx_id: 12345 };
398 let encoded = record.encode();
399
400 assert_eq!(encoded.len(), 13);
402 assert_eq!(encoded[0], RecordType::Begin as u8);
403 }
404
405 #[test]
406 fn test_encode_commit() {
407 let record = WalRecord::Commit { tx_id: 99999 };
408 let encoded = record.encode();
409
410 assert_eq!(encoded.len(), 13);
411 assert_eq!(encoded[0], RecordType::Commit as u8);
412 }
413
414 #[test]
415 fn test_encode_rollback() {
416 let record = WalRecord::Rollback { tx_id: 54321 };
417 let encoded = record.encode();
418
419 assert_eq!(encoded.len(), 13);
420 assert_eq!(encoded[0], RecordType::Rollback as u8);
421 }
422
423 #[test]
424 fn test_encode_checkpoint() {
425 let record = WalRecord::Checkpoint { lsn: 1000000 };
426 let encoded = record.encode();
427
428 assert_eq!(encoded.len(), 13);
429 assert_eq!(encoded[0], RecordType::Checkpoint as u8);
430 }
431
432 #[test]
433 fn test_encode_page_write_small() {
434 let data = vec![1, 2, 3, 4, 5];
436 let record = WalRecord::PageWrite {
437 tx_id: 100,
438 page_id: 42,
439 data: data.clone(),
440 };
441 let encoded = record.encode();
442
443 assert_eq!(encoded.len(), 26);
445 assert_eq!(encoded[0], RecordType::PageWrite as u8);
446 }
447
448 #[test]
449 fn test_encode_page_write_empty_data() {
450 let record = WalRecord::PageWrite {
451 tx_id: 1,
452 page_id: 0,
453 data: vec![],
454 };
455 let encoded = record.encode();
456
457 assert_eq!(encoded.len(), 21);
459 }
460
461 #[test]
462 fn test_encode_tx_commit_batch() {
463 let record = WalRecord::TxCommitBatch {
464 tx_id: 7,
465 actions: vec![b"insert".to_vec(), b"update".to_vec()],
466 };
467 let encoded = record.encode();
468
469 assert_eq!(encoded[0], RecordType::TxCommitBatch as u8);
470 }
471
472 #[test]
475 fn test_read_begin_roundtrip() {
476 let original = WalRecord::Begin { tx_id: 42 };
477 let encoded = original.encode();
478
479 let mut cursor = Cursor::new(encoded);
480 let decoded = WalRecord::read(&mut cursor).unwrap().unwrap();
481
482 assert_eq!(decoded, original);
483 }
484
485 #[test]
486 fn test_read_commit_roundtrip() {
487 let original = WalRecord::Commit { tx_id: 999 };
488 let encoded = original.encode();
489
490 let mut cursor = Cursor::new(encoded);
491 let decoded = WalRecord::read(&mut cursor).unwrap().unwrap();
492
493 assert_eq!(decoded, original);
494 }
495
496 #[test]
497 fn test_read_rollback_roundtrip() {
498 let original = WalRecord::Rollback { tx_id: 777 };
499 let encoded = original.encode();
500
501 let mut cursor = Cursor::new(encoded);
502 let decoded = WalRecord::read(&mut cursor).unwrap().unwrap();
503
504 assert_eq!(decoded, original);
505 }
506
507 #[test]
508 fn test_read_checkpoint_roundtrip() {
509 let original = WalRecord::Checkpoint { lsn: 123456789 };
510 let encoded = original.encode();
511
512 let mut cursor = Cursor::new(encoded);
513 let decoded = WalRecord::read(&mut cursor).unwrap().unwrap();
514
515 assert_eq!(decoded, original);
516 }
517
518 #[test]
519 fn test_read_page_write_roundtrip() {
520 let original = WalRecord::PageWrite {
521 tx_id: 50,
522 page_id: 100,
523 data: vec![10, 20, 30, 40, 50, 60, 70, 80],
524 };
525 let encoded = original.encode();
526
527 let mut cursor = Cursor::new(encoded);
528 let decoded = WalRecord::read(&mut cursor).unwrap().unwrap();
529
530 assert_eq!(decoded, original);
531 }
532
533 #[test]
534 fn test_read_tx_commit_batch_roundtrip() {
535 let original = WalRecord::TxCommitBatch {
536 tx_id: 42,
537 actions: vec![b"old-version".to_vec(), b"new-version".to_vec()],
538 };
539 let encoded = original.encode();
540
541 let mut cursor = Cursor::new(encoded);
542 let decoded = WalRecord::read(&mut cursor).unwrap().unwrap();
543
544 assert_eq!(decoded, original);
545 }
546
547 #[test]
548 fn test_read_page_write_large_data() {
549 let data: Vec<u8> = (0..4096).map(|i| (i % 256) as u8).collect();
551 let original = WalRecord::PageWrite {
552 tx_id: 1,
553 page_id: 0,
554 data,
555 };
556 let encoded = original.encode();
557
558 let mut cursor = Cursor::new(encoded);
559 let decoded = WalRecord::read(&mut cursor).unwrap().unwrap();
560
561 assert_eq!(decoded, original);
563 }
564
565 #[test]
566 fn page_write_compressed_roundtrip() {
567 let data = vec![0xABu8; 1024];
569 let record = WalRecord::PageWrite {
570 tx_id: 7,
571 page_id: 3,
572 data: data.clone(),
573 };
574 let encoded = record.encode();
575
576 assert_eq!(encoded[0], RecordType::PageWriteCompressed as u8);
578
579 let mut cursor = Cursor::new(encoded);
581 let decoded = WalRecord::read(&mut cursor).unwrap().unwrap();
582 assert_eq!(
583 decoded,
584 WalRecord::PageWrite {
585 tx_id: 7,
586 page_id: 3,
587 data
588 }
589 );
590 }
591
592 #[test]
593 fn test_read_eof() {
594 let mut cursor = Cursor::new(Vec::<u8>::new());
595 let result = WalRecord::read(&mut cursor).unwrap();
596 assert!(result.is_none());
597 }
598
599 #[test]
600 fn test_read_invalid_record_type() {
601 let buf = vec![99, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; let mut cursor = Cursor::new(buf);
603 let result = WalRecord::read(&mut cursor);
604 assert!(result.is_err());
605 }
606
607 #[test]
608 fn test_read_checksum_mismatch() {
609 let record = WalRecord::Begin { tx_id: 42 };
610 let mut encoded = record.encode();
611
612 let len = encoded.len();
614 encoded[len - 1] ^= 0xFF;
615
616 let mut cursor = Cursor::new(encoded);
617 let result = WalRecord::read(&mut cursor);
618 assert!(result.is_err());
619 }
620
621 #[test]
622 fn test_read_data_corruption() {
623 let record = WalRecord::PageWrite {
624 tx_id: 1,
625 page_id: 2,
626 data: vec![1, 2, 3, 4],
627 };
628 let mut encoded = record.encode();
629
630 encoded[15] ^= 0xFF;
632
633 let mut cursor = Cursor::new(encoded);
634 let result = WalRecord::read(&mut cursor);
635 assert!(result.is_err()); }
637
638 #[test]
641 fn test_multiple_records_sequential() {
642 let records = vec![
643 WalRecord::Begin { tx_id: 1 },
644 WalRecord::PageWrite {
645 tx_id: 1,
646 page_id: 10,
647 data: vec![1, 2, 3],
648 },
649 WalRecord::PageWrite {
650 tx_id: 1,
651 page_id: 20,
652 data: vec![4, 5, 6],
653 },
654 WalRecord::Commit { tx_id: 1 },
655 WalRecord::Checkpoint { lsn: 100 },
656 ];
657
658 let mut buf = Vec::new();
660 for r in &records {
661 buf.extend_from_slice(&r.encode());
662 }
663
664 let mut cursor = Cursor::new(buf);
666 for expected in &records {
667 let decoded = WalRecord::read(&mut cursor).unwrap().unwrap();
668 assert_eq!(&decoded, expected);
669 }
670
671 assert!(WalRecord::read(&mut cursor).unwrap().is_none());
673 }
674
675 #[test]
678 fn test_wal_magic() {
679 assert_eq!(WAL_MAGIC, b"RDBW");
680 }
681
682 #[test]
683 fn test_wal_version() {
684 assert_eq!(WAL_VERSION, 2);
685 }
686}