raftmodel/raftmodel/
raftserver.rs

1use crate::{append_entries, LogEntry, RaftMessage};
2use std::collections::HashSet;
3use std::default::Default;
4use std::fmt::Debug;
5
6/// The server states
7#[derive(Clone, PartialEq, Eq, Debug)]
8pub enum ServerState {
9    Leader,
10    Candidate,
11    Follower,
12}
13
14/// A single Raft server.
15/// A server only communicates via messages
16#[derive(Debug, Clone, PartialEq, Eq)]
17pub struct RaftServer<T>
18where
19    T: Sized + Clone + PartialEq + Eq + Debug + Default,
20{
21    // The following attributes are all per server
22    log: Vec<LogEntry<T>>,
23    state: ServerState,
24    current_term: usize,
25    voted_for: usize,
26    commit_index: usize,
27    last_applied: usize,
28
29    // The following attributes are used only on candidates
30    votes_responded: Option<HashSet<usize>>,
31    votes_granted: Option<HashSet<usize>>,
32    followers: Option<Vec<usize>>,
33
34    // The following attributes are used only on leaders
35    next_index: Option<Vec<usize>>,
36    match_index: Option<Vec<usize>>,
37}
38
39impl<T> RaftServer<T>
40where
41    T: Sized + Clone + PartialEq + Eq + Debug + Default,
42{
43    pub fn new(log: Vec<LogEntry<T>>) -> RaftServer<T> {
44        RaftServer {
45            log: log,
46            state: ServerState::Follower,
47            current_term: 1,
48            voted_for: 0,
49            commit_index: 0,
50            last_applied: 0,
51            votes_responded: Option::None,
52            votes_granted: Option::None,
53            followers: Option::None,
54            next_index: Option::None,
55            match_index: Option::None,
56        }
57    }
58
59    /// Returns the state of server as an immutable reference
60    pub fn server_state(&self) -> &ServerState {
61        return &self.state;
62    }
63
64    /// Returns an immutable reference to the server's log
65    pub fn log(&self) -> &Vec<LogEntry<T>> {
66        return &self.log;
67    }
68
69    /// This is the only public API to interact with the server
70    /// It takes an input message, dispatch it to some internal handlers based on the message type,
71    /// and returns a vector of messages as output.
72    /// It's up to the caller to decide what to do with the output messages
73    /// # Example
74    /// ```
75    /// use crate::raftmodel::*;
76    /// let log = create_empty_log::<String>();
77    ///
78    /// let mut servers = vec![
79    ///   RaftServer::new(log.clone()),
80    ///   RaftServer::new(log.clone()),
81    ///   RaftServer::new(log.clone()),
82    ///   RaftServer::new(log.clone()),
83    ///   RaftServer::new(log.clone()),
84    ///   RaftServer::new(log.clone()),
85    /// ];
86    ///
87    /// let message =
88    ///     RaftMessage::TimeOut {
89    ///         dest: 1,
90    ///         followers: (2..6).collect(),
91    ///     };
92    /// let server = &mut servers[1];
93    /// let responses = server.handle_message(message);
94    /// assert!(matches!(&responses[0], RaftMessage::RequestVoteRequest{..}));
95    /// assert_eq!(responses.len(), 4);
96    /// ```
97    pub fn handle_message(&mut self, msg: RaftMessage<T>) -> Vec<RaftMessage<T>> {
98        match msg {
99            RaftMessage::ClientRequest { dest, value } => self.handle_client_request(dest, value),
100            RaftMessage::BecomeLeader { dest, followers } => {
101                self.handle_become_leader(dest, followers)
102            }
103            RaftMessage::AppendEntries { dest, followers } => {
104                self.handle_append_entries(dest, followers)
105            }
106            RaftMessage::AppendEntriesRequest {
107                src,
108                dest,
109                term,
110                prev_index,
111                prev_term,
112                commit_index,
113                entries,
114            } => {
115                self.update_term(term);
116                self.handle_append_entries_request(
117                    src,
118                    dest,
119                    term,
120                    prev_index,
121                    prev_term,
122                    commit_index,
123                    entries,
124                )
125            }
126            RaftMessage::AppendEntriesResponse {
127                src,
128                dest,
129                term,
130                success,
131                match_index,
132            } => {
133                if term < self.current_term {
134                    return vec![];
135                }
136                self.update_term(term);
137                self.handle_append_entries_response(src, dest, term, success, match_index)
138            }
139            RaftMessage::RequestVoteRequest {
140                src,
141                dest,
142                term,
143                last_log_index,
144                last_log_term,
145            } => {
146                self.update_term(term);
147                self.handle_request_vote_request(src, dest, term, last_log_index, last_log_term)
148            }
149            RaftMessage::RequestVoteResponse {
150                src,
151                dest,
152                term,
153                vote_granted,
154            } => {
155                if term < self.current_term {
156                    return vec![];
157                }
158                self.update_term(term);
159                self.handle_request_vote_response(src, dest, term, vote_granted)
160            }
161            RaftMessage::TimeOut { dest, followers } => self.handle_time_out(dest, followers),
162        }
163    }
164
165    fn handle_client_request(&mut self, dest: usize, value: T) -> Vec<RaftMessage<T>> {
166        if self.state != ServerState::Leader {
167            return vec![];
168        }
169        let entries = vec![LogEntry {
170            term: self.current_term,
171            item: value,
172        }];
173        let prev_index = self.log.len() - 1;
174        let prev_term = self.log[prev_index].term;
175        // Call raftlog::append_entries
176        let success = append_entries(&mut self.log, prev_index, prev_term, entries);
177        if success {
178            self.match_index.as_mut().unwrap()[dest] = self.log.len() - 1;
179            self.next_index.as_mut().unwrap()[dest] = self.log.len();
180        }
181        vec![]
182    }
183
184    fn handle_become_leader(&mut self, dest: usize, followers: Vec<usize>) -> Vec<RaftMessage<T>> {
185        println!("{} become Leader", dest);
186        self.state = ServerState::Leader;
187        self.next_index = Some(vec![self.log.len(); followers.len() + 2]);
188        self.match_index = Some(vec![0; followers.len() + 2]);
189        return self.handle_append_entries(dest, followers);
190    }
191
192    fn handle_append_entries(&mut self, dest: usize, followers: Vec<usize>) -> Vec<RaftMessage<T>> {
193        if self.state != ServerState::Leader {
194            return vec![];
195        }
196        let mut msgs = vec![];
197        for follower in followers {
198            if follower == dest {
199                continue;
200            }
201            let next_idx = (self.next_index.as_ref().unwrap())[follower];
202            let prev_index = next_idx - 1;
203            let prev_term = if prev_index == 0 {
204                0
205            } else {
206                self.log[prev_index].term
207            };
208            let entries = self.log[next_idx..].to_vec();
209            msgs.push(RaftMessage::AppendEntriesRequest {
210                src: dest,
211                dest: follower,
212                term: self.current_term,
213                prev_index,
214                prev_term,
215                commit_index: self.commit_index,
216                entries,
217            });
218        }
219        msgs
220    }
221
222    fn handle_append_entries_request(
223        &mut self,
224        src: usize,
225        dest: usize,
226        term: usize,
227        prev_index: usize,
228        prev_term: usize,
229        commit_index: usize,
230        entries: Vec<LogEntry<T>>,
231    ) -> Vec<RaftMessage<T>> {
232        let mut msgs = vec![];
233        if term > self.current_term {
234            return msgs;
235        }
236        // Reject request
237        if term < self.current_term {
238            msgs.push(RaftMessage::AppendEntriesResponse {
239                src: dest,
240                dest: src,
241                term: self.current_term,
242                success: false,
243                match_index: 0,
244            });
245            return msgs;
246        }
247        // Return to follower state
248        if term == self.current_term && self.state == ServerState::Candidate {
249            self.state = ServerState::Follower;
250            return msgs;
251        }
252        let elen = entries.len();
253        if commit_index > self.commit_index {
254            self.commit_index = commit_index;
255            if self.commit_index > self.last_applied {
256                // To-do: send AppliedEntries message
257                self.last_applied = self.commit_index;
258            }
259        }
260        let success = append_entries(&mut self.log, prev_index, prev_term, entries);
261        let match_index = if success {
262            prev_index + elen
263        } else {
264            self.log.len() - 1
265        };
266        msgs.push(RaftMessage::AppendEntriesResponse {
267            src: dest,
268            dest: src,
269            term: self.current_term,
270            success,
271            match_index,
272        });
273
274        msgs
275    }
276
277    fn handle_append_entries_response(
278        &mut self,
279        src: usize,
280        dest: usize,
281        term: usize,
282        success: bool,
283        match_index: usize,
284    ) -> Vec<RaftMessage<T>> {
285        let mut msgs = vec![];
286        if term != self.current_term {
287            return msgs;
288        }
289        let next_index_mut = self.next_index.as_mut().unwrap();
290        let match_index_mut = self.match_index.as_mut().unwrap();
291        if !success {
292            next_index_mut[src] = next_index_mut[src] - 1;
293            let mut responses = self.handle_append_entries(dest, vec![src]);
294            msgs.append(&mut responses);
295        } else {
296            next_index_mut[src] = match_index + 1;
297            if match_index > match_index_mut[src] {
298                match_index_mut[src] = match_index;
299            }
300
301            self.advance_commit_index(dest);
302        }
303
304        msgs
305    }
306
307    fn handle_time_out(&mut self, dest: usize, followers: Vec<usize>) -> Vec<RaftMessage<T>> {
308        if self.state != ServerState::Follower && self.state != ServerState::Candidate {
309            return vec![];
310        }
311        self.state = ServerState::Candidate;
312        self.current_term = self.current_term + 1;
313        self.voted_for = dest;
314        self.votes_responded = Some(vec![dest].iter().cloned().collect());
315        self.votes_granted = Some(vec![dest].iter().cloned().collect());
316        self.followers = Some(followers.clone());
317        self.request_vote(dest, followers)
318    }
319
320    fn request_vote(&mut self, dest: usize, followers: Vec<usize>) -> Vec<RaftMessage<T>> {
321        let mut msgs = vec![];
322        if self.state != ServerState::Candidate {
323            return msgs;
324        }
325        for follower in followers {
326            if self.votes_responded.as_ref().unwrap().contains(&follower) {
327                continue;
328            }
329            let last_log_index = self.log.len() - 1;
330            let last_log_term = if last_log_index == 0 {
331                0
332            } else {
333                self.log[last_log_index].term
334            };
335            msgs.push(RaftMessage::RequestVoteRequest {
336                src: dest,
337                dest: follower,
338                term: self.current_term,
339                last_log_index: last_log_index,
340                last_log_term: last_log_term,
341            });
342            // dbg!(msgs.clone());
343        }
344        msgs
345    }
346
347    fn handle_request_vote_request(
348        &mut self,
349        src: usize,
350        dest: usize,
351        term: usize,
352        last_log_index: usize,
353        last_log_term: usize,
354    ) -> Vec<RaftMessage<T>> {
355        let mut msgs = vec![];
356        let last_term = if self.log.len() <= 1 {
357            0
358        } else {
359            self.log.last().unwrap().term
360        };
361        let log_ok = (last_log_term > last_term)
362            || (last_log_term == last_term && last_log_index >= self.log.len() - 1);
363        let grant =
364            (term == self.current_term) && log_ok && (self.voted_for == 0 || self.voted_for == src);
365        if term <= self.current_term {
366            if grant {
367                self.voted_for = src;
368            }
369            msgs.push(RaftMessage::RequestVoteResponse {
370                src: dest,
371                dest: src,
372                term: self.current_term,
373                vote_granted: grant,
374            });
375        }
376        // dbg!(msgs.clone());
377        msgs
378    }
379
380    fn handle_request_vote_response(
381        &mut self,
382        src: usize,
383        dest: usize,
384        term: usize,
385        vote_granted: bool,
386    ) -> Vec<RaftMessage<T>> {
387        // dbg!(src);
388        // dbg!(vote_granted);
389        // dbg!(term);
390        // dbg!(self.current_term);
391        // dbg!(self.state.clone());
392        if term != self.current_term {
393            //|| self.state != ServerState::Candidate {
394            return vec![];
395        }
396        self.votes_responded.as_mut().unwrap().insert(src);
397        if vote_granted {
398            self.votes_granted.as_mut().unwrap().insert(src);
399        }
400        // dbg!(self.votes_responded.clone());
401        // dbg!(self.votes_granted.clone());
402        let quorum = (self.followers.as_ref().unwrap().len() + 2) / 2;
403        // dbg!(quorum);
404        let followers = self.followers.as_ref().unwrap().clone();
405        if self.votes_granted.as_ref().unwrap().len() >= quorum {
406            self.handle_become_leader(dest, followers);
407        }
408        vec![]
409    }
410
411    fn update_term(&mut self, mterm: usize) {
412        if mterm > self.current_term {
413            self.current_term = mterm;
414            self.state = ServerState::Follower;
415            self.voted_for = 0;
416        }
417    }
418
419    fn advance_commit_index(&mut self, dest: usize) {
420        let mut match_index_cp = self.match_index.as_mut().unwrap().clone();
421
422        match_index_cp.sort_unstable();
423        let mid = match_index_cp.len() / 2 as usize;
424        let max_agree_index = match_index_cp[mid];
425        if self.log[max_agree_index].term >= self.current_term {
426            self.commit_index = max_agree_index;
427        }
428        if self.commit_index > self.last_applied {
429            // To-do: send ApplyEntries message
430            self.last_applied = self.commit_index;
431        }
432    }
433}
434
435#[cfg(test)]
436mod tests {
437    use super::*;
438    use std::collections::VecDeque;
439    fn run_message<T>(initial_message: RaftMessage<T>, servers: &mut Vec<RaftServer<T>>)
440    where
441        T: Sized + Clone + PartialEq + Eq + Debug + Default,
442    {
443        let mut messages = VecDeque::new();
444        messages.push_back(initial_message);
445        while let Some(msg) = messages.pop_front() {
446            let dest = match msg {
447                RaftMessage::ClientRequest { dest, .. }
448                | RaftMessage::BecomeLeader { dest, .. }
449                | RaftMessage::AppendEntries { dest, .. }
450                | RaftMessage::AppendEntriesRequest { dest, .. }
451                | RaftMessage::AppendEntriesResponse { dest, .. }
452                | RaftMessage::RequestVoteRequest { dest, .. }
453                | RaftMessage::RequestVoteResponse { dest, .. }
454                | RaftMessage::TimeOut { dest, .. } => dest,
455            };
456            let server = &mut servers[dest as usize];
457            let responses = server.handle_message(msg);
458            messages.append(&mut responses.into_iter().collect());
459        }
460    }
461
462    #[test]
463    fn test_replicate() {
464        let mut servers = vec![
465            RaftServer::new(vec![]),
466            RaftServer::new(vec![LogEntry::default(), LogEntry { term: 1, item: "x" }]),
467            RaftServer::new(vec![LogEntry::default()]),
468            RaftServer::new(vec![LogEntry::default()]),
469        ];
470
471        run_message(
472            RaftMessage::BecomeLeader {
473                dest: 1,
474                followers: vec![2, 3],
475            },
476            &mut servers,
477        );
478
479        run_message(
480            RaftMessage::AppendEntries {
481                dest: 1,
482                followers: vec![2, 3],
483            },
484            &mut servers,
485        );
486
487        assert_eq!(servers[1].log, servers[2].log);
488    }
489
490    fn make_log(terms: Vec<usize>) -> Vec<LogEntry<String>> {
491        let mut result: Vec<LogEntry<String>> = vec![LogEntry::default()];
492        for x in terms {
493            result.push(LogEntry {
494                term: x,
495                item: "a".to_string(),
496            });
497        }
498        result
499    }
500
501    #[test]
502    fn test_figure_6() {
503        let mut servers = vec![
504            RaftServer::new(vec![LogEntry::default()]),
505            RaftServer::new(make_log(vec![1, 1, 1, 2, 3, 3, 3, 3])),
506            RaftServer::new(make_log(vec![1, 1, 1, 2, 3])),
507            RaftServer::new(make_log(vec![1, 1, 1, 2, 3, 3, 3, 3])),
508            RaftServer::new(make_log(vec![1, 1])),
509            RaftServer::new(make_log(vec![1, 1, 1, 2, 3, 3, 3])),
510        ];
511
512        for server in &mut servers {
513            server.current_term = 3;
514        }
515
516        run_message(
517            RaftMessage::BecomeLeader {
518                dest: 1,
519                followers: (2..6).collect(),
520            },
521            &mut servers,
522        );
523
524        run_message(
525            RaftMessage::AppendEntries {
526                dest: 1,
527                followers: (2..6).collect(),
528            },
529            &mut servers,
530        );
531
532        // Check all the logs are identical
533        assert!(servers.iter().skip(1).all(|x| { x.log == servers[1].log }));
534
535        // After successful replication, the leader should have commited all its entries
536        assert_eq!(servers[1].commit_index, servers[1].log.len() - 1);
537    }
538
539    #[test]
540    fn test_figure_7() {
541        let mut servers = vec![
542            RaftServer::new(vec![LogEntry::default()]),
543            RaftServer::new(make_log(vec![1, 1, 1, 4, 4, 5, 5, 6, 6, 6])),
544            RaftServer::new(make_log(vec![1, 1, 1, 4, 4, 5, 5, 6, 6])),
545            RaftServer::new(make_log(vec![1, 1, 1, 4])),
546            RaftServer::new(make_log(vec![1, 1, 1, 4, 4, 5, 5, 6, 6, 6, 6])),
547            RaftServer::new(make_log(vec![1, 1, 1, 4, 4, 5, 5, 6, 6, 6, 7, 7])),
548            RaftServer::new(make_log(vec![1, 1, 1, 4, 4, 4, 4])),
549            RaftServer::new(make_log(vec![1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3])),
550        ];
551
552        for server in &mut servers {
553            server.current_term = 8;
554        }
555        servers[1].commit_index = 10;
556        run_message(
557            RaftMessage::BecomeLeader {
558                dest: 1,
559                followers: (2..8).collect(),
560            },
561            &mut servers,
562        );
563
564        run_message(
565            RaftMessage::ClientRequest {
566                dest: 1,
567                value: "x".to_string(),
568            },
569            &mut servers,
570        );
571
572        // The first AppendEntries will update leader commit_index
573        run_message(
574            RaftMessage::AppendEntries {
575                dest: 1,
576                followers: (2..8).collect(),
577            },
578            &mut servers,
579        );
580
581        // The second AppendEntries will update all followers commit_index
582        run_message(
583            RaftMessage::AppendEntries {
584                dest: 1,
585                followers: (2..8).collect(),
586            },
587            &mut servers,
588        );
589
590        assert!(servers.iter().skip(1).all(|x| { servers[1].log == x.log }));
591        assert_eq!(servers[1].commit_index, servers[1].log.len() - 1);
592        // dbg!(servers[1].match_index.clone());
593        // dbg!(servers[1].next_index.clone());
594        // for server in servers.iter().skip(1) {
595        //     dbg!(server.commit_index);
596        // }
597        // for server in servers.iter().skip(1) {
598        //     dbg!(server.last_applied);
599        // }
600    }
601
602    #[test]
603    fn test_commit() {
604        let mut servers = vec![
605            RaftServer::new(vec![LogEntry::default()]),
606            RaftServer::new(make_log(vec![1, 1, 1, 2, 2])),
607            RaftServer::new(make_log(vec![1, 1, 1, 2, 2])),
608            RaftServer::new(make_log(vec![1, 1, 1, 2, 2])),
609        ];
610
611        for server in &mut servers {
612            server.current_term = 2;
613        }
614
615        run_message(
616            RaftMessage::BecomeLeader {
617                dest: 1,
618                followers: vec![2, 3],
619            },
620            &mut servers,
621        );
622
623        run_message(
624            RaftMessage::ClientRequest {
625                dest: 1,
626                value: "x".to_string(),
627            },
628            &mut servers,
629        );
630
631        run_message(
632            RaftMessage::AppendEntries {
633                dest: 1,
634                followers: vec![2, 3],
635            },
636            &mut servers,
637        );
638
639        // The leader should have committed the entry. The followers should not because
640        // they won't learn about the commit index until leader send them another AppendEntries
641        assert_eq!(servers[1].commit_index, 6);
642        assert_eq!(servers[1].last_applied, 6);
643        assert!(servers.iter().skip(2).all(|x| { x.commit_index == 5 }));
644        assert!(servers.iter().skip(2).all(|x| { x.last_applied == 5 }));
645
646        // The followers will commit and apply after leader send another AppendEntries
647        run_message(
648            RaftMessage::AppendEntries {
649                dest: 1,
650                followers: vec![2, 3],
651            },
652            &mut servers,
653        );
654        assert!(servers.iter().skip(2).all(|x| { x.commit_index == 6 }));
655        assert!(servers.iter().skip(2).all(|x| { x.last_applied == 6 }));
656    }
657
658    #[test]
659    fn test_figure_6_election() {
660        let mut servers = vec![
661            RaftServer::new(vec![LogEntry::default()]),
662            RaftServer::new(make_log(vec![1, 1, 1, 2, 3, 3, 3, 3])),
663            RaftServer::new(make_log(vec![1, 1, 1, 2, 3])),
664            RaftServer::new(make_log(vec![1, 1, 1, 2, 3, 3, 3, 3])),
665            RaftServer::new(make_log(vec![1, 1])),
666            RaftServer::new(make_log(vec![1, 1, 1, 2, 3, 3, 3])),
667        ];
668
669        for server in &mut servers {
670            server.current_term = 3;
671        }
672
673        // Test: let server 1 time out to become a candidate. It should win the election with all votes
674        run_message(
675            RaftMessage::TimeOut {
676                dest: 1,
677                followers: (2..6).collect(),
678            },
679            &mut servers,
680        );
681        assert_eq!(servers[1].state, ServerState::Leader);
682        // dbg!(servers[1].votes_granted.as_ref().unwrap().clone());
683        assert_eq!(
684            servers[1].votes_granted.as_ref().unwrap().clone(),
685            (1..6).collect::<HashSet<usize>>()
686        );
687
688        // Test: server 2 will time out to become the candidate. It will lose the election and get only one vote
689        let mut servers = vec![
690            RaftServer::new(vec![LogEntry::default()]),
691            RaftServer::new(make_log(vec![1, 1, 1, 2, 3, 3, 3, 3])),
692            RaftServer::new(make_log(vec![1, 1, 1, 2, 3])),
693            RaftServer::new(make_log(vec![1, 1, 1, 2, 3, 3, 3, 3])),
694            RaftServer::new(make_log(vec![1, 1])),
695            RaftServer::new(make_log(vec![1, 1, 1, 2, 3, 3, 3])),
696        ];
697
698        for server in &mut servers {
699            server.current_term = 3;
700        }
701
702        run_message(
703            RaftMessage::TimeOut {
704                dest: 2,
705                followers: vec![1, 3, 4, 5].iter().cloned().collect(),
706            },
707            &mut servers,
708        );
709        assert_eq!(servers[2].state, ServerState::Candidate);
710        // dbg!(servers[1].votes_granted.as_ref().unwrap().clone());
711        assert_eq!(
712            servers[2].votes_granted.as_ref().unwrap().clone(),
713            vec![2, 4].iter().cloned().collect::<HashSet<usize>>()
714        );
715
716        // Test: server 5 will time out to become the candidate and will the election, but only get 3 votes
717        let mut servers = vec![
718            RaftServer::new(vec![LogEntry::default()]),
719            RaftServer::new(make_log(vec![1, 1, 1, 2, 3, 3, 3, 3])),
720            RaftServer::new(make_log(vec![1, 1, 1, 2, 3])),
721            RaftServer::new(make_log(vec![1, 1, 1, 2, 3, 3, 3, 3])),
722            RaftServer::new(make_log(vec![1, 1])),
723            RaftServer::new(make_log(vec![1, 1, 1, 2, 3, 3, 3])),
724        ];
725
726        for server in &mut servers {
727            server.current_term = 3;
728        }
729
730        run_message(
731            RaftMessage::TimeOut {
732                dest: 5,
733                followers: (1..5).collect(),
734            },
735            &mut servers,
736        );
737        assert_eq!(servers[5].state, ServerState::Leader);
738        // dbg!(servers[1].votes_granted.as_ref().unwrap().clone());
739        assert_eq!(
740            servers[5].votes_granted.as_ref().unwrap().clone(),
741            vec![2, 4, 5].iter().cloned().collect::<HashSet<usize>>()
742        );
743    }
744}