1use std::collections::{BTreeMap, BTreeSet};
2
3use crate::Crdt;
4
5#[derive(Debug, Clone, PartialEq, Eq)]
31pub struct ORSet<T: Ord + Clone> {
32 actor: String,
33 counter: u64,
34 elements: BTreeMap<T, BTreeSet<(String, u64)>>,
36 tombstones: BTreeSet<(String, u64)>,
38}
39
40impl<T: Ord + Clone> ORSet<T> {
41 pub fn new(actor: impl Into<String>) -> Self {
43 Self {
44 actor: actor.into(),
45 counter: 0,
46 elements: BTreeMap::new(),
47 tombstones: BTreeSet::new(),
48 }
49 }
50
51 pub fn insert(&mut self, value: T) {
56 self.counter += 1;
57 let tag = (self.actor.clone(), self.counter);
58 self.elements.entry(value).or_default().insert(tag);
59 }
60
61 pub fn remove(&mut self, value: &T) -> bool {
68 if let Some(tags) = self.elements.remove(value) {
69 self.tombstones.extend(tags);
70 true
71 } else {
72 false
73 }
74 }
75
76 #[must_use]
78 pub fn contains(&self, value: &T) -> bool {
79 self.elements
80 .get(value)
81 .is_some_and(|tags| !tags.is_empty())
82 }
83
84 #[must_use]
86 pub fn len(&self) -> usize {
87 self.elements
88 .values()
89 .filter(|tags| !tags.is_empty())
90 .count()
91 }
92
93 #[must_use]
95 pub fn is_empty(&self) -> bool {
96 self.len() == 0
97 }
98
99 pub fn iter(&self) -> impl Iterator<Item = &T> {
101 self.elements
102 .iter()
103 .filter(|(_, tags)| !tags.is_empty())
104 .map(|(v, _)| v)
105 }
106
107 #[must_use]
109 pub fn actor(&self) -> &str {
110 &self.actor
111 }
112}
113
114impl<T: Ord + Clone> Crdt for ORSet<T> {
115 fn merge(&mut self, other: &Self) {
116 for (value, other_tags) in &other.elements {
118 let self_tags = self.elements.entry(value.clone()).or_default();
119 for tag in other_tags {
120 if !self.tombstones.contains(tag) {
122 self_tags.insert(tag.clone());
123 }
124 }
125 }
126
127 for tag in &other.tombstones {
129 for tags in self.elements.values_mut() {
130 tags.remove(tag);
131 }
132 }
133
134 self.tombstones.extend(other.tombstones.iter().cloned());
136
137 self.elements.retain(|_, tags| !tags.is_empty());
139
140 self.counter = self.counter.max(other.counter);
142 }
143}
144
145#[cfg(test)]
146mod tests {
147 use super::*;
148
149 #[test]
150 fn new_set_is_empty() {
151 let s = ORSet::<String>::new("a");
152 assert!(s.is_empty());
153 assert_eq!(s.len(), 0);
154 }
155
156 #[test]
157 fn insert_and_contains() {
158 let mut s = ORSet::new("a");
159 s.insert("x");
160 assert!(s.contains(&"x"));
161 assert_eq!(s.len(), 1);
162 }
163
164 #[test]
165 fn remove_element() {
166 let mut s = ORSet::new("a");
167 s.insert("x");
168 assert!(s.remove(&"x"));
169 assert!(!s.contains(&"x"));
170 assert_eq!(s.len(), 0);
171 }
172
173 #[test]
174 fn can_readd_after_remove() {
175 let mut s = ORSet::new("a");
176 s.insert("x");
177 s.remove(&"x");
178 assert!(!s.contains(&"x"));
179
180 s.insert("x");
181 assert!(s.contains(&"x"));
182 }
183
184 #[test]
185 fn concurrent_add_survives_remove() {
186 let mut s1 = ORSet::new("a");
187 s1.insert("x");
188
189 s1.remove(&"x");
191
192 let mut s2 = ORSet::new("b");
194 s2.insert("x");
195
196 s1.merge(&s2);
197 assert!(s1.contains(&"x"));
199 }
200
201 #[test]
202 fn merge_is_commutative() {
203 let mut s1 = ORSet::new("a");
204 s1.insert("x");
205 s1.insert("y");
206
207 let mut s2 = ORSet::new("b");
208 s2.insert("y");
209 s2.insert("z");
210
211 let mut left = s1.clone();
212 left.merge(&s2);
213
214 let mut right = s2.clone();
215 right.merge(&s1);
216
217 let left_elems: BTreeSet<_> = left.iter().collect();
218 let right_elems: BTreeSet<_> = right.iter().collect();
219 assert_eq!(left_elems, right_elems);
220 }
221
222 #[test]
223 fn merge_is_idempotent() {
224 let mut s1 = ORSet::new("a");
225 s1.insert("x");
226
227 let mut s2 = ORSet::new("b");
228 s2.insert("y");
229
230 s1.merge(&s2);
231 let after_first = s1.clone();
232 s1.merge(&s2);
233
234 assert_eq!(s1, after_first);
235 }
236
237 #[test]
238 fn add_wins_semantics() {
239 let mut s1 = ORSet::new("a");
241 s1.insert("x");
242 s1.remove(&"x");
243
244 let mut s2 = ORSet::new("b");
246 s2.insert("x");
247
248 s1.merge(&s2);
249 assert!(s1.contains(&"x"));
251 }
252
253 #[test]
254 fn remove_nonexistent_returns_false() {
255 let mut s = ORSet::<&str>::new("a");
256 assert!(!s.remove(&"x"));
257 }
258
259 #[test]
260 fn iterate_elements() {
261 let mut s = ORSet::new("a");
262 s.insert(1);
263 s.insert(2);
264 s.insert(3);
265 s.remove(&2);
266
267 let elems: Vec<&i32> = s.iter().collect();
268 assert_eq!(elems, vec![&1, &3]);
269 }
270}