use crate::{
error::Error,
id::{OpId, ReplicaId},
version::VersionVector,
};
use smallvec::SmallVec;
use std::collections::HashMap;
use std::hash::Hash;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum SetOp<T> {
Add {
id: OpId,
value: T,
},
Remove {
id: OpId,
value: T,
tags: Vec<OpId>,
},
}
impl<T> SetOp<T> {
#[must_use]
pub fn id(&self) -> OpId {
match self {
SetOp::Add { id, .. } | SetOp::Remove { id, .. } => *id,
}
}
}
#[derive(Clone, Debug)]
pub struct Set<T: Eq + Hash + Clone> {
replica: ReplicaId,
clock: u64,
tags: HashMap<T, SmallVec<[OpId; 2]>>,
log: Vec<SetOp<T>>,
version: VersionVector,
}
impl<T: Eq + Hash + Clone> Set<T> {
#[must_use]
pub fn new(replica: ReplicaId) -> Self {
Self {
replica,
clock: 0,
tags: HashMap::new(),
log: Vec::new(),
version: VersionVector::new(),
}
}
#[must_use]
pub fn new_random() -> Self {
Self::new(crate::id::new_replica_id())
}
#[must_use]
pub fn replica_id(&self) -> ReplicaId {
self.replica
}
#[must_use]
pub fn len(&self) -> usize {
self.tags.values().filter(|t| !t.is_empty()).count()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn contains(&self, value: &T) -> bool {
self.tags.get(value).is_some_and(|t| !t.is_empty())
}
pub fn iter(&self) -> impl Iterator<Item = &T> + '_ {
self.tags
.iter()
.filter_map(|(v, t)| if t.is_empty() { None } else { Some(v) })
}
pub fn add(&mut self, value: T) -> SetOp<T> {
self.clock = self
.clock
.checked_add(1)
.expect("Lamport clock overflow (>2^64 ops)");
let id = OpId::new(self.clock, self.replica);
let op = SetOp::Add {
id,
value: value.clone(),
};
self.tags.entry(value).or_default().push(id);
self.version.observe(id);
self.log.push(op.clone());
op
}
pub fn remove(&mut self, value: &T) -> Option<SetOp<T>> {
let observed: Vec<OpId> = match self.tags.get(value) {
Some(t) if !t.is_empty() => t.iter().copied().collect(),
_ => return None,
};
self.clock = self
.clock
.checked_add(1)
.expect("Lamport clock overflow (>2^64 ops)");
let id = OpId::new(self.clock, self.replica);
let op = SetOp::Remove {
id,
value: value.clone(),
tags: observed.clone(),
};
if let Some(slot) = self.tags.get_mut(value) {
slot.retain(|t| !observed.contains(t));
}
self.version.observe(id);
self.log.push(op.clone());
Some(op)
}
pub fn apply(&mut self, op: SetOp<T>) -> Result<(), Error> {
let op_id = op.id();
if self.version.contains(op_id) {
return Ok(());
}
match &op {
SetOp::Add { id, value } => {
self.tags.entry(value.clone()).or_default().push(*id);
}
SetOp::Remove { id: _, value, tags } => {
if let Some(slot) = self.tags.get_mut(value) {
slot.retain(|t| !tags.contains(t));
}
}
}
self.version.observe(op_id);
self.clock = self.clock.max(op_id.counter);
self.log.push(op);
Ok(())
}
pub fn merge(&mut self, other: &Self) {
let mut to_apply: Vec<&SetOp<T>> = other
.log
.iter()
.filter(|op| !self.version.contains(op.id()))
.collect();
to_apply.sort_by_key(|op| op.id());
for op in to_apply {
self.apply(op.clone()).expect("set apply cannot fail");
}
}
#[must_use]
pub fn ops(&self) -> &[SetOp<T>] {
&self.log
}
pub fn ops_since<'a>(
&'a self,
since: &'a VersionVector,
) -> impl Iterator<Item = &'a SetOp<T>> + 'a {
self.log.iter().filter(move |op| !since.contains(op.id()))
}
#[must_use]
pub fn version(&self) -> &VersionVector {
&self.version
}
}
impl<T: Eq + Hash + Clone> Default for Set<T> {
fn default() -> Self {
Self::new(0)
}
}
#[cfg(feature = "serde")]
#[derive(Serialize, Deserialize)]
struct SetSnapshot<T> {
replica: ReplicaId,
clock: u64,
tags: Vec<(T, SmallVec<[OpId; 2]>)>,
version: VersionVector,
log: Vec<SetOp<T>>,
}
#[cfg(feature = "serde")]
impl<T> Serialize for Set<T>
where
T: Eq + Hash + Clone + Serialize,
{
fn serialize<S: serde::Serializer>(&self, ser: S) -> Result<S::Ok, S::Error> {
let snap = SetSnapshot {
replica: self.replica,
clock: self.clock,
tags: self
.tags
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect(),
version: self.version.clone(),
log: self.log.clone(),
};
snap.serialize(ser)
}
}
#[cfg(feature = "serde")]
impl<'de, T> Deserialize<'de> for Set<T>
where
T: Eq + Hash + Clone + Deserialize<'de>,
{
fn deserialize<D: serde::Deserializer<'de>>(de: D) -> Result<Self, D::Error> {
let snap = SetSnapshot::<T>::deserialize(de)?;
Ok(Set {
replica: snap.replica,
clock: snap.clock,
tags: snap.tags.into_iter().collect(),
version: snap.version,
log: snap.log,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_set() {
let s: Set<&str> = Set::new(1);
assert!(s.is_empty());
assert!(!s.contains(&"x"));
}
#[test]
fn add_and_contains() {
let mut s: Set<&str> = Set::new(1);
s.add("a");
s.add("b");
assert!(s.contains(&"a"));
assert!(s.contains(&"b"));
assert_eq!(s.len(), 2);
}
#[test]
fn remove_drops_value() {
let mut s: Set<&str> = Set::new(1);
s.add("a");
let op = s.remove(&"a");
assert!(op.is_some());
assert!(!s.contains(&"a"));
}
#[test]
fn remove_absent_returns_none() {
let mut s: Set<&str> = Set::new(1);
assert!(s.remove(&"x").is_none());
}
#[test]
fn add_wins_over_concurrent_remove() {
let mut a: Set<&str> = Set::new(1);
let mut b: Set<&str> = Set::new(2);
a.add("x");
b.merge(&a);
a.remove(&"x"); b.add("x");
let mut a2 = a.clone();
a2.merge(&b);
let mut b2 = b.clone();
b2.merge(&a);
assert!(a2.contains(&"x"));
assert!(b2.contains(&"x"));
}
#[test]
fn double_add_then_single_remove_keeps_value() {
let mut a: Set<&str> = Set::new(1);
a.add("x");
a.add("x");
a.remove(&"x");
assert!(!a.contains(&"x"));
}
#[test]
fn idempotent_apply() {
let mut a: Set<&str> = Set::new(1);
let op1 = a.add("x");
let op2 = a.add("y");
let mut b: Set<&str> = Set::new(2);
b.apply(op1.clone()).unwrap();
b.apply(op2.clone()).unwrap();
b.apply(op1).unwrap();
b.apply(op2).unwrap();
assert!(b.contains(&"x"));
assert!(b.contains(&"y"));
}
#[test]
fn merge_is_commutative() {
let mut a1: Set<&str> = Set::new(1);
let mut a2: Set<&str> = Set::new(1);
let mut b1: Set<&str> = Set::new(2);
let mut b2: Set<&str> = Set::new(2);
a1.add("x");
a2.add("x");
b1.add("y");
b2.add("y");
a1.merge(&b1);
b2.merge(&a2);
assert_eq!(a1.len(), b2.len());
assert!(a1.contains(&"x") && a1.contains(&"y"));
assert!(b2.contains(&"x") && b2.contains(&"y"));
}
}