pollen-crdt 0.1.0

CRDT synchronization for Pollen
Documentation
//! CRDT type implementations using rust-crdt library.
//!
//! This module provides type-safe wrappers around the rust-crdt types
//! for use in the Pollen distributed scheduler.

use crate::CrdtValue;
use crdts::{CmRDT, CvRDT};
use num_traits::ToPrimitive;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use std::fmt::Debug;
use std::hash::Hash;

/// Observed-Remove Set Without Tombstones (OR-Set).
///
/// A set CRDT where elements can be added and removed concurrently.
/// Uses the Observed-Remove semantics where concurrent add/remove
/// of the same element results in the element being present.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct OrSet<T>
where
    T: Clone + Eq + Hash + Debug + Send + Sync + 'static,
{
    inner: crdts::orswot::Orswot<T, u64>,
}

impl<T> OrSet<T>
where
    T: Clone + Eq + Hash + Debug + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
{
    /// Create a new empty OR-Set.
    pub fn new() -> Self {
        Self {
            inner: crdts::orswot::Orswot::new(),
        }
    }

    /// Add an element to the set.
    pub fn add(&mut self, element: T, actor: u64) {
        let read_ctx = self.inner.read_ctx();
        let op = self.inner.add(element, read_ctx.derive_add_ctx(actor));
        self.inner.apply(op);
    }

    /// Remove an element from the set.
    pub fn remove(&mut self, element: &T, _actor: u64) {
        let read_ctx = self.inner.read_ctx();
        let op = self.inner.rm(element.clone(), read_ctx.derive_rm_ctx());
        self.inner.apply(op);
    }

    /// Check if the set contains an element.
    pub fn contains(&self, element: &T) -> bool {
        self.inner.read().val.contains(element)
    }

    /// Get all elements in the set.
    pub fn elements(&self) -> HashSet<T> {
        self.inner.read().val.clone()
    }

    /// Get the number of elements.
    pub fn len(&self) -> usize {
        self.inner.read().val.len()
    }

    /// Check if empty.
    pub fn is_empty(&self) -> bool {
        self.inner.read().val.is_empty()
    }
}

impl<T> Default for OrSet<T>
where
    T: Clone + Eq + Hash + Debug + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
{
    fn default() -> Self {
        Self::new()
    }
}

impl<T> CrdtValue for OrSet<T>
where
    T: Clone + Eq + Hash + Debug + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
{
    fn merge(&mut self, other: &Self) {
        self.inner.merge(other.inner.clone());
    }
}

/// Grow-only Counter (G-Counter).
///
/// A counter that can only be incremented, never decremented.
/// Useful for counting events across distributed nodes.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct GCounter {
    inner: crdts::gcounter::GCounter<u64>,
}

impl GCounter {
    /// Create a new counter starting at 0.
    pub fn new() -> Self {
        Self {
            inner: crdts::gcounter::GCounter::new(),
        }
    }

    /// Increment the counter by 1 for the given actor.
    pub fn increment(&mut self, actor: u64) {
        self.inner.apply(self.inner.inc(actor));
    }

    /// Increment the counter by a specific amount for the given actor.
    pub fn increment_by(&mut self, actor: u64, amount: u64) {
        for _ in 0..amount {
            self.inner.apply(self.inner.inc(actor));
        }
    }

    /// Get the current value.
    pub fn value(&self) -> u64 {
        self.inner.read().to_u64().unwrap_or(0)
    }
}

impl Default for GCounter {
    fn default() -> Self {
        Self::new()
    }
}

impl CrdtValue for GCounter {
    fn merge(&mut self, other: &Self) {
        self.inner.merge(other.inner.clone());
    }
}

/// Positive-Negative Counter (PN-Counter).
///
/// A counter that can be both incremented and decremented.
/// Implemented as a pair of G-Counters.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct PnCounter {
    inner: crdts::pncounter::PNCounter<u64>,
}

impl PnCounter {
    /// Create a new counter starting at 0.
    pub fn new() -> Self {
        Self {
            inner: crdts::pncounter::PNCounter::new(),
        }
    }

    /// Increment the counter.
    pub fn increment(&mut self, actor: u64) {
        self.inner.apply(self.inner.inc(actor));
    }

