atomr_distributed_data/
sets.rs1use std::collections::{HashMap, HashSet};
4use std::hash::Hash;
5
6use serde::{Deserialize, Serialize};
7
8use crate::traits::CrdtMerge;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct GSet<T>
12where
13 T: Eq + Hash + Clone,
14{
15 items: HashSet<T>,
16}
17
18impl<T: Eq + Hash + Clone> Default for GSet<T> {
19 fn default() -> Self {
20 Self { items: HashSet::new() }
21 }
22}
23
24impl<T: Eq + Hash + Clone> GSet<T> {
25 pub fn new() -> Self {
26 Self::default()
27 }
28
29 pub fn add(&mut self, item: T) {
30 self.items.insert(item);
31 }
32
33 pub fn contains(&self, item: &T) -> bool {
34 self.items.contains(item)
35 }
36
37 pub fn iter(&self) -> impl Iterator<Item = &T> {
38 self.items.iter()
39 }
40
41 pub fn len(&self) -> usize {
42 self.items.len()
43 }
44
45 pub fn is_empty(&self) -> bool {
46 self.items.is_empty()
47 }
48}
49
50impl<T: Eq + Hash + Clone> CrdtMerge for GSet<T> {
51 fn merge(&mut self, other: &Self) {
52 for item in &other.items {
53 self.items.insert(item.clone());
54 }
55 }
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct OrSet<T>
63where
64 T: Eq + Hash + Clone,
65{
66 adds: HashMap<T, HashSet<u64>>,
67 removes: HashMap<T, HashSet<u64>>,
68 counter: u64,
69}
70
71impl<T: Eq + Hash + Clone> Default for OrSet<T> {
72 fn default() -> Self {
73 Self { adds: HashMap::new(), removes: HashMap::new(), counter: 0 }
74 }
75}
76
77impl<T: Eq + Hash + Clone> OrSet<T> {
78 pub fn new() -> Self {
79 Self::default()
80 }
81
82 pub fn add(&mut self, item: T) {
83 self.counter += 1;
84 self.adds.entry(item).or_default().insert(self.counter);
85 }
86
87 pub fn remove(&mut self, item: &T) {
88 if let Some(tags) = self.adds.get(item).cloned() {
89 self.removes.entry(item.clone()).or_default().extend(tags);
90 }
91 }
92
93 pub fn contains(&self, item: &T) -> bool {
94 match (self.adds.get(item), self.removes.get(item)) {
95 (Some(a), Some(r)) => a.difference(r).next().is_some(),
96 (Some(a), None) => !a.is_empty(),
97 _ => false,
98 }
99 }
100
101 pub fn iter(&self) -> impl Iterator<Item = &T> {
105 self.adds.iter().filter_map(|(k, add_tags)| {
106 let kept = match self.removes.get(k) {
107 Some(rem_tags) => add_tags.difference(rem_tags).next().is_some(),
108 None => !add_tags.is_empty(),
109 };
110 if kept {
111 Some(k)
112 } else {
113 None
114 }
115 })
116 }
117
118 pub fn len(&self) -> usize {
119 self.iter().count()
120 }
121
122 pub fn is_empty(&self) -> bool {
123 self.iter().next().is_none()
124 }
125}
126
127impl<T: Eq + Hash + Clone> CrdtMerge for OrSet<T> {
128 fn merge(&mut self, other: &Self) {
129 for (k, v) in &other.adds {
130 self.adds.entry(k.clone()).or_default().extend(v.iter().copied());
131 }
132 for (k, v) in &other.removes {
133 self.removes.entry(k.clone()).or_default().extend(v.iter().copied());
134 }
135 self.counter = self.counter.max(other.counter);
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use super::*;
142
143 #[test]
144 fn gset_merges_union() {
145 let mut a = GSet::<i32>::new();
146 let mut b = GSet::<i32>::new();
147 a.add(1);
148 b.add(2);
149 a.merge(&b);
150 assert_eq!(a.len(), 2);
151 }
152
153 #[test]
154 fn orset_add_then_remove() {
155 let mut s = OrSet::<&'static str>::new();
156 s.add("x");
157 assert!(s.contains(&"x"));
158 s.remove(&"x");
159 assert!(!s.contains(&"x"));
160 }
161
162 #[test]
163 fn orset_merge_preserves_re_add_after_concurrent_remove() {
164 let mut a = OrSet::<&'static str>::new();
165 a.add("x");
166
167 let mut b = a.clone();
168 b.remove(&"x");
169
170 a.add("x");
171 a.merge(&b);
172 assert!(a.contains(&"x"));
173 }
174}