raft_lite/
model_check.rs

1use crate::config::{RaftConfig, RaftParams};
2use crate::raft_protocol::Event;
3use crate::raft_protocol::{
4    LogRequestArgs, LogResponseArgs, RaftProtocol, StepOutput, VoteRequestArgs, VoteResponseArgs,
5};
6use crate::runner::{CheckerRunner};
7use stateright::actor::model_timeout;
8use stateright::actor::{Actor, ActorModel, Id, Network, Out};
9use stateright::Expectation;
10use std::borrow::Cow;
11use std::collections::HashSet;
12
13#[derive(Clone, Debug, Eq, Hash, PartialEq)]
14pub struct CheckerActor {
15    peer_ids: Vec<String>,
16}
17
18#[derive(Clone, Debug, Eq, Hash, PartialEq)]
19pub enum CheckerMessage {
20    VoteRequest(VoteRequestArgs),
21    VoteResponse(VoteResponseArgs),
22    LogRequest(LogRequestArgs),
23    LogResponse(LogResponseArgs),
24    Broadcast(Vec<u8>),
25}
26
27#[derive(Clone, Debug, Eq, Hash, PartialEq)]
28pub enum CheckerTimer {
29    ElectionTimeout,
30    ReplicationTimeout,
31}
32
33#[derive(Clone, Debug, Eq, Hash, PartialEq)]
34pub struct CheckerState {
35    delivered_messages: Vec<Vec<u8>>,
36    raft_protocol: RaftProtocol<CheckerRunner>,
37}
38
39fn process_event(
40    state: &mut Cow<CheckerState>,
41    event: Event,
42    o: &mut Out<CheckerActor>,
43) {
44    let state = state.to_mut();
45    let step_output = state.raft_protocol.step(event);
46    for output in step_output {
47        match output {
48            StepOutput::DeliverMessage(payload) => {
49                state.delivered_messages.push(payload);
50            }
51            StepOutput::VoteRequest(peer_id, vote_request_args) => {
52                o.send(
53                    Id::from(peer_id as usize),
54                    CheckerMessage::VoteRequest(vote_request_args),
55                );
56            }
57            StepOutput::VoteResponse(peer_id, vote_response_args) => {
58                o.send(
59                    Id::from(peer_id as usize),
60                    CheckerMessage::VoteResponse(vote_response_args),
61                );
62            }
63            StepOutput::LogRequest(peer_id, log_request_args) => {
64                o.send(
65                    Id::from(peer_id as usize),
66                    CheckerMessage::LogRequest(log_request_args),
67                );
68            }
69            StepOutput::LogResponse(peer_id, log_response_args) => {
70                o.send(
71                    Id::from(peer_id as usize),
72                    CheckerMessage::LogResponse(log_response_args),
73                );
74            }
75            StepOutput::Broadcast(peer_id, payload) => {
76                o.send(
77                    Id::from(peer_id as usize),
78                    CheckerMessage::Broadcast(payload),
79                );
80            }
81            StepOutput::ElectionTimerReset => {
82                o.set_timer(CheckerTimer::ElectionTimeout, model_timeout());
83            }
84            StepOutput::ReplicateTimerReset => {
85                o.set_timer(CheckerTimer::ReplicationTimeout, model_timeout());
86            }
87        }
88    }
89}
90
91impl Actor for CheckerActor {
92    type Msg = CheckerMessage;
93    type Timer = CheckerTimer;
94    type State = CheckerState;
95
96    fn on_start(&self, id: Id, o: &mut Out<Self>) -> Self::State {
97        let id: usize = id.into();
98        let checker_runner = CheckerRunner::new(id as u64, self.peer_ids.len());
99
100        let raft_config = RaftConfig::new(
101            self.peer_ids.clone(),
102            id.to_string(),
103            RaftParams::default(),
104            None,
105        );
106        let raft_protocol = RaftProtocol::new(raft_config, checker_runner);
107        let state = CheckerState {
108            delivered_messages: vec![],
109            raft_protocol,
110        };
111
112        // set timer
113        o.set_timer(CheckerTimer::ElectionTimeout, model_timeout());
114        o.set_timer(CheckerTimer::ReplicationTimeout, model_timeout());
115        // broadcast a message (the id of the actor)
116        o.send(
117            Id::from(id),
118            CheckerMessage::Broadcast(id.to_string().into_bytes()),
119        );
120        state
121    }
122
123    fn on_msg(
124        &self,
125        _id: Id,
126        state: &mut Cow<Self::State>,
127        _src: Id,
128        msg: Self::Msg,
129        o: &mut Out<Self>,
130    ) {
131        let event = match msg {
132            CheckerMessage::VoteRequest(vote_request_args) => Event::VoteRequest(vote_request_args),
133            CheckerMessage::VoteResponse(vote_response_args) => {
134                Event::VoteResponse(vote_response_args)
135            }
136            CheckerMessage::LogRequest(log_request_args) => Event::LogRequest(log_request_args),
137            CheckerMessage::LogResponse(log_response_args) => Event::LogResponse(log_response_args),
138            CheckerMessage::Broadcast(payload) => Event::Broadcast(payload),
139        };
140        process_event(state, event, o);
141    }
142
143    fn on_timeout(
144        &self,
145        _id: Id,
146        state: &mut Cow<Self::State>,
147        timer: &Self::Timer,
148        o: &mut Out<Self>,
149    ) {
150        let event = match timer {
151            CheckerTimer::ElectionTimeout => Event::ElectionTimeout,
152            CheckerTimer::ReplicationTimeout => Event::ReplicationTimeout,
153        };
154        process_event(state, event, o);
155    }
156}
157
158#[derive(Clone)]
159pub struct RaftModelCfg {
160    pub server_count: usize,
161    pub network: Network<<CheckerActor as Actor>::Msg>,
162}
163
164impl RaftModelCfg {
165    pub fn into_model(self) -> ActorModel<CheckerActor, Self> {
166        let peers: Vec<String> = (0..self.server_count).map(|i| i.to_string()).collect();
167        ActorModel::new(self.clone(), ())
168            .max_crashes((self.server_count - 1) / 2)
169            .actors((0..self.server_count).map(|_| CheckerActor {
170                peer_ids: peers.clone(),
171            }))
172            .init_network(self.network)
173            .property(Expectation::Sometimes, "Election Liveness", |_, state| {
174                state.actor_states.iter().any(|s| {
175                    s.raft_protocol.state.current_role == crate::raft_protocol::Role::Leader
176                })
177            })
178            .property(Expectation::Sometimes, "Log Liveness", |_, state| {
179                state
180                    .actor_states
181                    .iter()
182                    .any(|s| s.raft_protocol.state.commit_length > 0)
183            })
184            .property(Expectation::Always, "Election Safety", |_, state| {
185                // at most one leader can be elected in a given term
186
187                let mut leaders_term = HashSet::new();
188                for s in &state.actor_states {
189                    if s.raft_protocol.state.current_role == crate::raft_protocol::Role::Leader
190                        && !leaders_term.insert(s.raft_protocol.state.current_term)
191                    {
192                        return false;
193                    }
194                }
195                true
196            })
197            .property(Expectation::Always, "State Machine Safety", |_, state| {
198                // if a server has applied a log entry at a given index to its state machine, no other server will
199                // ever apply a different log entry for the same index.
200
201                let mut max_commit_length = 0;
202                let mut max_commit_length_actor_id = 0;
203                for (i, s) in state.actor_states.iter().enumerate() {
204                    if s.delivered_messages.len() > max_commit_length {
205                        max_commit_length = s.delivered_messages.len();
206                        max_commit_length_actor_id = i;
207                    }
208                }
209                if max_commit_length == 0 {
210                    return true;
211                }
212
213                for i in 0..max_commit_length {
214                    let ref_log = state.actor_states[max_commit_length_actor_id]
215                        .delivered_messages
216                        .get(i)
217                        .unwrap();
218                    for s in &state.actor_states {
219                        if let Some(log) = s.delivered_messages.get(i) {
220                            if log != ref_log {
221                                println!("log mismatch: {:?} != {:?}", log, ref_log);
222                                return false;
223                            }
224                        }
225                    }
226                }
227                true
228            })
229    }
230}