Skip to main content

atomr_distributed_data/
counters.rs

1//! Grow-only counter and positive/negative counter. akka.net: `GCounter`, `PNCounter`.
2
3use std::collections::HashMap;
4
5use serde::{Deserialize, Serialize};
6
7use crate::traits::{CrdtMerge, DeltaCrdt};
8
9#[derive(Debug, Default, Clone, Serialize, Deserialize)]
10pub struct GCounter {
11    state: HashMap<String, u64>,
12    /// Accumulated since the last `take_delta`. Skipped on
13    /// serialization so peers never see another node's pending
14    /// delta — they receive deltas through the explicit
15    /// `Replicator::propagate_delta` path.
16    #[serde(skip)]
17    pending_delta: HashMap<String, u64>,
18}
19
20impl GCounter {
21    pub fn new() -> Self {
22        Self::default()
23    }
24
25    pub fn increment(&mut self, node: &str, delta: u64) {
26        let key = node.to_string();
27        *self.state.entry(key.clone()).or_default() += delta;
28        *self.pending_delta.entry(key).or_default() += delta;
29    }
30
31    pub fn value(&self) -> u64 {
32        self.state.values().copied().sum()
33    }
34}
35
36impl CrdtMerge for GCounter {
37    fn merge(&mut self, other: &Self) {
38        for (k, v) in &other.state {
39            let slot = self.state.entry(k.clone()).or_default();
40            *slot = (*slot).max(*v);
41        }
42    }
43}
44
45impl DeltaCrdt for GCounter {
46    /// Delta is just the per-node increments accumulated since the
47    /// last take. Merging adds to the recipient's per-node count.
48    type Delta = HashMap<String, u64>;
49
50    fn take_delta(&mut self) -> Option<Self::Delta> {
51        if self.pending_delta.is_empty() {
52            return None;
53        }
54        Some(std::mem::take(&mut self.pending_delta))
55    }
56
57    fn merge_delta(&mut self, delta: &Self::Delta) {
58        for (k, v) in delta {
59            let slot = self.state.entry(k.clone()).or_default();
60            *slot += *v;
61        }
62    }
63}
64
65#[derive(Debug, Default, Clone, Serialize, Deserialize)]
66pub struct PNCounter {
67    inc: GCounter,
68    dec: GCounter,
69}
70
71impl PNCounter {
72    pub fn new() -> Self {
73        Self::default()
74    }
75
76    pub fn increment(&mut self, node: &str, delta: u64) {
77        self.inc.increment(node, delta);
78    }
79
80    pub fn decrement(&mut self, node: &str, delta: u64) {
81        self.dec.increment(node, delta);
82    }
83
84    pub fn value(&self) -> i64 {
85        self.inc.value() as i64 - self.dec.value() as i64
86    }
87}
88
89impl CrdtMerge for PNCounter {
90    fn merge(&mut self, other: &Self) {
91        self.inc.merge(&other.inc);
92        self.dec.merge(&other.dec);
93    }
94}
95
96#[cfg(test)]
97mod tests {
98    use super::*;
99
100    #[test]
101    fn gcounter_merges_take_max_per_node() {
102        let mut a = GCounter::new();
103        let mut b = GCounter::new();
104        a.increment("n1", 5);
105        b.increment("n1", 3);
106        b.increment("n2", 7);
107        a.merge(&b);
108        assert_eq!(a.value(), 5 + 7);
109    }
110
111    #[test]
112    fn pncounter_supports_positive_negative() {
113        let mut c = PNCounter::new();
114        c.increment("n1", 10);
115        c.decrement("n1", 3);
116        assert_eq!(c.value(), 7);
117    }
118
119    #[test]
120    fn delta_take_and_clear() {
121        let mut c = GCounter::new();
122        c.increment("a", 3);
123        c.increment("b", 2);
124        let delta = c.take_delta().expect("non-empty");
125        assert_eq!(delta.get("a"), Some(&3));
126        assert_eq!(delta.get("b"), Some(&2));
127        // Cleared on take.
128        assert!(c.take_delta().is_none());
129    }
130
131    #[test]
132    fn delta_merge_adds_to_remote() {
133        let mut local = GCounter::new();
134        local.increment("a", 5);
135        let _ = local.take_delta();
136
137        let mut remote = GCounter::new();
138        remote.increment("a", 1); // remote saw 1 from "a"
139        let _ = remote.take_delta();
140
141        // Local emits an additional +3 delta; remote applies it.
142        local.increment("a", 3);
143        let delta = local.take_delta().unwrap();
144        remote.merge_delta(&delta);
145        assert_eq!(remote.value(), 1 + 3);
146    }
147}