1use crate::lattice::Lattice;
11use serde::{Deserialize, Serialize};
12use std::collections::BTreeMap;
13
14#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
19pub struct PNCounter<K: Ord + Clone> {
20 increments: BTreeMap<K, u64>,
22 decrements: BTreeMap<K, u64>,
24}
25
26impl<K: Ord + Clone> PNCounter<K> {
27 pub fn new() -> Self {
29 Self {
30 increments: BTreeMap::new(),
31 decrements: BTreeMap::new(),
32 }
33 }
34
35 pub fn increment(&mut self, replica_id: K, amount: u64) {
37 let entry = self.increments.entry(replica_id).or_insert(0);
38 *entry = entry.saturating_add(amount);
39 }
40
41 pub fn decrement(&mut self, replica_id: K, amount: u64) {
43 let entry = self.decrements.entry(replica_id).or_insert(0);
44 *entry = entry.saturating_add(amount);
45 }
46
47 pub fn value(&self) -> i64 {
49 let inc_sum: u64 = self.increments.values().sum();
50 let dec_sum: u64 = self.decrements.values().sum();
51 (inc_sum as i64).saturating_sub(dec_sum as i64)
52 }
53
54 pub fn get_increment(&self, replica_id: &K) -> u64 {
56 self.increments.get(replica_id).copied().unwrap_or(0)
57 }
58
59 pub fn get_decrement(&self, replica_id: &K) -> u64 {
61 self.decrements.get(replica_id).copied().unwrap_or(0)
62 }
63
64 pub fn increments(&self) -> &BTreeMap<K, u64> {
66 &self.increments
67 }
68
69 pub fn decrements(&self) -> &BTreeMap<K, u64> {
71 &self.decrements
72 }
73}
74
75impl<K: Ord + Clone> Default for PNCounter<K> {
76 fn default() -> Self {
77 Self::new()
78 }
79}
80
81impl<K: Ord + Clone> Lattice for PNCounter<K> {
82 fn bottom() -> Self {
83 Self::new()
84 }
85
86 fn join(&self, other: &Self) -> Self {
89 let mut increments = self.increments.clone();
90 let mut decrements = self.decrements.clone();
91
92 for (k, v) in &other.increments {
94 increments
95 .entry(k.clone())
96 .and_modify(|e| *e = (*e).max(*v))
97 .or_insert(*v);
98 }
99
100 for (k, v) in &other.decrements {
102 decrements
103 .entry(k.clone())
104 .and_modify(|e| *e = (*e).max(*v))
105 .or_insert(*v);
106 }
107
108 Self {
109 increments,
110 decrements,
111 }
112 }
113}
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118
119 #[test]
120 fn test_pncounter_basic_operations() {
121 let mut counter = PNCounter::new();
122
123 counter.increment("A", 5);
125 assert_eq!(counter.value(), 5);
126
127 counter.decrement("B", 2);
129 assert_eq!(counter.value(), 3);
130
131 counter.increment("A", 3);
133 assert_eq!(counter.value(), 6);
134 }
135
136 #[test]
137 fn test_pncounter_join_idempotent() {
138 let mut c1 = PNCounter::new();
139 c1.increment("A", 5);
140 c1.decrement("B", 2);
141
142 let joined = c1.join(&c1);
143 assert_eq!(joined.value(), c1.value());
144 assert_eq!(joined.value(), 3);
145 }
146
147 #[test]
148 fn test_pncounter_join_commutative() {
149 let mut c1 = PNCounter::new();
150 c1.increment("A", 5);
151
152 let mut c2 = PNCounter::new();
153 c2.increment("B", 3);
154 c2.decrement("A", 1);
155
156 let joined1 = c1.join(&c2);
157 let joined2 = c2.join(&c1);
158
159 assert_eq!(joined1.value(), joined2.value());
160 assert_eq!(joined1.get_increment(&"A"), 5);
161 assert_eq!(joined1.get_increment(&"B"), 3);
162 assert_eq!(joined1.get_decrement(&"A"), 1);
163 }
164
165 #[test]
166 fn test_pncounter_join_associative() {
167 let mut c1 = PNCounter::new();
168 c1.increment("A", 1);
169
170 let mut c2 = PNCounter::new();
171 c2.increment("B", 2);
172
173 let mut c3 = PNCounter::new();
174 c3.decrement("C", 1);
175
176 let left = c1.join(&c2).join(&c3);
177 let right = c1.join(&c2.join(&c3));
178
179 assert_eq!(left.value(), right.value());
180 }
181
182 #[test]
183 fn test_pncounter_bottom_is_identity() {
184 let mut counter = PNCounter::new();
185 counter.increment("A", 5);
186 counter.decrement("B", 2);
187
188 let bottom = PNCounter::bottom();
189 let joined = counter.join(&bottom);
190
191 assert_eq!(joined.value(), counter.value());
192 }
193
194 #[test]
195 fn test_pncounter_convergence_different_order() {
196 let mut c1 = PNCounter::new();
197 c1.increment("X", 10);
198 c1.decrement("Y", 3);
199
200 let mut c2 = PNCounter::new();
201 c2.increment("Z", 5);
202 c2.decrement("X", 2);
203
204 let mut state1 = PNCounter::bottom();
206 state1.join_assign(&c1);
207 state1.join_assign(&c2);
208
209 let mut state2 = PNCounter::bottom();
210 state2.join_assign(&c2);
211 state2.join_assign(&c1);
212
213 assert_eq!(state1.value(), state2.value());
214 }
215
216 #[test]
217 fn test_pncounter_serialization() {
218 let mut counter = PNCounter::new();
219 counter.increment("replica1", 100);
220 counter.decrement("replica2", 25);
221
222 let serialized = serde_json::to_string(&counter).unwrap();
223 let deserialized: PNCounter<String> = serde_json::from_str(&serialized).unwrap();
224
225 assert_eq!(deserialized.value(), counter.value());
226 assert_eq!(deserialized.get_increment(&"replica1".to_string()), 100);
227 assert_eq!(deserialized.get_decrement(&"replica2".to_string()), 25);
228 }
229}