1use crate::config::{RaftConfig, RaftParams};
2use crate::raft_protocol::Event;
3use crate::raft_protocol::{
4 LogRequestArgs, LogResponseArgs, RaftProtocol, StepOutput, VoteRequestArgs, VoteResponseArgs,
5};
6use crate::runner::{CheckerRunner};
7use stateright::actor::model_timeout;
8use stateright::actor::{Actor, ActorModel, Id, Network, Out};
9use stateright::Expectation;
10use std::borrow::Cow;
11use std::collections::HashSet;
12
13#[derive(Clone, Debug, Eq, Hash, PartialEq)]
14pub struct CheckerActor {
15 peer_ids: Vec<String>,
16}
17
18#[derive(Clone, Debug, Eq, Hash, PartialEq)]
19pub enum CheckerMessage {
20 VoteRequest(VoteRequestArgs),
21 VoteResponse(VoteResponseArgs),
22 LogRequest(LogRequestArgs),
23 LogResponse(LogResponseArgs),
24 Broadcast(Vec<u8>),
25}
26
27#[derive(Clone, Debug, Eq, Hash, PartialEq)]
28pub enum CheckerTimer {
29 ElectionTimeout,
30 ReplicationTimeout,
31}
32
33#[derive(Clone, Debug, Eq, Hash, PartialEq)]
34pub struct CheckerState {
35 delivered_messages: Vec<Vec<u8>>,
36 raft_protocol: RaftProtocol<CheckerRunner>,
37}
38
39fn process_event(
40 state: &mut Cow<CheckerState>,
41 event: Event,
42 o: &mut Out<CheckerActor>,
43) {
44 let state = state.to_mut();
45 let step_output = state.raft_protocol.step(event);
46 for output in step_output {
47 match output {
48 StepOutput::DeliverMessage(payload) => {
49 state.delivered_messages.push(payload);
50 }
51 StepOutput::VoteRequest(peer_id, vote_request_args) => {
52 o.send(
53 Id::from(peer_id as usize),
54 CheckerMessage::VoteRequest(vote_request_args),
55 );
56 }
57 StepOutput::VoteResponse(peer_id, vote_response_args) => {
58 o.send(
59 Id::from(peer_id as usize),
60 CheckerMessage::VoteResponse(vote_response_args),
61 );
62 }
63 StepOutput::LogRequest(peer_id, log_request_args) => {
64 o.send(
65 Id::from(peer_id as usize),
66 CheckerMessage::LogRequest(log_request_args),
67 );
68 }
69 StepOutput::LogResponse(peer_id, log_response_args) => {
70 o.send(
71 Id::from(peer_id as usize),
72 CheckerMessage::LogResponse(log_response_args),
73 );
74 }
75 StepOutput::Broadcast(peer_id, payload) => {
76 o.send(
77 Id::from(peer_id as usize),
78 CheckerMessage::Broadcast(payload),
79 );
80 }
81 StepOutput::ElectionTimerReset => {
82 o.set_timer(CheckerTimer::ElectionTimeout, model_timeout());
83 }
84 StepOutput::ReplicateTimerReset => {
85 o.set_timer(CheckerTimer::ReplicationTimeout, model_timeout());
86 }
87 }
88 }
89}
90
91impl Actor for CheckerActor {
92 type Msg = CheckerMessage;
93 type Timer = CheckerTimer;
94 type State = CheckerState;
95
96 fn on_start(&self, id: Id, o: &mut Out<Self>) -> Self::State {
97 let id: usize = id.into();
98 let checker_runner = CheckerRunner::new(id as u64, self.peer_ids.len());
99
100 let raft_config = RaftConfig::new(
101 self.peer_ids.clone(),
102 id.to_string(),
103 RaftParams::default(),
104 None,
105 );
106 let raft_protocol = RaftProtocol::new(raft_config, checker_runner);
107 let state = CheckerState {
108 delivered_messages: vec![],
109 raft_protocol,
110 };
111
112 o.set_timer(CheckerTimer::ElectionTimeout, model_timeout());
114 o.set_timer(CheckerTimer::ReplicationTimeout, model_timeout());
115 o.send(
117 Id::from(id),
118 CheckerMessage::Broadcast(id.to_string().into_bytes()),
119 );
120 state
121 }
122
123 fn on_msg(
124 &self,
125 _id: Id,
126 state: &mut Cow<Self::State>,
127 _src: Id,
128 msg: Self::Msg,
129 o: &mut Out<Self>,
130 ) {
131 let event = match msg {
132 CheckerMessage::VoteRequest(vote_request_args) => Event::VoteRequest(vote_request_args),
133 CheckerMessage::VoteResponse(vote_response_args) => {
134 Event::VoteResponse(vote_response_args)
135 }
136 CheckerMessage::LogRequest(log_request_args) => Event::LogRequest(log_request_args),
137 CheckerMessage::LogResponse(log_response_args) => Event::LogResponse(log_response_args),
138 CheckerMessage::Broadcast(payload) => Event::Broadcast(payload),
139 };
140 process_event(state, event, o);
141 }
142
143 fn on_timeout(
144 &self,
145 _id: Id,
146 state: &mut Cow<Self::State>,
147 timer: &Self::Timer,
148 o: &mut Out<Self>,
149 ) {
150 let event = match timer {
151 CheckerTimer::ElectionTimeout => Event::ElectionTimeout,
152 CheckerTimer::ReplicationTimeout => Event::ReplicationTimeout,
153 };
154 process_event(state, event, o);
155 }
156}
157
158#[derive(Clone)]
159pub struct RaftModelCfg {
160 pub server_count: usize,
161 pub network: Network<<CheckerActor as Actor>::Msg>,
162}
163
164impl RaftModelCfg {
165 pub fn into_model(self) -> ActorModel<CheckerActor, Self> {
166 let peers: Vec<String> = (0..self.server_count).map(|i| i.to_string()).collect();
167 ActorModel::new(self.clone(), ())
168 .max_crashes((self.server_count - 1) / 2)
169 .actors((0..self.server_count).map(|_| CheckerActor {
170 peer_ids: peers.clone(),
171 }))
172 .init_network(self.network)
173 .property(Expectation::Sometimes, "Election Liveness", |_, state| {
174 state.actor_states.iter().any(|s| {
175 s.raft_protocol.state.current_role == crate::raft_protocol::Role::Leader
176 })
177 })
178 .property(Expectation::Sometimes, "Log Liveness", |_, state| {
179 state
180 .actor_states
181 .iter()
182 .any(|s| s.raft_protocol.state.commit_length > 0)
183 })
184 .property(Expectation::Always, "Election Safety", |_, state| {
185 let mut leaders_term = HashSet::new();
188 for s in &state.actor_states {
189 if s.raft_protocol.state.current_role == crate::raft_protocol::Role::Leader
190 && !leaders_term.insert(s.raft_protocol.state.current_term)
191 {
192 return false;
193 }
194 }
195 true
196 })
197 .property(Expectation::Always, "State Machine Safety", |_, state| {
198 let mut max_commit_length = 0;
202 let mut max_commit_length_actor_id = 0;
203 for (i, s) in state.actor_states.iter().enumerate() {
204 if s.delivered_messages.len() > max_commit_length {
205 max_commit_length = s.delivered_messages.len();
206 max_commit_length_actor_id = i;
207 }
208 }
209 if max_commit_length == 0 {
210 return true;
211 }
212
213 for i in 0..max_commit_length {
214 let ref_log = state.actor_states[max_commit_length_actor_id]
215 .delivered_messages
216 .get(i)
217 .unwrap();
218 for s in &state.actor_states {
219 if let Some(log) = s.delivered_messages.get(i) {
220 if log != ref_log {
221 println!("log mismatch: {:?} != {:?}", log, ref_log);
222 return false;
223 }
224 }
225 }
226 }
227 true
228 })
229 }
230}