Skip to main content

raft_io/
wal_log.rs

1//! A durable [`RaftLog`] backed by `wal-db`.
2//!
3//! [`WalLog`] gives a node a log that survives a restart. Raft's safety rests on
4//! `current_term`, `voted_for`, and the log entries being durable *before* the
5//! node acts on them — a node that forgot it had voted, or lost an
6//! acknowledged entry, could break consensus. This type provides that durability
7//! while presenting the same [`RaftLog`] interface as the in-memory store, so
8//! the protocol core is unchanged.
9//!
10//! # Design
11//!
12//! The store is log-structured. Every mutation — an appended entry, a hard-state
13//! update, a truncation — is encoded as a record and appended to a `wal-db`
14//! write-ahead log, which frames and checksums each record. An in-memory index
15//! (a [`MemoryLog`]) mirrors the current state for fast reads. On
16//! [`open`](WalLog::open) the records are replayed in order to rebuild that index
17//! exactly. Installing a snapshot writes a snapshot record and then physically
18//! drops every earlier record from the WAL (re-persisting the current hard state
19//! first), so the file stays bounded as the log is compacted.
20//!
21//! [`RaftLog`]: crate::RaftLog
22
23use wal_db::Wal;
24
25use crate::error::{Error, Result};
26use crate::log::{MemoryLog, RaftLog};
27use crate::types::{EntryKind, HardState, Index, LogEntry, NodeId, Snapshot, Term};
28
29/// Record tag for an appended [`LogEntry`].
30const TAG_ENTRY: u8 = 1;
31/// Record tag for a [`HardState`] update.
32const TAG_HARD_STATE: u8 = 2;
33/// Record tag for a truncation to a given index.
34const TAG_TRUNCATE: u8 = 3;
35/// Record tag for an installed [`Snapshot`].
36const TAG_SNAPSHOT: u8 = 4;
37
38/// A durable [`RaftLog`] whose entries and hard state survive a process restart.
39///
40/// Open it with [`open`](WalLog::open) and hand it to
41/// [`RaftNode::with_log`](crate::RaftNode::with_log). Reads are served from an
42/// in-memory index; writes are appended to the underlying `wal-db` log and
43/// become durable when [`sync`](RaftLog::sync) returns `Ok`.
44///
45/// # Examples
46///
47/// ```no_run
48/// use raft_io::{LogEntry, RaftLog, WalLog};
49///
50/// let mut log = WalLog::open("raft.wal")?;
51/// log.append(&[LogEntry::new(1, 1, b"set x = 1".to_vec())])?;
52/// log.sync()?; // durable from here
53///
54/// // After a restart, reopening the same path recovers the entry.
55/// let recovered = WalLog::open("raft.wal")?;
56/// assert_eq!(recovered.last_index(), 1);
57/// # Ok::<(), raft_io::Error>(())
58/// ```
59#[cfg_attr(docsrs, doc(cfg(feature = "persistence")))]
60pub struct WalLog {
61    wal: Wal,
62    index: MemoryLog,
63}
64
65impl WalLog {
66    /// Opens the durable log at `path`, replaying any existing records to recover
67    /// the log entries and hard state.
68    ///
69    /// Creates the file if it does not exist. Recovery is exact: the recovered
70    /// state is the logical result of every mutation that was appended before the
71    /// process stopped.
72    ///
73    /// # Errors
74    ///
75    /// Returns [`Error::Storage`] if the WAL cannot be opened or a record fails
76    /// to decode (for example, a checksum mismatch reported by `wal-db`).
77    pub fn open(path: impl AsRef<std::path::Path>) -> Result<Self> {
78        let wal = Wal::open(path).map_err(|e| Error::storage("open durable log", e))?;
79        let mut index = MemoryLog::new();
80        let iter = wal
81            .iter()
82            .map_err(|e| Error::storage("read durable log", e))?;
83        for record in iter {
84            let record = record.map_err(|e| Error::storage("read durable log record", e))?;
85            match decode(record.data())? {
86                Decoded::Entry(entry) => index.append(&[entry])?,
87                Decoded::HardState(hs) => index.set_hard_state(hs)?,
88                Decoded::Truncate(from) => index.truncate(from)?,
89                Decoded::Snapshot(snapshot) => index.apply_snapshot(&snapshot)?,
90            }
91        }
92        Ok(Self { wal, index })
93    }
94
95    /// Appends `record` to the WAL, mapping any backend failure to a storage
96    /// error tagged with `context`.
97    fn write(&self, context: &'static str, record: &[u8]) -> Result<()> {
98        self.wal
99            .append(record)
100            .map(|_lsn| ())
101            .map_err(|e| Error::storage(context, e))
102    }
103}
104
105impl RaftLog for WalLog {
106    #[inline]
107    fn last_index(&self) -> Index {
108        self.index.last_index()
109    }
110
111    #[inline]
112    fn last_term(&self) -> Term {
113        self.index.last_term()
114    }
115
116    #[inline]
117    fn term_at(&self, index: Index) -> Option<Term> {
118        self.index.term_at(index)
119    }
120
121    #[inline]
122    fn entry(&self, index: Index) -> Option<LogEntry> {
123        self.index.entry(index)
124    }
125
126    #[inline]
127    fn entries(&self, from: Index, to: Index) -> Vec<LogEntry> {
128        self.index.entries(from, to)
129    }
130
131    fn append(&mut self, entries: &[LogEntry]) -> Result<()> {
132        // Validate contiguity against the in-memory index first, so a bad batch
133        // is rejected before any record reaches the durable log.
134        self.index.append(entries)?;
135        for entry in entries {
136            self.write("append entry to durable log", &encode_entry(entry))?;
137        }
138        Ok(())
139    }
140
141    fn truncate(&mut self, from: Index) -> Result<()> {
142        self.index.truncate(from)?;
143        self.write("truncate durable log", &encode_truncate(from))
144    }
145
146    #[inline]
147    fn hard_state(&self) -> HardState {
148        self.index.hard_state()
149    }
150
151    fn set_hard_state(&mut self, state: HardState) -> Result<()> {
152        self.index.set_hard_state(state)?;
153        self.write("persist hard state", &encode_hard_state(&state))
154    }
155
156    fn sync(&mut self) -> Result<()> {
157        self.wal
158            .sync()
159            .map_err(|e| Error::storage("sync durable log", e))
160    }
161
162    #[inline]
163    fn snapshot_index(&self) -> Index {
164        self.index.snapshot_index()
165    }
166
167    fn snapshot(&self) -> Option<Snapshot> {
168        self.index.snapshot()
169    }
170
171    fn apply_snapshot(&mut self, snapshot: &Snapshot) -> Result<()> {
172        if snapshot.index <= self.index.snapshot_index() {
173            return Ok(()); // stale; nothing to persist
174        }
175        // Compact the in-memory index first.
176        self.index.apply_snapshot(snapshot)?;
177        // Persist the snapshot record, then re-write the current hard state so
178        // the latest term/vote sits *after* the snapshot in the log.
179        let lsn = self
180            .wal
181            .append(&encode_snapshot(snapshot))
182            .map_err(|e| Error::storage("persist snapshot", e))?;
183        self.write(
184            "persist hard state",
185            &encode_hard_state(&self.index.hard_state()),
186        )?;
187        // Physically drop every record before the snapshot. This is an
188        // optimisation: if it fails, the WAL is merely larger — replay still
189        // re-applies the snapshot record and reconstructs the same state — so the
190        // outcome is deliberately ignored rather than turned into a fatal error.
191        let _ = self.wal.truncate_before(lsn);
192        Ok(())
193    }
194}
195
196// ---- record codec --------------------------------------------------------
197
198/// A decoded WAL record.
199enum Decoded {
200    Entry(LogEntry),
201    HardState(HardState),
202    Truncate(Index),
203    Snapshot(Snapshot),
204}
205
206fn encode_snapshot(snapshot: &Snapshot) -> Vec<u8> {
207    let mut buf =
208        Vec::with_capacity(1 + 8 + 8 + 8 + snapshot.config.len() * 8 + 8 + snapshot.data.len());
209    buf.push(TAG_SNAPSHOT);
210    buf.extend_from_slice(&snapshot.index.to_le_bytes());
211    buf.extend_from_slice(&snapshot.term.to_le_bytes());
212    buf.extend_from_slice(&(snapshot.config.len() as u64).to_le_bytes());
213    for &id in &snapshot.config {
214        buf.extend_from_slice(&id.to_le_bytes());
215    }
216    buf.extend_from_slice(&(snapshot.data.len() as u64).to_le_bytes());
217    buf.extend_from_slice(&snapshot.data);
218    buf
219}
220
221/// On-disk byte for an [`EntryKind`].
222fn kind_byte(kind: EntryKind) -> u8 {
223    match kind {
224        EntryKind::Normal => 0,
225        EntryKind::Config => 1,
226    }
227}
228
229/// Reads an [`EntryKind`] from its on-disk byte.
230fn kind_from_byte(byte: u8) -> Result<EntryKind> {
231    match byte {
232        0 => Ok(EntryKind::Normal),
233        1 => Ok(EntryKind::Config),
234        other => Err(Error::storage(
235            "decode durable log record",
236            format!("unknown entry kind {other}"),
237        )),
238    }
239}
240
241fn encode_entry(entry: &LogEntry) -> Vec<u8> {
242    let mut buf = Vec::with_capacity(1 + 8 + 8 + 1 + 8 + entry.command.len());
243    buf.push(TAG_ENTRY);
244    buf.extend_from_slice(&entry.term.to_le_bytes());
245    buf.extend_from_slice(&entry.index.to_le_bytes());
246    buf.push(kind_byte(entry.kind));
247    buf.extend_from_slice(&(entry.command.len() as u64).to_le_bytes());
248    buf.extend_from_slice(&entry.command);
249    buf
250}
251
252fn encode_hard_state(state: &HardState) -> Vec<u8> {
253    let mut buf = Vec::with_capacity(1 + 8 + 1 + 8);
254    buf.push(TAG_HARD_STATE);
255    buf.extend_from_slice(&state.term.to_le_bytes());
256    match state.voted_for {
257        Some(id) => {
258            buf.push(1);
259            buf.extend_from_slice(&id.to_le_bytes());
260        }
261        None => {
262            buf.push(0);
263            buf.extend_from_slice(&0u64.to_le_bytes());
264        }
265    }
266    buf
267}
268
269fn encode_truncate(from: Index) -> Vec<u8> {
270    let mut buf = Vec::with_capacity(1 + 8);
271    buf.push(TAG_TRUNCATE);
272    buf.extend_from_slice(&from.to_le_bytes());
273    buf
274}
275
276/// Reads a little-endian `u64` at `offset`, bounds-checked.
277fn read_u64(data: &[u8], offset: usize) -> Result<u64> {
278    let end = offset
279        .checked_add(8)
280        .filter(|&e| e <= data.len())
281        .ok_or_else(|| Error::storage("decode durable log record", "record truncated"))?;
282    let mut bytes = [0u8; 8];
283    bytes.copy_from_slice(&data[offset..end]);
284    Ok(u64::from_le_bytes(bytes))
285}
286
287fn decode(data: &[u8]) -> Result<Decoded> {
288    let (&tag, rest_at) = match data.split_first() {
289        Some((tag, _)) => (tag, 1usize),
290        None => return Err(Error::storage("decode durable log record", "empty record")),
291    };
292    match tag {
293        TAG_ENTRY => {
294            let term = read_u64(data, rest_at)?;
295            let index = read_u64(data, rest_at + 8)?;
296            let kind =
297                kind_from_byte(*data.get(rest_at + 16).ok_or_else(|| {
298                    Error::storage("decode durable log record", "entry truncated")
299                })?)?;
300            let len = read_u64(data, rest_at + 17)? as usize;
301            let start = rest_at + 25;
302            let end = start
303                .checked_add(len)
304                .filter(|&e| e == data.len())
305                .ok_or_else(|| {
306                    Error::storage("decode durable log record", "entry length mismatch")
307                })?;
308            Ok(Decoded::Entry(LogEntry {
309                term,
310                index,
311                kind,
312                command: data[start..end].to_vec(),
313            }))
314        }
315        TAG_HARD_STATE => {
316            let term = read_u64(data, rest_at)?;
317            let flag = *data.get(rest_at + 8).ok_or_else(|| {
318                Error::storage("decode durable log record", "hard-state truncated")
319            })?;
320            let vote = read_u64(data, rest_at + 9)?;
321            let voted_for = if flag == 1 { Some(vote) } else { None };
322            Ok(Decoded::HardState(HardState { term, voted_for }))
323        }
324        TAG_TRUNCATE => {
325            let from = read_u64(data, rest_at)?;
326            Ok(Decoded::Truncate(from))
327        }
328        TAG_SNAPSHOT => {
329            let index = read_u64(data, rest_at)?;
330            let term = read_u64(data, rest_at + 8)?;
331            let config_count = read_u64(data, rest_at + 16)?;
332            // Bound the count to the bytes actually present before allocating, so
333            // a corrupt or hostile length cannot trigger a giant allocation. Each
334            // member is 8 bytes, and at least the trailing data length must follow.
335            let max_members = (data.len().saturating_sub(rest_at + 24) / 8) as u64;
336            if config_count > max_members {
337                return Err(Error::storage(
338                    "decode durable log record",
339                    "snapshot configuration length exceeds record",
340                ));
341            }
342            let config_count = config_count as usize;
343            let mut config = Vec::with_capacity(config_count);
344            let mut off = rest_at + 24;
345            for _ in 0..config_count {
346                config.push(read_u64(data, off)? as NodeId);
347                off += 8;
348            }
349            let len = read_u64(data, off)? as usize;
350            let start = off + 8;
351            let end = start
352                .checked_add(len)
353                .filter(|&e| e == data.len())
354                .ok_or_else(|| {
355                    Error::storage("decode durable log record", "snapshot length mismatch")
356                })?;
357            Ok(Decoded::Snapshot(Snapshot::with_config(
358                index,
359                term,
360                config,
361                data[start..end].to_vec(),
362            )))
363        }
364        other => Err(Error::storage(
365            "decode durable log record",
366            format!("unknown record tag {other}"),
367        )),
368    }
369}
370
371#[cfg(test)]
372mod tests {
373    #![allow(clippy::unwrap_used, clippy::expect_used)]
374
375    use super::*;
376
377    fn entry(term: Term, index: Index, cmd: &[u8]) -> LogEntry {
378        LogEntry::new(term, index, cmd.to_vec())
379    }
380
381    fn temp_path() -> (tempfile::TempDir, std::path::PathBuf) {
382        let dir = tempfile::tempdir().unwrap();
383        let path = dir.path().join("raft.wal");
384        (dir, path)
385    }
386
387    #[test]
388    fn test_entry_codec_round_trips() {
389        let e = entry(3, 9, b"hello world");
390        match decode(&encode_entry(&e)).unwrap() {
391            Decoded::Entry(got) => assert_eq!(got, e),
392            _ => panic!("wrong record"),
393        }
394    }
395
396    #[test]
397    fn test_snapshot_record_hostile_config_length_is_rejected() {
398        // A snapshot record claiming a huge member count must be rejected, not
399        // turned into a giant `Vec::with_capacity` that aborts the process.
400        let mut bad = vec![TAG_SNAPSHOT];
401        bad.extend_from_slice(&5u64.to_le_bytes()); // index
402        bad.extend_from_slice(&2u64.to_le_bytes()); // term
403        bad.extend_from_slice(&u64::MAX.to_le_bytes()); // config_count: hostile
404        assert!(decode(&bad).is_err());
405    }
406
407    proptest::proptest! {
408        /// Fuzz the WAL record decoder: arbitrary bytes must yield `Ok` or `Err`,
409        /// never a panic or an unbounded allocation, so a corrupt or hostile
410        /// record cannot crash recovery.
411        #[test]
412        fn wal_decode_never_panics(
413            bytes in proptest::collection::vec(proptest::prelude::any::<u8>(), 0..512)
414        ) {
415            let _ = decode(&bytes);
416        }
417    }
418
419    #[test]
420    fn test_hard_state_codec_round_trips() {
421        for hs in [
422            HardState {
423                term: 7,
424                voted_for: Some(4),
425            },
426            HardState {
427                term: 0,
428                voted_for: None,
429            },
430        ] {
431            match decode(&encode_hard_state(&hs)).unwrap() {
432                Decoded::HardState(got) => assert_eq!(got, hs),
433                _ => panic!("wrong record"),
434            }
435        }
436    }
437
438    #[test]
439    fn test_truncate_codec_round_trips() {
440        match decode(&encode_truncate(5)).unwrap() {
441            Decoded::Truncate(from) => assert_eq!(from, 5),
442            _ => panic!("wrong record"),
443        }
444    }
445
446    #[test]
447    fn test_decode_rejects_malformed() {
448        assert!(decode(&[]).is_err()); // empty
449        assert!(decode(&[TAG_ENTRY, 1, 2, 3]).is_err()); // truncated entry
450        assert!(decode(&[TAG_TRUNCATE, 0, 0]).is_err()); // short index
451        assert!(decode(&[99]).is_err()); // unknown tag
452        // Entry claiming a longer command than is present.
453        let mut bad = encode_entry(&entry(1, 1, b"x"));
454        let _ = bad.pop(); // drop the command byte; length now mismatches
455        assert!(decode(&bad).is_err());
456    }
457
458    #[test]
459    fn test_append_sync_recover() {
460        let (_dir, path) = temp_path();
461        {
462            let mut log = WalLog::open(&path).unwrap();
463            log.append(&[entry(1, 1, b"a"), entry(1, 2, b"b")]).unwrap();
464            log.set_hard_state(HardState {
465                term: 1,
466                voted_for: Some(2),
467            })
468            .unwrap();
469            log.sync().unwrap();
470        }
471        let recovered = WalLog::open(&path).unwrap();
472        assert_eq!(recovered.last_index(), 2);
473        assert_eq!(recovered.last_term(), 1);
474        assert_eq!(recovered.entry(2).unwrap().command, b"b");
475        assert_eq!(
476            recovered.hard_state(),
477            HardState {
478                term: 1,
479                voted_for: Some(2)
480            }
481        );
482    }
483
484    #[test]
485    fn test_truncation_survives_recovery() {
486        let (_dir, path) = temp_path();
487        {
488            let mut log = WalLog::open(&path).unwrap();
489            log.append(&[entry(1, 1, b"a"), entry(1, 2, b"b"), entry(1, 3, b"c")])
490                .unwrap();
491            log.truncate(2).unwrap(); // drop indices >= 2
492            log.append(&[entry(2, 2, b"B")]).unwrap(); // re-write index 2 in a new term
493            log.sync().unwrap();
494        }
495        let recovered = WalLog::open(&path).unwrap();
496        assert_eq!(recovered.last_index(), 2);
497        assert_eq!(recovered.entry(2).unwrap().term, 2);
498        assert_eq!(recovered.entry(2).unwrap().command, b"B");
499        assert_eq!(recovered.entry(3), None);
500    }
501
502    #[test]
503    fn test_latest_hard_state_wins_on_recovery() {
504        let (_dir, path) = temp_path();
505        {
506            let mut log = WalLog::open(&path).unwrap();
507            log.set_hard_state(HardState {
508                term: 1,
509                voted_for: Some(1),
510            })
511            .unwrap();
512            log.set_hard_state(HardState {
513                term: 2,
514                voted_for: None,
515            })
516            .unwrap();
517            log.set_hard_state(HardState {
518                term: 3,
519                voted_for: Some(2),
520            })
521            .unwrap();
522            log.sync().unwrap();
523        }
524        let recovered = WalLog::open(&path).unwrap();
525        assert_eq!(
526            recovered.hard_state(),
527            HardState {
528                term: 3,
529                voted_for: Some(2)
530            }
531        );
532    }
533
534    #[test]
535    fn test_snapshot_compaction_survives_recovery() {
536        let (_dir, path) = temp_path();
537        {
538            let mut log = WalLog::open(&path).unwrap();
539            log.append(&[entry(1, 1, b"a"), entry(1, 2, b"b"), entry(2, 3, b"c")])
540                .unwrap();
541            log.apply_snapshot(&Snapshot::new(2, 1, b"state@2".to_vec()))
542                .unwrap();
543            log.append(&[entry(2, 4, b"d")]).unwrap();
544            log.sync().unwrap();
545        }
546        let recovered = WalLog::open(&path).unwrap();
547        // The snapshot boundary, the surviving tail, and the snapshot bytes all
548        // came back; compacted entries did not.
549        assert_eq!(recovered.snapshot_index(), 2);
550        assert_eq!(recovered.last_index(), 4);
551        assert_eq!(recovered.entry(1), None);
552        assert_eq!(recovered.entry(2), None);
553        assert_eq!(recovered.term_at(2), Some(1));
554        assert_eq!(recovered.entry(3).unwrap().command, b"c");
555        assert_eq!(recovered.entry(4).unwrap().command, b"d");
556        assert_eq!(recovered.snapshot().unwrap().data, b"state@2");
557    }
558
559    #[test]
560    fn test_snapshot_codec_round_trips() {
561        let snap = Snapshot::with_config(9, 4, vec![1, 2, 3], b"payload".to_vec());
562        match decode(&encode_snapshot(&snap)).unwrap() {
563            Decoded::Snapshot(got) => assert_eq!(got, snap),
564            _ => panic!("wrong record"),
565        }
566    }
567
568    #[test]
569    fn test_config_entry_and_snapshot_membership_survive_recovery() {
570        let (_dir, path) = temp_path();
571        {
572            let mut log = WalLog::open(&path).unwrap();
573            log.apply_snapshot(&Snapshot::with_config(2, 1, vec![1, 2, 3], b"s".to_vec()))
574                .unwrap();
575            log.append(&[LogEntry::config(2, 3, &[1, 2, 3, 4])])
576                .unwrap();
577            log.sync().unwrap();
578        }
579        let recovered = WalLog::open(&path).unwrap();
580        assert_eq!(recovered.snapshot().unwrap().config, vec![1, 2, 3]);
581        assert_eq!(
582            recovered.entry(3).unwrap().members(),
583            Some(vec![1, 2, 3, 4])
584        );
585    }
586
587    #[test]
588    fn test_empty_log_opens_clean() {
589        let (_dir, path) = temp_path();
590        let log = WalLog::open(&path).unwrap();
591        assert_eq!(log.last_index(), 0);
592        assert_eq!(log.hard_state(), HardState::default());
593    }
594
595    #[test]
596    fn test_non_contiguous_append_is_rejected_before_write() {
597        let (_dir, path) = temp_path();
598        let mut log = WalLog::open(&path).unwrap();
599        assert!(log.append(&[entry(1, 5, b"x")]).is_err());
600        // The rejected batch left nothing behind.
601        assert_eq!(log.last_index(), 0);
602        drop(log);
603        assert_eq!(WalLog::open(&path).unwrap().last_index(), 0);
604    }
605}