1use wal_db::Wal;
24
25use crate::error::{Error, Result};
26use crate::log::{MemoryLog, RaftLog};
27use crate::types::{EntryKind, HardState, Index, LogEntry, NodeId, Snapshot, Term};
28
29const TAG_ENTRY: u8 = 1;
31const TAG_HARD_STATE: u8 = 2;
33const TAG_TRUNCATE: u8 = 3;
35const TAG_SNAPSHOT: u8 = 4;
37
38#[cfg_attr(docsrs, doc(cfg(feature = "persistence")))]
60pub struct WalLog {
61 wal: Wal,
62 index: MemoryLog,
63}
64
65impl WalLog {
66 pub fn open(path: impl AsRef<std::path::Path>) -> Result<Self> {
78 let wal = Wal::open(path).map_err(|e| Error::storage("open durable log", e))?;
79 let mut index = MemoryLog::new();
80 let iter = wal
81 .iter()
82 .map_err(|e| Error::storage("read durable log", e))?;
83 for record in iter {
84 let record = record.map_err(|e| Error::storage("read durable log record", e))?;
85 match decode(record.data())? {
86 Decoded::Entry(entry) => index.append(&[entry])?,
87 Decoded::HardState(hs) => index.set_hard_state(hs)?,
88 Decoded::Truncate(from) => index.truncate(from)?,
89 Decoded::Snapshot(snapshot) => index.apply_snapshot(&snapshot)?,
90 }
91 }
92 Ok(Self { wal, index })
93 }
94
95 fn write(&self, context: &'static str, record: &[u8]) -> Result<()> {
98 self.wal
99 .append(record)
100 .map(|_lsn| ())
101 .map_err(|e| Error::storage(context, e))
102 }
103}
104
105impl RaftLog for WalLog {
106 #[inline]
107 fn last_index(&self) -> Index {
108 self.index.last_index()
109 }
110
111 #[inline]
112 fn last_term(&self) -> Term {
113 self.index.last_term()
114 }
115
116 #[inline]
117 fn term_at(&self, index: Index) -> Option<Term> {
118 self.index.term_at(index)
119 }
120
121 #[inline]
122 fn entry(&self, index: Index) -> Option<LogEntry> {
123 self.index.entry(index)
124 }
125
126 #[inline]
127 fn entries(&self, from: Index, to: Index) -> Vec<LogEntry> {
128 self.index.entries(from, to)
129 }
130
131 fn append(&mut self, entries: &[LogEntry]) -> Result<()> {
132 self.index.append(entries)?;
135 for entry in entries {
136 self.write("append entry to durable log", &encode_entry(entry))?;
137 }
138 Ok(())
139 }
140
141 fn truncate(&mut self, from: Index) -> Result<()> {
142 self.index.truncate(from)?;
143 self.write("truncate durable log", &encode_truncate(from))
144 }
145
146 #[inline]
147 fn hard_state(&self) -> HardState {
148 self.index.hard_state()
149 }
150
151 fn set_hard_state(&mut self, state: HardState) -> Result<()> {
152 self.index.set_hard_state(state)?;
153 self.write("persist hard state", &encode_hard_state(&state))
154 }
155
156 fn sync(&mut self) -> Result<()> {
157 self.wal
158 .sync()
159 .map_err(|e| Error::storage("sync durable log", e))
160 }
161
162 #[inline]
163 fn snapshot_index(&self) -> Index {
164 self.index.snapshot_index()
165 }
166
167 fn snapshot(&self) -> Option<Snapshot> {
168 self.index.snapshot()
169 }
170
171 fn apply_snapshot(&mut self, snapshot: &Snapshot) -> Result<()> {
172 if snapshot.index <= self.index.snapshot_index() {
173 return Ok(()); }
175 self.index.apply_snapshot(snapshot)?;
177 let lsn = self
180 .wal
181 .append(&encode_snapshot(snapshot))
182 .map_err(|e| Error::storage("persist snapshot", e))?;
183 self.write(
184 "persist hard state",
185 &encode_hard_state(&self.index.hard_state()),
186 )?;
187 let _ = self.wal.truncate_before(lsn);
192 Ok(())
193 }
194}
195
196enum Decoded {
200 Entry(LogEntry),
201 HardState(HardState),
202 Truncate(Index),
203 Snapshot(Snapshot),
204}
205
206fn encode_snapshot(snapshot: &Snapshot) -> Vec<u8> {
207 let mut buf =
208 Vec::with_capacity(1 + 8 + 8 + 8 + snapshot.config.len() * 8 + 8 + snapshot.data.len());
209 buf.push(TAG_SNAPSHOT);
210 buf.extend_from_slice(&snapshot.index.to_le_bytes());
211 buf.extend_from_slice(&snapshot.term.to_le_bytes());
212 buf.extend_from_slice(&(snapshot.config.len() as u64).to_le_bytes());
213 for &id in &snapshot.config {
214 buf.extend_from_slice(&id.to_le_bytes());
215 }
216 buf.extend_from_slice(&(snapshot.data.len() as u64).to_le_bytes());
217 buf.extend_from_slice(&snapshot.data);
218 buf
219}
220
221fn kind_byte(kind: EntryKind) -> u8 {
223 match kind {
224 EntryKind::Normal => 0,
225 EntryKind::Config => 1,
226 }
227}
228
229fn kind_from_byte(byte: u8) -> Result<EntryKind> {
231 match byte {
232 0 => Ok(EntryKind::Normal),
233 1 => Ok(EntryKind::Config),
234 other => Err(Error::storage(
235 "decode durable log record",
236 format!("unknown entry kind {other}"),
237 )),
238 }
239}
240
241fn encode_entry(entry: &LogEntry) -> Vec<u8> {
242 let mut buf = Vec::with_capacity(1 + 8 + 8 + 1 + 8 + entry.command.len());
243 buf.push(TAG_ENTRY);
244 buf.extend_from_slice(&entry.term.to_le_bytes());
245 buf.extend_from_slice(&entry.index.to_le_bytes());
246 buf.push(kind_byte(entry.kind));
247 buf.extend_from_slice(&(entry.command.len() as u64).to_le_bytes());
248 buf.extend_from_slice(&entry.command);
249 buf
250}
251
252fn encode_hard_state(state: &HardState) -> Vec<u8> {
253 let mut buf = Vec::with_capacity(1 + 8 + 1 + 8);
254 buf.push(TAG_HARD_STATE);
255 buf.extend_from_slice(&state.term.to_le_bytes());
256 match state.voted_for {
257 Some(id) => {
258 buf.push(1);
259 buf.extend_from_slice(&id.to_le_bytes());
260 }
261 None => {
262 buf.push(0);
263 buf.extend_from_slice(&0u64.to_le_bytes());
264 }
265 }
266 buf
267}
268
269fn encode_truncate(from: Index) -> Vec<u8> {
270 let mut buf = Vec::with_capacity(1 + 8);
271 buf.push(TAG_TRUNCATE);
272 buf.extend_from_slice(&from.to_le_bytes());
273 buf
274}
275
276fn read_u64(data: &[u8], offset: usize) -> Result<u64> {
278 let end = offset
279 .checked_add(8)
280 .filter(|&e| e <= data.len())
281 .ok_or_else(|| Error::storage("decode durable log record", "record truncated"))?;
282 let mut bytes = [0u8; 8];
283 bytes.copy_from_slice(&data[offset..end]);
284 Ok(u64::from_le_bytes(bytes))
285}
286
287fn decode(data: &[u8]) -> Result<Decoded> {
288 let (&tag, rest_at) = match data.split_first() {
289 Some((tag, _)) => (tag, 1usize),
290 None => return Err(Error::storage("decode durable log record", "empty record")),
291 };
292 match tag {
293 TAG_ENTRY => {
294 let term = read_u64(data, rest_at)?;
295 let index = read_u64(data, rest_at + 8)?;
296 let kind =
297 kind_from_byte(*data.get(rest_at + 16).ok_or_else(|| {
298 Error::storage("decode durable log record", "entry truncated")
299 })?)?;
300 let len = read_u64(data, rest_at + 17)? as usize;
301 let start = rest_at + 25;
302 let end = start
303 .checked_add(len)
304 .filter(|&e| e == data.len())
305 .ok_or_else(|| {
306 Error::storage("decode durable log record", "entry length mismatch")
307 })?;
308 Ok(Decoded::Entry(LogEntry {
309 term,
310 index,
311 kind,
312 command: data[start..end].to_vec(),
313 }))
314 }
315 TAG_HARD_STATE => {
316 let term = read_u64(data, rest_at)?;
317 let flag = *data.get(rest_at + 8).ok_or_else(|| {
318 Error::storage("decode durable log record", "hard-state truncated")
319 })?;
320 let vote = read_u64(data, rest_at + 9)?;
321 let voted_for = if flag == 1 { Some(vote) } else { None };
322 Ok(Decoded::HardState(HardState { term, voted_for }))
323 }
324 TAG_TRUNCATE => {
325 let from = read_u64(data, rest_at)?;
326 Ok(Decoded::Truncate(from))
327 }
328 TAG_SNAPSHOT => {
329 let index = read_u64(data, rest_at)?;
330 let term = read_u64(data, rest_at + 8)?;
331 let config_count = read_u64(data, rest_at + 16)?;
332 let max_members = (data.len().saturating_sub(rest_at + 24) / 8) as u64;
336 if config_count > max_members {
337 return Err(Error::storage(
338 "decode durable log record",
339 "snapshot configuration length exceeds record",
340 ));
341 }
342 let config_count = config_count as usize;
343 let mut config = Vec::with_capacity(config_count);
344 let mut off = rest_at + 24;
345 for _ in 0..config_count {
346 config.push(read_u64(data, off)? as NodeId);
347 off += 8;
348 }
349 let len = read_u64(data, off)? as usize;
350 let start = off + 8;
351 let end = start
352 .checked_add(len)
353 .filter(|&e| e == data.len())
354 .ok_or_else(|| {
355 Error::storage("decode durable log record", "snapshot length mismatch")
356 })?;
357 Ok(Decoded::Snapshot(Snapshot::with_config(
358 index,
359 term,
360 config,
361 data[start..end].to_vec(),
362 )))
363 }
364 other => Err(Error::storage(
365 "decode durable log record",
366 format!("unknown record tag {other}"),
367 )),
368 }
369}
370
371#[cfg(test)]
372mod tests {
373 #![allow(clippy::unwrap_used, clippy::expect_used)]
374
375 use super::*;
376
377 fn entry(term: Term, index: Index, cmd: &[u8]) -> LogEntry {
378 LogEntry::new(term, index, cmd.to_vec())
379 }
380
381 fn temp_path() -> (tempfile::TempDir, std::path::PathBuf) {
382 let dir = tempfile::tempdir().unwrap();
383 let path = dir.path().join("raft.wal");
384 (dir, path)
385 }
386
387 #[test]
388 fn test_entry_codec_round_trips() {
389 let e = entry(3, 9, b"hello world");
390 match decode(&encode_entry(&e)).unwrap() {
391 Decoded::Entry(got) => assert_eq!(got, e),
392 _ => panic!("wrong record"),
393 }
394 }
395
396 #[test]
397 fn test_snapshot_record_hostile_config_length_is_rejected() {
398 let mut bad = vec![TAG_SNAPSHOT];
401 bad.extend_from_slice(&5u64.to_le_bytes()); bad.extend_from_slice(&2u64.to_le_bytes()); bad.extend_from_slice(&u64::MAX.to_le_bytes()); assert!(decode(&bad).is_err());
405 }
406
407 proptest::proptest! {
408 #[test]
412 fn wal_decode_never_panics(
413 bytes in proptest::collection::vec(proptest::prelude::any::<u8>(), 0..512)
414 ) {
415 let _ = decode(&bytes);
416 }
417 }
418
419 #[test]
420 fn test_hard_state_codec_round_trips() {
421 for hs in [
422 HardState {
423 term: 7,
424 voted_for: Some(4),
425 },
426 HardState {
427 term: 0,
428 voted_for: None,
429 },
430 ] {
431 match decode(&encode_hard_state(&hs)).unwrap() {
432 Decoded::HardState(got) => assert_eq!(got, hs),
433 _ => panic!("wrong record"),
434 }
435 }
436 }
437
438 #[test]
439 fn test_truncate_codec_round_trips() {
440 match decode(&encode_truncate(5)).unwrap() {
441 Decoded::Truncate(from) => assert_eq!(from, 5),
442 _ => panic!("wrong record"),
443 }
444 }
445
446 #[test]
447 fn test_decode_rejects_malformed() {
448 assert!(decode(&[]).is_err()); assert!(decode(&[TAG_ENTRY, 1, 2, 3]).is_err()); assert!(decode(&[TAG_TRUNCATE, 0, 0]).is_err()); assert!(decode(&[99]).is_err()); let mut bad = encode_entry(&entry(1, 1, b"x"));
454 let _ = bad.pop(); assert!(decode(&bad).is_err());
456 }
457
458 #[test]
459 fn test_append_sync_recover() {
460 let (_dir, path) = temp_path();
461 {
462 let mut log = WalLog::open(&path).unwrap();
463 log.append(&[entry(1, 1, b"a"), entry(1, 2, b"b")]).unwrap();
464 log.set_hard_state(HardState {
465 term: 1,
466 voted_for: Some(2),
467 })
468 .unwrap();
469 log.sync().unwrap();
470 }
471 let recovered = WalLog::open(&path).unwrap();
472 assert_eq!(recovered.last_index(), 2);
473 assert_eq!(recovered.last_term(), 1);
474 assert_eq!(recovered.entry(2).unwrap().command, b"b");
475 assert_eq!(
476 recovered.hard_state(),
477 HardState {
478 term: 1,
479 voted_for: Some(2)
480 }
481 );
482 }
483
484 #[test]
485 fn test_truncation_survives_recovery() {
486 let (_dir, path) = temp_path();
487 {
488 let mut log = WalLog::open(&path).unwrap();
489 log.append(&[entry(1, 1, b"a"), entry(1, 2, b"b"), entry(1, 3, b"c")])
490 .unwrap();
491 log.truncate(2).unwrap(); log.append(&[entry(2, 2, b"B")]).unwrap(); log.sync().unwrap();
494 }
495 let recovered = WalLog::open(&path).unwrap();
496 assert_eq!(recovered.last_index(), 2);
497 assert_eq!(recovered.entry(2).unwrap().term, 2);
498 assert_eq!(recovered.entry(2).unwrap().command, b"B");
499 assert_eq!(recovered.entry(3), None);
500 }
501
502 #[test]
503 fn test_latest_hard_state_wins_on_recovery() {
504 let (_dir, path) = temp_path();
505 {
506 let mut log = WalLog::open(&path).unwrap();
507 log.set_hard_state(HardState {
508 term: 1,
509 voted_for: Some(1),
510 })
511 .unwrap();
512 log.set_hard_state(HardState {
513 term: 2,
514 voted_for: None,
515 })
516 .unwrap();
517 log.set_hard_state(HardState {
518 term: 3,
519 voted_for: Some(2),
520 })
521 .unwrap();
522 log.sync().unwrap();
523 }
524 let recovered = WalLog::open(&path).unwrap();
525 assert_eq!(
526 recovered.hard_state(),
527 HardState {
528 term: 3,
529 voted_for: Some(2)
530 }
531 );
532 }
533
534 #[test]
535 fn test_snapshot_compaction_survives_recovery() {
536 let (_dir, path) = temp_path();
537 {
538 let mut log = WalLog::open(&path).unwrap();
539 log.append(&[entry(1, 1, b"a"), entry(1, 2, b"b"), entry(2, 3, b"c")])
540 .unwrap();
541 log.apply_snapshot(&Snapshot::new(2, 1, b"state@2".to_vec()))
542 .unwrap();
543 log.append(&[entry(2, 4, b"d")]).unwrap();
544 log.sync().unwrap();
545 }
546 let recovered = WalLog::open(&path).unwrap();
547 assert_eq!(recovered.snapshot_index(), 2);
550 assert_eq!(recovered.last_index(), 4);
551 assert_eq!(recovered.entry(1), None);
552 assert_eq!(recovered.entry(2), None);
553 assert_eq!(recovered.term_at(2), Some(1));
554 assert_eq!(recovered.entry(3).unwrap().command, b"c");
555 assert_eq!(recovered.entry(4).unwrap().command, b"d");
556 assert_eq!(recovered.snapshot().unwrap().data, b"state@2");
557 }
558
559 #[test]
560 fn test_snapshot_codec_round_trips() {
561 let snap = Snapshot::with_config(9, 4, vec![1, 2, 3], b"payload".to_vec());
562 match decode(&encode_snapshot(&snap)).unwrap() {
563 Decoded::Snapshot(got) => assert_eq!(got, snap),
564 _ => panic!("wrong record"),
565 }
566 }
567
568 #[test]
569 fn test_config_entry_and_snapshot_membership_survive_recovery() {
570 let (_dir, path) = temp_path();
571 {
572 let mut log = WalLog::open(&path).unwrap();
573 log.apply_snapshot(&Snapshot::with_config(2, 1, vec![1, 2, 3], b"s".to_vec()))
574 .unwrap();
575 log.append(&[LogEntry::config(2, 3, &[1, 2, 3, 4])])
576 .unwrap();
577 log.sync().unwrap();
578 }
579 let recovered = WalLog::open(&path).unwrap();
580 assert_eq!(recovered.snapshot().unwrap().config, vec![1, 2, 3]);
581 assert_eq!(
582 recovered.entry(3).unwrap().members(),
583 Some(vec![1, 2, 3, 4])
584 );
585 }
586
587 #[test]
588 fn test_empty_log_opens_clean() {
589 let (_dir, path) = temp_path();
590 let log = WalLog::open(&path).unwrap();
591 assert_eq!(log.last_index(), 0);
592 assert_eq!(log.hard_state(), HardState::default());
593 }
594
595 #[test]
596 fn test_non_contiguous_append_is_rejected_before_write() {
597 let (_dir, path) = temp_path();
598 let mut log = WalLog::open(&path).unwrap();
599 assert!(log.append(&[entry(1, 5, b"x")]).is_err());
600 assert_eq!(log.last_index(), 0);
602 drop(log);
603 assert_eq!(WalLog::open(&path).unwrap().last_index(), 0);
604 }
605}