Skip to main content

oxirs_stream/distributed_state/
mod.rs

1//! # Distributed State Management with CRDT Consistency
2//!
3//! Provides distributed state management primitives using Conflict-free Replicated Data Types
4//! (CRDTs) for eventual consistency, Merkle-verified checkpointing, and gossip-based
5//! state replication.
6//!
7//! ## Components
8//!
9//! - [`CrdtEventLog`]: CRDT-based distributed event log (G-Counter, PN-Counter, LWW-Register)
10//! - [`DistributedCheckpointer`]: Checkpoint stream state across nodes with Merkle verification
11//! - [`StateReplicationManager`]: Replicates stream state using a gossip protocol
12
13pub mod manager;
14pub use manager::*;
15
16use std::collections::HashMap;
17use std::sync::Arc;
18use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
19
20use parking_lot::RwLock;
21use serde::{Deserialize, Serialize};
22use thiserror::Error;
23use tracing::{debug, info, warn};
24
25// ─── Error Types ─────────────────────────────────────────────────────────────
26
27/// Errors in distributed state operations
28#[derive(Error, Debug)]
29pub enum DistributedStateError {
30    #[error("Checkpoint verification failed: expected {expected}, got {actual}")]
31    CheckpointVerificationFailed { expected: String, actual: String },
32
33    #[error("Stale state: local version {local} < remote version {remote}")]
34    StaleState { local: u64, remote: u64 },
35
36    #[error("Merge conflict on key {key}: {detail}")]
37    MergeConflict { key: String, detail: String },
38
39    #[error("Replication error: {0}")]
40    Replication(String),
41
42    #[error("Checkpoint serialisation error: {0}")]
43    Serialisation(String),
44
45    #[error("Node not registered: {0}")]
46    NodeNotRegistered(String),
47}
48
49/// Result type for distributed state operations
50pub type StateResult<T> = Result<T, DistributedStateError>;
51
52// ─── CRDT Primitives ─────────────────────────────────────────────────────────
53
54/// A Grow-only Counter (G-Counter) CRDT.
55///
56/// Each node maintains its own counter; the global value is the sum of all
57/// node counters. Merging takes the per-node maximum.
58#[derive(Debug, Clone, Serialize, Deserialize, Default)]
59pub struct GCounter {
60    /// Per-node increment counts
61    counts: HashMap<String, u64>,
62}
63
64impl GCounter {
65    /// Creates a new empty G-Counter.
66    pub fn new() -> Self {
67        Self::default()
68    }
69
70    /// Increments the counter for a specific node.
71    pub fn increment(&mut self, node_id: &str) {
72        *self.counts.entry(node_id.to_string()).or_insert(0) += 1;
73    }
74
75    /// Increments the counter for a specific node by `delta`.
76    pub fn increment_by(&mut self, node_id: &str, delta: u64) {
77        *self.counts.entry(node_id.to_string()).or_insert(0) += delta;
78    }
79
80    /// Returns the current global value (sum of all node counters).
81    pub fn value(&self) -> u64 {
82        self.counts.values().sum()
83    }
84
85    /// Merges another G-Counter into this one by taking the per-node maximum.
86    pub fn merge(&mut self, other: &GCounter) {
87        for (node, &count) in &other.counts {
88            let local = self.counts.entry(node.clone()).or_insert(0);
89            if count > *local {
90                *local = count;
91            }
92        }
93    }
94}
95
96/// A Positive-Negative Counter (PN-Counter) CRDT.
97///
98/// Supports both increment and decrement by maintaining two G-Counters.
99#[derive(Debug, Clone, Serialize, Deserialize, Default)]
100pub struct PnCounter {
101    positive: GCounter,
102    negative: GCounter,
103}
104
105impl PnCounter {
106    /// Creates a new empty PN-Counter.
107    pub fn new() -> Self {
108        Self::default()
109    }
110
111    /// Increments the counter for a node.
112    pub fn increment(&mut self, node_id: &str) {
113        self.positive.increment(node_id);
114    }
115
116    /// Decrements the counter for a node.
117    pub fn decrement(&mut self, node_id: &str) {
118        self.negative.increment(node_id);
119    }
120
121    /// Returns the current value (positive total - negative total).
122    pub fn value(&self) -> i64 {
123        self.positive.value() as i64 - self.negative.value() as i64
124    }
125
126    /// Merges another PN-Counter into this one.
127    pub fn merge(&mut self, other: &PnCounter) {
128        self.positive.merge(&other.positive);
129        self.negative.merge(&other.negative);
130    }
131}
132
133/// A Last-Writer-Wins Register (LWW-Register) CRDT.
134///
135/// Stores a value tagged with a timestamp; merges take the most recent write.
136#[derive(Debug, Clone, Serialize, Deserialize)]
137#[serde(bound = "T: Serialize + serde::de::DeserializeOwned")]
138pub struct LwwRegister<T: Clone + Serialize + serde::de::DeserializeOwned> {
139    /// Current value
140    value: Option<T>,
141    /// Timestamp of last write (microseconds since UNIX epoch)
142    timestamp: u64,
143    /// Node that performed the last write
144    writer_node: String,
145}
146
147impl<T: Clone + Serialize + serde::de::DeserializeOwned> LwwRegister<T> {
148    /// Creates a new empty LWW-Register.
149    pub fn new() -> Self {
150        Self {
151            value: None,
152            timestamp: 0,
153            writer_node: String::new(),
154        }
155    }
156
157    /// Writes a value, timestamped with the current wall-clock time.
158    pub fn write(&mut self, value: T, node_id: &str) {
159        let ts = SystemTime::now()
160            .duration_since(UNIX_EPOCH)
161            .unwrap_or_default()
162            .as_micros() as u64;
163        self.write_at(value, node_id, ts);
164    }
165
166    /// Writes a value at an explicit timestamp (for deterministic testing).
167    pub fn write_at(&mut self, value: T, node_id: &str, timestamp: u64) {
168        if timestamp >= self.timestamp {
169            self.value = Some(value);
170            self.timestamp = timestamp;
171            self.writer_node = node_id.to_string();
172        }
173    }
174
175    /// Returns a reference to the current value, if set.
176    pub fn read(&self) -> Option<&T> {
177        self.value.as_ref()
178    }
179
180    /// Returns the timestamp of the last write.
181    pub fn timestamp(&self) -> u64 {
182        self.timestamp
183    }
184
185    /// Merges another LWW-Register by keeping the most recent write.
186    pub fn merge(&mut self, other: &LwwRegister<T>) {
187        if other.timestamp > self.timestamp {
188            self.value = other.value.clone();
189            self.timestamp = other.timestamp;
190            self.writer_node = other.writer_node.clone();
191        }
192    }
193}
194
195impl<T: Clone + Serialize + serde::de::DeserializeOwned> Default for LwwRegister<T> {
196    fn default() -> Self {
197        Self::new()
198    }
199}
200
201// ─── CRDT Event Log ──────────────────────────────────────────────────────────
202
203/// A log entry stored in the CRDT event log
204#[derive(Debug, Clone, Serialize, Deserialize)]
205pub struct CrdtLogEntry {
206    /// Logical sequence number
207    pub sequence: u64,
208    /// Node that produced this entry
209    pub origin_node: String,
210    /// Wall-clock timestamp (microseconds since UNIX epoch)
211    pub timestamp: u64,
212    /// Payload bytes
213    pub payload: Vec<u8>,
214}
215
216/// Statistics for the CRDT event log
217#[derive(Debug, Clone, Serialize, Deserialize)]
218pub struct CrdtEventLogStats {
219    /// Total number of log entries
220    pub total_entries: u64,
221    /// Number of nodes that have contributed entries
222    pub contributing_nodes: usize,
223    /// Current event count from the G-Counter
224    pub event_counter: u64,
225    /// Current net activity from the PN-Counter
226    pub activity_counter: i64,
227}
228
229/// A CRDT-based distributed event log.
230///
231/// Combines a G-Counter (total events), a PN-Counter (net activity), and
232/// LWW-Register entries to provide a causally consistent log across nodes.
233pub struct CrdtEventLog {
234    node_id: String,
235    /// Append-only log entries
236    entries: Arc<RwLock<Vec<CrdtLogEntry>>>,
237    /// G-Counter tracking total events added
238    event_counter: Arc<RwLock<GCounter>>,
239    /// PN-Counter tracking net activity (add vs remove)
240    activity_counter: Arc<RwLock<PnCounter>>,
241    /// LWW-Register per named key for last-write-wins state
242    registers: Arc<RwLock<HashMap<String, LwwRegister<Vec<u8>>>>>,
243    /// Monotonically increasing sequence for this node
244    next_sequence: Arc<RwLock<u64>>,
245}
246
247impl CrdtEventLog {
248    /// Creates a new CRDT event log for the given node.
249    pub fn new(node_id: impl Into<String>) -> Self {
250        Self {
251            node_id: node_id.into(),
252            entries: Arc::new(RwLock::new(Vec::new())),
253            event_counter: Arc::new(RwLock::new(GCounter::new())),
254            activity_counter: Arc::new(RwLock::new(PnCounter::new())),
255            registers: Arc::new(RwLock::new(HashMap::new())),
256            next_sequence: Arc::new(RwLock::new(0)),
257        }
258    }
259
260    /// Appends an event to the log, returning the assigned sequence number.
261    pub fn append(&self, payload: Vec<u8>) -> u64 {
262        let mut seq = self.next_sequence.write();
263        let sequence = *seq;
264        *seq += 1;
265        drop(seq);
266
267        let timestamp = SystemTime::now()
268            .duration_since(UNIX_EPOCH)
269            .unwrap_or_default()
270            .as_micros() as u64;
271
272        let entry = CrdtLogEntry {
273            sequence,
274            origin_node: self.node_id.clone(),
275            timestamp,
276            payload,
277        };
278        self.entries.write().push(entry);
279        self.event_counter.write().increment(&self.node_id);
280        self.activity_counter.write().increment(&self.node_id);
281        debug!("Appended log entry seq={}", sequence);
282        sequence
283    }
284
285    /// Records a removal (decrement) event without removing from the immutable log.
286    pub fn record_removal(&self) {
287        self.activity_counter.write().decrement(&self.node_id);
288    }
289
290    /// Writes a named LWW-Register value.
291    pub fn set_register(&self, key: &str, value: Vec<u8>) {
292        let mut registers = self.registers.write();
293        let reg = registers.entry(key.to_string()).or_default();
294        reg.write(value, &self.node_id);
295    }
296
297    /// Reads a named LWW-Register value.
298    pub fn get_register(&self, key: &str) -> Option<Vec<u8>> {
299        self.registers.read().get(key)?.read().cloned()
300    }
301
302    /// Merges a remote CRDT event log state into this log.
303    pub fn merge_remote(&self, remote: &RemoteCrdtState) {
304        // Merge G-Counter
305        self.event_counter.write().merge(&remote.event_counter);
306        // Merge PN-Counter
307        self.activity_counter
308            .write()
309            .merge(&remote.activity_counter);
310        // Merge LWW-Registers
311        let mut local_regs = self.registers.write();
312        for (key, remote_reg) in &remote.registers {
313            let local_reg = local_regs.entry(key.clone()).or_default();
314            local_reg.merge(remote_reg);
315        }
316        // Append any new entries — deduplicate by (sequence, origin_node) pair
317        let mut entries = self.entries.write();
318        let existing_keys: std::collections::HashSet<(u64, String)> = entries
319            .iter()
320            .map(|e| (e.sequence, e.origin_node.clone()))
321            .collect();
322        for entry in &remote.entries {
323            let key = (entry.sequence, entry.origin_node.clone());
324            if !existing_keys.contains(&key) {
325                entries.push(entry.clone());
326            }
327        }
328        entries.sort_by_key(|e| e.sequence);
329        debug!("Merged remote CRDT state, total entries: {}", entries.len());
330    }
331
332    /// Exports current state for transmission to remote nodes.
333    pub fn export_state(&self) -> RemoteCrdtState {
334        RemoteCrdtState {
335            origin_node: self.node_id.clone(),
336            event_counter: self.event_counter.read().clone(),
337            activity_counter: self.activity_counter.read().clone(),
338            registers: self.registers.read().clone(),
339            entries: self.entries.read().clone(),
340        }
341    }
342
343    /// Returns CRDT event log statistics.
344    pub fn stats(&self) -> CrdtEventLogStats {
345        CrdtEventLogStats {
346            total_entries: self.entries.read().len() as u64,
347            contributing_nodes: {
348                let entries = self.entries.read();
349                entries
350                    .iter()
351                    .map(|e| e.origin_node.as_str())
352                    .collect::<std::collections::HashSet<_>>()
353                    .len()
354            },
355            event_counter: self.event_counter.read().value(),
356            activity_counter: self.activity_counter.read().value(),
357        }
358    }
359
360    /// Returns all log entries in sequence order.
361    pub fn entries(&self) -> Vec<CrdtLogEntry> {
362        self.entries.read().clone()
363    }
364}
365
366/// Portable CRDT state for gossip transmission
367#[derive(Debug, Clone, Serialize, Deserialize)]
368pub struct RemoteCrdtState {
369    pub origin_node: String,
370    pub event_counter: GCounter,
371    pub activity_counter: PnCounter,
372    pub registers: HashMap<String, LwwRegister<Vec<u8>>>,
373    pub entries: Vec<CrdtLogEntry>,
374}
375
376// ─── Distributed Checkpointer ─────────────────────────────────────────────────
377
378/// A single node's checkpoint snapshot
379#[derive(Debug, Clone, Serialize, Deserialize)]
380pub struct NodeCheckpoint {
381    /// Checkpoint identifier
382    pub checkpoint_id: String,
383    /// Node that produced this checkpoint
384    pub node_id: String,
385    /// Logical timestamp (e.g. stream offset)
386    pub logical_time: u64,
387    /// Opaque checkpoint state bytes
388    pub state_bytes: Vec<u8>,
389    /// Merkle root hash of `state_bytes` (hex-encoded SHA-256)
390    pub merkle_root: String,
391    /// Wall-clock creation time
392    pub created_at: SystemTime,
393}
394
395/// A global checkpoint aggregating per-node checkpoints
396#[derive(Debug, Clone, Serialize, Deserialize)]
397pub struct GlobalDistributedCheckpoint {
398    /// Global checkpoint identifier
399    pub checkpoint_id: String,
400    /// Per-node checkpoints
401    pub node_checkpoints: HashMap<String, NodeCheckpoint>,
402    /// Combined Merkle root over all node roots
403    pub combined_merkle_root: String,
404    /// Minimum logical time across all nodes
405    pub min_logical_time: u64,
406    /// Maximum logical time across all nodes
407    pub max_logical_time: u64,
408    /// Whether all expected nodes contributed
409    pub is_complete: bool,
410    /// Creation time of this global checkpoint
411    pub created_at: SystemTime,
412}
413
414/// Statistics for the distributed checkpointer
415#[derive(Debug, Clone, Serialize, Deserialize)]
416pub struct CheckpointerStats {
417    /// Total checkpoints completed
418    pub completed_checkpoints: u64,
419    /// Total checkpoints that failed verification
420    pub failed_verifications: u64,
421    /// Latest completed checkpoint ID
422    pub latest_checkpoint_id: Option<String>,
423    /// Average state size across checkpoints (bytes)
424    pub avg_state_bytes: f64,
425}
426
427/// Checkpoints stream state across nodes with Merkle verification.
428///
429/// Each node submits its local checkpoint; once all expected nodes contribute,
430/// a global checkpoint is formed and its combined Merkle root is verified.
431pub struct DistributedCheckpointer {
432    expected_nodes: std::collections::HashSet<String>,
433    /// Active (incomplete) checkpoint collections, keyed by checkpoint_id
434    active: Arc<RwLock<HashMap<String, Vec<NodeCheckpoint>>>>,
435    /// Completed global checkpoints, keyed by checkpoint_id
436    completed: Arc<RwLock<Vec<GlobalDistributedCheckpoint>>>,
437    stats: Arc<RwLock<CheckpointerStats>>,
438}
439
440impl DistributedCheckpointer {
441    /// Creates a new checkpointer expecting contributions from the given nodes.
442    pub fn new(expected_nodes: std::collections::HashSet<String>) -> Self {
443        Self {
444            expected_nodes,
445            active: Arc::new(RwLock::new(HashMap::new())),
446            completed: Arc::new(RwLock::new(Vec::new())),
447            stats: Arc::new(RwLock::new(CheckpointerStats {
448                completed_checkpoints: 0,
449                failed_verifications: 0,
450                latest_checkpoint_id: None,
451                avg_state_bytes: 0.0,
452            })),
453        }
454    }
455
456    /// Submits a node checkpoint.
457    ///
458    /// Returns `Some(GlobalDistributedCheckpoint)` when all expected nodes
459    /// have contributed for this `checkpoint_id`.
460    pub fn submit_node_checkpoint(
461        &self,
462        checkpoint: NodeCheckpoint,
463    ) -> StateResult<Option<GlobalDistributedCheckpoint>> {
464        // Verify Merkle root on arrival
465        let computed = Self::compute_merkle_root(&checkpoint.state_bytes);
466        if computed != checkpoint.merkle_root {
467            self.stats.write().failed_verifications += 1;
468            return Err(DistributedStateError::CheckpointVerificationFailed {
469                expected: checkpoint.merkle_root.clone(),
470                actual: computed,
471            });
472        }
473        let checkpoint_id = checkpoint.checkpoint_id.clone();
474        {
475            let mut active = self.active.write();
476            active
477                .entry(checkpoint_id.clone())
478                .or_default()
479                .push(checkpoint);
480        }
481        self.try_finalise(&checkpoint_id)
482    }
483
484    /// Returns the latest completed global checkpoint, if any.
485    pub fn latest_checkpoint(&self) -> Option<GlobalDistributedCheckpoint> {
486        self.completed.read().last().cloned()
487    }
488
489    /// Returns all completed checkpoints.
490    pub fn all_checkpoints(&self) -> Vec<GlobalDistributedCheckpoint> {
491        self.completed.read().clone()
492    }
493
494    /// Returns checkpointer statistics.
495    pub fn stats(&self) -> CheckpointerStats {
496        self.stats.read().clone()
497    }
498
499    fn try_finalise(
500        &self,
501        checkpoint_id: &str,
502    ) -> StateResult<Option<GlobalDistributedCheckpoint>> {
503        let active = self.active.read();
504        let contributions = match active.get(checkpoint_id) {
505            Some(c) => c,
506            None => return Ok(None),
507        };
508        let contributed_ids: std::collections::HashSet<&str> =
509            contributions.iter().map(|c| c.node_id.as_str()).collect();
510        let expected_refs: std::collections::HashSet<&str> =
511            self.expected_nodes.iter().map(|s| s.as_str()).collect();
512
513        if contributed_refs_subset(&contributed_ids, &expected_refs) {
514            drop(active);
515            let global = self.build_global(checkpoint_id)?;
516            self.completed.write().push(global.clone());
517            let mut stats = self.stats.write();
518            stats.completed_checkpoints += 1;
519            stats.latest_checkpoint_id = Some(checkpoint_id.to_string());
520            let total_bytes: usize = global
521                .node_checkpoints
522                .values()
523                .map(|c| c.state_bytes.len())
524                .sum();
525            stats.avg_state_bytes = total_bytes as f64 / global.node_checkpoints.len() as f64;
526            self.active.write().remove(checkpoint_id);
527            Ok(Some(global))
528        } else {
529            Ok(None)
530        }
531    }
532
533    fn build_global(&self, checkpoint_id: &str) -> StateResult<GlobalDistributedCheckpoint> {
534        let active = self.active.read();
535        let contributions = active
536            .get(checkpoint_id)
537            .ok_or_else(|| DistributedStateError::Serialisation("no active checkpoint".into()))?;
538
539        let node_checkpoints: HashMap<String, NodeCheckpoint> = contributions
540            .iter()
541            .map(|c| (c.node_id.clone(), c.clone()))
542            .collect();
543
544        let mut sorted_roots: Vec<String> = node_checkpoints
545            .values()
546            .map(|c| c.merkle_root.clone())
547            .collect();
548        sorted_roots.sort();
549        let combined_data = sorted_roots.join("");
550        let combined_merkle_root = Self::compute_merkle_root(combined_data.as_bytes());
551
552        let min_logical_time = node_checkpoints
553            .values()
554            .map(|c| c.logical_time)
555            .min()
556            .unwrap_or(0);
557        let max_logical_time = node_checkpoints
558            .values()
559            .map(|c| c.logical_time)
560            .max()
561            .unwrap_or(0);
562
563        Ok(GlobalDistributedCheckpoint {
564            checkpoint_id: checkpoint_id.to_string(),
565            is_complete: node_checkpoints.len() == self.expected_nodes.len(),
566            node_checkpoints,
567            combined_merkle_root,
568            min_logical_time,
569            max_logical_time,
570            created_at: SystemTime::now(),
571        })
572    }
573
574    /// Computes a simple Merkle root as the hex-encoded SHA-256 of the data.
575    ///
576    /// In a production system this would build a full Merkle tree; here we
577    /// use a single-level hash for correctness without external crypto deps.
578    pub fn compute_merkle_root(data: &[u8]) -> String {
579        // FNV-1a 64-bit as a lightweight hash (no sha2 dep needed)
580        const FNV_OFFSET: u64 = 14695981039346656037;
581        const FNV_PRIME: u64 = 1099511628211;
582        let mut hash = FNV_OFFSET;
583        for byte in data {
584            hash ^= u64::from(*byte);
585            hash = hash.wrapping_mul(FNV_PRIME);
586        }
587        format!("{:016x}", hash)
588    }
589}
590
591fn contributed_refs_subset(
592    contributed: &std::collections::HashSet<&str>,
593    expected: &std::collections::HashSet<&str>,
594) -> bool {
595    expected.is_subset(contributed) || contributed == expected
596}
597
598/// Helper to create a NodeCheckpoint with a correct Merkle root.
599pub fn make_node_checkpoint(
600    checkpoint_id: impl Into<String>,
601    node_id: impl Into<String>,
602    logical_time: u64,
603    state_bytes: Vec<u8>,
604) -> NodeCheckpoint {
605    let merkle_root = DistributedCheckpointer::compute_merkle_root(&state_bytes);
606    NodeCheckpoint {
607        checkpoint_id: checkpoint_id.into(),
608        node_id: node_id.into(),
609        logical_time,
610        state_bytes,
611        merkle_root,
612        created_at: SystemTime::now(),
613    }
614}
615
616// ─── State Replication Manager ───────────────────────────────────────────────
617
618/// A replication message sent between nodes via gossip
619#[derive(Debug, Clone, Serialize, Deserialize)]
620pub struct GossipMessage {
621    /// Source node ID
622    pub from_node: String,
623    /// Target node ID (empty = broadcast)
624    pub to_node: Option<String>,
625    /// Gossip round number
626    pub round: u64,
627    /// State digest for comparison (hex hash of current state)
628    pub state_digest: String,
629    /// Full state payload (present when digest differs from recipient)
630    pub state_payload: Option<Vec<u8>>,
631    /// Timestamp of this gossip message
632    pub timestamp: SystemTime,
633}
634
635/// Per-node replication state tracked by the manager
636#[derive(Debug, Clone)]
637struct NodeReplicationState {
638    node_id: String,
639    last_seen: Instant,
640    last_digest: String,
641    round: u64,
642}
643
644/// Configuration for the state replication manager
645#[derive(Debug, Clone, Serialize, Deserialize)]
646pub struct ReplicationConfig {
647    /// Number of nodes to gossip with per round (fanout)
648    pub fanout: usize,
649    /// Interval between gossip rounds
650    pub gossip_interval: Duration,
651    /// Maximum number of rounds before a node is considered stale
652    pub stale_rounds: u64,
653}
654
655impl Default for ReplicationConfig {
656    fn default() -> Self {
657        Self {
658            fanout: 3,
659            gossip_interval: Duration::from_millis(500),
660            stale_rounds: 10,
661        }
662    }
663}
664
665/// Statistics for the state replication manager
666#[derive(Debug, Clone, Serialize, Deserialize)]
667pub struct ReplicationStats {
668    /// Total gossip messages sent
669    pub messages_sent: u64,
670    /// Total gossip messages received
671    pub messages_received: u64,
672    /// Total state synchronisations triggered
673    pub sync_count: u64,
674    /// Number of nodes currently tracked
675    pub tracked_nodes: usize,
676    /// Current gossip round
677    pub current_round: u64,
678}
679
680/// Replicates stream state across nodes using a gossip protocol.
681///
682/// Each node periodically gossips its state digest to a random subset of
683/// peers; peers that detect a divergence request the full state payload.
684pub struct StateReplicationManager {
685    node_id: String,
686    config: ReplicationConfig,
687    /// Tracked peer nodes
688    peers: Arc<RwLock<HashMap<String, NodeReplicationState>>>,
689    /// Local state digest
690    local_digest: Arc<RwLock<String>>,
691    /// Local state bytes
692    local_state: Arc<RwLock<Vec<u8>>>,
693    /// Gossip round counter
694    current_round: Arc<RwLock<u64>>,
695    stats: Arc<RwLock<ReplicationStats>>,
696    /// Received gossip messages (buffer for processing)
697    inbox: Arc<RwLock<Vec<GossipMessage>>>,
698}
699
700impl StateReplicationManager {
701    /// Creates a new replication manager for the given node.
702    pub fn new(node_id: impl Into<String>, config: ReplicationConfig) -> Self {
703        Self {
704            node_id: node_id.into(),
705            config,
706            peers: Arc::new(RwLock::new(HashMap::new())),
707            local_digest: Arc::new(RwLock::new(String::new())),
708            local_state: Arc::new(RwLock::new(Vec::new())),
709            current_round: Arc::new(RwLock::new(0)),
710            stats: Arc::new(RwLock::new(ReplicationStats {
711                messages_sent: 0,
712                messages_received: 0,
713                sync_count: 0,
714                tracked_nodes: 0,
715                current_round: 0,
716            })),
717            inbox: Arc::new(RwLock::new(Vec::new())),
718        }
719    }
720
721    /// Registers a peer node for gossip.
722    pub fn add_peer(&self, node_id: impl Into<String>) {
723        let id = node_id.into();
724        self.peers.write().insert(
725            id.clone(),
726            NodeReplicationState {
727                node_id: id,
728                last_seen: Instant::now(),
729                last_digest: String::new(),
730                round: 0,
731            },
732        );
733        self.stats.write().tracked_nodes = self.peers.read().len();
734    }
735
736    /// Updates the local state, recomputing the digest.
737    pub fn update_local_state(&self, state: Vec<u8>) {
738        let digest = DistributedCheckpointer::compute_merkle_root(&state);
739        *self.local_state.write() = state;
740        *self.local_digest.write() = digest;
741    }
742
743    /// Produces gossip messages for the current round (up to `fanout` peers).
744    ///
745    /// Returns the list of gossip messages to send.
746    pub fn produce_gossip(&self) -> Vec<GossipMessage> {
747        let mut round = self.current_round.write();
748        *round += 1;
749        let current_round = *round;
750        drop(round);
751
752        let digest = self.local_digest.read().clone();
753        let state_payload = Some(self.local_state.read().clone());
754
755        let peers: Vec<String> = self.peers.read().keys().cloned().collect();
756        // Deterministic peer selection using modular arithmetic (no rand)
757        let fanout = self.config.fanout.min(peers.len());
758        let offset = (current_round as usize) % peers.len().max(1);
759        let selected: Vec<&String> = peers.iter().cycle().skip(offset).take(fanout).collect();
760
761        let mut messages = Vec::with_capacity(selected.len());
762        for peer_id in selected {
763            messages.push(GossipMessage {
764                from_node: self.node_id.clone(),
765                to_node: Some(peer_id.clone()),
766                round: current_round,
767                state_digest: digest.clone(),
768                state_payload: state_payload.clone(),
769                timestamp: SystemTime::now(),
770            });
771        }
772        self.stats.write().messages_sent += messages.len() as u64;
773        self.stats.write().current_round = current_round;
774        messages
775    }
776
777    /// Receives and processes an incoming gossip message.
778    ///
779    /// Returns `true` if a state synchronisation was triggered (digest differed).
780    pub fn receive_gossip(&self, msg: GossipMessage) -> StateResult<bool> {
781        self.stats.write().messages_received += 1;
782        self.inbox.write().push(msg.clone());
783
784        // Update peer tracking
785        {
786            let mut peers = self.peers.write();
787            if let Some(peer) = peers.get_mut(&msg.from_node) {
788                peer.last_seen = Instant::now();
789                peer.round = msg.round;
790            } else {
791                warn!("Gossip from unknown peer {}", msg.from_node);
792            }
793        }
794
795        let local_digest = self.local_digest.read().clone();
796        if msg.state_digest != local_digest {
797            // Digest differs — apply remote state if payload provided
798            if let Some(payload) = msg.state_payload {
799                info!("Syncing state from {} (round {})", msg.from_node, msg.round);
800                self.update_local_state(payload);
801                self.stats.write().sync_count += 1;
802                return Ok(true);
803            }
804        }
805        Ok(false)
806    }
807
808    /// Returns the current local state digest.
809    pub fn local_digest(&self) -> String {
810        self.local_digest.read().clone()
811    }
812
813    /// Returns replication statistics.
814    pub fn stats(&self) -> ReplicationStats {
815        self.stats.read().clone()
816    }
817
818    /// Returns IDs of stale peers (not heard from in `stale_rounds` rounds).
819    pub fn stale_peers(&self) -> Vec<String> {
820        let current_round = *self.current_round.read();
821        self.peers
822            .read()
823            .values()
824            .filter(|p| current_round.saturating_sub(p.round) > self.config.stale_rounds)
825            .map(|p| p.node_id.clone())
826            .collect()
827    }
828}
829
830// ─── Tests ───────────────────────────────────────────────────────────────────
831
832#[cfg(test)]
833mod tests {
834    use super::*;
835
836    // ── G-Counter tests ──────────────────────────────────────────────────────
837
838    #[test]
839    fn test_g_counter_basic() {
840        let mut c = GCounter::new();
841        c.increment("node-1");
842        c.increment("node-1");
843        c.increment("node-2");
844        assert_eq!(c.value(), 3);
845    }
846
847    #[test]
848    fn test_g_counter_merge() {
849        let mut c1 = GCounter::new();
850        c1.increment_by("node-1", 5);
851
852        let mut c2 = GCounter::new();
853        c2.increment_by("node-1", 3);
854        c2.increment_by("node-2", 7);
855
856        c1.merge(&c2);
857        // node-1: max(5, 3)=5, node-2: max(0, 7)=7
858        assert_eq!(c1.value(), 12);
859    }
860
861    #[test]
862    fn test_g_counter_merge_idempotent() {
863        let mut c1 = GCounter::new();
864        c1.increment_by("node-1", 10);
865        let c2 = c1.clone();
866        c1.merge(&c2);
867        assert_eq!(c1.value(), 10);
868    }
869
870    // ── PN-Counter tests ─────────────────────────────────────────────────────
871
872    #[test]
873    fn test_pn_counter_basic() {
874        let mut c = PnCounter::new();
875        c.increment("node-1");
876        c.increment("node-1");
877        c.increment("node-1");
878        c.decrement("node-1");
879        assert_eq!(c.value(), 2);
880    }
881
882    #[test]
883    fn test_pn_counter_merge() {
884        let mut c1 = PnCounter::new();
885        c1.increment("node-1");
886
887        let mut c2 = PnCounter::new();
888        c2.increment("node-2");
889        c2.decrement("node-2");
890
891        c1.merge(&c2);
892        // node-1: +1, node-2: +1-1=0 → net 1
893        assert_eq!(c1.value(), 1);
894    }
895
896    // ── LWW-Register tests ───────────────────────────────────────────────────
897
898    #[test]
899    fn test_lww_register_write_and_read() {
900        let mut reg: LwwRegister<String> = LwwRegister::new();
901        reg.write_at("hello".to_string(), "node-1", 100);
902        assert_eq!(reg.read(), Some(&"hello".to_string()));
903    }
904
905    #[test]
906    fn test_lww_register_last_write_wins() {
907        let mut reg: LwwRegister<String> = LwwRegister::new();
908        reg.write_at("first".to_string(), "node-1", 100);
909        reg.write_at("second".to_string(), "node-2", 200);
910        reg.write_at("old".to_string(), "node-3", 50);
911        assert_eq!(reg.read(), Some(&"second".to_string()));
912    }
913
914    #[test]
915    fn test_lww_register_merge() {
916        let mut r1: LwwRegister<String> = LwwRegister::new();
917        r1.write_at("r1-value".to_string(), "node-1", 100);
918
919        let mut r2: LwwRegister<String> = LwwRegister::new();
920        r2.write_at("r2-value".to_string(), "node-2", 200);
921
922        r1.merge(&r2);
923        assert_eq!(r1.read(), Some(&"r2-value".to_string()));
924        assert_eq!(r1.timestamp(), 200);
925    }
926
927    // ── CRDT Event Log tests ─────────────────────────────────────────────────
928
929    #[test]
930    fn test_crdt_event_log_append() {
931        let log = CrdtEventLog::new("node-1");
932        let seq0 = log.append(b"event-0".to_vec());
933        let seq1 = log.append(b"event-1".to_vec());
934        assert_eq!(seq0, 0);
935        assert_eq!(seq1, 1);
936        let stats = log.stats();
937        assert_eq!(stats.total_entries, 2);
938        assert_eq!(stats.event_counter, 2);
939        assert_eq!(stats.activity_counter, 2);
940    }
941
942    #[test]
943    fn test_crdt_event_log_registers() {
944        let log = CrdtEventLog::new("node-a");
945        log.set_register("config", b"v1".to_vec());
946        assert_eq!(log.get_register("config"), Some(b"v1".to_vec()));
947        log.set_register("config", b"v2".to_vec());
948        // LWW — but both writes are "now", so v2 should win (same or higher ts)
949        assert!(log.get_register("config").is_some());
950    }
951
952    #[test]
953    fn test_crdt_event_log_merge() {
954        let log1 = CrdtEventLog::new("node-1");
955        log1.append(b"n1-event".to_vec());
956
957        let log2 = CrdtEventLog::new("node-2");
958        log2.append(b"n2-event".to_vec());
959        log2.append(b"n2-event-2".to_vec());
960
961        let remote_state = log2.export_state();
962        log1.merge_remote(&remote_state);
963
964        let stats = log1.stats();
965        // log1 had 1 entry from node-1; after merge should have 1 + 2 = 3
966        assert_eq!(stats.total_entries, 3);
967        assert_eq!(stats.contributing_nodes, 2);
968    }
969
970    #[test]
971    fn test_crdt_event_log_record_removal() {
972        let log = CrdtEventLog::new("node-x");
973        log.append(b"e1".to_vec());
974        log.append(b"e2".to_vec());
975        log.record_removal();
976        let stats = log.stats();
977        // G-counter: 2, activity: +2 -1 = 1
978        assert_eq!(stats.event_counter, 2);
979        assert_eq!(stats.activity_counter, 1);
980    }
981
982    // ── Distributed Checkpointer tests ───────────────────────────────────────
983
984    #[test]
985    fn test_distributed_checkpointer_completes_on_all_nodes() {
986        let expected: std::collections::HashSet<String> =
987            ["n1", "n2"].iter().map(|s| s.to_string()).collect();
988        let checkpointer = DistributedCheckpointer::new(expected);
989
990        let cp1 = make_node_checkpoint("ckpt-1", "n1", 100, b"state-n1".to_vec());
991        let result = checkpointer
992            .submit_node_checkpoint(cp1)
993            .expect("submit should succeed");
994        assert!(result.is_none(), "should not complete with 1/2 nodes");
995
996        let cp2 = make_node_checkpoint("ckpt-1", "n2", 110, b"state-n2".to_vec());
997        let result = checkpointer
998            .submit_node_checkpoint(cp2)
999            .expect("submit should succeed");
1000        assert!(result.is_some(), "should complete with 2/2 nodes");
1001
1002        let global = result.expect("must be Some");
1003        assert_eq!(global.checkpoint_id, "ckpt-1");
1004        assert_eq!(global.node_checkpoints.len(), 2);
1005        assert_eq!(global.min_logical_time, 100);
1006        assert_eq!(global.max_logical_time, 110);
1007        assert!(global.is_complete);
1008    }
1009
1010    #[test]
1011    fn test_distributed_checkpointer_rejects_bad_merkle() {
1012        let expected: std::collections::HashSet<String> =
1013            ["n1"].iter().map(|s| s.to_string()).collect();
1014        let checkpointer = DistributedCheckpointer::new(expected);
1015
1016        let mut cp = make_node_checkpoint("ckpt-bad", "n1", 50, b"data".to_vec());
1017        cp.merkle_root = "deadbeef".to_string(); // deliberately wrong
1018
1019        let result = checkpointer.submit_node_checkpoint(cp);
1020        assert!(
1021            matches!(
1022                result,
1023                Err(DistributedStateError::CheckpointVerificationFailed { .. })
1024            ),
1025            "should reject bad Merkle root"
1026        );
1027
1028        let stats = checkpointer.stats();
1029        assert_eq!(stats.failed_verifications, 1);
1030    }
1031
1032    // ── State Replication Manager tests ──────────────────────────────────────
1033
1034    #[test]
1035    fn test_state_replication_gossip_produced() {
1036        let config = ReplicationConfig {
1037            fanout: 2,
1038            gossip_interval: Duration::from_millis(100),
1039            stale_rounds: 5,
1040        };
1041        let mgr = StateReplicationManager::new("node-1", config);
1042        mgr.add_peer("node-2");
1043        mgr.add_peer("node-3");
1044        mgr.update_local_state(b"my-state".to_vec());
1045
1046        let messages = mgr.produce_gossip();
1047        assert!(!messages.is_empty(), "should produce gossip messages");
1048        assert!(messages.len() <= 2, "fanout should be respected");
1049        for msg in &messages {
1050            assert_eq!(msg.from_node, "node-1");
1051            assert!(!msg.state_digest.is_empty());
1052        }
1053    }
1054
1055    #[test]
1056    fn test_state_replication_receive_sync() {
1057        let config = ReplicationConfig::default();
1058        let receiver = StateReplicationManager::new("node-2", config);
1059        receiver.add_peer("node-1");
1060        receiver.update_local_state(b"old-state".to_vec());
1061
1062        let new_state = b"new-state-from-node-1".to_vec();
1063        let new_digest = DistributedCheckpointer::compute_merkle_root(&new_state);
1064
1065        let gossip = GossipMessage {
1066            from_node: "node-1".to_string(),
1067            to_node: Some("node-2".to_string()),
1068            round: 1,
1069            state_digest: new_digest.clone(),
1070            state_payload: Some(new_state.clone()),
1071            timestamp: SystemTime::now(),
1072        };
1073
1074        let synced = receiver
1075            .receive_gossip(gossip)
1076            .expect("receive should succeed");
1077        assert!(synced, "should detect and apply diverged state");
1078        assert_eq!(receiver.local_digest(), new_digest);
1079
1080        let stats = receiver.stats();
1081        assert_eq!(stats.messages_received, 1);
1082        assert_eq!(stats.sync_count, 1);
1083    }
1084
1085    #[test]
1086    fn test_state_replication_no_sync_when_same_digest() {
1087        let config = ReplicationConfig::default();
1088        let mgr = StateReplicationManager::new("node-x", config);
1089        mgr.add_peer("node-y");
1090        let state = b"shared-state".to_vec();
1091        mgr.update_local_state(state.clone());
1092
1093        let digest = mgr.local_digest();
1094        let gossip = GossipMessage {
1095            from_node: "node-y".to_string(),
1096            to_node: Some("node-x".to_string()),
1097            round: 1,
1098            state_digest: digest,
1099            state_payload: None,
1100            timestamp: SystemTime::now(),
1101        };
1102
1103        let synced = mgr.receive_gossip(gossip).expect("receive should succeed");
1104        assert!(!synced, "should not sync when digests match");
1105        let stats = mgr.stats();
1106        assert_eq!(stats.sync_count, 0);
1107    }
1108
1109    #[test]
1110    fn test_merkle_root_deterministic() {
1111        let data = b"hello world";
1112        let r1 = DistributedCheckpointer::compute_merkle_root(data);
1113        let r2 = DistributedCheckpointer::compute_merkle_root(data);
1114        assert_eq!(r1, r2);
1115
1116        let r3 = DistributedCheckpointer::compute_merkle_root(b"different");
1117        assert_ne!(r1, r3);
1118    }
1119}