1#![doc = include_str!("../README.md")]
2
3#![feature(map_try_insert)]
36
37mod display;
38pub mod docs;
39#[cfg(test)]
40mod tests;
41
42use std::cmp::max;
43use std::cmp::min;
44use std::cmp::Ordering;
45use std::collections::BTreeMap;
46use std::collections::BTreeSet;
47
48use derivative::Derivative;
49use derive_more::Display;
50use derive_new::new as New;
51use itertools::Itertools;
52use log::debug;
53use log::error;
54use log::info;
55use log::trace;
56use mpsc::UnboundedReceiver;
57use tokio::sync::mpsc;
58use tokio::sync::oneshot;
59use tokio::sync::watch;
60
61use crate::display::DisplayExt;
62
63#[derive(Debug, Clone, Default, Copy, PartialEq, Eq, Hash, Display)]
64#[display(fmt = "L({})", _0)]
65pub struct LeaderId(pub u64);
66
67impl PartialOrd for LeaderId {
68 fn partial_cmp(&self, b: &Self) -> Option<Ordering> {
69 [None, Some(Ordering::Equal)][(self.0 == b.0) as usize]
70 }
71}
72
73#[derive(Debug, Clone, Default, Copy, PartialEq, Eq, PartialOrd, Hash)]
74#[derive(New)]
75pub struct Vote {
76 pub term: u64,
77 pub committed: Option<()>,
78 pub voted_for: LeaderId,
79}
80
81#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Display)]
82#[display(fmt = "T{}-{}", term, index)]
83pub struct LogId {
84 term: u64,
85 index: u64,
86}
87
88#[derive(Debug, Clone, Default, New)]
89pub struct Log {
90 #[new(default)]
91 pub log_id: LogId,
92 pub data: Option<String>,
93 pub config: Option<Vec<BTreeSet<u64>>>,
94}
95
96#[derive(Debug, Default)]
97pub struct Net {
98 pub targets: BTreeMap<u64, mpsc::UnboundedSender<(u64, Event)>>,
99}
100
101impl Net {
102 fn send(&mut self, from: u64, target: u64, ev: Event) {
103 trace!("N{} send --> N{} {}", from, target, ev);
104 let tx = self.targets.get(&target).unwrap();
105 tx.send((from, ev)).unwrap();
106 }
107}
108
109#[derive(Debug)]
110pub struct Request {
111 vote: Vote,
112 last_log_id: LogId,
113
114 prev: LogId,
115 logs: Vec<Log>,
116 commit: u64,
117}
118
119#[derive(Debug)]
120pub struct Reply {
121 granted: bool,
122 vote: Vote,
123 log: Result<LogId, u64>,
124}
125
126#[derive(Display)]
127pub enum Event {
128 #[display(fmt = "Request({})", _0)]
129 Request(Request),
130 #[display(fmt = "Reply({})", _0)]
131 Reply(Reply),
132 #[display(fmt = "Write({})", _1)]
133 Write(oneshot::Sender<String>, Log),
134 #[display(fmt = "Func")]
135 Func(Box<dyn FnOnce(&mut Raft) + Send + 'static>),
136}
137
138#[derive(Debug, Clone, Copy, Derivative, PartialEq, Eq, New)]
139#[derivative(Default)]
140pub struct Progress {
141 acked: LogId,
142 len: u64,
143 #[derivative(Default(value = "Some(())"))]
146 ready: Option<()>,
147}
148
149pub struct Leading {
150 granted_by: BTreeSet<u64>,
151 progresses: BTreeMap<u64, Progress>,
152 log_index_range: (u64, u64),
153}
154
155#[derive(Debug, Default, Clone, PartialEq, Eq, New)]
156pub struct Metrics {
157 pub vote: Vote,
158 pub last_log: LogId,
159 pub commit: u64,
160 pub config: Vec<BTreeSet<u64>>,
161 pub progresses: Option<BTreeMap<u64, Progress>>,
162}
163
164#[derive(Debug, Default)]
165pub struct Store {
166 id: u64,
167 vote: Vote,
168 configs: BTreeMap<u64, Vec<BTreeSet<u64>>>,
169 replies: BTreeMap<u64, oneshot::Sender<String>>,
170 logs: Vec<Log>,
171}
172
173impl Store {
174 pub fn new(membership: Vec<BTreeSet<u64>>) -> Self {
175 let mut configs = BTreeMap::new();
176 let vote = Vote::default();
177 configs.insert(0, membership);
178 let replies = BTreeMap::new();
179 Store { id: 0, vote, configs, replies, logs: vec![Log::default()] }
180 }
181
182 pub fn config(&self) -> &Vec<BTreeSet<u64>> {
183 self.configs.values().last().unwrap()
184 }
185
186 fn last(&self) -> LogId {
187 self.logs.last().map(|x| x.log_id).unwrap_or_default()
188 }
189
190 fn truncate(&mut self, log_id: LogId) {
191 debug!("truncate: {}", log_id);
192 self.replies.retain(|&x, _| x < log_id.index);
193 self.configs.retain(|&x, _| x < log_id.index);
194 self.logs.truncate(log_id.index as usize);
195 }
196
197 fn append(&mut self, logs: Vec<Log>) {
198 if logs.is_empty() {
199 return;
200 }
201 debug!("N{} append: [{}]", self.id, logs.iter().join(", "));
202 for log in logs {
203 if let Some(x) = self.get_log_id(log.log_id.index) {
204 if x != log.log_id {
205 self.truncate(x);
206 } else {
207 continue;
208 }
209 }
210 if let Some(ref membership) = log.config {
211 self.configs.insert(log.log_id.index, membership.clone());
212 }
213 self.logs.push(log);
214 }
215 }
216
217 fn get_log_id(&self, rel_index: u64) -> Option<LogId> {
218 self.logs.get(rel_index as usize).map(|x| x.log_id)
219 }
220
221 fn read_logs(&self, i: u64, n: u64) -> Vec<Log> {
222 if n == 0 {
223 return vec![];
224 }
225
226 let logs: Vec<_> = self.logs[i as usize..].iter().take(n as usize).cloned().collect();
227 debug!("N{} read_logs: [{i},+{n})={}", self.id, logs.iter().join(","));
228 logs
229 }
230}
231
232pub struct Raft {
233 pub id: u64,
234 pub leading: Option<Leading>,
235 pub commit: u64,
236 pub net: Net,
237 pub sto: Store,
238 pub metrics: watch::Sender<Metrics>,
239 pub rx: UnboundedReceiver<(u64, Event)>,
240}
241
242impl Raft {
243 pub fn new(id: u64, mut sto: Store, net: Net, rx: UnboundedReceiver<(u64, Event)>) -> Self {
244 let (metrics, _) = watch::channel(Metrics::default());
245 sto.id = id;
246 Raft { id, leading: None, commit: 0, net, sto, metrics, rx }
247 }
248
249 pub async fn run(mut self) -> Result<(), anyhow::Error> {
250 loop {
251 let mem = self.sto.config().clone();
252 #[allow(clippy::useless_asref)]
253 let ps = self.leading.as_ref().map(|x| x.progresses.clone());
254 let m = Metrics::new(self.sto.vote, self.sto.last(), self.commit, mem, ps);
255 self.metrics.send_replace(m);
256
257 let (from, ev) = self.rx.recv().await.ok_or(anyhow::anyhow!("closed"))?;
258 debug!("N{} recv <-- N{} {}", self.id, from, ev);
259 match ev {
260 Event::Request(req) => {
261 let reply = self.handle_replicate_req(req);
262 self.net.send(self.id, from, Event::Reply(reply));
263 }
264 Event::Reply(reply) => {
265 self.handle_replicate_reply(from, reply);
266 }
267 Event::Write(tx, log) => {
268 let res = self.write(tx, log.clone());
269 if res.is_none() {
270 error!("N{} leader can not write : {}", self.id, log);
271 }
272 }
273 Event::Func(f) => {
274 f(&mut self);
275 }
276 }
277 }
278 }
279
280 pub fn elect(&mut self) {
281 self.sto.vote = Vote::new(self.sto.vote.term + 1, None, LeaderId(self.id));
282
283 let noop_index = self.sto.last().index + 1;
284 let config = self.sto.config().clone();
285 let p = Progress::new(LogId::default(), noop_index, Some(()));
286
287 debug!("N{} elect: ids: {}", self.id, node_ids(&config).join(","));
288
289 self.leading = Some(Leading {
290 granted_by: BTreeSet::new(),
291 progresses: node_ids(&config).map(|id| (id, p)).collect(),
292 log_index_range: (noop_index, noop_index),
293 });
294
295 node_ids(&config).for_each(|id| self.send_if_idle(id, 0).unwrap_or(()));
296 }
297
298 pub fn write(&mut self, tx: oneshot::Sender<String>, mut log: Log) -> Option<LogId> {
299 self.sto.vote.committed?;
300 let l = self.leading.as_mut()?;
301
302 let log_id = LogId { term: self.sto.vote.term, index: l.log_index_range.1 };
303 l.log_index_range.1 += 1;
304 log.log_id = log_id;
305
306 if let Some(ref membership) = log.config {
307 if self.sto.configs.keys().last().copied().unwrap() > self.commit {
308 panic!("N{} can write membership: {} before committing the previous", self.id, log);
309 }
310 let ids = node_ids(membership);
311 l.progresses = ids.map(|x| (x, l.progresses.remove(&x).unwrap_or_default())).collect();
312 info!("N{} rebuild progresses: {}", self.id, l.progresses.display());
313 }
314 self.sto.replies.insert(log_id.index, tx);
315 self.sto.append(vec![log]);
316
317 l.progresses.insert(self.id, Progress::new(log_id, log_id.index, None));
319
320 node_ids(self.sto.config()).for_each(|id| self.send_if_idle(id, 10).unwrap_or(()));
321 Some(log_id)
322 }
323
324 pub fn handle_replicate_req(&mut self, req: Request) -> Reply {
325 let my_last = self.sto.last();
326 let (is_granted, vote) = self.check_vote(req.vote);
327 let is_upto_date = req.last_log_id >= my_last;
328
329 let req_last = req.logs.last().map(|x| x.log_id).unwrap_or(req.prev);
330
331 if is_granted && is_upto_date {
332 let log = if self.sto.get_log_id(req.prev.index) == Some(req.prev) {
333 self.sto.append(req.logs);
334 self.commit(min(req.commit, req_last.index));
335 Ok(req_last)
336 } else {
337 self.sto.truncate(req.prev);
338 Err(req.prev.index)
339 };
340
341 Reply { granted: true, vote, log }
342 } else {
343 Reply { granted: false, vote, log: Err(my_last.index + 1) }
344 }
345 }
346
347 pub fn handle_replicate_reply(&mut self, target: u64, reply: Reply) -> Option<Leading> {
348 let l = self.leading.as_mut()?;
349 let v = self.sto.vote;
350
351 let is_same_leader = reply.vote.term == v.term && reply.vote.voted_for == v.voted_for;
352
353 if is_same_leader {
355 assert!(l.progresses[&target].ready.is_none());
356 l.progresses.get_mut(&target).unwrap().ready = Some(());
357 }
358
359 if reply.granted && is_same_leader {
360 if v.committed.is_none() {
362 debug!("N{} is granted by: N{}", self.id, target);
363 l.granted_by.insert(target);
364
365 if is_quorum(self.sto.config(), &l.granted_by) {
366 self.sto.vote.committed = Some(());
367 info!("N{} Leader established: {}", self.id, self.sto.vote);
368
369 let (tx, _rx) = oneshot::channel();
370 self.net.send(self.id, self.id, Event::Write(tx, Log::default()));
371 }
372 }
373
374 let p = l.progresses.get_mut(&target).unwrap();
375
376 *p = match reply.log {
379 Ok(acked) => Progress::new(acked, max(p.len, acked.index + 1), Some(())),
380 Err(len) => Progress::new(p.acked, min(p.len, len), Some(())),
381 };
382 debug!("N{} progress N{target}={}", self.id, p);
383
384 let (noop_index, len) = l.log_index_range;
387 let acked = p.acked.index;
388
389 let acked_desc = l.progresses.values().map(|p| p.acked).sorted().rev();
390 let mut max_committed = acked_desc.filter(|acked| {
391 let greater_equal = l.progresses.iter().filter(|(_id, p)| p.acked >= *acked);
392 acked.index >= noop_index
393 && is_quorum(self.sto.config(), greater_equal.map(|(id, _)| id))
394 });
395
396 if let Some(log_id) = max_committed.next() {
397 self.commit(log_id.index)
398 }
399
400 if len - 1 > acked {
402 self.send_if_idle(target, len - 1 - acked);
403 }
404 } else {
405 self.check_vote(reply.vote);
406 }
407 None
408 }
409
410 pub fn send_if_idle(&mut self, target: u64, n: u64) -> Option<()> {
411 let l = self.leading.as_mut().unwrap();
412
413 let p = l.progresses.get_mut(&target).unwrap();
414 trace!("send_if_idle: prog: N{}={:?}", target, p);
415 p.ready.take()?;
416
417 let prev = (p.acked.index + p.len) / 2;
418
419 let req = Request {
420 vote: self.sto.vote,
421 last_log_id: self.sto.last(),
422
423 prev: self.sto.get_log_id(prev).unwrap(),
424 logs: self.sto.read_logs(prev + 1, n),
425 commit: self.commit,
426 };
427
428 self.net.send(self.id, target, Event::Request(req));
429 Some(())
430 }
431
432 fn commit(&mut self, i: u64) {
433 if i > self.commit {
434 info!("N{} commit: {i}: {}", self.id, self.sto.logs[i as usize]);
435 self.commit = i;
436 let right = self.sto.replies.split_off(&(i + 1));
437 for (i, tx) in std::mem::replace(&mut self.sto.replies, right).into_iter() {
438 let _ = tx.send(format!("{}", i));
439 }
440 }
441 }
442
443 fn check_vote(&mut self, vote: Vote) -> (bool, Vote) {
444 trace!("N{} check_vote: my:{}, {}", self.id, self.sto.vote, vote);
445
446 if vote > self.sto.vote {
447 info!("N{} update_vote: {} --> {}", self.id, self.sto.vote, vote);
448 self.sto.vote = vote;
449
450 if vote.voted_for != LeaderId(self.id) && self.leading.is_some() {
451 info!("N{} Leading quit: vote:{}", self.id, self.sto.vote);
452 self.leading = None;
453 }
454 }
455
456 trace!("check_vote: ret: {}", self.sto.vote);
457 (vote == self.sto.vote, self.sto.vote)
458 }
459}
460
461pub fn is_quorum<'a>(config: &[BTreeSet<u64>], granted: impl IntoIterator<Item = &'a u64>) -> bool {
462 let granted = granted.into_iter().copied().collect::<BTreeSet<_>>();
463 for group in config {
464 if group.intersection(&granted).count() <= group.len() / 2 {
465 return false;
466 }
467 }
468 true
469}
470
471pub fn node_ids(config: &[BTreeSet<u64>]) -> impl Iterator<Item = u64> + 'static {
472 config.iter().flat_map(|x| x.iter().copied()).collect::<BTreeSet<_>>().into_iter()
473}