Skip to main content

reddb_server/storage/wal/
record.rs

1use crate::storage::engine::crc32::{crc32, crc32_update};
2use std::io::{self, Read};
3
4/// WAL file magic bytes (RDBW)
5pub const WAL_MAGIC: &[u8; 4] = b"RDBW";
6
7/// WAL file format version
8pub const WAL_VERSION: u8 = 2;
9
10/// Minimum payload size (bytes) to attempt zstd compression.
11/// Smaller records pay more overhead than benefit from compression.
12const COMPRESS_THRESHOLD: usize = 256;
13
14/// Compression algorithm tag embedded in `PageWriteCompressed` records.
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16#[repr(u8)]
17pub enum Compression {
18    None = 0,
19    Zstd = 1,
20}
21
22impl Compression {
23    fn from_u8(v: u8) -> Option<Self> {
24        match v {
25            0 => Some(Compression::None),
26            1 => Some(Compression::Zstd),
27            _ => None,
28        }
29    }
30}
31
32/// Type of WAL record
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34#[repr(u8)]
35pub enum RecordType {
36    Begin = 1,
37    Commit = 2,
38    Rollback = 3,
39    /// Legacy uncompressed page write (v1 format — still written for
40    /// small payloads to avoid compression overhead).
41    PageWrite = 4,
42    Checkpoint = 5,
43    /// Compressed page write (v2 format).
44    ///
45    /// Layout (after the type byte):
46    /// ```text
47    /// [TxID: 8][PageID: 4][Compression: 1][OrigLen: 4][DataLen: 4][Data: N][CRC: 4]
48    /// ```
49    /// `OrigLen` is the original (uncompressed) size; needed to pre-allocate
50    /// the decompression buffer.
51    PageWriteCompressed = 6,
52}
53
54impl RecordType {
55    pub fn from_u8(v: u8) -> Option<Self> {
56        match v {
57            1 => Some(RecordType::Begin),
58            2 => Some(RecordType::Commit),
59            3 => Some(RecordType::Rollback),
60            4 => Some(RecordType::PageWrite),
61            5 => Some(RecordType::Checkpoint),
62            6 => Some(RecordType::PageWriteCompressed),
63            _ => None,
64        }
65    }
66}
67
68/// A single entry in the write-ahead log
69#[derive(Debug, Clone, PartialEq)]
70pub enum WalRecord {
71    /// Start of a transaction
72    Begin { tx_id: u64 },
73    /// Commit of a transaction
74    Commit { tx_id: u64 },
75    /// Rollback of a transaction
76    Rollback { tx_id: u64 },
77    /// Write of a page — always carries uncompressed data (transparent to
78    /// callers: `read()` decompresses on-the-fly).
79    PageWrite {
80        tx_id: u64,
81        page_id: u32,
82        data: Vec<u8>,
83    },
84    /// Checkpoint marker (indicates up to which LSN pages are flushed)
85    Checkpoint { lsn: u64 },
86}
87
88impl WalRecord {
89    /// Serialize record to bytes (including checksum).
90    ///
91    /// `PageWrite` records whose payload is ≥ `COMPRESS_THRESHOLD` bytes are
92    /// compressed with zstd level 3 and emitted as `PageWriteCompressed`.
93    /// Smaller payloads use the plain `PageWrite` encoding (no overhead).
94    pub fn encode(&self) -> Vec<u8> {
95        let mut buf = Vec::new();
96
97        // Layout (non-PageWrite):
98        // [Type: 1]
99        // [TxID/LSN: 8]
100        // [Checksum: 4]
101        //
102        // PageWrite (uncompressed):
103        // [Type: 1][TxID: 8][PageID: 4][DataLen: 4][Data: N][CRC: 4]
104        //
105        // PageWriteCompressed:
106        // [Type: 1][TxID: 8][PageID: 4][Compression: 1][OrigLen: 4][DataLen: 4][Data: N][CRC: 4]
107
108        match self {
109            WalRecord::Begin { tx_id } => {
110                buf.push(RecordType::Begin as u8);
111                buf.extend_from_slice(&tx_id.to_le_bytes());
112            }
113            WalRecord::Commit { tx_id } => {
114                buf.push(RecordType::Commit as u8);
115                buf.extend_from_slice(&tx_id.to_le_bytes());
116            }
117            WalRecord::Rollback { tx_id } => {
118                buf.push(RecordType::Rollback as u8);
119                buf.extend_from_slice(&tx_id.to_le_bytes());
120            }
121            WalRecord::PageWrite {
122                tx_id,
123                page_id,
124                data,
125            } => {
126                if data.len() >= COMPRESS_THRESHOLD {
127                    // Try zstd compression; fall back to uncompressed if it expands.
128                    if let Ok(compressed) =
129                        zstd::bulk::compress(data.as_slice(), /* level */ 3)
130                    {
131                        if compressed.len() < data.len() {
132                            // Compressed is smaller — use compressed format.
133                            buf.push(RecordType::PageWriteCompressed as u8);
134                            buf.extend_from_slice(&tx_id.to_le_bytes());
135                            buf.extend_from_slice(&page_id.to_le_bytes());
136                            buf.push(Compression::Zstd as u8);
137                            buf.extend_from_slice(&(data.len() as u32).to_le_bytes()); // orig_len
138                            buf.extend_from_slice(&(compressed.len() as u32).to_le_bytes());
139                            buf.extend_from_slice(&compressed);
140                            let checksum = crc32(&buf);
141                            buf.extend_from_slice(&checksum.to_le_bytes());
142                            return buf;
143                        }
144                    }
145                }
146                // Uncompressed path (small payload or compression expanded).
147                buf.push(RecordType::PageWrite as u8);
148                buf.extend_from_slice(&tx_id.to_le_bytes());
149                buf.extend_from_slice(&page_id.to_le_bytes());
150                buf.extend_from_slice(&(data.len() as u32).to_le_bytes());
151                buf.extend_from_slice(data);
152            }
153            WalRecord::Checkpoint { lsn } => {
154                buf.push(RecordType::Checkpoint as u8);
155                buf.extend_from_slice(&lsn.to_le_bytes());
156            }
157        }
158
159        // Calculate and append checksum
160        let checksum = crc32(&buf);
161        buf.extend_from_slice(&checksum.to_le_bytes());
162
163        buf
164    }
165
166    /// Read a record from a reader.
167    ///
168    /// Handles both v1 (`PageWrite`) and v2 (`PageWriteCompressed`) record
169    /// formats transparently — callers always receive uncompressed data.
170    pub fn read<R: Read>(reader: &mut R) -> io::Result<Option<WalRecord>> {
171        // Read type byte
172        let mut type_buf = [0u8; 1];
173        match reader.read_exact(&mut type_buf) {
174            Ok(_) => (),
175            Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
176            Err(e) => return Err(e),
177        };
178
179        let record_type = RecordType::from_u8(type_buf[0])
180            .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Invalid record type"))?;
181
182        // Start checksum calculation
183        let mut running_crc = crc32_update(0, &type_buf);
184
185        let record = match record_type {
186            RecordType::Begin | RecordType::Commit | RecordType::Rollback => {
187                let mut buf = [0u8; 8];
188                reader.read_exact(&mut buf)?;
189                running_crc = crc32_update(running_crc, &buf);
190                let tx_id = u64::from_le_bytes(buf);
191
192                match record_type {
193                    RecordType::Begin => WalRecord::Begin { tx_id },
194                    RecordType::Commit => WalRecord::Commit { tx_id },
195                    RecordType::Rollback => WalRecord::Rollback { tx_id },
196                    _ => unreachable!(),
197                }
198            }
199            RecordType::PageWrite => {
200                // Read TxID
201                let mut tx_buf = [0u8; 8];
202                reader.read_exact(&mut tx_buf)?;
203                running_crc = crc32_update(running_crc, &tx_buf);
204                let tx_id = u64::from_le_bytes(tx_buf);
205
206                // Read PageID
207                let mut page_buf = [0u8; 4];
208                reader.read_exact(&mut page_buf)?;
209                running_crc = crc32_update(running_crc, &page_buf);
210                let page_id = u32::from_le_bytes(page_buf);
211
212                // Read Length
213                let mut len_buf = [0u8; 4];
214                reader.read_exact(&mut len_buf)?;
215                running_crc = crc32_update(running_crc, &len_buf);
216                let len = u32::from_le_bytes(len_buf) as usize;
217
218                // Read Data
219                let mut data = vec![0u8; len];
220                reader.read_exact(&mut data)?;
221                running_crc = crc32_update(running_crc, &data);
222
223                WalRecord::PageWrite {
224                    tx_id,
225                    page_id,
226                    data,
227                }
228            }
229            RecordType::PageWriteCompressed => {
230                // Read TxID
231                let mut tx_buf = [0u8; 8];
232                reader.read_exact(&mut tx_buf)?;
233                running_crc = crc32_update(running_crc, &tx_buf);
234                let tx_id = u64::from_le_bytes(tx_buf);
235
236                // Read PageID
237                let mut page_buf = [0u8; 4];
238                reader.read_exact(&mut page_buf)?;
239                running_crc = crc32_update(running_crc, &page_buf);
240                let page_id = u32::from_le_bytes(page_buf);
241
242                // Read Compression algorithm byte
243                let mut comp_buf = [0u8; 1];
244                reader.read_exact(&mut comp_buf)?;
245                running_crc = crc32_update(running_crc, &comp_buf);
246                let compression = Compression::from_u8(comp_buf[0]).ok_or_else(|| {
247                    io::Error::new(
248                        io::ErrorKind::InvalidData,
249                        format!("Unknown WAL compression algorithm: {}", comp_buf[0]),
250                    )
251                })?;
252
253                // Read original (uncompressed) length — used to pre-allocate decompression buffer
254                let mut orig_len_buf = [0u8; 4];
255                reader.read_exact(&mut orig_len_buf)?;
256                running_crc = crc32_update(running_crc, &orig_len_buf);
257                let orig_len = u32::from_le_bytes(orig_len_buf) as usize;
258
259                // Read compressed data length
260                let mut len_buf = [0u8; 4];
261                reader.read_exact(&mut len_buf)?;
262                running_crc = crc32_update(running_crc, &len_buf);
263                let len = u32::from_le_bytes(len_buf) as usize;
264
265                // Read compressed data
266                let mut compressed = vec![0u8; len];
267                reader.read_exact(&mut compressed)?;
268                running_crc = crc32_update(running_crc, &compressed);
269
270                // Decompress
271                let data = match compression {
272                    Compression::Zstd => {
273                        let mut out = vec![0u8; orig_len];
274                        zstd::bulk::decompress_to_buffer(&compressed, &mut out).map_err(|e| {
275                            io::Error::new(
276                                io::ErrorKind::InvalidData,
277                                format!("WAL zstd decompress failed: {e}"),
278                            )
279                        })?;
280                        out
281                    }
282                    Compression::None => compressed,
283                };
284
285                WalRecord::PageWrite {
286                    tx_id,
287                    page_id,
288                    data,
289                }
290            }
291            RecordType::Checkpoint => {
292                let mut buf = [0u8; 8];
293                reader.read_exact(&mut buf)?;
294                running_crc = crc32_update(running_crc, &buf);
295                let lsn = u64::from_le_bytes(buf);
296                WalRecord::Checkpoint { lsn }
297            }
298        };
299
300        // Verify checksum
301        let mut crc_buf = [0u8; 4];
302        reader.read_exact(&mut crc_buf)?;
303        let stored_crc = u32::from_le_bytes(crc_buf);
304
305        if running_crc != stored_crc {
306            return Err(io::Error::new(
307                io::ErrorKind::InvalidData,
308                "WAL record checksum mismatch",
309            ));
310        }
311
312        Ok(Some(record))
313    }
314}
315
316#[cfg(test)]
317mod tests {
318    use super::*;
319    use std::io::Cursor;
320
321    // ==================== RecordType Tests ====================
322
323    #[test]
324    fn test_record_type_from_u8() {
325        assert_eq!(RecordType::from_u8(1), Some(RecordType::Begin));
326        assert_eq!(RecordType::from_u8(2), Some(RecordType::Commit));
327        assert_eq!(RecordType::from_u8(3), Some(RecordType::Rollback));
328        assert_eq!(RecordType::from_u8(4), Some(RecordType::PageWrite));
329        assert_eq!(RecordType::from_u8(5), Some(RecordType::Checkpoint));
330        assert_eq!(
331            RecordType::from_u8(6),
332            Some(RecordType::PageWriteCompressed)
333        );
334    }
335
336    #[test]
337    fn test_record_type_invalid() {
338        assert_eq!(RecordType::from_u8(0), None);
339        assert_eq!(RecordType::from_u8(7), None);
340        assert_eq!(RecordType::from_u8(255), None);
341    }
342
343    // ==================== WalRecord::encode Tests ====================
344
345    #[test]
346    fn test_encode_begin() {
347        let record = WalRecord::Begin { tx_id: 12345 };
348        let encoded = record.encode();
349
350        // Type (1) + TxID (8) + Checksum (4) = 13 bytes
351        assert_eq!(encoded.len(), 13);
352        assert_eq!(encoded[0], RecordType::Begin as u8);
353    }
354
355    #[test]
356    fn test_encode_commit() {
357        let record = WalRecord::Commit { tx_id: 99999 };
358        let encoded = record.encode();
359
360        assert_eq!(encoded.len(), 13);
361        assert_eq!(encoded[0], RecordType::Commit as u8);
362    }
363
364    #[test]
365    fn test_encode_rollback() {
366        let record = WalRecord::Rollback { tx_id: 54321 };
367        let encoded = record.encode();
368
369        assert_eq!(encoded.len(), 13);
370        assert_eq!(encoded[0], RecordType::Rollback as u8);
371    }
372
373    #[test]
374    fn test_encode_checkpoint() {
375        let record = WalRecord::Checkpoint { lsn: 1000000 };
376        let encoded = record.encode();
377
378        assert_eq!(encoded.len(), 13);
379        assert_eq!(encoded[0], RecordType::Checkpoint as u8);
380    }
381
382    #[test]
383    fn test_encode_page_write_small() {
384        // Small data (< COMPRESS_THRESHOLD) stays uncompressed.
385        let data = vec![1, 2, 3, 4, 5];
386        let record = WalRecord::PageWrite {
387            tx_id: 100,
388            page_id: 42,
389            data: data.clone(),
390        };
391        let encoded = record.encode();
392
393        // Type (1) + TxID (8) + PageID (4) + Len (4) + Data (5) + Checksum (4) = 26 bytes
394        assert_eq!(encoded.len(), 26);
395        assert_eq!(encoded[0], RecordType::PageWrite as u8);
396    }
397
398    #[test]
399    fn test_encode_page_write_empty_data() {
400        let record = WalRecord::PageWrite {
401            tx_id: 1,
402            page_id: 0,
403            data: vec![],
404        };
405        let encoded = record.encode();
406
407        // Type (1) + TxID (8) + PageID (4) + Len (4) + Checksum (4) = 21 bytes
408        assert_eq!(encoded.len(), 21);
409    }
410
411    // ==================== WalRecord::read Tests ====================
412
413    #[test]
414    fn test_read_begin_roundtrip() {
415        let original = WalRecord::Begin { tx_id: 42 };
416        let encoded = original.encode();
417
418        let mut cursor = Cursor::new(encoded);
419        let decoded = WalRecord::read(&mut cursor).unwrap().unwrap();
420
421        assert_eq!(decoded, original);
422    }
423
424    #[test]
425    fn test_read_commit_roundtrip() {
426        let original = WalRecord::Commit { tx_id: 999 };
427        let encoded = original.encode();
428
429        let mut cursor = Cursor::new(encoded);
430        let decoded = WalRecord::read(&mut cursor).unwrap().unwrap();
431
432        assert_eq!(decoded, original);
433    }
434
435    #[test]
436    fn test_read_rollback_roundtrip() {
437        let original = WalRecord::Rollback { tx_id: 777 };
438        let encoded = original.encode();
439
440        let mut cursor = Cursor::new(encoded);
441        let decoded = WalRecord::read(&mut cursor).unwrap().unwrap();
442
443        assert_eq!(decoded, original);
444    }
445
446    #[test]
447    fn test_read_checkpoint_roundtrip() {
448        let original = WalRecord::Checkpoint { lsn: 123456789 };
449        let encoded = original.encode();
450
451        let mut cursor = Cursor::new(encoded);
452        let decoded = WalRecord::read(&mut cursor).unwrap().unwrap();
453
454        assert_eq!(decoded, original);
455    }
456
457    #[test]
458    fn test_read_page_write_roundtrip() {
459        let original = WalRecord::PageWrite {
460            tx_id: 50,
461            page_id: 100,
462            data: vec![10, 20, 30, 40, 50, 60, 70, 80],
463        };
464        let encoded = original.encode();
465
466        let mut cursor = Cursor::new(encoded);
467        let decoded = WalRecord::read(&mut cursor).unwrap().unwrap();
468
469        assert_eq!(decoded, original);
470    }
471
472    #[test]
473    fn test_read_page_write_large_data() {
474        // Large enough to trigger compression.
475        let data: Vec<u8> = (0..4096).map(|i| (i % 256) as u8).collect();
476        let original = WalRecord::PageWrite {
477            tx_id: 1,
478            page_id: 0,
479            data,
480        };
481        let encoded = original.encode();
482
483        let mut cursor = Cursor::new(encoded);
484        let decoded = WalRecord::read(&mut cursor).unwrap().unwrap();
485
486        // Round-trip: decoded data matches original (even if encoded differently).
487        assert_eq!(decoded, original);
488    }
489
490    #[test]
491    fn page_write_compressed_roundtrip() {
492        // Highly compressible payload: 1 KiB of repeated bytes.
493        let data = vec![0xABu8; 1024];
494        let record = WalRecord::PageWrite {
495            tx_id: 7,
496            page_id: 3,
497            data: data.clone(),
498        };
499        let encoded = record.encode();
500
501        // Should be stored as PageWriteCompressed (compressible > threshold).
502        assert_eq!(encoded[0], RecordType::PageWriteCompressed as u8);
503
504        // And round-trip decoding recovers original.
505        let mut cursor = Cursor::new(encoded);
506        let decoded = WalRecord::read(&mut cursor).unwrap().unwrap();
507        assert_eq!(
508            decoded,
509            WalRecord::PageWrite {
510                tx_id: 7,
511                page_id: 3,
512                data
513            }
514        );
515    }
516
517    #[test]
518    fn test_read_eof() {
519        let mut cursor = Cursor::new(Vec::<u8>::new());
520        let result = WalRecord::read(&mut cursor).unwrap();
521        assert!(result.is_none());
522    }
523
524    #[test]
525    fn test_read_invalid_record_type() {
526        let buf = vec![99, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; // Invalid type 99
527        let mut cursor = Cursor::new(buf);
528        let result = WalRecord::read(&mut cursor);
529        assert!(result.is_err());
530    }
531
532    #[test]
533    fn test_read_checksum_mismatch() {
534        let record = WalRecord::Begin { tx_id: 42 };
535        let mut encoded = record.encode();
536
537        // Corrupt the last byte (checksum)
538        let len = encoded.len();
539        encoded[len - 1] ^= 0xFF;
540
541        let mut cursor = Cursor::new(encoded);
542        let result = WalRecord::read(&mut cursor);
543        assert!(result.is_err());
544    }
545
546    #[test]
547    fn test_read_data_corruption() {
548        let record = WalRecord::PageWrite {
549            tx_id: 1,
550            page_id: 2,
551            data: vec![1, 2, 3, 4],
552        };
553        let mut encoded = record.encode();
554
555        // Corrupt a data byte
556        encoded[15] ^= 0xFF;
557
558        let mut cursor = Cursor::new(encoded);
559        let result = WalRecord::read(&mut cursor);
560        assert!(result.is_err()); // Checksum will fail
561    }
562
563    // ==================== Multiple Records Tests ====================
564
565    #[test]
566    fn test_multiple_records_sequential() {
567        let records = vec![
568            WalRecord::Begin { tx_id: 1 },
569            WalRecord::PageWrite {
570                tx_id: 1,
571                page_id: 10,
572                data: vec![1, 2, 3],
573            },
574            WalRecord::PageWrite {
575                tx_id: 1,
576                page_id: 20,
577                data: vec![4, 5, 6],
578            },
579            WalRecord::Commit { tx_id: 1 },
580            WalRecord::Checkpoint { lsn: 100 },
581        ];
582
583        // Encode all
584        let mut buf = Vec::new();
585        for r in &records {
586            buf.extend_from_slice(&r.encode());
587        }
588
589        // Read them back
590        let mut cursor = Cursor::new(buf);
591        for expected in &records {
592            let decoded = WalRecord::read(&mut cursor).unwrap().unwrap();
593            assert_eq!(&decoded, expected);
594        }
595
596        // Next read should return None (EOF)
597        assert!(WalRecord::read(&mut cursor).unwrap().is_none());
598    }
599
600    // ==================== Constants Tests ====================
601
602    #[test]
603    fn test_wal_magic() {
604        assert_eq!(WAL_MAGIC, b"RDBW");
605    }
606
607    #[test]
608    fn test_wal_version() {
609        assert_eq!(WAL_VERSION, 2);
610    }
611}