use crate::{POISONED_OBJECT_RW_LOCK, ThreadMapLockError};
use std::{
cell::UnsafeCell,
collections::HashMap,
fmt::Debug,
mem::take,
ops::DerefMut,
sync::RwLock,
thread::{self, ThreadId},
};
struct UnsafeSyncCell<V>(UnsafeCell<V>);
unsafe impl<V> Sync for UnsafeSyncCell<V> {}
impl<V: Debug> Debug for UnsafeSyncCell<V> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_fmt(format_args!("{:?}", unsafe { &*self.0.get() }))
}
}
#[doc = include_str!("../examples/doc_thread_map.rs")]
#[derive(Debug)]
pub struct ThreadMap<V> {
state: RwLock<HashMap<ThreadId, UnsafeSyncCell<V>>>,
value_init: fn() -> V,
}
impl<V> ThreadMap<V> {
pub fn new(value_init: fn() -> V) -> Self {
Self {
state: RwLock::new(HashMap::new()),
value_init,
}
}
pub fn with_mut<W>(&self, f: impl FnOnce(&mut V) -> W) -> W {
let lock = self.state.read().expect(POISONED_OBJECT_RW_LOCK);
let tid = thread::current().id();
match lock.get(&tid) {
Some(c) => {
let v = c.0.get();
let rv = unsafe { &mut *v };
f(rv)
}
None => {
drop(lock);
let mut lock = self.state.write().expect(POISONED_OBJECT_RW_LOCK);
let mut v0 = (self.value_init)();
let w = f(&mut v0);
lock.insert(tid, UnsafeSyncCell(UnsafeCell::new(v0)));
w
}
}
}
pub fn with<W>(&self, f: impl FnOnce(&V) -> W) -> W {
let g = |v: &mut V| f(v);
self.with_mut(g)
}
pub fn get(&self) -> V
where
V: Clone,
{
self.with(|v| v.clone())
}
pub fn set(&self, v: V) {
self.with_mut(|v0| *v0 = v);
}
pub fn drain(&self) -> Result<HashMap<ThreadId, V>, ThreadMapLockError> {
let mut lock = self.state.write()?;
let rmap = lock.deref_mut();
let tmap = take(rmap);
let map = tmap
.into_iter()
.map(|(k, v)| (k, v.0.into_inner()))
.collect::<HashMap<_, _>>();
Ok(map)
}
pub fn fold<W>(
&self,
z: W,
f: impl FnMut(W, (ThreadId, &V)) -> W,
) -> Result<W, ThreadMapLockError> {
let w = self
.state
.write()?
.iter()
.map(|(tid, c)| {
let v = c.0.get();
(*tid, unsafe { &*v })
})
.fold(z, f);
Ok(w)
}
pub fn fold_values<W>(
&self,
z: W,
mut f: impl FnMut(W, &V) -> W,
) -> Result<W, ThreadMapLockError> {
self.fold(z, |w, (_, v)| f(w, v))
}
pub fn probe(&self) -> Result<HashMap<ThreadId, V>, ThreadMapLockError>
where
V: Clone,
{
let z = HashMap::<ThreadId, V>::new();
self.fold(z, |mut w, (tid, v)| {
w.insert(tid, v.clone());
w
})
}
}
impl<V: Default> Default for ThreadMap<V> {
fn default() -> Self {
Self::new(V::default)
}
}
#[allow(clippy::unwrap_used)]
#[cfg(test)]
mod test {
use super::ThreadMap;
use std::{
collections::HashMap,
thread::{self},
time::Duration,
};
const NTHREADS: i32 = 20;
const NITER: i32 = 10;
const SLEEP_MICROS: u64 = 10;
fn update_value((i0, v0): &mut (i32, i32), i: i32) {
*i0 = i;
*v0 += i;
}
#[test]
fn test_lifecycle() {
let tm: ThreadMap<(i32, i32)> = ThreadMap::default();
thread::scope(|s| {
let tm = &tm;
for i in 0..NTHREADS {
s.spawn(move || {
for _ in 0..NITER {
thread::sleep(Duration::from_micros(SLEEP_MICROS));
tm.with_mut(move |p: &mut (i32, i32)| update_value(p, i));
}
let value = tm.get();
assert_eq!((i, i * NITER), value);
});
}
let probed = tm.probe().unwrap().into_values().collect::<HashMap<_, _>>();
println!("probed={probed:?}");
for _ in 0..NITER {
tm.with_mut(move |p: &mut (i32, i32)| update_value(p, NTHREADS))
}
let probed = tm.probe().unwrap().into_values().collect::<HashMap<_, _>>();
println!("probed={probed:?}");
});
let expected = (0..=NTHREADS)
.map(|i| (i, i * NITER))
.collect::<HashMap<_, _>>();
let expected_sum = expected.values().sum::<i32>();
let sum = tm.fold_values(0, |z, (_, v)| z + v).unwrap();
assert_eq!(expected_sum, sum);
let probed = tm.probe().unwrap().into_values().collect::<HashMap<_, _>>();
println!("probed={probed:?}");
assert_eq!(expected, probed);
let dumped = tm.drain().unwrap().into_values().collect::<HashMap<_, _>>();
assert_eq!(expected, dumped);
}
#[test]
fn test_set() {
let tm: ThreadMap<i32> = ThreadMap::default();
thread::scope(|s| {
let tm = &tm;
for i in 0..NTHREADS {
s.spawn(move || {
tm.set(i);
assert_eq!(i, tm.get());
});
}
});
let expected_sum = (0..NTHREADS).sum::<i32>();
let sum = tm.fold_values(0, |z, v| z + v).unwrap();
assert_eq!(expected_sum, sum);
}
}