use std::collections::{HashMap, HashSet};
use tokio::sync::oneshot;
use crate::error::{KubericError, Result};
use crate::types::{Lsn, ReplicaId};
pub struct QuorumTracker {
pending: HashMap<Lsn, PendingOp>,
current_members: HashSet<ReplicaId>,
current_write_quorum: u32,
previous_members: HashSet<ReplicaId>,
previous_write_quorum: u32,
must_catch_up_ids: HashSet<ReplicaId>,
replica_acked_lsn: HashMap<ReplicaId, Lsn>,
committed_lsn: Lsn,
highest_lsn: Lsn,
catch_up_baseline_lsn: Lsn,
catch_up_waiters: Vec<CatchUpWaiter>,
}
struct CatchUpWaiter {
mode: crate::types::ReplicaSetQuorumMode,
reply: oneshot::Sender<Result<()>>,
}
struct PendingOp {
acked_by: HashSet<ReplicaId>,
reply: Option<oneshot::Sender<Result<Lsn>>>,
lsn: Lsn,
}
impl QuorumTracker {
pub fn new() -> Self {
Self {
pending: HashMap::new(),
current_members: HashSet::new(),
current_write_quorum: 0,
previous_members: HashSet::new(),
previous_write_quorum: 0,
must_catch_up_ids: HashSet::new(),
replica_acked_lsn: HashMap::new(),
committed_lsn: 0,
highest_lsn: 0,
catch_up_baseline_lsn: 0,
catch_up_waiters: Vec::new(),
}
}
pub fn committed_lsn(&self) -> Lsn {
self.committed_lsn
}
pub fn set_catch_up_configuration(
&mut self,
current_members: HashSet<ReplicaId>,
current_write_quorum: u32,
previous_members: HashSet<ReplicaId>,
previous_write_quorum: u32,
must_catch_up_ids: HashSet<ReplicaId>,
member_progress: HashMap<ReplicaId, Lsn>,
) {
self.current_members = current_members;
self.current_write_quorum = current_write_quorum;
self.previous_members = previous_members;
self.previous_write_quorum = previous_write_quorum;
self.must_catch_up_ids = must_catch_up_ids;
self.catch_up_baseline_lsn = self.highest_lsn;
for (id, progress) in &member_progress {
self.replica_acked_lsn
.entry(*id)
.and_modify(|v| {
if *progress > *v {
*v = *progress;
}
})
.or_insert(*progress);
}
self.notify_catch_up_waiters();
}
pub fn set_current_configuration(
&mut self,
current_members: HashSet<ReplicaId>,
current_write_quorum: u32,
) {
self.current_members = current_members;
self.current_write_quorum = current_write_quorum;
self.previous_members.clear();
self.previous_write_quorum = 0;
self.must_catch_up_ids.clear();
}
pub fn register(
&mut self,
lsn: Lsn,
primary_id: ReplicaId,
reply: oneshot::Sender<Result<Lsn>>,
) {
if lsn > self.highest_lsn {
self.highest_lsn = lsn;
}
let mut acked_by = HashSet::new();
acked_by.insert(primary_id);
self.replica_acked_lsn
.entry(primary_id)
.and_modify(|v| {
if lsn > *v {
*v = lsn;
}
})
.or_insert(lsn);
let mut op = PendingOp {
acked_by,
reply: Some(reply),
lsn,
};
if self.is_quorum_met(&op.acked_by) {
self.commit_op(&mut op);
}
if op.reply.is_some() {
self.pending.insert(lsn, op);
} else {
self.notify_catch_up_waiters();
}
}
pub fn ack(&mut self, lsn: Lsn, replica_id: ReplicaId) {
self.replica_acked_lsn
.entry(replica_id)
.and_modify(|v| {
if lsn > *v {
*v = lsn;
}
})
.or_insert(lsn);
if let Some(op) = self.pending.get_mut(&lsn) {
op.acked_by.insert(replica_id);
} else {
self.notify_catch_up_waiters();
return;
}
let quorum_met = {
let op = self.pending.get(&lsn).unwrap();
self.is_quorum_met(&op.acked_by)
};
if quorum_met {
let mut op = self.pending.remove(&lsn).unwrap();
self.commit_op(&mut op);
self.try_commit_pending();
self.notify_catch_up_waiters();
}
}
pub fn fail_all(&mut self, error: KubericError) {
for (_, mut op) in self.pending.drain() {
if let Some(reply) = op.reply.take() {
let _ = reply.send(Err(match &error {
KubericError::NotPrimary => KubericError::NotPrimary,
KubericError::Closed => KubericError::Closed,
_ => KubericError::Internal(error.to_string().into()),
}));
}
}
for waiter in self.catch_up_waiters.drain(..) {
let _ = waiter.reply.send(Err(match &error {
KubericError::NotPrimary => KubericError::NotPrimary,
KubericError::Closed => KubericError::Closed,
_ => KubericError::Internal(error.to_string().into()),
}));
}
}
pub fn pending_count(&self) -> usize {
self.pending.len()
}
pub fn wait_for_catch_up(
&mut self,
mode: crate::types::ReplicaSetQuorumMode,
reply: oneshot::Sender<Result<()>>,
) {
if self.is_caught_up(mode) {
let _ = reply.send(Ok(()));
} else {
self.catch_up_waiters.push(CatchUpWaiter { mode, reply });
}
}
fn is_caught_up(&self, mode: crate::types::ReplicaSetQuorumMode) -> bool {
if !self.pending.is_empty() {
return false;
}
let check_lsn = self.highest_lsn;
if check_lsn <= self.catch_up_baseline_lsn {
return true;
}
match mode {
crate::types::ReplicaSetQuorumMode::Write => {
for &id in &self.must_catch_up_ids {
let acked = self.replica_acked_lsn.get(&id).copied().unwrap_or(0);
if acked < check_lsn {
return false;
}
}
}
crate::types::ReplicaSetQuorumMode::All => {
for &id in &self.current_members {
let acked = self.replica_acked_lsn.get(&id).copied().unwrap_or(0);
if acked < check_lsn {
return false;
}
}
}
}
true
}
fn is_quorum_met(&self, acked_by: &HashSet<ReplicaId>) -> bool {
let cc_met =
self.count_acks_in_set(acked_by, &self.current_members) >= self.current_write_quorum;
if self.previous_members.is_empty() {
return cc_met;
}
let pc_met =
self.count_acks_in_set(acked_by, &self.previous_members) >= self.previous_write_quorum;
cc_met && pc_met
}
fn count_acks_in_set(
&self,
acked_by: &HashSet<ReplicaId>,
members: &HashSet<ReplicaId>,
) -> u32 {
acked_by.intersection(members).count() as u32
}
fn commit_op(&mut self, op: &mut PendingOp) {
if op.lsn > self.committed_lsn {
self.committed_lsn = op.lsn;
}
if let Some(reply) = op.reply.take() {
let _ = reply.send(Ok(op.lsn));
}
}
fn notify_catch_up_waiters(&mut self) {
if self.catch_up_waiters.is_empty() {
return;
}
let waiters = std::mem::take(&mut self.catch_up_waiters);
for waiter in waiters {
if self.is_caught_up(waiter.mode) {
let _ = waiter.reply.send(Ok(()));
} else {
self.catch_up_waiters.push(waiter);
}
}
}
fn try_commit_pending(&mut self) {
let mut to_remove = Vec::new();
for (lsn, op) in &self.pending {
if self.is_quorum_met(&op.acked_by) {
to_remove.push(*lsn);
}
}
for lsn in to_remove {
if let Some(mut op) = self.pending.remove(&lsn) {
self.commit_op(&mut op);
}
}
}
}
impl Default for QuorumTracker {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_single_replica_commits_immediately() {
let mut tracker = QuorumTracker::new();
let primary_id = 1;
tracker.set_current_configuration(
HashSet::from([primary_id]),
1, );
let (tx, rx) = oneshot::channel();
tracker.register(1, primary_id, tx);
let lsn = rx.await.unwrap().unwrap();
assert_eq!(lsn, 1);
assert_eq!(tracker.committed_lsn(), 1);
assert_eq!(tracker.pending_count(), 0);
}
#[tokio::test]
async fn test_three_replicas_quorum() {
let mut tracker = QuorumTracker::new();
let primary_id = 1;
tracker.set_current_configuration(
HashSet::from([1, 2, 3]),
2, );
let (tx, rx) = oneshot::channel();
tracker.register(1, primary_id, tx);
assert_eq!(tracker.pending_count(), 1);
tracker.ack(1, 2);
let lsn = rx.await.unwrap().unwrap();
assert_eq!(lsn, 1);
assert_eq!(tracker.committed_lsn(), 1);
assert_eq!(tracker.pending_count(), 0);
}
#[tokio::test]
async fn test_dual_config_quorum() {
let mut tracker = QuorumTracker::new();
let primary_id = 1;
tracker.set_catch_up_configuration(
HashSet::from([1, 2, 3]),
2,
HashSet::from([1, 2]),
2,
HashSet::new(),
HashMap::new(),
);
let (tx, rx) = oneshot::channel();
tracker.register(1, primary_id, tx);
assert_eq!(tracker.pending_count(), 1);
tracker.ack(1, 3);
assert_eq!(tracker.pending_count(), 1);
tracker.ack(1, 2);
let lsn = rx.await.unwrap().unwrap();
assert_eq!(lsn, 1);
assert_eq!(tracker.pending_count(), 0);
}
#[tokio::test]
async fn test_out_of_order_acks() {
let mut tracker = QuorumTracker::new();
let primary_id = 1;
tracker.set_current_configuration(HashSet::from([1, 2, 3]), 2);
let (tx1, rx1) = oneshot::channel();
let (tx2, rx2) = oneshot::channel();
tracker.register(1, primary_id, tx1);
tracker.register(2, primary_id, tx2);
tracker.ack(2, 2);
let lsn2 = rx2.await.unwrap().unwrap();
assert_eq!(lsn2, 2);
tracker.ack(1, 2);
let lsn1 = rx1.await.unwrap().unwrap();
assert_eq!(lsn1, 1);
assert_eq!(tracker.committed_lsn(), 2);
}
#[tokio::test]
async fn test_fail_all() {
let mut tracker = QuorumTracker::new();
tracker.set_current_configuration(HashSet::from([1, 2, 3]), 2);
let (tx1, rx1) = oneshot::channel();
let (tx2, rx2) = oneshot::channel();
tracker.register(1, 1, tx1);
tracker.register(2, 1, tx2);
tracker.fail_all(KubericError::NotPrimary);
let result1 = rx1.await.unwrap();
assert!(matches!(result1, Err(KubericError::NotPrimary)));
let result2 = rx2.await.unwrap();
assert!(matches!(result2, Err(KubericError::NotPrimary)));
assert_eq!(tracker.pending_count(), 0);
}
#[tokio::test]
async fn test_must_catch_up_enforcement() {
use crate::types::ReplicaSetQuorumMode;
let mut tracker = QuorumTracker::new();
tracker.set_catch_up_configuration(
HashSet::from([1, 2, 3]),
2,
HashSet::new(),
0,
HashSet::from([2]),
HashMap::from([(2, 0), (3, 0)]), );
let (tx, rx) = oneshot::channel();
tracker.register(1, 1, tx);
tracker.ack(1, 3);
let lsn = rx.await.unwrap().unwrap();
assert_eq!(lsn, 1);
assert_eq!(tracker.pending_count(), 0);
let (wait_tx, mut wait_rx) = oneshot::channel();
tracker.wait_for_catch_up(ReplicaSetQuorumMode::Write, wait_tx);
assert!(wait_rx.try_recv().is_err());
tracker.ack(1, 2);
let result = wait_rx.await.unwrap();
assert!(result.is_ok());
}
#[tokio::test]
async fn test_wait_catch_up_all_mode() {
use crate::types::ReplicaSetQuorumMode;
let mut tracker = QuorumTracker::new();
tracker.set_current_configuration(HashSet::from([1, 2, 3]), 2);
let (tx, _rx) = oneshot::channel();
tracker.register(1, 1, tx);
let (wait_tx, mut wait_rx) = oneshot::channel();
tracker.wait_for_catch_up(ReplicaSetQuorumMode::All, wait_tx);
assert!(wait_rx.try_recv().is_err());
tracker.ack(1, 2);
assert!(wait_rx.try_recv().is_err());
tracker.ack(1, 3);
let result = wait_rx.await.unwrap();
assert!(result.is_ok());
}
}