1use crate::storage::engine::crc32::{crc32, crc32_update};
2use std::io::{self, Read};
3
4pub const WAL_MAGIC: &[u8; 4] = b"RDBW";
6
7pub const WAL_VERSION: u8 = 2;
9
10const COMPRESS_THRESHOLD: usize = 256;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16#[repr(u8)]
17pub enum Compression {
18 None = 0,
19 Zstd = 1,
20}
21
22impl Compression {
23 fn from_u8(v: u8) -> Option<Self> {
24 match v {
25 0 => Some(Compression::None),
26 1 => Some(Compression::Zstd),
27 _ => None,
28 }
29 }
30}
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34#[repr(u8)]
35pub enum RecordType {
36 Begin = 1,
37 Commit = 2,
38 Rollback = 3,
39 PageWrite = 4,
42 Checkpoint = 5,
43 PageWriteCompressed = 6,
52}
53
54impl RecordType {
55 pub fn from_u8(v: u8) -> Option<Self> {
56 match v {
57 1 => Some(RecordType::Begin),
58 2 => Some(RecordType::Commit),
59 3 => Some(RecordType::Rollback),
60 4 => Some(RecordType::PageWrite),
61 5 => Some(RecordType::Checkpoint),
62 6 => Some(RecordType::PageWriteCompressed),
63 _ => None,
64 }
65 }
66}
67
68#[derive(Debug, Clone, PartialEq)]
70pub enum WalRecord {
71 Begin { tx_id: u64 },
73 Commit { tx_id: u64 },
75 Rollback { tx_id: u64 },
77 PageWrite {
80 tx_id: u64,
81 page_id: u32,
82 data: Vec<u8>,
83 },
84 Checkpoint { lsn: u64 },
86}
87
88impl WalRecord {
89 pub fn encode(&self) -> Vec<u8> {
95 let mut buf = Vec::new();
96
97 match self {
109 WalRecord::Begin { tx_id } => {
110 buf.push(RecordType::Begin as u8);
111 buf.extend_from_slice(&tx_id.to_le_bytes());
112 }
113 WalRecord::Commit { tx_id } => {
114 buf.push(RecordType::Commit as u8);
115 buf.extend_from_slice(&tx_id.to_le_bytes());
116 }
117 WalRecord::Rollback { tx_id } => {
118 buf.push(RecordType::Rollback as u8);
119 buf.extend_from_slice(&tx_id.to_le_bytes());
120 }
121 WalRecord::PageWrite {
122 tx_id,
123 page_id,
124 data,
125 } => {
126 if data.len() >= COMPRESS_THRESHOLD {
127 if let Ok(compressed) =
129 zstd::bulk::compress(data.as_slice(), 3)
130 {
131 if compressed.len() < data.len() {
132 buf.push(RecordType::PageWriteCompressed as u8);
134 buf.extend_from_slice(&tx_id.to_le_bytes());
135 buf.extend_from_slice(&page_id.to_le_bytes());
136 buf.push(Compression::Zstd as u8);
137 buf.extend_from_slice(&(data.len() as u32).to_le_bytes()); buf.extend_from_slice(&(compressed.len() as u32).to_le_bytes());
139 buf.extend_from_slice(&compressed);
140 let checksum = crc32(&buf);
141 buf.extend_from_slice(&checksum.to_le_bytes());
142 return buf;
143 }
144 }
145 }
146 buf.push(RecordType::PageWrite as u8);
148 buf.extend_from_slice(&tx_id.to_le_bytes());
149 buf.extend_from_slice(&page_id.to_le_bytes());
150 buf.extend_from_slice(&(data.len() as u32).to_le_bytes());
151 buf.extend_from_slice(data);
152 }
153 WalRecord::Checkpoint { lsn } => {
154 buf.push(RecordType::Checkpoint as u8);
155 buf.extend_from_slice(&lsn.to_le_bytes());
156 }
157 }
158
159 let checksum = crc32(&buf);
161 buf.extend_from_slice(&checksum.to_le_bytes());
162
163 buf
164 }
165
166 pub fn read<R: Read>(reader: &mut R) -> io::Result<Option<WalRecord>> {
171 let mut type_buf = [0u8; 1];
173 match reader.read_exact(&mut type_buf) {
174 Ok(_) => (),
175 Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
176 Err(e) => return Err(e),
177 };
178
179 let record_type = RecordType::from_u8(type_buf[0])
180 .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Invalid record type"))?;
181
182 let mut running_crc = crc32_update(0, &type_buf);
184
185 let record = match record_type {
186 RecordType::Begin | RecordType::Commit | RecordType::Rollback => {
187 let mut buf = [0u8; 8];
188 reader.read_exact(&mut buf)?;
189 running_crc = crc32_update(running_crc, &buf);
190 let tx_id = u64::from_le_bytes(buf);
191
192 match record_type {
193 RecordType::Begin => WalRecord::Begin { tx_id },
194 RecordType::Commit => WalRecord::Commit { tx_id },
195 RecordType::Rollback => WalRecord::Rollback { tx_id },
196 _ => unreachable!(),
197 }
198 }
199 RecordType::PageWrite => {
200 let mut tx_buf = [0u8; 8];
202 reader.read_exact(&mut tx_buf)?;
203 running_crc = crc32_update(running_crc, &tx_buf);
204 let tx_id = u64::from_le_bytes(tx_buf);
205
206 let mut page_buf = [0u8; 4];
208 reader.read_exact(&mut page_buf)?;
209 running_crc = crc32_update(running_crc, &page_buf);
210 let page_id = u32::from_le_bytes(page_buf);
211
212 let mut len_buf = [0u8; 4];
214 reader.read_exact(&mut len_buf)?;
215 running_crc = crc32_update(running_crc, &len_buf);
216 let len = u32::from_le_bytes(len_buf) as usize;
217
218 let mut data = vec![0u8; len];
220 reader.read_exact(&mut data)?;
221 running_crc = crc32_update(running_crc, &data);
222
223 WalRecord::PageWrite {
224 tx_id,
225 page_id,
226 data,
227 }
228 }
229 RecordType::PageWriteCompressed => {
230 let mut tx_buf = [0u8; 8];
232 reader.read_exact(&mut tx_buf)?;
233 running_crc = crc32_update(running_crc, &tx_buf);
234 let tx_id = u64::from_le_bytes(tx_buf);
235
236 let mut page_buf = [0u8; 4];
238 reader.read_exact(&mut page_buf)?;
239 running_crc = crc32_update(running_crc, &page_buf);
240 let page_id = u32::from_le_bytes(page_buf);
241
242 let mut comp_buf = [0u8; 1];
244 reader.read_exact(&mut comp_buf)?;
245 running_crc = crc32_update(running_crc, &comp_buf);
246 let compression = Compression::from_u8(comp_buf[0]).ok_or_else(|| {
247 io::Error::new(
248 io::ErrorKind::InvalidData,
249 format!("Unknown WAL compression algorithm: {}", comp_buf[0]),
250 )
251 })?;
252
253 let mut orig_len_buf = [0u8; 4];
255 reader.read_exact(&mut orig_len_buf)?;
256 running_crc = crc32_update(running_crc, &orig_len_buf);
257 let orig_len = u32::from_le_bytes(orig_len_buf) as usize;
258
259 let mut len_buf = [0u8; 4];
261 reader.read_exact(&mut len_buf)?;
262 running_crc = crc32_update(running_crc, &len_buf);
263 let len = u32::from_le_bytes(len_buf) as usize;
264
265 let mut compressed = vec![0u8; len];
267 reader.read_exact(&mut compressed)?;
268 running_crc = crc32_update(running_crc, &compressed);
269
270 let data = match compression {
272 Compression::Zstd => {
273 let mut out = vec![0u8; orig_len];
274 zstd::bulk::decompress_to_buffer(&compressed, &mut out).map_err(|e| {
275 io::Error::new(
276 io::ErrorKind::InvalidData,
277 format!("WAL zstd decompress failed: {e}"),
278 )
279 })?;
280 out
281 }
282 Compression::None => compressed,
283 };
284
285 WalRecord::PageWrite {
286 tx_id,
287 page_id,
288 data,
289 }
290 }
291 RecordType::Checkpoint => {
292 let mut buf = [0u8; 8];
293 reader.read_exact(&mut buf)?;
294 running_crc = crc32_update(running_crc, &buf);
295 let lsn = u64::from_le_bytes(buf);
296 WalRecord::Checkpoint { lsn }
297 }
298 };
299
300 let mut crc_buf = [0u8; 4];
302 reader.read_exact(&mut crc_buf)?;
303 let stored_crc = u32::from_le_bytes(crc_buf);
304
305 if running_crc != stored_crc {
306 return Err(io::Error::new(
307 io::ErrorKind::InvalidData,
308 "WAL record checksum mismatch",
309 ));
310 }
311
312 Ok(Some(record))
313 }
314}
315
316#[cfg(test)]
317mod tests {
318 use super::*;
319 use std::io::Cursor;
320
321 #[test]
324 fn test_record_type_from_u8() {
325 assert_eq!(RecordType::from_u8(1), Some(RecordType::Begin));
326 assert_eq!(RecordType::from_u8(2), Some(RecordType::Commit));
327 assert_eq!(RecordType::from_u8(3), Some(RecordType::Rollback));
328 assert_eq!(RecordType::from_u8(4), Some(RecordType::PageWrite));
329 assert_eq!(RecordType::from_u8(5), Some(RecordType::Checkpoint));
330 assert_eq!(
331 RecordType::from_u8(6),
332 Some(RecordType::PageWriteCompressed)
333 );
334 }
335
336 #[test]
337 fn test_record_type_invalid() {
338 assert_eq!(RecordType::from_u8(0), None);
339 assert_eq!(RecordType::from_u8(7), None);
340 assert_eq!(RecordType::from_u8(255), None);
341 }
342
343 #[test]
346 fn test_encode_begin() {
347 let record = WalRecord::Begin { tx_id: 12345 };
348 let encoded = record.encode();
349
350 assert_eq!(encoded.len(), 13);
352 assert_eq!(encoded[0], RecordType::Begin as u8);
353 }
354
355 #[test]
356 fn test_encode_commit() {
357 let record = WalRecord::Commit { tx_id: 99999 };
358 let encoded = record.encode();
359
360 assert_eq!(encoded.len(), 13);
361 assert_eq!(encoded[0], RecordType::Commit as u8);
362 }
363
364 #[test]
365 fn test_encode_rollback() {
366 let record = WalRecord::Rollback { tx_id: 54321 };
367 let encoded = record.encode();
368
369 assert_eq!(encoded.len(), 13);
370 assert_eq!(encoded[0], RecordType::Rollback as u8);
371 }
372
373 #[test]
374 fn test_encode_checkpoint() {
375 let record = WalRecord::Checkpoint { lsn: 1000000 };
376 let encoded = record.encode();
377
378 assert_eq!(encoded.len(), 13);
379 assert_eq!(encoded[0], RecordType::Checkpoint as u8);
380 }
381
382 #[test]
383 fn test_encode_page_write_small() {
384 let data = vec![1, 2, 3, 4, 5];
386 let record = WalRecord::PageWrite {
387 tx_id: 100,
388 page_id: 42,
389 data: data.clone(),
390 };
391 let encoded = record.encode();
392
393 assert_eq!(encoded.len(), 26);
395 assert_eq!(encoded[0], RecordType::PageWrite as u8);
396 }
397
398 #[test]
399 fn test_encode_page_write_empty_data() {
400 let record = WalRecord::PageWrite {
401 tx_id: 1,
402 page_id: 0,
403 data: vec![],
404 };
405 let encoded = record.encode();
406
407 assert_eq!(encoded.len(), 21);
409 }
410
411 #[test]
414 fn test_read_begin_roundtrip() {
415 let original = WalRecord::Begin { tx_id: 42 };
416 let encoded = original.encode();
417
418 let mut cursor = Cursor::new(encoded);
419 let decoded = WalRecord::read(&mut cursor).unwrap().unwrap();
420
421 assert_eq!(decoded, original);
422 }
423
424 #[test]
425 fn test_read_commit_roundtrip() {
426 let original = WalRecord::Commit { tx_id: 999 };
427 let encoded = original.encode();
428
429 let mut cursor = Cursor::new(encoded);
430 let decoded = WalRecord::read(&mut cursor).unwrap().unwrap();
431
432 assert_eq!(decoded, original);
433 }
434
435 #[test]
436 fn test_read_rollback_roundtrip() {
437 let original = WalRecord::Rollback { tx_id: 777 };
438 let encoded = original.encode();
439
440 let mut cursor = Cursor::new(encoded);
441 let decoded = WalRecord::read(&mut cursor).unwrap().unwrap();
442
443 assert_eq!(decoded, original);
444 }
445
446 #[test]
447 fn test_read_checkpoint_roundtrip() {
448 let original = WalRecord::Checkpoint { lsn: 123456789 };
449 let encoded = original.encode();
450
451 let mut cursor = Cursor::new(encoded);
452 let decoded = WalRecord::read(&mut cursor).unwrap().unwrap();
453
454 assert_eq!(decoded, original);
455 }
456
457 #[test]
458 fn test_read_page_write_roundtrip() {
459 let original = WalRecord::PageWrite {
460 tx_id: 50,
461 page_id: 100,
462 data: vec![10, 20, 30, 40, 50, 60, 70, 80],
463 };
464 let encoded = original.encode();
465
466 let mut cursor = Cursor::new(encoded);
467 let decoded = WalRecord::read(&mut cursor).unwrap().unwrap();
468
469 assert_eq!(decoded, original);
470 }
471
472 #[test]
473 fn test_read_page_write_large_data() {
474 let data: Vec<u8> = (0..4096).map(|i| (i % 256) as u8).collect();
476 let original = WalRecord::PageWrite {
477 tx_id: 1,
478 page_id: 0,
479 data,
480 };
481 let encoded = original.encode();
482
483 let mut cursor = Cursor::new(encoded);
484 let decoded = WalRecord::read(&mut cursor).unwrap().unwrap();
485
486 assert_eq!(decoded, original);
488 }
489
490 #[test]
491 fn page_write_compressed_roundtrip() {
492 let data = vec![0xABu8; 1024];
494 let record = WalRecord::PageWrite {
495 tx_id: 7,
496 page_id: 3,
497 data: data.clone(),
498 };
499 let encoded = record.encode();
500
501 assert_eq!(encoded[0], RecordType::PageWriteCompressed as u8);
503
504 let mut cursor = Cursor::new(encoded);
506 let decoded = WalRecord::read(&mut cursor).unwrap().unwrap();
507 assert_eq!(
508 decoded,
509 WalRecord::PageWrite {
510 tx_id: 7,
511 page_id: 3,
512 data
513 }
514 );
515 }
516
517 #[test]
518 fn test_read_eof() {
519 let mut cursor = Cursor::new(Vec::<u8>::new());
520 let result = WalRecord::read(&mut cursor).unwrap();
521 assert!(result.is_none());
522 }
523
524 #[test]
525 fn test_read_invalid_record_type() {
526 let buf = vec![99, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; let mut cursor = Cursor::new(buf);
528 let result = WalRecord::read(&mut cursor);
529 assert!(result.is_err());
530 }
531
532 #[test]
533 fn test_read_checksum_mismatch() {
534 let record = WalRecord::Begin { tx_id: 42 };
535 let mut encoded = record.encode();
536
537 let len = encoded.len();
539 encoded[len - 1] ^= 0xFF;
540
541 let mut cursor = Cursor::new(encoded);
542 let result = WalRecord::read(&mut cursor);
543 assert!(result.is_err());
544 }
545
546 #[test]
547 fn test_read_data_corruption() {
548 let record = WalRecord::PageWrite {
549 tx_id: 1,
550 page_id: 2,
551 data: vec![1, 2, 3, 4],
552 };
553 let mut encoded = record.encode();
554
555 encoded[15] ^= 0xFF;
557
558 let mut cursor = Cursor::new(encoded);
559 let result = WalRecord::read(&mut cursor);
560 assert!(result.is_err()); }
562
563 #[test]
566 fn test_multiple_records_sequential() {
567 let records = vec![
568 WalRecord::Begin { tx_id: 1 },
569 WalRecord::PageWrite {
570 tx_id: 1,
571 page_id: 10,
572 data: vec![1, 2, 3],
573 },
574 WalRecord::PageWrite {
575 tx_id: 1,
576 page_id: 20,
577 data: vec![4, 5, 6],
578 },
579 WalRecord::Commit { tx_id: 1 },
580 WalRecord::Checkpoint { lsn: 100 },
581 ];
582
583 let mut buf = Vec::new();
585 for r in &records {
586 buf.extend_from_slice(&r.encode());
587 }
588
589 let mut cursor = Cursor::new(buf);
591 for expected in &records {
592 let decoded = WalRecord::read(&mut cursor).unwrap().unwrap();
593 assert_eq!(&decoded, expected);
594 }
595
596 assert!(WalRecord::read(&mut cursor).unwrap().is_none());
598 }
599
600 #[test]
603 fn test_wal_magic() {
604 assert_eq!(WAL_MAGIC, b"RDBW");
605 }
606
607 #[test]
608 fn test_wal_version() {
609 assert_eq!(WAL_VERSION, 2);
610 }
611}