1use crate::lattice::Lattice;
10use serde::{Deserialize, Deserializer, Serialize, Serializer};
11use std::collections::BTreeMap;
12
13#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
16pub struct Dot {
17 pub replica_id: String,
18 pub seq: u64,
19}
20
21impl Dot {
22 pub fn new(replica_id: impl Into<String>, seq: u64) -> Self {
23 Self {
24 replica_id: replica_id.into(),
25 seq,
26 }
27 }
28}
29
30#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
32pub struct CausalContext {
33 dots: std::collections::BTreeSet<Dot>,
35}
36
37impl CausalContext {
38 pub fn new() -> Self {
39 Self {
40 dots: std::collections::BTreeSet::new(),
41 }
42 }
43
44 pub fn add_dot(&mut self, dot: Dot) {
45 self.dots.insert(dot);
46 }
47
48 pub fn contains(&self, dot: &Dot) -> bool {
49 self.dots.contains(dot)
50 }
51
52 pub fn join(&self, other: &CausalContext) -> CausalContext {
53 let mut joined = self.clone();
54 for dot in &other.dots {
55 joined.add_dot(dot.clone());
56 }
57 joined
58 }
59}
60
61impl Default for CausalContext {
62 fn default() -> Self {
63 Self::new()
64 }
65}
66
67#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
70pub enum MapValue {
71 Int(i64),
72 Text(String),
73 Bytes(Vec<u8>),
74 }
77
78#[derive(Clone, Debug, PartialEq, Eq)]
84pub struct CRDTMap<K: Ord + Clone> {
85 entries: BTreeMap<K, BTreeMap<Dot, MapValue>>,
87 context: CausalContext,
89 local_seq: u64,
91}
92
93impl<K: Ord + Clone + Serialize> Serialize for CRDTMap<K> {
95 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
96 where
97 S: Serializer,
98 {
99 #[derive(Serialize)]
101 struct SerializableCRDTMap<'a, K: Ord + Clone + Serialize> {
102 entries: Vec<(&'a K, Vec<(&'a Dot, &'a MapValue)>)>,
103 context: &'a CausalContext,
104 }
105
106 let entries: Vec<_> = self
107 .entries
108 .iter()
109 .map(|(k, v)| (k, v.iter().collect::<Vec<_>>()))
110 .collect();
111
112 let serializable = SerializableCRDTMap {
113 entries,
114 context: &self.context,
115 };
116
117 serializable.serialize(serializer)
118 }
119}
120
121impl<'de, K: Ord + Clone + Deserialize<'de>> Deserialize<'de> for CRDTMap<K> {
122 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
123 where
124 D: Deserializer<'de>,
125 {
126 #[derive(Deserialize)]
127 struct DeserializableCRDTMap<K: Ord + Clone> {
128 entries: Vec<(K, Vec<(Dot, MapValue)>)>,
129 context: CausalContext,
130 }
131
132 let deserialized = DeserializableCRDTMap::<K>::deserialize(deserializer)?;
133
134 let entries: BTreeMap<K, BTreeMap<Dot, MapValue>> = deserialized
135 .entries
136 .into_iter()
137 .map(|(k, v)| (k, v.into_iter().collect()))
138 .collect();
139
140 Ok(Self {
141 entries,
142 context: deserialized.context,
143 local_seq: 0,
144 })
145 }
146}
147
148impl<K: Ord + Clone> CRDTMap<K> {
149 pub fn new() -> Self {
151 Self {
152 entries: BTreeMap::new(),
153 context: CausalContext::new(),
154 local_seq: 0,
155 }
156 }
157
158 pub fn put(&mut self, replica_id: &str, key: K, value: MapValue) -> Dot {
160 let dot = Dot::new(replica_id, self.local_seq);
161 self.local_seq += 1;
162
163 let entry = self.entries.entry(key).or_default();
165
166 entry.clear();
168 entry.insert(dot.clone(), value);
169
170 self.context.add_dot(dot.clone());
172
173 dot
174 }
175
176 pub fn get(&self, key: &K) -> Option<&MapValue> {
179 self.entries
180 .get(key)
181 .and_then(|entry| entry.values().next())
182 }
183
184 pub fn get_all(&self, key: &K) -> Vec<&MapValue> {
186 self.entries
187 .get(key)
188 .map(|entry| entry.values().collect())
189 .unwrap_or_default()
190 }
191
192 pub fn remove(&mut self, key: &K) {
194 if let Some(entry) = self.entries.get_mut(key) {
195 entry.clear();
197 }
198 }
199
200 pub fn contains_key(&self, key: &K) -> bool {
202 self.entries
203 .get(key)
204 .map(|entry| !entry.is_empty())
205 .unwrap_or(false)
206 }
207
208 pub fn keys(&self) -> impl Iterator<Item = &K> {
210 self.entries
211 .iter()
212 .filter_map(|(k, v)| if !v.is_empty() { Some(k) } else { None })
213 }
214
215 pub fn context(&self) -> &CausalContext {
217 &self.context
218 }
219
220 pub fn put_with_dot(&mut self, key: K, dot: Dot, value: MapValue) {
222 let entry = self.entries.entry(key).or_default();
223 entry.insert(dot.clone(), value);
224 self.context.add_dot(dot);
225 }
226}
227
228impl<K: Ord + Clone> Default for CRDTMap<K> {
229 fn default() -> Self {
230 Self::new()
231 }
232}
233
234impl<K: Ord + Clone> Lattice for CRDTMap<K> {
235 fn bottom() -> Self {
236 Self::new()
237 }
238
239 fn join(&self, other: &Self) -> Self {
242 let mut entries = self.entries.clone();
243 let mut context = self.context.clone();
244
245 for (key, other_entry) in &other.entries {
247 let entry = entries.entry(key.clone()).or_default();
248 for (dot, value) in other_entry {
249 entry.insert(dot.clone(), value.clone());
250 }
251 }
252
253 context = context.join(&other.context);
255
256 Self {
257 entries,
258 context,
259 local_seq: self.local_seq.max(other.local_seq),
260 }
261 }
262}
263
264#[cfg(test)]
265mod tests {
266 use super::*;
267
268 #[test]
269 fn test_map_basic_operations() {
270 let mut map: CRDTMap<String> = CRDTMap::new();
271
272 map.put("replica1", "key1".to_string(), MapValue::Int(42));
273 assert_eq!(map.get(&"key1".to_string()), Some(&MapValue::Int(42)));
274
275 map.put(
276 "replica1",
277 "key2".to_string(),
278 MapValue::Text("hello".to_string()),
279 );
280 assert_eq!(
281 map.get(&"key2".to_string()),
282 Some(&MapValue::Text("hello".to_string()))
283 );
284 }
285
286 #[test]
287 fn test_map_remove() {
288 let mut map: CRDTMap<String> = CRDTMap::new();
289
290 map.put("replica1", "key1".to_string(), MapValue::Int(42));
291 assert!(map.contains_key(&"key1".to_string()));
292
293 map.remove(&"key1".to_string());
294 assert!(!map.contains_key(&"key1".to_string()));
295 }
296
297 #[test]
298 fn test_map_join_idempotent() {
299 let mut map1: CRDTMap<String> = CRDTMap::new();
300 map1.put("replica1", "key1".to_string(), MapValue::Int(42));
301
302 let joined = map1.join(&map1);
303 assert_eq!(joined.get(&"key1".to_string()), Some(&MapValue::Int(42)));
304 }
305
306 #[test]
307 fn test_map_join_commutative() {
308 let mut map1: CRDTMap<String> = CRDTMap::new();
309 map1.put("replica1", "key1".to_string(), MapValue::Int(42));
310
311 let mut map2: CRDTMap<String> = CRDTMap::new();
312 map2.put(
313 "replica2",
314 "key2".to_string(),
315 MapValue::Text("world".to_string()),
316 );
317
318 let joined1 = map1.join(&map2);
319 let joined2 = map2.join(&map1);
320
321 assert_eq!(joined1.get(&"key1".to_string()), Some(&MapValue::Int(42)));
322 assert_eq!(
323 joined1.get(&"key2".to_string()),
324 Some(&MapValue::Text("world".to_string()))
325 );
326
327 assert_eq!(joined2.get(&"key1".to_string()), Some(&MapValue::Int(42)));
328 assert_eq!(
329 joined2.get(&"key2".to_string()),
330 Some(&MapValue::Text("world".to_string()))
331 );
332 }
333
334 #[test]
335 fn test_map_join_associative() {
336 let mut map1: CRDTMap<String> = CRDTMap::new();
337 map1.put("replica1", "key1".to_string(), MapValue::Int(1));
338
339 let mut map2: CRDTMap<String> = CRDTMap::new();
340 map2.put("replica2", "key2".to_string(), MapValue::Int(2));
341
342 let mut map3: CRDTMap<String> = CRDTMap::new();
343 map3.put("replica3", "key3".to_string(), MapValue::Int(3));
344
345 let left = map1.join(&map2).join(&map3);
346 let right = map1.join(&map2.join(&map3));
347
348 assert_eq!(left.get(&"key1".to_string()), Some(&MapValue::Int(1)));
349 assert_eq!(left.get(&"key2".to_string()), Some(&MapValue::Int(2)));
350 assert_eq!(left.get(&"key3".to_string()), Some(&MapValue::Int(3)));
351
352 assert_eq!(right.get(&"key1".to_string()), Some(&MapValue::Int(1)));
353 assert_eq!(right.get(&"key2".to_string()), Some(&MapValue::Int(2)));
354 assert_eq!(right.get(&"key3".to_string()), Some(&MapValue::Int(3)));
355 }
356
357 #[test]
358 fn test_map_concurrent_writes_different_keys() {
359 let mut map1: CRDTMap<String> = CRDTMap::new();
360 map1.put("replica1", "key1".to_string(), MapValue::Int(10));
361
362 let mut map2: CRDTMap<String> = CRDTMap::new();
363 map2.put("replica2", "key2".to_string(), MapValue::Int(20));
364
365 let merged = map1.join(&map2);
366 assert_eq!(merged.get(&"key1".to_string()), Some(&MapValue::Int(10)));
367 assert_eq!(merged.get(&"key2".to_string()), Some(&MapValue::Int(20)));
368 }
369
370 #[test]
371 fn test_map_serialization() {
372 let mut map: CRDTMap<String> = CRDTMap::new();
373 map.put("replica1", "key1".to_string(), MapValue::Int(42));
374 map.put(
375 "replica1",
376 "key2".to_string(),
377 MapValue::Text("hello".to_string()),
378 );
379
380 let serialized = serde_json::to_string(&map).unwrap();
381 let deserialized: CRDTMap<String> = serde_json::from_str(&serialized).unwrap();
382
383 assert_eq!(
384 deserialized.get(&"key1".to_string()),
385 Some(&MapValue::Int(42))
386 );
387 assert_eq!(
388 deserialized.get(&"key2".to_string()),
389 Some(&MapValue::Text("hello".to_string()))
390 );
391 }
392}