1use crate::{
31 error::Error,
32 id::{OpId, ReplicaId},
33 version::VersionVector,
34};
35use std::collections::HashMap;
36
37#[cfg(feature = "serde")]
38use serde::{Deserialize, Serialize};
39
40#[derive(Clone, Copy, Debug, PartialEq, Eq)]
46#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
47pub struct CounterOp {
48 pub id: OpId,
50 pub delta: i64,
52}
53
54impl CounterOp {
55 #[must_use]
57 pub fn id(&self) -> OpId {
58 self.id
59 }
60}
61
62#[derive(Clone, Debug)]
68#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
69pub struct Counter {
70 replica: ReplicaId,
71 clock: u64,
72 p: HashMap<ReplicaId, u128>,
74 n: HashMap<ReplicaId, u128>,
76 value: i128,
78 log: Vec<CounterOp>,
79 version: VersionVector,
80}
81
82impl Counter {
83 #[must_use]
85 pub fn new(replica: ReplicaId) -> Self {
86 Self {
87 replica,
88 clock: 0,
89 p: HashMap::new(),
90 n: HashMap::new(),
91 value: 0,
92 log: Vec::new(),
93 version: VersionVector::new(),
94 }
95 }
96
97 #[must_use]
100 pub fn new_random() -> Self {
101 Self::new(crate::id::new_replica_id())
102 }
103
104 #[must_use]
106 pub fn replica_id(&self) -> ReplicaId {
107 self.replica
108 }
109
110 #[must_use]
112 pub fn value(&self) -> i128 {
113 self.value
114 }
115
116 #[must_use]
118 pub fn positive_total(&self) -> u128 {
119 self.p.values().sum()
120 }
121
122 #[must_use]
124 pub fn negative_total(&self) -> u128 {
125 self.n.values().sum()
126 }
127
128 pub fn add(&mut self, delta: i64) -> CounterOp {
130 self.clock = self
131 .clock
132 .checked_add(1)
133 .expect("Lamport clock overflow (>2^64 ops)");
134 let id = OpId::new(self.clock, self.replica);
135 let op = CounterOp { id, delta };
136 self.apply_internal(id, delta);
137 self.version.observe(id);
138 self.log.push(op);
139 op
140 }
141
142 pub fn increment(&mut self, by: u64) -> CounterOp {
144 self.add(i64::try_from(by).expect("increment overflow"))
145 }
146
147 pub fn decrement(&mut self, by: u64) -> CounterOp {
149 self.add(-i64::try_from(by).expect("decrement overflow"))
150 }
151
152 pub fn apply(&mut self, op: CounterOp) -> Result<(), Error> {
154 if self.version.contains(op.id) {
155 return Ok(());
156 }
157 self.apply_internal(op.id, op.delta);
158 self.version.observe(op.id);
159 self.clock = self.clock.max(op.id.counter);
160 self.log.push(op);
161 Ok(())
162 }
163
164 pub fn merge(&mut self, other: &Self) {
166 let mut to_apply: Vec<&CounterOp> = other
167 .log
168 .iter()
169 .filter(|op| !self.version.contains(op.id))
170 .collect();
171 to_apply.sort_by_key(|op| op.id);
172 for op in to_apply {
173 self.apply(*op).expect("counter apply cannot fail");
174 }
175 }
176
177 #[must_use]
179 pub fn ops(&self) -> &[CounterOp] {
180 &self.log
181 }
182
183 pub fn ops_since<'a>(
185 &'a self,
186 since: &'a VersionVector,
187 ) -> impl Iterator<Item = &'a CounterOp> + 'a {
188 self.log.iter().filter(move |op| !since.contains(op.id))
189 }
190
191 #[must_use]
193 pub fn version(&self) -> &VersionVector {
194 &self.version
195 }
196
197 fn apply_internal(&mut self, id: OpId, delta: i64) {
198 self.value += i128::from(delta);
199 if delta >= 0 {
200 #[allow(clippy::cast_sign_loss)]
202 let abs = delta as u64;
203 *self.p.entry(id.replica).or_insert(0) += u128::from(abs);
204 } else {
205 *self.n.entry(id.replica).or_insert(0) += u128::from(delta.unsigned_abs());
206 }
207 }
208}
209
210impl Default for Counter {
211 fn default() -> Self {
212 Self::new(0)
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219
220 #[test]
221 fn empty_counter_is_zero() {
222 let c = Counter::new(1);
223 assert_eq!(c.value(), 0);
224 }
225
226 #[test]
227 fn add_and_value() {
228 let mut c = Counter::new(1);
229 c.add(5);
230 c.add(-2);
231 c.add(10);
232 assert_eq!(c.value(), 13);
233 }
234
235 #[test]
236 fn increment_decrement_helpers() {
237 let mut c = Counter::new(1);
238 c.increment(10);
239 c.decrement(3);
240 assert_eq!(c.value(), 7);
241 }
242
243 #[test]
244 fn merge_sums_concurrent_deltas() {
245 let mut a = Counter::new(1);
246 let mut b = Counter::new(2);
247 a.add(10);
248 b.add(20);
249 let mut a2 = a.clone();
250 a2.merge(&b);
251 let mut b2 = b.clone();
252 b2.merge(&a);
253 assert_eq!(a2.value(), 30);
254 assert_eq!(b2.value(), 30);
255 }
256
257 #[test]
258 fn idempotent_apply() {
259 let mut a = Counter::new(1);
260 let op1 = a.add(5);
261 let op2 = a.add(7);
262 let mut b = Counter::new(2);
263 b.apply(op1).unwrap();
264 b.apply(op1).unwrap(); b.apply(op2).unwrap();
266 b.apply(op2).unwrap();
267 assert_eq!(b.value(), 12);
268 }
269
270 #[test]
271 fn pn_breakdown() {
272 let mut a = Counter::new(1);
273 let mut b = Counter::new(2);
274 a.add(10);
275 a.add(-3);
276 b.add(20);
277 b.add(-5);
278 a.merge(&b);
279 assert_eq!(a.value(), 22);
280 assert_eq!(a.positive_total(), 30);
281 assert_eq!(a.negative_total(), 8);
282 }
283
284 #[test]
285 fn ops_since_returns_only_unseen() {
286 let mut a = Counter::new(1);
287 a.add(1);
288 let v1 = a.version().clone();
289 a.add(2);
290 a.add(3);
291 let new: Vec<&CounterOp> = a.ops_since(&v1).collect();
292 assert_eq!(new.len(), 2);
293 }
294}