causal_length/
map.rs

1use super::*;
2use crate::register::Register;
3use std::borrow::Borrow;
4use std::cmp::max;
5use std::collections::HashMap;
6
7/// Causal Length Map
8///
9/// A CRDT map based on an adaptation of the causal length set.
10///
11/// `Map` uses the tag for garbage collection of old removed members, and to
12/// resolve conflicting values for the same key and causal length.
13#[derive(Clone, Debug, Default, Eq, PartialEq)]
14pub struct Map<K, V, Tag, CL>
15where
16    K: Key + Ord,
17    V: Value + Hash + Eq + Ord,
18    Tag: TagT,
19    CL: CausalLength,
20{
21    map: HashMap<K, Register<V, Tag, CL>>,
22}
23
24impl<K, V, Tag, CL> Map<K, V, Tag, CL>
25where
26    K: Key + Ord,
27    V: Value + Hash + Eq + Ord,
28    Tag: TagT,
29    CL: CausalLength,
30{
31    /// Create an empty `Map`
32    pub fn new() -> Map<K, V, Tag, CL> {
33        Map {
34            map: HashMap::new(),
35        }
36    }
37
38    /// Returns a reference to the value and tag corresponding to the key.
39    pub fn get<Q>(&self, key: Q) -> Option<(&V, Tag)>
40    where
41        Q: Borrow<K>,
42    {
43        if let Some(e) = self.map.get(key.borrow()) {
44            if e.length.is_odd() {
45                return Some((&e.item, e.tag));
46            }
47        }
48        None
49    }
50
51    /// Returns true if the map contains a value for the specified key.
52    pub fn contains<Q>(&self, key: Q) -> bool
53    where
54        Q: Borrow<K>,
55    {
56        self.get(key).is_some()
57    }
58
59    /// Inserts a key, value, and tag into the map.
60    ///
61    /// If the map did not have this key present, [`None`] is returned.
62    ///
63    /// If the map did have this key present, the value is updated, and the old
64    /// value is returned, along with the old tag.
65    pub fn insert(&mut self, key: K, value: V, tag: Tag) -> Option<(V, Tag)> {
66        let one: CL = CL::one();
67        let e = self.map.entry(key);
68        match e {
69            std::collections::hash_map::Entry::Occupied(mut oe) => {
70                let oe = oe.get_mut();
71                // s{e |-> s(e)+1} if even
72                //s if odd s(e)
73                if oe.length.is_even() {
74                    oe.length = oe.length + one;
75                } else if oe.item != value {
76                    // Special adaptation for a map: we add two to the causal length
77                    // in cases where the key exists, but the value is not the same.
78                    // This is equivalent to removing and re-adding the key.
79                    oe.length = oe.length + one + one;
80                }
81                // always use the max value of tag
82                oe.tag = max(oe.tag, tag);
83                let r = oe.item.clone();
84                oe.item = value;
85                Some((r, oe.tag))
86            }
87            _ => {
88                e.or_insert_with(|| Register::make(value, tag, one));
89                None
90            }
91        }
92    }
93
94    /// Remove a key from the map, returning the stored value and tag if
95    /// the key was in the map.
96    pub fn remove(&mut self, key: K, tag: Tag) -> Option<(V, Tag)> {
97        let e = self.map.entry(key);
98        match e {
99            std::collections::hash_map::Entry::Occupied(mut oe) => {
100                let oe = oe.get_mut();
101                oe.tag = max(oe.tag, tag);
102
103                // {} if even(s(e))
104                // { e |-> s(e) + 1 } if odd(s(e))
105                if oe.length.is_odd() {
106                    oe.length = oe.length + CL::one();
107                    Some((oe.item.clone(), oe.tag))
108                } else {
109                    None
110                }
111            }
112            _ => None,
113        }
114        // ignore attempts to remove items that aren't present...
115    }
116
117    /// An iterator visiting all key, value, tag tuples in arbitrary order.
118    pub fn iter(&self) -> impl Iterator<Item = (K, V, Tag)> + '_ {
119        self.map
120            .iter()
121            .filter(|(_k, v)| v.length.is_odd())
122            .map(|(k, v)| (k.clone(), v.item.clone(), v.tag))
123    }
124
125    /// An iterator visiting all delta registers in arbitrary order.
126    pub fn register_iter(&self) -> impl Iterator<Item = Register<(K, V), Tag, CL>> + '_ {
127        self.map
128            .iter()
129            .map(|(k, v)| Register::make((k.clone(), v.item.clone()), v.tag, v.length))
130    }
131
132    /// Merge a delta [Register] into a map.
133    ///
134    /// Remove deltas with a tag value less than `min_tag` will be ignored.
135    pub fn merge_register(&mut self, delta: Register<(K, V), Tag, CL>, min_tag: Tag) {
136        if delta.length.is_even() && delta.tag < min_tag {
137            // ignore excessively old remove records
138            return;
139        }
140
141        match self.map.entry(delta.item.0.clone()) {
142            Entry::Occupied(mut e) => {
143                let e = e.get_mut();
144
145                let reg = Register::make(delta.item.1.clone(), delta.tag, delta.length);
146                e.merge(&reg);
147            }
148            Entry::Vacant(e) => {
149                e.insert(Register::make(delta.item.1, delta.tag, delta.length));
150            }
151        }
152    }
153
154    /// Merge two maps.
155    ///
156    /// Remove deltas with a tag value less than `min_tag` will be ignored.
157    pub fn merge(&mut self, other: &Self, min_tag: Tag) {
158        for delta in other.register_iter() {
159            self.merge_register(delta, min_tag);
160        }
161    }
162
163    /// Filter out old remove tombstone deltas from the map.
164    ///
165    /// Remove deltas with a tag value less than `min_tag` will be removed.
166    pub fn retain(&mut self, min_tag: Tag) {
167        self.map
168            .retain(|_k, v| v.length.is_odd() || min_tag < v.tag);
169    }
170}
171
172#[cfg(feature = "serialization")]
173mod serialization {
174    use super::*;
175    use serde::de::{SeqAccess, Visitor};
176    use serde::ser::SerializeSeq;
177    use serde::{Deserialize, Deserializer, Serialize, Serializer};
178    use std::fmt::Formatter;
179    use std::marker::PhantomData;
180
181    impl<K, V, Tag, CL> Serialize for Map<K, V, Tag, CL>
182    where
183        K: Key + Ord + Serialize,
184        V: Value + Hash + Ord + Serialize,
185        Tag: TagT + Serialize,
186        CL: CausalLength + Serialize,
187    {
188        fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
189        where
190            S: Serializer,
191        {
192            let mut seq = serializer.serialize_seq(Some(self.map.len()))?;
193            for member in self.register_iter() {
194                seq.serialize_element(&(member.item.0, member.item.1, member.tag, member.length))?;
195            }
196            seq.end()
197        }
198    }
199
200    struct DeltaVisitor<K, V, Tag, CL>(
201        PhantomData<K>,
202        PhantomData<V>,
203        PhantomData<Tag>,
204        PhantomData<CL>,
205    );
206
207    impl<'de, K, V, Tag, CL> Visitor<'de> for DeltaVisitor<K, V, Tag, CL>
208    where
209        K: Key + Ord + Deserialize<'de>,
210        V: Value + Hash + Ord + Deserialize<'de>,
211        Tag: TagT + Deserialize<'de>,
212        CL: CausalLength + Deserialize<'de>,
213    {
214        type Value = HashMap<K, Register<V, Tag, CL>>;
215
216        fn expecting(&self, formatter: &mut Formatter<'_>) -> std::fmt::Result {
217            formatter.write_str("a tuple of key, value, tag, and causal length")
218        }
219
220        fn visit_seq<A>(self, mut seq: A) -> std::result::Result<Self::Value, A::Error>
221        where
222            A: SeqAccess<'de>,
223        {
224            let mut map: HashMap<K, Register<V, Tag, CL>> =
225                HashMap::with_capacity(seq.size_hint().unwrap_or(0));
226            while let Some(d) = seq.next_element::<(K, V, Tag, CL)>()? {
227                map.insert(d.0, Register::make(d.1, d.2, d.3));
228            }
229            Ok(map)
230        }
231    }
232
233    impl<'de, K, V, Tag, CL> Deserialize<'de> for Map<K, V, Tag, CL>
234    where
235        K: Key + Ord + Deserialize<'de>,
236        V: Value + Hash + Ord + Deserialize<'de>,
237        Tag: TagT + Deserialize<'de>,
238        CL: CausalLength + Deserialize<'de>,
239    {
240        fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
241        where
242            D: Deserializer<'de>,
243        {
244            let visitor =
245                DeltaVisitor::<K, V, Tag, CL>(PhantomData, PhantomData, PhantomData, PhantomData);
246            let map = deserializer.deserialize_seq(visitor)?;
247
248            Ok(Map { map })
249        }
250    }
251}
252
253impl<K, V, Tag, CL> From<Set<(K, V), Tag, CL>> for Map<K, V, Tag, CL>
254where
255    K: Key + Ord,
256    V: Value + Hash + Eq + Ord,
257    Tag: TagT,
258    CL: CausalLength,
259{
260    fn from(s: Set<(K, V), Tag, CL>) -> Self {
261        let mut m = Self::new();
262        for item in s.register_iter() {
263            m.merge_register(item, Tag::default());
264        }
265        m
266    }
267}
268
269impl<K, V, Tag, CL> From<Map<K, V, Tag, CL>> for Set<(K, V), Tag, CL>
270where
271    K: Key + Ord,
272    V: Value + Hash + Eq + Ord,
273    Tag: TagT,
274    CL: CausalLength,
275{
276    fn from(m: Map<K, V, Tag, CL>) -> Self {
277        let mut s = Self::new();
278        for item in m.register_iter() {
279            s.merge_register(item, Tag::default());
280        }
281        s
282    }
283}
284
285impl<K, V, Tag, CL> From<Map<K, V, Tag, CL>> for HashMap<K, (V, Tag)>
286where
287    K: Key + Ord,
288    V: Value + Hash + Eq + Ord,
289    Tag: TagT,
290    CL: CausalLength,
291{
292    fn from(m: Map<K, V, Tag, CL>) -> Self {
293        let mut h = Self::new();
294        for item in m.register_iter() {
295            if let Some(((k, v), tag)) = item.get() {
296                h.insert(k.clone(), (v.clone(), tag));
297            }
298        }
299        h
300    }
301}
302
303#[cfg(feature = "serialization")]
304pub use serialization::*;
305use std::collections::hash_map::Entry;
306
307#[cfg(test)]
308mod tests {
309    use super::*;
310    use quickcheck_macros::quickcheck;
311    use rand::seq::SliceRandom;
312
313    #[test]
314    fn test_add() {
315        let later_time = 1;
316        let mut cls: Map<&str, bool, u16, u16> = Map::new();
317
318        cls.insert("foo", true, later_time);
319        cls.insert("foo", true, later_time);
320        cls.insert("foo", true, later_time);
321        assert_eq!(cls.map.len(), 1);
322        assert_eq!(
323            cls.map.get("foo"),
324            Some(&Register {
325                item: true,
326                tag: later_time,
327                length: 1
328            })
329        );
330        assert_eq!(cls.contains("foo"), true);
331        assert_eq!(cls.get("bar"), None);
332    }
333
334    #[test]
335    fn test_remove() {
336        let time_1 = 1;
337        let time_2 = 2;
338        let time_3 = 3;
339        let mut cls: Map<&str, bool, u32, u16> = Map::new();
340
341        cls.insert("foo", true, time_1);
342        cls.insert("bar", false, time_1);
343        cls.remove("foo", time_2);
344        cls.remove("bar", time_2);
345        cls.insert("bar", true, time_3);
346        // check map
347        assert_eq!(cls.map.len(), 2);
348        assert_eq!(
349            cls.map.get("bar"),
350            Some(&Register {
351                item: true,
352                tag: time_3,
353                length: 3
354            })
355        );
356        assert_eq!(
357            cls.map.get("foo"),
358            Some(&Register {
359                item: true,
360                tag: time_2,
361                length: 2
362            })
363        );
364        // check edges
365        let values: Vec<(&str, bool, u32)> = cls.iter().collect();
366        assert_eq!(values.len(), 1);
367        assert_eq!(values[0], ("bar", true, time_3));
368    }
369
370    #[test]
371    fn test_merge() {
372        let time_0 = 0;
373        let time_1 = 1;
374        let time_2 = 2;
375        let time_3 = 3;
376        let mut cls1: Map<&str, u32, u32, u16> = Map::new();
377        let mut cls2: Map<&str, u32, u32, u16> = Map::new();
378
379        cls1.insert("foo", 128, time_1);
380        cls1.insert("bar", 256, time_1);
381        cls2.merge(&cls1, time_0);
382        cls2.insert("foo", 128, time_2);
383        cls1.remove("foo", time_2);
384        cls1.remove("bar", time_2);
385        cls2.merge(&cls1, time_0);
386        cls2.insert("bar", 256, time_3);
387
388        assert_eq!(cls2.map.len(), 2);
389        assert_eq!(
390            cls2.map.get(&"bar"),
391            Some(&Register {
392                item: 256,
393                tag: time_3,
394                length: 3
395            })
396        );
397        assert_eq!(
398            cls2.map.get(&"foo"),
399            Some(&Register {
400                item: 128,
401                tag: time_2,
402                length: 2
403            })
404        );
405
406        let values: Vec<(&str, u32, u32)> = cls2.iter().collect();
407        assert_eq!(values.len(), 1);
408        assert_eq!(values[0], ("bar", 256, time_3));
409    }
410
411    #[test]
412    fn test_retain() {
413        let time_0 = 0;
414        let time_1 = 1;
415        let time_2 = 2;
416        let time_3 = 3;
417        let mut cls: Map<&str, u32, u32, u16> = Map::new();
418
419        cls.insert("foo", 128, time_0);
420        cls.insert("bar", 256, time_0);
421        cls.remove("foo", time_1);
422        cls.remove("bar", time_1);
423        cls.insert("bar", 256, time_2);
424        // check map
425        assert_eq!(cls.map.len(), 2);
426        assert_eq!(
427            cls.map.get(&"bar"),
428            Some(&Register {
429                item: 256,
430                tag: time_2,
431                length: 3
432            })
433        );
434        assert_eq!(
435            cls.map.get(&"foo"),
436            Some(&Register {
437                item: 128,
438                tag: time_1,
439                length: 2
440            })
441        );
442        // check edges
443        let values: Vec<(&str, u32, u32)> = cls.iter().collect();
444        assert_eq!(values.len(), 1);
445        assert_eq!(values[0], ("bar", 256, time_2));
446        // now clear old removes
447        cls.retain(time_3);
448        assert_eq!(cls.map.len(), 1);
449        assert_eq!(
450            cls.map.get(&"bar"),
451            Some(&Register {
452                item: 256,
453                tag: time_2,
454                length: 3
455            })
456        );
457        // attempt to merge an out of date remove
458        cls.merge_register(
459            Register {
460                item: ("bar", 512),
461                tag: time_2,
462                length: 2,
463            },
464            time_0,
465        );
466        assert_eq!(cls.map.len(), 1);
467        assert_eq!(
468            cls.map.get(&"bar"),
469            Some(&Register {
470                item: 256,
471                tag: time_2,
472                length: 3
473            })
474        );
475    }
476
477    #[test]
478    fn test_overwrite() {
479        let time_0 = 0;
480        let time_1 = 1;
481        let time_2 = 2;
482        let mut cls: Map<&str, u32, u32, u16> = Map::new();
483
484        cls.insert("bar", 256, time_0);
485        cls.insert("bar", 256, time_1);
486        // now try an overwrite
487        cls.insert("bar", 512, time_2);
488        assert_eq!(cls.map.len(), 1);
489        assert_eq!(
490            cls.map.get(&"bar"),
491            Some(&Register {
492                item: 512,
493                tag: time_2,
494                length: 3
495            })
496        );
497    }
498
499    #[cfg(feature = "serialization")]
500    #[test]
501    fn test_serialization() {
502        let time_1 = 1;
503        let time_2 = 2;
504        let time_3 = 3;
505        let mut m: Map<&str, bool, u32, u16> = Map::new();
506
507        m.insert("foo", true, time_1);
508        m.insert("bar", false, time_1);
509        m.remove("foo", time_2);
510        m.remove("bar", time_2);
511        m.insert("bar", true, time_3);
512
513        let data = serde_json::to_string(&m).unwrap_or("".to_owned());
514        let cls2: Map<&str, bool, u32, u16> = serde_json::from_str(&data).unwrap();
515        assert_eq!(m.map, cls2.map);
516    }
517
518    #[test]
519    fn test_order_independence() {
520        let mut m1: Map<&str, usize, u32, u16> = Map::new();
521        let mut m2: Map<&str, usize, u32, u16> = Map::new();
522        let mut v: Vec<Register<(&str, usize), u32, u16>> = vec![];
523
524        for i in 0..1000 {
525            v.push(Register {
526                item: ("foo", i as usize),
527                tag: i as u32,
528                length: i as u16,
529            });
530        }
531
532        // now randomize the updates
533        v.shuffle(&mut rand::thread_rng());
534
535        for r in v {
536            m1.merge_register(r, 0);
537        }
538        assert_eq!(
539            m1.map.get("foo"),
540            Some(&Register {
541                item: 999,
542                tag: 999,
543                length: 999
544            })
545        );
546
547        let mut v: Vec<Register<(&str, usize), u32, u16>> = vec![];
548        for i in 0..1000 {
549            v.push(Register {
550                item: ("foo", i as usize),
551                tag: i as u32,
552                length: i as u16,
553            });
554        }
555        v.shuffle(&mut rand::thread_rng());
556        for r in v {
557            m2.merge_register(r, 0);
558        }
559        assert_eq!(m1, m2);
560    }
561
562    fn merge(mut acc: Map<u8, u8, u8, u8>, el: &Register<(u8, u8), u8, u8>) -> Map<u8, u8, u8, u8> {
563        acc.merge_register(el.clone(), 0);
564        acc
565    }
566
567    #[quickcheck]
568    fn test_merge_commutative(xs: Vec<Register<(u8, u8), u8, u8>>) -> bool {
569        let left: HashMap<u8, (u8, u8)> = xs.iter().fold(Map::default(), merge).into();
570        let right: HashMap<u8, (u8, u8)> = xs.iter().rfold(Map::default(), merge).into();
571        left == right
572    }
573
574    #[quickcheck]
575    fn is_merge_order_independent(xs: Vec<Register<(u8, u8), u8, u8>>) -> bool {
576        let mut copy = xs.clone();
577        copy.shuffle(&mut rand::thread_rng());
578        let left: HashMap<u8, (u8, u8)> = xs.iter().fold(Map::default(), merge).into();
579        let right: HashMap<u8, (u8, u8)> = copy.iter().rfold(Map::default(), merge).into();
580        left == right
581    }
582
583    mod simple_model {
584        use super::*;
585        use quickcheck::{Arbitrary, Gen};
586        #[derive(Clone, Debug)]
587        enum Op {
588            Insert(u8, u8),
589            Get(u8),
590            Delete(u8),
591        }
592
593        const KEY_SPACE: u8 = 20;
594
595        impl Arbitrary for Op {
596            fn arbitrary(g: &mut Gen) -> Op {
597                let k: u8 = u8::arbitrary(g) % KEY_SPACE;
598                let v: u8 = u8::arbitrary(g);
599                let n: u8 = u8::arbitrary(g) % 4;
600
601                match n {
602                    0 => Op::Insert(k, v),
603                    1 => Op::Delete(k),
604                    2 | 3 => Op::Get(k),
605                    _ => Op::Get(k),
606                }
607            }
608        }
609
610        #[quickcheck]
611        fn implementation_matches_model(ops: Vec<Op>) -> bool {
612            let mut implementation: Map<u8, u8, u8, u8> = Map::new();
613            let mut model = std::collections::HashMap::new();
614
615            for op in ops {
616                match op {
617                    Op::Insert(k, v) => {
618                        implementation.insert(k, v, 0);
619                        model.insert(k, v);
620                    }
621                    Op::Get(k) => {
622                        if implementation.get(&k).map(|i| i.0) != model.get(&k) {
623                            return false;
624                        }
625                    }
626                    Op::Delete(k) => {
627                        implementation.remove(k, 0);
628                        model.remove(&k);
629                    }
630                }
631            }
632
633            true
634        }
635    }
636}