1use ahash::AHashMap;
7use serde::{Deserialize, Serialize};
8use std::cmp::Ordering;
9use std::collections::{HashMap, HashSet};
10use std::fmt;
11
12#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
14pub struct VectorClock {
15 clock: AHashMap<String, u64>,
16}
17
18impl VectorClock {
19 pub fn new() -> Self {
21 Self {
22 clock: AHashMap::new(),
23 }
24 }
25
26 pub fn increment(&mut self, node_id: &str) {
28 let counter = self.clock.entry(node_id.to_string()).or_insert(0);
29 *counter = counter.saturating_add(1);
30 }
31
32 pub fn get(&self, node_id: &str) -> u64 {
34 self.clock.get(node_id).copied().unwrap_or(0)
35 }
36
37 pub fn merge(&mut self, other: &VectorClock) {
39 for (node_id, &other_count) in &other.clock {
40 let count = self.clock.entry(node_id.clone()).or_insert(0);
41 *count = (*count).max(other_count);
42 }
43 }
44
45 pub fn compare(&self, other: &VectorClock) -> ClockOrdering {
47 let mut less = false;
48 let mut greater = false;
49
50 for (node_id, &self_count) in &self.clock {
52 let other_count = other.get(node_id);
53 match self_count.cmp(&other_count) {
54 Ordering::Less => less = true,
55 Ordering::Greater => greater = true,
56 Ordering::Equal => {}
57 }
58 }
59
60 for node_id in other.clock.keys() {
62 if !self.clock.contains_key(node_id) {
63 less = true;
64 }
65 }
66
67 match (less, greater) {
68 (true, false) => ClockOrdering::Before,
69 (false, true) => ClockOrdering::After,
70 (false, false) => ClockOrdering::Equal,
71 (true, true) => ClockOrdering::Concurrent,
72 }
73 }
74
75 pub fn is_concurrent(&self, other: &VectorClock) -> bool {
77 matches!(self.compare(other), ClockOrdering::Concurrent)
78 }
79
80 pub fn happens_before(&self, other: &VectorClock) -> bool {
82 matches!(self.compare(other), ClockOrdering::Before)
83 }
84}
85
86impl Default for VectorClock {
87 fn default() -> Self {
88 Self::new()
89 }
90}
91
92#[derive(Debug, Clone, Copy, PartialEq, Eq)]
94pub enum ClockOrdering {
95 Before,
97 After,
99 Equal,
101 Concurrent,
103}
104
105#[derive(Debug, Clone, Serialize, Deserialize)]
107pub struct LwwRegister<T> {
108 value: T,
109 timestamp: VectorClock,
110 logical_time: u64,
111 node_id: String,
112}
113
114impl<T: Clone> LwwRegister<T> {
115 pub fn new(value: T, node_id: String) -> Self {
117 let mut timestamp = VectorClock::new();
118 timestamp.increment(&node_id);
119 Self {
120 value,
121 timestamp,
122 logical_time: 1,
123 node_id,
124 }
125 }
126
127 pub fn value(&self) -> &T {
129 &self.value
130 }
131
132 pub fn update(&mut self, value: T) {
134 self.value = value;
135 self.timestamp.increment(&self.node_id);
136 self.logical_time += 1;
137 }
138
139 pub fn merge(&mut self, other: &LwwRegister<T>) {
141 match self.timestamp.compare(&other.timestamp) {
142 ClockOrdering::Before => {
143 self.value = other.value.clone();
144 self.timestamp = other.timestamp.clone();
145 self.logical_time = other.logical_time;
146 }
147 ClockOrdering::Concurrent => {
148 let should_adopt_other = other.logical_time > self.logical_time
151 || (other.logical_time == self.logical_time && self.node_id < other.node_id);
152
153 if should_adopt_other {
154 self.value = other.value.clone();
155 self.timestamp = other.timestamp.clone();
156 self.logical_time = other.logical_time;
157 }
158 }
159 ClockOrdering::After | ClockOrdering::Equal => {
160 }
162 }
163 }
164}
165
166#[derive(Debug, Clone, Serialize, Deserialize)]
168pub struct GSet<T: Eq + std::hash::Hash> {
169 elements: HashSet<T>,
170}
171
172impl<T: Eq + std::hash::Hash> GSet<T> {
173 pub fn new() -> Self {
175 Self {
176 elements: HashSet::new(),
177 }
178 }
179
180 pub fn insert(&mut self, element: T) {
182 self.elements.insert(element);
183 }
184
185 pub fn contains(&self, element: &T) -> bool {
187 self.elements.contains(element)
188 }
189
190 pub fn len(&self) -> usize {
192 self.elements.len()
193 }
194
195 pub fn is_empty(&self) -> bool {
197 self.elements.is_empty()
198 }
199
200 pub fn merge(&mut self, other: &GSet<T>)
202 where
203 T: Clone,
204 {
205 for element in &other.elements {
206 self.elements.insert(element.clone());
207 }
208 }
209}
210
211impl<T: Eq + std::hash::Hash> Default for GSet<T> {
212 fn default() -> Self {
213 Self::new()
214 }
215}
216
217#[derive(Debug, Clone, Serialize, Deserialize)]
219pub struct TwoPhaseSet<T: Eq + std::hash::Hash + Clone> {
220 added: HashSet<T>,
221 removed: HashSet<T>,
222}
223
224impl<T: Eq + std::hash::Hash + Clone> TwoPhaseSet<T> {
225 pub fn new() -> Self {
227 Self {
228 added: HashSet::new(),
229 removed: HashSet::new(),
230 }
231 }
232
233 pub fn insert(&mut self, element: T) {
235 if !self.removed.contains(&element) {
236 self.added.insert(element);
237 }
238 }
239
240 pub fn remove(&mut self, element: &T) -> bool {
242 if self.added.contains(element) {
243 self.removed.insert(element.clone());
244 true
245 } else {
246 false
247 }
248 }
249
250 pub fn contains(&self, element: &T) -> bool {
252 self.added.contains(element) && !self.removed.contains(element)
253 }
254
255 pub fn elements(&self) -> impl Iterator<Item = &T> {
257 self.added.iter().filter(|e| !self.removed.contains(e))
258 }
259
260 pub fn len(&self) -> usize {
262 self.elements().count()
263 }
264
265 pub fn is_empty(&self) -> bool {
267 self.len() == 0
268 }
269
270 pub fn merge(&mut self, other: &TwoPhaseSet<T>) {
272 for element in &other.added {
273 self.added.insert(element.clone());
274 }
275 for element in &other.removed {
276 self.removed.insert(element.clone());
277 }
278 }
279}
280
281impl<T: Eq + std::hash::Hash + Clone> Default for TwoPhaseSet<T> {
282 fn default() -> Self {
283 Self::new()
284 }
285}
286
287pub type CrdtSet<T> = TwoPhaseSet<T>;
289
290#[derive(Debug, Clone, Serialize, Deserialize)]
292pub struct CrdtMap<K, V>
293where
294 K: Eq + std::hash::Hash + Clone,
295 V: Clone,
296{
297 entries: HashMap<K, LwwRegister<V>>,
298 node_id: String,
299}
300
301impl<K, V> CrdtMap<K, V>
302where
303 K: Eq + std::hash::Hash + Clone,
304 V: Clone,
305{
306 pub fn new(node_id: String) -> Self {
308 Self {
309 entries: HashMap::new(),
310 node_id,
311 }
312 }
313
314 pub fn insert(&mut self, key: K, value: V) {
316 if let Some(register) = self.entries.get_mut(&key) {
317 register.update(value);
318 } else {
319 let register = LwwRegister::new(value, self.node_id.clone());
320 self.entries.insert(key, register);
321 }
322 }
323
324 pub fn get(&self, key: &K) -> Option<&V> {
326 self.entries.get(key).map(|r| r.value())
327 }
328
329 pub fn contains_key(&self, key: &K) -> bool {
331 self.entries.contains_key(key)
332 }
333
334 pub fn len(&self) -> usize {
336 self.entries.len()
337 }
338
339 pub fn is_empty(&self) -> bool {
341 self.entries.is_empty()
342 }
343
344 pub fn iter(&self) -> impl Iterator<Item = (&K, &V)> {
346 self.entries.iter().map(|(k, v)| (k, v.value()))
347 }
348
349 pub fn merge(&mut self, other: &CrdtMap<K, V>) {
351 for (key, other_register) in &other.entries {
352 if let Some(register) = self.entries.get_mut(key) {
353 register.merge(other_register);
354 } else {
355 self.entries.insert(key.clone(), other_register.clone());
356 }
357 }
358 }
359}
360
361pub struct ConflictResolver {
363 node_id: String,
364}
365
366impl ConflictResolver {
367 pub fn new(node_id: String) -> Self {
369 Self { node_id }
370 }
371
372 pub fn create_map<K, V>(&self) -> CrdtMap<K, V>
374 where
375 K: Eq + std::hash::Hash + Clone,
376 V: Clone,
377 {
378 CrdtMap::new(self.node_id.clone())
379 }
380
381 pub fn create_set<T: Eq + std::hash::Hash + Clone>(&self) -> CrdtSet<T> {
383 CrdtSet::new()
384 }
385
386 pub fn resolve_lww<T: Clone>(
388 &self,
389 local: &T,
390 local_clock: &VectorClock,
391 remote: &T,
392 remote_clock: &VectorClock,
393 ) -> T {
394 match local_clock.compare(remote_clock) {
395 ClockOrdering::Before => remote.clone(),
396 ClockOrdering::After | ClockOrdering::Equal => local.clone(),
397 ClockOrdering::Concurrent => {
398 if self.node_id.as_str() < "remote" {
400 remote.clone()
401 } else {
402 local.clone()
403 }
404 }
405 }
406 }
407
408 pub fn node_id(&self) -> &str {
410 &self.node_id
411 }
412}
413
414impl fmt::Display for VectorClock {
415 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
416 write!(f, "{{")?;
417 for (i, (node, count)) in self.clock.iter().enumerate() {
418 if i > 0 {
419 write!(f, ", ")?;
420 }
421 write!(f, "{}: {}", node, count)?;
422 }
423 write!(f, "}}")
424 }
425}
426
427#[cfg(test)]
428mod tests {
429 use super::*;
430
431 #[test]
432 fn test_vector_clock_increment() {
433 let mut clock = VectorClock::new();
434 clock.increment("node1");
435 clock.increment("node1");
436 clock.increment("node2");
437
438 assert_eq!(clock.get("node1"), 2);
439 assert_eq!(clock.get("node2"), 1);
440 assert_eq!(clock.get("node3"), 0);
441 }
442
443 #[test]
444 fn test_vector_clock_merge() {
445 let mut clock1 = VectorClock::new();
446 clock1.increment("node1");
447 clock1.increment("node1");
448
449 let mut clock2 = VectorClock::new();
450 clock2.increment("node2");
451
452 clock1.merge(&clock2);
453 assert_eq!(clock1.get("node1"), 2);
454 assert_eq!(clock1.get("node2"), 1);
455 }
456
457 #[test]
458 fn test_vector_clock_compare() {
459 let mut clock1 = VectorClock::new();
460 clock1.increment("node1");
461
462 let mut clock2 = VectorClock::new();
463 clock2.increment("node1");
464 clock2.increment("node1");
465
466 assert_eq!(clock1.compare(&clock2), ClockOrdering::Before);
467 assert_eq!(clock2.compare(&clock1), ClockOrdering::After);
468
469 let mut clock3 = VectorClock::new();
470 clock3.increment("node2");
471
472 assert_eq!(clock1.compare(&clock3), ClockOrdering::Concurrent);
473 }
474
475 #[test]
476 fn test_lww_register() {
477 let mut reg1 = LwwRegister::new(42, "node1".to_string());
478 let mut reg2 = LwwRegister::new(100, "node2".to_string());
479
480 reg1.update(50);
481 reg2.merge(®1);
482
483 assert_eq!(*reg2.value(), 50);
484 }
485
486 #[test]
487 fn test_gset() {
488 let mut set1 = GSet::new();
489 set1.insert(1);
490 set1.insert(2);
491
492 let mut set2 = GSet::new();
493 set2.insert(2);
494 set2.insert(3);
495
496 set1.merge(&set2);
497
498 assert_eq!(set1.len(), 3);
499 assert!(set1.contains(&1));
500 assert!(set1.contains(&2));
501 assert!(set1.contains(&3));
502 }
503
504 #[test]
505 fn test_two_phase_set() {
506 let mut set = TwoPhaseSet::new();
507 set.insert(1);
508 set.insert(2);
509 set.insert(3);
510
511 assert_eq!(set.len(), 3);
512 assert!(set.contains(&2));
513
514 set.remove(&2);
515 assert_eq!(set.len(), 2);
516 assert!(!set.contains(&2));
517
518 set.insert(2);
520 assert!(!set.contains(&2));
521 }
522
523 #[test]
524 fn test_two_phase_set_merge() {
525 let mut set1 = TwoPhaseSet::new();
526 set1.insert(1);
527 set1.insert(2);
528
529 let mut set2 = TwoPhaseSet::new();
530 set2.insert(2);
531 set2.insert(3);
532 set2.remove(&2);
533
534 set1.merge(&set2);
535
536 assert!(set1.contains(&1));
537 assert!(!set1.contains(&2)); assert!(set1.contains(&3));
539 }
540
541 #[test]
542 fn test_crdt_map() {
543 let mut map1 = CrdtMap::new("node1".to_string());
544 map1.insert("key1", 100);
545 map1.insert("key2", 200);
546
547 let mut map2 = CrdtMap::new("node2".to_string());
548 map2.insert("key2", 250);
549 map2.insert("key3", 300);
550
551 map1.merge(&map2);
552
553 assert_eq!(map1.get(&"key1"), Some(&100));
554 assert_eq!(map1.get(&"key3"), Some(&300));
555 }
557
558 #[test]
559 fn test_conflict_resolver() {
560 let resolver = ConflictResolver::new("node1".to_string());
561 assert_eq!(resolver.node_id(), "node1");
562
563 let map: CrdtMap<String, i32> = resolver.create_map();
564 assert!(map.is_empty());
565
566 let set: CrdtSet<i32> = resolver.create_set();
567 assert!(set.is_empty());
568 }
569}