Skip to main content

mdcs_core/
pncounter.rs

1//! PN-Counter (Positive-Negative Counter) CRDT
2//!
3//! A PN-Counter supports both increment and decrement operations by maintaining
4//! two separate counters: one for increments (P) and one for decrements (N).
5//! The value is P - N.
6//!
7//! Each replica has its own counter entry, and the join operation performs
8//! component-wise max across all replicas.
9
10use crate::lattice::Lattice;
11use serde::{Deserialize, Serialize};
12use std::collections::BTreeMap;
13
14/// A Positive-Negative Counter CRDT
15///
16/// Supports both increment and decrement by maintaining two separate counters.
17/// Value = sum(increments) - sum(decrements)
18#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
19pub struct PNCounter<K: Ord + Clone> {
20    /// Per-replica increment counters
21    increments: BTreeMap<K, u64>,
22    /// Per-replica decrement counters
23    decrements: BTreeMap<K, u64>,
24}
25
26impl<K: Ord + Clone> PNCounter<K> {
27    /// Create a new PN-Counter
28    pub fn new() -> Self {
29        Self {
30            increments: BTreeMap::new(),
31            decrements: BTreeMap::new(),
32        }
33    }
34
35    /// Increment the counter for a specific replica
36    pub fn increment(&mut self, replica_id: K, amount: u64) {
37        let entry = self.increments.entry(replica_id).or_insert(0);
38        *entry = entry.saturating_add(amount);
39    }
40
41    /// Decrement the counter for a specific replica
42    pub fn decrement(&mut self, replica_id: K, amount: u64) {
43        let entry = self.decrements.entry(replica_id).or_insert(0);
44        *entry = entry.saturating_add(amount);
45    }
46
47    /// Get the current value (sum of increments - sum of decrements)
48    pub fn value(&self) -> i64 {
49        let inc_sum: u64 = self.increments.values().sum();
50        let dec_sum: u64 = self.decrements.values().sum();
51        (inc_sum as i64).saturating_sub(dec_sum as i64)
52    }
53
54    /// Get the increment counter for a replica
55    pub fn get_increment(&self, replica_id: &K) -> u64 {
56        self.increments.get(replica_id).copied().unwrap_or(0)
57    }
58
59    /// Get the decrement counter for a replica
60    pub fn get_decrement(&self, replica_id: &K) -> u64 {
61        self.decrements.get(replica_id).copied().unwrap_or(0)
62    }
63
64    /// Get a reference to all increment counters
65    pub fn increments(&self) -> &BTreeMap<K, u64> {
66        &self.increments
67    }
68
69    /// Get a reference to all decrement counters
70    pub fn decrements(&self) -> &BTreeMap<K, u64> {
71        &self.decrements
72    }
73}
74
75impl<K: Ord + Clone> Default for PNCounter<K> {
76    fn default() -> Self {
77        Self::new()
78    }
79}
80
81impl<K: Ord + Clone> Lattice for PNCounter<K> {
82    fn bottom() -> Self {
83        Self::new()
84    }
85
86    /// Join operation performs component-wise max on both counters
87    /// This ensures that concurrent updates always converge to the same value
88    fn join(&self, other: &Self) -> Self {
89        let mut increments = self.increments.clone();
90        let mut decrements = self.decrements.clone();
91
92        // Merge other's increments (take max for each replica)
93        for (k, v) in &other.increments {
94            increments
95                .entry(k.clone())
96                .and_modify(|e| *e = (*e).max(*v))
97                .or_insert(*v);
98        }
99
100        // Merge other's decrements (take max for each replica)
101        for (k, v) in &other.decrements {
102            decrements
103                .entry(k.clone())
104                .and_modify(|e| *e = (*e).max(*v))
105                .or_insert(*v);
106        }
107
108        Self {
109            increments,
110            decrements,
111        }
112    }
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118
119    #[test]
120    fn test_pncounter_basic_operations() {
121        let mut counter = PNCounter::new();
122
123        // Increment from replica "A"
124        counter.increment("A", 5);
125        assert_eq!(counter.value(), 5);
126
127        // Decrement from replica "B"
128        counter.decrement("B", 2);
129        assert_eq!(counter.value(), 3);
130
131        // Increment again
132        counter.increment("A", 3);
133        assert_eq!(counter.value(), 6);
134    }
135
136    #[test]
137    fn test_pncounter_join_idempotent() {
138        let mut c1 = PNCounter::new();
139        c1.increment("A", 5);
140        c1.decrement("B", 2);
141
142        let joined = c1.join(&c1);
143        assert_eq!(joined.value(), c1.value());
144        assert_eq!(joined.value(), 3);
145    }
146
147    #[test]
148    fn test_pncounter_join_commutative() {
149        let mut c1 = PNCounter::new();
150        c1.increment("A", 5);
151
152        let mut c2 = PNCounter::new();
153        c2.increment("B", 3);
154        c2.decrement("A", 1);
155
156        let joined1 = c1.join(&c2);
157        let joined2 = c2.join(&c1);
158
159        assert_eq!(joined1.value(), joined2.value());
160        assert_eq!(joined1.get_increment(&"A"), 5);
161        assert_eq!(joined1.get_increment(&"B"), 3);
162        assert_eq!(joined1.get_decrement(&"A"), 1);
163    }
164
165    #[test]
166    fn test_pncounter_join_associative() {
167        let mut c1 = PNCounter::new();
168        c1.increment("A", 1);
169
170        let mut c2 = PNCounter::new();
171        c2.increment("B", 2);
172
173        let mut c3 = PNCounter::new();
174        c3.decrement("C", 1);
175
176        let left = c1.join(&c2).join(&c3);
177        let right = c1.join(&c2.join(&c3));
178
179        assert_eq!(left.value(), right.value());
180    }
181
182    #[test]
183    fn test_pncounter_bottom_is_identity() {
184        let mut counter = PNCounter::new();
185        counter.increment("A", 5);
186        counter.decrement("B", 2);
187
188        let bottom = PNCounter::bottom();
189        let joined = counter.join(&bottom);
190
191        assert_eq!(joined.value(), counter.value());
192    }
193
194    #[test]
195    fn test_pncounter_convergence_different_order() {
196        let mut c1 = PNCounter::new();
197        c1.increment("X", 10);
198        c1.decrement("Y", 3);
199
200        let mut c2 = PNCounter::new();
201        c2.increment("Z", 5);
202        c2.decrement("X", 2);
203
204        // Apply updates in different order
205        let mut state1 = PNCounter::bottom();
206        state1.join_assign(&c1);
207        state1.join_assign(&c2);
208
209        let mut state2 = PNCounter::bottom();
210        state2.join_assign(&c2);
211        state2.join_assign(&c1);
212
213        assert_eq!(state1.value(), state2.value());
214    }
215
216    #[test]
217    fn test_pncounter_serialization() {
218        let mut counter = PNCounter::new();
219        counter.increment("replica1", 100);
220        counter.decrement("replica2", 25);
221
222        let serialized = serde_json::to_string(&counter).unwrap();
223        let deserialized: PNCounter<String> = serde_json::from_str(&serialized).unwrap();
224
225        assert_eq!(deserialized.value(), counter.value());
226        assert_eq!(deserialized.get_increment(&"replica1".to_string()), 100);
227        assert_eq!(deserialized.get_decrement(&"replica2".to_string()), 25);
228    }
229}