use crate::buffer::{ReplicaId, SeqNo};
use mdcs_core::lattice::Lattice;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, VecDeque};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct DeltaInterval<D> {
pub from: ReplicaId,
pub to: ReplicaId,
pub delta: D,
pub from_seq: SeqNo,
pub to_seq: SeqNo,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct IntervalAck {
pub from: ReplicaId,
pub to: ReplicaId,
pub acked_seq: SeqNo,
}
#[derive(Debug, Clone)]
pub enum CausalMessage<D> {
DeltaInterval(DeltaInterval<D>),
Ack(IntervalAck),
SnapshotRequest { from: ReplicaId, to: ReplicaId },
Snapshot {
from: ReplicaId,
to: ReplicaId,
state: D,
seq: SeqNo,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DurableState<S> {
pub replica_id: ReplicaId,
pub state: S,
pub counter: SeqNo,
}
impl<S: Lattice> DurableState<S> {
pub fn new(replica_id: impl Into<ReplicaId>) -> Self {
Self {
replica_id: replica_id.into(),
state: S::bottom(),
counter: 0,
}
}
}
#[derive(Debug, Clone)]
pub struct PeerDeltaBuffer<D: Lattice> {
delta: Option<D>,
from_seq: SeqNo,
to_seq: SeqNo,
}
impl<D: Lattice> PeerDeltaBuffer<D> {
pub fn new() -> Self {
Self {
delta: None,
from_seq: 0,
to_seq: 0,
}
}
pub fn start_from(seq: SeqNo) -> Self {
Self {
delta: None,
from_seq: seq,
to_seq: seq,
}
}
pub fn push(&mut self, delta: D, seq: SeqNo) {
match &mut self.delta {
Some(existing) => {
existing.join_assign(&delta);
}
None => {
self.delta = Some(delta);
}
}
self.to_seq = seq;
}
pub fn has_pending(&self) -> bool {
self.delta.is_some()
}
pub fn take(&mut self) -> Option<(D, SeqNo, SeqNo)> {
self.delta.take().map(|d| {
let from = self.from_seq;
let to = self.to_seq;
self.from_seq = to;
(d, from, to)
})
}
pub fn clear(&mut self) {
self.delta = None;
self.from_seq = self.to_seq;
}
pub fn reset_from(&mut self, seq: SeqNo) {
self.delta = None;
self.from_seq = seq;
self.to_seq = seq;
}
}
impl<D: Lattice> Default for PeerDeltaBuffer<D> {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct VolatileState<D: Lattice> {
pub delta_buffers: HashMap<ReplicaId, PeerDeltaBuffer<D>>,
pub peer_acks: HashMap<ReplicaId, SeqNo>,
}
impl<D: Lattice> VolatileState<D> {
pub fn new() -> Self {
Self {
delta_buffers: HashMap::new(),
peer_acks: HashMap::new(),
}
}
pub fn register_peer(&mut self, peer_id: ReplicaId) {
self.delta_buffers.entry(peer_id.clone()).or_default();
self.peer_acks.entry(peer_id).or_insert(0);
}
pub fn get_peer_ack(&self, peer_id: &str) -> SeqNo {
self.peer_acks.get(peer_id).copied().unwrap_or(0)
}
pub fn update_peer_ack(&mut self, peer_id: &str, seq: SeqNo) {
if let Some(ack) = self.peer_acks.get_mut(peer_id) {
*ack = (*ack).max(seq);
}
}
}
impl<D: Lattice> Default for VolatileState<D> {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct CausalReplica<S: Lattice + Clone> {
durable: DurableState<S>,
volatile: VolatileState<S>,
pending: HashMap<ReplicaId, VecDeque<DeltaInterval<S>>>,
}
impl<S: Lattice + Clone> CausalReplica<S> {
pub fn new(id: impl Into<ReplicaId>) -> Self {
Self {
durable: DurableState::new(id),
volatile: VolatileState::new(),
pending: HashMap::new(),
}
}
pub fn restore(durable: DurableState<S>) -> Self {
Self {
durable,
volatile: VolatileState::new(),
pending: HashMap::new(),
}
}
pub fn id(&self) -> &ReplicaId {
&self.durable.replica_id
}
pub fn state(&self) -> &S {
&self.durable.state
}
pub fn counter(&self) -> SeqNo {
self.durable.counter
}
pub fn durable_state(&self) -> &DurableState<S> {
&self.durable
}
pub fn register_peer(&mut self, peer_id: ReplicaId) {
self.volatile.register_peer(peer_id.clone());
self.pending.entry(peer_id).or_default();
}
pub fn mutate<F>(&mut self, mutator: F) -> S
where
F: FnOnce(&S) -> S,
{
self.durable.counter += 1;
let seq = self.durable.counter;
let delta = mutator(&self.durable.state);
self.durable.state.join_assign(&delta);
for buffer in self.volatile.delta_buffers.values_mut() {
buffer.push(delta.clone(), seq);
}
delta
}
pub fn prepare_interval(&mut self, peer_id: &str) -> Option<DeltaInterval<S>> {
let buffer = self.volatile.delta_buffers.get_mut(peer_id)?;
buffer
.take()
.map(|(delta, from_seq, to_seq)| DeltaInterval {
from: self.durable.replica_id.clone(),
to: peer_id.to_string(),
delta,
from_seq,
to_seq,
})
}
fn is_causally_ready(&self, interval: &DeltaInterval<S>) -> bool {
let last_acked = self.volatile.get_peer_ack(&interval.from);
interval.from_seq == last_acked
}
pub fn receive_interval(&mut self, interval: DeltaInterval<S>) -> Option<IntervalAck> {
if !self.volatile.peer_acks.contains_key(&interval.from) {
self.register_peer(interval.from.clone());
}
if self.is_causally_ready(&interval) {
self.durable.state.join_assign(&interval.delta);
self.volatile
.update_peer_ack(&interval.from, interval.to_seq);
let ack = IntervalAck {
from: self.durable.replica_id.clone(),
to: interval.from.clone(),
acked_seq: interval.to_seq,
};
self.try_apply_pending(&interval.from);
Some(ack)
} else {
let pending = self.pending.entry(interval.from.clone()).or_default();
let pos = pending.iter().position(|p| p.from_seq > interval.from_seq);
match pos {
Some(i) => pending.insert(i, interval),
None => pending.push_back(interval),
}
None
}
}
fn try_apply_pending(&mut self, peer_id: &str) -> Vec<IntervalAck> {
let mut acks = Vec::new();
if let Some(pending) = self.pending.get_mut(peer_id) {
while let Some(interval) = pending.front() {
let last_acked = self.volatile.get_peer_ack(peer_id);
if interval.from_seq == last_acked {
let interval = pending.pop_front().unwrap();
self.durable.state.join_assign(&interval.delta);
self.volatile.update_peer_ack(peer_id, interval.to_seq);
acks.push(IntervalAck {
from: self.durable.replica_id.clone(),
to: interval.from.clone(),
acked_seq: interval.to_seq,
});
} else {
break;
}
}
}
acks
}
pub fn receive_ack(&mut self, ack: &IntervalAck) {
if let Some(buffer) = self.volatile.delta_buffers.get_mut(&ack.from) {
buffer.clear();
}
}
pub fn snapshot(&self) -> (S, SeqNo) {
(self.durable.state.clone(), self.durable.counter)
}
pub fn apply_snapshot(&mut self, state: S, seq: SeqNo, from: &str) {
self.durable.state.join_assign(&state);
self.volatile.update_peer_ack(from, seq);
}
pub fn peers(&self) -> impl Iterator<Item = &ReplicaId> {
self.volatile.peer_acks.keys()
}
pub fn has_pending_deltas(&self) -> bool {
self.volatile
.delta_buffers
.values()
.any(|b| b.has_pending())
}
pub fn pending_count(&self) -> usize {
self.pending.values().map(|v| v.len()).sum()
}
}
pub trait DurableStorage<S: Lattice> {
fn persist(&mut self, state: &DurableState<S>) -> Result<(), StorageError>;
fn load(&self, replica_id: &str) -> Result<Option<DurableState<S>>, StorageError>;
fn sync(&mut self) -> Result<(), StorageError>;
}
#[derive(Debug, Clone)]
pub enum StorageError {
IoError(String),
SerializationError(String),
NotFound,
}
impl std::fmt::Display for StorageError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
StorageError::IoError(msg) => write!(f, "IO error: {}", msg),
StorageError::SerializationError(msg) => write!(f, "Serialization error: {}", msg),
StorageError::NotFound => write!(f, "State not found"),
}
}
}
impl std::error::Error for StorageError {}
#[derive(Debug, Default)]
pub struct MemoryStorage<S> {
states: HashMap<ReplicaId, DurableState<S>>,
}
impl<S: Clone> MemoryStorage<S> {
pub fn new() -> Self {
Self {
states: HashMap::new(),
}
}
}
impl<S: Lattice + Clone + Serialize + for<'de> Deserialize<'de>> DurableStorage<S>
for MemoryStorage<S>
{
fn persist(&mut self, state: &DurableState<S>) -> Result<(), StorageError> {
self.states.insert(state.replica_id.clone(), state.clone());
Ok(())
}
fn load(&self, replica_id: &str) -> Result<Option<DurableState<S>>, StorageError> {
Ok(self.states.get(replica_id).cloned())
}
fn sync(&mut self) -> Result<(), StorageError> {
Ok(())
}
}
#[derive(Debug)]
pub struct CausalNetworkSimulator<D> {
in_flight: VecDeque<CausalMessage<D>>,
lost: Vec<CausalMessage<D>>,
loss_rate: f64,
rng_state: u64,
}
impl<D: Clone> CausalNetworkSimulator<D> {
pub fn new(loss_rate: f64) -> Self {
Self {
in_flight: VecDeque::new(),
lost: Vec::new(),
loss_rate,
rng_state: 42,
}
}
fn next_random(&mut self) -> f64 {
self.rng_state = self.rng_state.wrapping_mul(1103515245).wrapping_add(12345);
((self.rng_state >> 16) & 0x7fff) as f64 / 32768.0
}
pub fn send(&mut self, msg: CausalMessage<D>) {
if self.next_random() < self.loss_rate {
self.lost.push(msg);
} else {
self.in_flight.push_back(msg);
}
}
pub fn receive(&mut self) -> Option<CausalMessage<D>> {
self.in_flight.pop_front()
}
pub fn retransmit_lost(&mut self) {
for msg in self.lost.drain(..) {
self.in_flight.push_back(msg);
}
}
pub fn is_empty(&self) -> bool {
self.in_flight.is_empty()
}
pub fn in_flight_count(&self) -> usize {
self.in_flight.len()
}
pub fn lost_count(&self) -> usize {
self.lost.len()
}
}
#[derive(Debug)]
pub struct CausalCluster<S: Lattice + Clone> {
replicas: Vec<CausalReplica<S>>,
network: CausalNetworkSimulator<S>,
}
impl<S: Lattice + Clone> CausalCluster<S> {
pub fn new(n: usize, loss_rate: f64) -> Self {
let mut replicas = Vec::with_capacity(n);
for i in 0..n {
let mut replica = CausalReplica::new(format!("causal_{}", i));
for j in 0..n {
if i != j {
replica.register_peer(format!("causal_{}", j));
}
}
replicas.push(replica);
}
Self {
replicas,
network: CausalNetworkSimulator::new(loss_rate),
}
}
pub fn replica(&self, idx: usize) -> &CausalReplica<S> {
&self.replicas[idx]
}
pub fn replica_mut(&mut self, idx: usize) -> &mut CausalReplica<S> {
&mut self.replicas[idx]
}
pub fn mutate<F>(&mut self, replica_idx: usize, mutator: F) -> S
where
F: FnOnce(&S) -> S,
{
self.replicas[replica_idx].mutate(mutator)
}
pub fn broadcast_intervals(&mut self, from_idx: usize) {
let replica = &mut self.replicas[from_idx];
let peer_ids: Vec<_> = replica.peers().cloned().collect();
for peer_id in peer_ids {
if let Some(interval) = replica.prepare_interval(&peer_id) {
self.network.send(CausalMessage::DeltaInterval(interval));
}
}
}
pub fn process_one(&mut self) -> bool {
if let Some(msg) = self.network.receive() {
match msg {
CausalMessage::DeltaInterval(interval) => {
for replica in &mut self.replicas {
if replica.id() == &interval.to {
if let Some(ack) = replica.receive_interval(interval.clone()) {
self.network.send(CausalMessage::Ack(ack));
}
break;
}
}
}
CausalMessage::Ack(ack) => {
for replica in &mut self.replicas {
if replica.id() == &ack.to {
replica.receive_ack(&ack);
break;
}
}
}
CausalMessage::SnapshotRequest { from, to } => {
for replica in &self.replicas {
if replica.id() == &to {
let (state, seq) = replica.snapshot();
self.network.send(CausalMessage::Snapshot {
from: to,
to: from,
state,
seq,
});
break;
}
}
}
CausalMessage::Snapshot {
from,
to,
state,
seq,
} => {
for replica in &mut self.replicas {
if replica.id() == &to {
replica.apply_snapshot(state, seq, &from);
break;
}
}
}
}
true
} else {
false
}
}
pub fn drain_network(&mut self) {
while self.process_one() {}
}
pub fn full_sync_round(&mut self) {
let n = self.replicas.len();
for i in 0..n {
self.broadcast_intervals(i);
}
self.drain_network();
}
pub fn is_converged(&self) -> bool {
if self.replicas.len() < 2 {
return true;
}
let first = self.replicas[0].state();
self.replicas.iter().skip(1).all(|r| r.state() == first)
}
pub fn retransmit_and_process(&mut self) {
self.network.retransmit_lost();
self.drain_network();
}
pub fn len(&self) -> usize {
self.replicas.len()
}
pub fn is_empty(&self) -> bool {
self.replicas.is_empty()
}
pub fn crash_and_recover(&mut self, idx: usize) {
let durable = self.replicas[idx].durable_state().clone();
let mut recovered = CausalReplica::restore(durable);
let n = self.replicas.len();
for j in 0..n {
if idx != j {
recovered.register_peer(format!("causal_{}", j));
}
}
self.replicas[idx] = recovered;
}
pub fn total_pending(&self) -> usize {
self.replicas.iter().map(|r| r.pending_count()).sum()
}
}
#[cfg(test)]
mod tests {
use super::*;
use mdcs_core::gset::GSet;
use mdcs_core::pncounter::PNCounter;
#[test]
fn test_causal_replica_basic() {
let mut replica: CausalReplica<GSet<i32>> = CausalReplica::new("test1");
replica.mutate(|_| {
let mut d = GSet::new();
d.insert(42);
d
});
assert!(replica.state().contains(&42));
assert_eq!(replica.counter(), 1);
}
#[test]
fn test_causal_interval_generation() {
let mut replica: CausalReplica<GSet<i32>> = CausalReplica::new("test1");
replica.register_peer("peer1".to_string());
replica.mutate(|_| {
let mut d = GSet::new();
d.insert(1);
d
});
replica.mutate(|_| {
let mut d = GSet::new();
d.insert(2);
d
});
let interval = replica.prepare_interval("peer1").unwrap();
assert_eq!(interval.from_seq, 0);
assert_eq!(interval.to_seq, 2);
assert!(interval.delta.contains(&1));
assert!(interval.delta.contains(&2));
}
#[test]
fn test_causal_delivery() {
let mut r1: CausalReplica<GSet<i32>> = CausalReplica::new("r1");
let mut r2: CausalReplica<GSet<i32>> = CausalReplica::new("r2");
r1.register_peer("r2".to_string());
r2.register_peer("r1".to_string());
r1.mutate(|_| {
let mut d = GSet::new();
d.insert(1);
d
});
r1.mutate(|_| {
let mut d = GSet::new();
d.insert(2);
d
});
let interval = r1.prepare_interval("r2").unwrap();
assert_eq!(interval.from_seq, 0);
assert_eq!(interval.to_seq, 2);
let ack = r2.receive_interval(interval).unwrap();
assert_eq!(ack.acked_seq, 2);
assert!(r2.state().contains(&1));
assert!(r2.state().contains(&2));
}
#[test]
fn test_out_of_order_buffering() {
let mut replica: CausalReplica<GSet<i32>> = CausalReplica::new("r1");
replica.register_peer("peer".to_string());
let out_of_order = DeltaInterval {
from: "peer".to_string(),
to: "r1".to_string(),
delta: {
let mut d = GSet::new();
d.insert(999);
d
},
from_seq: 5, to_seq: 6,
};
let result = replica.receive_interval(out_of_order);
assert!(result.is_none());
assert_eq!(replica.pending_count(), 1);
assert!(!replica.state().contains(&999));
}
#[test]
fn test_cluster_convergence() {
let mut cluster: CausalCluster<GSet<i32>> = CausalCluster::new(3, 0.0);
for i in 0..3 {
let val = (i + 1) as i32;
cluster.mutate(i, move |_| {
let mut d = GSet::new();
d.insert(val);
d
});
}
assert!(!cluster.is_converged());
cluster.full_sync_round();
assert!(cluster.is_converged());
for i in 0..3 {
for val in 1..=3 {
assert!(cluster.replica(i).state().contains(&val));
}
}
}
#[test]
fn test_cluster_with_loss() {
let mut cluster: CausalCluster<GSet<i32>> = CausalCluster::new(3, 0.3);
for i in 0..3 {
let val = (i + 1) as i32;
cluster.mutate(i, move |_| {
let mut d = GSet::new();
d.insert(val);
d
});
}
for _ in 0..10 {
cluster.full_sync_round();
cluster.retransmit_and_process();
}
assert!(cluster.is_converged());
}
#[test]
fn test_crash_recovery() {
let mut cluster: CausalCluster<GSet<i32>> = CausalCluster::new(2, 0.0);
cluster.mutate(0, |_| {
let mut d = GSet::new();
d.insert(1);
d
});
cluster.full_sync_round();
assert!(cluster.is_converged());
cluster.mutate(0, |_| {
let mut d = GSet::new();
d.insert(2);
d
});
let counter_before = cluster.replica(0).counter();
cluster.crash_and_recover(0);
assert_eq!(cluster.replica(0).counter(), counter_before);
assert!(cluster.replica(0).state().contains(&1));
assert!(cluster.replica(0).state().contains(&2));
assert!(!cluster.replica(0).has_pending_deltas());
}
#[test]
fn test_pncounter_causal() {
let mut cluster: CausalCluster<PNCounter<String>> = CausalCluster::new(2, 0.0);
cluster.mutate(0, |_s| {
let mut delta = PNCounter::new();
delta.increment("r0".to_string(), 1);
delta
});
cluster.mutate(1, |_s| {
let mut delta = PNCounter::new();
delta.decrement("r1".to_string(), 1);
delta
});
cluster.full_sync_round();
assert!(cluster.is_converged());
assert_eq!(cluster.replica(0).state().value(), 0);
}
#[test]
fn test_causal_ordering_preserved() {
let mut r1: CausalReplica<GSet<i32>> = CausalReplica::new("r1");
let mut r2: CausalReplica<GSet<i32>> = CausalReplica::new("r2");
r1.register_peer("r2".to_string());
r2.register_peer("r1".to_string());
for i in 1..=3 {
r1.mutate(move |_| {
let mut d = GSet::new();
d.insert(i);
d
});
}
let interval_1_3 = DeltaInterval {
from: "r1".to_string(),
to: "r2".to_string(),
delta: {
let mut d = GSet::new();
d.insert(3);
d
},
from_seq: 2, to_seq: 3,
};
let interval_0_2 = DeltaInterval {
from: "r1".to_string(),
to: "r2".to_string(),
delta: {
let mut d = GSet::new();
d.insert(1);
d.insert(2);
d
},
from_seq: 0,
to_seq: 2,
};
let result = r2.receive_interval(interval_1_3.clone());
assert!(result.is_none()); assert!(!r2.state().contains(&3));
let result = r2.receive_interval(interval_0_2);
assert!(result.is_some()); assert!(r2.state().contains(&1));
assert!(r2.state().contains(&2));
assert!(r2.state().contains(&3));
assert_eq!(r2.pending_count(), 0);
}
#[test]
fn test_durable_storage() {
let mut storage: MemoryStorage<GSet<i32>> = MemoryStorage::new();
let mut replica: CausalReplica<GSet<i32>> = CausalReplica::new("test");
replica.mutate(|_| {
let mut d = GSet::new();
d.insert(42);
d
});
storage.persist(replica.durable_state()).unwrap();
let loaded = storage.load("test").unwrap().unwrap();
assert_eq!(loaded.counter, 1);
assert!(loaded.state.contains(&42));
let recovered = CausalReplica::restore(loaded);
assert!(recovered.state().contains(&42));
}
}