use crate::{POISONED_OBJECT_RW_LOCK, POISONED_THREAD_LOCK};
use super::ThreadMapLockError;
use std::{
collections::HashMap,
mem::take,
ops::DerefMut,
sync::{Mutex, RwLock},
thread::{self, ThreadId},
};
#[doc = include_str!("../examples/doc_thread_map_x.rs")]
#[derive(Debug)]
pub struct ThreadMapX<V> {
state: RwLock<HashMap<ThreadId, Mutex<V>>>,
value_init: fn() -> V,
}
impl<V> ThreadMapX<V> {
pub fn new(value_init: fn() -> V) -> Self {
Self {
state: RwLock::new(HashMap::new()),
value_init,
}
}
pub fn with_mut<W>(&self, f: impl FnOnce(&mut V) -> W) -> W {
let lock = self.state.read().expect(POISONED_OBJECT_RW_LOCK);
let tid = thread::current().id();
match lock.get(&tid) {
Some(c) => {
let mut v = c.lock();
let rv = v.as_mut().expect(POISONED_THREAD_LOCK);
f(rv)
}
None => {
drop(lock);
let mut lock = self.state.write().expect(POISONED_OBJECT_RW_LOCK);
let mut v0 = (self.value_init)();
let w = f(&mut v0);
lock.insert(tid, Mutex::new(v0));
w
}
}
}
pub fn with<W>(&self, f: impl FnOnce(&V) -> W) -> W {
let g = |v: &mut V| f(v);
self.with_mut(g)
}
pub fn get(&self) -> V
where
V: Clone,
{
self.with(|v| v.clone())
}
pub fn set(&self, v: V) {
self.with_mut(|v0| *v0 = v);
}
pub fn drain(&self) -> Result<HashMap<ThreadId, V>, ThreadMapLockError> {
let mut lock = self.state.write()?;
let rmap = lock.deref_mut();
let tmap = take(rmap);
tmap.into_iter()
.map(|(k, v)| {
let v = v.into_inner()?;
Ok((k, v))
})
.collect()
}
pub fn fold<W>(
&self,
z: W,
mut f: impl FnMut(W, (ThreadId, &V)) -> W,
) -> Result<W, ThreadMapLockError> {
self.state.read()?.iter().try_fold(z, |w, (tid, v)| {
let tid = *tid;
let mut mlock = v.lock()?;
let v = mlock.deref_mut();
Ok(f(w, (tid, v)))
})
}
pub fn fold_values<W>(
&self,
z: W,
mut f: impl FnMut(W, &V) -> W,
) -> Result<W, ThreadMapLockError> {
self.fold(z, |w, (_, v)| f(w, v))
}
pub fn probe(&self) -> Result<HashMap<ThreadId, V>, ThreadMapLockError>
where
V: Clone,
{
let z = HashMap::<ThreadId, V>::new();
self.fold(z, |mut w, (tid, v)| {
w.insert(tid, v.clone());
w
})
}
}
impl<V: Default> Default for ThreadMapX<V> {
fn default() -> Self {
Self::new(V::default)
}
}
#[allow(clippy::unwrap_used)]
#[cfg(test)]
mod test {
use super::ThreadMapX;
use std::{
collections::HashMap,
thread::{self},
time::Duration,
};
const NTHREADS: i32 = 20;
const NITER: i32 = 10;
const SLEEP_MICROS: u64 = 10;
fn update_value((i0, v0): &mut (i32, i32), i: i32) {
*i0 = i;
*v0 += i;
}
#[test]
fn test_lifecycle() {
let tm: ThreadMapX<(i32, i32)> = ThreadMapX::default();
thread::scope(|s| {
let tm = &tm;
for i in 0..NTHREADS {
s.spawn(move || {
for _ in 0..NITER {
thread::sleep(Duration::from_micros(SLEEP_MICROS));
tm.with_mut(move |p: &mut (i32, i32)| update_value(p, i));
}
let value = tm.get();
assert_eq!((i, i * NITER), value);
});
}
let probed = tm.probe().unwrap().into_values().collect::<HashMap<_, _>>();
println!("probed={probed:?}");
for _ in 0..NITER {
tm.with_mut(move |p: &mut (i32, i32)| update_value(p, NTHREADS))
}
let probed = tm.probe().unwrap().into_values().collect::<HashMap<_, _>>();
println!("probed={probed:?}");
});
let expected = (0..=NTHREADS)
.map(|i| (i, i * NITER))
.collect::<HashMap<_, _>>();
let expected_sum = expected.values().sum::<i32>();
let sum = tm.fold_values(0, |z, (_, v)| z + v).unwrap();
assert_eq!(expected_sum, sum);
let probed = tm.probe().unwrap().into_values().collect::<HashMap<_, _>>();
println!("probed={probed:?}");
assert_eq!(expected, probed);
let dumped = tm.drain().unwrap().into_values().collect::<HashMap<_, _>>();
assert_eq!(expected, dumped);
}
#[test]
fn test_set() {
let tm: ThreadMapX<i32> = ThreadMapX::default();
thread::scope(|s| {
let tm = &tm;
for i in 0..NTHREADS {
s.spawn(move || {
tm.set(i);
assert_eq!(i, tm.get());
});
}
});
let expected_sum = (0..NTHREADS).sum::<i32>();
let sum = tm.fold_values(0, |z, v| z + v).unwrap();
assert_eq!(expected_sum, sum);
}
}