1use crate::error::{RaftError, RaftResult};
4use crate::types::{LogIndex, Term};
5use std::collections::VecDeque;
6
7pub trait StateMachine: Send + Sync {
13 fn apply(&mut self, entry: &LogEntry) -> RaftResult<Vec<u8>>;
18
19 fn snapshot(&self) -> RaftResult<Vec<u8>>;
21
22 fn restore(&mut self, snapshot: &[u8]) -> RaftResult<()>;
24}
25
26#[derive(Debug, Clone)]
28pub struct ApplyResult {
29 pub index: LogIndex,
31 pub term: Term,
33 pub output: Vec<u8>,
35}
36
37#[derive(Debug, Clone)]
40pub struct SnapshotData {
41 pub last_included_index: LogIndex,
43 pub last_included_term: Term,
45 pub data: Vec<u8>,
47}
48
49#[derive(Debug, Clone, PartialEq, Eq)]
51pub struct Command {
52 pub data: Vec<u8>,
54}
55
56impl Command {
57 pub fn new(data: Vec<u8>) -> Self {
59 Self { data }
60 }
61
62 #[allow(clippy::should_implement_trait)]
64 pub fn from_str(s: &str) -> Self {
65 Self::new(s.as_bytes().to_vec())
66 }
67}
68
69#[derive(Debug, Clone, PartialEq, Eq)]
71pub struct LogEntry {
72 pub term: Term,
74 pub index: LogIndex,
76 pub command: Command,
78 pub fencing_token: u64,
83}
84
85impl LogEntry {
86 pub fn new(term: Term, index: LogIndex, command: Command) -> Self {
88 Self {
89 term,
90 index,
91 command,
92 fencing_token: 0,
93 }
94 }
95
96 pub fn with_fencing_token(
98 term: Term,
99 index: LogIndex,
100 command: Command,
101 fencing_token: u64,
102 ) -> Self {
103 Self {
104 term,
105 index,
106 command,
107 fencing_token,
108 }
109 }
110}
111
112pub struct RaftLog {
114 entries: VecDeque<LogEntry>,
116 first_index: LogIndex,
119 last_index: LogIndex,
121 last_term: Term,
123 commit_index: LogIndex,
125 applied_index: LogIndex,
127 snapshot_index: LogIndex,
129 snapshot_term: Term,
130 apply_callback:
134 std::sync::Mutex<Option<Box<dyn FnMut(&LogEntry) -> RaftResult<Vec<u8>> + Send>>>,
135}
136
137impl RaftLog {
138 pub fn new() -> Self {
140 Self {
141 entries: VecDeque::new(),
142 first_index: 1,
143 last_index: 0,
144 last_term: 0,
145 commit_index: 0,
146 applied_index: 0,
147 snapshot_index: 0,
148 snapshot_term: 0,
149 apply_callback: std::sync::Mutex::new(None),
150 }
151 }
152
153 pub fn append(&mut self, term: Term, command: Command) -> LogIndex {
155 let index = self.last_index + 1;
156 let entry = LogEntry::new(term, index, command);
157
158 self.entries.push_back(entry);
159 self.last_index = index;
160 self.last_term = term;
161
162 index
163 }
164
165 pub fn append_entries(&mut self, entries: Vec<LogEntry>) -> RaftResult<()> {
167 if entries.is_empty() {
168 return Ok(());
169 }
170
171 for (expected_index, entry) in (self.last_index + 1..).zip(entries.iter()) {
173 if entry.index != expected_index {
174 return Err(RaftError::LogInconsistency {
175 reason: format!("Expected index {}, got {}", expected_index, entry.index),
176 });
177 }
178 }
179
180 for entry in entries {
182 self.last_index = entry.index;
183 self.last_term = entry.term;
184 self.entries.push_back(entry);
185 }
186
187 Ok(())
188 }
189
190 pub fn get(&self, index: LogIndex) -> Option<&LogEntry> {
192 if index < self.first_index || index > self.last_index {
193 return None;
194 }
195
196 let offset = (index - self.first_index) as usize;
197 self.entries.get(offset)
198 }
199
200 pub fn get_entries_from(&self, start_index: LogIndex, max_count: usize) -> Vec<LogEntry> {
202 if start_index < self.first_index || start_index > self.last_index {
203 return Vec::new();
204 }
205
206 let offset = (start_index - self.first_index) as usize;
207 self.entries
208 .iter()
209 .skip(offset)
210 .take(max_count)
211 .cloned()
212 .collect()
213 }
214
215 pub fn get_term(&self, index: LogIndex) -> Option<Term> {
217 if index == 0 {
218 return Some(0);
219 }
220
221 if index == self.snapshot_index {
222 return Some(self.snapshot_term);
223 }
224
225 self.get(index).map(|entry| entry.term)
226 }
227
228 pub fn last_index(&self) -> LogIndex {
230 self.last_index
231 }
232
233 pub fn last_term(&self) -> Term {
235 self.last_term
236 }
237
238 pub fn truncate_from(&mut self, from_index: LogIndex) -> RaftResult<()> {
240 if from_index <= self.snapshot_index {
241 return Err(RaftError::LogInconsistency {
242 reason: format!(
243 "Cannot truncate before snapshot index {}",
244 self.snapshot_index
245 ),
246 });
247 }
248
249 if from_index > self.last_index {
250 return Ok(());
251 }
252
253 let offset = (from_index - self.first_index) as usize;
255 self.entries.truncate(offset);
256
257 if let Some(last_entry) = self.entries.back() {
259 self.last_index = last_entry.index;
260 self.last_term = last_entry.term;
261 } else {
262 self.last_index = self.snapshot_index;
263 self.last_term = self.snapshot_term;
264 }
265
266 Ok(())
267 }
268
269 pub fn matches(&self, index: LogIndex, term: Term) -> bool {
271 if index == 0 {
272 return term == 0;
273 }
274
275 if index == self.snapshot_index {
276 return term == self.snapshot_term;
277 }
278
279 match self.get_term(index) {
280 Some(t) => t == term,
281 None => false,
282 }
283 }
284
285 pub fn commit_index(&self) -> LogIndex {
287 self.commit_index
288 }
289
290 pub fn set_commit_index(&mut self, index: LogIndex) -> RaftResult<()> {
292 if index < self.commit_index {
293 return Err(RaftError::LogInconsistency {
294 reason: format!(
295 "Cannot decrease commit index from {} to {}",
296 self.commit_index, index
297 ),
298 });
299 }
300
301 if index > self.last_index {
302 return Err(RaftError::LogInconsistency {
303 reason: format!(
304 "Cannot commit beyond last index {} (tried to commit {})",
305 self.last_index, index
306 ),
307 });
308 }
309
310 self.commit_index = index;
311 Ok(())
312 }
313
314 pub fn applied_index(&self) -> LogIndex {
316 self.applied_index
317 }
318
319 pub fn set_applied_index(&mut self, index: LogIndex) -> RaftResult<()> {
321 if index < self.applied_index {
322 return Err(RaftError::LogInconsistency {
323 reason: format!(
324 "Cannot decrease applied index from {} to {}",
325 self.applied_index, index
326 ),
327 });
328 }
329
330 if index > self.commit_index {
331 return Err(RaftError::LogInconsistency {
332 reason: format!(
333 "Cannot apply beyond commit index {} (tried to apply {})",
334 self.commit_index, index
335 ),
336 });
337 }
338
339 self.applied_index = index;
340 Ok(())
341 }
342
343 pub fn get_uncommitted_entries(&self) -> Vec<LogEntry> {
345 if self.applied_index >= self.commit_index {
346 return Vec::new();
347 }
348
349 self.get_entries_from(self.applied_index + 1, usize::MAX)
350 .into_iter()
351 .take_while(|entry| entry.index <= self.commit_index)
352 .collect()
353 }
354
355 pub fn compact_until(&mut self, index: LogIndex, term: Term) -> RaftResult<()> {
365 if index == 0 {
366 return Ok(());
367 }
368
369 if index <= self.snapshot_index {
370 return Ok(());
372 }
373
374 if index > self.applied_index {
375 return Err(RaftError::LogInconsistency {
376 reason: format!(
377 "Cannot compact beyond applied index {} (tried to compact until {})",
378 self.applied_index, index
379 ),
380 });
381 }
382
383 if let Some(entry_term) = self.get_term(index) {
385 if entry_term != term {
386 return Err(RaftError::LogInconsistency {
387 reason: format!(
388 "Term mismatch at index {}: expected {}, found {}",
389 index, term, entry_term
390 ),
391 });
392 }
393 }
394
395 let entries_to_remove = if index >= self.first_index {
397 ((index - self.first_index) + 1) as usize
398 } else {
399 0
400 };
401
402 let drain_count = entries_to_remove.min(self.entries.len());
403 self.entries.drain(..drain_count);
404
405 self.snapshot_index = index;
407 self.snapshot_term = term;
408 self.first_index = index + 1;
409
410 Ok(())
411 }
412
413 pub fn get_snapshot_point(&self) -> (LogIndex, Term) {
415 (self.snapshot_index, self.snapshot_term)
416 }
417
418 pub fn snapshot_index(&self) -> LogIndex {
420 self.snapshot_index
421 }
422
423 pub fn snapshot_term(&self) -> Term {
425 self.snapshot_term
426 }
427
428 pub fn install_snapshot(&mut self, last_included_index: LogIndex, last_included_term: Term) {
433 self.entries.clear();
434 self.snapshot_index = last_included_index;
435 self.snapshot_term = last_included_term;
436 self.first_index = last_included_index + 1;
437 self.last_index = last_included_index;
438 self.last_term = last_included_term;
439
440 if self.commit_index < last_included_index {
442 self.commit_index = last_included_index;
443 }
444 if self.applied_index < last_included_index {
445 self.applied_index = last_included_index;
446 }
447 }
448
449 pub fn entries_since_snapshot(&self) -> u64 {
451 self.last_index.saturating_sub(self.snapshot_index)
452 }
453
454 pub fn is_empty(&self) -> bool {
456 self.entries.is_empty()
457 }
458
459 pub fn len(&self) -> usize {
461 self.entries.len()
462 }
463
464 pub fn set_apply_callback<F>(&mut self, callback: F)
470 where
471 F: FnMut(&LogEntry) -> RaftResult<Vec<u8>> + Send + 'static,
472 {
473 let mut guard = self
474 .apply_callback
475 .lock()
476 .unwrap_or_else(|e| e.into_inner());
477 *guard = Some(Box::new(callback));
478 }
479
480 pub fn apply_committed_entries(&mut self) -> RaftResult<Vec<ApplyResult>> {
486 let entries = self.get_uncommitted_entries();
487 let mut results = Vec::with_capacity(entries.len());
488 for entry in &entries {
489 let output = {
490 let mut guard = self
491 .apply_callback
492 .lock()
493 .unwrap_or_else(|e| e.into_inner());
494 if let Some(ref mut cb) = *guard {
495 cb(entry)?
496 } else {
497 Vec::new()
498 }
499 };
500 self.applied_index = entry.index;
501 results.push(ApplyResult {
502 index: entry.index,
503 term: entry.term,
504 output,
505 });
506 }
507 Ok(results)
508 }
509
510 pub fn apply_batch(&mut self, max_entries: usize) -> RaftResult<Vec<ApplyResult>> {
516 let entries = self.get_uncommitted_entries();
517 let batch: Vec<_> = entries.into_iter().take(max_entries).collect();
518 let saved_applied = self.applied_index;
519 let mut results = Vec::new();
520 for entry in &batch {
521 let invoke_result = {
522 let mut guard = self
523 .apply_callback
524 .lock()
525 .unwrap_or_else(|e| e.into_inner());
526 if let Some(ref mut cb) = *guard {
527 cb(entry)
528 } else {
529 Ok(Vec::new())
530 }
531 };
532 match invoke_result {
533 Ok(output) => {
534 self.applied_index = entry.index;
535 results.push(ApplyResult {
536 index: entry.index,
537 term: entry.term,
538 output,
539 });
540 }
541 Err(e) => {
542 self.applied_index = saved_applied;
543 return Err(e);
544 }
545 }
546 }
547 Ok(results)
548 }
549
550 pub fn create_snapshot(&self) -> RaftResult<SnapshotData> {
555 let term = self
556 .entries
557 .iter()
558 .find(|e| e.index == self.applied_index)
559 .map(|e| e.term)
560 .unwrap_or(self.snapshot_term);
561 Ok(SnapshotData {
562 last_included_index: self.applied_index,
563 last_included_term: term,
564 data: Vec::new(),
565 })
566 }
567
568 pub fn pending_apply_count(&self) -> usize {
570 if self.commit_index <= self.applied_index {
571 0
572 } else {
573 (self.commit_index - self.applied_index) as usize
574 }
575 }
576
577 pub fn is_fully_applied(&self) -> bool {
580 self.applied_index >= self.commit_index
581 }
582}
583
584impl Default for RaftLog {
585 fn default() -> Self {
586 Self::new()
587 }
588}
589
590#[cfg(test)]
591mod tests {
592 use super::*;
593
594 #[test]
595 fn test_new_log() {
596 let log = RaftLog::new();
597 assert_eq!(log.last_index(), 0);
598 assert_eq!(log.last_term(), 0);
599 assert_eq!(log.commit_index(), 0);
600 assert_eq!(log.applied_index(), 0);
601 assert!(log.is_empty());
602 }
603
604 #[test]
605 fn test_append_entry() {
606 let mut log = RaftLog::new();
607 let cmd = Command::from_str("test");
608
609 let index = log.append(1, cmd.clone());
610 assert_eq!(index, 1);
611 assert_eq!(log.last_index(), 1);
612 assert_eq!(log.last_term(), 1);
613 assert_eq!(log.len(), 1);
614
615 let entry = log.get(1).expect("Entry should exist");
616 assert_eq!(entry.index, 1);
617 assert_eq!(entry.term, 1);
618 assert_eq!(entry.command, cmd);
619 }
620
621 #[test]
622 fn test_append_multiple_entries() {
623 let mut log = RaftLog::new();
624 log.append(1, Command::from_str("cmd1"));
625 log.append(1, Command::from_str("cmd2"));
626 log.append(2, Command::from_str("cmd3"));
627
628 assert_eq!(log.last_index(), 3);
629 assert_eq!(log.last_term(), 2);
630 assert_eq!(log.len(), 3);
631 }
632
633 #[test]
634 fn test_get_entries_from() {
635 let mut log = RaftLog::new();
636 log.append(1, Command::from_str("cmd1"));
637 log.append(1, Command::from_str("cmd2"));
638 log.append(2, Command::from_str("cmd3"));
639
640 let entries = log.get_entries_from(2, 10);
641 assert_eq!(entries.len(), 2);
642 assert_eq!(entries[0].index, 2);
643 assert_eq!(entries[1].index, 3);
644 }
645
646 #[test]
647 fn test_get_entries_from_with_limit() {
648 let mut log = RaftLog::new();
649 log.append(1, Command::from_str("cmd1"));
650 log.append(1, Command::from_str("cmd2"));
651 log.append(2, Command::from_str("cmd3"));
652
653 let entries = log.get_entries_from(1, 2);
654 assert_eq!(entries.len(), 2);
655 assert_eq!(entries[0].index, 1);
656 assert_eq!(entries[1].index, 2);
657 }
658
659 #[test]
660 fn test_truncate_from() {
661 let mut log = RaftLog::new();
662 log.append(1, Command::from_str("cmd1"));
663 log.append(1, Command::from_str("cmd2"));
664 log.append(2, Command::from_str("cmd3"));
665
666 log.truncate_from(2).expect("Truncate should succeed");
667
668 assert_eq!(log.last_index(), 1);
669 assert_eq!(log.last_term(), 1);
670 assert_eq!(log.len(), 1);
671 assert!(log.get(2).is_none());
672 assert!(log.get(3).is_none());
673 }
674
675 #[test]
676 fn test_matches() {
677 let mut log = RaftLog::new();
678 log.append(1, Command::from_str("cmd1"));
679 log.append(1, Command::from_str("cmd2"));
680 log.append(2, Command::from_str("cmd3"));
681
682 assert!(log.matches(1, 1));
683 assert!(log.matches(2, 1));
684 assert!(log.matches(3, 2));
685 assert!(!log.matches(3, 1));
686 assert!(!log.matches(4, 2));
687 }
688
689 #[test]
690 fn test_commit_index() {
691 let mut log = RaftLog::new();
692 log.append(1, Command::from_str("cmd1"));
693 log.append(1, Command::from_str("cmd2"));
694 log.append(2, Command::from_str("cmd3"));
695
696 assert_eq!(log.commit_index(), 0);
697
698 log.set_commit_index(2).expect("Set commit should succeed");
699 assert_eq!(log.commit_index(), 2);
700
701 let result = log.set_commit_index(1);
703 assert!(result.is_err());
704 }
705
706 #[test]
707 fn test_applied_index() {
708 let mut log = RaftLog::new();
709 log.append(1, Command::from_str("cmd1"));
710 log.append(1, Command::from_str("cmd2"));
711 log.set_commit_index(2).expect("Set commit should succeed");
712
713 assert_eq!(log.applied_index(), 0);
714
715 log.set_applied_index(1)
716 .expect("Set applied should succeed");
717 assert_eq!(log.applied_index(), 1);
718
719 let result = log.set_applied_index(3);
721 assert!(result.is_err());
722 }
723
724 #[test]
725 fn test_compact_until() {
726 let mut log = RaftLog::new();
727 log.append(1, Command::from_str("cmd1"));
728 log.append(1, Command::from_str("cmd2"));
729 log.append(2, Command::from_str("cmd3"));
730 log.append(2, Command::from_str("cmd4"));
731 log.append(3, Command::from_str("cmd5"));
732
733 log.set_commit_index(3).expect("Set commit should succeed");
735 log.set_applied_index(3)
736 .expect("Set applied should succeed");
737
738 log.compact_until(2, 1).expect("Compact should succeed");
740
741 assert_eq!(log.snapshot_index(), 2);
742 assert_eq!(log.snapshot_term(), 1);
743 assert_eq!(log.len(), 3); assert!(log.get(1).is_none());
745 assert!(log.get(2).is_none());
746 assert!(log.get(3).is_some());
747 assert_eq!(log.last_index(), 5);
748 }
749
750 #[test]
751 fn test_compact_until_beyond_applied_fails() {
752 let mut log = RaftLog::new();
753 log.append(1, Command::from_str("cmd1"));
754 log.append(1, Command::from_str("cmd2"));
755 log.set_commit_index(1).expect("Set commit should succeed");
756 log.set_applied_index(1)
757 .expect("Set applied should succeed");
758
759 let result = log.compact_until(2, 1);
761 assert!(result.is_err());
762 }
763
764 #[test]
765 fn test_compact_preserves_snapshot_metadata() {
766 let mut log = RaftLog::new();
767 log.append(1, Command::from_str("cmd1"));
768 log.append(2, Command::from_str("cmd2"));
769 log.append(3, Command::from_str("cmd3"));
770 log.set_commit_index(3).expect("Set commit should succeed");
771 log.set_applied_index(3)
772 .expect("Set applied should succeed");
773
774 log.compact_until(2, 2).expect("Compact should succeed");
775
776 let (snap_idx, snap_term) = log.get_snapshot_point();
777 assert_eq!(snap_idx, 2);
778 assert_eq!(snap_term, 2);
779
780 assert_eq!(log.get_term(2), Some(2));
782 }
783
784 #[test]
785 fn test_entries_since_snapshot() {
786 let mut log = RaftLog::new();
787 for i in 1..=10 {
788 log.append(1, Command::from_str(&format!("cmd{}", i)));
789 }
790 assert_eq!(log.entries_since_snapshot(), 10);
791
792 log.set_commit_index(5).expect("Set commit should succeed");
793 log.set_applied_index(5)
794 .expect("Set applied should succeed");
795 log.compact_until(5, 1).expect("Compact should succeed");
796
797 assert_eq!(log.entries_since_snapshot(), 5);
798 }
799
800 #[test]
801 fn test_install_snapshot_resets_log() {
802 let mut log = RaftLog::new();
803 log.append(1, Command::from_str("cmd1"));
804 log.append(1, Command::from_str("cmd2"));
805
806 log.install_snapshot(100, 5);
807
808 assert_eq!(log.last_index(), 100);
809 assert_eq!(log.last_term(), 5);
810 assert_eq!(log.snapshot_index(), 100);
811 assert_eq!(log.snapshot_term(), 5);
812 assert_eq!(log.commit_index(), 100);
813 assert_eq!(log.applied_index(), 100);
814 assert!(log.is_empty());
815 }
816
817 #[test]
818 fn test_get_uncommitted_entries() {
819 let mut log = RaftLog::new();
820 log.append(1, Command::from_str("cmd1"));
821 log.append(1, Command::from_str("cmd2"));
822 log.append(2, Command::from_str("cmd3"));
823 log.set_commit_index(2).expect("Set commit should succeed");
824
825 let entries = log.get_uncommitted_entries();
826 assert_eq!(entries.len(), 2);
827 assert_eq!(entries[0].index, 1);
828 assert_eq!(entries[1].index, 2);
829
830 log.set_applied_index(1)
831 .expect("Set applied should succeed");
832 let entries = log.get_uncommitted_entries();
833 assert_eq!(entries.len(), 1);
834 assert_eq!(entries[0].index, 2);
835 }
836
837 #[test]
840 fn test_apply_committed_sequential() {
841 let mut log = RaftLog::new();
842 for i in 1..=5 {
843 log.append(1, Command::from_str(&format!("cmd{}", i)));
844 }
845 log.set_commit_index(5).expect("commit");
846
847 let results = log.apply_committed_entries().expect("apply");
848 assert_eq!(results.len(), 5);
849 for (i, r) in results.iter().enumerate() {
850 assert_eq!(r.index, (i + 1) as u64);
851 assert_eq!(r.term, 1);
852 }
853 assert_eq!(log.applied_index(), 5);
854 }
855
856 #[test]
857 fn test_apply_committed_partial() {
858 let mut log = RaftLog::new();
859 for i in 1..=5 {
860 log.append(1, Command::from_str(&format!("cmd{}", i)));
861 }
862 log.set_commit_index(3).expect("commit");
863
864 let results = log.apply_committed_entries().expect("apply");
865 assert_eq!(results.len(), 3);
866 assert_eq!(log.applied_index(), 3);
867 }
868
869 #[test]
870 fn test_apply_with_callback() {
871 let mut log = RaftLog::new();
872 log.append(1, Command::from_str("hello"));
873 log.append(1, Command::from_str("world"));
874 log.set_commit_index(2).expect("commit");
875
876 let seen = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
877 let seen_clone = seen.clone();
878 log.set_apply_callback(move |entry| {
879 seen_clone
880 .lock()
881 .expect("lock")
882 .push(entry.command.data.clone());
883 Ok(Vec::new())
884 });
885
886 log.apply_committed_entries().expect("apply");
887 let data = seen.lock().expect("lock");
888 assert_eq!(data.len(), 2);
889 assert_eq!(data[0], b"hello");
890 assert_eq!(data[1], b"world");
891 }
892
893 #[test]
894 fn test_apply_callback_output() {
895 let mut log = RaftLog::new();
896 log.append(1, Command::from_str("ping"));
897 log.set_commit_index(1).expect("commit");
898
899 log.set_apply_callback(|_entry| Ok(b"pong".to_vec()));
900
901 let results = log.apply_committed_entries().expect("apply");
902 assert_eq!(results.len(), 1);
903 assert_eq!(results[0].output, b"pong");
904 }
905
906 #[test]
907 fn test_apply_callback_error() {
908 let mut log = RaftLog::new();
909 for i in 1..=5 {
910 log.append(1, Command::from_str(&format!("cmd{}", i)));
911 }
912 log.set_commit_index(5).expect("commit");
913
914 let mut count = 0u64;
915 log.set_apply_callback(move |entry| {
916 count += 1;
917 if entry.index == 3 {
918 return Err(RaftError::StateMachineError {
919 message: "boom".into(),
920 });
921 }
922 let _ = count; Ok(Vec::new())
924 });
925
926 let err = log.apply_committed_entries().expect_err("should fail");
927 assert!(matches!(err, RaftError::StateMachineError { .. }));
928 assert_eq!(log.applied_index(), 2);
930 }
931
932 #[test]
933 fn test_apply_batch_limited() {
934 let mut log = RaftLog::new();
935 for i in 1..=5 {
936 log.append(1, Command::from_str(&format!("cmd{}", i)));
937 }
938 log.set_commit_index(5).expect("commit");
939
940 let results = log.apply_batch(2).expect("batch");
941 assert_eq!(results.len(), 2);
942 assert_eq!(log.applied_index(), 2);
943 }
944
945 #[test]
946 fn test_apply_batch_rollback() {
947 let mut log = RaftLog::new();
948 for i in 1..=5 {
949 log.append(1, Command::from_str(&format!("cmd{}", i)));
950 }
951 log.set_commit_index(5).expect("commit");
952
953 log.set_apply_callback(|entry| {
954 if entry.index == 3 {
955 return Err(RaftError::StateMachineError {
956 message: "fail".into(),
957 });
958 }
959 Ok(Vec::new())
960 });
961
962 let err = log.apply_batch(5).expect_err("should fail");
963 assert!(matches!(err, RaftError::StateMachineError { .. }));
964 assert_eq!(log.applied_index(), 0);
966 }
967
968 #[test]
969 fn test_apply_no_callback() {
970 let mut log = RaftLog::new();
971 log.append(1, Command::from_str("x"));
972 log.append(1, Command::from_str("y"));
973 log.set_commit_index(2).expect("commit");
974
975 let results = log.apply_committed_entries().expect("apply");
976 assert_eq!(results.len(), 2);
977 assert!(results[0].output.is_empty());
978 assert!(results[1].output.is_empty());
979 assert_eq!(log.applied_index(), 2);
980 }
981
982 #[test]
983 fn test_apply_empty() {
984 let mut log = RaftLog::new();
985 let results = log.apply_committed_entries().expect("apply");
987 assert!(results.is_empty());
988 }
989
990 #[test]
991 fn test_apply_idempotent() {
992 let mut log = RaftLog::new();
993 log.append(1, Command::from_str("a"));
994 log.set_commit_index(1).expect("commit");
995
996 let r1 = log.apply_committed_entries().expect("first apply");
997 assert_eq!(r1.len(), 1);
998
999 let r2 = log.apply_committed_entries().expect("second apply");
1000 assert!(r2.is_empty());
1001 }
1002
1003 #[test]
1004 fn test_pending_apply_count() {
1005 let mut log = RaftLog::new();
1006 log.append(1, Command::from_str("a"));
1007 log.append(1, Command::from_str("b"));
1008 log.append(1, Command::from_str("c"));
1009 log.set_commit_index(3).expect("commit");
1010
1011 assert_eq!(log.pending_apply_count(), 3);
1012
1013 log.set_applied_index(1).expect("apply");
1014 assert_eq!(log.pending_apply_count(), 2);
1015
1016 log.set_applied_index(3).expect("apply");
1017 assert_eq!(log.pending_apply_count(), 0);
1018 }
1019
1020 #[test]
1021 fn test_is_fully_applied() {
1022 let mut log = RaftLog::new();
1023 assert!(log.is_fully_applied()); log.append(1, Command::from_str("a"));
1026 log.set_commit_index(1).expect("commit");
1027 assert!(!log.is_fully_applied());
1028
1029 log.set_applied_index(1).expect("apply");
1030 assert!(log.is_fully_applied());
1031 }
1032
1033 #[test]
1034 fn test_create_snapshot() {
1035 let mut log = RaftLog::new();
1036 log.append(1, Command::from_str("a"));
1037 log.append(2, Command::from_str("b"));
1038 log.append(2, Command::from_str("c"));
1039 log.set_commit_index(3).expect("commit");
1040 log.set_applied_index(3).expect("apply");
1041
1042 let snap = log.create_snapshot().expect("snapshot");
1043 assert_eq!(snap.last_included_index, 3);
1044 assert_eq!(snap.last_included_term, 2);
1045 assert!(snap.data.is_empty());
1046 }
1047
1048 #[test]
1049 fn test_state_machine_trait() {
1050 use std::collections::HashMap;
1051
1052 struct KvStateMachine {
1054 store: HashMap<String, String>,
1055 }
1056
1057 impl KvStateMachine {
1058 fn new() -> Self {
1059 Self {
1060 store: HashMap::new(),
1061 }
1062 }
1063 }
1064
1065 impl StateMachine for KvStateMachine {
1066 fn apply(&mut self, entry: &LogEntry) -> RaftResult<Vec<u8>> {
1067 let text = std::str::from_utf8(&entry.command.data).map_err(|e| {
1068 RaftError::StateMachineError {
1069 message: format!("invalid utf8: {}", e),
1070 }
1071 })?;
1072 let parts: Vec<&str> = text.splitn(2, '=').collect();
1073 if parts.len() == 2 {
1074 self.store
1075 .insert(parts[0].to_string(), parts[1].to_string());
1076 Ok(b"OK".to_vec())
1077 } else {
1078 let val = self.store.get(parts[0]).cloned().unwrap_or_default();
1080 Ok(val.into_bytes())
1081 }
1082 }
1083
1084 fn snapshot(&self) -> RaftResult<Vec<u8>> {
1085 let mut buf = Vec::new();
1086 for (k, v) in &self.store {
1087 buf.extend_from_slice(k.as_bytes());
1088 buf.push(b'=');
1089 buf.extend_from_slice(v.as_bytes());
1090 buf.push(b'\n');
1091 }
1092 Ok(buf)
1093 }
1094
1095 fn restore(&mut self, snapshot: &[u8]) -> RaftResult<()> {
1096 self.store.clear();
1097 let text =
1098 std::str::from_utf8(snapshot).map_err(|e| RaftError::StateMachineError {
1099 message: format!("invalid utf8: {}", e),
1100 })?;
1101 for line in text.lines() {
1102 let parts: Vec<&str> = line.splitn(2, '=').collect();
1103 if parts.len() == 2 {
1104 self.store
1105 .insert(parts[0].to_string(), parts[1].to_string());
1106 }
1107 }
1108 Ok(())
1109 }
1110 }
1111
1112 let mut sm = KvStateMachine::new();
1114 let entry1 = LogEntry::new(1, 1, Command::from_str("foo=bar"));
1115 let entry2 = LogEntry::new(1, 2, Command::from_str("baz=qux"));
1116
1117 let out1 = sm.apply(&entry1).expect("apply1");
1118 assert_eq!(out1, b"OK");
1119 let out2 = sm.apply(&entry2).expect("apply2");
1120 assert_eq!(out2, b"OK");
1121
1122 let snap = sm.snapshot().expect("snapshot");
1124 assert!(!snap.is_empty());
1125
1126 let mut sm2 = KvStateMachine::new();
1128 sm2.restore(&snap).expect("restore");
1129 assert_eq!(sm2.store.get("foo").map(|s| s.as_str()), Some("bar"));
1130 assert_eq!(sm2.store.get("baz").map(|s| s.as_str()), Some("qux"));
1131
1132 let entry3 = LogEntry::new(1, 3, Command::from_str("foo"));
1134 let out3 = sm2.apply(&entry3).expect("apply3");
1135 assert_eq!(out3, b"bar");
1136 }
1137}