one_file_raft/
lib.rs

1#![doc = include_str!("../README.md")]
2
3//!
4//! - Read the [Tutorial](`crate::docs::tutorial`);
5//! - Read the [Tutorial-cn](`crate::docs::tutorial_cn`);
6//!
7//! Features:
8//!
9//! - [x] Election(`Raft::elect()`)
10//! - [x] Log replication(`Raft::handle_replicate_req()`)
11//! - [x] Commit
12//! - [x] Write application data(`Raft::write()`)
13//! - [x] Membership store(`Store::configs`).
14//! - [x] Membership change: joint consensus.
15//! - [x] Event loop model(main loop: `Raft::run()`).
16//! - [x] Pseudo network simulated by mpsc channels(`Net`).
17//! - [x] Pseudo Log store simulated by in-memory store(`Store`).
18//! - [x] Raft log data is a simple `String`
19//! - [x] Metrics
20//!
21//! Not yet implemented:
22//! - [ ] State machine(`Raft::commit()` is a no-op entry)
23//! - [ ] Log compaction
24//! - [ ] Log purge
25//! - [ ] Heartbeat
26//! - [ ] Leader lease
27//! - [ ] Linearizable read
28//!
29//! Implementation details:
30//! - [x] Membership config takes effect once appended(not on-applied).
31//! - [x] Standalone Leader, it has to check vote when accessing local store.
32//! - [x] Leader access store directly(not via RPC).
33//! - [ ] Append log when vote?
34
35#![feature(map_try_insert)]
36
37mod display;
38pub mod docs;
39#[cfg(test)]
40mod tests;
41
42use std::cmp::max;
43use std::cmp::min;
44use std::cmp::Ordering;
45use std::collections::BTreeMap;
46use std::collections::BTreeSet;
47
48use derivative::Derivative;
49use derive_more::Display;
50use derive_new::new as New;
51use itertools::Itertools;
52use log::debug;
53use log::error;
54use log::info;
55use log::trace;
56use mpsc::UnboundedReceiver;
57use tokio::sync::mpsc;
58use tokio::sync::oneshot;
59use tokio::sync::watch;
60
61use crate::display::DisplayExt;
62
63#[derive(Debug, Clone, Default, Copy, PartialEq, Eq, Hash, Display)]
64#[display(fmt = "L({})", _0)]
65pub struct LeaderId(pub u64);
66
67impl PartialOrd for LeaderId {
68    fn partial_cmp(&self, b: &Self) -> Option<Ordering> {
69        [None, Some(Ordering::Equal)][(self.0 == b.0) as usize]
70    }
71}
72
73#[derive(Debug, Clone, Default, Copy, PartialEq, Eq, PartialOrd, Hash)]
74#[derive(New)]
75pub struct Vote {
76    pub term: u64,
77    pub committed: Option<()>,
78    pub voted_for: LeaderId,
79}
80
81#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Display)]
82#[display(fmt = "T{}-{}", term, index)]
83pub struct LogId {
84    term: u64,
85    index: u64,
86}
87
88#[derive(Debug, Clone, Default, New)]
89pub struct Log {
90    #[new(default)]
91    pub log_id: LogId,
92    pub data: Option<String>,
93    pub config: Option<Vec<BTreeSet<u64>>>,
94}
95
96#[derive(Debug, Default)]
97pub struct Net {
98    pub targets: BTreeMap<u64, mpsc::UnboundedSender<(u64, Event)>>,
99}
100
101impl Net {
102    fn send(&mut self, from: u64, target: u64, ev: Event) {
103        trace!("N{} send --> N{} {}", from, target, ev);
104        let tx = self.targets.get(&target).unwrap();
105        tx.send((from, ev)).unwrap();
106    }
107}
108
109#[derive(Debug)]
110pub struct Request {
111    vote: Vote,
112    last_log_id: LogId,
113
114    prev: LogId,
115    logs: Vec<Log>,
116    commit: u64,
117}
118
119#[derive(Debug)]
120pub struct Reply {
121    granted: bool,
122    vote: Vote,
123    log: Result<LogId, u64>,
124}
125
126#[derive(Display)]
127pub enum Event {
128    #[display(fmt = "Request({})", _0)]
129    Request(Request),
130    #[display(fmt = "Reply({})", _0)]
131    Reply(Reply),
132    #[display(fmt = "Write({})", _1)]
133    Write(oneshot::Sender<String>, Log),
134    #[display(fmt = "Func")]
135    Func(Box<dyn FnOnce(&mut Raft) + Send + 'static>),
136}
137
138#[derive(Debug, Clone, Copy, Derivative, PartialEq, Eq, New)]
139#[derivative(Default)]
140pub struct Progress {
141    acked: LogId,
142    len: u64,
143    /// It is a token to indicate if it can send an RPC, e.g., there is no inflight RPC sending.
144    /// It is set to `None` when an RPC is sent, and set to `Some(())` when the RPC is finished.
145    #[derivative(Default(value = "Some(())"))]
146    ready: Option<()>,
147}
148
149pub struct Leading {
150    granted_by: BTreeSet<u64>,
151    progresses: BTreeMap<u64, Progress>,
152    log_index_range: (u64, u64),
153}
154
155#[derive(Debug, Default, Clone, PartialEq, Eq, New)]
156pub struct Metrics {
157    pub vote: Vote,
158    pub last_log: LogId,
159    pub commit: u64,
160    pub config: Vec<BTreeSet<u64>>,
161    pub progresses: Option<BTreeMap<u64, Progress>>,
162}
163
164#[derive(Debug, Default)]
165pub struct Store {
166    id: u64,
167    vote: Vote,
168    configs: BTreeMap<u64, Vec<BTreeSet<u64>>>,
169    replies: BTreeMap<u64, oneshot::Sender<String>>,
170    logs: Vec<Log>,
171}
172
173impl Store {
174    pub fn new(membership: Vec<BTreeSet<u64>>) -> Self {
175        let mut configs = BTreeMap::new();
176        let vote = Vote::default();
177        configs.insert(0, membership);
178        let replies = BTreeMap::new();
179        Store { id: 0, vote, configs, replies, logs: vec![Log::default()] }
180    }
181
182    pub fn config(&self) -> &Vec<BTreeSet<u64>> {
183        self.configs.values().last().unwrap()
184    }
185
186    fn last(&self) -> LogId {
187        self.logs.last().map(|x| x.log_id).unwrap_or_default()
188    }
189
190    fn truncate(&mut self, log_id: LogId) {
191        debug!("truncate: {}", log_id);
192        self.replies.retain(|&x, _| x < log_id.index);
193        self.configs.retain(|&x, _| x < log_id.index);
194        self.logs.truncate(log_id.index as usize);
195    }
196
197    fn append(&mut self, logs: Vec<Log>) {
198        if logs.is_empty() {
199            return;
200        }
201        debug!("N{} append: [{}]", self.id, logs.iter().join(", "));
202        for log in logs {
203            if let Some(x) = self.get_log_id(log.log_id.index) {
204                if x != log.log_id {
205                    self.truncate(x);
206                } else {
207                    continue;
208                }
209            }
210            if let Some(ref membership) = log.config {
211                self.configs.insert(log.log_id.index, membership.clone());
212            }
213            self.logs.push(log);
214        }
215    }
216
217    fn get_log_id(&self, rel_index: u64) -> Option<LogId> {
218        self.logs.get(rel_index as usize).map(|x| x.log_id)
219    }
220
221    fn read_logs(&self, i: u64, n: u64) -> Vec<Log> {
222        if n == 0 {
223            return vec![];
224        }
225
226        let logs: Vec<_> = self.logs[i as usize..].iter().take(n as usize).cloned().collect();
227        debug!("N{} read_logs: [{i},+{n})={}", self.id, logs.iter().join(","));
228        logs
229    }
230}
231
232pub struct Raft {
233    pub id: u64,
234    pub leading: Option<Leading>,
235    pub commit: u64,
236    pub net: Net,
237    pub sto: Store,
238    pub metrics: watch::Sender<Metrics>,
239    pub rx: UnboundedReceiver<(u64, Event)>,
240}
241
242impl Raft {
243    pub fn new(id: u64, mut sto: Store, net: Net, rx: UnboundedReceiver<(u64, Event)>) -> Self {
244        let (metrics, _) = watch::channel(Metrics::default());
245        sto.id = id;
246        Raft { id, leading: None, commit: 0, net, sto, metrics, rx }
247    }
248
249    pub async fn run(mut self) -> Result<(), anyhow::Error> {
250        loop {
251            let mem = self.sto.config().clone();
252            #[allow(clippy::useless_asref)]
253            let ps = self.leading.as_ref().map(|x| x.progresses.clone());
254            let m = Metrics::new(self.sto.vote, self.sto.last(), self.commit, mem, ps);
255            self.metrics.send_replace(m);
256
257            let (from, ev) = self.rx.recv().await.ok_or(anyhow::anyhow!("closed"))?;
258            debug!("N{} recv <-- N{} {}", self.id, from, ev);
259            match ev {
260                Event::Request(req) => {
261                    let reply = self.handle_replicate_req(req);
262                    self.net.send(self.id, from, Event::Reply(reply));
263                }
264                Event::Reply(reply) => {
265                    self.handle_replicate_reply(from, reply);
266                }
267                Event::Write(tx, log) => {
268                    let res = self.write(tx, log.clone());
269                    if res.is_none() {
270                        error!("N{} leader can not write : {}", self.id, log);
271                    }
272                }
273                Event::Func(f) => {
274                    f(&mut self);
275                }
276            }
277        }
278    }
279
280    pub fn elect(&mut self) {
281        self.sto.vote = Vote::new(self.sto.vote.term + 1, None, LeaderId(self.id));
282
283        let noop_index = self.sto.last().index + 1;
284        let config = self.sto.config().clone();
285        let p = Progress::new(LogId::default(), noop_index, Some(()));
286
287        debug!("N{} elect: ids: {}", self.id, node_ids(&config).join(","));
288
289        self.leading = Some(Leading {
290            granted_by: BTreeSet::new(),
291            progresses: node_ids(&config).map(|id| (id, p)).collect(),
292            log_index_range: (noop_index, noop_index),
293        });
294
295        node_ids(&config).for_each(|id| self.send_if_idle(id, 0).unwrap_or(()));
296    }
297
298    pub fn write(&mut self, tx: oneshot::Sender<String>, mut log: Log) -> Option<LogId> {
299        self.sto.vote.committed?;
300        let l = self.leading.as_mut()?;
301
302        let log_id = LogId { term: self.sto.vote.term, index: l.log_index_range.1 };
303        l.log_index_range.1 += 1;
304        log.log_id = log_id;
305
306        if let Some(ref membership) = log.config {
307            if self.sto.configs.keys().last().copied().unwrap() > self.commit {
308                panic!("N{} can write membership: {} before committing the previous", self.id, log);
309            }
310            let ids = node_ids(membership);
311            l.progresses = ids.map(|x| (x, l.progresses.remove(&x).unwrap_or_default())).collect();
312            info!("N{} rebuild progresses: {}", self.id, l.progresses.display());
313        }
314        self.sto.replies.insert(log_id.index, tx);
315        self.sto.append(vec![log]);
316
317        // Mark it as sending, so that it won't be sent again.
318        l.progresses.insert(self.id, Progress::new(log_id, log_id.index, None));
319
320        node_ids(self.sto.config()).for_each(|id| self.send_if_idle(id, 10).unwrap_or(()));
321        Some(log_id)
322    }
323
324    pub fn handle_replicate_req(&mut self, req: Request) -> Reply {
325        let my_last = self.sto.last();
326        let (is_granted, vote) = self.check_vote(req.vote);
327        let is_upto_date = req.last_log_id >= my_last;
328
329        let req_last = req.logs.last().map(|x| x.log_id).unwrap_or(req.prev);
330
331        if is_granted && is_upto_date {
332            let log = if self.sto.get_log_id(req.prev.index) == Some(req.prev) {
333                self.sto.append(req.logs);
334                self.commit(min(req.commit, req_last.index));
335                Ok(req_last)
336            } else {
337                self.sto.truncate(req.prev);
338                Err(req.prev.index)
339            };
340
341            Reply { granted: true, vote, log }
342        } else {
343            Reply { granted: false, vote, log: Err(my_last.index + 1) }
344        }
345    }
346
347    pub fn handle_replicate_reply(&mut self, target: u64, reply: Reply) -> Option<Leading> {
348        let l = self.leading.as_mut()?;
349        let v = self.sto.vote;
350
351        let is_same_leader = reply.vote.term == v.term && reply.vote.voted_for == v.voted_for;
352
353        // 0. Set a replication channel to `ready`, once a reply is received.
354        if is_same_leader {
355            assert!(l.progresses[&target].ready.is_none());
356            l.progresses.get_mut(&target).unwrap().ready = Some(());
357        }
358
359        if reply.granted && is_same_leader {
360            // 1. Vote is granted, means that Log replication privilege is acquired.
361            if v.committed.is_none() {
362                debug!("N{} is granted by: N{}", self.id, target);
363                l.granted_by.insert(target);
364
365                if is_quorum(self.sto.config(), &l.granted_by) {
366                    self.sto.vote.committed = Some(());
367                    info!("N{} Leader established: {}", self.id, self.sto.vote);
368
369                    let (tx, _rx) = oneshot::channel();
370                    self.net.send(self.id, self.id, Event::Write(tx, Log::default()));
371                }
372            }
373
374            let p = l.progresses.get_mut(&target).unwrap();
375
376            // 2. Update the log replication progress
377
378            *p = match reply.log {
379                Ok(acked) => Progress::new(acked, max(p.len, acked.index + 1), Some(())),
380                Err(len) => Progress::new(p.acked, min(p.len, len), Some(())),
381            };
382            debug!("N{} progress N{target}={}", self.id, p);
383
384            // 3. Update committed index
385
386            let (noop_index, len) = l.log_index_range;
387            let acked = p.acked.index;
388
389            let acked_desc = l.progresses.values().map(|p| p.acked).sorted().rev();
390            let mut max_committed = acked_desc.filter(|acked| {
391                let greater_equal = l.progresses.iter().filter(|(_id, p)| p.acked >= *acked);
392                acked.index >= noop_index
393                    && is_quorum(self.sto.config(), greater_equal.map(|(id, _)| id))
394            });
395
396            if let Some(log_id) = max_committed.next() {
397                self.commit(log_id.index)
398            }
399
400            // 4. Keep sending
401            if len - 1 > acked {
402                self.send_if_idle(target, len - 1 - acked);
403            }
404        } else {
405            self.check_vote(reply.vote);
406        }
407        None
408    }
409
410    pub fn send_if_idle(&mut self, target: u64, n: u64) -> Option<()> {
411        let l = self.leading.as_mut().unwrap();
412
413        let p = l.progresses.get_mut(&target).unwrap();
414        trace!("send_if_idle: prog: N{}={:?}", target, p);
415        p.ready.take()?;
416
417        let prev = (p.acked.index + p.len) / 2;
418
419        let req = Request {
420            vote: self.sto.vote,
421            last_log_id: self.sto.last(),
422
423            prev: self.sto.get_log_id(prev).unwrap(),
424            logs: self.sto.read_logs(prev + 1, n),
425            commit: self.commit,
426        };
427
428        self.net.send(self.id, target, Event::Request(req));
429        Some(())
430    }
431
432    fn commit(&mut self, i: u64) {
433        if i > self.commit {
434            info!("N{} commit: {i}: {}", self.id, self.sto.logs[i as usize]);
435            self.commit = i;
436            let right = self.sto.replies.split_off(&(i + 1));
437            for (i, tx) in std::mem::replace(&mut self.sto.replies, right).into_iter() {
438                let _ = tx.send(format!("{}", i));
439            }
440        }
441    }
442
443    fn check_vote(&mut self, vote: Vote) -> (bool, Vote) {
444        trace!("N{} check_vote: my:{}, {}", self.id, self.sto.vote, vote);
445
446        if vote > self.sto.vote {
447            info!("N{} update_vote: {} --> {}", self.id, self.sto.vote, vote);
448            self.sto.vote = vote;
449
450            if vote.voted_for != LeaderId(self.id) && self.leading.is_some() {
451                info!("N{} Leading quit: vote:{}", self.id, self.sto.vote);
452                self.leading = None;
453            }
454        }
455
456        trace!("check_vote: ret: {}", self.sto.vote);
457        (vote == self.sto.vote, self.sto.vote)
458    }
459}
460
461pub fn is_quorum<'a>(config: &[BTreeSet<u64>], granted: impl IntoIterator<Item = &'a u64>) -> bool {
462    let granted = granted.into_iter().copied().collect::<BTreeSet<_>>();
463    for group in config {
464        if group.intersection(&granted).count() <= group.len() / 2 {
465            return false;
466        }
467    }
468    true
469}
470
471pub fn node_ids(config: &[BTreeSet<u64>]) -> impl Iterator<Item = u64> + 'static {
472    config.iter().flat_map(|x| x.iter().copied()).collect::<BTreeSet<_>>().into_iter()
473}