use super::{
Either, NoExtensionTypes, TypeVariantValue, Value, ValueRef,
mvreg::MvRegValue,
orarray::Uid,
snapshot::{self, AllValues, CollapsedValue, SingleValueError, SingleValueIssue, ToValue},
};
use crate::{
CausalContext, CausalDotStore, DotMap, DotStoreJoin, ExtensionType, Identifier, MvReg, OrArray,
dotstores::{DotChange, DotStore, DryJoinOutput},
sentinel::{KeySentinel, TypeSentinel, ValueSentinel, Visit},
};
use std::{borrow::Borrow, fmt, hash::Hash, ops::Index};
#[derive(Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(::serde::Deserialize, ::serde::Serialize))]
pub struct OrMap<K: Hash + Eq, C = NoExtensionTypes>(pub(super) DotMap<K, TypeVariantValue<C>>);
impl<K, C> Default for OrMap<K, C>
where
K: Hash + Eq,
{
fn default() -> Self {
Self(Default::default())
}
}
impl<K, C> std::fmt::Debug for OrMap<K, C>
where
K: Hash + Eq + std::fmt::Debug,
C: fmt::Debug + ExtensionType,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
impl<K, C> FromIterator<(K, TypeVariantValue<C>)> for OrMap<K, C>
where
K: Eq + Hash,
{
fn from_iter<T: IntoIterator<Item = (K, TypeVariantValue<C>)>>(iter: T) -> Self {
Self(DotMap::from_iter(iter))
}
}
impl<K, Q, C> Index<&Q> for OrMap<K, C>
where
K: Eq + Hash + Borrow<Q>,
Q: Eq + Hash + ?Sized,
{
type Output = TypeVariantValue<C>;
fn index(&self, index: &Q) -> &Self::Output {
self.0.index(index)
}
}
impl<K, C> DotStore for OrMap<K, C>
where
K: Hash + Eq + fmt::Debug + Clone,
C: ExtensionType,
{
fn dots(&self) -> CausalContext {
self.0.dots()
}
fn add_dots_to(&self, other: &mut CausalContext) {
self.0.add_dots_to(other);
}
fn is_bottom(&self) -> bool {
self.0.is_bottom()
}
fn subset_for_inflation_from(&self, frontier: &CausalContext) -> Self {
Self(DotMap::subset_for_inflation_from(&self.0, frontier))
}
}
impl<K, C, S> DotStoreJoin<S> for OrMap<K, C>
where
K: Hash + Eq + fmt::Debug + Clone,
C: ExtensionType + DotStoreJoin<S> + fmt::Debug + Clone + PartialEq,
S: Visit<K>
+ Visit<String>
+ Visit<Uid>
+ KeySentinel
+ TypeSentinel<C::ValueKind>
+ ValueSentinel<MvRegValue>,
{
fn join(
(m1, cc1): (Self, &CausalContext),
(m2, cc2): (Self, &CausalContext),
on_dot_change: &mut dyn FnMut(DotChange),
sentinel: &mut S,
) -> Result<Self, S::Error>
where
Self: Sized,
S: KeySentinel,
{
Ok(Self(DotMap::join(
(m1.0, cc1),
(m2.0, cc2),
on_dot_change,
sentinel,
)?))
}
fn dry_join(
(m1, cc1): (&Self, &CausalContext),
(m2, cc2): (&Self, &CausalContext),
sentinel: &mut S,
) -> Result<DryJoinOutput, S::Error>
where
Self: Sized,
S: KeySentinel,
{
DotMap::dry_join((&m1.0, cc1), (&m2.0, cc2), sentinel)
}
}
impl<'doc, K, C> ToValue for &'doc OrMap<K, C>
where
K: Hash + Eq + fmt::Display,
C: ExtensionType,
{
type Values = snapshot::OrMap<'doc, K, AllValues<'doc, C::ValueRef<'doc>>>;
type Value = snapshot::OrMap<'doc, K, CollapsedValue<'doc, C::ValueRef<'doc>>>;
type LeafValue = Either<MvRegValue, <C::ValueRef<'doc> as ToValue>::LeafValue>;
fn values(self) -> Self::Values {
let mut ret_map = snapshot::OrMap::default();
for (key, inner_map) in self.0.iter() {
let v = match inner_map.coerce_to_value_ref() {
ValueRef::Map(m) => AllValues::Map(m.values()),
ValueRef::Array(a) => AllValues::Array(a.values()),
ValueRef::Register(r) => AllValues::Register(r.values()),
ValueRef::Custom(c) => AllValues::Custom(c.values()),
};
ret_map.map.insert(key.borrow(), v);
}
ret_map
}
fn value(self) -> Result<Self::Value, Box<SingleValueError<Self::LeafValue>>> {
let mut ret_map = snapshot::OrMap::default();
for (key, inner_map) in self.0.iter() {
let v = match inner_map.coerce_to_value_ref() {
ValueRef::Map(m) => m.value().map(CollapsedValue::Map).map(Some),
ValueRef::Array(a) => a.value().map(CollapsedValue::Array).map(Some),
ValueRef::Register(r) => {
match r.value() {
Ok(v) => Ok(Some(CollapsedValue::Register(v))),
Err(e) if e.issue == SingleValueIssue::Cleared => Ok(None),
Err(mut e) => {
e.path.push(key.to_string());
Err(e.map_values(Either::Left))
}
}
}
ValueRef::Custom(c) => c
.value()
.map(CollapsedValue::Custom)
.map(Some)
.map_err(|v| v.map_values(Either::Right)),
}?;
if let Some(v) = v {
ret_map.map.insert(key.borrow(), v);
}
}
Ok(ret_map)
}
}
impl<K, C> OrMap<K, C>
where
K: Hash + Eq,
{
pub fn get<Q>(&self, key: &Q) -> Option<&TypeVariantValue<C>>
where
K: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
self.0.get(key)
}
pub fn get_mut_and_invalidate<Q>(&mut self, key: &Q) -> Option<&mut TypeVariantValue<C>>
where
K: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
self.0.get_mut_and_invalidate(key)
}
pub fn len(&self) -> usize {
self.0.len()
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
#[doc(hidden)]
pub fn insert(&mut self, key: K, value: TypeVariantValue<C>) {
self.0.insert(key, value);
}
pub fn iter_mut_and_invalidate(
&mut self,
) -> impl ExactSizeIterator<Item = (&K, &mut TypeVariantValue<C>)> {
self.0.iter_mut_and_invalidate()
}
pub fn retain_and_invalidate(&mut self, f: impl FnMut(&K, &mut TypeVariantValue<C>) -> bool) {
self.0.retain_and_invalidate(f)
}
pub fn inner(&self) -> &DotMap<K, TypeVariantValue<C>> {
&self.0
}
}
macro_rules! apply_to_X {
($name:ident, $frag:literal, $field:ident, [$($others:ident),*], $innerType:ty) => {
#[doc = $frag]
pub fn $name<'data, O>(&'data self, o: O, k: K, cc: &'_ CausalContext, id: Identifier) -> CausalDotStore<Self>
where
O: for<'cc, 'v> FnOnce(
&'v $innerType,
&'cc CausalContext,
Identifier,
) -> CausalDotStore<$innerType>,
{
let CausalDotStore {
store: ret_map,
context: mut ret_cc,
} = self.apply(
move |m, cc, id| {
o(&m.$field, cc, id).map_store(Value::from)
},
k.clone(),
cc,
id
);
if let Some(inner) = self.0.get(&k) {
$( inner.$others.add_dots_to(&mut ret_cc); )*
}
CausalDotStore {
store: ret_map,
context: ret_cc,
}
}
};
}
impl<K, C> OrMap<K, C>
where
K: Hash + Eq + fmt::Debug + Clone,
C: ExtensionType,
{
pub fn create(&self, _cc: &CausalContext, _id: Identifier) -> CausalDotStore<Self> {
CausalDotStore {
store: Self(Default::default()),
context: CausalContext::default(),
}
}
apply_to_X!(
apply_to_map,
"an [`OrMap`]",
map,
[array, reg, custom],
OrMap<String, C>
);
apply_to_X!(
apply_to_array,
"an [`OrArray`]",
array,
[map, reg, custom],
OrArray<C>
);
apply_to_X!(
apply_to_register,
"an [`MvReg`]",
reg,
[map, array, custom],
MvReg
);
pub fn apply_to_custom<'data, O>(
&'data self,
o: O,
k: K,
cc: &'_ CausalContext,
id: Identifier,
) -> CausalDotStore<Self>
where
O: for<'cc, 'v> FnOnce(&'v C, &'cc CausalContext, Identifier) -> CausalDotStore<C::Value>,
{
let CausalDotStore {
store: ret_map,
context: mut ret_cc,
} = self.apply(
move |m, cc, id| {
let y = o(&m.custom, cc, id);
y.map_store(Value::Custom)
},
k.clone(),
cc,
id,
);
if let Some(inner) = self.0.get(&k) {
inner.map.add_dots_to(&mut ret_cc);
inner.array.add_dots_to(&mut ret_cc);
inner.reg.add_dots_to(&mut ret_cc);
}
CausalDotStore {
store: ret_map,
context: ret_cc,
}
}
pub fn apply<'data, O>(
&'data self,
o: O,
key: K,
cc: &'_ CausalContext,
id: Identifier,
) -> CausalDotStore<Self>
where
O: for<'cc, 'v> FnOnce(
&'v TypeVariantValue<C>,
&'cc CausalContext,
Identifier,
) -> CausalDotStore<Value<C>>,
{
let mut ret_dot_map = Self::default();
let v = if let Some(v) = self.get(&key) {
v
} else {
&TypeVariantValue::default()
};
let CausalDotStore {
store: new_v,
context: ret_cc,
} = o(v, cc, id);
ret_dot_map.0.set(key, new_v.into());
CausalDotStore {
store: ret_dot_map,
context: ret_cc,
}
}
pub fn remove<Q>(&self, k: &Q, _cc: &CausalContext, _id: Identifier) -> CausalDotStore<Self>
where
K: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
let Some(inner_map) = self.0.get(k) else {
return CausalDotStore::new();
};
let ret_cc = inner_map.dots();
CausalDotStore {
store: Self(Default::default()),
context: ret_cc,
}
}
pub fn clear(&self, _cc: &CausalContext, _id: Identifier) -> CausalDotStore<Self> {
let ret_cc = self.dots();
CausalDotStore {
store: Self(Default::default()),
context: ret_cc,
}
}
pub fn remove_immediately<Q>(&mut self, k: &Q) -> Option<TypeVariantValue<C>>
where
K: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
self.0.remove(k)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
crdts::{
NoExtensionTypes,
test_util::{Ops, join_harness},
},
sentinel::{DummySentinel, test::ValueCountingValidator},
};
use std::collections::BTreeMap;
type OrMap<K> = super::OrMap<K, NoExtensionTypes>;
#[test]
fn empty() {
let cds = CausalDotStore::<OrMap<String>>::default();
assert!(cds.store.is_bottom());
assert!(cds.store.value().unwrap().is_empty());
assert_eq!(cds.store.values().len(), 0);
}
#[test]
fn created_is_bottom() {
let map = OrMap::<String>::default();
let cc = CausalContext::new();
let id = Identifier::new(0, 0);
let m = map.create(&cc, id);
assert!(m.store.is_bottom());
assert_eq!(map, m.store);
}
#[test]
fn cleared_is_bottom() {
let map = OrMap::<String>::default();
let cc = CausalContext::new();
let id = Identifier::new(0, 0);
let m = map.create(&cc, id);
let m = m.store.clear(&m.context, id);
assert!(m.store.is_bottom());
}
#[test]
fn set_get_remove() {
let map = OrMap::<String>::default();
let cc = CausalContext::new();
let id = Identifier::new(0, 0);
let m = map.apply_to_register(
|_old, cc, id| MvReg::default().write(MvRegValue::Bool(true), cc, id),
"foo".into(),
&cc,
id,
);
assert!(!m.store.is_bottom());
assert_eq!(
m.store.value().unwrap().get(&String::from("foo")).cloned(),
Some(CollapsedValue::Register(&MvRegValue::Bool(true)))
);
assert_eq!(m.store.len(), 1);
assert_eq!(
m.context.next_dot_for(id).sequence().get() - 1,
1
);
let m = m.store.remove("foo", &cc, id);
assert!(m.store.is_bottom()); assert_eq!(m.store.value().unwrap().get(&String::from("foo")), None);
assert_eq!(m.store.len(), 0);
assert_eq!(m.context.next_dot_for(id).sequence().get() - 1, 1);
}
#[test]
fn set_one_key_then_another() {
let map = CausalDotStore::<OrMap<String>>::new();
let id = Identifier::new(0, 0);
let delta = map.store.apply_to_register(
|_old, cc, id| MvReg::default().write(MvRegValue::Bool(true), cc, id),
"true".into(),
&map.context,
id,
);
assert!(!delta.store.is_bottom());
assert_eq!(
delta
.store
.value()
.unwrap()
.get(&String::from("true"))
.cloned(),
Some(CollapsedValue::Register(&MvRegValue::Bool(true)))
);
assert_eq!(delta.store.len(), 1);
let map = map.join(delta, &mut DummySentinel).unwrap();
let delta = map.store.apply_to_register(
|_old, cc, id| MvReg::default().write(MvRegValue::Bool(false), cc, id),
"false".into(),
&map.context,
id,
);
assert!(!delta.store.is_bottom());
assert_eq!(
delta
.store
.value()
.unwrap()
.get(&String::from("false"))
.cloned(),
Some(CollapsedValue::Register(&MvRegValue::Bool(false)))
);
assert_eq!(delta.store.len(), 1);
let map = map.join(delta, &mut DummySentinel).unwrap();
assert!(!map.store.is_bottom());
assert_eq!(
map.store
.value()
.unwrap()
.get(&String::from("true"))
.cloned(),
Some(CollapsedValue::Register(&MvRegValue::Bool(true)))
);
assert_eq!(
map.store
.value()
.unwrap()
.get(&String::from("false"))
.cloned(),
Some(CollapsedValue::Register(&MvRegValue::Bool(false)))
);
assert_eq!(map.store.len(), 2);
}
#[test]
fn independent_keys() {
join_harness(
OrMap::<String>::default(),
|cds, _| cds,
|m, cc, id| {
m.apply_to_register(
|_old, cc, id| MvReg::default().write(MvRegValue::Bool(true), cc, id),
"foo".into(),
&cc,
id,
)
},
|m, cc, id| {
m.apply_to_register(
|_old, cc, id| MvReg::default().write(MvRegValue::U64(42), cc, id),
"bar".into(),
&cc,
id,
)
},
DummySentinel,
|CausalDotStore { store: m, .. }, _| {
assert!(!m.is_bottom());
assert_eq!(
m.value().unwrap().get(&String::from("foo")).cloned(),
Some(CollapsedValue::Register(&MvRegValue::Bool(true)))
);
assert_eq!(
m.value().unwrap().get(&String::from("bar")).cloned(),
Some(CollapsedValue::Register(&MvRegValue::U64(42)))
);
},
);
}
#[test]
fn conflicting_reg_value() {
join_harness(
OrMap::<String>::default(),
|cds, _| cds,
|m, cc, id| {
m.apply_to_register(
|_old, cc, id| MvReg::default().write(MvRegValue::Bool(true), cc, id),
"foo".into(),
&cc,
id,
)
},
|m, cc, id| {
m.apply_to_register(
|_old, cc, id| MvReg::default().write(MvRegValue::U64(42), cc, id),
"foo".into(),
&cc,
id,
)
},
ValueCountingValidator::default(),
|CausalDotStore { store: m, .. }, sentinel| {
assert!(!m.is_bottom());
let values = m.values();
let AllValues::Register(v) = values.get(&String::from("foo")).unwrap() else {
panic!("foo isn't a register even though we only wrote registers");
};
assert_eq!(v.len(), 2);
assert!(v.contains(&MvRegValue::Bool(true)));
assert!(v.contains(&MvRegValue::U64(42)));
assert_eq!(sentinel.added, BTreeMap::from([(MvRegValue::U64(42), 1)]));
assert!(sentinel.removed.is_empty());
},
);
}
#[test]
fn concurrent_clear() {
join_harness(
OrMap::<String>::default(),
|CausalDotStore {
store: m,
context: cc,
},
id| {
m.apply_to_register(
|_old, cc, id| MvReg::default().write(MvRegValue::Bool(true), cc, id),
"foo".into(),
&cc,
id,
)
},
|m, cc, id| m.clear(&cc, id),
|m, cc, id| m.clear(&cc, id),
DummySentinel,
|CausalDotStore { store: m, .. }, _| {
assert!(m.is_bottom());
let values = m.values();
assert_eq!(values.len(), 0);
},
);
}
#[test]
fn remove_reg_value() {
join_harness(
OrMap::<String>::default(),
|CausalDotStore {
store: m,
context: cc,
},
id| {
m.apply_to_register(
|_old, cc, id| MvReg::default().write(MvRegValue::Bool(true), cc, id),
"foo".into(),
&cc,
id,
)
},
|m, cc, id| m.clear(&cc, id),
|m, cc, _| CausalDotStore {
store: m.clone(),
context: cc,
},
ValueCountingValidator::new(true),
|CausalDotStore { store: m, .. }, sentinel| {
assert!(m.is_bottom());
let values = m.values();
assert_eq!(values.get(&String::from("foo")), None);
assert!(sentinel.added.is_empty());
assert!(sentinel.removed.is_empty());
},
);
join_harness(
OrMap::<String>::default(),
|CausalDotStore {
store: m,
context: cc,
},
id| {
m.apply_to_register(
|_old, cc, id| MvReg::default().write(MvRegValue::Bool(true), cc, id),
"foo".into(),
&cc,
id,
)
},
|m, cc, _| CausalDotStore {
store: m.clone(),
context: cc,
},
|m, cc, id| m.clear(&cc, id),
ValueCountingValidator::new(true),
|CausalDotStore { store: m, .. }, sentinel| {
assert!(m.is_bottom());
let values = m.values();
assert_eq!(values.get(&String::from("foo")), None);
assert!(sentinel.added.is_empty());
assert_eq!(
sentinel.removed,
BTreeMap::from([(MvRegValue::Bool(true), 1)])
);
},
);
}
#[test]
fn update_vs_remove() {
join_harness(
OrMap::<String>::default(),
|CausalDotStore {
store: m,
context: cc,
},
id| {
m.apply_to_register(
|_old, cc, id| MvReg::default().write(MvRegValue::U64(42), cc, id),
"foo".into(),
&cc,
id,
)
},
|m, cc, id| {
m.apply_to_register(
|_old, cc, id| MvReg::default().write(MvRegValue::Bool(true), cc, id),
"foo".into(),
&cc,
id,
)
},
|m, cc, id| {
m.remove("foo", &cc, id)
},
ValueCountingValidator::default(),
|CausalDotStore { store: m, .. }, sentinel| {
assert!(!m.is_bottom());
let values = m.values();
let AllValues::Register(v) = values.get(&String::from("foo")).unwrap() else {
panic!("foo isn't a register even though we only wrote registers");
};
assert_eq!(v, [MvRegValue::Bool(true)]);
assert!(sentinel.added.is_empty());
assert!(sentinel.removed.is_empty());
},
);
}
#[test]
fn nested_update_vs_remove() {
join_harness(
OrMap::<String>::default(),
|CausalDotStore {
store: m,
context: cc,
},
id| {
m.apply_to_map(
|_old, cc, id| {
OrMap::default().apply_to_register(
|_old, cc, id| MvReg::default().write(MvRegValue::U64(42), cc, id),
"bar".into(),
cc,
id,
)
},
"foo".into(),
&cc,
id,
)
},
|m, cc, id| {
m.apply_to_map(
|old, cc, id| {
old.apply_to_register(
|_old, cc, id| MvReg::default().write(MvRegValue::Bool(true), cc, id),
"baz".into(),
cc,
id,
)
},
"foo".into(),
&cc,
id,
)
},
|m, cc, id| {
m.remove("foo", &cc, id)
},
ValueCountingValidator::default(),
|CausalDotStore { store: m, .. }, sentinel| {
assert!(!m.is_bottom());
let values = m.values();
let AllValues::Map(m) = values.get(&String::from("foo")).unwrap() else {
panic!("foo isn't a map even though we only wrote map");
};
assert_eq!(values.len(), 1);
let AllValues::Register(r) = m
.get(&String::from("baz"))
.expect("baz key isn't preserved")
else {
panic!("baz isn't a register though we only wrote a register ")
};
assert_eq!(m.len(), 1);
assert_eq!(r, [MvRegValue::Bool(true)]);
assert!(sentinel.added.is_empty());
assert!(sentinel.removed.is_empty());
},
);
}
#[quickcheck]
fn order_invariant(ops: Ops<OrMap<String>>, seed: u64) -> quickcheck::TestResult {
ops.check_order_invariance(seed)
}
}