Skip to main content

atomr_distributed_data/
maps.rs

1//! Map-shaped CRDTs. akka.net: `ORMap`, `LWWMap`, `PNCounterMap`.
2//!
3//! Phase 8 of `docs/full-port-plan.md`. Three flavours of CRDT map:
4//!
5//! * [`ORMap`] — keys can be added & removed concurrently; per-key
6//!   value is itself a CRDT (`V: CrdtMerge`).
7//! * [`LWWMap`] — keys map to last-write-wins-registered values; the
8//!   highest timestamp per key wins.
9//! * [`PNCounterMap`] — keys map to `PNCounter`s; merge is per-key
10//!   PNCounter merge.
11
12use std::collections::HashMap;
13use std::hash::Hash;
14
15use serde::{Deserialize, Serialize};
16
17use crate::counters::PNCounter;
18use crate::sets::OrSet;
19use crate::traits::CrdtMerge;
20
21// -- ORMap ---------------------------------------------------------
22
23/// Observed-remove map of K → V (V itself a CRDT).
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct ORMap<K, V>
26where
27    K: Eq + Hash + Clone,
28    V: CrdtMerge,
29{
30    entries: HashMap<K, (u64, V)>, // (add-tag, value)
31    tombstones: HashMap<K, u64>,
32    counter: u64,
33}
34
35impl<K: Eq + Hash + Clone, V: CrdtMerge> Default for ORMap<K, V> {
36    fn default() -> Self {
37        Self { entries: HashMap::new(), tombstones: HashMap::new(), counter: 0 }
38    }
39}
40
41impl<K: Eq + Hash + Clone, V: CrdtMerge> ORMap<K, V> {
42    pub fn new() -> Self {
43        Self::default()
44    }
45
46    /// Insert or update an entry. Bumps the per-key add-tag so a
47    /// concurrent `remove` (with an older tag) can be merged
48    /// correctly.
49    pub fn put(&mut self, key: K, value: V) {
50        self.counter += 1;
51        self.entries.insert(key, (self.counter, value));
52    }
53
54    /// Update the value for `key` in-place (CRDT merge).
55    pub fn update(&mut self, key: K, value: V) {
56        self.counter += 1;
57        match self.entries.get_mut(&key) {
58            Some((tag, existing)) => {
59                existing.merge(&value);
60                *tag = self.counter;
61            }
62            None => {
63                self.entries.insert(key, (self.counter, value));
64            }
65        }
66    }
67
68    pub fn remove(&mut self, key: &K) {
69        if let Some((tag, _)) = self.entries.get(key) {
70            self.tombstones.insert(key.clone(), *tag);
71        }
72    }
73
74    pub fn get(&self, key: &K) -> Option<&V> {
75        let (add_tag, v) = self.entries.get(key)?;
76        match self.tombstones.get(key) {
77            Some(tomb) if tomb >= add_tag => None,
78            _ => Some(v),
79        }
80    }
81
82    pub fn iter(&self) -> impl Iterator<Item = (&K, &V)> {
83        self.entries.iter().filter_map(|(k, (add, v))| match self.tombstones.get(k) {
84            Some(tomb) if tomb >= add => None,
85            _ => Some((k, v)),
86        })
87    }
88}
89
90impl<K: Eq + Hash + Clone, V: CrdtMerge> CrdtMerge for ORMap<K, V> {
91    fn merge(&mut self, other: &Self) {
92        for (k, (other_tag, other_v)) in &other.entries {
93            match self.entries.get_mut(k) {
94                Some((tag, existing)) => {
95                    existing.merge(other_v);
96                    *tag = (*tag).max(*other_tag);
97                }
98                None => {
99                    self.entries.insert(k.clone(), (*other_tag, other_v.clone()));
100                }
101            }
102        }
103        for (k, t) in &other.tombstones {
104            let cur = self.tombstones.entry(k.clone()).or_insert(0);
105            *cur = (*cur).max(*t);
106        }
107        self.counter = self.counter.max(other.counter);
108    }
109}
110
111// -- LWWMap --------------------------------------------------------
112
113/// Last-write-wins map of K → V.
114#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct LWWMap<K, V>
116where
117    K: Eq + Hash + Clone,
118    V: Clone,
119{
120    entries: HashMap<K, (u128, V)>, // (timestamp, value)
121}
122
123impl<K: Eq + Hash + Clone, V: Clone> Default for LWWMap<K, V> {
124    fn default() -> Self {
125        Self { entries: HashMap::new() }
126    }
127}
128
129impl<K: Eq + Hash + Clone, V: Clone> LWWMap<K, V> {
130    pub fn new() -> Self {
131        Self::default()
132    }
133
134    pub fn put(&mut self, key: K, value: V, timestamp: u128) {
135        match self.entries.get(&key) {
136            Some((ts, _)) if *ts >= timestamp => {} // older write — drop
137            _ => {
138                self.entries.insert(key, (timestamp, value));
139            }
140        }
141    }
142
143    pub fn get(&self, key: &K) -> Option<&V> {
144        self.entries.get(key).map(|(_, v)| v)
145    }
146
147    pub fn iter(&self) -> impl Iterator<Item = (&K, &V)> {
148        self.entries.iter().map(|(k, (_, v))| (k, v))
149    }
150}
151
152impl<K: Eq + Hash + Clone, V: Clone> CrdtMerge for LWWMap<K, V> {
153    fn merge(&mut self, other: &Self) {
154        for (k, (ts, v)) in &other.entries {
155            match self.entries.get(k) {
156                Some((my_ts, _)) if my_ts >= ts => {}
157                _ => {
158                    self.entries.insert(k.clone(), (*ts, v.clone()));
159                }
160            }
161        }
162    }
163}
164
165// -- PNCounterMap --------------------------------------------------
166
167/// Map of K → `PNCounter`.
168#[derive(Debug, Clone, Serialize, Deserialize)]
169pub struct PNCounterMap<K>
170where
171    K: Eq + Hash + Clone,
172{
173    entries: HashMap<K, PNCounter>,
174}
175
176impl<K: Eq + Hash + Clone> Default for PNCounterMap<K> {
177    fn default() -> Self {
178        Self { entries: HashMap::new() }
179    }
180}
181
182impl<K: Eq + Hash + Clone> PNCounterMap<K> {
183    pub fn new() -> Self {
184        Self::default()
185    }
186
187    pub fn increment(&mut self, key: K, node: &str, delta: u64) {
188        self.entries.entry(key).or_default().increment(node, delta);
189    }
190
191    pub fn decrement(&mut self, key: K, node: &str, delta: u64) {
192        self.entries.entry(key).or_default().decrement(node, delta);
193    }
194
195    pub fn value(&self, key: &K) -> i64 {
196        self.entries.get(key).map(|c| c.value()).unwrap_or(0)
197    }
198
199    pub fn iter(&self) -> impl Iterator<Item = (&K, i64)> {
200        self.entries.iter().map(|(k, c)| (k, c.value()))
201    }
202}
203
204impl<K: Eq + Hash + Clone> CrdtMerge for PNCounterMap<K> {
205    fn merge(&mut self, other: &Self) {
206        for (k, v) in &other.entries {
207            self.entries.entry(k.clone()).or_default().merge(v);
208        }
209    }
210}
211
212// -- ORMultiMap --------------------------------------------------
213
214/// Map of K → set-of-V, where the set is itself an `OrSet<V>`.
215/// Akka.DistributedData's `ORMultiMap`. Phase 8.B.
216#[derive(Debug, Clone, Serialize, Deserialize)]
217pub struct ORMultiMap<K, V>
218where
219    K: Eq + Hash + Clone,
220    V: Eq + Hash + Clone,
221{
222    entries: HashMap<K, OrSet<V>>,
223}
224
225impl<K: Eq + Hash + Clone, V: Eq + Hash + Clone> Default for ORMultiMap<K, V> {
226    fn default() -> Self {
227        Self { entries: HashMap::new() }
228    }
229}
230
231impl<K: Eq + Hash + Clone, V: Eq + Hash + Clone> ORMultiMap<K, V> {
232    pub fn new() -> Self {
233        Self::default()
234    }
235
236    pub fn add(&mut self, key: K, value: V) {
237        self.entries.entry(key).or_default().add(value);
238    }
239
240    pub fn remove(&mut self, key: &K, value: &V) {
241        if let Some(set) = self.entries.get_mut(key) {
242            set.remove(value);
243        }
244    }
245
246    pub fn contains(&self, key: &K, value: &V) -> bool {
247        self.entries.get(key).map(|s| s.contains(value)).unwrap_or(false)
248    }
249
250    pub fn key_count(&self) -> usize {
251        self.entries.len()
252    }
253}
254
255impl<K: Eq + Hash + Clone, V: Eq + Hash + Clone> CrdtMerge for ORMultiMap<K, V> {
256    fn merge(&mut self, other: &Self) {
257        for (k, set) in &other.entries {
258            self.entries.entry(k.clone()).or_default().merge(set);
259        }
260    }
261}
262
263#[cfg(test)]
264mod tests {
265    use super::*;
266
267    #[test]
268    fn ormap_concurrent_put_and_remove_resolves_to_remove() {
269        let mut a = ORMap::<&'static str, crate::counters::GCounter>::new();
270        a.put("k", crate::counters::GCounter::new());
271        let mut b = a.clone();
272        b.remove(&"k");
273        a.merge(&b);
274        assert!(a.get(&"k").is_none());
275    }
276
277    #[test]
278    fn ormap_concurrent_re_add_after_remove() {
279        let mut a = ORMap::<&'static str, crate::counters::GCounter>::new();
280        a.put("k", crate::counters::GCounter::new());
281        let mut b = a.clone();
282        b.remove(&"k");
283        // Concurrent re-add on a wins because its tag is newer.
284        a.put("k", crate::counters::GCounter::new());
285        a.merge(&b);
286        assert!(a.get(&"k").is_some());
287    }
288
289    #[test]
290    fn lwwmap_higher_timestamp_wins() {
291        let mut a = LWWMap::<&'static str, i32>::new();
292        let mut b = LWWMap::<&'static str, i32>::new();
293        a.put("k", 1, 100);
294        b.put("k", 2, 200);
295        a.merge(&b);
296        assert_eq!(a.get(&"k"), Some(&2));
297        // Reverse direction: older write must not displace.
298        let mut a = LWWMap::<&'static str, i32>::new();
299        let mut b = LWWMap::<&'static str, i32>::new();
300        a.put("k", 1, 200);
301        b.put("k", 2, 100);
302        a.merge(&b);
303        assert_eq!(a.get(&"k"), Some(&1));
304    }
305
306    #[test]
307    fn pncounter_map_per_key_counts() {
308        let mut m: PNCounterMap<&'static str> = PNCounterMap::new();
309        m.increment("alice", "n1", 5);
310        m.increment("bob", "n1", 3);
311        m.decrement("alice", "n1", 2);
312        assert_eq!(m.value(&"alice"), 3);
313        assert_eq!(m.value(&"bob"), 3);
314
315        let mut m2: PNCounterMap<&'static str> = PNCounterMap::new();
316        m2.increment("alice", "n2", 7);
317        m.merge(&m2);
318        assert_eq!(m.value(&"alice"), 10);
319        assert_eq!(m.value(&"bob"), 3);
320    }
321}