Skip to main content

kaya_wal/
codec.rs

1use std::fmt;
2
3use kaya_core::{crc32c, Bytes, KayaError, Lsn, Result, SequenceNumber};
4
5pub const WAL_MAGIC: u32 = 0x4b41_5941;
6pub const WAL_VERSION: u16 = 1;
7pub const WAL_HEADER_LEN: usize = 40;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum WalRecordType {
11    Put = 1,
12    Delete = 2,
13    Noop = 3,
14}
15
16impl WalRecordType {
17    pub fn from_wire(value: u16) -> Option<Self> {
18        match value {
19            1 => Some(Self::Put),
20            2 => Some(Self::Delete),
21            3 => Some(Self::Noop),
22            _ => None,
23        }
24    }
25
26    pub const fn as_wire(self) -> u16 {
27        self as u16
28    }
29}
30
31impl fmt::Display for WalRecordType {
32    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
33        match self {
34            Self::Put => write!(f, "PUT"),
35            Self::Delete => write!(f, "DELETE"),
36            Self::Noop => write!(f, "NOOP"),
37        }
38    }
39}
40
41#[derive(Debug, Clone, PartialEq, Eq)]
42pub enum WalPayload {
43    Put { key: Bytes, value: Bytes },
44    Delete { key: Bytes },
45    Noop,
46}
47
48impl WalPayload {
49    pub const fn record_type(&self) -> WalRecordType {
50        match self {
51            Self::Put { .. } => WalRecordType::Put,
52            Self::Delete { .. } => WalRecordType::Delete,
53            Self::Noop => WalRecordType::Noop,
54        }
55    }
56
57    pub fn key_len(&self) -> Option<usize> {
58        match self {
59            Self::Put { key, .. } | Self::Delete { key } => Some(key.len()),
60            Self::Noop => None,
61        }
62    }
63
64    pub fn value_len(&self) -> Option<usize> {
65        match self {
66            Self::Put { value, .. } => Some(value.len()),
67            Self::Delete { .. } | Self::Noop => None,
68        }
69    }
70}
71
72#[derive(Debug, Clone, PartialEq, Eq)]
73pub struct WalRecord {
74    pub flags: u16,
75    pub lsn: Lsn,
76    pub sequence: SequenceNumber,
77    pub payload: WalPayload,
78}
79
80impl WalRecord {
81    pub fn new(lsn: Lsn, sequence: SequenceNumber, payload: WalPayload) -> Self {
82        Self {
83            flags: 0,
84            lsn,
85            sequence,
86            payload,
87        }
88    }
89
90    pub const fn record_type(&self) -> WalRecordType {
91        self.payload.record_type()
92    }
93}
94
95#[derive(Debug, Clone, PartialEq, Eq)]
96pub enum WalWarning {
97    PartialHeader {
98        offset: u64,
99    },
100    PartialPayload {
101        offset: u64,
102        expected: usize,
103        actual: usize,
104    },
105    BadMagic {
106        offset: u64,
107        found: u32,
108    },
109    UnsupportedVersion {
110        offset: u64,
111        found: u16,
112    },
113    BadHeaderLength {
114        offset: u64,
115        found: u16,
116    },
117    UnknownFlags {
118        offset: u64,
119        found: u16,
120    },
121    UnknownRecordType {
122        offset: u64,
123        found: u16,
124    },
125    OversizedPayload {
126        offset: u64,
127        found: u32,
128        max: u32,
129    },
130    BadHeaderChecksum {
131        offset: u64,
132        expected: u32,
133        actual: u32,
134    },
135    BadPayloadChecksum {
136        offset: u64,
137        expected: u32,
138        actual: u32,
139    },
140    MalformedPayload {
141        offset: u64,
142        message: String,
143    },
144    NonMonotonicLsn {
145        offset: u64,
146        expected: u64,
147        found: u64,
148    },
149    TailTruncated {
150        path: String,
151        valid_len: u64,
152        truncated_bytes: u64,
153    },
154    TrailingSegmentsIgnored {
155        count: usize,
156    },
157}
158
159impl fmt::Display for WalWarning {
160    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
161        match self {
162            Self::PartialHeader { offset } => write!(f, "PartialHeader offset={offset}"),
163            Self::PartialPayload {
164                offset,
165                expected,
166                actual,
167            } => write!(
168                f,
169                "PartialPayload offset={offset} expected={expected} actual={actual}"
170            ),
171            Self::BadMagic { offset, found } => {
172                write!(f, "BadMagic offset={offset} found=0x{found:08x}")
173            }
174            Self::UnsupportedVersion { offset, found } => {
175                write!(f, "UnsupportedVersion offset={offset} found={found}")
176            }
177            Self::BadHeaderLength { offset, found } => {
178                write!(f, "BadHeaderLength offset={offset} found={found}")
179            }
180            Self::UnknownFlags { offset, found } => {
181                write!(f, "UnknownFlags offset={offset} found=0x{found:04x}")
182            }
183            Self::UnknownRecordType { offset, found } => {
184                write!(f, "UnknownRecordType offset={offset} found={found}")
185            }
186            Self::OversizedPayload { offset, found, max } => write!(
187                f,
188                "OversizedPayload offset={offset} found={found} max={max}"
189            ),
190            Self::BadHeaderChecksum {
191                offset,
192                expected,
193                actual,
194            } => write!(
195                f,
196                "BadHeaderChecksum offset={offset} expected=0x{expected:08x} actual=0x{actual:08x}"
197            ),
198            Self::BadPayloadChecksum {
199                offset,
200                expected,
201                actual,
202            } => write!(
203                f,
204                "BadPayloadChecksum offset={offset} expected=0x{expected:08x} actual=0x{actual:08x}"
205            ),
206            Self::MalformedPayload { offset, message } => {
207                write!(f, "MalformedPayload offset={offset} {message}")
208            }
209            Self::NonMonotonicLsn {
210                offset,
211                expected,
212                found,
213            } => write!(
214                f,
215                "NonMonotonicLsn offset={offset} expected={expected} found={found}"
216            ),
217            Self::TailTruncated {
218                path,
219                valid_len,
220                truncated_bytes,
221            } => write!(
222                f,
223                "TailTruncated path={path} valid_len={valid_len} truncated_bytes={truncated_bytes}"
224            ),
225            Self::TrailingSegmentsIgnored { count } => {
226                write!(f, "TrailingSegmentsIgnored count={count}")
227            }
228        }
229    }
230}
231
232#[derive(Debug, Clone, PartialEq, Eq)]
233pub enum DecodeRecordResult {
234    Complete {
235        record: WalRecord,
236        bytes_read: usize,
237    },
238    Incomplete {
239        warning: WalWarning,
240    },
241    Invalid {
242        warning: WalWarning,
243    },
244}
245
246pub fn encode_record(record: &WalRecord) -> Result<Vec<u8>> {
247    let payload = encode_payload(&record.payload)?;
248    let payload_len = u32::try_from(payload.len())
249        .map_err(|_| KayaError::invalid_argument("WAL payload length does not fit into u32"))?;
250    let payload_crc = crc32c(&payload);
251
252    let mut encoded = Vec::with_capacity(WAL_HEADER_LEN + payload.len());
253    encoded.extend_from_slice(&WAL_MAGIC.to_le_bytes());
254    encoded.extend_from_slice(&WAL_VERSION.to_le_bytes());
255    encoded.extend_from_slice(&(WAL_HEADER_LEN as u16).to_le_bytes());
256    encoded.extend_from_slice(&record.flags.to_le_bytes());
257    encoded.extend_from_slice(&record.record_type().as_wire().to_le_bytes());
258    encoded.extend_from_slice(&record.lsn.get().to_le_bytes());
259    encoded.extend_from_slice(&record.sequence.get().to_le_bytes());
260    encoded.extend_from_slice(&payload_len.to_le_bytes());
261    encoded.extend_from_slice(&0_u32.to_le_bytes());
262    encoded.extend_from_slice(&payload_crc.to_le_bytes());
263
264    let header_crc = crc32c(&encoded[..WAL_HEADER_LEN]);
265    encoded[32..36].copy_from_slice(&header_crc.to_le_bytes());
266    encoded.extend_from_slice(&payload);
267    Ok(encoded)
268}
269
270pub fn decode_record(input: &[u8], offset: u64, max_payload_len: u32) -> DecodeRecordResult {
271    if input.is_empty() {
272        return DecodeRecordResult::Incomplete {
273            warning: WalWarning::PartialHeader { offset },
274        };
275    }
276    if input.len() < WAL_HEADER_LEN {
277        return DecodeRecordResult::Incomplete {
278            warning: WalWarning::PartialHeader { offset },
279        };
280    }
281
282    let magic = read_u32(&input[0..4]);
283    if magic != WAL_MAGIC {
284        return DecodeRecordResult::Invalid {
285            warning: WalWarning::BadMagic {
286                offset,
287                found: magic,
288            },
289        };
290    }
291
292    let version = read_u16(&input[4..6]);
293    if version != WAL_VERSION {
294        return DecodeRecordResult::Invalid {
295            warning: WalWarning::UnsupportedVersion {
296                offset,
297                found: version,
298            },
299        };
300    }
301
302    let header_len = read_u16(&input[6..8]);
303    if usize::from(header_len) != WAL_HEADER_LEN {
304        return DecodeRecordResult::Invalid {
305            warning: WalWarning::BadHeaderLength {
306                offset,
307                found: header_len,
308            },
309        };
310    }
311
312    let flags = read_u16(&input[8..10]);
313    if flags != 0 {
314        return DecodeRecordResult::Invalid {
315            warning: WalWarning::UnknownFlags {
316                offset,
317                found: flags,
318            },
319        };
320    }
321
322    let record_type_raw = read_u16(&input[10..12]);
323    let Some(record_type) = WalRecordType::from_wire(record_type_raw) else {
324        return DecodeRecordResult::Invalid {
325            warning: WalWarning::UnknownRecordType {
326                offset,
327                found: record_type_raw,
328            },
329        };
330    };
331
332    let lsn = read_u64(&input[12..20]);
333    let sequence = read_u64(&input[20..28]);
334    let payload_len = read_u32(&input[28..32]);
335    if payload_len > max_payload_len {
336        return DecodeRecordResult::Invalid {
337            warning: WalWarning::OversizedPayload {
338                offset,
339                found: payload_len,
340                max: max_payload_len,
341            },
342        };
343    }
344
345    let actual_header_crc = read_u32(&input[32..36]);
346    let mut header = [0_u8; WAL_HEADER_LEN];
347    header.copy_from_slice(&input[..WAL_HEADER_LEN]);
348    header[32..36].copy_from_slice(&0_u32.to_le_bytes());
349    let expected_header_crc = crc32c(&header);
350    if actual_header_crc != expected_header_crc {
351        return DecodeRecordResult::Invalid {
352            warning: WalWarning::BadHeaderChecksum {
353                offset,
354                expected: expected_header_crc,
355                actual: actual_header_crc,
356            },
357        };
358    }
359
360    let actual_payload_crc = read_u32(&input[36..40]);
361    let total_len = WAL_HEADER_LEN + payload_len as usize;
362    if input.len() < total_len {
363        return DecodeRecordResult::Incomplete {
364            warning: WalWarning::PartialPayload {
365                offset,
366                expected: total_len,
367                actual: input.len(),
368            },
369        };
370    }
371
372    let payload_bytes = &input[WAL_HEADER_LEN..total_len];
373    let expected_payload_crc = crc32c(payload_bytes);
374    if actual_payload_crc != expected_payload_crc {
375        return DecodeRecordResult::Invalid {
376            warning: WalWarning::BadPayloadChecksum {
377                offset,
378                expected: expected_payload_crc,
379                actual: actual_payload_crc,
380            },
381        };
382    }
383
384    let payload = match decode_payload(record_type, payload_bytes) {
385        Ok(payload) => payload,
386        Err(error) => {
387            return DecodeRecordResult::Invalid {
388                warning: WalWarning::MalformedPayload {
389                    offset,
390                    message: error.to_string(),
391                },
392            };
393        }
394    };
395
396    DecodeRecordResult::Complete {
397        record: WalRecord {
398            flags,
399            lsn: Lsn::new(lsn),
400            sequence: SequenceNumber::new(sequence),
401            payload,
402        },
403        bytes_read: total_len,
404    }
405}
406
407fn encode_payload(payload: &WalPayload) -> Result<Vec<u8>> {
408    let mut encoded = Vec::new();
409    match payload {
410        WalPayload::Put { key, value } => {
411            let key_len = u32::try_from(key.len()).map_err(|_| {
412                KayaError::invalid_argument("WAL PUT key length does not fit into u32")
413            })?;
414            let value_len = u32::try_from(value.len()).map_err(|_| {
415                KayaError::invalid_argument("WAL PUT value length does not fit into u32")
416            })?;
417            encoded.extend_from_slice(&key_len.to_le_bytes());
418            encoded.extend_from_slice(&value_len.to_le_bytes());
419            encoded.extend_from_slice(key);
420            encoded.extend_from_slice(value);
421        }
422        WalPayload::Delete { key } => {
423            let key_len = u32::try_from(key.len()).map_err(|_| {
424                KayaError::invalid_argument("WAL DELETE key length does not fit into u32")
425            })?;
426            encoded.extend_from_slice(&key_len.to_le_bytes());
427            encoded.extend_from_slice(key);
428        }
429        WalPayload::Noop => {}
430    }
431    Ok(encoded)
432}
433
434fn decode_payload(record_type: WalRecordType, payload: &[u8]) -> Result<WalPayload> {
435    match record_type {
436        WalRecordType::Put => {
437            if payload.len() < 8 {
438                return Err(KayaError::corruption("PUT payload header is too short"));
439            }
440            let key_len = read_u32(&payload[0..4]) as usize;
441            let value_len = read_u32(&payload[4..8]) as usize;
442            let expected = 8_usize
443                .checked_add(key_len)
444                .and_then(|len| len.checked_add(value_len))
445                .ok_or_else(|| KayaError::corruption("PUT payload length overflows usize"))?;
446            if payload.len() != expected {
447                return Err(KayaError::corruption(format!(
448                    "PUT payload length mismatch: expected {expected}, got {}",
449                    payload.len()
450                )));
451            }
452            let key = payload[8..8 + key_len].to_vec();
453            let value = payload[8 + key_len..].to_vec();
454            Ok(WalPayload::Put { key, value })
455        }
456        WalRecordType::Delete => {
457            if payload.len() < 4 {
458                return Err(KayaError::corruption("DELETE payload header is too short"));
459            }
460            let key_len = read_u32(&payload[0..4]) as usize;
461            let expected = 4_usize
462                .checked_add(key_len)
463                .ok_or_else(|| KayaError::corruption("DELETE payload length overflows usize"))?;
464            if payload.len() != expected {
465                return Err(KayaError::corruption(format!(
466                    "DELETE payload length mismatch: expected {expected}, got {}",
467                    payload.len()
468                )));
469            }
470            Ok(WalPayload::Delete {
471                key: payload[4..].to_vec(),
472            })
473        }
474        WalRecordType::Noop => {
475            if !payload.is_empty() {
476                return Err(KayaError::corruption("NOOP payload must be empty"));
477            }
478            Ok(WalPayload::Noop)
479        }
480    }
481}
482
483fn read_u16(bytes: &[u8]) -> u16 {
484    u16::from_le_bytes(bytes.try_into().expect("slice length checked by caller"))
485}
486
487fn read_u32(bytes: &[u8]) -> u32 {
488    u32::from_le_bytes(bytes.try_into().expect("slice length checked by caller"))
489}
490
491fn read_u64(bytes: &[u8]) -> u64 {
492    u64::from_le_bytes(bytes.try_into().expect("slice length checked by caller"))
493}
494
495#[cfg(test)]
496mod tests {
497    use super::*;
498
499    #[test]
500    fn put_roundtrip() {
501        let record = WalRecord::new(
502            Lsn::new(1),
503            SequenceNumber::new(1),
504            WalPayload::Put {
505                key: b"user:1".to_vec(),
506                value: b"Ada".to_vec(),
507            },
508        );
509        let encoded = encode_record(&record).expect("record encodes");
510        match decode_record(&encoded, 0, 1024) {
511            DecodeRecordResult::Complete {
512                record: decoded, ..
513            } => assert_eq!(decoded, record),
514            other => panic!("unexpected decode result: {other:?}"),
515        }
516    }
517
518    #[test]
519    fn rejects_bad_magic_without_panic() {
520        let record = WalRecord::new(Lsn::new(1), SequenceNumber::new(1), WalPayload::Noop);
521        let mut encoded = encode_record(&record).expect("record encodes");
522        encoded[0] = 0;
523        match decode_record(&encoded, 0, 1024) {
524            DecodeRecordResult::Invalid {
525                warning: WalWarning::BadMagic { .. },
526            } => {}
527            other => panic!("unexpected decode result: {other:?}"),
528        }
529    }
530}