logicaffeine_data/crdt/
orset.rs1use super::causal::{Dot, DotContext, VClock};
7use super::delta::DeltaCrdt;
8use super::replica::{generate_replica_id, ReplicaId};
9use super::Merge;
10use serde::de::DeserializeOwned;
11use serde::{Deserialize, Serialize};
12use std::collections::{HashMap, HashSet};
13use std::hash::Hash;
14use std::marker::PhantomData;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18#[serde(bound = "T: Serialize + serde::de::DeserializeOwned + Hash + Eq")]
19pub struct ORSetDelta<T> {
20 pub entries: HashMap<T, HashSet<Dot>>,
21 pub context: DotContext,
22}
23
24pub trait SetBias: Default + Clone + Send + 'static {
41 fn resolve(
54 local_has_dots: bool,
55 remote_has_dots: bool,
56 local_removed: bool,
57 remote_removed: bool,
58 ) -> bool;
59}
60
61#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
68pub struct AddWins;
69
70impl SetBias for AddWins {
71 fn resolve(
72 local_has_dots: bool,
73 remote_has_dots: bool,
74 _local_removed: bool,
75 _remote_removed: bool,
76 ) -> bool {
77 local_has_dots || remote_has_dots
78 }
79}
80
81#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
88pub struct RemoveWins;
89
90impl SetBias for RemoveWins {
91 fn resolve(
92 local_has_dots: bool,
93 remote_has_dots: bool,
94 local_removed: bool,
95 remote_removed: bool,
96 ) -> bool {
97 if local_removed || remote_removed {
99 false
100 } else {
101 local_has_dots || remote_has_dots
103 }
104 }
105}
106
107#[derive(Debug, Clone, Serialize, Deserialize)]
113#[serde(bound = "T: Serialize + serde::de::DeserializeOwned + Hash + Eq")]
114pub struct ORSet<T, B: SetBias = AddWins> {
115 entries: HashMap<T, HashSet<Dot>>,
117 context: DotContext,
119 replica_id: ReplicaId,
121 #[serde(skip)]
123 _bias: PhantomData<B>,
124}
125
126impl<T: Hash + Eq + Clone, B: SetBias> ORSet<T, B> {
127 pub fn new(replica_id: ReplicaId) -> Self {
129 Self {
130 entries: HashMap::new(),
131 context: DotContext::new(),
132 replica_id,
133 _bias: PhantomData,
134 }
135 }
136
137 pub fn new_random() -> Self {
139 Self::new(generate_replica_id())
140 }
141
142 pub fn add(&mut self, value: T) {
144 let dot = self.context.next(self.replica_id);
145 self.entries.entry(value).or_default().insert(dot);
146 }
147
148 pub fn insert(&mut self, value: T) {
150 self.add(value);
151 }
152
153 pub fn remove(&mut self, value: &T) {
157 self.entries.remove(value);
158 }
159
160 pub fn contains(&self, value: &T) -> bool {
162 self.entries
163 .get(value)
164 .map_or(false, |dots| !dots.is_empty())
165 }
166
167 pub fn len(&self) -> usize {
169 self.entries
170 .values()
171 .filter(|dots| !dots.is_empty())
172 .count()
173 }
174
175 pub fn is_empty(&self) -> bool {
177 self.len() == 0
178 }
179
180 pub fn iter(&self) -> impl Iterator<Item = &T> {
182 self.entries
183 .iter()
184 .filter(|(_, dots)| !dots.is_empty())
185 .map(|(v, _)| v)
186 }
187
188 pub fn replica_id(&self) -> ReplicaId {
190 self.replica_id
191 }
192}
193
194impl<T: Hash + Eq, B: SetBias> PartialEq for ORSet<T, B> {
195 fn eq(&self, other: &Self) -> bool {
196 self.entries == other.entries && self.context == other.context
197 }
198}
199
200impl<T: Hash + Eq + Clone, B: SetBias> Merge for ORSet<T, B> {
201 fn merge(&mut self, other: &Self) {
202 let all_keys: HashSet<T> = self
204 .entries
205 .keys()
206 .chain(other.entries.keys())
207 .cloned()
208 .collect();
209
210 for value in all_keys {
211 let my_dots_before: HashSet<Dot> = self
212 .entries
213 .get(&value)
214 .cloned()
215 .unwrap_or_default();
216 let other_dots: HashSet<Dot> = other
217 .entries
218 .get(&value)
219 .cloned()
220 .unwrap_or_default();
221
222 let my_removed = my_dots_before.is_empty()
227 && other_dots.iter().any(|dot| self.context.has_seen(dot));
228 let other_removed = other_dots.is_empty()
229 && my_dots_before.iter().any(|dot| other.context.has_seen(dot));
230
231 let mut combined_dots: HashSet<Dot> = HashSet::new();
233
234 for dot in &my_dots_before {
236 if !other.context.has_seen(dot) || other_dots.contains(dot) {
237 combined_dots.insert(*dot);
238 }
239 }
240
241 for dot in &other_dots {
243 if !self.context.has_seen(dot) || my_dots_before.contains(dot) {
244 combined_dots.insert(*dot);
245 }
246 }
247
248 let my_has_dots = !my_dots_before.is_empty();
249 let other_has_dots = !other_dots.is_empty();
250
251 let keep = B::resolve(my_has_dots, other_has_dots, my_removed, other_removed);
253
254 let my_dots = self.entries.entry(value).or_default();
255 if keep {
256 *my_dots = combined_dots;
257 } else {
258 my_dots.clear();
259 }
260 }
261
262 self.context.merge(&other.context);
264
265 self.entries.retain(|_, dots| !dots.is_empty());
267 }
268}
269
270impl<T: Hash + Eq + Clone + Serialize + DeserializeOwned + Send + 'static, B: SetBias> DeltaCrdt
271 for ORSet<T, B>
272{
273 type Delta = ORSetDelta<T>;
274
275 fn delta_since(&self, since: &VClock) -> Option<Self::Delta> {
276 let current = self.version();
277 if since.dominates(¤t) {
278 return None;
279 }
280
281 Some(ORSetDelta {
283 entries: self.entries.clone(),
284 context: self.context.clone(),
285 })
286 }
287
288 fn apply_delta(&mut self, delta: &Self::Delta) {
289 let temp: ORSet<T, B> = ORSet {
291 entries: delta.entries.clone(),
292 context: delta.context.clone(),
293 replica_id: 0, _bias: PhantomData,
295 };
296 self.merge(&temp);
297 }
298
299 fn version(&self) -> VClock {
300 self.context.version()
301 }
302}
303
304impl<T: Hash + Eq + Clone, B: SetBias> Default for ORSet<T, B> {
305 fn default() -> Self {
306 Self::new_random()
307 }
308}
309
310#[cfg(test)]
311mod tests {
312 use super::*;
313
314 #[test]
315 fn test_orset_add_contains() {
316 let mut set: ORSet<String> = ORSet::new(1);
317 set.add("alice".to_string());
318 assert!(set.contains(&"alice".to_string()));
319 assert!(!set.contains(&"bob".to_string()));
320 }
321
322 #[test]
323 fn test_orset_remove() {
324 let mut set: ORSet<String> = ORSet::new(1);
325 set.add("alice".to_string());
326 set.remove(&"alice".to_string());
327 assert!(!set.contains(&"alice".to_string()));
328 }
329
330 #[test]
331 fn test_orset_add_wins() {
332 let mut a: ORSet<String> = ORSet::new(1);
333 let mut b: ORSet<String> = ORSet::new(2);
334
335 a.add("item".to_string());
336 b.merge(&a);
337
338 a.remove(&"item".to_string());
339 b.add("item".to_string());
340
341 a.merge(&b);
342 assert!(a.contains(&"item".to_string()));
343 }
344
345 #[test]
346 fn test_orset_merge_commutative() {
347 let mut a: ORSet<String> = ORSet::new(1);
348 let mut b: ORSet<String> = ORSet::new(2);
349
350 a.add("x".to_string());
351 b.add("y".to_string());
352
353 let mut a1 = a.clone();
354 let mut b1 = b.clone();
355 a1.merge(&b);
356 b1.merge(&a);
357
358 assert_eq!(a1.len(), b1.len());
359 }
360}