1use crate::error::{RaftError, RaftResult};
8use crate::log::{Command, LogEntry};
9use crate::types::{LogIndex, NodeId, Term};
10use std::io::{Read, Write};
11use std::path::{Path, PathBuf};
12
13pub trait RaftPersistence: Send + Sync {
15 fn save_state(&self, term: Term, voted_for: Option<NodeId>) -> RaftResult<()>;
17
18 fn load_state(&self) -> RaftResult<(Term, Option<NodeId>)>;
20
21 fn append_entries(&self, entries: &[LogEntry]) -> RaftResult<()>;
23
24 fn load_log(&self) -> RaftResult<Vec<LogEntry>>;
26
27 fn truncate_log_from(&self, index: LogIndex) -> RaftResult<()>;
29
30 fn last_log_index(&self) -> RaftResult<LogIndex>;
32
33 fn save_applied_index(&self, index: LogIndex) -> RaftResult<()>;
38
39 fn load_applied_index(&self) -> RaftResult<LogIndex>;
41
42 fn sync(&self) -> RaftResult<()>;
44}
45
46pub struct FilePersistence {
58 state_path: PathBuf,
59 log_path: PathBuf,
60 sync_on_write: bool,
61}
62
63#[derive(serde::Serialize, serde::Deserialize)]
65struct PersistedState {
66 current_term: Term,
67 voted_for: Option<NodeId>,
68 #[serde(default)]
69 applied_index: LogIndex,
70}
71
72impl FilePersistence {
73 pub fn new(dir: &Path, sync_on_write: bool) -> RaftResult<Self> {
77 std::fs::create_dir_all(dir).map_err(|e| RaftError::StorageError {
78 message: format!("failed to create persistence dir {}: {e}", dir.display()),
79 })?;
80
81 Ok(Self {
82 state_path: dir.join("raft_state.json"),
83 log_path: dir.join("raft_log.bin"),
84 sync_on_write,
85 })
86 }
87
88 fn atomic_write_state(&self, data: &[u8]) -> RaftResult<()> {
90 let tmp_path = self.state_path.with_extension("json.tmp");
91
92 let mut f = std::fs::File::create(&tmp_path).map_err(|e| RaftError::StorageError {
93 message: format!("failed to create tmp state file: {e}"),
94 })?;
95
96 f.write_all(data).map_err(|e| RaftError::StorageError {
97 message: format!("failed to write tmp state file: {e}"),
98 })?;
99
100 if self.sync_on_write {
101 f.sync_all().map_err(|e| RaftError::StorageError {
102 message: format!("failed to sync tmp state file: {e}"),
103 })?;
104 }
105
106 std::fs::rename(&tmp_path, &self.state_path).map_err(|e| RaftError::StorageError {
107 message: format!("failed to rename tmp state file: {e}"),
108 })?;
109
110 Ok(())
111 }
112
113 fn encode_entry(entry: &LogEntry) -> Vec<u8> {
120 let cmd_bytes = &entry.command.data;
121 let payload_len = 8 + 8 + 4 + cmd_bytes.len() + 4;
123
124 let mut buf = Vec::with_capacity(4 + payload_len);
125
126 buf.extend_from_slice(&(payload_len as u32).to_le_bytes());
128 buf.extend_from_slice(&entry.term.to_le_bytes());
130 buf.extend_from_slice(&entry.index.to_le_bytes());
132 buf.extend_from_slice(&(cmd_bytes.len() as u32).to_le_bytes());
134 buf.extend_from_slice(cmd_bytes);
136 let crc = crc32fast::hash(&buf[4..]);
138 buf.extend_from_slice(&crc.to_le_bytes());
139
140 buf
141 }
142
143 fn decode_entries(data: &[u8]) -> RaftResult<Vec<LogEntry>> {
146 let mut entries = Vec::new();
147 let mut pos = 0;
148
149 while pos + 4 <= data.len() {
150 let total_len = u32::from_le_bytes(read_4(data, pos)?) as usize;
152
153 if pos + 4 + total_len > data.len() {
155 break;
157 }
158
159 let record_start = pos + 4;
160 let record_end = record_start + total_len;
161 let record = &data[record_start..record_end];
162
163 if total_len < 4 {
165 break; }
167 let payload = &record[..total_len - 4];
168 let stored_crc = u32::from_le_bytes(read_4(record, total_len - 4)?);
169 let computed_crc = crc32fast::hash(payload);
170
171 if stored_crc != computed_crc {
172 return Err(RaftError::StorageError {
173 message: format!(
174 "CRC mismatch at offset {pos}: stored={stored_crc:#010x}, computed={computed_crc:#010x}"
175 ),
176 });
177 }
178
179 if payload.len() < 20 {
181 return Err(RaftError::StorageError {
182 message: format!("record too short at offset {pos}"),
183 });
184 }
185
186 let term = u64::from_le_bytes(read_8(payload, 0)?);
187 let index = u64::from_le_bytes(read_8(payload, 8)?);
188 let cmd_len = u32::from_le_bytes(read_4(payload, 16)?) as usize;
189
190 if payload.len() < 20 + cmd_len {
191 return Err(RaftError::StorageError {
192 message: format!("cmd_len exceeds record at offset {pos}"),
193 });
194 }
195
196 let cmd_data = payload[20..20 + cmd_len].to_vec();
197 entries.push(LogEntry::new(term, index, Command::new(cmd_data)));
198
199 pos = record_end;
200 }
201
202 Ok(entries)
203 }
204
205 fn rewrite_log_without(&self, from_index: LogIndex) -> RaftResult<()> {
207 let entries = self.load_log()?;
208 let kept: Vec<&LogEntry> = entries.iter().filter(|e| e.index < from_index).collect();
209
210 let tmp_path = self.log_path.with_extension("bin.tmp");
211 let mut f = std::fs::File::create(&tmp_path).map_err(|e| RaftError::StorageError {
212 message: format!("failed to create tmp log file: {e}"),
213 })?;
214
215 for entry in &kept {
216 let encoded = Self::encode_entry(entry);
217 f.write_all(&encoded).map_err(|e| RaftError::StorageError {
218 message: format!("failed to write entry to tmp log: {e}"),
219 })?;
220 }
221
222 if self.sync_on_write {
223 f.sync_all().map_err(|e| RaftError::StorageError {
224 message: format!("failed to sync tmp log: {e}"),
225 })?;
226 }
227
228 std::fs::rename(&tmp_path, &self.log_path).map_err(|e| RaftError::StorageError {
229 message: format!("failed to rename tmp log: {e}"),
230 })?;
231
232 Ok(())
233 }
234}
235
236impl RaftPersistence for FilePersistence {
237 fn save_state(&self, term: Term, voted_for: Option<NodeId>) -> RaftResult<()> {
238 let applied_index = if self.state_path.exists() {
240 self.load_applied_index().unwrap_or(0)
241 } else {
242 0
243 };
244 let state = PersistedState {
245 current_term: term,
246 voted_for,
247 applied_index,
248 };
249 let json = serde_json::to_vec_pretty(&state).map_err(|e| RaftError::StorageError {
250 message: format!("failed to serialize state: {e}"),
251 })?;
252 self.atomic_write_state(&json)
253 }
254
255 fn load_state(&self) -> RaftResult<(Term, Option<NodeId>)> {
256 if !self.state_path.exists() {
257 return Ok((0, None));
258 }
259
260 let mut f = std::fs::File::open(&self.state_path).map_err(|e| RaftError::StorageError {
261 message: format!("failed to open state file: {e}"),
262 })?;
263
264 let mut data = Vec::new();
265 f.read_to_end(&mut data)
266 .map_err(|e| RaftError::StorageError {
267 message: format!("failed to read state file: {e}"),
268 })?;
269
270 let state: PersistedState =
271 serde_json::from_slice(&data).map_err(|e| RaftError::StorageError {
272 message: format!("failed to parse state file: {e}"),
273 })?;
274
275 Ok((state.current_term, state.voted_for))
276 }
277
278 fn append_entries(&self, entries: &[LogEntry]) -> RaftResult<()> {
279 if entries.is_empty() {
280 return Ok(());
281 }
282
283 let mut f = std::fs::OpenOptions::new()
284 .create(true)
285 .append(true)
286 .open(&self.log_path)
287 .map_err(|e| RaftError::StorageError {
288 message: format!("failed to open log file for append: {e}"),
289 })?;
290
291 for entry in entries {
292 let encoded = Self::encode_entry(entry);
293 f.write_all(&encoded).map_err(|e| RaftError::StorageError {
294 message: format!("failed to append entry: {e}"),
295 })?;
296 }
297
298 if self.sync_on_write {
299 f.sync_all().map_err(|e| RaftError::StorageError {
300 message: format!("failed to sync log file: {e}"),
301 })?;
302 }
303
304 Ok(())
305 }
306
307 fn load_log(&self) -> RaftResult<Vec<LogEntry>> {
308 if !self.log_path.exists() {
309 return Ok(Vec::new());
310 }
311
312 let mut f = std::fs::File::open(&self.log_path).map_err(|e| RaftError::StorageError {
313 message: format!("failed to open log file: {e}"),
314 })?;
315
316 let mut data = Vec::new();
317 f.read_to_end(&mut data)
318 .map_err(|e| RaftError::StorageError {
319 message: format!("failed to read log file: {e}"),
320 })?;
321
322 Self::decode_entries(&data)
323 }
324
325 fn truncate_log_from(&self, index: LogIndex) -> RaftResult<()> {
326 if !self.log_path.exists() {
327 return Ok(());
328 }
329 self.rewrite_log_without(index)
330 }
331
332 fn last_log_index(&self) -> RaftResult<LogIndex> {
333 let entries = self.load_log()?;
334 Ok(entries.last().map_or(0, |e| e.index))
335 }
336
337 fn save_applied_index(&self, index: LogIndex) -> RaftResult<()> {
338 let (current_term, voted_for) = if self.state_path.exists() {
340 self.load_state()?
341 } else {
342 (0, None)
343 };
344 let state = PersistedState {
345 current_term,
346 voted_for,
347 applied_index: index,
348 };
349 let json = serde_json::to_vec_pretty(&state).map_err(|e| RaftError::StorageError {
350 message: format!("failed to serialize state (applied_index update): {e}"),
351 })?;
352 self.atomic_write_state(&json)
353 }
354
355 fn load_applied_index(&self) -> RaftResult<LogIndex> {
356 if !self.state_path.exists() {
357 return Ok(0);
358 }
359 let mut f = std::fs::File::open(&self.state_path).map_err(|e| RaftError::StorageError {
360 message: format!("failed to open state file: {e}"),
361 })?;
362 let mut data = Vec::new();
363 f.read_to_end(&mut data)
364 .map_err(|e| RaftError::StorageError {
365 message: format!("failed to read state file: {e}"),
366 })?;
367 let state: PersistedState =
368 serde_json::from_slice(&data).map_err(|e| RaftError::StorageError {
369 message: format!("failed to parse state file (applied_index): {e}"),
370 })?;
371 Ok(state.applied_index)
372 }
373
374 fn sync(&self) -> RaftResult<()> {
375 if let Ok(dir) =
378 std::fs::File::open(self.state_path.parent().unwrap_or_else(|| Path::new(".")))
379 {
380 let _ = dir.sync_all();
381 }
382 Ok(())
383 }
384}
385
386pub struct MemoryPersistence {
392 state: parking_lot::RwLock<(Term, Option<NodeId>)>,
393 log: parking_lot::RwLock<Vec<LogEntry>>,
394 applied_index: parking_lot::RwLock<LogIndex>,
395}
396
397impl MemoryPersistence {
398 pub fn new() -> Self {
400 Self {
401 state: parking_lot::RwLock::new((0, None)),
402 log: parking_lot::RwLock::new(Vec::new()),
403 applied_index: parking_lot::RwLock::new(0),
404 }
405 }
406}
407
408impl Default for MemoryPersistence {
409 fn default() -> Self {
410 Self::new()
411 }
412}
413
414impl RaftPersistence for MemoryPersistence {
415 fn save_state(&self, term: Term, voted_for: Option<NodeId>) -> RaftResult<()> {
416 *self.state.write() = (term, voted_for);
417 Ok(())
418 }
419
420 fn load_state(&self) -> RaftResult<(Term, Option<NodeId>)> {
421 Ok(*self.state.read())
422 }
423
424 fn append_entries(&self, entries: &[LogEntry]) -> RaftResult<()> {
425 self.log.write().extend(entries.iter().cloned());
426 Ok(())
427 }
428
429 fn load_log(&self) -> RaftResult<Vec<LogEntry>> {
430 Ok(self.log.read().clone())
431 }
432
433 fn truncate_log_from(&self, index: LogIndex) -> RaftResult<()> {
434 self.log.write().retain(|e| e.index < index);
435 Ok(())
436 }
437
438 fn last_log_index(&self) -> RaftResult<LogIndex> {
439 Ok(self.log.read().last().map_or(0, |e| e.index))
440 }
441
442 fn save_applied_index(&self, index: LogIndex) -> RaftResult<()> {
443 *self.applied_index.write() = index;
444 Ok(())
445 }
446
447 fn load_applied_index(&self) -> RaftResult<LogIndex> {
448 Ok(*self.applied_index.read())
449 }
450
451 fn sync(&self) -> RaftResult<()> {
452 Ok(())
453 }
454}
455
456fn read_4(data: &[u8], offset: usize) -> RaftResult<[u8; 4]> {
461 data.get(offset..offset + 4)
462 .and_then(|s| s.try_into().ok())
463 .ok_or_else(|| RaftError::StorageError {
464 message: format!("unexpected EOF reading 4 bytes at offset {offset}"),
465 })
466}
467
468fn read_8(data: &[u8], offset: usize) -> RaftResult<[u8; 8]> {
469 data.get(offset..offset + 8)
470 .and_then(|s| s.try_into().ok())
471 .ok_or_else(|| RaftError::StorageError {
472 message: format!("unexpected EOF reading 8 bytes at offset {offset}"),
473 })
474}
475
476#[cfg(test)]
481mod tests {
482 use super::*;
483 use std::sync::Arc;
484
485 fn temp_persistence_dir(prefix: &str) -> PathBuf {
487 let dir = std::env::temp_dir().join(format!(
488 "amaters_test_{prefix}_{}",
489 std::time::SystemTime::now()
490 .duration_since(std::time::UNIX_EPOCH)
491 .map(|d| d.as_nanos())
492 .unwrap_or(0)
493 ));
494 let _ = std::fs::remove_dir_all(&dir);
496 dir
497 }
498
499 fn make_entry(term: Term, index: LogIndex, data: &str) -> LogEntry {
500 LogEntry::new(term, index, Command::from_str(data))
501 }
502
503 #[test]
506 fn test_file_persistence_save_load_state() {
507 let dir = temp_persistence_dir("state_save_load");
508 let fp = FilePersistence::new(&dir, true).expect("create persistence");
509
510 let (term, voted) = fp.load_state().expect("load default");
512 assert_eq!(term, 0);
513 assert_eq!(voted, None);
514
515 fp.save_state(5, Some(42)).expect("save");
517 let (term, voted) = fp.load_state().expect("load after save");
518 assert_eq!(term, 5);
519 assert_eq!(voted, Some(42));
520
521 fp.save_state(10, None).expect("overwrite");
523 let (term, voted) = fp.load_state().expect("load overwritten");
524 assert_eq!(term, 10);
525 assert_eq!(voted, None);
526
527 let _ = std::fs::remove_dir_all(&dir);
528 }
529
530 #[test]
533 fn test_file_persistence_append_load_log() {
534 let dir = temp_persistence_dir("log_append_load");
535 let fp = FilePersistence::new(&dir, true).expect("create");
536
537 let entries = vec![
538 make_entry(1, 1, "cmd1"),
539 make_entry(1, 2, "cmd2"),
540 make_entry(2, 3, "cmd3"),
541 ];
542
543 fp.append_entries(&entries).expect("append");
544
545 let loaded = fp.load_log().expect("load");
546 assert_eq!(loaded.len(), 3);
547 assert_eq!(loaded[0].term, 1);
548 assert_eq!(loaded[0].index, 1);
549 assert_eq!(loaded[0].command.data, b"cmd1");
550 assert_eq!(loaded[2].term, 2);
551 assert_eq!(loaded[2].index, 3);
552
553 fp.append_entries(&[make_entry(2, 4, "cmd4")])
555 .expect("append more");
556 let loaded = fp.load_log().expect("load 2");
557 assert_eq!(loaded.len(), 4);
558
559 assert_eq!(fp.last_log_index().expect("last idx"), 4);
560
561 let _ = std::fs::remove_dir_all(&dir);
562 }
563
564 #[test]
567 fn test_file_persistence_truncate_log() {
568 let dir = temp_persistence_dir("log_truncate");
569 let fp = FilePersistence::new(&dir, true).expect("create");
570
571 let entries = vec![
572 make_entry(1, 1, "a"),
573 make_entry(1, 2, "b"),
574 make_entry(2, 3, "c"),
575 make_entry(2, 4, "d"),
576 ];
577 fp.append_entries(&entries).expect("append");
578
579 fp.truncate_log_from(3).expect("truncate");
581 let loaded = fp.load_log().expect("load");
582 assert_eq!(loaded.len(), 2);
583 assert_eq!(loaded[0].index, 1);
584 assert_eq!(loaded[1].index, 2);
585
586 assert_eq!(fp.last_log_index().expect("last idx"), 2);
587
588 let _ = std::fs::remove_dir_all(&dir);
589 }
590
591 #[test]
594 fn test_file_persistence_crash_recovery() {
595 let dir = temp_persistence_dir("crash_recovery");
596
597 {
599 let fp = FilePersistence::new(&dir, true).expect("create");
600 fp.save_state(7, Some(99)).expect("save state");
601 fp.append_entries(&[
602 make_entry(5, 1, "hello"),
603 make_entry(6, 2, "world"),
604 make_entry(7, 3, "!"),
605 ])
606 .expect("append");
607 fp.sync().expect("sync");
608 }
609 {
613 let fp = FilePersistence::new(&dir, true).expect("reopen");
614
615 let (term, voted) = fp.load_state().expect("load state");
616 assert_eq!(term, 7);
617 assert_eq!(voted, Some(99));
618
619 let log = fp.load_log().expect("load log");
620 assert_eq!(log.len(), 3);
621 assert_eq!(log[0].command.data, b"hello");
622 assert_eq!(log[2].index, 3);
623 }
624
625 let _ = std::fs::remove_dir_all(&dir);
626 }
627
628 #[test]
631 fn test_file_persistence_atomic_state_write() {
632 let dir = temp_persistence_dir("atomic_state");
633 let fp = FilePersistence::new(&dir, true).expect("create");
634
635 fp.save_state(1, Some(10)).expect("save 1");
637
638 fp.save_state(2, Some(20)).expect("save 2");
640
641 let tmp = fp.state_path.with_extension("json.tmp");
643 assert!(!tmp.exists(), "tmp file should have been renamed away");
644
645 let (term, voted) = fp.load_state().expect("load");
646 assert_eq!(term, 2);
647 assert_eq!(voted, Some(20));
648
649 let _ = std::fs::remove_dir_all(&dir);
650 }
651
652 #[test]
655 fn test_file_persistence_corrupted_entry() {
656 let dir = temp_persistence_dir("corrupted");
657 let fp = FilePersistence::new(&dir, true).expect("create");
658
659 fp.append_entries(&[make_entry(1, 1, "good")])
660 .expect("append");
661
662 let mut data = std::fs::read(&fp.log_path).expect("read raw");
664 if data.len() > 10 {
666 data[10] ^= 0xFF;
667 }
668 std::fs::write(&fp.log_path, &data).expect("write corrupted");
669
670 let result = fp.load_log();
671 assert!(result.is_err(), "should detect CRC mismatch");
672 let err_msg = format!("{}", result.expect_err("expected error"));
673 assert!(
674 err_msg.contains("CRC mismatch"),
675 "error should mention CRC: {err_msg}"
676 );
677
678 let _ = std::fs::remove_dir_all(&dir);
679 }
680
681 #[test]
684 fn test_memory_persistence_basic() {
685 let mp = MemoryPersistence::new();
686
687 let (t, v) = mp.load_state().expect("load default");
689 assert_eq!(t, 0);
690 assert_eq!(v, None);
691
692 mp.save_state(3, Some(7)).expect("save");
693 let (t, v) = mp.load_state().expect("load");
694 assert_eq!(t, 3);
695 assert_eq!(v, Some(7));
696
697 mp.append_entries(&[make_entry(1, 1, "x"), make_entry(1, 2, "y")])
699 .expect("append");
700 assert_eq!(mp.last_log_index().expect("last"), 2);
701
702 mp.truncate_log_from(2).expect("truncate");
703 assert_eq!(mp.last_log_index().expect("last after trunc"), 1);
704
705 mp.sync().expect("sync");
706 }
707
708 #[test]
711 fn test_persistence_trait_object() {
712 let mp: Arc<dyn RaftPersistence> = Arc::new(MemoryPersistence::new());
713 mp.save_state(1, None).expect("save via trait object");
714 let (t, _) = mp.load_state().expect("load via trait object");
715 assert_eq!(t, 1);
716 }
717}