Skip to main content

llm_sync/
crdt.rs

1// SPDX-License-Identifier: MIT
2//! CRDT implementations: GCounter, PNCounter, GSet, LWWRegister, ORMap.
3//!
4//! Every merge operation is commutative, associative, and idempotent.
5//! No merge operation ever blocks.
6
7use std::collections::{HashMap, HashSet};
8use serde::{Deserialize, Serialize};
9
10// ── GCounter (grow-only counter) ─────────────────────────────────────────────
11
12/// A grow-only counter CRDT.
13///
14/// Per-node counters are merged by taking the component-wise maximum.
15/// The aggregate value is the sum of all per-node counters.
16#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
17pub struct GCounter {
18    counts: HashMap<String, u64>,
19}
20
21impl GCounter {
22    /// Create a new, empty GCounter.
23    pub fn new() -> Self { Self::default() }
24
25    /// Increment this node's counter by `amount`.
26    ///
27    /// # Arguments
28    /// * `node` — Identifier of the node performing the increment.
29    /// * `amount` — Value to add (must be positive for growth semantics).
30    pub fn increment(&mut self, node: impl Into<String>, amount: u64) {
31        *self.counts.entry(node.into()).or_insert(0) += amount;
32    }
33
34    /// Return the aggregate value across all nodes.
35    pub fn value(&self) -> u64 { self.counts.values().sum() }
36
37    /// Merge with another GCounter, taking the component-wise maximum.
38    ///
39    /// The operation is commutative, associative, and idempotent.
40    pub fn merge(&self, other: &GCounter) -> GCounter {
41        let mut result = self.counts.clone();
42        for (k, &v) in &other.counts {
43            let e = result.entry(k.clone()).or_insert(0);
44            if v > *e { *e = v; }
45        }
46        GCounter { counts: result }
47    }
48}
49
50// ── PNCounter (increment/decrement counter) ──────────────────────────────────
51
52/// A positive-negative counter CRDT that supports both increment and decrement.
53///
54/// Internally composed of two GCounters: one for increments, one for decrements.
55#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
56pub struct PNCounter {
57    increments: GCounter,
58    decrements: GCounter,
59}
60
61impl PNCounter {
62    /// Create a new, zeroed PNCounter.
63    pub fn new() -> Self { Self::default() }
64
65    /// Increment by `amount` for `node`.
66    pub fn increment(&mut self, node: impl Into<String>, amount: u64) {
67        self.increments.increment(node, amount);
68    }
69
70    /// Decrement by `amount` for `node`.
71    pub fn decrement(&mut self, node: impl Into<String>, amount: u64) {
72        self.decrements.increment(node, amount);
73    }
74
75    /// Return the net value (increments − decrements).
76    pub fn value(&self) -> i64 {
77        self.increments.value() as i64 - self.decrements.value() as i64
78    }
79
80    /// Merge with another PNCounter.
81    pub fn merge(&self, other: &PNCounter) -> PNCounter {
82        PNCounter {
83            increments: self.increments.merge(&other.increments),
84            decrements: self.decrements.merge(&other.decrements),
85        }
86    }
87}
88
89// ── GSet (grow-only set) ─────────────────────────────────────────────────────
90
91/// A grow-only set CRDT. Elements can be added but never removed.
92#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
93pub struct GSet {
94    items: HashSet<String>,
95}
96
97impl GSet {
98    /// Create a new, empty GSet.
99    pub fn new() -> Self { Self::default() }
100
101    /// Insert an element.
102    pub fn insert(&mut self, item: impl Into<String>) { self.items.insert(item.into()); }
103
104    /// Return true if the element is present.
105    pub fn contains(&self, item: &str) -> bool { self.items.contains(item) }
106
107    /// Return the number of elements.
108    pub fn len(&self) -> usize { self.items.len() }
109
110    /// Return true if the set is empty.
111    pub fn is_empty(&self) -> bool { self.items.is_empty() }
112
113    /// Merge with another GSet (union).
114    pub fn merge(&self, other: &GSet) -> GSet {
115        GSet { items: self.items.union(&other.items).cloned().collect() }
116    }
117
118    /// Iterate over elements.
119    pub fn iter(&self) -> impl Iterator<Item = &str> {
120        self.items.iter().map(|s| s.as_str())
121    }
122}
123
124// ── LWWRegister (last-write-wins register) ───────────────────────────────────
125
126/// A last-write-wins register CRDT.
127///
128/// Writes with a higher logical timestamp overwrite earlier writes.
129/// On merge, the replica with the higher timestamp wins.
130#[derive(Debug, Clone, Serialize, Deserialize)]
131#[serde(bound(deserialize = "T: serde::de::DeserializeOwned"))]
132pub struct LWWRegister<T: Clone + Serialize + serde::de::DeserializeOwned> {
133    value: Option<T>,
134    timestamp: u64,
135    writer: String,
136}
137
138impl<T: Clone + Serialize + serde::de::DeserializeOwned> LWWRegister<T> {
139    /// Create a new, empty register.
140    pub fn new() -> Self {
141        Self { value: None, timestamp: 0, writer: String::new() }
142    }
143
144    /// Write a value with a logical timestamp.
145    ///
146    /// # Arguments
147    /// * `value` — The value to write.
148    /// * `timestamp` — Logical timestamp; higher values win.
149    /// * `writer` — Identifier of the writing node.
150    pub fn write(&mut self, value: T, timestamp: u64, writer: impl Into<String>) {
151        if timestamp >= self.timestamp {
152            self.value = Some(value);
153            self.timestamp = timestamp;
154            self.writer = writer.into();
155        }
156    }
157
158    /// Read the current value. Returns `None` if never written.
159    pub fn read(&self) -> Option<&T> { self.value.as_ref() }
160
161    /// Merge with another register. The one with the higher timestamp wins.
162    pub fn merge(&self, other: &LWWRegister<T>) -> LWWRegister<T> {
163        if other.timestamp > self.timestamp {
164            other.clone()
165        } else {
166            self.clone()
167        }
168    }
169
170    /// Return the current logical timestamp.
171    pub fn timestamp(&self) -> u64 { self.timestamp }
172}
173
174impl<T: Clone + Serialize + serde::de::DeserializeOwned> Default for LWWRegister<T> {
175    fn default() -> Self { Self::new() }
176}
177
178// ── ORMap (observed-remove map, LWW-semantics) ───────────────────────────────
179
180/// An observed-remove map CRDT backed by per-key LWWRegisters.
181///
182/// Keys can be added and updated. Deletion can be modeled via tombstone
183/// conventions at the application layer (not enforced here).
184#[derive(Debug, Clone, Serialize, Deserialize, Default)]
185pub struct ORMap {
186    entries: HashMap<String, LWWRegister<String>>,
187}
188
189impl ORMap {
190    /// Create a new, empty ORMap.
191    pub fn new() -> Self { Self::default() }
192
193    /// Set a key to a value at the given logical timestamp.
194    ///
195    /// # Arguments
196    /// * `key` — The map key.
197    /// * `value` — The string value to store.
198    /// * `timestamp` — Logical timestamp for LWW resolution.
199    /// * `writer` — Identifier of the writing agent.
200    pub fn set(
201        &mut self,
202        key: impl Into<String>,
203        value: impl Into<String>,
204        timestamp: u64,
205        writer: impl Into<String>,
206    ) {
207        let k = key.into();
208        self.entries.entry(k).or_default().write(value.into(), timestamp, writer);
209    }
210
211    /// Get a value by key. Returns `None` if not set.
212    pub fn get(&self, key: &str) -> Option<&str> {
213        self.entries.get(key)?.read().map(|s| s.as_str())
214    }
215
216    /// Merge with another ORMap.
217    pub fn merge(&self, other: &ORMap) -> ORMap {
218        let mut result = self.entries.clone();
219        for (k, reg) in &other.entries {
220            let entry = result.entry(k.clone()).or_default();
221            *entry = entry.merge(reg);
222        }
223        ORMap { entries: result }
224    }
225
226    /// Return the number of keys.
227    pub fn len(&self) -> usize { self.entries.len() }
228
229    /// Return true if the map is empty.
230    pub fn is_empty(&self) -> bool { self.entries.is_empty() }
231}
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236
237    // ── GCounter ──────────────────────────────────────────────────────────────
238
239    #[test]
240    fn test_gcounter_increment_and_value() {
241        let mut c = GCounter::new();
242        c.increment("a", 3);
243        c.increment("b", 2);
244        assert_eq!(c.value(), 5);
245    }
246
247    #[test]
248    fn test_gcounter_merge_is_commutative() {
249        let mut c1 = GCounter::new(); c1.increment("a", 5);
250        let mut c2 = GCounter::new(); c2.increment("b", 3);
251        assert_eq!(c1.merge(&c2).value(), c2.merge(&c1).value());
252    }
253
254    #[test]
255    fn test_gcounter_merge_is_idempotent() {
256        let mut c = GCounter::new(); c.increment("a", 10);
257        assert_eq!(c.merge(&c.clone()).value(), c.value());
258    }
259
260    #[test]
261    fn test_gcounter_merge_is_associative() {
262        let mut a = GCounter::new(); a.increment("x", 1);
263        let mut b = GCounter::new(); b.increment("y", 2);
264        let mut c = GCounter::new(); c.increment("z", 3);
265        assert_eq!(a.merge(&b).merge(&c).value(), a.merge(&b.merge(&c)).value());
266    }
267
268    #[test]
269    fn test_gcounter_merge_takes_max_per_node() {
270        let mut c1 = GCounter::new(); c1.increment("a", 10);
271        let mut c2 = GCounter::new(); c2.increment("a", 5);
272        assert_eq!(c1.merge(&c2).value(), 10);
273    }
274
275    #[test]
276    fn test_gcounter_new_is_zero() {
277        assert_eq!(GCounter::new().value(), 0);
278    }
279
280    // ── PNCounter ─────────────────────────────────────────────────────────────
281
282    #[test]
283    fn test_pncounter_increment_decrement() {
284        let mut c = PNCounter::new();
285        c.increment("a", 10);
286        c.decrement("a", 3);
287        assert_eq!(c.value(), 7);
288    }
289
290    #[test]
291    fn test_pncounter_merge_is_commutative() {
292        let mut c1 = PNCounter::new(); c1.increment("a", 5);
293        let mut c2 = PNCounter::new(); c2.decrement("b", 2);
294        assert_eq!(c1.merge(&c2).value(), c2.merge(&c1).value());
295    }
296
297    #[test]
298    fn test_pncounter_merge_is_idempotent() {
299        let mut c = PNCounter::new(); c.increment("x", 7);
300        assert_eq!(c.merge(&c.clone()).value(), c.value());
301    }
302
303    #[test]
304    fn test_pncounter_new_is_zero() {
305        assert_eq!(PNCounter::new().value(), 0);
306    }
307
308    #[test]
309    fn test_pncounter_can_go_negative() {
310        let mut c = PNCounter::new();
311        c.decrement("a", 5);
312        assert_eq!(c.value(), -5);
313    }
314
315    // ── GSet ──────────────────────────────────────────────────────────────────
316
317    #[test]
318    fn test_gset_insert_and_contains() {
319        let mut s = GSet::new();
320        s.insert("alpha");
321        assert!(s.contains("alpha"));
322        assert!(!s.contains("beta"));
323    }
324
325    #[test]
326    fn test_gset_merge_is_commutative() {
327        let mut s1 = GSet::new(); s1.insert("a");
328        let mut s2 = GSet::new(); s2.insert("b");
329        assert_eq!(s1.merge(&s2).len(), s2.merge(&s1).len());
330    }
331
332    #[test]
333    fn test_gset_merge_is_idempotent() {
334        let mut s = GSet::new(); s.insert("x");
335        assert_eq!(s.merge(&s.clone()).len(), s.len());
336    }
337
338    #[test]
339    fn test_gset_merge_union() {
340        let mut s1 = GSet::new(); s1.insert("a"); s1.insert("b");
341        let mut s2 = GSet::new(); s2.insert("b"); s2.insert("c");
342        assert_eq!(s1.merge(&s2).len(), 3);
343    }
344
345    #[test]
346    fn test_gset_is_empty_on_new() {
347        assert!(GSet::new().is_empty());
348    }
349
350    #[test]
351    fn test_gset_iter_yields_all_elements() {
352        let mut s = GSet::new(); s.insert("a"); s.insert("b");
353        let mut items: Vec<&str> = s.iter().collect();
354        items.sort();
355        assert_eq!(items, vec!["a", "b"]);
356    }
357
358    // ── LWWRegister ───────────────────────────────────────────────────────────
359
360    #[test]
361    fn test_lww_register_write_and_read() {
362        let mut r: LWWRegister<String> = LWWRegister::new();
363        r.write("hello".into(), 1, "agent-1");
364        assert_eq!(r.read().unwrap(), "hello");
365    }
366
367    #[test]
368    fn test_lww_register_higher_timestamp_wins() {
369        let mut r: LWWRegister<String> = LWWRegister::new();
370        r.write("first".into(), 1, "a");
371        r.write("second".into(), 2, "b");
372        assert_eq!(r.read().unwrap(), "second");
373    }
374
375    #[test]
376    fn test_lww_register_lower_timestamp_ignored() {
377        let mut r: LWWRegister<String> = LWWRegister::new();
378        r.write("latest".into(), 10, "a");
379        r.write("old".into(), 5, "b");
380        assert_eq!(r.read().unwrap(), "latest");
381    }
382
383    #[test]
384    fn test_lww_register_merge_picks_higher_ts() {
385        let mut r1: LWWRegister<String> = LWWRegister::new();
386        r1.write("old".into(), 1, "a");
387        let mut r2: LWWRegister<String> = LWWRegister::new();
388        r2.write("new".into(), 5, "b");
389        assert_eq!(r1.merge(&r2).read().unwrap(), "new");
390    }
391
392    #[test]
393    fn test_lww_register_merge_is_commutative() {
394        let mut r1: LWWRegister<String> = LWWRegister::new();
395        r1.write("v1".into(), 3, "a");
396        let mut r2: LWWRegister<String> = LWWRegister::new();
397        r2.write("v2".into(), 7, "b");
398        assert_eq!(r1.merge(&r2).read(), r2.merge(&r1).read());
399    }
400
401    #[test]
402    fn test_lww_register_new_is_empty() {
403        let r: LWWRegister<String> = LWWRegister::new();
404        assert!(r.read().is_none());
405        assert_eq!(r.timestamp(), 0);
406    }
407
408    // ── ORMap ─────────────────────────────────────────────────────────────────
409
410    #[test]
411    fn test_ormap_set_and_get() {
412        let mut m = ORMap::new();
413        m.set("key", "value", 1, "a");
414        assert_eq!(m.get("key").unwrap(), "value");
415    }
416
417    #[test]
418    fn test_ormap_get_missing_key_returns_none() {
419        let m = ORMap::new();
420        assert!(m.get("absent").is_none());
421    }
422
423    #[test]
424    fn test_ormap_merge_lww_semantics() {
425        let mut m1 = ORMap::new(); m1.set("k", "v1", 1, "a");
426        let mut m2 = ORMap::new(); m2.set("k", "v2", 2, "b");
427        assert_eq!(m1.merge(&m2).get("k").unwrap(), "v2");
428    }
429
430    #[test]
431    fn test_ormap_merge_is_commutative() {
432        let mut m1 = ORMap::new(); m1.set("x", "val1", 3, "a");
433        let mut m2 = ORMap::new(); m2.set("y", "val2", 1, "b");
434        let merged1 = m1.merge(&m2);
435        let merged2 = m2.merge(&m1);
436        assert_eq!(merged1.get("x"), merged2.get("x"));
437        assert_eq!(merged1.get("y"), merged2.get("y"));
438    }
439
440    #[test]
441    fn test_ormap_merge_is_idempotent() {
442        let mut m = ORMap::new(); m.set("k", "v", 5, "a");
443        let m2 = m.merge(&m.clone());
444        assert_eq!(m2.get("k"), m.get("k"));
445    }
446
447    #[test]
448    fn test_ormap_is_empty_on_new() {
449        assert!(ORMap::new().is_empty());
450    }
451
452    #[test]
453    fn test_ormap_len_counts_unique_keys() {
454        let mut m = ORMap::new();
455        m.set("a", "1", 1, "x");
456        m.set("b", "2", 2, "x");
457        assert_eq!(m.len(), 2);
458    }
459}