use crate::CrdtValue;
use crdts::{CmRDT, CvRDT};
use num_traits::ToPrimitive;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use std::fmt::Debug;
use std::hash::Hash;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct OrSet<T>
where
T: Clone + Eq + Hash + Debug + Send + Sync + 'static,
{
inner: crdts::orswot::Orswot<T, u64>,
}
impl<T> OrSet<T>
where
T: Clone + Eq + Hash + Debug + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
{
pub fn new() -> Self {
Self {
inner: crdts::orswot::Orswot::new(),
}
}
pub fn add(&mut self, element: T, actor: u64) {
let read_ctx = self.inner.read_ctx();
let op = self.inner.add(element, read_ctx.derive_add_ctx(actor));
self.inner.apply(op);
}
pub fn remove(&mut self, element: &T, _actor: u64) {
let read_ctx = self.inner.read_ctx();
let op = self.inner.rm(element.clone(), read_ctx.derive_rm_ctx());
self.inner.apply(op);
}
pub fn contains(&self, element: &T) -> bool {
self.inner.read().val.contains(element)
}
pub fn elements(&self) -> HashSet<T> {
self.inner.read().val.clone()
}
pub fn len(&self) -> usize {
self.inner.read().val.len()
}
pub fn is_empty(&self) -> bool {
self.inner.read().val.is_empty()
}
}
impl<T> Default for OrSet<T>
where
T: Clone + Eq + Hash + Debug + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
{
fn default() -> Self {
Self::new()
}
}
impl<T> CrdtValue for OrSet<T>
where
T: Clone + Eq + Hash + Debug + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
{
fn merge(&mut self, other: &Self) {
self.inner.merge(other.inner.clone());
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct GCounter {
inner: crdts::gcounter::GCounter<u64>,
}
impl GCounter {
pub fn new() -> Self {
Self {
inner: crdts::gcounter::GCounter::new(),
}
}
pub fn increment(&mut self, actor: u64) {
self.inner.apply(self.inner.inc(actor));
}
pub fn increment_by(&mut self, actor: u64, amount: u64) {
for _ in 0..amount {
self.inner.apply(self.inner.inc(actor));
}
}
pub fn value(&self) -> u64 {
self.inner.read().to_u64().unwrap_or(0)
}
}
impl Default for GCounter {
fn default() -> Self {
Self::new()
}
}
impl CrdtValue for GCounter {
fn merge(&mut self, other: &Self) {
self.inner.merge(other.inner.clone());
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct PnCounter {
inner: crdts::pncounter::PNCounter<u64>,
}
impl PnCounter {
pub fn new() -> Self {
Self {
inner: crdts::pncounter::PNCounter::new(),
}
}
pub fn increment(&mut self, actor: u64) {
self.inner.apply(self.inner.inc(actor));
}
pub fn decrement(&mut self, actor: u64) {
self.inner.apply(self.inner.dec(actor));
}
pub fn value(&self) -> i64 {
self.inner.read().to_i64().unwrap_or(0)
}
}
impl Default for PnCounter {
fn default() -> Self {
Self::new()
}
}
impl CrdtValue for PnCounter {
fn merge(&mut self, other: &Self) {
self.inner.merge(other.inner.clone());
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct MvRegister<T>
where
T: Clone + Debug + Send + Sync + 'static,
{
inner: crdts::mvreg::MVReg<T, u64>,
}
impl<T> MvRegister<T>
where
T: Clone + Debug + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
{
pub fn new() -> Self {
Self {
inner: crdts::mvreg::MVReg::new(),
}
}
pub fn set(&mut self, value: T, actor: u64) {
let read_ctx = self.inner.read_ctx();
let op = self.inner.write(value, read_ctx.derive_add_ctx(actor));
self.inner.apply(op);
}
pub fn values(&self) -> Vec<T> {
self.inner.read().val.into_iter().collect()
}
pub fn value(&self) -> Option<T> {
self.inner.read().val.into_iter().next()
}
}
impl<T> Default for MvRegister<T>
where
T: Clone + Debug + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
{
fn default() -> Self {
Self::new()
}
}
impl<T> CrdtValue for MvRegister<T>
where
T: Clone + Debug + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
{
fn merge(&mut self, other: &Self) {
self.inner.merge(other.inner.clone());
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_orset_add_remove() {
let mut set1: OrSet<String> = OrSet::new();
let mut set2: OrSet<String> = OrSet::new();
set1.add("a".to_string(), 1);
set1.add("b".to_string(), 1);
set2.add("b".to_string(), 2);
set2.add("c".to_string(), 2);
set1.merge(&set2);
assert!(set1.contains(&"a".to_string()));
assert!(set1.contains(&"b".to_string()));
assert!(set1.contains(&"c".to_string()));
assert_eq!(set1.len(), 3);
}
#[test]
fn test_gcounter_merge() {
let mut counter1 = GCounter::new();
let mut counter2 = GCounter::new();
counter1.increment(1);
counter1.increment(1);
counter2.increment(2);
counter2.increment(2);
counter2.increment(2);
counter1.merge(&counter2);
assert_eq!(counter1.value(), 5);
}
#[test]
fn test_pncounter_merge() {
let mut counter1 = PnCounter::new();
let mut counter2 = PnCounter::new();
counter1.increment(1);
counter1.increment(1);
counter1.increment(1);
counter2.increment(2);
counter2.increment(2);
counter2.decrement(2);
counter1.merge(&counter2);
assert_eq!(counter1.value(), 4);
}
#[test]
fn test_mvregister_concurrent_writes() {
let mut reg1: MvRegister<String> = MvRegister::new();
let mut reg2: MvRegister<String> = MvRegister::new();
reg1.set("value1".to_string(), 1);
reg2.set("value2".to_string(), 2);
reg1.merge(®2);
let values = reg1.values();
assert!(values.contains(&"value1".to_string()) || values.contains(&"value2".to_string()));
}
}