use crate::{merkle::MerkleTree, CrdtEntry, CrdtEvent, CrdtKv, CrdtValue};
use async_trait::async_trait;
use bytes::Bytes;
use dashmap::DashMap;
use parking_lot::RwLock;
use pollen_clock::SharedClock;
use pollen_store::StoreBackend;
use pollen_types::{NodeId, Result};
use std::sync::Arc;
use tokio::sync::broadcast;
use tracing::info;
pub struct CrdtStore {
#[allow(dead_code)] node_id: NodeId,
clock: SharedClock,
store: Arc<StoreBackend>,
state: DashMap<String, CrdtEntry>,
merkle: RwLock<MerkleTree>,
event_tx: broadcast::Sender<CrdtEvent>,
prefix_subscribers: DashMap<String, broadcast::Sender<CrdtEvent>>,
}
impl CrdtStore {
pub fn new(node_id: NodeId, clock: SharedClock, store: Arc<StoreBackend>) -> Self {
let (event_tx, _) = broadcast::channel(1000);
Self {
node_id,
clock,
store,
state: DashMap::new(),
merkle: RwLock::new(MerkleTree::new()),
event_tx,
prefix_subscribers: DashMap::new(),
}
}
pub async fn load(&self) -> Result<()> {
let keys = self.store.read(|_r| {
Ok(Vec::<String>::new())
}).await?;
for key in keys {
let key_clone = key.clone();
if let Some(data) = self.store.read(move |r| r.get_crdt_snapshot(&key_clone)).await? {
if let Ok(entry) = bincode::deserialize::<CrdtEntry>(&data) {
self.state.insert(key.clone(), entry.clone());
self.merkle.write().insert(&key, &entry.data);
}
}
}
info!("Loaded {} CRDT entries from storage", self.state.len());
Ok(())
}
pub async fn apply_delta(&self, entry: CrdtEntry) -> Result<bool> {
let key = entry.key.clone();
let mut changed = false;
let should_apply = match self.state.get(&key) {
Some(existing) => {
if entry.timestamp > existing.timestamp {
true
} else if entry.timestamp == existing.timestamp {
entry.deleted && !existing.deleted
} else {
false
}
}
None => true,
};
if should_apply {
self.state.insert(key.clone(), entry.clone());
if entry.deleted {
self.merkle.write().remove(&key);
} else {
self.merkle.write().insert(&key, &entry.data);
}
let data = bincode::serialize(&entry)?;
let key_for_storage = key.clone();
let timestamp = entry.timestamp;
self.store.write(move |w| {
w.save_crdt_snapshot(&key_for_storage, &data, timestamp)
}).await?;
let event = if entry.deleted {
CrdtEvent::Deleted { key: key.clone() }
} else {
CrdtEvent::Updated { key: key.clone() }
};
let _ = self.event_tx.send(event.clone());
for sub in self.prefix_subscribers.iter() {
if key.starts_with(sub.key()) {
let _ = sub.value().send(event.clone());
}
}
changed = true;
}
Ok(changed)
}
pub fn merkle_root(&self) -> Bytes {
self.merkle.read().root_hash()
}
pub fn merkle_level(&self, level: usize) -> Vec<(String, Bytes)> {
self.merkle.read().level_hashes(level)
}
pub fn entries_in_range(&self, start: &str, end: &str) -> Vec<CrdtEntry> {
self.state
.iter()
.filter(|e| e.key().as_str() >= start && e.key().as_str() < end)
.map(|e| e.value().clone())
.collect()
}
pub fn all_entries(&self) -> Vec<CrdtEntry> {
self.state.iter().map(|e| e.value().clone()).collect()
}
}
#[async_trait]
impl CrdtKv for CrdtStore {
fn get<T: CrdtValue>(&self, key: &str) -> Option<T> {
self.state.get(key).and_then(|entry| {
if entry.deleted {
None
} else {
bincode::deserialize(&entry.data).ok()
}
})
}
async fn set<T: CrdtValue>(&self, key: &str, value: T) -> Result<()> {
let timestamp = self.clock.now();
let data = Bytes::from(bincode::serialize(&value)?);
let entry = CrdtEntry {
key: key.to_string(),
crdt_type: std::any::type_name::<T>().to_string(),
data,
timestamp: timestamp.as_u128() as u64,
deleted: false,
};
self.apply_delta(entry).await?;
Ok(())
}
async fn delete(&self, key: &str) -> Result<()> {
let timestamp = self.clock.now();
let entry = CrdtEntry::tombstone(key.to_string(), timestamp.as_u128() as u64);
self.apply_delta(entry).await?;
Ok(())
}
fn subscribe(&self, prefix: &str) -> broadcast::Receiver<CrdtEvent> {
if prefix.is_empty() {
return self.event_tx.subscribe();
}
let prefix = prefix.to_string();
let tx = self.prefix_subscribers
.entry(prefix)
.or_insert_with(|| {
let (tx, _) = broadcast::channel(100);
tx
})
.clone();
tx.subscribe()
}
async fn sync_with(&self, _peer: NodeId) -> Result<()> {
Ok(())
}
fn keys(&self) -> Vec<String> {
self.state
.iter()
.filter(|e| !e.value().deleted)
.map(|e| e.key().clone())
.collect()
}
fn keys_with_prefix(&self, prefix: &str) -> Vec<String> {
self.state
.iter()
.filter(|e| !e.value().deleted && e.key().starts_with(prefix))
.map(|e| e.key().clone())
.collect()
}
}
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct LwwRegister<T> {
value: T,
timestamp: u64,
}
impl<T: Clone + Send + Sync + 'static + serde::Serialize + serde::de::DeserializeOwned> LwwRegister<T> {
pub fn new(value: T, timestamp: u64) -> Self {
Self { value, timestamp }
}
pub fn value(&self) -> &T {
&self.value
}
pub fn timestamp(&self) -> u64 {
self.timestamp
}
}
impl<T: Clone + Send + Sync + 'static + serde::Serialize + serde::de::DeserializeOwned> CrdtValue for LwwRegister<T> {
fn merge(&mut self, other: &Self) {
if other.timestamp > self.timestamp {
self.value = other.value.clone();
self.timestamp = other.timestamp;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use pollen_store::{MemoryStore, StoreBackend};
#[tokio::test]
async fn test_crdt_store_basic() {
let node_id = NodeId::new();
let clock = pollen_clock::new_clock_with_id(node_id);
let store = Arc::new(StoreBackend::Memory(MemoryStore::new()));
let crdt = CrdtStore::new(node_id, clock, store);
crdt.set("test:key", LwwRegister::new("hello".to_string(), 1)).await.unwrap();
let value: Option<LwwRegister<String>> = crdt.get("test:key");
assert!(value.is_some());
assert_eq!(value.unwrap().value(), "hello");
crdt.delete("test:key").await.unwrap();
let value: Option<LwwRegister<String>> = crdt.get("test:key");
assert!(value.is_none());
}
}