use crate::rt::{self, VersionVec};
use std::cell::UnsafeCell;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
#[derive(Debug)]
pub struct CausalCell<T> {
data: UnsafeCell<T>,
state: Arc<Mutex<State>>,
}
#[derive(Debug)]
#[must_use]
pub struct CausalCheck {
deferred: Vec<(Arc<Mutex<State>>, usize)>,
}
#[derive(Debug)]
struct State {
causality: Causality,
deferred: HashMap<usize, Deferred>,
next_index: usize,
}
#[derive(Debug)]
struct Deferred {
is_mut: bool,
thread_causality: VersionVec,
result: Result<(), String>,
}
#[derive(Debug, Clone)]
struct Causality {
immut_access_version: VersionVec,
mut_access_version: VersionVec,
}
impl<T> CausalCell<T> {
pub fn new(data: T) -> CausalCell<T> {
let v = rt::execution(|execution| execution.threads.active().causality.clone());
CausalCell {
data: UnsafeCell::new(data),
state: Arc::new(Mutex::new(State {
causality: Causality {
immut_access_version: v.clone(),
mut_access_version: v,
},
deferred: HashMap::new(),
next_index: 0,
})),
}
}
pub fn with<F, R>(&self, f: F) -> R
where
F: FnOnce(*const T) -> R,
{
rt::critical(|| {
self.check();
self.with_unchecked(f)
})
}
pub fn with_deferred<F, R>(&self, f: F) -> (R, CausalCheck)
where
F: FnOnce(*const T) -> R,
{
rt::critical(|| {
rt::execution(|execution| {
let thread_causality = &execution.threads.active().causality;
let mut state = self.state.lock().unwrap();
let index = state.next_index;
let result = state.causality.check(thread_causality);
state.deferred.insert(
index,
Deferred {
is_mut: false,
thread_causality: thread_causality.clone(),
result,
},
);
state.next_index += 1;
let check = CausalCheck {
deferred: vec![(self.state.clone(), index)],
};
(self.with_unchecked(f), check)
})
})
}
pub fn with_mut<F, R>(&self, f: F) -> R
where
F: FnOnce(*mut T) -> R,
{
rt::critical(|| {
self.check_mut();
self.with_mut_unchecked(f)
})
}
pub fn with_deferred_mut<F, R>(&self, f: F) -> (R, CausalCheck)
where
F: FnOnce(*mut T) -> R,
{
rt::critical(|| {
rt::execution(|execution| {
let thread_causality = &execution.threads.active().causality;
let mut state = self.state.lock().unwrap();
let index = state.next_index;
let result = state.causality.check_mut(thread_causality);
state.deferred.insert(
index,
Deferred {
is_mut: true,
thread_causality: thread_causality.clone(),
result,
},
);
state.next_index += 1;
let check = CausalCheck {
deferred: vec![(self.state.clone(), index)],
};
(self.with_mut_unchecked(f), check)
})
})
}
pub fn with_unchecked<F, R>(&self, f: F) -> R
where
F: FnOnce(*const T) -> R,
{
f(self.data.get())
}
pub fn with_mut_unchecked<F, R>(&self, f: F) -> R
where
F: FnOnce(*mut T) -> R,
{
f(self.data.get())
}
pub fn check(&self) {
rt::execution(|execution| {
let thread_causality = &execution.threads.active().causality;
let mut state = self.state.lock().unwrap();
state.causality.check(thread_causality).unwrap();
state.causality.immut_access_version.join(thread_causality);
for deferred in state.deferred.values_mut() {
deferred.check(thread_causality);
}
})
}
pub fn check_mut(&self) {
rt::execution(|execution| {
let thread_causality = &execution.threads.active().causality;
let mut state = self.state.lock().unwrap();
state.causality.check_mut(thread_causality).unwrap();
state.causality.mut_access_version.join(thread_causality);
for deferred in state.deferred.values_mut() {
deferred.check_mut(thread_causality);
}
})
}
}
impl CausalCheck {
pub fn check(mut self) {
for (state, index) in self.deferred.drain(..) {
let mut state = state.lock().unwrap();
let deferred = state.deferred.remove(&index).unwrap();
deferred.result.unwrap();
if deferred.is_mut {
state
.causality
.mut_access_version
.join(&deferred.thread_causality);
} else {
state
.causality
.immut_access_version
.join(&deferred.thread_causality);
}
for other in state.deferred.values_mut() {
if deferred.is_mut {
other.check_mut(&deferred.thread_causality);
} else {
other.check(&deferred.thread_causality);
}
}
}
}
pub fn join(&mut self, other: CausalCheck) {
self.deferred.extend(other.deferred.into_iter());
}
}
impl Default for CausalCheck {
fn default() -> CausalCheck {
CausalCheck { deferred: vec![] }
}
}
impl Causality {
fn check(&self, thread_causality: &VersionVec) -> Result<(), String> {
if !(self.mut_access_version <= *thread_causality) {
let msg = format!(
"Causality violation: \
Concurrent mutable access and immutable access(es): \
cell.with: v={:?}; mut v: {:?}; thread={:?}",
self.immut_access_version, self.mut_access_version, thread_causality
);
return Err(msg);
}
Ok(())
}
fn check_mut(&self, thread_causality: &VersionVec) -> Result<(), String> {
if !(self.mut_access_version <= *thread_causality) {
let msg = format!(
"Causality violation: \
Concurrent mutable accesses: \
cell.with_mut: v={:?}; mut v={:?}; thread={:?}",
self.immut_access_version, self.mut_access_version, thread_causality,
);
return Err(msg);
}
if !(self.immut_access_version <= *thread_causality) {
let msg = format!(
"Causality violation: \
Concurrent mutable access and immutable access(es): \
cell.with_mut: v={:?}; mut v={:?}; thread={:?}",
self.immut_access_version, self.mut_access_version, thread_causality,
);
return Err(msg);
}
Ok(())
}
}
impl Deferred {
fn check(&mut self, thread_causality: &VersionVec) {
if self.result.is_err() {
return;
}
if !self.is_mut {
return;
}
if self
.thread_causality
.partial_cmp(thread_causality)
.is_none()
{
self.result = Err(
"Causality violation: concurrent mutable access and immutable access(es)"
.to_string(),
);
}
}
fn check_mut(&mut self, thread_causality: &VersionVec) {
if self.result.is_err() {
return;
}
if self
.thread_causality
.partial_cmp(thread_causality)
.is_none()
{
self.result = Err("Causality violation: concurrent mutable accesses".to_string());
}
}
}