use super::*;
use crate::register::Register;
use std::borrow::Borrow;
use std::cmp::max;
use std::collections::HashMap;
use std::hash::Hash;
#[derive(Clone, Debug, Eq, PartialEq)]
struct SubRegister<Tag, CL>
where
Tag: TagT,
CL: CausalLength,
{
tag: Tag,
length: CL,
}
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct Set<T, Tag, CL>
where
T: Key,
Tag: TagT,
CL: CausalLength,
{
map: HashMap<T, SubRegister<Tag, CL>>,
}
impl<T, Tag, CL> Set<T, Tag, CL>
where
T: Key,
Tag: TagT,
CL: CausalLength,
{
pub fn new() -> Set<T, Tag, CL> {
Set {
map: HashMap::new(),
}
}
pub fn get<Q>(&self, member: Q) -> Option<Tag>
where
Q: Borrow<T>,
{
if let Some(e) = self.map.get(member.borrow()).to_owned() {
if e.length.is_odd() {
return Some(e.tag);
}
}
None
}
pub fn contains<Q>(&self, member: Q) -> bool
where
Q: Borrow<T>,
{
self.get(member).is_some()
}
pub fn add(&mut self, member: T, tag: Tag) {
let one: CL = CL::one();
let mut e = self
.map
.entry(member)
.or_insert(SubRegister { tag, length: one });
if e.length.is_even() {
e.length = e.length + one;
}
e.tag = max(e.tag, tag);
}
pub fn remove(&mut self, member: T, tag: Tag) {
self.map.entry(member).and_modify(|e| {
if e.length.is_odd() {
e.length = e.length + CL::one()
}
e.tag = max(e.tag, tag);
});
}
pub fn iter(&self) -> impl Iterator<Item = (&T, Tag)> + '_ {
self.map
.iter()
.filter(|(_k, v)| v.length.is_odd())
.map(|(k, v)| (k, v.tag))
}
pub fn register_iter(&self) -> impl Iterator<Item = Register<T, Tag, CL>> + '_ {
self.map.iter().map(|(k, v)| Register {
item: k.clone(),
tag: v.tag,
length: v.length,
})
}
pub fn merge_register(&mut self, delta: Register<T, Tag, CL>, min_tag: Tag) {
if delta.length.is_even() && delta.tag < min_tag {
return;
}
let Register { item, tag, length } = delta;
match self.map.entry(item) {
Entry::Occupied(mut e) => {
let e = e.get_mut();
e.tag = max(e.tag, tag);
e.length = max(e.length, length);
}
Entry::Vacant(e) => {
e.insert(SubRegister { tag, length });
}
}
}
pub fn merge(&mut self, other: &Self, min_tag: Tag) {
for delta in other.register_iter() {
self.merge_register(delta, min_tag);
}
}
pub fn retain(&mut self, min_tag: Tag) {
self.map
.retain(|_k, SubRegister { tag, length }| length.is_odd() || min_tag < *tag);
}
}
#[cfg(feature = "serialization")]
mod serialization {
use super::*;
use serde::de::{SeqAccess, Visitor};
use serde::ser::SerializeSeq;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::fmt::Formatter;
use std::marker::PhantomData;
impl<T, Tag, CL> Serialize for Set<T, Tag, CL>
where
T: Key + Serialize,
Tag: TagT + Serialize,
CL: CausalLength + Serialize,
{
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut seq = serializer.serialize_seq(Some(self.map.len()))?;
for member in self.register_iter() {
seq.serialize_element(&(member.item, member.tag, member.length))?;
}
seq.end()
}
}
struct DeltaVisitor<T, Tag, CL>(PhantomData<T>, PhantomData<Tag>, PhantomData<CL>);
impl<'de, T, Tag, CL> Visitor<'de> for DeltaVisitor<T, Tag, CL>
where
T: Key + Deserialize<'de>,
Tag: TagT + Deserialize<'de>,
CL: CausalLength + Deserialize<'de>,
{
type Value = HashMap<T, SubRegister<Tag, CL>>;
fn expecting(&self, formatter: &mut Formatter<'_>) -> std::fmt::Result {
formatter.write_str("a tuple of key, value, tag, and causal length")
}
fn visit_seq<A>(self, mut seq: A) -> std::result::Result<Self::Value, A::Error>
where
A: SeqAccess<'de>,
{
let mut map: HashMap<T, SubRegister<Tag, CL>> =
HashMap::with_capacity(seq.size_hint().unwrap_or(0));
while let Some(d) = seq.next_element::<(T, Tag, CL)>()? {
map.insert(
d.0,
SubRegister {
tag: d.1,
length: d.2,
},
);
}
Ok(map)
}
}
impl<'de, T, Tag, CL> Deserialize<'de> for Set<T, Tag, CL>
where
T: Eq + Hash + Clone + Deserialize<'de>,
Tag: TagT + Deserialize<'de>,
CL: CausalLength + Deserialize<'de>,
{
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let visitor = DeltaVisitor::<T, Tag, CL>(PhantomData, PhantomData, PhantomData);
let map = deserializer.deserialize_seq(visitor)?;
Ok(Set { map })
}
}
}
#[cfg(feature = "serialization")]
pub use serialization::*;
use std::collections::hash_map::Entry;
#[cfg(test)]
mod tests {
use super::*;
use quickcheck_macros::quickcheck;
use rand::seq::SliceRandom;
#[test]
fn test_add() {
let later_time = 1;
let mut cls: Set<&str, u32, u16> = Set::new();
cls.add("foo", later_time);
cls.add("foo", later_time);
cls.add("foo", later_time);
assert_eq!(cls.map.len(), 1);
assert_eq!(
cls.map.get("foo"),
Some(&SubRegister {
tag: later_time,
length: 1
})
);
assert_eq!(cls.contains("foo"), true);
assert_eq!(cls.get("bar"), None);
}
#[test]
fn test_remove() {
let time_1 = 1;
let time_2 = 2;
let time_3 = 3;
let mut cls: Set<&str, u32, u16> = Set::new();
cls.add("foo", time_1);
cls.add("bar", time_1);
cls.remove("foo", time_2);
cls.remove("bar", time_2);
cls.add("bar", time_3);
assert_eq!(cls.map.len(), 2);
assert_eq!(
cls.map.get(&"bar"),
Some(&SubRegister {
tag: time_3,
length: 3
})
);
assert_eq!(
cls.map.get(&"foo"),
Some(&SubRegister {
tag: time_2,
length: 2
})
);
let values: Vec<(&&str, u32)> = cls.iter().collect();
assert_eq!(values.len(), 1);
assert_eq!(values[0], (&"bar", time_3));
}
#[test]
fn test_merge() {
let time_0 = 0;
let time_1 = 1;
let time_2 = 2;
let time_3 = 3;
let mut cls1: Set<&str, u32, u16> = Set::new();
let mut cls2: Set<&str, u32, u16> = Set::new();
cls1.add("foo", time_1);
cls1.add("bar", time_1);
cls2.merge(&cls1, time_0);
cls2.remove("foo", time_2);
cls1.remove("bar", time_2);
cls1.remove("bar", time_2);
cls2.merge(&cls1, time_0);
cls2.add("bar", time_3);
assert_eq!(cls2.map.len(), 2);
assert_eq!(
cls2.map.get(&"bar"),
Some(&SubRegister {
tag: time_3,
length: 3
})
);
assert_eq!(
cls2.map.get(&"foo"),
Some(&SubRegister {
tag: time_2,
length: 2
})
);
let values: Vec<(&&str, u32)> = cls2.iter().collect();
assert_eq!(values.len(), 1);
assert_eq!(values[0], (&"bar", time_3));
}
#[test]
fn test_retain() {
let time_0 = 0;
let time_1 = 1;
let time_2 = 2;
let time_3 = 3;
let mut cls: Set<&str, u32, u16> = Set::new();
cls.add("foo", time_0);
cls.add("bar", time_0);
cls.remove("foo", time_1);
cls.remove("bar", time_1);
cls.add("bar", time_2);
assert_eq!(cls.map.len(), 2);
assert_eq!(
cls.map.get(&"bar"),
Some(&SubRegister {
tag: time_2,
length: 3
})
);
assert_eq!(
cls.map.get(&"foo"),
Some(&SubRegister {
tag: time_1,
length: 2
})
);
let values: Vec<(&&str, u32)> = cls.iter().collect();
assert_eq!(values.len(), 1);
assert_eq!(values[0], (&"bar", time_2));
cls.retain(time_3);
assert_eq!(cls.map.len(), 1);
assert_eq!(
cls.map.get(&"bar"),
Some(&SubRegister {
tag: time_2,
length: 3
})
);
cls.merge_register(
Register {
item: &"bar",
tag: time_2,
length: 2,
},
time_0,
);
assert_eq!(cls.map.len(), 1);
assert_eq!(
cls.map.get(&"bar"),
Some(&SubRegister {
tag: time_2,
length: 3
})
);
}
#[cfg(feature = "serialization")]
#[test]
fn test_serialization() {
let time_1 = 1;
let time_2 = 2;
let time_3 = 3;
let mut cls: Set<&str, u32, u16> = Set::new();
cls.add("foo", time_1);
cls.add("bar", time_1);
cls.remove("foo", time_2);
cls.remove("bar", time_2);
cls.add("bar", time_3);
let data = serde_json::to_vec(&cls).unwrap();
let cls2: Set<&str, u32, u16> = serde_json::from_slice(&data).unwrap();
assert_eq!(cls.map, cls2.map);
}
#[test]
fn test_order_independence() {
let mut m: Set<&str, u32, u16> = Set::new();
let mut v: Vec<Register<&str, u32, u16>> = vec![];
for i in 0..1000 {
v.push(Register {
item: "foo",
tag: i as u32,
length: i as u16,
});
}
v.shuffle(&mut rand::thread_rng());
for r in v {
m.merge_register(r, 0);
}
assert_eq!(
m.map.get("foo"),
Some(&SubRegister {
tag: 999,
length: 999
})
);
}
fn merge(mut acc: Set<u8, u8, u8>, el: &Register<u8, u8, u8>) -> Set<u8, u8, u8> {
acc.merge_register(el.clone(), 0);
acc
}
#[quickcheck]
fn is_merge_commutative(xs: Vec<Register<u8, u8, u8>>) -> bool {
let left = xs.iter().fold(Set::default(), merge);
let right = xs.iter().rfold(Set::default(), merge);
left == right
}
#[quickcheck]
fn is_merge_order_independent(xs: Vec<Register<u8, u8, u8>>) -> bool {
let mut copy = xs.clone();
copy.shuffle(&mut rand::thread_rng());
let left = xs.iter().fold(Set::default(), merge);
let right = copy.iter().rfold(Set::default(), merge);
left == right
}
use quickcheck::{Arbitrary, Gen};
#[derive(Clone, Debug)]
enum Op {
Insert(u8),
Get(u8),
Delete(u8),
}
const KEY_SPACE: u8 = 20;
impl Arbitrary for Op {
fn arbitrary(g: &mut Gen) -> Op {
let k: u8 = u8::arbitrary(g) % KEY_SPACE;
let n: u8 = u8::arbitrary(g) % 4;
match n {
0 => Op::Insert(k),
1 => Op::Delete(k),
2 | 3 => Op::Get(k),
_ => Op::Get(k),
}
}
}
#[quickcheck]
fn implementation_matches_model(ops: Vec<Op>) -> bool {
let mut implementation: Set<u8, u8, u8> = Set::new();
let mut model = std::collections::HashSet::new();
for op in ops {
match op {
Op::Insert(k) => {
implementation.add(k, 0);
model.insert(k);
}
Op::Get(k) => {
if implementation.get(&k).is_some() != model.get(&k).is_some() {
return false;
}
}
Op::Delete(k) => {
implementation.remove(k, 0);
model.remove(&k);
}
}
}
true
}
}