use std::cell::UnsafeCell;
use std::collections::HashMap;
use std::hash::Hash;
#[cfg(debug_assertions)]
mod reentry {
use std::cell::RefCell;
use std::collections::HashMap;
std::thread_local! {
static LOCAL_VALUES_HOLDS: RefCell<HashMap<usize, LocalValuesState>> =
RefCell::new(HashMap::new());
}
pub(super) struct ReentryFlag {
_address: u8,
}
impl ReentryFlag {
pub(super) fn new() -> Self {
Self { _address: 0 }
}
pub(super) fn acquire_values_read(&self) -> ReentryReadGuard<'_> {
let key = self.key();
LOCAL_VALUES_HOLDS.with(|holds| {
let mut holds = holds.borrow_mut();
let state = holds.entry(key).or_default();
assert!(
!state.write_held,
"openpit::storage: closure re-entered the same storage: \
Storage::with called from inside a Storage::with_mut \
closure on the same storage; a write hold conflicts \
with any other hold on the values domain"
);
state.read_depth += 1;
});
ReentryReadGuard {
key,
_marker: std::marker::PhantomData,
}
}
pub(super) fn acquire_values_write(&self) -> ReentryWriteGuard<'_> {
let key = self.key();
LOCAL_VALUES_HOLDS.with(|holds| {
let mut holds = holds.borrow_mut();
let state = holds.entry(key).or_default();
assert!(
state.read_depth == 0,
"openpit::storage: closure re-entered the same storage: \
Storage::with_mut called from inside a Storage::with \
closure on the same storage; a write hold conflicts \
with any read hold on the values domain"
);
assert!(
!state.write_held,
"openpit::storage: closure re-entered the same storage: \
Storage::with_mut called from inside another \
Storage::with_mut closure on the same storage; write \
holds do not stack"
);
state.write_held = true;
});
ReentryWriteGuard {
key,
_marker: std::marker::PhantomData,
}
}
fn key(&self) -> usize {
self as *const Self as usize
}
}
#[derive(Clone, Copy, Default)]
struct LocalValuesState {
read_depth: usize,
write_held: bool,
}
pub(super) struct ReentryReadGuard<'a> {
key: usize,
_marker: std::marker::PhantomData<&'a ReentryFlag>,
}
impl Drop for ReentryReadGuard<'_> {
fn drop(&mut self) {
LOCAL_VALUES_HOLDS.with(|holds| {
let mut holds = holds.borrow_mut();
if let Some(state) = holds.get_mut(&self.key) {
state.read_depth -= 1;
if state.read_depth == 0 && !state.write_held {
holds.remove(&self.key);
}
}
});
}
}
pub(super) struct ReentryWriteGuard<'a> {
key: usize,
_marker: std::marker::PhantomData<&'a ReentryFlag>,
}
impl Drop for ReentryWriteGuard<'_> {
fn drop(&mut self) {
LOCAL_VALUES_HOLDS.with(|holds| {
let mut holds = holds.borrow_mut();
if let Some(state) = holds.get_mut(&self.key) {
state.write_held = false;
if state.read_depth == 0 {
holds.remove(&self.key);
}
}
});
}
}
}
pub struct Storage<Key, Value, LockingPolicy>
where
LockingPolicy: super::policy::LockingPolicy,
{
data: UnsafeCell<HashMap<Key, Box<UnsafeCell<Value>>>>,
locking_policy: LockingPolicy,
#[cfg(debug_assertions)]
reentry: reentry::ReentryFlag,
}
unsafe impl<Key, Value, LockingPolicy> Send for Storage<Key, Value, LockingPolicy>
where
LockingPolicy: super::policy::LockingPolicy + Send,
Key: Send,
Value: Send,
{
}
unsafe impl<Key, Value, LockingPolicy> Sync for Storage<Key, Value, LockingPolicy>
where
LockingPolicy: super::policy::FullySynchronized + Sync,
Key: Send + Sync,
Value: Send + Sync,
{
}
impl<Key, Value, LockingPolicy> Storage<Key, Value, LockingPolicy>
where
LockingPolicy: super::policy::LockingPolicy,
{
pub(super) fn with_locking_policy(locking_policy: LockingPolicy) -> Self {
Self {
locking_policy,
data: UnsafeCell::new(HashMap::new()),
#[cfg(debug_assertions)]
reentry: reentry::ReentryFlag::new(),
}
}
pub(super) fn with_locking_policy_and_capacity(
locking_policy: LockingPolicy,
capacity: usize,
) -> Self {
Self {
locking_policy,
data: UnsafeCell::new(HashMap::with_capacity(capacity)),
#[cfg(debug_assertions)]
reentry: reentry::ReentryFlag::new(),
}
}
pub fn len(&self) -> usize {
let _index = self.locking_policy.read_index();
unsafe { (*self.data.get()).len() }
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
impl<Key, Value, LockingPolicy> Storage<Key, Value, LockingPolicy>
where
Key: Hash + Eq,
LockingPolicy: super::policy::LockingPolicy,
{
pub fn with<Reader, Output>(&self, key: &Key, reader: Reader) -> Option<Output>
where
Reader: FnOnce(&Value) -> Output,
{
#[cfg(debug_assertions)]
let _reentry = self.reentry.acquire_values_read();
let _index_guard = self.locking_policy.read_index();
let data = unsafe { &*self.data.get() };
let value_box = data.get(key)?;
let value_cell: &UnsafeCell<Value> = value_box.as_ref();
let _values_guard = self.locking_policy.read_values(key);
let value: &Value = unsafe { &*value_cell.get() };
Some(reader(value))
}
pub fn with_mut_if_present<Mutator, Output>(
&self,
key: &Key,
mutator: Mutator,
) -> Option<Output>
where
Mutator: FnOnce(&mut Value) -> Output,
{
#[cfg(debug_assertions)]
let _reentry = self.reentry.acquire_values_write();
let _index_guard = self.locking_policy.read_index();
let data = unsafe { &*self.data.get() };
let value_box = data.get(key)?;
let value_cell: &UnsafeCell<Value> = value_box.as_ref();
let _values_guard = self.locking_policy.write_values(key);
let value: &mut Value = unsafe { &mut *value_cell.get() };
Some(mutator(value))
}
pub(crate) fn with_mut_if_present_exclusive_index<Mutator, Output>(
&self,
key: &Key,
mutator: Mutator,
) -> Option<Output>
where
Mutator: FnOnce(&mut Value) -> Output,
{
#[cfg(debug_assertions)]
let _reentry = self.reentry.acquire_values_write();
let _index_guard = self.locking_policy.write_index();
let data = unsafe { &*self.data.get() };
let value_box = data.get(key)?;
let value_cell: &UnsafeCell<Value> = value_box.as_ref();
let _values_guard = self.locking_policy.write_values(key);
let value: &mut Value = unsafe { &mut *value_cell.get() };
Some(mutator(value))
}
pub fn with_mut<Mutator, Output, Initializer>(
&self,
key: Key,
default: Initializer,
mutator: Mutator,
) -> Output
where
Mutator: FnOnce(&mut Value, bool) -> Output,
Initializer: FnOnce() -> Value,
Key: Clone,
{
#[cfg(debug_assertions)]
let _reentry = self.reentry.acquire_values_write();
{
let index_guard = self.locking_policy.read_index();
let data = unsafe { &*self.data.get() };
if let Some(value_box) = data.get(&key) {
let value_cell: &UnsafeCell<Value> = value_box.as_ref();
let _values_guard = self.locking_policy.write_values(&key);
let value: &mut Value = unsafe { &mut *value_cell.get() };
let result = mutator(value, false);
drop(index_guard);
return result;
}
}
let _index_guard = self.locking_policy.write_index();
let data = unsafe { &mut *self.data.get() };
use std::collections::hash_map::Entry;
let lookup_key = key.clone();
let (value_cell, already_present): (&UnsafeCell<Value>, bool) = match data.entry(key) {
Entry::Occupied(entry) => (&**entry.into_mut(), true),
Entry::Vacant(entry) => (&**entry.insert(Box::new(UnsafeCell::new(default()))), false),
};
let _values_guard = self.locking_policy.write_values(&lookup_key);
let value: &mut Value = unsafe { &mut *value_cell.get() };
mutator(value, !already_present)
}
pub fn with_mut_or_insert<Mutator, Output, Error, Initializer>(
&self,
key: Key,
default: Initializer,
mutator: Mutator,
) -> Result<Output, Error>
where
Mutator: FnOnce(&mut Value, bool) -> Result<Output, Error>,
Initializer: FnOnce() -> Value,
Key: Clone,
{
self.with_mut_or_insert_prune_new_if(key, default, |_| false, mutator)
}
pub fn with_mut_or_insert_prune_new_if<Mutator, Output, Error, Initializer, Pred>(
&self,
key: Key,
default: Initializer,
prune_if: Pred,
mutator: Mutator,
) -> Result<Output, Error>
where
Mutator: FnOnce(&mut Value, bool) -> Result<Output, Error>,
Initializer: FnOnce() -> Value,
Pred: FnOnce(&Value) -> bool,
Key: Clone,
{
#[cfg(debug_assertions)]
let _reentry = self.reentry.acquire_values_write();
{
let index_guard = self.locking_policy.read_index();
let data = unsafe { &*self.data.get() };
if let Some(value_box) = data.get(&key) {
let value_cell: &UnsafeCell<Value> = value_box.as_ref();
let _values_guard = self.locking_policy.write_values(&key);
let result = {
let value: &mut Value = unsafe { &mut *value_cell.get() };
mutator(value, false)
};
drop(index_guard);
return result;
}
}
let _index_guard = self.locking_policy.write_index();
let data = unsafe { &mut *self.data.get() };
use std::collections::hash_map::Entry;
let lookup_key = key.clone();
let (value_cell, already_present): (&UnsafeCell<Value>, bool) = match data.entry(key) {
Entry::Occupied(entry) => (&**entry.into_mut(), true),
Entry::Vacant(entry) => (&**entry.insert(Box::new(UnsafeCell::new(default()))), false),
};
let _values_guard = self.locking_policy.write_values(&lookup_key);
let result = {
let value: &mut Value = unsafe { &mut *value_cell.get() };
mutator(value, !already_present)
};
if !already_present {
let should_remove = match &result {
Err(_) => true,
Ok(_) => prune_if(unsafe { &*value_cell.get() }),
};
if should_remove {
data.remove(&lookup_key);
}
}
result
}
pub fn remove(&self, key: &Key) -> bool {
let _index_guard = self.locking_policy.write_index();
let _values_guard = self.locking_policy.write_values(key);
let data = unsafe { &mut *self.data.get() };
data.remove(key).is_some()
}
pub fn remove_if<Pred>(&self, key: &Key, predicate: Pred) -> bool
where
Pred: FnOnce(&Value) -> bool,
{
let _index_guard = self.locking_policy.write_index();
let _values_guard = self.locking_policy.write_values(key);
let data = unsafe { &mut *self.data.get() };
match data.get(key) {
Some(value_box) if predicate(unsafe { &*value_box.get() }) => {
data.remove(key);
true
}
_ => false,
}
}
}