    /// Decrement the counter.
    pub fn decrement(&mut self, actor: u64) {
        self.inner.apply(self.inner.dec(actor));
    }

    /// Get the current value.
    pub fn value(&self) -> i64 {
        self.inner.read().to_i64().unwrap_or(0)
    }
}

impl Default for PnCounter {
    fn default() -> Self {
        Self::new()
    }
}

impl CrdtValue for PnCounter {
    fn merge(&mut self, other: &Self) {
        self.inner.merge(other.inner.clone());
    }
}

/// Multi-Value Register (MV-Register).
///
/// A register that can hold multiple concurrent values.
/// When there are concurrent writes, all values are preserved
/// until a subsequent write that "observes" them.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct MvRegister<T>
where
    T: Clone + Debug + Send + Sync + 'static,
{
    inner: crdts::mvreg::MVReg<T, u64>,
}

impl<T> MvRegister<T>
where
    T: Clone + Debug + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
{
    /// Create a new empty register.
    pub fn new() -> Self {
        Self {
            inner: crdts::mvreg::MVReg::new(),
        }
    }

    /// Set the value, observing any concurrent values.
    pub fn set(&mut self, value: T, actor: u64) {
        let read_ctx = self.inner.read_ctx();
        let op = self.inner.write(value, read_ctx.derive_add_ctx(actor));
        self.inner.apply(op);
    }

    /// Get all current values (may be multiple if there are concurrent writes).
    pub fn values(&self) -> Vec<T> {
        self.inner.read().val.into_iter().collect()
    }

    /// Get the first value if any.
    pub fn value(&self) -> Option<T> {
        self.inner.read().val.into_iter().next()
    }
}

impl<T> Default for MvRegister<T>
where
    T: Clone + Debug + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
{
    fn default() -> Self {
        Self::new()
    }
}

impl<T> CrdtValue for MvRegister<T>
where
    T: Clone + Debug + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
{
    fn merge(&mut self, other: &Self) {
        self.inner.merge(other.inner.clone());
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_orset_add_remove() {
        let mut set1: OrSet<String> = OrSet::new();
        let mut set2: OrSet<String> = OrSet::new();

        // Add on node 1
        set1.add("a".to_string(), 1);
        set1.add("b".to_string(), 1);

        // Add on node 2
        set2.add("b".to_string(), 2);
        set2.add("c".to_string(), 2);

        // Merge
        set1.merge(&set2);

        // Should have all elements
        assert!(set1.contains(&"a".to_string()));
        assert!(set1.contains(&"b".to_string()));
        assert!(set1.contains(&"c".to_string()));
        assert_eq!(set1.len(), 3);
    }

    #[test]
    fn test_gcounter_merge() {
        let mut counter1 = GCounter::new();
        let mut counter2 = GCounter::new();

        // Increment on node 1
        counter1.increment(1);
        counter1.increment(1);

        // Increment on node 2
        counter2.increment(2);
        counter2.increment(2);
        counter2.increment(2);

        // Merge
        counter1.merge(&counter2);

        // Should be sum of both
        assert_eq!(counter1.value(), 5);
    }

    #[test]
    fn test_pncounter_merge() {
        let mut counter1 = PnCounter::new();
        let mut counter2 = PnCounter::new();

        // Node 1: +3
        counter1.increment(1);
        counter1.increment(1);
        counter1.increment(1);

        // Node 2: +2, -1 = +1
        counter2.increment(2);
        counter2.increment(2);
        counter2.decrement(2);

        // Merge
        counter1.merge(&counter2);

        // Should be 3 + 1 = 4
        assert_eq!(counter1.value(), 4);
    }

    #[test]
    fn test_mvregister_concurrent_writes() {
        let mut reg1: MvRegister<String> = MvRegister::new();
        let mut reg2: MvRegister<String> = MvRegister::new();

        // Concurrent writes
        reg1.set("value1".to_string(), 1);
        reg2.set("value2".to_string(), 2);

        // Merge - should have both values
        reg1.merge(&reg2);
        let values = reg1.values();
        assert!(values.contains(&"value1".to_string()) || values.contains(&"value2".to_string()));
    }
}