1use std::collections::{HashMap, HashSet};
2
3use tokio::sync::oneshot;
4
5use crate::error::{KubericError, Result};
6use crate::types::{Lsn, ReplicaId};
7
8pub struct QuorumTracker {
14 pending: HashMap<Lsn, PendingOp>,
16 current_members: HashSet<ReplicaId>,
18 current_write_quorum: u32,
19 previous_members: HashSet<ReplicaId>,
21 previous_write_quorum: u32,
22 must_catch_up_ids: HashSet<ReplicaId>,
24 replica_acked_lsn: HashMap<ReplicaId, Lsn>,
26 committed_lsn: Lsn,
28 highest_lsn: Lsn,
30 catch_up_baseline_lsn: Lsn,
34 catch_up_waiters: Vec<CatchUpWaiter>,
36}
37
38struct CatchUpWaiter {
39 mode: crate::types::ReplicaSetQuorumMode,
40 reply: oneshot::Sender<Result<()>>,
41}
42
43struct PendingOp {
44 acked_by: HashSet<ReplicaId>,
46 reply: Option<oneshot::Sender<Result<Lsn>>>,
48 lsn: Lsn,
49}
50
51impl QuorumTracker {
52 pub fn new() -> Self {
53 Self {
54 pending: HashMap::new(),
55 current_members: HashSet::new(),
56 current_write_quorum: 0,
57 previous_members: HashSet::new(),
58 previous_write_quorum: 0,
59 must_catch_up_ids: HashSet::new(),
60 replica_acked_lsn: HashMap::new(),
61 committed_lsn: 0,
62 highest_lsn: 0,
63 catch_up_baseline_lsn: 0,
64 catch_up_waiters: Vec::new(),
65 }
66 }
67
68 pub fn committed_lsn(&self) -> Lsn {
69 self.committed_lsn
70 }
71
72 pub fn set_catch_up_configuration(
76 &mut self,
77 current_members: HashSet<ReplicaId>,
78 current_write_quorum: u32,
79 previous_members: HashSet<ReplicaId>,
80 previous_write_quorum: u32,
81 must_catch_up_ids: HashSet<ReplicaId>,
82 member_progress: HashMap<ReplicaId, Lsn>,
83 ) {
84 self.current_members = current_members;
85 self.current_write_quorum = current_write_quorum;
86 self.previous_members = previous_members;
87 self.previous_write_quorum = previous_write_quorum;
88 self.must_catch_up_ids = must_catch_up_ids;
89 self.catch_up_baseline_lsn = self.highest_lsn;
90
91 for (id, progress) in &member_progress {
94 self.replica_acked_lsn
95 .entry(*id)
96 .and_modify(|v| {
97 if *progress > *v {
98 *v = *progress;
99 }
100 })
101 .or_insert(*progress);
102 }
103 self.notify_catch_up_waiters();
104 }
105
106 pub fn set_current_configuration(
108 &mut self,
109 current_members: HashSet<ReplicaId>,
110 current_write_quorum: u32,
111 ) {
112 self.current_members = current_members;
113 self.current_write_quorum = current_write_quorum;
114 self.previous_members.clear();
115 self.previous_write_quorum = 0;
116 self.must_catch_up_ids.clear();
117 }
118
119 pub fn register(
122 &mut self,
123 lsn: Lsn,
124 primary_id: ReplicaId,
125 reply: oneshot::Sender<Result<Lsn>>,
126 ) {
127 if lsn > self.highest_lsn {
128 self.highest_lsn = lsn;
129 }
130
131 let mut acked_by = HashSet::new();
132 acked_by.insert(primary_id);
133
134 self.replica_acked_lsn
136 .entry(primary_id)
137 .and_modify(|v| {
138 if lsn > *v {
139 *v = lsn;
140 }
141 })
142 .or_insert(lsn);
143
144 let mut op = PendingOp {
145 acked_by,
146 reply: Some(reply),
147 lsn,
148 };
149
150 if self.is_quorum_met(&op.acked_by) {
152 self.commit_op(&mut op);
153 }
154
155 if op.reply.is_some() {
157 self.pending.insert(lsn, op);
158 } else {
159 self.notify_catch_up_waiters();
160 }
161 }
162
163 pub fn ack(&mut self, lsn: Lsn, replica_id: ReplicaId) {
166 self.replica_acked_lsn
168 .entry(replica_id)
169 .and_modify(|v| {
170 if lsn > *v {
171 *v = lsn;
172 }
173 })
174 .or_insert(lsn);
175
176 if let Some(op) = self.pending.get_mut(&lsn) {
177 op.acked_by.insert(replica_id);
178 } else {
179 self.notify_catch_up_waiters();
182 return;
183 }
184
185 let quorum_met = {
187 let op = self.pending.get(&lsn).unwrap();
188 self.is_quorum_met(&op.acked_by)
189 };
190
191 if quorum_met {
192 let mut op = self.pending.remove(&lsn).unwrap();
193 self.commit_op(&mut op);
194 self.try_commit_pending();
195 self.notify_catch_up_waiters();
196 }
197 }
198
199 pub fn fail_all(&mut self, error: KubericError) {
201 for (_, mut op) in self.pending.drain() {
202 if let Some(reply) = op.reply.take() {
203 let _ = reply.send(Err(match &error {
204 KubericError::NotPrimary => KubericError::NotPrimary,
205 KubericError::Closed => KubericError::Closed,
206 _ => KubericError::Internal(error.to_string().into()),
207 }));
208 }
209 }
210 for waiter in self.catch_up_waiters.drain(..) {
212 let _ = waiter.reply.send(Err(match &error {
213 KubericError::NotPrimary => KubericError::NotPrimary,
214 KubericError::Closed => KubericError::Closed,
215 _ => KubericError::Internal(error.to_string().into()),
216 }));
217 }
218 }
219
220 pub fn pending_count(&self) -> usize {
222 self.pending.len()
223 }
224
225 pub fn wait_for_catch_up(
233 &mut self,
234 mode: crate::types::ReplicaSetQuorumMode,
235 reply: oneshot::Sender<Result<()>>,
236 ) {
237 if self.is_caught_up(mode) {
238 let _ = reply.send(Ok(()));
239 } else {
240 self.catch_up_waiters.push(CatchUpWaiter { mode, reply });
241 }
242 }
243
244 fn is_caught_up(&self, mode: crate::types::ReplicaSetQuorumMode) -> bool {
245 if !self.pending.is_empty() {
246 return false;
247 }
248 let check_lsn = self.highest_lsn;
251 if check_lsn <= self.catch_up_baseline_lsn {
252 return true;
254 }
255 match mode {
256 crate::types::ReplicaSetQuorumMode::Write => {
257 for &id in &self.must_catch_up_ids {
258 let acked = self.replica_acked_lsn.get(&id).copied().unwrap_or(0);
259 if acked < check_lsn {
260 return false;
261 }
262 }
263 }
264 crate::types::ReplicaSetQuorumMode::All => {
265 for &id in &self.current_members {
266 let acked = self.replica_acked_lsn.get(&id).copied().unwrap_or(0);
267 if acked < check_lsn {
268 return false;
269 }
270 }
271 }
272 }
273 true
274 }
275
276 fn is_quorum_met(&self, acked_by: &HashSet<ReplicaId>) -> bool {
281 let cc_met =
282 self.count_acks_in_set(acked_by, &self.current_members) >= self.current_write_quorum;
283
284 if self.previous_members.is_empty() {
285 return cc_met;
286 }
287
288 let pc_met =
289 self.count_acks_in_set(acked_by, &self.previous_members) >= self.previous_write_quorum;
290
291 cc_met && pc_met
292 }
293
294 fn count_acks_in_set(
295 &self,
296 acked_by: &HashSet<ReplicaId>,
297 members: &HashSet<ReplicaId>,
298 ) -> u32 {
299 acked_by.intersection(members).count() as u32
300 }
301
302 fn commit_op(&mut self, op: &mut PendingOp) {
303 if op.lsn > self.committed_lsn {
304 self.committed_lsn = op.lsn;
305 }
306 if let Some(reply) = op.reply.take() {
307 let _ = reply.send(Ok(op.lsn));
308 }
309 }
310
311 fn notify_catch_up_waiters(&mut self) {
312 if self.catch_up_waiters.is_empty() {
313 return;
314 }
315 let waiters = std::mem::take(&mut self.catch_up_waiters);
316 for waiter in waiters {
317 if self.is_caught_up(waiter.mode) {
318 let _ = waiter.reply.send(Ok(()));
319 } else {
320 self.catch_up_waiters.push(waiter);
321 }
322 }
323 }
324
325 fn try_commit_pending(&mut self) {
326 let mut to_remove = Vec::new();
327 for (lsn, op) in &self.pending {
328 if self.is_quorum_met(&op.acked_by) {
329 to_remove.push(*lsn);
330 }
331 }
332 for lsn in to_remove {
333 if let Some(mut op) = self.pending.remove(&lsn) {
334 self.commit_op(&mut op);
335 }
336 }
337 }
338}
339
340impl Default for QuorumTracker {
341 fn default() -> Self {
342 Self::new()
343 }
344}
345
346#[cfg(test)]
347mod tests {
348 use super::*;
349
350 #[tokio::test]
351 async fn test_single_replica_commits_immediately() {
352 let mut tracker = QuorumTracker::new();
353 let primary_id = 1;
354
355 tracker.set_current_configuration(
357 HashSet::from([primary_id]),
358 1, );
360
361 let (tx, rx) = oneshot::channel();
362 tracker.register(1, primary_id, tx);
363
364 let lsn = rx.await.unwrap().unwrap();
366 assert_eq!(lsn, 1);
367 assert_eq!(tracker.committed_lsn(), 1);
368 assert_eq!(tracker.pending_count(), 0);
369 }
370
371 #[tokio::test]
372 async fn test_three_replicas_quorum() {
373 let mut tracker = QuorumTracker::new();
374 let primary_id = 1;
375
376 tracker.set_current_configuration(
378 HashSet::from([1, 2, 3]),
379 2, );
381
382 let (tx, rx) = oneshot::channel();
383 tracker.register(1, primary_id, tx);
384
385 assert_eq!(tracker.pending_count(), 1);
387
388 tracker.ack(1, 2);
390
391 let lsn = rx.await.unwrap().unwrap();
392 assert_eq!(lsn, 1);
393 assert_eq!(tracker.committed_lsn(), 1);
394 assert_eq!(tracker.pending_count(), 0);
395 }
396
397 #[tokio::test]
398 async fn test_dual_config_quorum() {
399 let mut tracker = QuorumTracker::new();
400 let primary_id = 1;
401
402 tracker.set_catch_up_configuration(
404 HashSet::from([1, 2, 3]),
405 2,
406 HashSet::from([1, 2]),
407 2,
408 HashSet::new(),
409 HashMap::new(),
410 );
411
412 let (tx, rx) = oneshot::channel();
413 tracker.register(1, primary_id, tx);
414
415 assert_eq!(tracker.pending_count(), 1);
417
418 tracker.ack(1, 3);
420 assert_eq!(tracker.pending_count(), 1);
421
422 tracker.ack(1, 2);
424
425 let lsn = rx.await.unwrap().unwrap();
426 assert_eq!(lsn, 1);
427 assert_eq!(tracker.pending_count(), 0);
428 }
429
430 #[tokio::test]
431 async fn test_out_of_order_acks() {
432 let mut tracker = QuorumTracker::new();
433 let primary_id = 1;
434
435 tracker.set_current_configuration(HashSet::from([1, 2, 3]), 2);
436
437 let (tx1, rx1) = oneshot::channel();
438 let (tx2, rx2) = oneshot::channel();
439 tracker.register(1, primary_id, tx1);
440 tracker.register(2, primary_id, tx2);
441
442 tracker.ack(2, 2);
444 let lsn2 = rx2.await.unwrap().unwrap();
445 assert_eq!(lsn2, 2);
446
447 tracker.ack(1, 2);
449 let lsn1 = rx1.await.unwrap().unwrap();
450 assert_eq!(lsn1, 1);
451
452 assert_eq!(tracker.committed_lsn(), 2);
453 }
454
455 #[tokio::test]
456 async fn test_fail_all() {
457 let mut tracker = QuorumTracker::new();
458
459 tracker.set_current_configuration(HashSet::from([1, 2, 3]), 2);
460
461 let (tx1, rx1) = oneshot::channel();
462 let (tx2, rx2) = oneshot::channel();
463 tracker.register(1, 1, tx1);
464 tracker.register(2, 1, tx2);
465
466 tracker.fail_all(KubericError::NotPrimary);
467
468 let result1 = rx1.await.unwrap();
469 assert!(matches!(result1, Err(KubericError::NotPrimary)));
470
471 let result2 = rx2.await.unwrap();
472 assert!(matches!(result2, Err(KubericError::NotPrimary)));
473
474 assert_eq!(tracker.pending_count(), 0);
475 }
476
477 #[tokio::test]
478 async fn test_must_catch_up_enforcement() {
479 use crate::types::ReplicaSetQuorumMode;
480
481 let mut tracker = QuorumTracker::new();
482
483 tracker.set_catch_up_configuration(
485 HashSet::from([1, 2, 3]),
486 2,
487 HashSet::new(),
488 0,
489 HashSet::from([2]),
490 HashMap::from([(2, 0), (3, 0)]), );
492
493 let (tx, rx) = oneshot::channel();
494 tracker.register(1, 1, tx);
495
496 tracker.ack(1, 3);
498 let lsn = rx.await.unwrap().unwrap();
499 assert_eq!(lsn, 1);
500 assert_eq!(tracker.pending_count(), 0);
501
502 let (wait_tx, mut wait_rx) = oneshot::channel();
505 tracker.wait_for_catch_up(ReplicaSetQuorumMode::Write, wait_tx);
506
507 assert!(wait_rx.try_recv().is_err());
509
510 tracker.ack(1, 2);
512
513 let result = wait_rx.await.unwrap();
515 assert!(result.is_ok());
516 }
517
518 #[tokio::test]
519 async fn test_wait_catch_up_all_mode() {
520 use crate::types::ReplicaSetQuorumMode;
521
522 let mut tracker = QuorumTracker::new();
523 tracker.set_current_configuration(HashSet::from([1, 2, 3]), 2);
524
525 let (tx, _rx) = oneshot::channel();
526 tracker.register(1, 1, tx);
527
528 let (wait_tx, mut wait_rx) = oneshot::channel();
530 tracker.wait_for_catch_up(ReplicaSetQuorumMode::All, wait_tx);
531 assert!(wait_rx.try_recv().is_err());
532
533 tracker.ack(1, 2);
535 assert!(wait_rx.try_recv().is_err());
537
538 tracker.ack(1, 3);
540 let result = wait_rx.await.unwrap();
541 assert!(result.is_ok());
542 }
543}