1use crate::node::NodeId;
9use crate::vector_clock::VectorClock;
10use serde::{Deserialize, Serialize};
11use std::collections::{HashMap, HashSet};
12
13pub trait CRDT: Clone {
19 fn merge(&mut self, other: &Self);
21
22 fn merged(&self, other: &Self) -> Self {
24 let mut result = self.clone();
25 result.merge(other);
26 result
27 }
28}
29
30#[derive(Debug, Clone, Default, Serialize, Deserialize)]
36pub struct GCounter {
37 counts: HashMap<String, u64>,
38}
39
40impl GCounter {
41 pub fn new() -> Self {
43 Self {
44 counts: HashMap::new(),
45 }
46 }
47
48 pub fn increment(&mut self, node_id: &NodeId) {
50 let key = node_id.as_str().to_string();
51 *self.counts.entry(key).or_insert(0) += 1;
52 }
53
54 pub fn increment_by(&mut self, node_id: &NodeId, amount: u64) {
56 let key = node_id.as_str().to_string();
57 *self.counts.entry(key).or_insert(0) += amount;
58 }
59
60 pub fn value(&self) -> u64 {
62 self.counts.values().sum()
63 }
64
65 pub fn node_value(&self, node_id: &NodeId) -> u64 {
67 self.counts.get(node_id.as_str()).copied().unwrap_or(0)
68 }
69}
70
71impl CRDT for GCounter {
72 fn merge(&mut self, other: &Self) {
73 for (node, &value) in &other.counts {
74 let current = self.counts.entry(node.clone()).or_insert(0);
75 *current = (*current).max(value);
76 }
77 }
78}
79
80#[derive(Debug, Clone, Default, Serialize, Deserialize)]
86pub struct PNCounter {
87 positive: GCounter,
88 negative: GCounter,
89}
90
91impl PNCounter {
92 pub fn new() -> Self {
94 Self {
95 positive: GCounter::new(),
96 negative: GCounter::new(),
97 }
98 }
99
100 pub fn increment(&mut self, node_id: &NodeId) {
102 self.positive.increment(node_id);
103 }
104
105 pub fn increment_by(&mut self, node_id: &NodeId, amount: u64) {
107 self.positive.increment_by(node_id, amount);
108 }
109
110 pub fn decrement(&mut self, node_id: &NodeId) {
112 self.negative.increment(node_id);
113 }
114
115 pub fn decrement_by(&mut self, node_id: &NodeId, amount: u64) {
117 self.negative.increment_by(node_id, amount);
118 }
119
120 pub fn value(&self) -> i64 {
122 self.positive.value() as i64 - self.negative.value() as i64
123 }
124}
125
126impl CRDT for PNCounter {
127 fn merge(&mut self, other: &Self) {
128 self.positive.merge(&other.positive);
129 self.negative.merge(&other.negative);
130 }
131}
132
133#[derive(Debug, Clone, Default, Serialize, Deserialize)]
139pub struct GSet<T: Clone + Eq + std::hash::Hash + Serialize> {
140 elements: HashSet<T>,
141}
142
143impl<T: Clone + Eq + std::hash::Hash + Serialize + for<'de> Deserialize<'de>> GSet<T> {
144 pub fn new() -> Self {
146 Self {
147 elements: HashSet::new(),
148 }
149 }
150
151 pub fn add(&mut self, element: T) {
153 self.elements.insert(element);
154 }
155
156 pub fn contains(&self, element: &T) -> bool {
158 self.elements.contains(element)
159 }
160
161 pub fn len(&self) -> usize {
163 self.elements.len()
164 }
165
166 pub fn is_empty(&self) -> bool {
168 self.elements.is_empty()
169 }
170
171 pub fn iter(&self) -> impl Iterator<Item = &T> {
173 self.elements.iter()
174 }
175
176 pub fn to_vec(&self) -> Vec<T> {
178 self.elements.iter().cloned().collect()
179 }
180}
181
182impl<T: Clone + Eq + std::hash::Hash + Serialize + for<'de> Deserialize<'de>> CRDT for GSet<T> {
183 fn merge(&mut self, other: &Self) {
184 for element in &other.elements {
185 self.elements.insert(element.clone());
186 }
187 }
188}
189
190#[derive(Debug, Clone, Default, Serialize, Deserialize)]
196pub struct TwoPSet<T: Clone + Eq + std::hash::Hash + Serialize> {
197 added: HashSet<T>,
198 removed: HashSet<T>,
199}
200
201impl<T: Clone + Eq + std::hash::Hash + Serialize + for<'de> Deserialize<'de>> TwoPSet<T> {
202 pub fn new() -> Self {
204 Self {
205 added: HashSet::new(),
206 removed: HashSet::new(),
207 }
208 }
209
210 pub fn add(&mut self, element: T) {
212 if !self.removed.contains(&element) {
213 self.added.insert(element);
214 }
215 }
216
217 pub fn remove(&mut self, element: T) {
219 if self.added.contains(&element) {
220 self.removed.insert(element);
221 }
222 }
223
224 pub fn contains(&self, element: &T) -> bool {
226 self.added.contains(element) && !self.removed.contains(element)
227 }
228
229 pub fn elements(&self) -> HashSet<T> {
231 self.added
232 .difference(&self.removed)
233 .cloned()
234 .collect()
235 }
236
237 pub fn len(&self) -> usize {
239 self.elements().len()
240 }
241
242 pub fn is_empty(&self) -> bool {
244 self.elements().is_empty()
245 }
246}
247
248impl<T: Clone + Eq + std::hash::Hash + Serialize + for<'de> Deserialize<'de>> CRDT for TwoPSet<T> {
249 fn merge(&mut self, other: &Self) {
250 for element in &other.added {
251 self.added.insert(element.clone());
252 }
253 for element in &other.removed {
254 self.removed.insert(element.clone());
255 }
256 }
257}
258
259#[derive(Debug, Clone, Serialize, Deserialize)]
265pub struct LWWRegister<T: Clone + Serialize> {
266 value: Option<T>,
267 timestamp: u64,
268 node_id: String,
269}
270
271impl<T: Clone + Serialize + for<'de> Deserialize<'de>> LWWRegister<T> {
272 pub fn new() -> Self {
274 Self {
275 value: None,
276 timestamp: 0,
277 node_id: String::new(),
278 }
279 }
280
281 pub fn with_value(value: T, node_id: &NodeId) -> Self {
283 Self {
284 value: Some(value),
285 timestamp: Self::now(),
286 node_id: node_id.as_str().to_string(),
287 }
288 }
289
290 fn now() -> u64 {
292 std::time::SystemTime::now()
293 .duration_since(std::time::UNIX_EPOCH)
294 .unwrap_or_default()
295 .as_nanos() as u64
296 }
297
298 pub fn set(&mut self, value: T, node_id: &NodeId) {
300 let ts = Self::now();
301 if ts > self.timestamp || (ts == self.timestamp && node_id.as_str() > &self.node_id) {
302 self.value = Some(value);
303 self.timestamp = ts;
304 self.node_id = node_id.as_str().to_string();
305 }
306 }
307
308 pub fn set_with_timestamp(&mut self, value: T, timestamp: u64, node_id: &NodeId) {
310 if timestamp > self.timestamp
311 || (timestamp == self.timestamp && node_id.as_str() > &self.node_id)
312 {
313 self.value = Some(value);
314 self.timestamp = timestamp;
315 self.node_id = node_id.as_str().to_string();
316 }
317 }
318
319 pub fn get(&self) -> Option<&T> {
321 self.value.as_ref()
322 }
323
324 pub fn timestamp(&self) -> u64 {
326 self.timestamp
327 }
328
329 pub fn is_set(&self) -> bool {
331 self.value.is_some()
332 }
333}
334
335impl<T: Clone + Serialize + for<'de> Deserialize<'de>> Default for LWWRegister<T> {
336 fn default() -> Self {
337 Self::new()
338 }
339}
340
341impl<T: Clone + Serialize + for<'de> Deserialize<'de>> CRDT for LWWRegister<T> {
342 fn merge(&mut self, other: &Self) {
343 if other.timestamp > self.timestamp
344 || (other.timestamp == self.timestamp && other.node_id > self.node_id)
345 {
346 self.value = other.value.clone();
347 self.timestamp = other.timestamp;
348 self.node_id = other.node_id.clone();
349 }
350 }
351}
352
353#[derive(Debug, Clone, Serialize, Deserialize)]
359pub struct MVRegister<T: Clone + Eq + std::hash::Hash + Serialize> {
360 values: HashMap<T, VectorClock>,
361}
362
363impl<T: Clone + Eq + std::hash::Hash + Serialize + for<'de> Deserialize<'de>> MVRegister<T> {
364 pub fn new() -> Self {
366 Self {
367 values: HashMap::new(),
368 }
369 }
370
371 pub fn set(&mut self, value: T, clock: VectorClock) {
373 self.values.retain(|_, v| !v.happened_before(&clock));
375
376 let dominated = self.values.values().any(|v| clock.happened_before(v));
378 if !dominated {
379 self.values.insert(value, clock);
380 }
381 }
382
383 pub fn get(&self) -> Vec<&T> {
385 self.values.keys().collect()
386 }
387
388 pub fn get_with_clocks(&self) -> Vec<(&T, &VectorClock)> {
390 self.values.iter().collect()
391 }
392
393 pub fn has_conflict(&self) -> bool {
395 self.values.len() > 1
396 }
397
398 pub fn value_count(&self) -> usize {
400 self.values.len()
401 }
402}
403
404impl<T: Clone + Eq + std::hash::Hash + Serialize + for<'de> Deserialize<'de>> Default
405 for MVRegister<T>
406{
407 fn default() -> Self {
408 Self::new()
409 }
410}
411
412impl<T: Clone + Eq + std::hash::Hash + Serialize + for<'de> Deserialize<'de>> CRDT
413 for MVRegister<T>
414{
415 fn merge(&mut self, other: &Self) {
416 for (value, clock) in &other.values {
417 self.set(value.clone(), clock.clone());
418 }
419 }
420}
421
422#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
428pub struct UniqueTag {
429 node_id: String,
430 counter: u64,
431}
432
433impl UniqueTag {
434 pub fn new(node_id: &NodeId, counter: u64) -> Self {
436 Self {
437 node_id: node_id.as_str().to_string(),
438 counter,
439 }
440 }
441}
442
443#[derive(Debug, Clone, Default, Serialize, Deserialize)]
445pub struct ORSet<T: Clone + Eq + std::hash::Hash + Serialize> {
446 elements: HashMap<T, HashSet<UniqueTag>>,
447 counters: HashMap<String, u64>,
448}
449
450impl<T: Clone + Eq + std::hash::Hash + Serialize + for<'de> Deserialize<'de>> ORSet<T> {
451 pub fn new() -> Self {
453 Self {
454 elements: HashMap::new(),
455 counters: HashMap::new(),
456 }
457 }
458
459 pub fn add(&mut self, element: T, node_id: &NodeId) {
461 let counter = self
462 .counters
463 .entry(node_id.as_str().to_string())
464 .or_insert(0);
465 *counter += 1;
466
467 let tag = UniqueTag::new(node_id, *counter);
468 self.elements
469 .entry(element)
470 .or_insert_with(HashSet::new)
471 .insert(tag);
472 }
473
474 pub fn remove(&mut self, element: &T) {
476 self.elements.remove(element);
477 }
478
479 pub fn contains(&self, element: &T) -> bool {
481 self.elements
482 .get(element)
483 .map(|tags| !tags.is_empty())
484 .unwrap_or(false)
485 }
486
487 pub fn elements(&self) -> Vec<&T> {
489 self.elements
490 .iter()
491 .filter(|(_, tags)| !tags.is_empty())
492 .map(|(elem, _)| elem)
493 .collect()
494 }
495
496 pub fn len(&self) -> usize {
498 self.elements
499 .iter()
500 .filter(|(_, tags)| !tags.is_empty())
501 .count()
502 }
503
504 pub fn is_empty(&self) -> bool {
506 self.len() == 0
507 }
508}
509
510impl<T: Clone + Eq + std::hash::Hash + Serialize + for<'de> Deserialize<'de>> CRDT for ORSet<T> {
511 fn merge(&mut self, other: &Self) {
512 for (element, tags) in &other.elements {
513 let our_tags = self.elements.entry(element.clone()).or_insert_with(HashSet::new);
514 for tag in tags {
515 our_tags.insert(tag.clone());
516 }
517 }
518
519 for (node, &counter) in &other.counters {
520 let our_counter = self.counters.entry(node.clone()).or_insert(0);
521 *our_counter = (*our_counter).max(counter);
522 }
523 }
524}
525
526#[derive(Debug, Clone, Serialize, Deserialize)]
532pub struct LWWMap<K: Clone + Eq + std::hash::Hash + Serialize, V: Clone + Serialize> {
533 entries: HashMap<K, LWWRegister<V>>,
534}
535
536impl<
537 K: Clone + Eq + std::hash::Hash + Serialize + for<'de> Deserialize<'de>,
538 V: Clone + Serialize + for<'de> Deserialize<'de>,
539 > LWWMap<K, V>
540{
541 pub fn new() -> Self {
543 Self {
544 entries: HashMap::new(),
545 }
546 }
547
548 pub fn set(&mut self, key: K, value: V, node_id: &NodeId) {
550 self.entries
551 .entry(key)
552 .or_insert_with(LWWRegister::new)
553 .set(value, node_id);
554 }
555
556 pub fn get(&self, key: &K) -> Option<&V> {
558 self.entries.get(key).and_then(|r| r.get())
559 }
560
561 pub fn remove(&mut self, key: &K) {
563 self.entries.remove(key);
564 }
565
566 pub fn keys(&self) -> Vec<&K> {
568 self.entries
569 .iter()
570 .filter(|(_, v)| v.is_set())
571 .map(|(k, _)| k)
572 .collect()
573 }
574
575 pub fn len(&self) -> usize {
577 self.entries.iter().filter(|(_, v)| v.is_set()).count()
578 }
579
580 pub fn is_empty(&self) -> bool {
582 self.len() == 0
583 }
584}
585
586impl<
587 K: Clone + Eq + std::hash::Hash + Serialize + for<'de> Deserialize<'de>,
588 V: Clone + Serialize + for<'de> Deserialize<'de>,
589 > Default for LWWMap<K, V>
590{
591 fn default() -> Self {
592 Self::new()
593 }
594}
595
596impl<
597 K: Clone + Eq + std::hash::Hash + Serialize + for<'de> Deserialize<'de>,
598 V: Clone + Serialize + for<'de> Deserialize<'de>,
599 > CRDT for LWWMap<K, V>
600{
601 fn merge(&mut self, other: &Self) {
602 for (key, register) in &other.entries {
603 self.entries
604 .entry(key.clone())
605 .or_insert_with(LWWRegister::new)
606 .merge(register);
607 }
608 }
609}
610
611pub trait DeltaCRDT: CRDT {
617 type Delta: Clone + Serialize;
619
620 fn apply_delta(&mut self, delta: &Self::Delta);
622
623 fn generate_delta(&self) -> Option<Self::Delta>;
625}
626
627#[derive(Debug, Clone, Serialize, Deserialize)]
629pub struct GCounterDelta {
630 pub node_id: String,
631 pub value: u64,
632}
633
634impl DeltaCRDT for GCounter {
635 type Delta = GCounterDelta;
636
637 fn apply_delta(&mut self, delta: &Self::Delta) {
638 let current = self.counts.entry(delta.node_id.clone()).or_insert(0);
639 *current = (*current).max(delta.value);
640 }
641
642 fn generate_delta(&self) -> Option<Self::Delta> {
643 None
645 }
646}
647
648#[cfg(test)]
653mod tests {
654 use super::*;
655
656 #[test]
657 fn test_gcounter_basic() {
658 let mut counter = GCounter::new();
659 let node_a = NodeId::new("A");
660 let node_b = NodeId::new("B");
661
662 counter.increment(&node_a);
663 counter.increment(&node_a);
664 counter.increment(&node_b);
665
666 assert_eq!(counter.value(), 3);
667 assert_eq!(counter.node_value(&node_a), 2);
668 assert_eq!(counter.node_value(&node_b), 1);
669 }
670
671 #[test]
672 fn test_gcounter_merge() {
673 let node_a = NodeId::new("A");
674 let node_b = NodeId::new("B");
675
676 let mut counter1 = GCounter::new();
677 counter1.increment(&node_a);
678 counter1.increment(&node_a);
679
680 let mut counter2 = GCounter::new();
681 counter2.increment(&node_b);
682 counter2.increment(&node_b);
683 counter2.increment(&node_b);
684
685 counter1.merge(&counter2);
686 assert_eq!(counter1.value(), 5);
687 }
688
689 #[test]
690 fn test_pncounter() {
691 let mut counter = PNCounter::new();
692 let node = NodeId::new("A");
693
694 counter.increment(&node);
695 counter.increment(&node);
696 counter.increment(&node);
697 counter.decrement(&node);
698
699 assert_eq!(counter.value(), 2);
700 }
701
702 #[test]
703 fn test_pncounter_negative() {
704 let mut counter = PNCounter::new();
705 let node = NodeId::new("A");
706
707 counter.decrement(&node);
708 counter.decrement(&node);
709
710 assert_eq!(counter.value(), -2);
711 }
712
713 #[test]
714 fn test_gset() {
715 let mut set: GSet<String> = GSet::new();
716
717 set.add("apple".to_string());
718 set.add("banana".to_string());
719 set.add("apple".to_string()); assert_eq!(set.len(), 2);
722 assert!(set.contains(&"apple".to_string()));
723 assert!(!set.contains(&"cherry".to_string()));
724 }
725
726 #[test]
727 fn test_gset_merge() {
728 let mut set1: GSet<String> = GSet::new();
729 set1.add("apple".to_string());
730
731 let mut set2: GSet<String> = GSet::new();
732 set2.add("banana".to_string());
733
734 set1.merge(&set2);
735 assert_eq!(set1.len(), 2);
736 }
737
738 #[test]
739 fn test_twopset() {
740 let mut set: TwoPSet<String> = TwoPSet::new();
741
742 set.add("apple".to_string());
743 set.add("banana".to_string());
744
745 assert!(set.contains(&"apple".to_string()));
746
747 set.remove("apple".to_string());
748 assert!(!set.contains(&"apple".to_string()));
749
750 set.add("apple".to_string());
752 assert!(!set.contains(&"apple".to_string()));
753 }
754
755 #[test]
756 fn test_lww_register() {
757 let mut reg: LWWRegister<String> = LWWRegister::new();
758 let node = NodeId::new("A");
759
760 reg.set("value1".to_string(), &node);
761 assert_eq!(reg.get(), Some(&"value1".to_string()));
762
763 reg.set("value2".to_string(), &node);
764 assert_eq!(reg.get(), Some(&"value2".to_string()));
765 }
766
767 #[test]
768 fn test_lww_register_merge() {
769 let node_a = NodeId::new("A");
770 let node_b = NodeId::new("B");
771
772 let mut reg1: LWWRegister<String> = LWWRegister::new();
773 reg1.set_with_timestamp("older".to_string(), 100, &node_a);
774
775 let mut reg2: LWWRegister<String> = LWWRegister::new();
776 reg2.set_with_timestamp("newer".to_string(), 200, &node_b);
777
778 reg1.merge(®2);
779 assert_eq!(reg1.get(), Some(&"newer".to_string()));
780 }
781
782 #[test]
783 fn test_mv_register() {
784 let mut reg: MVRegister<String> = MVRegister::new();
785 let node_a = NodeId::new("A");
786 let node_b = NodeId::new("B");
787
788 let mut clock1 = VectorClock::new();
789 clock1.set(&node_a, 1);
790
791 let mut clock2 = VectorClock::new();
792 clock2.set(&node_b, 1);
793
794 reg.set("value_a".to_string(), clock1);
796 reg.set("value_b".to_string(), clock2);
797
798 assert!(reg.has_conflict());
799 assert_eq!(reg.value_count(), 2);
800 }
801
802 #[test]
803 fn test_orset() {
804 let mut set: ORSet<String> = ORSet::new();
805 let node = NodeId::new("A");
806
807 set.add("apple".to_string(), &node);
808 set.add("banana".to_string(), &node);
809
810 assert!(set.contains(&"apple".to_string()));
811 assert_eq!(set.len(), 2);
812
813 set.remove(&"apple".to_string());
814 assert!(!set.contains(&"apple".to_string()));
815
816 set.add("apple".to_string(), &node);
818 assert!(set.contains(&"apple".to_string()));
819 }
820
821 #[test]
822 fn test_orset_merge() {
823 let node_a = NodeId::new("A");
824 let node_b = NodeId::new("B");
825
826 let mut set1: ORSet<String> = ORSet::new();
827 set1.add("apple".to_string(), &node_a);
828
829 let mut set2: ORSet<String> = ORSet::new();
830 set2.add("banana".to_string(), &node_b);
831
832 set1.merge(&set2);
833
834 assert!(set1.contains(&"apple".to_string()));
835 assert!(set1.contains(&"banana".to_string()));
836 }
837
838 #[test]
839 fn test_orset_concurrent_add_remove() {
840 let node_a = NodeId::new("A");
841 let node_b = NodeId::new("B");
842
843 let mut set1: ORSet<String> = ORSet::new();
844 set1.add("apple".to_string(), &node_a);
845
846 let mut set2 = set1.clone();
848
849 set1.remove(&"apple".to_string());
851
852 set2.add("apple".to_string(), &node_b);
854
855 set1.merge(&set2);
857 assert!(set1.contains(&"apple".to_string()));
858 }
859
860 #[test]
861 fn test_lww_map() {
862 let mut map: LWWMap<String, i32> = LWWMap::new();
863 let node = NodeId::new("A");
864
865 map.set("key1".to_string(), 100, &node);
866 map.set("key2".to_string(), 200, &node);
867
868 assert_eq!(map.get(&"key1".to_string()), Some(&100));
869 assert_eq!(map.len(), 2);
870 }
871
872 #[test]
873 fn test_lww_map_merge() {
874 let node_a = NodeId::new("A");
875 let node_b = NodeId::new("B");
876
877 let mut map1: LWWMap<String, i32> = LWWMap::new();
878 map1.set("key1".to_string(), 100, &node_a);
879
880 let mut map2: LWWMap<String, i32> = LWWMap::new();
881 map2.set("key2".to_string(), 200, &node_b);
882
883 map1.merge(&map2);
884
885 assert_eq!(map1.get(&"key1".to_string()), Some(&100));
886 assert_eq!(map1.get(&"key2".to_string()), Some(&200));
887 }
888
889 #[test]
890 fn test_gcounter_delta() {
891 let mut counter = GCounter::new();
892 let node = NodeId::new("A");
893
894 counter.increment(&node);
895
896 let delta = GCounterDelta {
897 node_id: "B".to_string(),
898 value: 5,
899 };
900
901 counter.apply_delta(&delta);
902 assert_eq!(counter.value(), 6);
903 }
904
905 #[test]
906 fn test_crdt_merged() {
907 let node_a = NodeId::new("A");
908 let node_b = NodeId::new("B");
909
910 let mut counter1 = GCounter::new();
911 counter1.increment(&node_a);
912
913 let mut counter2 = GCounter::new();
914 counter2.increment(&node_b);
915
916 let merged = counter1.merged(&counter2);
917
918 assert_eq!(counter1.value(), 1);
920 assert_eq!(merged.value(), 2);
922 }
923}