1use crate::{append_entries, LogEntry, RaftMessage};
2use std::collections::HashSet;
3use std::default::Default;
4use std::fmt::Debug;
5
6#[derive(Clone, PartialEq, Eq, Debug)]
8pub enum ServerState {
9 Leader,
10 Candidate,
11 Follower,
12}
13
14#[derive(Debug, Clone, PartialEq, Eq)]
17pub struct RaftServer<T>
18where
19 T: Sized + Clone + PartialEq + Eq + Debug + Default,
20{
21 log: Vec<LogEntry<T>>,
23 state: ServerState,
24 current_term: usize,
25 voted_for: usize,
26 commit_index: usize,
27 last_applied: usize,
28
29 votes_responded: Option<HashSet<usize>>,
31 votes_granted: Option<HashSet<usize>>,
32 followers: Option<Vec<usize>>,
33
34 next_index: Option<Vec<usize>>,
36 match_index: Option<Vec<usize>>,
37}
38
39impl<T> RaftServer<T>
40where
41 T: Sized + Clone + PartialEq + Eq + Debug + Default,
42{
43 pub fn new(log: Vec<LogEntry<T>>) -> RaftServer<T> {
44 RaftServer {
45 log: log,
46 state: ServerState::Follower,
47 current_term: 1,
48 voted_for: 0,
49 commit_index: 0,
50 last_applied: 0,
51 votes_responded: Option::None,
52 votes_granted: Option::None,
53 followers: Option::None,
54 next_index: Option::None,
55 match_index: Option::None,
56 }
57 }
58
59 pub fn server_state(&self) -> &ServerState {
61 return &self.state;
62 }
63
64 pub fn log(&self) -> &Vec<LogEntry<T>> {
66 return &self.log;
67 }
68
69 pub fn handle_message(&mut self, msg: RaftMessage<T>) -> Vec<RaftMessage<T>> {
98 match msg {
99 RaftMessage::ClientRequest { dest, value } => self.handle_client_request(dest, value),
100 RaftMessage::BecomeLeader { dest, followers } => {
101 self.handle_become_leader(dest, followers)
102 }
103 RaftMessage::AppendEntries { dest, followers } => {
104 self.handle_append_entries(dest, followers)
105 }
106 RaftMessage::AppendEntriesRequest {
107 src,
108 dest,
109 term,
110 prev_index,
111 prev_term,
112 commit_index,
113 entries,
114 } => {
115 self.update_term(term);
116 self.handle_append_entries_request(
117 src,
118 dest,
119 term,
120 prev_index,
121 prev_term,
122 commit_index,
123 entries,
124 )
125 }
126 RaftMessage::AppendEntriesResponse {
127 src,
128 dest,
129 term,
130 success,
131 match_index,
132 } => {
133 if term < self.current_term {
134 return vec![];
135 }
136 self.update_term(term);
137 self.handle_append_entries_response(src, dest, term, success, match_index)
138 }
139 RaftMessage::RequestVoteRequest {
140 src,
141 dest,
142 term,
143 last_log_index,
144 last_log_term,
145 } => {
146 self.update_term(term);
147 self.handle_request_vote_request(src, dest, term, last_log_index, last_log_term)
148 }
149 RaftMessage::RequestVoteResponse {
150 src,
151 dest,
152 term,
153 vote_granted,
154 } => {
155 if term < self.current_term {
156 return vec![];
157 }
158 self.update_term(term);
159 self.handle_request_vote_response(src, dest, term, vote_granted)
160 }
161 RaftMessage::TimeOut { dest, followers } => self.handle_time_out(dest, followers),
162 }
163 }
164
165 fn handle_client_request(&mut self, dest: usize, value: T) -> Vec<RaftMessage<T>> {
166 if self.state != ServerState::Leader {
167 return vec![];
168 }
169 let entries = vec![LogEntry {
170 term: self.current_term,
171 item: value,
172 }];
173 let prev_index = self.log.len() - 1;
174 let prev_term = self.log[prev_index].term;
175 let success = append_entries(&mut self.log, prev_index, prev_term, entries);
177 if success {
178 self.match_index.as_mut().unwrap()[dest] = self.log.len() - 1;
179 self.next_index.as_mut().unwrap()[dest] = self.log.len();
180 }
181 vec![]
182 }
183
184 fn handle_become_leader(&mut self, dest: usize, followers: Vec<usize>) -> Vec<RaftMessage<T>> {
185 println!("{} become Leader", dest);
186 self.state = ServerState::Leader;
187 self.next_index = Some(vec![self.log.len(); followers.len() + 2]);
188 self.match_index = Some(vec![0; followers.len() + 2]);
189 return self.handle_append_entries(dest, followers);
190 }
191
192 fn handle_append_entries(&mut self, dest: usize, followers: Vec<usize>) -> Vec<RaftMessage<T>> {
193 if self.state != ServerState::Leader {
194 return vec![];
195 }
196 let mut msgs = vec![];
197 for follower in followers {
198 if follower == dest {
199 continue;
200 }
201 let next_idx = (self.next_index.as_ref().unwrap())[follower];
202 let prev_index = next_idx - 1;
203 let prev_term = if prev_index == 0 {
204 0
205 } else {
206 self.log[prev_index].term
207 };
208 let entries = self.log[next_idx..].to_vec();
209 msgs.push(RaftMessage::AppendEntriesRequest {
210 src: dest,
211 dest: follower,
212 term: self.current_term,
213 prev_index,
214 prev_term,
215 commit_index: self.commit_index,
216 entries,
217 });
218 }
219 msgs
220 }
221
222 fn handle_append_entries_request(
223 &mut self,
224 src: usize,
225 dest: usize,
226 term: usize,
227 prev_index: usize,
228 prev_term: usize,
229 commit_index: usize,
230 entries: Vec<LogEntry<T>>,
231 ) -> Vec<RaftMessage<T>> {
232 let mut msgs = vec![];
233 if term > self.current_term {
234 return msgs;
235 }
236 if term < self.current_term {
238 msgs.push(RaftMessage::AppendEntriesResponse {
239 src: dest,
240 dest: src,
241 term: self.current_term,
242 success: false,
243 match_index: 0,
244 });
245 return msgs;
246 }
247 if term == self.current_term && self.state == ServerState::Candidate {
249 self.state = ServerState::Follower;
250 return msgs;
251 }
252 let elen = entries.len();
253 if commit_index > self.commit_index {
254 self.commit_index = commit_index;
255 if self.commit_index > self.last_applied {
256 self.last_applied = self.commit_index;
258 }
259 }
260 let success = append_entries(&mut self.log, prev_index, prev_term, entries);
261 let match_index = if success {
262 prev_index + elen
263 } else {
264 self.log.len() - 1
265 };
266 msgs.push(RaftMessage::AppendEntriesResponse {
267 src: dest,
268 dest: src,
269 term: self.current_term,
270 success,
271 match_index,
272 });
273
274 msgs
275 }
276
277 fn handle_append_entries_response(
278 &mut self,
279 src: usize,
280 dest: usize,
281 term: usize,
282 success: bool,
283 match_index: usize,
284 ) -> Vec<RaftMessage<T>> {
285 let mut msgs = vec![];
286 if term != self.current_term {
287 return msgs;
288 }
289 let next_index_mut = self.next_index.as_mut().unwrap();
290 let match_index_mut = self.match_index.as_mut().unwrap();
291 if !success {
292 next_index_mut[src] = next_index_mut[src] - 1;
293 let mut responses = self.handle_append_entries(dest, vec![src]);
294 msgs.append(&mut responses);
295 } else {
296 next_index_mut[src] = match_index + 1;
297 if match_index > match_index_mut[src] {
298 match_index_mut[src] = match_index;
299 }
300
301 self.advance_commit_index(dest);
302 }
303
304 msgs
305 }
306
307 fn handle_time_out(&mut self, dest: usize, followers: Vec<usize>) -> Vec<RaftMessage<T>> {
308 if self.state != ServerState::Follower && self.state != ServerState::Candidate {
309 return vec![];
310 }
311 self.state = ServerState::Candidate;
312 self.current_term = self.current_term + 1;
313 self.voted_for = dest;
314 self.votes_responded = Some(vec![dest].iter().cloned().collect());
315 self.votes_granted = Some(vec![dest].iter().cloned().collect());
316 self.followers = Some(followers.clone());
317 self.request_vote(dest, followers)
318 }
319
320 fn request_vote(&mut self, dest: usize, followers: Vec<usize>) -> Vec<RaftMessage<T>> {
321 let mut msgs = vec![];
322 if self.state != ServerState::Candidate {
323 return msgs;
324 }
325 for follower in followers {
326 if self.votes_responded.as_ref().unwrap().contains(&follower) {
327 continue;
328 }
329 let last_log_index = self.log.len() - 1;
330 let last_log_term = if last_log_index == 0 {
331 0
332 } else {
333 self.log[last_log_index].term
334 };
335 msgs.push(RaftMessage::RequestVoteRequest {
336 src: dest,
337 dest: follower,
338 term: self.current_term,
339 last_log_index: last_log_index,
340 last_log_term: last_log_term,
341 });
342 }
344 msgs
345 }
346
347 fn handle_request_vote_request(
348 &mut self,
349 src: usize,
350 dest: usize,
351 term: usize,
352 last_log_index: usize,
353 last_log_term: usize,
354 ) -> Vec<RaftMessage<T>> {
355 let mut msgs = vec![];
356 let last_term = if self.log.len() <= 1 {
357 0
358 } else {
359 self.log.last().unwrap().term
360 };
361 let log_ok = (last_log_term > last_term)
362 || (last_log_term == last_term && last_log_index >= self.log.len() - 1);
363 let grant =
364 (term == self.current_term) && log_ok && (self.voted_for == 0 || self.voted_for == src);
365 if term <= self.current_term {
366 if grant {
367 self.voted_for = src;
368 }
369 msgs.push(RaftMessage::RequestVoteResponse {
370 src: dest,
371 dest: src,
372 term: self.current_term,
373 vote_granted: grant,
374 });
375 }
376 msgs
378 }
379
380 fn handle_request_vote_response(
381 &mut self,
382 src: usize,
383 dest: usize,
384 term: usize,
385 vote_granted: bool,
386 ) -> Vec<RaftMessage<T>> {
387 if term != self.current_term {
393 return vec![];
395 }
396 self.votes_responded.as_mut().unwrap().insert(src);
397 if vote_granted {
398 self.votes_granted.as_mut().unwrap().insert(src);
399 }
400 let quorum = (self.followers.as_ref().unwrap().len() + 2) / 2;
403 let followers = self.followers.as_ref().unwrap().clone();
405 if self.votes_granted.as_ref().unwrap().len() >= quorum {
406 self.handle_become_leader(dest, followers);
407 }
408 vec![]
409 }
410
411 fn update_term(&mut self, mterm: usize) {
412 if mterm > self.current_term {
413 self.current_term = mterm;
414 self.state = ServerState::Follower;
415 self.voted_for = 0;
416 }
417 }
418
419 fn advance_commit_index(&mut self, dest: usize) {
420 let mut match_index_cp = self.match_index.as_mut().unwrap().clone();
421
422 match_index_cp.sort_unstable();
423 let mid = match_index_cp.len() / 2 as usize;
424 let max_agree_index = match_index_cp[mid];
425 if self.log[max_agree_index].term >= self.current_term {
426 self.commit_index = max_agree_index;
427 }
428 if self.commit_index > self.last_applied {
429 self.last_applied = self.commit_index;
431 }
432 }
433}
434
435#[cfg(test)]
436mod tests {
437 use super::*;
438 use std::collections::VecDeque;
439 fn run_message<T>(initial_message: RaftMessage<T>, servers: &mut Vec<RaftServer<T>>)
440 where
441 T: Sized + Clone + PartialEq + Eq + Debug + Default,
442 {
443 let mut messages = VecDeque::new();
444 messages.push_back(initial_message);
445 while let Some(msg) = messages.pop_front() {
446 let dest = match msg {
447 RaftMessage::ClientRequest { dest, .. }
448 | RaftMessage::BecomeLeader { dest, .. }
449 | RaftMessage::AppendEntries { dest, .. }
450 | RaftMessage::AppendEntriesRequest { dest, .. }
451 | RaftMessage::AppendEntriesResponse { dest, .. }
452 | RaftMessage::RequestVoteRequest { dest, .. }
453 | RaftMessage::RequestVoteResponse { dest, .. }
454 | RaftMessage::TimeOut { dest, .. } => dest,
455 };
456 let server = &mut servers[dest as usize];
457 let responses = server.handle_message(msg);
458 messages.append(&mut responses.into_iter().collect());
459 }
460 }
461
462 #[test]
463 fn test_replicate() {
464 let mut servers = vec![
465 RaftServer::new(vec![]),
466 RaftServer::new(vec![LogEntry::default(), LogEntry { term: 1, item: "x" }]),
467 RaftServer::new(vec![LogEntry::default()]),
468 RaftServer::new(vec![LogEntry::default()]),
469 ];
470
471 run_message(
472 RaftMessage::BecomeLeader {
473 dest: 1,
474 followers: vec![2, 3],
475 },
476 &mut servers,
477 );
478
479 run_message(
480 RaftMessage::AppendEntries {
481 dest: 1,
482 followers: vec![2, 3],
483 },
484 &mut servers,
485 );
486
487 assert_eq!(servers[1].log, servers[2].log);
488 }
489
490 fn make_log(terms: Vec<usize>) -> Vec<LogEntry<String>> {
491 let mut result: Vec<LogEntry<String>> = vec![LogEntry::default()];
492 for x in terms {
493 result.push(LogEntry {
494 term: x,
495 item: "a".to_string(),
496 });
497 }
498 result
499 }
500
501 #[test]
502 fn test_figure_6() {
503 let mut servers = vec![
504 RaftServer::new(vec![LogEntry::default()]),
505 RaftServer::new(make_log(vec![1, 1, 1, 2, 3, 3, 3, 3])),
506 RaftServer::new(make_log(vec![1, 1, 1, 2, 3])),
507 RaftServer::new(make_log(vec![1, 1, 1, 2, 3, 3, 3, 3])),
508 RaftServer::new(make_log(vec![1, 1])),
509 RaftServer::new(make_log(vec![1, 1, 1, 2, 3, 3, 3])),
510 ];
511
512 for server in &mut servers {
513 server.current_term = 3;
514 }
515
516 run_message(
517 RaftMessage::BecomeLeader {
518 dest: 1,
519 followers: (2..6).collect(),
520 },
521 &mut servers,
522 );
523
524 run_message(
525 RaftMessage::AppendEntries {
526 dest: 1,
527 followers: (2..6).collect(),
528 },
529 &mut servers,
530 );
531
532 assert!(servers.iter().skip(1).all(|x| { x.log == servers[1].log }));
534
535 assert_eq!(servers[1].commit_index, servers[1].log.len() - 1);
537 }
538
539 #[test]
540 fn test_figure_7() {
541 let mut servers = vec![
542 RaftServer::new(vec![LogEntry::default()]),
543 RaftServer::new(make_log(vec![1, 1, 1, 4, 4, 5, 5, 6, 6, 6])),
544 RaftServer::new(make_log(vec![1, 1, 1, 4, 4, 5, 5, 6, 6])),
545 RaftServer::new(make_log(vec![1, 1, 1, 4])),
546 RaftServer::new(make_log(vec![1, 1, 1, 4, 4, 5, 5, 6, 6, 6, 6])),
547 RaftServer::new(make_log(vec![1, 1, 1, 4, 4, 5, 5, 6, 6, 6, 7, 7])),
548 RaftServer::new(make_log(vec![1, 1, 1, 4, 4, 4, 4])),
549 RaftServer::new(make_log(vec![1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3])),
550 ];
551
552 for server in &mut servers {
553 server.current_term = 8;
554 }
555 servers[1].commit_index = 10;
556 run_message(
557 RaftMessage::BecomeLeader {
558 dest: 1,
559 followers: (2..8).collect(),
560 },
561 &mut servers,
562 );
563
564 run_message(
565 RaftMessage::ClientRequest {
566 dest: 1,
567 value: "x".to_string(),
568 },
569 &mut servers,
570 );
571
572 run_message(
574 RaftMessage::AppendEntries {
575 dest: 1,
576 followers: (2..8).collect(),
577 },
578 &mut servers,
579 );
580
581 run_message(
583 RaftMessage::AppendEntries {
584 dest: 1,
585 followers: (2..8).collect(),
586 },
587 &mut servers,
588 );
589
590 assert!(servers.iter().skip(1).all(|x| { servers[1].log == x.log }));
591 assert_eq!(servers[1].commit_index, servers[1].log.len() - 1);
592 }
601
602 #[test]
603 fn test_commit() {
604 let mut servers = vec![
605 RaftServer::new(vec![LogEntry::default()]),
606 RaftServer::new(make_log(vec![1, 1, 1, 2, 2])),
607 RaftServer::new(make_log(vec![1, 1, 1, 2, 2])),
608 RaftServer::new(make_log(vec![1, 1, 1, 2, 2])),
609 ];
610
611 for server in &mut servers {
612 server.current_term = 2;
613 }
614
615 run_message(
616 RaftMessage::BecomeLeader {
617 dest: 1,
618 followers: vec![2, 3],
619 },
620 &mut servers,
621 );
622
623 run_message(
624 RaftMessage::ClientRequest {
625 dest: 1,
626 value: "x".to_string(),
627 },
628 &mut servers,
629 );
630
631 run_message(
632 RaftMessage::AppendEntries {
633 dest: 1,
634 followers: vec![2, 3],
635 },
636 &mut servers,
637 );
638
639 assert_eq!(servers[1].commit_index, 6);
642 assert_eq!(servers[1].last_applied, 6);
643 assert!(servers.iter().skip(2).all(|x| { x.commit_index == 5 }));
644 assert!(servers.iter().skip(2).all(|x| { x.last_applied == 5 }));
645
646 run_message(
648 RaftMessage::AppendEntries {
649 dest: 1,
650 followers: vec![2, 3],
651 },
652 &mut servers,
653 );
654 assert!(servers.iter().skip(2).all(|x| { x.commit_index == 6 }));
655 assert!(servers.iter().skip(2).all(|x| { x.last_applied == 6 }));
656 }
657
658 #[test]
659 fn test_figure_6_election() {
660 let mut servers = vec![
661 RaftServer::new(vec![LogEntry::default()]),
662 RaftServer::new(make_log(vec![1, 1, 1, 2, 3, 3, 3, 3])),
663 RaftServer::new(make_log(vec![1, 1, 1, 2, 3])),
664 RaftServer::new(make_log(vec![1, 1, 1, 2, 3, 3, 3, 3])),
665 RaftServer::new(make_log(vec![1, 1])),
666 RaftServer::new(make_log(vec![1, 1, 1, 2, 3, 3, 3])),
667 ];
668
669 for server in &mut servers {
670 server.current_term = 3;
671 }
672
673 run_message(
675 RaftMessage::TimeOut {
676 dest: 1,
677 followers: (2..6).collect(),
678 },
679 &mut servers,
680 );
681 assert_eq!(servers[1].state, ServerState::Leader);
682 assert_eq!(
684 servers[1].votes_granted.as_ref().unwrap().clone(),
685 (1..6).collect::<HashSet<usize>>()
686 );
687
688 let mut servers = vec![
690 RaftServer::new(vec![LogEntry::default()]),
691 RaftServer::new(make_log(vec![1, 1, 1, 2, 3, 3, 3, 3])),
692 RaftServer::new(make_log(vec![1, 1, 1, 2, 3])),
693 RaftServer::new(make_log(vec![1, 1, 1, 2, 3, 3, 3, 3])),
694 RaftServer::new(make_log(vec![1, 1])),
695 RaftServer::new(make_log(vec![1, 1, 1, 2, 3, 3, 3])),
696 ];
697
698 for server in &mut servers {
699 server.current_term = 3;
700 }
701
702 run_message(
703 RaftMessage::TimeOut {
704 dest: 2,
705 followers: vec![1, 3, 4, 5].iter().cloned().collect(),
706 },
707 &mut servers,
708 );
709 assert_eq!(servers[2].state, ServerState::Candidate);
710 assert_eq!(
712 servers[2].votes_granted.as_ref().unwrap().clone(),
713 vec![2, 4].iter().cloned().collect::<HashSet<usize>>()
714 );
715
716 let mut servers = vec![
718 RaftServer::new(vec![LogEntry::default()]),
719 RaftServer::new(make_log(vec![1, 1, 1, 2, 3, 3, 3, 3])),
720 RaftServer::new(make_log(vec![1, 1, 1, 2, 3])),
721 RaftServer::new(make_log(vec![1, 1, 1, 2, 3, 3, 3, 3])),
722 RaftServer::new(make_log(vec![1, 1])),
723 RaftServer::new(make_log(vec![1, 1, 1, 2, 3, 3, 3])),
724 ];
725
726 for server in &mut servers {
727 server.current_term = 3;
728 }
729
730 run_message(
731 RaftMessage::TimeOut {
732 dest: 5,
733 followers: (1..5).collect(),
734 },
735 &mut servers,
736 );
737 assert_eq!(servers[5].state, ServerState::Leader);
738 assert_eq!(
740 servers[5].votes_granted.as_ref().unwrap().clone(),
741 vec![2, 4, 5].iter().cloned().collect::<HashSet<usize>>()
742 );
743 }
744}