browseraft/
raft.rs

1use gloo::timers::callback::Timeout;
2use serde::{Deserialize, Serialize};
3use std::{collections::HashSet, sync::Arc};
4
5use crate::{
6    rpc::{Message, Recipient},
7    NodeState,
8};
9
10use super::Node;
11
12#[derive(Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Debug)]
13pub enum Role {
14    Follower,
15    Candidate,
16    Leader,
17}
18
19impl std::fmt::Display for Role {
20    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21        match self {
22            Role::Follower => write!(f, "Follower"),
23            Role::Candidate => write!(f, "Candidate"),
24            Role::Leader => write!(f, "Leader"),
25        }
26    }
27}
28
29#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash, PartialOrd, Ord)]
30pub struct Peer(u32);
31
32impl Peer {
33    pub fn id(&self) -> u32 {
34        self.0
35    }
36}
37
38impl<T> Node<T>
39where
40    T: serde::ser::Serialize + serde::de::DeserializeOwned + 'static,
41{
42    pub fn peer(&self) -> Peer {
43        Peer(self.id)
44    }
45
46    pub fn is(&self, peer: &Peer) -> bool {
47        peer.id() == self.id
48    }
49
50    pub(crate) fn add_peer(&self, peer: Peer) {
51        let mut state = self.state.lock().expect("poisoned mutex!");
52        state.peers.insert(peer);
53
54        if state.role == Role::Leader {
55            self.send(Message::PeerSet(state.peers.clone()), Recipient::Everyone)
56        }
57    }
58
59    pub(crate) fn remove_peer(&self, peer: Peer) {
60        let mut state = self.state.lock().expect("poisoned mutex!");
61        state.peers.remove(&peer);
62    }
63
64    /// When the leader sees PeerAdded message, it sends out a PeerSet response
65    /// so that all nodes know
66    pub(crate) fn reconcile_peers(&self, peers: HashSet<Peer>) {
67        let mut state = self.state.lock().expect("poisoned mutex!");
68        state.peers = peers;
69    }
70
71    pub(crate) fn new_election_task(self: Arc<Self>) -> Timeout {
72        Timeout::new(self.election_timeout_ms, || self.start_election())
73    }
74
75    /// When election times out (the node hasn't heard from it's leader in the
76    /// specified interval), this node becomes a candidate and initiates an
77    /// election.
78    fn start_election(self: Arc<Self>) {
79        let mut state = self.state.lock().expect("poisoned mutex!");
80        match state.peers.len() {
81            // When only node, automatically win the election
82            1 => {
83                drop(state);
84                self.clone().win_election();
85            }
86
87            // When two nodes, we could possibly deadlock. Default to the node with lower id
88            2 => {
89                // This requires waiting until the lower node's election times out
90                let mut iter = state.peers.iter();
91                let lower = std::cmp::min(iter.next().unwrap(), iter.next().unwrap());
92                if self.peer().eq(lower) {
93                    drop(state);
94                    self.clone().win_election();
95                }
96            }
97
98            // With at least 3 nodes, elect normally
99            _ => {
100                let candidate = self.peer();
101                state.role = Role::Candidate;
102                state.term += 1;
103                state.votes.insert(self.peer());
104                state.voted_for = Some(candidate);
105                state.election_task = Some(self.clone().new_election_task());
106
107                self.call_on_role_change(Role::Candidate);
108
109                self.send(
110                    Message::VoteRequest {
111                        term: state.term,
112                        candidate,
113                    },
114                    Recipient::Everyone,
115                );
116            }
117        }
118    }
119
120    /// Receive a vote from the given follower
121    pub(crate) fn receive_vote(self: Arc<Self>, term: u32, follower: Peer) {
122        let mut state = self.state.lock().expect("poisoned mutex!");
123        if state.role != Role::Candidate {
124            return;
125        }
126
127        match term.cmp(&state.term) {
128            // Got vote for a future term // TODO?
129            std::cmp::Ordering::Greater => {}
130            // Got vote for the current term: continue
131            std::cmp::Ordering::Equal => {}
132            // Got vote for previous term. Ignore
133            std::cmp::Ordering::Less => return,
134        }
135
136        state.votes.insert(follower);
137
138        if state.votes.len() > (state.peers.len() / 2) {
139            drop(state);
140            self.win_election()
141        }
142    }
143
144    /// Win the current election and start sending out heartbeats
145    fn win_election(self: Arc<Self>) {
146        {
147            let mut state = self.state.lock().expect("poisoned mutex!");
148            // if state.role != Role::Candidate {
149            //     return;
150            // }
151
152            state.role = Role::Leader;
153            state.voted_for = None;
154            state.votes.clear();
155            state.election_task = None;
156
157            self.call_on_role_change(Role::Leader);
158        }
159        self.send_heartbeat();
160    }
161
162    pub(crate) fn receive_vote_request(self: Arc<Self>, term: u32, candidate: Peer) {
163        let mut state = self.state.lock().expect("poisoned mutex!");
164
165        if self.peer() == candidate {
166            return;
167        }
168
169        // Update term
170        match term.cmp(&state.term) {
171            std::cmp::Ordering::Less => return,
172            std::cmp::Ordering::Equal => (),
173            std::cmp::Ordering::Greater => {
174                state.term = term;
175                state.voted_for = None;
176                state.votes.clear();
177            }
178        }
179
180        if state.role == Role::Follower && state.voted_for.is_none() {
181            state.voted_for = Some(candidate);
182            self.send(
183                Message::VoteResponse {
184                    term,
185                    candidate,
186                    follower: Peer(self.id),
187                },
188                Recipient::Peer(candidate),
189            );
190        }
191        state.replace_election_task(Some(self.clone().new_election_task()));
192    }
193
194    pub(crate) fn new_heartbeat_task(self: Arc<Self>) -> Timeout {
195        Timeout::new(self.heartbeat_timeout_ms, || self.send_heartbeat())
196    }
197
198    fn send_heartbeat(self: Arc<Self>) {
199        let mut state = self.state.lock().expect("poisoned mutex!");
200        self.send(Message::Heartbeat { term: state.term }, Recipient::Everyone);
201        state.replace_heartbeat_task(Some(self.clone().new_heartbeat_task()));
202    }
203
204    pub(crate) fn receive_hearbeat(self: Arc<Self>, term: u32, _leader: &Peer) {
205        let mut state = self.state.lock().expect("poisoned mutex!");
206
207        match state.role {
208            // Leader's own heartbeat
209            Role::Leader => return,
210
211            // Someone else won an election
212            Role::Candidate => {
213                if term >= state.term {
214                    state.role = Role::Follower;
215                    state.voted_for = None;
216                    state.votes.clear();
217                    self.call_on_role_change(Role::Follower);
218                }
219            }
220
221            // Update term if there's a new term
222            Role::Follower => {
223                if term > state.term {
224                    state.term = term;
225                    state.voted_for = None;
226                }
227            }
228        }
229
230        state.replace_election_task(Some(self.clone().new_election_task()));
231    }
232}
233
234impl<T> Drop for Node<T>
235where
236    T: serde::ser::Serialize + serde::de::DeserializeOwned + 'static,
237{
238    fn drop(&mut self) {
239        self.stop();
240        self.send(Message::PeerRemoved, Recipient::Everyone) // TODO: replace peer-removed with a leader heartbeat response count
241    }
242}
243
244impl NodeState {
245    pub(crate) fn replace_election_task(&mut self, new_task: Option<Timeout>) {
246        if let Some(old_task) = if let Some(new_task) = new_task {
247            self.election_task.replace(new_task)
248        } else {
249            self.election_task.take()
250        } {
251            old_task.cancel();
252        }
253    }
254
255    pub(crate) fn replace_heartbeat_task(&mut self, new_task: Option<Timeout>) {
256        if let Some(old_task) = if let Some(new_task) = new_task {
257            self.heartbeat_task.replace(new_task)
258        } else {
259            self.heartbeat_task.take()
260        } {
261            old_task.cancel();
262        }
263    }
264}