use crate::CrdtMerge;
use fxhash::{FxHashMap, FxHashSet};
use serde::{Deserialize, Serialize};
use std::hash::Hash;
use uuid::Uuid;
pub type Dot = (String, u64);
const LEGACY_ACTOR: &str = "__legacy__";
fn new_actor() -> String {
Uuid::new_v4().to_string()
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(
from = "ORSetWire<T>",
into = "ORSetWireV2<T>",
bound(
serialize = "T: Serialize + Hash + Eq + Clone",
deserialize = "T: Deserialize<'de> + Hash + Eq + Clone"
)
)]
pub struct ORSet<T: Hash + Eq + Clone> {
dots: FxHashMap<T, FxHashSet<Dot>>,
vv: FxHashMap<String, u64>,
actor: String,
}
impl<T: Hash + Eq + Clone> Default for ORSet<T> {
fn default() -> Self {
Self {
dots: FxHashMap::default(),
vv: FxHashMap::default(),
actor: new_actor(),
}
}
}
impl<T: Hash + Eq + Clone> ORSet<T> {
pub fn new() -> Self {
Self::default()
}
pub fn fork(&self) -> Self {
let mut forked = self.clone();
forked.actor = new_actor();
forked
}
pub fn add(&mut self, element: T) -> Dot {
let counter = self.vv.entry(self.actor.clone()).or_insert(0);
*counter += 1;
let dot: Dot = (self.actor.clone(), *counter);
let mut set = FxHashSet::default();
set.insert(dot.clone());
self.dots.insert(element, set);
dot
}
pub fn remove(&mut self, element: &T) {
self.dots.remove(element);
}
pub fn contains(&self, element: &T) -> bool {
self.dots.get(element).is_some_and(|dots| !dots.is_empty())
}
pub fn elements(&self) -> Vec<T> {
self.dots
.iter()
.filter(|(_, dots)| !dots.is_empty())
.map(|(elem, _)| elem.clone())
.collect()
}
pub fn len(&self) -> usize {
self.dots.values().filter(|dots| !dots.is_empty()).count()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
impl<T: Hash + Eq + Clone> PartialEq for ORSet<T> {
fn eq(&self, other: &Self) -> bool {
self.dots == other.dots && self.vv == other.vv
}
}
impl<T: Hash + Eq + Clone> CrdtMerge for ORSet<T> {
fn merge(&mut self, other: &Self) {
let mut keys: Vec<T> = Vec::new();
{
let mut seen: FxHashSet<&T> = FxHashSet::default();
for k in self.dots.keys().chain(other.dots.keys()) {
if seen.insert(k) {
keys.push(k.clone());
}
}
}
let empty: FxHashSet<Dot> = FxHashSet::default();
for key in keys {
let sd: FxHashSet<Dot> = self.dots.get(&key).cloned().unwrap_or_default();
let od: &FxHashSet<Dot> = other.dots.get(&key).unwrap_or(&empty);
let mut surviving: FxHashSet<Dot> = FxHashSet::default();
for d in sd.intersection(od) {
surviving.insert(d.clone());
}
for d in sd.difference(od) {
if d.1 > other.vv.get(&d.0).copied().unwrap_or(0) {
surviving.insert(d.clone());
}
}
for d in od.difference(&sd) {
if d.1 > self.vv.get(&d.0).copied().unwrap_or(0) {
surviving.insert(d.clone());
}
}
if surviving.is_empty() {
self.dots.remove(&key);
} else {
self.dots.insert(key, surviving);
}
}
for (actor, &counter) in &other.vv {
let entry = self.vv.entry(actor.clone()).or_insert(0);
if counter > *entry {
*entry = counter;
}
}
}
}
#[derive(Serialize)]
#[serde(bound(serialize = "T: Serialize + Hash + Eq + Clone"))]
struct ORSetWireV2<T: Hash + Eq + Clone> {
dots: FxHashMap<T, FxHashSet<Dot>>,
vv: FxHashMap<String, u64>,
}
impl<T: Hash + Eq + Clone> From<ORSet<T>> for ORSetWireV2<T> {
fn from(set: ORSet<T>) -> Self {
ORSetWireV2 {
dots: set.dots,
vv: set.vv,
}
}
}
#[derive(Deserialize)]
#[serde(bound(deserialize = "T: Deserialize<'de> + Hash + Eq + Clone"))]
struct ORSetWire<T: Hash + Eq + Clone> {
#[serde(default)]
dots: Option<FxHashMap<T, FxHashSet<Dot>>>,
#[serde(default)]
vv: Option<FxHashMap<String, u64>>,
#[serde(default)]
elements: Option<FxHashMap<T, FxHashSet<Uuid>>>,
#[serde(default)]
tombstones: Option<FxHashSet<Uuid>>,
}
impl<T: Hash + Eq + Clone> From<ORSetWire<T>> for ORSet<T> {
fn from(wire: ORSetWire<T>) -> Self {
let ORSetWire {
dots,
vv,
elements,
tombstones,
} = wire;
if let (Some(dots), Some(vv)) = (dots, vv) {
return ORSet {
dots,
vv,
actor: new_actor(),
};
}
let tombstones = tombstones.unwrap_or_default();
let mut new_dots: FxHashMap<T, FxHashSet<Dot>> = FxHashMap::default();
let mut counter: u64 = 0;
if let Some(elements) = elements {
for (elem, tags) in elements {
if tags.iter().any(|tag| !tombstones.contains(tag)) {
counter += 1;
let mut set = FxHashSet::default();
set.insert((LEGACY_ACTOR.to_string(), counter));
new_dots.insert(elem, set);
}
}
}
let mut new_vv = FxHashMap::default();
if counter > 0 {
new_vv.insert(LEGACY_ACTOR.to_string(), counter);
}
ORSet {
dots: new_dots,
vv: new_vv,
actor: new_actor(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Crdt;
#[test]
fn test_add_remove() {
let mut os = ORSet::new();
os.add("apple".to_string());
assert!(os.contains(&"apple".to_string()));
os.remove(&"apple".to_string());
assert!(!os.contains(&"apple".to_string()));
}
#[test]
fn test_add_wins() {
let mut a = ORSet::new();
a.add("apple".to_string());
let mut b = a.fork();
b.remove(&"apple".to_string());
a.add("apple".to_string());
a.merge(&b);
assert!(a.contains(&"apple".to_string()));
}
#[test]
fn test_merge() {
let mut a = ORSet::new();
a.add(1);
a.add(2);
let mut b = ORSet::new();
b.add(2);
b.add(3);
a.merge(&b);
let elements = a.elements();
assert!(elements.contains(&1));
assert!(elements.contains(&2));
assert!(elements.contains(&3));
assert_eq!(elements.len(), 3);
}
#[test]
fn merge_is_commutative_and_idempotent() {
let mut a = ORSet::new();
a.add("x".to_string());
let mut b = a.fork();
b.add("y".to_string());
b.remove(&"x".to_string());
let mut ab = a.clone();
ab.merge(&b);
let mut ba = b.clone();
ba.merge(&a);
assert_eq!(ab, ba, "merge must be commutative");
let mut ab2 = ab.clone();
ab2.merge(&b);
assert_eq!(ab, ab2, "merge must be idempotent");
}
#[test]
fn serialized_size_bounded_under_churn() {
let mut a = ORSet::new();
let mut b = a.fork();
let size_after =
|set: &ORSet<String>| -> usize { Crdt::ORSet(set.clone()).to_msgpack().unwrap().len() };
for _ in 0..1000 {
a.add("k".to_string());
a.remove(&"k".to_string());
b.add("k".to_string());
b.remove(&"k".to_string());
a.merge(&b);
b.merge(&a);
}
let bytes = size_after(&a);
assert!(
bytes < 256,
"serialized churned ORSet should stay small, got {bytes} bytes"
);
assert!(a.is_empty());
}
#[test]
fn v1_payload_decodes_and_upgrades() {
let v1 = serde_json::json!({
"t": "os",
"d": {
"elements": {
"live": ["6f9619ff-8b86-d011-b42d-00cf4fc964ff"],
"dead": ["7f9619ff-8b86-d011-b42d-00cf4fc964ff"]
},
"tombstones": ["7f9619ff-8b86-d011-b42d-00cf4fc964ff"]
}
});
let crdt: Crdt = serde_json::from_value(v1).expect("v1 payload must decode");
let Crdt::ORSet(os) = crdt else {
panic!("expected ORSet");
};
assert!(os.contains(&"live".to_string()), "live element preserved");
assert!(
!os.contains(&"dead".to_string()),
"tombstoned element dropped"
);
assert_eq!(os.len(), 1);
let json = serde_json::to_value(Crdt::ORSet(os)).unwrap();
let d = json.get("d").unwrap();
assert!(d.get("dots").is_some(), "re-serializes as v2 (dots)");
assert!(d.get("vv").is_some(), "re-serializes as v2 (vv)");
}
#[test]
fn v2_roundtrip_preserves_visibility() {
let mut os = ORSet::new();
os.add("a".to_string());
os.add("b".to_string());
os.remove(&"b".to_string());
let bytes = Crdt::ORSet(os.clone()).to_msgpack().unwrap();
let Crdt::ORSet(decoded) = Crdt::from_msgpack(&bytes).unwrap() else {
panic!("expected ORSet");
};
assert!(decoded.contains(&"a".to_string()));
assert!(!decoded.contains(&"b".to_string()));
assert_eq!(os, decoded, "v2 round-trip is state-preserving");
}
}