Skip to main content

reddb_server/storage/wal/
record.rs

1use crate::storage::engine::crc32::{crc32, crc32_update};
2use std::io::{self, Read};
3
4/// WAL file magic bytes (RDBW)
5pub const WAL_MAGIC: &[u8; 4] = b"RDBW";
6
7/// WAL file format version
8pub const WAL_VERSION: u8 = 2;
9
10/// Minimum payload size (bytes) to attempt zstd compression.
11/// Smaller records pay more overhead than benefit from compression.
12const COMPRESS_THRESHOLD: usize = 256;
13
14/// Compression algorithm tag embedded in `PageWriteCompressed` records.
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16#[repr(u8)]
17pub enum Compression {
18    None = 0,
19    Zstd = 1,
20}
21
22impl Compression {
23    fn from_u8(v: u8) -> Option<Self> {
24        match v {
25            0 => Some(Compression::None),
26            1 => Some(Compression::Zstd),
27            _ => None,
28        }
29    }
30}
31
32/// Type of WAL record
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34#[repr(u8)]
35pub enum RecordType {
36    Begin = 1,
37    Commit = 2,
38    Rollback = 3,
39    /// Legacy uncompressed page write (v1 format — still written for
40    /// small payloads to avoid compression overhead).
41    PageWrite = 4,
42    Checkpoint = 5,
43    /// Compressed page write (v2 format).
44    ///
45    /// Layout (after the type byte):
46    /// ```text
47    /// [TxID: 8][PageID: 4][Compression: 1][OrigLen: 4][DataLen: 4][Data: N][CRC: 4]
48    /// ```
49    /// `OrigLen` is the original (uncompressed) size; needed to pre-allocate
50    /// the decompression buffer.
51    PageWriteCompressed = 6,
52    /// Logical autocommit transaction commit batch (v2 format).
53    ///
54    /// Layout (after the type byte):
55    /// ```text
56    /// [TxID: 8][ActionCount: 4][[DataLen: 4][Data: N]...][CRC: 4]
57    /// ```
58    TxCommitBatch = 7,
59}
60
61impl RecordType {
62    pub fn from_u8(v: u8) -> Option<Self> {
63        match v {
64            1 => Some(RecordType::Begin),
65            2 => Some(RecordType::Commit),
66            3 => Some(RecordType::Rollback),
67            4 => Some(RecordType::PageWrite),
68            5 => Some(RecordType::Checkpoint),
69            6 => Some(RecordType::PageWriteCompressed),
70            7 => Some(RecordType::TxCommitBatch),
71            _ => None,
72        }
73    }
74}
75
76/// A single entry in the write-ahead log
77#[derive(Debug, Clone, PartialEq)]
78pub enum WalRecord {
79    /// Start of a transaction
80    Begin { tx_id: u64 },
81    /// Commit of a transaction
82    Commit { tx_id: u64 },
83    /// Rollback of a transaction
84    Rollback { tx_id: u64 },
85    /// Write of a page — always carries uncompressed data (transparent to
86    /// callers: `read()` decompresses on-the-fly).
87    PageWrite {
88        tx_id: u64,
89        page_id: u32,
90        data: Vec<u8>,
91    },
92    /// Atomic logical commit batch. Recovery applies all actions in
93    /// order iff this complete record and checksum are present.
94    TxCommitBatch { tx_id: u64, actions: Vec<Vec<u8>> },
95    /// Checkpoint marker (indicates up to which LSN pages are flushed)
96    Checkpoint { lsn: u64 },
97}
98
99impl WalRecord {
100    /// Serialize record to bytes (including checksum).
101    ///
102    /// `PageWrite` records whose payload is ≥ `COMPRESS_THRESHOLD` bytes are
103    /// compressed with zstd level 3 and emitted as `PageWriteCompressed`.
104    /// Smaller payloads use the plain `PageWrite` encoding (no overhead).
105    pub fn encode(&self) -> Vec<u8> {
106        let mut buf = Vec::new();
107
108        // Layout (non-PageWrite):
109        // [Type: 1]
110        // [TxID/LSN: 8]
111        // [Checksum: 4]
112        //
113        // PageWrite (uncompressed):
114        // [Type: 1][TxID: 8][PageID: 4][DataLen: 4][Data: N][CRC: 4]
115        //
116        // PageWriteCompressed:
117        // [Type: 1][TxID: 8][PageID: 4][Compression: 1][OrigLen: 4][DataLen: 4][Data: N][CRC: 4]
118        //
119        // TxCommitBatch:
120        // [Type: 1][TxID: 8][ActionCount: 4][[DataLen: 4][Data: N]...][CRC: 4]
121
122        match self {
123            WalRecord::Begin { tx_id } => {
124                buf.push(RecordType::Begin as u8);
125                buf.extend_from_slice(&tx_id.to_le_bytes());
126            }
127            WalRecord::Commit { tx_id } => {
128                buf.push(RecordType::Commit as u8);
129                buf.extend_from_slice(&tx_id.to_le_bytes());
130            }
131            WalRecord::Rollback { tx_id } => {
132                buf.push(RecordType::Rollback as u8);
133                buf.extend_from_slice(&tx_id.to_le_bytes());
134            }
135            WalRecord::PageWrite {
136                tx_id,
137                page_id,
138                data,
139            } => {
140                if data.len() >= COMPRESS_THRESHOLD {
141                    // Try zstd compression; fall back to uncompressed if it expands.
142                    if let Ok(compressed) =
143                        zstd::bulk::compress(data.as_slice(), /* level */ 3)
144                    {
145                        if compressed.len() < data.len() {
146                            // Compressed is smaller — use compressed format.
147                            buf.push(RecordType::PageWriteCompressed as u8);
148                            buf.extend_from_slice(&tx_id.to_le_bytes());
149                            buf.extend_from_slice(&page_id.to_le_bytes());
150                            buf.push(Compression::Zstd as u8);
151                            buf.extend_from_slice(&(data.len() as u32).to_le_bytes()); // orig_len
152                            buf.extend_from_slice(&(compressed.len() as u32).to_le_bytes());
153                            buf.extend_from_slice(&compressed);
154                            let checksum = crc32(&buf);
155                            buf.extend_from_slice(&checksum.to_le_bytes());
156                            return buf;
157                        }
158                    }
159                }
160                // Uncompressed path (small payload or compression expanded).
161                buf.push(RecordType::PageWrite as u8);
162                buf.extend_from_slice(&tx_id.to_le_bytes());
163                buf.extend_from_slice(&page_id.to_le_bytes());
164                buf.extend_from_slice(&(data.len() as u32).to_le_bytes());
165                buf.extend_from_slice(data);
166            }
167            WalRecord::TxCommitBatch { tx_id, actions } => {
168                buf.push(RecordType::TxCommitBatch as u8);
169                buf.extend_from_slice(&tx_id.to_le_bytes());
170                buf.extend_from_slice(&(actions.len() as u32).to_le_bytes());
171                for action in actions {
172                    buf.extend_from_slice(&(action.len() as u32).to_le_bytes());
173                    buf.extend_from_slice(action);
174                }
175            }
176            WalRecord::Checkpoint { lsn } => {
177                buf.push(RecordType::Checkpoint as u8);
178                buf.extend_from_slice(&lsn.to_le_bytes());
179            }
180        }
181
182        // Calculate and append checksum
183        let checksum = crc32(&buf);
184        buf.extend_from_slice(&checksum.to_le_bytes());
185
186        buf
187    }
188
189    /// Read a record from a reader.
190    ///
191    /// Handles both v1 (`PageWrite`) and v2 (`PageWriteCompressed`) record
192    /// formats transparently — callers always receive uncompressed data.
193    pub fn read<R: Read>(reader: &mut R) -> io::Result<Option<WalRecord>> {
194        // Read type byte
195        let mut type_buf = [0u8; 1];
196        match reader.read_exact(&mut type_buf) {
197            Ok(_) => (),
198            Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
199            Err(e) => return Err(e),
200        };
201
202        let record_type = RecordType::from_u8(type_buf[0])
203            .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Invalid record type"))?;
204
205        // Start checksum calculation
206        let mut running_crc = crc32_update(0, &type_buf);
207
208        let record = match record_type {
209            RecordType::Begin | RecordType::Commit | RecordType::Rollback => {
210                let mut buf = [0u8; 8];
211                reader.read_exact(&mut buf)?;
212                running_crc = crc32_update(running_crc, &buf);
213                let tx_id = u64::from_le_bytes(buf);
214
215                match record_type {
216                    RecordType::Begin => WalRecord::Begin { tx_id },
217                    RecordType::Commit => WalRecord::Commit { tx_id },
218                    RecordType::Rollback => WalRecord::Rollback { tx_id },
219                    _ => unreachable!(),
220                }
221            }
222            RecordType::PageWrite => {
223                // Read TxID
224                let mut tx_buf = [0u8; 8];
225                reader.read_exact(&mut tx_buf)?;
226                running_crc = crc32_update(running_crc, &tx_buf);
227                let tx_id = u64::from_le_bytes(tx_buf);
228
229                // Read PageID
230                let mut page_buf = [0u8; 4];
231                reader.read_exact(&mut page_buf)?;
232                running_crc = crc32_update(running_crc, &page_buf);
233                let page_id = u32::from_le_bytes(page_buf);
234
235                // Read Length
236                let mut len_buf = [0u8; 4];
237                reader.read_exact(&mut len_buf)?;
238                running_crc = crc32_update(running_crc, &len_buf);
239                let len = u32::from_le_bytes(len_buf) as usize;
240
241                // Read Data
242                let mut data = vec![0u8; len];
243                reader.read_exact(&mut data)?;
244                running_crc = crc32_update(running_crc, &data);
245
246                WalRecord::PageWrite {
247                    tx_id,
248                    page_id,
249                    data,
250                }
251            }
252            RecordType::PageWriteCompressed => {
253                // Read TxID
254                let mut tx_buf = [0u8; 8];
255                reader.read_exact(&mut tx_buf)?;
256                running_crc = crc32_update(running_crc, &tx_buf);
257                let tx_id = u64::from_le_bytes(tx_buf);
258
259                // Read PageID
260                let mut page_buf = [0u8; 4];
261                reader.read_exact(&mut page_buf)?;
262                running_crc = crc32_update(running_crc, &page_buf);
263                let page_id = u32::from_le_bytes(page_buf);
264
265                // Read Compression algorithm byte
266                let mut comp_buf = [0u8; 1];
267                reader.read_exact(&mut comp_buf)?;
268                running_crc = crc32_update(running_crc, &comp_buf);
269                let compression = Compression::from_u8(comp_buf[0]).ok_or_else(|| {
270                    io::Error::new(
271                        io::ErrorKind::InvalidData,
272                        format!("Unknown WAL compression algorithm: {}", comp_buf[0]),
273                    )
274                })?;
275
276                // Read original (uncompressed) length — used to pre-allocate decompression buffer
277                let mut orig_len_buf = [0u8; 4];
278                reader.read_exact(&mut orig_len_buf)?;
279                running_crc = crc32_update(running_crc, &orig_len_buf);
280                let orig_len = u32::from_le_bytes(orig_len_buf) as usize;
281
282                // Read compressed data length
283                let mut len_buf = [0u8; 4];
284                reader.read_exact(&mut len_buf)?;
285                running_crc = crc32_update(running_crc, &len_buf);
286                let len = u32::from_le_bytes(len_buf) as usize;
287
288                // Read compressed data
289                let mut compressed = vec![0u8; len];
290                reader.read_exact(&mut compressed)?;
291                running_crc = crc32_update(running_crc, &compressed);
292
293                // Decompress
294                let data = match compression {
295                    Compression::Zstd => {
296                        let mut out = vec![0u8; orig_len];
297                        zstd::bulk::decompress_to_buffer(&compressed, &mut out).map_err(|e| {
298                            io::Error::new(
299                                io::ErrorKind::InvalidData,
300                                format!("WAL zstd decompress failed: {e}"),
301                            )
302                        })?;
303                        out
304                    }
305                    Compression::None => compressed,
306                };
307
308                WalRecord::PageWrite {
309                    tx_id,
310                    page_id,
311                    data,
312                }
313            }
314            RecordType::TxCommitBatch => {
315                let mut tx_buf = [0u8; 8];
316                reader.read_exact(&mut tx_buf)?;
317                running_crc = crc32_update(running_crc, &tx_buf);
318                let tx_id = u64::from_le_bytes(tx_buf);
319
320                let mut count_buf = [0u8; 4];
321                reader.read_exact(&mut count_buf)?;
322                running_crc = crc32_update(running_crc, &count_buf);
323                let count = u32::from_le_bytes(count_buf) as usize;
324
325                let mut actions = Vec::with_capacity(count);
326                for _ in 0..count {
327                    let mut len_buf = [0u8; 4];
328                    reader.read_exact(&mut len_buf)?;
329                    running_crc = crc32_update(running_crc, &len_buf);
330                    let len = u32::from_le_bytes(len_buf) as usize;
331
332                    let mut action = vec![0u8; len];
333                    reader.read_exact(&mut action)?;
334                    running_crc = crc32_update(running_crc, &action);
335                    actions.push(action);
336                }
337
338                WalRecord::TxCommitBatch { tx_id, actions }
339            }
340            RecordType::Checkpoint => {
341                let mut buf = [0u8; 8];
342                reader.read_exact(&mut buf)?;
343                running_crc = crc32_update(running_crc, &buf);
344                let lsn = u64::from_le_bytes(buf);
345                WalRecord::Checkpoint { lsn }
346            }
347        };
348
349        // Verify checksum
350        let mut crc_buf = [0u8; 4];
351        reader.read_exact(&mut crc_buf)?;
352        let stored_crc = u32::from_le_bytes(crc_buf);
353
354        if running_crc != stored_crc {
355            return Err(io::Error::new(
356                io::ErrorKind::InvalidData,
357                "WAL record checksum mismatch",
358            ));
359        }
360
361        Ok(Some(record))
362    }
363}
364
365#[cfg(test)]
366mod tests {
367    use super::*;
368    use std::io::Cursor;
369
370    // ==================== RecordType Tests ====================
371
372    #[test]
373    fn test_record_type_from_u8() {
374        assert_eq!(RecordType::from_u8(1), Some(RecordType::Begin));
375        assert_eq!(RecordType::from_u8(2), Some(RecordType::Commit));
376        assert_eq!(RecordType::from_u8(3), Some(RecordType::Rollback));
377        assert_eq!(RecordType::from_u8(4), Some(RecordType::PageWrite));
378        assert_eq!(RecordType::from_u8(5), Some(RecordType::Checkpoint));
379        assert_eq!(
380            RecordType::from_u8(6),
381            Some(RecordType::PageWriteCompressed)
382        );
383        assert_eq!(RecordType::from_u8(7), Some(RecordType::TxCommitBatch));
384    }
385
386    #[test]
387    fn test_record_type_invalid() {
388        assert_eq!(RecordType::from_u8(0), None);
389        assert_eq!(RecordType::from_u8(8), None);
390        assert_eq!(RecordType::from_u8(255), None);
391    }
392
393    // ==================== WalRecord::encode Tests ====================
394
395    #[test]
396    fn test_encode_begin() {
397        let record = WalRecord::Begin { tx_id: 12345 };
398        let encoded = record.encode();
399
400        // Type (1) + TxID (8) + Checksum (4) = 13 bytes
401        assert_eq!(encoded.len(), 13);
402        assert_eq!(encoded[0], RecordType::Begin as u8);
403    }
404
405    #[test]
406    fn test_encode_commit() {
407        let record = WalRecord::Commit { tx_id: 99999 };
408        let encoded = record.encode();
409
410        assert_eq!(encoded.len(), 13);
411        assert_eq!(encoded[0], RecordType::Commit as u8);
412    }
413
414    #[test]
415    fn test_encode_rollback() {
416        let record = WalRecord::Rollback { tx_id: 54321 };
417        let encoded = record.encode();
418
419        assert_eq!(encoded.len(), 13);
420        assert_eq!(encoded[0], RecordType::Rollback as u8);
421    }
422
423    #[test]
424    fn test_encode_checkpoint() {
425        let record = WalRecord::Checkpoint { lsn: 1000000 };
426        let encoded = record.encode();
427
428        assert_eq!(encoded.len(), 13);
429        assert_eq!(encoded[0], RecordType::Checkpoint as u8);
430    }
431
432    #[test]
433    fn test_encode_page_write_small() {
434        // Small data (< COMPRESS_THRESHOLD) stays uncompressed.
435        let data = vec![1, 2, 3, 4, 5];
436        let record = WalRecord::PageWrite {
437            tx_id: 100,
438            page_id: 42,
439            data: data.clone(),
440        };
441        let encoded = record.encode();
442
443        // Type (1) + TxID (8) + PageID (4) + Len (4) + Data (5) + Checksum (4) = 26 bytes
444        assert_eq!(encoded.len(), 26);
445        assert_eq!(encoded[0], RecordType::PageWrite as u8);
446    }
447
448    #[test]
449    fn test_encode_page_write_empty_data() {
450        let record = WalRecord::PageWrite {
451            tx_id: 1,
452            page_id: 0,
453            data: vec![],
454        };
455        let encoded = record.encode();
456
457        // Type (1) + TxID (8) + PageID (4) + Len (4) + Checksum (4) = 21 bytes
458        assert_eq!(encoded.len(), 21);
459    }
460
461    #[test]
462    fn test_encode_tx_commit_batch() {
463        let record = WalRecord::TxCommitBatch {
464            tx_id: 7,
465            actions: vec![b"insert".to_vec(), b"update".to_vec()],
466        };
467        let encoded = record.encode();
468
469        assert_eq!(encoded[0], RecordType::TxCommitBatch as u8);
470    }
471
472    // ==================== WalRecord::read Tests ====================
473
474    #[test]
475    fn test_read_begin_roundtrip() {
476        let original = WalRecord::Begin { tx_id: 42 };
477        let encoded = original.encode();
478
479        let mut cursor = Cursor::new(encoded);
480        let decoded = WalRecord::read(&mut cursor).unwrap().unwrap();
481
482        assert_eq!(decoded, original);
483    }
484
485    #[test]
486    fn test_read_commit_roundtrip() {
487        let original = WalRecord::Commit { tx_id: 999 };
488        let encoded = original.encode();
489
490        let mut cursor = Cursor::new(encoded);
491        let decoded = WalRecord::read(&mut cursor).unwrap().unwrap();
492
493        assert_eq!(decoded, original);
494    }
495
496    #[test]
497    fn test_read_rollback_roundtrip() {
498        let original = WalRecord::Rollback { tx_id: 777 };
499        let encoded = original.encode();
500
501        let mut cursor = Cursor::new(encoded);
502        let decoded = WalRecord::read(&mut cursor).unwrap().unwrap();
503
504        assert_eq!(decoded, original);
505    }
506
507    #[test]
508    fn test_read_checkpoint_roundtrip() {
509        let original = WalRecord::Checkpoint { lsn: 123456789 };
510        let encoded = original.encode();
511
512        let mut cursor = Cursor::new(encoded);
513        let decoded = WalRecord::read(&mut cursor).unwrap().unwrap();
514
515        assert_eq!(decoded, original);
516    }
517
518    #[test]
519    fn test_read_page_write_roundtrip() {
520        let original = WalRecord::PageWrite {
521            tx_id: 50,
522            page_id: 100,
523            data: vec![10, 20, 30, 40, 50, 60, 70, 80],
524        };
525        let encoded = original.encode();
526
527        let mut cursor = Cursor::new(encoded);
528        let decoded = WalRecord::read(&mut cursor).unwrap().unwrap();
529
530        assert_eq!(decoded, original);
531    }
532
533    #[test]
534    fn test_read_tx_commit_batch_roundtrip() {
535        let original = WalRecord::TxCommitBatch {
536            tx_id: 42,
537            actions: vec![b"old-version".to_vec(), b"new-version".to_vec()],
538        };
539        let encoded = original.encode();
540
541        let mut cursor = Cursor::new(encoded);
542        let decoded = WalRecord::read(&mut cursor).unwrap().unwrap();
543
544        assert_eq!(decoded, original);
545    }
546
547    #[test]
548    fn test_read_page_write_large_data() {
549        // Large enough to trigger compression.
550        let data: Vec<u8> = (0..4096).map(|i| (i % 256) as u8).collect();
551        let original = WalRecord::PageWrite {
552            tx_id: 1,
553            page_id: 0,
554            data,
555        };
556        let encoded = original.encode();
557
558        let mut cursor = Cursor::new(encoded);
559        let decoded = WalRecord::read(&mut cursor).unwrap().unwrap();
560
561        // Round-trip: decoded data matches original (even if encoded differently).
562        assert_eq!(decoded, original);
563    }
564
565    #[test]
566    fn page_write_compressed_roundtrip() {
567        // Highly compressible payload: 1 KiB of repeated bytes.
568        let data = vec![0xABu8; 1024];
569        let record = WalRecord::PageWrite {
570            tx_id: 7,
571            page_id: 3,
572            data: data.clone(),
573        };
574        let encoded = record.encode();
575
576        // Should be stored as PageWriteCompressed (compressible > threshold).
577        assert_eq!(encoded[0], RecordType::PageWriteCompressed as u8);
578
579        // And round-trip decoding recovers original.
580        let mut cursor = Cursor::new(encoded);
581        let decoded = WalRecord::read(&mut cursor).unwrap().unwrap();
582        assert_eq!(
583            decoded,
584            WalRecord::PageWrite {
585                tx_id: 7,
586                page_id: 3,
587                data
588            }
589        );
590    }
591
592    #[test]
593    fn test_read_eof() {
594        let mut cursor = Cursor::new(Vec::<u8>::new());
595        let result = WalRecord::read(&mut cursor).unwrap();
596        assert!(result.is_none());
597    }
598
599    #[test]
600    fn test_read_invalid_record_type() {
601        let buf = vec![99, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; // Invalid type 99
602        let mut cursor = Cursor::new(buf);
603        let result = WalRecord::read(&mut cursor);
604        assert!(result.is_err());
605    }
606
607    #[test]
608    fn test_read_checksum_mismatch() {
609        let record = WalRecord::Begin { tx_id: 42 };
610        let mut encoded = record.encode();
611
612        // Corrupt the last byte (checksum)
613        let len = encoded.len();
614        encoded[len - 1] ^= 0xFF;
615
616        let mut cursor = Cursor::new(encoded);
617        let result = WalRecord::read(&mut cursor);
618        assert!(result.is_err());
619    }
620
621    #[test]
622    fn test_read_data_corruption() {
623        let record = WalRecord::PageWrite {
624            tx_id: 1,
625            page_id: 2,
626            data: vec![1, 2, 3, 4],
627        };
628        let mut encoded = record.encode();
629
630        // Corrupt a data byte
631        encoded[15] ^= 0xFF;
632
633        let mut cursor = Cursor::new(encoded);
634        let result = WalRecord::read(&mut cursor);
635        assert!(result.is_err()); // Checksum will fail
636    }
637
638    // ==================== Multiple Records Tests ====================
639
640    #[test]
641    fn test_multiple_records_sequential() {
642        let records = vec![
643            WalRecord::Begin { tx_id: 1 },
644            WalRecord::PageWrite {
645                tx_id: 1,
646                page_id: 10,
647                data: vec![1, 2, 3],
648            },
649            WalRecord::PageWrite {
650                tx_id: 1,
651                page_id: 20,
652                data: vec![4, 5, 6],
653            },
654            WalRecord::Commit { tx_id: 1 },
655            WalRecord::Checkpoint { lsn: 100 },
656        ];
657
658        // Encode all
659        let mut buf = Vec::new();
660        for r in &records {
661            buf.extend_from_slice(&r.encode());
662        }
663
664        // Read them back
665        let mut cursor = Cursor::new(buf);
666        for expected in &records {
667            let decoded = WalRecord::read(&mut cursor).unwrap().unwrap();
668            assert_eq!(&decoded, expected);
669        }
670
671        // Next read should return None (EOF)
672        assert!(WalRecord::read(&mut cursor).unwrap().is_none());
673    }
674
675    // ==================== Constants Tests ====================
676
677    #[test]
678    fn test_wal_magic() {
679        assert_eq!(WAL_MAGIC, b"RDBW");
680    }
681
682    #[test]
683    fn test_wal_version() {
684        assert_eq!(WAL_VERSION, 2);
685    }
686}