Skip to main content

abyo_crdt/
counter.rs

1//! Counter CRDT — signed integer with concurrent increment/decrement support.
2//!
3//! Implemented as a delta-based signed counter with version-vector idempotency.
4//! Each `add(delta)` produces an op carrying the delta and a unique [`OpId`];
5//! applying an already-seen op is a no-op, so concurrent and replayed deltas
6//! all resolve to the same final sum.
7//!
8//! Internally also tracks a per-replica positive/negative running sum so the
9//! state is exposed in PN-Counter form for inspection.
10//!
11//! ## Quick start
12//!
13//! ```
14//! use abyo_crdt::Counter;
15//!
16//! let mut alice = Counter::new(1);
17//! let mut bob = Counter::new(2);
18//!
19//! alice.add(5);
20//! bob.add(-2);
21//! bob.add(10);
22//!
23//! alice.merge(&bob);
24//! bob.merge(&alice);
25//!
26//! assert_eq!(alice.value(), 13);
27//! assert_eq!(bob.value(), 13);
28//! ```
29
30use crate::{
31    error::Error,
32    id::{OpId, ReplicaId},
33    version::VersionVector,
34};
35use std::collections::HashMap;
36
37#[cfg(feature = "serde")]
38use serde::{Deserialize, Serialize};
39
40// ---------------------------------------------------------------------------
41// Public op type
42// ---------------------------------------------------------------------------
43
44/// A single [`Counter`] CRDT operation.
45#[derive(Clone, Copy, Debug, PartialEq, Eq)]
46#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
47pub struct CounterOp {
48    /// Unique id of this op (drives idempotency).
49    pub id: OpId,
50    /// Signed delta to apply.
51    pub delta: i64,
52}
53
54impl CounterOp {
55    /// The id of this op.
56    #[must_use]
57    pub fn id(&self) -> OpId {
58        self.id
59    }
60}
61
62// ---------------------------------------------------------------------------
63// Counter CRDT
64// ---------------------------------------------------------------------------
65
66/// PN-Counter CRDT (signed). See the module docs for semantics.
67#[derive(Clone, Debug)]
68#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
69pub struct Counter {
70    replica: ReplicaId,
71    clock: u64,
72    /// Per-replica positive cumulative sum (for inspection).
73    p: HashMap<ReplicaId, u128>,
74    /// Per-replica negative cumulative sum (for inspection).
75    n: HashMap<ReplicaId, u128>,
76    /// Cached signed sum, kept in sync with applied ops.
77    value: i128,
78    log: Vec<CounterOp>,
79    version: VersionVector,
80}
81
82impl Counter {
83    /// Create a counter at zero for the given replica.
84    #[must_use]
85    pub fn new(replica: ReplicaId) -> Self {
86        Self {
87            replica,
88            clock: 0,
89            p: HashMap::new(),
90            n: HashMap::new(),
91            value: 0,
92            log: Vec::new(),
93            version: VersionVector::new(),
94        }
95    }
96
97    /// Create a new instance with a random [`ReplicaId`] from OS entropy.
98    /// See [`crate::new_replica_id`].
99    #[must_use]
100    pub fn new_random() -> Self {
101        Self::new(crate::id::new_replica_id())
102    }
103
104    /// This replica's id.
105    #[must_use]
106    pub fn replica_id(&self) -> ReplicaId {
107        self.replica
108    }
109
110    /// Current signed value.
111    #[must_use]
112    pub fn value(&self) -> i128 {
113        self.value
114    }
115
116    /// Sum of all positive contributions across all replicas.
117    #[must_use]
118    pub fn positive_total(&self) -> u128 {
119        self.p.values().sum()
120    }
121
122    /// Sum of all negative contributions across all replicas.
123    #[must_use]
124    pub fn negative_total(&self) -> u128 {
125        self.n.values().sum()
126    }
127
128    /// Apply a signed `delta` to this replica. Returns the generated op.
129    pub fn add(&mut self, delta: i64) -> CounterOp {
130        self.clock = self
131            .clock
132            .checked_add(1)
133            .expect("Lamport clock overflow (>2^64 ops)");
134        let id = OpId::new(self.clock, self.replica);
135        let op = CounterOp { id, delta };
136        self.apply_internal(id, delta);
137        self.version.observe(id);
138        self.log.push(op);
139        op
140    }
141
142    /// Increment by an unsigned amount.
143    pub fn increment(&mut self, by: u64) -> CounterOp {
144        self.add(i64::try_from(by).expect("increment overflow"))
145    }
146
147    /// Decrement by an unsigned amount.
148    pub fn decrement(&mut self, by: u64) -> CounterOp {
149        self.add(-i64::try_from(by).expect("decrement overflow"))
150    }
151
152    /// Apply a remote op. Idempotent.
153    pub fn apply(&mut self, op: CounterOp) -> Result<(), Error> {
154        if self.version.contains(op.id) {
155            return Ok(());
156        }
157        self.apply_internal(op.id, op.delta);
158        self.version.observe(op.id);
159        self.clock = self.clock.max(op.id.counter);
160        self.log.push(op);
161        Ok(())
162    }
163
164    /// Merge all ops from `other` we haven't seen.
165    pub fn merge(&mut self, other: &Self) {
166        let mut to_apply: Vec<&CounterOp> = other
167            .log
168            .iter()
169            .filter(|op| !self.version.contains(op.id))
170            .collect();
171        to_apply.sort_by_key(|op| op.id);
172        for op in to_apply {
173            self.apply(*op).expect("counter apply cannot fail");
174        }
175    }
176
177    /// All ops in this counter's log.
178    #[must_use]
179    pub fn ops(&self) -> &[CounterOp] {
180        &self.log
181    }
182
183    /// Iterate over ops not yet seen by `since`.
184    pub fn ops_since<'a>(
185        &'a self,
186        since: &'a VersionVector,
187    ) -> impl Iterator<Item = &'a CounterOp> + 'a {
188        self.log.iter().filter(move |op| !since.contains(op.id))
189    }
190
191    /// This replica's current version vector.
192    #[must_use]
193    pub fn version(&self) -> &VersionVector {
194        &self.version
195    }
196
197    fn apply_internal(&mut self, id: OpId, delta: i64) {
198        self.value += i128::from(delta);
199        if delta >= 0 {
200            // delta is non-negative; cast to u64 is sign-safe.
201            #[allow(clippy::cast_sign_loss)]
202            let abs = delta as u64;
203            *self.p.entry(id.replica).or_insert(0) += u128::from(abs);
204        } else {
205            *self.n.entry(id.replica).or_insert(0) += u128::from(delta.unsigned_abs());
206        }
207    }
208}
209
210impl Default for Counter {
211    fn default() -> Self {
212        Self::new(0)
213    }
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219
220    #[test]
221    fn empty_counter_is_zero() {
222        let c = Counter::new(1);
223        assert_eq!(c.value(), 0);
224    }
225
226    #[test]
227    fn add_and_value() {
228        let mut c = Counter::new(1);
229        c.add(5);
230        c.add(-2);
231        c.add(10);
232        assert_eq!(c.value(), 13);
233    }
234
235    #[test]
236    fn increment_decrement_helpers() {
237        let mut c = Counter::new(1);
238        c.increment(10);
239        c.decrement(3);
240        assert_eq!(c.value(), 7);
241    }
242
243    #[test]
244    fn merge_sums_concurrent_deltas() {
245        let mut a = Counter::new(1);
246        let mut b = Counter::new(2);
247        a.add(10);
248        b.add(20);
249        let mut a2 = a.clone();
250        a2.merge(&b);
251        let mut b2 = b.clone();
252        b2.merge(&a);
253        assert_eq!(a2.value(), 30);
254        assert_eq!(b2.value(), 30);
255    }
256
257    #[test]
258    fn idempotent_apply() {
259        let mut a = Counter::new(1);
260        let op1 = a.add(5);
261        let op2 = a.add(7);
262        let mut b = Counter::new(2);
263        b.apply(op1).unwrap();
264        b.apply(op1).unwrap(); // duplicate
265        b.apply(op2).unwrap();
266        b.apply(op2).unwrap();
267        assert_eq!(b.value(), 12);
268    }
269
270    #[test]
271    fn pn_breakdown() {
272        let mut a = Counter::new(1);
273        let mut b = Counter::new(2);
274        a.add(10);
275        a.add(-3);
276        b.add(20);
277        b.add(-5);
278        a.merge(&b);
279        assert_eq!(a.value(), 22);
280        assert_eq!(a.positive_total(), 30);
281        assert_eq!(a.negative_total(), 8);
282    }
283
284    #[test]
285    fn ops_since_returns_only_unseen() {
286        let mut a = Counter::new(1);
287        a.add(1);
288        let v1 = a.version().clone();
289        a.add(2);
290        a.add(3);
291        let new: Vec<&CounterOp> = a.ops_since(&v1).collect();
292        assert_eq!(new.len(), 2);
293    }
294}