use crate::lattice::Lattice;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::collections::BTreeMap;
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub struct Dot {
pub replica_id: String,
pub seq: u64,
}
impl Dot {
pub fn new(replica_id: impl Into<String>, seq: u64) -> Self {
Self {
replica_id: replica_id.into(),
seq,
}
}
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct CausalContext {
dots: std::collections::BTreeSet<Dot>,
}
impl CausalContext {
pub fn new() -> Self {
Self {
dots: std::collections::BTreeSet::new(),
}
}
pub fn add_dot(&mut self, dot: Dot) {
self.dots.insert(dot);
}
pub fn contains(&self, dot: &Dot) -> bool {
self.dots.contains(dot)
}
pub fn join(&self, other: &CausalContext) -> CausalContext {
let mut joined = self.clone();
for dot in &other.dots {
joined.add_dot(dot.clone());
}
joined
}
}
impl Default for CausalContext {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum MapValue {
Int(i64),
Text(String),
Bytes(Vec<u8>),
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct CRDTMap<K: Ord + Clone> {
entries: BTreeMap<K, BTreeMap<Dot, MapValue>>,
context: CausalContext,
local_seq: u64,
}
impl<K: Ord + Clone + Serialize> Serialize for CRDTMap<K> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
#[derive(Serialize)]
struct SerializableCRDTMap<'a, K: Ord + Clone + Serialize> {
entries: Vec<(&'a K, Vec<(&'a Dot, &'a MapValue)>)>,
context: &'a CausalContext,
}
let entries: Vec<_> = self
.entries
.iter()
.map(|(k, v)| (k, v.iter().collect::<Vec<_>>()))
.collect();
let serializable = SerializableCRDTMap {
entries,
context: &self.context,
};
serializable.serialize(serializer)
}
}
impl<'de, K: Ord + Clone + Deserialize<'de>> Deserialize<'de> for CRDTMap<K> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
struct DeserializableCRDTMap<K: Ord + Clone> {
entries: Vec<(K, Vec<(Dot, MapValue)>)>,
context: CausalContext,
}
let deserialized = DeserializableCRDTMap::<K>::deserialize(deserializer)?;
let entries: BTreeMap<K, BTreeMap<Dot, MapValue>> = deserialized
.entries
.into_iter()
.map(|(k, v)| (k, v.into_iter().collect()))
.collect();
Ok(Self {
entries,
context: deserialized.context,
local_seq: 0,
})
}
}
impl<K: Ord + Clone> CRDTMap<K> {
pub fn new() -> Self {
Self {
entries: BTreeMap::new(),
context: CausalContext::new(),
local_seq: 0,
}
}
pub fn put(&mut self, replica_id: &str, key: K, value: MapValue) -> Dot {
let dot = Dot::new(replica_id, self.local_seq);
self.local_seq += 1;
let entry = self.entries.entry(key).or_default();
entry.clear();
entry.insert(dot.clone(), value);
self.context.add_dot(dot.clone());
dot
}
pub fn get(&self, key: &K) -> Option<&MapValue> {
self.entries
.get(key)
.and_then(|entry| entry.values().next())
}
pub fn get_all(&self, key: &K) -> Vec<&MapValue> {
self.entries
.get(key)
.map(|entry| entry.values().collect())
.unwrap_or_default()
}
pub fn remove(&mut self, key: &K) {
if let Some(entry) = self.entries.get_mut(key) {
entry.clear();
}
}
pub fn contains_key(&self, key: &K) -> bool {
self.entries
.get(key)
.map(|entry| !entry.is_empty())
.unwrap_or(false)
}
pub fn keys(&self) -> impl Iterator<Item = &K> {
self.entries
.iter()
.filter_map(|(k, v)| if !v.is_empty() { Some(k) } else { None })
}
pub fn context(&self) -> &CausalContext {
&self.context
}
pub fn put_with_dot(&mut self, key: K, dot: Dot, value: MapValue) {
let entry = self.entries.entry(key).or_default();
entry.insert(dot.clone(), value);
self.context.add_dot(dot);
}
}
impl<K: Ord + Clone> Default for CRDTMap<K> {
fn default() -> Self {
Self::new()
}
}
impl<K: Ord + Clone> Lattice for CRDTMap<K> {
fn bottom() -> Self {
Self::new()
}
fn join(&self, other: &Self) -> Self {
let mut entries = self.entries.clone();
for (key, other_entry) in &other.entries {
let entry = entries.entry(key.clone()).or_default();
for (dot, value) in other_entry.iter() {
entry.insert(dot.clone(), value.clone());
}
}
Self {
entries,
context: self.context.join(&other.context),
local_seq: self.local_seq.max(other.local_seq),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_map_basic_operations() {
let mut map: CRDTMap<String> = CRDTMap::new();
map.put("replica1", "key1".to_string(), MapValue::Int(42));
assert_eq!(map.get(&"key1".to_string()), Some(&MapValue::Int(42)));
map.put(
"replica1",
"key2".to_string(),
MapValue::Text("hello".to_string()),
);
assert_eq!(
map.get(&"key2".to_string()),
Some(&MapValue::Text("hello".to_string()))
);
}
#[test]
fn test_map_remove() {
let mut map: CRDTMap<String> = CRDTMap::new();
map.put("replica1", "key1".to_string(), MapValue::Int(42));
assert!(map.contains_key(&"key1".to_string()));
map.remove(&"key1".to_string());
assert!(!map.contains_key(&"key1".to_string()));
}
#[test]
fn test_map_join_idempotent() {
let mut map1: CRDTMap<String> = CRDTMap::new();
map1.put("replica1", "key1".to_string(), MapValue::Int(42));
let joined = map1.join(&map1);
assert_eq!(joined.get(&"key1".to_string()), Some(&MapValue::Int(42)));
}
#[test]
fn test_map_join_commutative() {
let mut map1: CRDTMap<String> = CRDTMap::new();
map1.put("replica1", "key1".to_string(), MapValue::Int(42));
let mut map2: CRDTMap<String> = CRDTMap::new();
map2.put(
"replica2",
"key2".to_string(),
MapValue::Text("world".to_string()),
);
let joined1 = map1.join(&map2);
let joined2 = map2.join(&map1);
assert_eq!(joined1.get(&"key1".to_string()), Some(&MapValue::Int(42)));
assert_eq!(
joined1.get(&"key2".to_string()),
Some(&MapValue::Text("world".to_string()))
);
assert_eq!(joined2.get(&"key1".to_string()), Some(&MapValue::Int(42)));
assert_eq!(
joined2.get(&"key2".to_string()),
Some(&MapValue::Text("world".to_string()))
);
}
#[test]
fn test_map_join_associative() {
let mut map1: CRDTMap<String> = CRDTMap::new();
map1.put("replica1", "key1".to_string(), MapValue::Int(1));
let mut map2: CRDTMap<String> = CRDTMap::new();
map2.put("replica2", "key2".to_string(), MapValue::Int(2));
let mut map3: CRDTMap<String> = CRDTMap::new();
map3.put("replica3", "key3".to_string(), MapValue::Int(3));
let left = map1.join(&map2).join(&map3);
let right = map1.join(&map2.join(&map3));
assert_eq!(left.get(&"key1".to_string()), Some(&MapValue::Int(1)));
assert_eq!(left.get(&"key2".to_string()), Some(&MapValue::Int(2)));
assert_eq!(left.get(&"key3".to_string()), Some(&MapValue::Int(3)));
assert_eq!(right.get(&"key1".to_string()), Some(&MapValue::Int(1)));
assert_eq!(right.get(&"key2".to_string()), Some(&MapValue::Int(2)));
assert_eq!(right.get(&"key3".to_string()), Some(&MapValue::Int(3)));
}
#[test]
fn test_map_concurrent_writes_different_keys() {
let mut map1: CRDTMap<String> = CRDTMap::new();
map1.put("replica1", "key1".to_string(), MapValue::Int(10));
let mut map2: CRDTMap<String> = CRDTMap::new();
map2.put("replica2", "key2".to_string(), MapValue::Int(20));
let merged = map1.join(&map2);
assert_eq!(merged.get(&"key1".to_string()), Some(&MapValue::Int(10)));
assert_eq!(merged.get(&"key2".to_string()), Some(&MapValue::Int(20)));
}
#[test]
fn test_map_serialization() {
let mut map: CRDTMap<String> = CRDTMap::new();
map.put("replica1", "key1".to_string(), MapValue::Int(42));
map.put(
"replica1",
"key2".to_string(),
MapValue::Text("hello".to_string()),
);
let serialized = serde_json::to_string(&map).unwrap();
let deserialized: CRDTMap<String> = serde_json::from_str(&serialized).unwrap();
assert_eq!(
deserialized.get(&"key1".to_string()),
Some(&MapValue::Int(42))
);
assert_eq!(
deserialized.get(&"key2".to_string()),
Some(&MapValue::Text("hello".to_string()))
);
}
}