use std::collections::hash_map::Entry;
use std::hash::Hash;
use parking_lot::Mutex;
use rustc_hash::FxHashMap;
use crate::Track;
use crate::track::{Call, Sink};
pub struct Constraint<C>(Mutex<ConstraintRepr<C>>);
struct ConstraintRepr<C> {
immutable: CallSequence<C>,
mutable: Vec<C>,
}
impl<C> Constraint<C> {
pub fn new() -> Self {
Self::default()
}
pub fn validate<T>(&self, value: &T) -> bool
where
T: Track<Call = C> + ?Sized,
{
self.0
.lock()
.immutable
.vec
.iter()
.filter_map(|x| x.as_ref())
.all(|(call, ret)| value.call(call) == *ret)
}
pub(crate) fn take(&self) -> (CallSequence<C>, Vec<C>) {
let mut inner = self.0.lock();
(std::mem::take(&mut inner.immutable), std::mem::take(&mut inner.mutable))
}
}
impl<C> Default for Constraint<C> {
fn default() -> Self {
Self(Mutex::new(ConstraintRepr {
immutable: CallSequence::new(),
mutable: Vec::new(),
}))
}
}
impl<C: Call> Sink for Constraint<C> {
type Call = C;
fn emit(&self, call: C, ret: u128) -> bool {
let mut inner = self.0.lock();
if call.is_mutable() {
inner.mutable.push(call);
true
} else {
inner.immutable.insert(call, ret)
}
}
}
pub struct CallSequence<C> {
vec: Vec<Option<(C, u128)>>,
map: FxHashMap<u128, usize>,
cursor: usize,
}
impl<C> CallSequence<C> {
pub fn new() -> Self {
Self {
vec: Vec::new(),
map: FxHashMap::default(),
cursor: 0,
}
}
}
impl<C: Hash> CallSequence<C> {
pub fn insert(&mut self, call: C, ret: u128) -> bool {
match self.map.entry(crate::hash::hash(&call)) {
Entry::Vacant(entry) => {
let i = self.vec.len();
self.vec.push(Some((call, ret)));
entry.insert(i);
true
}
#[allow(unused_variables)]
Entry::Occupied(entry) => {
#[cfg(debug_assertions)]
if let Some((_, ret2)) = &self.vec[*entry.get()] {
if ret != *ret2 {
panic!(
"comemo: found differing return values. \
is there an impure tracked function?"
)
}
}
false
}
}
}
pub fn next(&mut self) -> Option<(C, u128)> {
while self.cursor < self.vec.len() {
if let Some(pair) = self.vec[self.cursor].take() {
return Some(pair);
}
self.cursor += 1;
}
None
}
pub fn extract(&mut self, call: &C) -> Option<u128> {
let h = crate::hash::hash(&call);
let i = *self.map.get(&h)?;
let res = self.vec[i].take().map(|(_, ret)| ret);
debug_assert!(self.cursor <= i || res.is_none());
res
}
}
impl<C> Default for CallSequence<C> {
fn default() -> Self {
Self::new()
}
}
impl<C: Hash> FromIterator<(C, u128)> for CallSequence<C> {
fn from_iter<T: IntoIterator<Item = (C, u128)>>(iter: T) -> Self {
let mut seq = CallSequence::new();
for (call, ret) in iter {
seq.insert(call, ret);
}
seq
}
}