1use crate::node::NodeId;
9use crate::vector_clock::VectorClock;
10use serde::{Deserialize, Serialize};
11use std::collections::{HashMap, HashSet};
12
13pub trait CRDT: Clone {
19 fn merge(&mut self, other: &Self);
21
22 fn merged(&self, other: &Self) -> Self {
24 let mut result = self.clone();
25 result.merge(other);
26 result
27 }
28}
29
30#[derive(Debug, Clone, Default, Serialize, Deserialize)]
36pub struct GCounter {
37 counts: HashMap<String, u64>,
38}
39
40impl GCounter {
41 pub fn new() -> Self {
43 Self {
44 counts: HashMap::new(),
45 }
46 }
47
48 pub fn increment(&mut self, node_id: &NodeId) {
50 let key = node_id.as_str().to_string();
51 *self.counts.entry(key).or_insert(0) += 1;
52 }
53
54 pub fn increment_by(&mut self, node_id: &NodeId, amount: u64) {
56 let key = node_id.as_str().to_string();
57 *self.counts.entry(key).or_insert(0) += amount;
58 }
59
60 pub fn value(&self) -> u64 {
62 self.counts.values().sum()
63 }
64
65 pub fn node_value(&self, node_id: &NodeId) -> u64 {
67 self.counts.get(node_id.as_str()).copied().unwrap_or(0)
68 }
69}
70
71impl CRDT for GCounter {
72 fn merge(&mut self, other: &Self) {
73 for (node, &value) in &other.counts {
74 let current = self.counts.entry(node.clone()).or_insert(0);
75 *current = (*current).max(value);
76 }
77 }
78}
79
80#[derive(Debug, Clone, Default, Serialize, Deserialize)]
86pub struct PNCounter {
87 positive: GCounter,
88 negative: GCounter,
89}
90
91impl PNCounter {
92 pub fn new() -> Self {
94 Self {
95 positive: GCounter::new(),
96 negative: GCounter::new(),
97 }
98 }
99
100 pub fn increment(&mut self, node_id: &NodeId) {
102 self.positive.increment(node_id);
103 }
104
105 pub fn increment_by(&mut self, node_id: &NodeId, amount: u64) {
107 self.positive.increment_by(node_id, amount);
108 }
109
110 pub fn decrement(&mut self, node_id: &NodeId) {
112 self.negative.increment(node_id);
113 }
114
115 pub fn decrement_by(&mut self, node_id: &NodeId, amount: u64) {
117 self.negative.increment_by(node_id, amount);
118 }
119
120 pub fn value(&self) -> i64 {
122 self.positive.value() as i64 - self.negative.value() as i64
123 }
124}
125
126impl CRDT for PNCounter {
127 fn merge(&mut self, other: &Self) {
128 self.positive.merge(&other.positive);
129 self.negative.merge(&other.negative);
130 }
131}
132
133#[derive(Debug, Clone, Default, Serialize, Deserialize)]
139pub struct GSet<T: Clone + Eq + std::hash::Hash + Serialize> {
140 elements: HashSet<T>,
141}
142
143impl<T: Clone + Eq + std::hash::Hash + Serialize + for<'de> Deserialize<'de>> GSet<T> {
144 pub fn new() -> Self {
146 Self {
147 elements: HashSet::new(),
148 }
149 }
150
151 pub fn add(&mut self, element: T) {
153 self.elements.insert(element);
154 }
155
156 pub fn contains(&self, element: &T) -> bool {
158 self.elements.contains(element)
159 }
160
161 pub fn len(&self) -> usize {
163 self.elements.len()
164 }
165
166 pub fn is_empty(&self) -> bool {
168 self.elements.is_empty()
169 }
170
171 pub fn iter(&self) -> impl Iterator<Item = &T> {
173 self.elements.iter()
174 }
175
176 pub fn to_vec(&self) -> Vec<T> {
178 self.elements.iter().cloned().collect()
179 }
180}
181
182impl<T: Clone + Eq + std::hash::Hash + Serialize + for<'de> Deserialize<'de>> CRDT for GSet<T> {
183 fn merge(&mut self, other: &Self) {
184 for element in &other.elements {
185 self.elements.insert(element.clone());
186 }
187 }
188}
189
190#[derive(Debug, Clone, Default, Serialize, Deserialize)]
196pub struct TwoPSet<T: Clone + Eq + std::hash::Hash + Serialize> {
197 added: HashSet<T>,
198 removed: HashSet<T>,
199}
200
201impl<T: Clone + Eq + std::hash::Hash + Serialize + for<'de> Deserialize<'de>> TwoPSet<T> {
202 pub fn new() -> Self {
204 Self {
205 added: HashSet::new(),
206 removed: HashSet::new(),
207 }
208 }
209
210 pub fn add(&mut self, element: T) {
212 if !self.removed.contains(&element) {
213 self.added.insert(element);
214 }
215 }
216
217 pub fn remove(&mut self, element: T) {
219 if self.added.contains(&element) {
220 self.removed.insert(element);
221 }
222 }
223
224 pub fn contains(&self, element: &T) -> bool {
226 self.added.contains(element) && !self.removed.contains(element)
227 }
228
229 pub fn elements(&self) -> HashSet<T> {
231 self.added.difference(&self.removed).cloned().collect()
232 }
233
234 pub fn len(&self) -> usize {
236 self.elements().len()
237 }
238
239 pub fn is_empty(&self) -> bool {
241 self.elements().is_empty()
242 }
243}
244
245impl<T: Clone + Eq + std::hash::Hash + Serialize + for<'de> Deserialize<'de>> CRDT for TwoPSet<T> {
246 fn merge(&mut self, other: &Self) {
247 for element in &other.added {
248 self.added.insert(element.clone());
249 }
250 for element in &other.removed {
251 self.removed.insert(element.clone());
252 }
253 }
254}
255
256#[derive(Debug, Clone, Serialize, Deserialize)]
262pub struct LWWRegister<T: Clone + Serialize> {
263 value: Option<T>,
264 timestamp: u64,
265 node_id: String,
266}
267
268impl<T: Clone + Serialize + for<'de> Deserialize<'de>> LWWRegister<T> {
269 pub fn new() -> Self {
271 Self {
272 value: None,
273 timestamp: 0,
274 node_id: String::new(),
275 }
276 }
277
278 pub fn with_value(value: T, node_id: &NodeId) -> Self {
280 Self {
281 value: Some(value),
282 timestamp: Self::now(),
283 node_id: node_id.as_str().to_string(),
284 }
285 }
286
287 fn now() -> u64 {
289 std::time::SystemTime::now()
290 .duration_since(std::time::UNIX_EPOCH)
291 .unwrap_or_default()
292 .as_nanos() as u64
293 }
294
295 pub fn set(&mut self, value: T, node_id: &NodeId) {
297 let ts = Self::now();
298 if ts > self.timestamp || (ts == self.timestamp && node_id.as_str() > self.node_id.as_str())
299 {
300 self.value = Some(value);
301 self.timestamp = ts;
302 self.node_id = node_id.as_str().to_string();
303 }
304 }
305
306 pub fn set_with_timestamp(&mut self, value: T, timestamp: u64, node_id: &NodeId) {
308 if timestamp > self.timestamp
309 || (timestamp == self.timestamp && node_id.as_str() > self.node_id.as_str())
310 {
311 self.value = Some(value);
312 self.timestamp = timestamp;
313 self.node_id = node_id.as_str().to_string();
314 }
315 }
316
317 pub fn get(&self) -> Option<&T> {
319 self.value.as_ref()
320 }
321
322 pub fn timestamp(&self) -> u64 {
324 self.timestamp
325 }
326
327 pub fn is_set(&self) -> bool {
329 self.value.is_some()
330 }
331}
332
333impl<T: Clone + Serialize + for<'de> Deserialize<'de>> Default for LWWRegister<T> {
334 fn default() -> Self {
335 Self::new()
336 }
337}
338
339impl<T: Clone + Serialize + for<'de> Deserialize<'de>> CRDT for LWWRegister<T> {
340 fn merge(&mut self, other: &Self) {
341 if other.timestamp > self.timestamp
342 || (other.timestamp == self.timestamp && other.node_id > self.node_id)
343 {
344 self.value = other.value.clone();
345 self.timestamp = other.timestamp;
346 self.node_id = other.node_id.clone();
347 }
348 }
349}
350
351#[derive(Debug, Clone, Serialize, Deserialize)]
357pub struct MVRegister<T: Clone + Eq + std::hash::Hash + Serialize> {
358 values: HashMap<T, VectorClock>,
359}
360
361impl<T: Clone + Eq + std::hash::Hash + Serialize + for<'de> Deserialize<'de>> MVRegister<T> {
362 pub fn new() -> Self {
364 Self {
365 values: HashMap::new(),
366 }
367 }
368
369 pub fn set(&mut self, value: T, clock: VectorClock) {
371 self.values.retain(|_, v| !v.happened_before(&clock));
373
374 let dominated = self.values.values().any(|v| clock.happened_before(v));
376 if !dominated {
377 self.values.insert(value, clock);
378 }
379 }
380
381 pub fn get(&self) -> Vec<&T> {
383 self.values.keys().collect()
384 }
385
386 pub fn get_with_clocks(&self) -> Vec<(&T, &VectorClock)> {
388 self.values.iter().collect()
389 }
390
391 pub fn has_conflict(&self) -> bool {
393 self.values.len() > 1
394 }
395
396 pub fn value_count(&self) -> usize {
398 self.values.len()
399 }
400}
401
402impl<T: Clone + Eq + std::hash::Hash + Serialize + for<'de> Deserialize<'de>> Default
403 for MVRegister<T>
404{
405 fn default() -> Self {
406 Self::new()
407 }
408}
409
410impl<T: Clone + Eq + std::hash::Hash + Serialize + for<'de> Deserialize<'de>> CRDT
411 for MVRegister<T>
412{
413 fn merge(&mut self, other: &Self) {
414 for (value, clock) in &other.values {
415 self.set(value.clone(), clock.clone());
416 }
417 }
418}
419
420#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
426pub struct UniqueTag {
427 node_id: String,
428 counter: u64,
429}
430
431impl UniqueTag {
432 pub fn new(node_id: &NodeId, counter: u64) -> Self {
434 Self {
435 node_id: node_id.as_str().to_string(),
436 counter,
437 }
438 }
439}
440
441#[derive(Debug, Clone, Default, Serialize, Deserialize)]
443pub struct ORSet<T: Clone + Eq + std::hash::Hash + Serialize> {
444 elements: HashMap<T, HashSet<UniqueTag>>,
445 counters: HashMap<String, u64>,
446}
447
448impl<T: Clone + Eq + std::hash::Hash + Serialize + for<'de> Deserialize<'de>> ORSet<T> {
449 pub fn new() -> Self {
451 Self {
452 elements: HashMap::new(),
453 counters: HashMap::new(),
454 }
455 }
456
457 pub fn add(&mut self, element: T, node_id: &NodeId) {
459 let counter = self
460 .counters
461 .entry(node_id.as_str().to_string())
462 .or_insert(0);
463 *counter += 1;
464
465 let tag = UniqueTag::new(node_id, *counter);
466 self.elements.entry(element).or_default().insert(tag);
467 }
468
469 pub fn remove(&mut self, element: &T) {
471 self.elements.remove(element);
472 }
473
474 pub fn contains(&self, element: &T) -> bool {
476 self.elements
477 .get(element)
478 .map(|tags| !tags.is_empty())
479 .unwrap_or(false)
480 }
481
482 pub fn elements(&self) -> Vec<&T> {
484 self.elements
485 .iter()
486 .filter(|(_, tags)| !tags.is_empty())
487 .map(|(elem, _)| elem)
488 .collect()
489 }
490
491 pub fn len(&self) -> usize {
493 self.elements
494 .iter()
495 .filter(|(_, tags)| !tags.is_empty())
496 .count()
497 }
498
499 pub fn is_empty(&self) -> bool {
501 self.len() == 0
502 }
503}
504
505impl<T: Clone + Eq + std::hash::Hash + Serialize + for<'de> Deserialize<'de>> CRDT for ORSet<T> {
506 fn merge(&mut self, other: &Self) {
507 for (element, tags) in &other.elements {
508 let our_tags = self.elements.entry(element.clone()).or_default();
509 for tag in tags {
510 our_tags.insert(tag.clone());
511 }
512 }
513
514 for (node, &counter) in &other.counters {
515 let our_counter = self.counters.entry(node.clone()).or_insert(0);
516 *our_counter = (*our_counter).max(counter);
517 }
518 }
519}
520
521#[derive(Debug, Clone, Serialize, Deserialize)]
527pub struct LWWMap<K: Clone + Eq + std::hash::Hash + Serialize, V: Clone + Serialize> {
528 entries: HashMap<K, LWWRegister<V>>,
529}
530
531impl<
532 K: Clone + Eq + std::hash::Hash + Serialize + for<'de> Deserialize<'de>,
533 V: Clone + Serialize + for<'de> Deserialize<'de>,
534 > LWWMap<K, V>
535{
536 pub fn new() -> Self {
538 Self {
539 entries: HashMap::new(),
540 }
541 }
542
543 pub fn set(&mut self, key: K, value: V, node_id: &NodeId) {
545 self.entries.entry(key).or_default().set(value, node_id);
546 }
547
548 pub fn get(&self, key: &K) -> Option<&V> {
550 self.entries.get(key).and_then(|r| r.get())
551 }
552
553 pub fn remove(&mut self, key: &K) {
555 self.entries.remove(key);
556 }
557
558 pub fn keys(&self) -> Vec<&K> {
560 self.entries
561 .iter()
562 .filter(|(_, v)| v.is_set())
563 .map(|(k, _)| k)
564 .collect()
565 }
566
567 pub fn len(&self) -> usize {
569 self.entries.iter().filter(|(_, v)| v.is_set()).count()
570 }
571
572 pub fn is_empty(&self) -> bool {
574 self.len() == 0
575 }
576}
577
578impl<
579 K: Clone + Eq + std::hash::Hash + Serialize + for<'de> Deserialize<'de>,
580 V: Clone + Serialize + for<'de> Deserialize<'de>,
581 > Default for LWWMap<K, V>
582{
583 fn default() -> Self {
584 Self::new()
585 }
586}
587
588impl<
589 K: Clone + Eq + std::hash::Hash + Serialize + for<'de> Deserialize<'de>,
590 V: Clone + Serialize + for<'de> Deserialize<'de>,
591 > CRDT for LWWMap<K, V>
592{
593 fn merge(&mut self, other: &Self) {
594 for (key, register) in &other.entries {
595 self.entries.entry(key.clone()).or_default().merge(register);
596 }
597 }
598}
599
600pub trait DeltaCRDT: CRDT {
606 type Delta: Clone + Serialize;
608
609 fn apply_delta(&mut self, delta: &Self::Delta);
611
612 fn generate_delta(&self) -> Option<Self::Delta>;
614}
615
616#[derive(Debug, Clone, Serialize, Deserialize)]
618pub struct GCounterDelta {
619 pub node_id: String,
620 pub value: u64,
621}
622
623impl DeltaCRDT for GCounter {
624 type Delta = GCounterDelta;
625
626 fn apply_delta(&mut self, delta: &Self::Delta) {
627 let current = self.counts.entry(delta.node_id.clone()).or_insert(0);
628 *current = (*current).max(delta.value);
629 }
630
631 fn generate_delta(&self) -> Option<Self::Delta> {
632 None
634 }
635}
636
637#[cfg(test)]
642mod tests {
643 use super::*;
644
645 #[test]
646 fn test_gcounter_basic() {
647 let mut counter = GCounter::new();
648 let node_a = NodeId::new("A");
649 let node_b = NodeId::new("B");
650
651 counter.increment(&node_a);
652 counter.increment(&node_a);
653 counter.increment(&node_b);
654
655 assert_eq!(counter.value(), 3);
656 assert_eq!(counter.node_value(&node_a), 2);
657 assert_eq!(counter.node_value(&node_b), 1);
658 }
659
660 #[test]
661 fn test_gcounter_merge() {
662 let node_a = NodeId::new("A");
663 let node_b = NodeId::new("B");
664
665 let mut counter1 = GCounter::new();
666 counter1.increment(&node_a);
667 counter1.increment(&node_a);
668
669 let mut counter2 = GCounter::new();
670 counter2.increment(&node_b);
671 counter2.increment(&node_b);
672 counter2.increment(&node_b);
673
674 counter1.merge(&counter2);
675 assert_eq!(counter1.value(), 5);
676 }
677
678 #[test]
679 fn test_pncounter() {
680 let mut counter = PNCounter::new();
681 let node = NodeId::new("A");
682
683 counter.increment(&node);
684 counter.increment(&node);
685 counter.increment(&node);
686 counter.decrement(&node);
687
688 assert_eq!(counter.value(), 2);
689 }
690
691 #[test]
692 fn test_pncounter_negative() {
693 let mut counter = PNCounter::new();
694 let node = NodeId::new("A");
695
696 counter.decrement(&node);
697 counter.decrement(&node);
698
699 assert_eq!(counter.value(), -2);
700 }
701
702 #[test]
703 fn test_gset() {
704 let mut set: GSet<String> = GSet::new();
705
706 set.add("apple".to_string());
707 set.add("banana".to_string());
708 set.add("apple".to_string()); assert_eq!(set.len(), 2);
711 assert!(set.contains(&"apple".to_string()));
712 assert!(!set.contains(&"cherry".to_string()));
713 }
714
715 #[test]
716 fn test_gset_merge() {
717 let mut set1: GSet<String> = GSet::new();
718 set1.add("apple".to_string());
719
720 let mut set2: GSet<String> = GSet::new();
721 set2.add("banana".to_string());
722
723 set1.merge(&set2);
724 assert_eq!(set1.len(), 2);
725 }
726
727 #[test]
728 fn test_twopset() {
729 let mut set: TwoPSet<String> = TwoPSet::new();
730
731 set.add("apple".to_string());
732 set.add("banana".to_string());
733
734 assert!(set.contains(&"apple".to_string()));
735
736 set.remove("apple".to_string());
737 assert!(!set.contains(&"apple".to_string()));
738
739 set.add("apple".to_string());
741 assert!(!set.contains(&"apple".to_string()));
742 }
743
744 #[test]
745 fn test_lww_register() {
746 let mut reg: LWWRegister<String> = LWWRegister::new();
747 let node = NodeId::new("A");
748
749 reg.set("value1".to_string(), &node);
750 assert_eq!(reg.get(), Some(&"value1".to_string()));
751
752 reg.set("value2".to_string(), &node);
753 assert_eq!(reg.get(), Some(&"value2".to_string()));
754 }
755
756 #[test]
757 fn test_lww_register_merge() {
758 let node_a = NodeId::new("A");
759 let node_b = NodeId::new("B");
760
761 let mut reg1: LWWRegister<String> = LWWRegister::new();
762 reg1.set_with_timestamp("older".to_string(), 100, &node_a);
763
764 let mut reg2: LWWRegister<String> = LWWRegister::new();
765 reg2.set_with_timestamp("newer".to_string(), 200, &node_b);
766
767 reg1.merge(®2);
768 assert_eq!(reg1.get(), Some(&"newer".to_string()));
769 }
770
771 #[test]
772 fn test_mv_register() {
773 let mut reg: MVRegister<String> = MVRegister::new();
774 let node_a = NodeId::new("A");
775 let node_b = NodeId::new("B");
776
777 let mut clock1 = VectorClock::new();
778 clock1.set(&node_a, 1);
779
780 let mut clock2 = VectorClock::new();
781 clock2.set(&node_b, 1);
782
783 reg.set("value_a".to_string(), clock1);
785 reg.set("value_b".to_string(), clock2);
786
787 assert!(reg.has_conflict());
788 assert_eq!(reg.value_count(), 2);
789 }
790
791 #[test]
792 fn test_orset() {
793 let mut set: ORSet<String> = ORSet::new();
794 let node = NodeId::new("A");
795
796 set.add("apple".to_string(), &node);
797 set.add("banana".to_string(), &node);
798
799 assert!(set.contains(&"apple".to_string()));
800 assert_eq!(set.len(), 2);
801
802 set.remove(&"apple".to_string());
803 assert!(!set.contains(&"apple".to_string()));
804
805 set.add("apple".to_string(), &node);
807 assert!(set.contains(&"apple".to_string()));
808 }
809
810 #[test]
811 fn test_orset_merge() {
812 let node_a = NodeId::new("A");
813 let node_b = NodeId::new("B");
814
815 let mut set1: ORSet<String> = ORSet::new();
816 set1.add("apple".to_string(), &node_a);
817
818 let mut set2: ORSet<String> = ORSet::new();
819 set2.add("banana".to_string(), &node_b);
820
821 set1.merge(&set2);
822
823 assert!(set1.contains(&"apple".to_string()));
824 assert!(set1.contains(&"banana".to_string()));
825 }
826
827 #[test]
828 fn test_orset_concurrent_add_remove() {
829 let node_a = NodeId::new("A");
830 let node_b = NodeId::new("B");
831
832 let mut set1: ORSet<String> = ORSet::new();
833 set1.add("apple".to_string(), &node_a);
834
835 let mut set2 = set1.clone();
837
838 set1.remove(&"apple".to_string());
840
841 set2.add("apple".to_string(), &node_b);
843
844 set1.merge(&set2);
846 assert!(set1.contains(&"apple".to_string()));
847 }
848
849 #[test]
850 fn test_lww_map() {
851 let mut map: LWWMap<String, i32> = LWWMap::new();
852 let node = NodeId::new("A");
853
854 map.set("key1".to_string(), 100, &node);
855 map.set("key2".to_string(), 200, &node);
856
857 assert_eq!(map.get(&"key1".to_string()), Some(&100));
858 assert_eq!(map.len(), 2);
859 }
860
861 #[test]
862 fn test_lww_map_merge() {
863 let node_a = NodeId::new("A");
864 let node_b = NodeId::new("B");
865
866 let mut map1: LWWMap<String, i32> = LWWMap::new();
867 map1.set("key1".to_string(), 100, &node_a);
868
869 let mut map2: LWWMap<String, i32> = LWWMap::new();
870 map2.set("key2".to_string(), 200, &node_b);
871
872 map1.merge(&map2);
873
874 assert_eq!(map1.get(&"key1".to_string()), Some(&100));
875 assert_eq!(map1.get(&"key2".to_string()), Some(&200));
876 }
877
878 #[test]
879 fn test_gcounter_delta() {
880 let mut counter = GCounter::new();
881 let node = NodeId::new("A");
882
883 counter.increment(&node);
884
885 let delta = GCounterDelta {
886 node_id: "B".to_string(),
887 value: 5,
888 };
889
890 counter.apply_delta(&delta);
891 assert_eq!(counter.value(), 6);
892 }
893
894 #[test]
895 fn test_crdt_merged() {
896 let node_a = NodeId::new("A");
897 let node_b = NodeId::new("B");
898
899 let mut counter1 = GCounter::new();
900 counter1.increment(&node_a);
901
902 let mut counter2 = GCounter::new();
903 counter2.increment(&node_b);
904
905 let merged = counter1.merged(&counter2);
906
907 assert_eq!(counter1.value(), 1);
909 assert_eq!(merged.value(), 2);
911 }
912}