Skip to main content

pollen_crdt/
types.rs

1//! CRDT type implementations using rust-crdt library.
2//!
3//! This module provides type-safe wrappers around the rust-crdt types
4//! for use in the Pollen distributed scheduler.
5
6use crate::CrdtValue;
7use crdts::{CmRDT, CvRDT};
8use num_traits::ToPrimitive;
9use serde::{Deserialize, Serialize};
10use std::collections::HashSet;
11use std::fmt::Debug;
12use std::hash::Hash;
13
14/// Observed-Remove Set Without Tombstones (OR-Set).
15///
16/// A set CRDT where elements can be added and removed concurrently.
17/// Uses the Observed-Remove semantics where concurrent add/remove
18/// of the same element results in the element being present.
19#[derive(Clone, Debug, Serialize, Deserialize)]
20pub struct OrSet<T>
21where
22    T: Clone + Eq + Hash + Debug + Send + Sync + 'static,
23{
24    inner: crdts::orswot::Orswot<T, u64>,
25}
26
27impl<T> OrSet<T>
28where
29    T: Clone + Eq + Hash + Debug + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
30{
31    /// Create a new empty OR-Set.
32    pub fn new() -> Self {
33        Self {
34            inner: crdts::orswot::Orswot::new(),
35        }
36    }
37
38    /// Add an element to the set.
39    pub fn add(&mut self, element: T, actor: u64) {
40        let read_ctx = self.inner.read_ctx();
41        let op = self.inner.add(element, read_ctx.derive_add_ctx(actor));
42        self.inner.apply(op);
43    }
44
45    /// Remove an element from the set.
46    pub fn remove(&mut self, element: &T, _actor: u64) {
47        let read_ctx = self.inner.read_ctx();
48        let op = self.inner.rm(element.clone(), read_ctx.derive_rm_ctx());
49        self.inner.apply(op);
50    }
51
52    /// Check if the set contains an element.
53    pub fn contains(&self, element: &T) -> bool {
54        self.inner.read().val.contains(element)
55    }
56
57    /// Get all elements in the set.
58    pub fn elements(&self) -> HashSet<T> {
59        self.inner.read().val.clone()
60    }
61
62    /// Get the number of elements.
63    pub fn len(&self) -> usize {
64        self.inner.read().val.len()
65    }
66
67    /// Check if empty.
68    pub fn is_empty(&self) -> bool {
69        self.inner.read().val.is_empty()
70    }
71}
72
73impl<T> Default for OrSet<T>
74where
75    T: Clone + Eq + Hash + Debug + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
76{
77    fn default() -> Self {
78        Self::new()
79    }
80}
81
82impl<T> CrdtValue for OrSet<T>
83where
84    T: Clone + Eq + Hash + Debug + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
85{
86    fn merge(&mut self, other: &Self) {
87        self.inner.merge(other.inner.clone());
88    }
89}
90
91/// Grow-only Counter (G-Counter).
92///
93/// A counter that can only be incremented, never decremented.
94/// Useful for counting events across distributed nodes.
95#[derive(Clone, Debug, Serialize, Deserialize)]
96pub struct GCounter {
97    inner: crdts::gcounter::GCounter<u64>,
98}
99
100impl GCounter {
101    /// Create a new counter starting at 0.
102    pub fn new() -> Self {
103        Self {
104            inner: crdts::gcounter::GCounter::new(),
105        }
106    }
107
108    /// Increment the counter by 1 for the given actor.
109    pub fn increment(&mut self, actor: u64) {
110        self.inner.apply(self.inner.inc(actor));
111    }
112
113    /// Increment the counter by a specific amount for the given actor.
114    pub fn increment_by(&mut self, actor: u64, amount: u64) {
115        for _ in 0..amount {
116            self.inner.apply(self.inner.inc(actor));
117        }
118    }
119
120    /// Get the current value.
121    pub fn value(&self) -> u64 {
122        self.inner.read().to_u64().unwrap_or(0)
123    }
124}
125
126impl Default for GCounter {
127    fn default() -> Self {
128        Self::new()
129    }
130}
131
132impl CrdtValue for GCounter {
133    fn merge(&mut self, other: &Self) {
134        self.inner.merge(other.inner.clone());
135    }
136}
137
138/// Positive-Negative Counter (PN-Counter).
139///
140/// A counter that can be both incremented and decremented.
141/// Implemented as a pair of G-Counters.
142#[derive(Clone, Debug, Serialize, Deserialize)]
143pub struct PnCounter {
144    inner: crdts::pncounter::PNCounter<u64>,
145}
146
147impl PnCounter {
148    /// Create a new counter starting at 0.
149    pub fn new() -> Self {
150        Self {
151            inner: crdts::pncounter::PNCounter::new(),
152        }
153    }
154
155    /// Increment the counter.
156    pub fn increment(&mut self, actor: u64) {
157        self.inner.apply(self.inner.inc(actor));
158    }
159
160    /// Decrement the counter.
161    pub fn decrement(&mut self, actor: u64) {
162        self.inner.apply(self.inner.dec(actor));
163    }
164
165    /// Get the current value.
166    pub fn value(&self) -> i64 {
167        self.inner.read().to_i64().unwrap_or(0)
168    }
169}
170
171impl Default for PnCounter {
172    fn default() -> Self {
173        Self::new()
174    }
175}
176
177impl CrdtValue for PnCounter {
178    fn merge(&mut self, other: &Self) {
179        self.inner.merge(other.inner.clone());
180    }
181}
182
183/// Multi-Value Register (MV-Register).
184///
185/// A register that can hold multiple concurrent values.
186/// When there are concurrent writes, all values are preserved
187/// until a subsequent write that "observes" them.
188#[derive(Clone, Debug, Serialize, Deserialize)]
189pub struct MvRegister<T>
190where
191    T: Clone + Debug + Send + Sync + 'static,
192{
193    inner: crdts::mvreg::MVReg<T, u64>,
194}
195
196impl<T> MvRegister<T>
197where
198    T: Clone + Debug + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
199{
200    /// Create a new empty register.
201    pub fn new() -> Self {
202        Self {
203            inner: crdts::mvreg::MVReg::new(),
204        }
205    }
206
207    /// Set the value, observing any concurrent values.
208    pub fn set(&mut self, value: T, actor: u64) {
209        let read_ctx = self.inner.read_ctx();
210        let op = self.inner.write(value, read_ctx.derive_add_ctx(actor));
211        self.inner.apply(op);
212    }
213
214    /// Get all current values (may be multiple if there are concurrent writes).
215    pub fn values(&self) -> Vec<T> {
216        self.inner.read().val.into_iter().collect()
217    }
218
219    /// Get the first value if any.
220    pub fn value(&self) -> Option<T> {
221        self.inner.read().val.into_iter().next()
222    }
223}
224
225impl<T> Default for MvRegister<T>
226where
227    T: Clone + Debug + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
228{
229    fn default() -> Self {
230        Self::new()
231    }
232}
233
234impl<T> CrdtValue for MvRegister<T>
235where
236    T: Clone + Debug + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
237{
238    fn merge(&mut self, other: &Self) {
239        self.inner.merge(other.inner.clone());
240    }
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246
247    #[test]
248    fn test_orset_add_remove() {
249        let mut set1: OrSet<String> = OrSet::new();
250        let mut set2: OrSet<String> = OrSet::new();
251
252        // Add on node 1
253        set1.add("a".to_string(), 1);
254        set1.add("b".to_string(), 1);
255
256        // Add on node 2
257        set2.add("b".to_string(), 2);
258        set2.add("c".to_string(), 2);
259
260        // Merge
261        set1.merge(&set2);
262
263        // Should have all elements
264        assert!(set1.contains(&"a".to_string()));
265        assert!(set1.contains(&"b".to_string()));
266        assert!(set1.contains(&"c".to_string()));
267        assert_eq!(set1.len(), 3);
268    }
269
270    #[test]
271    fn test_gcounter_merge() {
272        let mut counter1 = GCounter::new();
273        let mut counter2 = GCounter::new();
274
275        // Increment on node 1
276        counter1.increment(1);
277        counter1.increment(1);
278
279        // Increment on node 2
280        counter2.increment(2);
281        counter2.increment(2);
282        counter2.increment(2);
283
284        // Merge
285        counter1.merge(&counter2);
286
287        // Should be sum of both
288        assert_eq!(counter1.value(), 5);
289    }
290
291    #[test]
292    fn test_pncounter_merge() {
293        let mut counter1 = PnCounter::new();
294        let mut counter2 = PnCounter::new();
295
296        // Node 1: +3
297        counter1.increment(1);
298        counter1.increment(1);
299        counter1.increment(1);
300
301        // Node 2: +2, -1 = +1
302        counter2.increment(2);
303        counter2.increment(2);
304        counter2.decrement(2);
305
306        // Merge
307        counter1.merge(&counter2);
308
309        // Should be 3 + 1 = 4
310        assert_eq!(counter1.value(), 4);
311    }
312
313    #[test]
314    fn test_mvregister_concurrent_writes() {
315        let mut reg1: MvRegister<String> = MvRegister::new();
316        let mut reg2: MvRegister<String> = MvRegister::new();
317
318        // Concurrent writes
319        reg1.set("value1".to_string(), 1);
320        reg2.set("value2".to_string(), 2);
321
322        // Merge - should have both values
323        reg1.merge(&reg2);
324        let values = reg1.values();
325        assert!(values.contains(&"value1".to_string()) || values.contains(&"value2".to_string()));
326    }
327}