1use crate::{
39 error::Error,
40 id::{OpId, ReplicaId},
41 version::VersionVector,
42};
43use smallvec::SmallVec;
44use std::collections::HashMap;
45use std::hash::Hash;
46
47#[cfg(feature = "serde")]
48use serde::{Deserialize, Serialize};
49
50#[derive(Clone, Debug, PartialEq, Eq)]
56#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
57pub enum SetOp<T> {
58 Add {
60 id: OpId,
62 value: T,
64 },
65 Remove {
71 id: OpId,
73 value: T,
75 tags: Vec<OpId>,
77 },
78}
79
80impl<T> SetOp<T> {
81 #[must_use]
83 pub fn id(&self) -> OpId {
84 match self {
85 SetOp::Add { id, .. } | SetOp::Remove { id, .. } => *id,
86 }
87 }
88}
89
90#[derive(Clone, Debug)]
96pub struct Set<T: Eq + Hash + Clone> {
97 replica: ReplicaId,
98 clock: u64,
99 tags: HashMap<T, SmallVec<[OpId; 2]>>,
102 log: Vec<SetOp<T>>,
103 version: VersionVector,
104}
105
106impl<T: Eq + Hash + Clone> Set<T> {
107 #[must_use]
109 pub fn new(replica: ReplicaId) -> Self {
110 Self {
111 replica,
112 clock: 0,
113 tags: HashMap::new(),
114 log: Vec::new(),
115 version: VersionVector::new(),
116 }
117 }
118
119 #[must_use]
122 pub fn new_random() -> Self {
123 Self::new(crate::id::new_replica_id())
124 }
125
126 #[must_use]
128 pub fn replica_id(&self) -> ReplicaId {
129 self.replica
130 }
131
132 #[must_use]
134 pub fn len(&self) -> usize {
135 self.tags.values().filter(|t| !t.is_empty()).count()
136 }
137
138 #[must_use]
140 pub fn is_empty(&self) -> bool {
141 self.len() == 0
142 }
143
144 pub fn contains(&self, value: &T) -> bool {
146 self.tags.get(value).is_some_and(|t| !t.is_empty())
147 }
148
149 pub fn iter(&self) -> impl Iterator<Item = &T> + '_ {
151 self.tags
152 .iter()
153 .filter_map(|(v, t)| if t.is_empty() { None } else { Some(v) })
154 }
155
156 pub fn add(&mut self, value: T) -> SetOp<T> {
160 self.clock = self
161 .clock
162 .checked_add(1)
163 .expect("Lamport clock overflow (>2^64 ops)");
164 let id = OpId::new(self.clock, self.replica);
165 let op = SetOp::Add {
166 id,
167 value: value.clone(),
168 };
169 self.tags.entry(value).or_default().push(id);
170 self.version.observe(id);
171 self.log.push(op.clone());
172 op
173 }
174
175 pub fn remove(&mut self, value: &T) -> Option<SetOp<T>> {
177 let observed: Vec<OpId> = match self.tags.get(value) {
178 Some(t) if !t.is_empty() => t.iter().copied().collect(),
179 _ => return None,
180 };
181 self.clock = self
182 .clock
183 .checked_add(1)
184 .expect("Lamport clock overflow (>2^64 ops)");
185 let id = OpId::new(self.clock, self.replica);
186 let op = SetOp::Remove {
187 id,
188 value: value.clone(),
189 tags: observed.clone(),
190 };
191 if let Some(slot) = self.tags.get_mut(value) {
193 slot.retain(|t| !observed.contains(t));
194 }
195 self.version.observe(id);
196 self.log.push(op.clone());
197 Some(op)
198 }
199
200 pub fn apply(&mut self, op: SetOp<T>) -> Result<(), Error> {
202 let op_id = op.id();
203 if self.version.contains(op_id) {
204 return Ok(());
205 }
206 match &op {
207 SetOp::Add { id, value } => {
208 self.tags.entry(value.clone()).or_default().push(*id);
209 }
210 SetOp::Remove { id: _, value, tags } => {
211 if let Some(slot) = self.tags.get_mut(value) {
212 slot.retain(|t| !tags.contains(t));
213 }
214 }
215 }
216 self.version.observe(op_id);
217 self.clock = self.clock.max(op_id.counter);
218 self.log.push(op);
219 Ok(())
220 }
221
222 pub fn merge(&mut self, other: &Self) {
224 let mut to_apply: Vec<&SetOp<T>> = other
225 .log
226 .iter()
227 .filter(|op| !self.version.contains(op.id()))
228 .collect();
229 to_apply.sort_by_key(|op| op.id());
230 for op in to_apply {
231 self.apply(op.clone()).expect("set apply cannot fail");
232 }
233 }
234
235 #[must_use]
237 pub fn ops(&self) -> &[SetOp<T>] {
238 &self.log
239 }
240
241 pub fn ops_since<'a>(
243 &'a self,
244 since: &'a VersionVector,
245 ) -> impl Iterator<Item = &'a SetOp<T>> + 'a {
246 self.log.iter().filter(move |op| !since.contains(op.id()))
247 }
248
249 #[must_use]
251 pub fn version(&self) -> &VersionVector {
252 &self.version
253 }
254}
255
256impl<T: Eq + Hash + Clone> Default for Set<T> {
257 fn default() -> Self {
258 Self::new(0)
259 }
260}
261
262#[cfg(feature = "serde")]
267#[derive(Serialize, Deserialize)]
268struct SetSnapshot<T> {
269 replica: ReplicaId,
270 clock: u64,
271 tags: Vec<(T, SmallVec<[OpId; 2]>)>,
272 version: VersionVector,
273 log: Vec<SetOp<T>>,
274}
275
276#[cfg(feature = "serde")]
277impl<T> Serialize for Set<T>
278where
279 T: Eq + Hash + Clone + Serialize,
280{
281 fn serialize<S: serde::Serializer>(&self, ser: S) -> Result<S::Ok, S::Error> {
282 let snap = SetSnapshot {
283 replica: self.replica,
284 clock: self.clock,
285 tags: self
286 .tags
287 .iter()
288 .map(|(k, v)| (k.clone(), v.clone()))
289 .collect(),
290 version: self.version.clone(),
291 log: self.log.clone(),
292 };
293 snap.serialize(ser)
294 }
295}
296
297#[cfg(feature = "serde")]
298impl<'de, T> Deserialize<'de> for Set<T>
299where
300 T: Eq + Hash + Clone + Deserialize<'de>,
301{
302 fn deserialize<D: serde::Deserializer<'de>>(de: D) -> Result<Self, D::Error> {
303 let snap = SetSnapshot::<T>::deserialize(de)?;
304 Ok(Set {
305 replica: snap.replica,
306 clock: snap.clock,
307 tags: snap.tags.into_iter().collect(),
308 version: snap.version,
309 log: snap.log,
310 })
311 }
312}
313
314#[cfg(test)]
319mod tests {
320 use super::*;
321
322 #[test]
323 fn empty_set() {
324 let s: Set<&str> = Set::new(1);
325 assert!(s.is_empty());
326 assert!(!s.contains(&"x"));
327 }
328
329 #[test]
330 fn add_and_contains() {
331 let mut s: Set<&str> = Set::new(1);
332 s.add("a");
333 s.add("b");
334 assert!(s.contains(&"a"));
335 assert!(s.contains(&"b"));
336 assert_eq!(s.len(), 2);
337 }
338
339 #[test]
340 fn remove_drops_value() {
341 let mut s: Set<&str> = Set::new(1);
342 s.add("a");
343 let op = s.remove(&"a");
344 assert!(op.is_some());
345 assert!(!s.contains(&"a"));
346 }
347
348 #[test]
349 fn remove_absent_returns_none() {
350 let mut s: Set<&str> = Set::new(1);
351 assert!(s.remove(&"x").is_none());
352 }
353
354 #[test]
355 fn add_wins_over_concurrent_remove() {
356 let mut a: Set<&str> = Set::new(1);
357 let mut b: Set<&str> = Set::new(2);
358 a.add("x");
359 b.merge(&a);
360
361 a.remove(&"x"); b.add("x"); let mut a2 = a.clone();
365 a2.merge(&b);
366 let mut b2 = b.clone();
367 b2.merge(&a);
368
369 assert!(a2.contains(&"x"));
371 assert!(b2.contains(&"x"));
372 }
373
374 #[test]
375 fn double_add_then_single_remove_keeps_value() {
376 let mut a: Set<&str> = Set::new(1);
379 a.add("x");
380 a.add("x");
381 a.remove(&"x");
382 assert!(!a.contains(&"x"));
383 }
384
385 #[test]
386 fn idempotent_apply() {
387 let mut a: Set<&str> = Set::new(1);
388 let op1 = a.add("x");
389 let op2 = a.add("y");
390
391 let mut b: Set<&str> = Set::new(2);
392 b.apply(op1.clone()).unwrap();
393 b.apply(op2.clone()).unwrap();
394 b.apply(op1).unwrap();
395 b.apply(op2).unwrap();
396 assert!(b.contains(&"x"));
397 assert!(b.contains(&"y"));
398 }
399
400 #[test]
401 fn merge_is_commutative() {
402 let mut a1: Set<&str> = Set::new(1);
403 let mut a2: Set<&str> = Set::new(1);
404 let mut b1: Set<&str> = Set::new(2);
405 let mut b2: Set<&str> = Set::new(2);
406 a1.add("x");
407 a2.add("x");
408 b1.add("y");
409 b2.add("y");
410 a1.merge(&b1);
411 b2.merge(&a2);
412 assert_eq!(a1.len(), b2.len());
413 assert!(a1.contains(&"x") && a1.contains(&"y"));
414 assert!(b2.contains(&"x") && b2.contains(&"y"));
415 }
416}