1use crate::CrdtMerge;
5use fxhash::{FxHashMap, FxHashSet};
6use serde::{Deserialize, Serialize};
7use std::hash::Hash;
8use uuid::Uuid;
9
10pub type Dot = (String, u64);
13
14const LEGACY_ACTOR: &str = "__legacy__";
17
18fn new_actor() -> String {
20 Uuid::new_v4().to_string()
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
46#[serde(
47 from = "ORSetWire<T>",
48 into = "ORSetWireV2<T>",
49 bound(
50 serialize = "T: Serialize + Hash + Eq + Clone",
51 deserialize = "T: Deserialize<'de> + Hash + Eq + Clone"
52 )
53)]
54pub struct ORSet<T: Hash + Eq + Clone> {
55 dots: FxHashMap<T, FxHashSet<Dot>>,
57 vv: FxHashMap<String, u64>,
59 actor: String,
61}
62
63impl<T: Hash + Eq + Clone> Default for ORSet<T> {
64 fn default() -> Self {
65 Self {
66 dots: FxHashMap::default(),
67 vv: FxHashMap::default(),
68 actor: new_actor(),
69 }
70 }
71}
72
73impl<T: Hash + Eq + Clone> ORSet<T> {
74 pub fn new() -> Self {
76 Self::default()
77 }
78
79 pub fn fork(&self) -> Self {
85 let mut forked = self.clone();
86 forked.actor = new_actor();
87 forked
88 }
89
90 pub fn add(&mut self, element: T) -> Dot {
93 let counter = self.vv.entry(self.actor.clone()).or_insert(0);
94 *counter += 1;
95 let dot: Dot = (self.actor.clone(), *counter);
96 let mut set = FxHashSet::default();
97 set.insert(dot.clone());
98 self.dots.insert(element, set);
101 dot
102 }
103
104 pub fn remove(&mut self, element: &T) {
109 self.dots.remove(element);
110 }
111
112 pub fn contains(&self, element: &T) -> bool {
114 self.dots.get(element).is_some_and(|dots| !dots.is_empty())
115 }
116
117 pub fn elements(&self) -> Vec<T> {
119 self.dots
120 .iter()
121 .filter(|(_, dots)| !dots.is_empty())
122 .map(|(elem, _)| elem.clone())
123 .collect()
124 }
125
126 pub fn len(&self) -> usize {
128 self.dots.values().filter(|dots| !dots.is_empty()).count()
129 }
130
131 pub fn is_empty(&self) -> bool {
133 self.len() == 0
134 }
135}
136
137impl<T: Hash + Eq + Clone> PartialEq for ORSet<T> {
140 fn eq(&self, other: &Self) -> bool {
141 self.dots == other.dots && self.vv == other.vv
142 }
143}
144
145impl<T: Hash + Eq + Clone> CrdtMerge for ORSet<T> {
146 fn merge(&mut self, other: &Self) {
147 let mut keys: Vec<T> = Vec::new();
149 {
150 let mut seen: FxHashSet<&T> = FxHashSet::default();
151 for k in self.dots.keys().chain(other.dots.keys()) {
152 if seen.insert(k) {
153 keys.push(k.clone());
154 }
155 }
156 }
157
158 let empty: FxHashSet<Dot> = FxHashSet::default();
159 for key in keys {
160 let sd: FxHashSet<Dot> = self.dots.get(&key).cloned().unwrap_or_default();
162 let od: &FxHashSet<Dot> = other.dots.get(&key).unwrap_or(&empty);
163
164 let mut surviving: FxHashSet<Dot> = FxHashSet::default();
165 for d in sd.intersection(od) {
167 surviving.insert(d.clone());
168 }
169 for d in sd.difference(od) {
172 if d.1 > other.vv.get(&d.0).copied().unwrap_or(0) {
173 surviving.insert(d.clone());
174 }
175 }
176 for d in od.difference(&sd) {
178 if d.1 > self.vv.get(&d.0).copied().unwrap_or(0) {
179 surviving.insert(d.clone());
180 }
181 }
182
183 if surviving.is_empty() {
184 self.dots.remove(&key);
185 } else {
186 self.dots.insert(key, surviving);
187 }
188 }
189
190 for (actor, &counter) in &other.vv {
192 let entry = self.vv.entry(actor.clone()).or_insert(0);
193 if counter > *entry {
194 *entry = counter;
195 }
196 }
197 }
198}
199
200#[derive(Serialize)]
204#[serde(bound(serialize = "T: Serialize + Hash + Eq + Clone"))]
205struct ORSetWireV2<T: Hash + Eq + Clone> {
206 dots: FxHashMap<T, FxHashSet<Dot>>,
207 vv: FxHashMap<String, u64>,
208}
209
210impl<T: Hash + Eq + Clone> From<ORSet<T>> for ORSetWireV2<T> {
211 fn from(set: ORSet<T>) -> Self {
212 ORSetWireV2 {
213 dots: set.dots,
214 vv: set.vv,
215 }
216 }
217}
218
219#[derive(Deserialize)]
223#[serde(bound(deserialize = "T: Deserialize<'de> + Hash + Eq + Clone"))]
224struct ORSetWire<T: Hash + Eq + Clone> {
225 #[serde(default)]
226 dots: Option<FxHashMap<T, FxHashSet<Dot>>>,
227 #[serde(default)]
228 vv: Option<FxHashMap<String, u64>>,
229 #[serde(default)]
231 elements: Option<FxHashMap<T, FxHashSet<Uuid>>>,
232 #[serde(default)]
233 tombstones: Option<FxHashSet<Uuid>>,
234}
235
236impl<T: Hash + Eq + Clone> From<ORSetWire<T>> for ORSet<T> {
237 fn from(wire: ORSetWire<T>) -> Self {
238 let ORSetWire {
239 dots,
240 vv,
241 elements,
242 tombstones,
243 } = wire;
244
245 if let (Some(dots), Some(vv)) = (dots, vv) {
247 return ORSet {
248 dots,
249 vv,
250 actor: new_actor(),
251 };
252 }
253
254 let tombstones = tombstones.unwrap_or_default();
257 let mut new_dots: FxHashMap<T, FxHashSet<Dot>> = FxHashMap::default();
258 let mut counter: u64 = 0;
259 if let Some(elements) = elements {
260 for (elem, tags) in elements {
261 if tags.iter().any(|tag| !tombstones.contains(tag)) {
262 counter += 1;
263 let mut set = FxHashSet::default();
264 set.insert((LEGACY_ACTOR.to_string(), counter));
265 new_dots.insert(elem, set);
266 }
267 }
268 }
269 let mut new_vv = FxHashMap::default();
270 if counter > 0 {
271 new_vv.insert(LEGACY_ACTOR.to_string(), counter);
272 }
273 ORSet {
274 dots: new_dots,
275 vv: new_vv,
276 actor: new_actor(),
277 }
278 }
279}
280
281#[cfg(test)]
282mod tests {
283 use super::*;
284 use crate::Crdt;
285
286 #[test]
287 fn test_add_remove() {
288 let mut os = ORSet::new();
289 os.add("apple".to_string());
290 assert!(os.contains(&"apple".to_string()));
291
292 os.remove(&"apple".to_string());
293 assert!(!os.contains(&"apple".to_string()));
294 }
295
296 #[test]
297 fn test_add_wins() {
298 let mut a = ORSet::new();
299 a.add("apple".to_string());
300
301 let mut b = a.fork();
303 b.remove(&"apple".to_string());
304
305 a.add("apple".to_string());
307
308 a.merge(&b);
309
310 assert!(a.contains(&"apple".to_string()));
313 }
314
315 #[test]
316 fn test_merge() {
317 let mut a = ORSet::new();
318 a.add(1);
319 a.add(2);
320
321 let mut b = ORSet::new();
322 b.add(2);
323 b.add(3);
324
325 a.merge(&b);
326
327 let elements = a.elements();
328 assert!(elements.contains(&1));
329 assert!(elements.contains(&2));
330 assert!(elements.contains(&3));
331 assert_eq!(elements.len(), 3);
332 }
333
334 #[test]
335 fn merge_is_commutative_and_idempotent() {
336 let mut a = ORSet::new();
337 a.add("x".to_string());
338 let mut b = a.fork();
339 b.add("y".to_string());
340 b.remove(&"x".to_string());
341
342 let mut ab = a.clone();
343 ab.merge(&b);
344 let mut ba = b.clone();
345 ba.merge(&a);
346 assert_eq!(ab, ba, "merge must be commutative");
347
348 let mut ab2 = ab.clone();
350 ab2.merge(&b);
351 assert_eq!(ab, ab2, "merge must be idempotent");
352 }
353
354 #[test]
357 fn serialized_size_bounded_under_churn() {
358 let mut a = ORSet::new();
359 let mut b = a.fork();
360
361 let size_after =
362 |set: &ORSet<String>| -> usize { Crdt::ORSet(set.clone()).to_msgpack().unwrap().len() };
363
364 for _ in 0..1000 {
366 a.add("k".to_string());
367 a.remove(&"k".to_string());
368 b.add("k".to_string());
369 b.remove(&"k".to_string());
370 a.merge(&b);
371 b.merge(&a);
372 }
373
374 let bytes = size_after(&a);
377 assert!(
378 bytes < 256,
379 "serialized churned ORSet should stay small, got {bytes} bytes"
380 );
381 assert!(a.is_empty());
382 }
383
384 #[test]
387 fn v1_payload_decodes_and_upgrades() {
388 let v1 = serde_json::json!({
390 "t": "os",
391 "d": {
392 "elements": {
393 "live": ["6f9619ff-8b86-d011-b42d-00cf4fc964ff"],
394 "dead": ["7f9619ff-8b86-d011-b42d-00cf4fc964ff"]
395 },
396 "tombstones": ["7f9619ff-8b86-d011-b42d-00cf4fc964ff"]
397 }
398 });
399
400 let crdt: Crdt = serde_json::from_value(v1).expect("v1 payload must decode");
401 let Crdt::ORSet(os) = crdt else {
402 panic!("expected ORSet");
403 };
404 assert!(os.contains(&"live".to_string()), "live element preserved");
405 assert!(
406 !os.contains(&"dead".to_string()),
407 "tombstoned element dropped"
408 );
409 assert_eq!(os.len(), 1);
410
411 let json = serde_json::to_value(Crdt::ORSet(os)).unwrap();
413 let d = json.get("d").unwrap();
414 assert!(d.get("dots").is_some(), "re-serializes as v2 (dots)");
415 assert!(d.get("vv").is_some(), "re-serializes as v2 (vv)");
416 }
417
418 #[test]
419 fn v2_roundtrip_preserves_visibility() {
420 let mut os = ORSet::new();
421 os.add("a".to_string());
422 os.add("b".to_string());
423 os.remove(&"b".to_string());
424
425 let bytes = Crdt::ORSet(os.clone()).to_msgpack().unwrap();
426 let Crdt::ORSet(decoded) = Crdt::from_msgpack(&bytes).unwrap() else {
427 panic!("expected ORSet");
428 };
429 assert!(decoded.contains(&"a".to_string()));
430 assert!(!decoded.contains(&"b".to_string()));
431 assert_eq!(os, decoded, "v2 round-trip is state-preserving");
432 }
433}