#![deny(warnings, missing_docs, clippy::all, rustdoc::broken_intra_doc_links)]
use std::collections::HashMap;
use std::fmt;
use std::fmt::Debug;
use std::ops::Deref;
use std::ptr::NonNull;
use std::sync::Mutex;
#[derive(Debug, Eq, PartialEq, Hash, Copy, Clone)]
struct Handle(usize);
pub struct ThreadMap<T> {
inner: Mutex<ThreadMapLocked<T>>,
}
struct ThreadMapLocked<T> {
idx: usize,
map: HashMap<Handle, NonNull<T>>,
}
unsafe impl<T> Send for ThreadMap<T> {}
unsafe impl<T> Sync for ThreadMap<T> {}
impl<T> Default for ThreadMap<T> {
fn default() -> Self {
Self {
inner: Mutex::new(ThreadMapLocked {
idx: 0,
map: HashMap::new(),
}),
}
}
}
impl<T> ThreadMap<T> {
pub fn register(&'static self, val: T) -> PerThread<T>
where
T: 'static + Sync,
{
let mut storage = Box::new(StableStorage {
val,
handle: Handle(0), map: NonNull::from(self),
});
let mut locked = self.inner.lock().expect("poisoned lock");
storage.handle = Handle(locked.idx);
locked
.map
.insert(storage.handle, NonNull::from(&storage.val));
locked.idx += 1;
PerThread { storage }
}
pub fn for_each<F>(&self, mut cb: F)
where
F: FnMut(&T),
{
let locked = self.inner.lock().expect("lock poisoned");
for val in locked.map.values() {
cb(unsafe { val.as_ref() });
}
}
fn unregister(&self, h: Handle) {
let mut locked = self.inner.lock().expect("poisoned lock");
locked.map.remove(&h);
}
}
pub struct PerThread<T> {
storage: Box<StableStorage<T>>,
}
struct StableStorage<T> {
val: T,
handle: Handle,
map: NonNull<ThreadMap<T>>,
}
unsafe impl<T: Send> Send for StableStorage<T> {}
unsafe impl<T> Sync for StableStorage<T> {}
impl<T> Debug for PerThread<T>
where
T: Debug,
{
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
self.storage.val.fmt(fmt)
}
}
impl<T> Drop for PerThread<T> {
fn drop(&mut self) {
let map = unsafe { self.storage.map.as_ref() };
map.unregister(self.storage.handle)
}
}
impl<T> AsRef<T> for PerThread<T> {
fn as_ref(&self) -> &T {
&self.storage.val
}
}
impl<T> Deref for PerThread<T> {
type Target = T;
fn deref(&self) -> &T {
self.as_ref()
}
}
#[cfg(test)]
mod tests {
use std::collections::HashSet;
use std::hash::Hash;
use once_cell::sync::Lazy;
use super::*;
fn assert_map_content<T>(map: &ThreadMap<T>, expected: &HashSet<T>)
where
T: Clone + fmt::Debug + Hash + Eq,
{
let mut set = HashSet::new();
map.for_each(|el| assert!(set.insert(el.clone())));
assert_eq!(&set, expected);
}
#[test]
fn test_single_thread() {
static TEST_MAP: Lazy<ThreadMap<i64>> = Lazy::new(ThreadMap::default);
static TEST_VAL1: Lazy<PerThread<i64>> = Lazy::new(|| TEST_MAP.register(42));
static TEST_VAL2: Lazy<PerThread<i64>> = Lazy::new(|| TEST_MAP.register(431));
let mut expected_values = HashSet::new();
assert_map_content(&*TEST_MAP, &expected_values);
assert_eq!(**TEST_VAL1, 42);
expected_values.insert(**TEST_VAL1);
assert_map_content(&*TEST_MAP, &expected_values);
assert_eq!(**TEST_VAL2, 431);
expected_values.insert(**TEST_VAL2);
assert_map_content(&*TEST_MAP, &expected_values);
}
#[test]
fn test_integration_with_thread_local() {
use std::sync::mpsc::sync_channel;
struct Ack;
static TEST_MAP: Lazy<ThreadMap<i64>> = Lazy::new(ThreadMap::default);
thread_local! {
static TEST_VAL1: PerThread<i64> = TEST_MAP.register(7);
static TEST_VAL2: PerThread<i64> = TEST_MAP.register(42);
}
let (sender, receiver) = sync_channel(0);
let (r_sender, r_receiver) = sync_channel(0);
let test_thread = ::std::thread::spawn(move || {
receiver.recv().unwrap();
TEST_VAL1.with(|val| assert_eq!(**val, 7));
r_sender.send(Ack).unwrap();
receiver.recv().unwrap();
TEST_VAL2.with(|val| assert_eq!(**val, 42));
r_sender.send(Ack).unwrap();
receiver.recv().unwrap();
});
let mut expected_values = HashSet::new();
assert_map_content(&*TEST_MAP, &expected_values);
sender.send(Ack).unwrap();
r_receiver.recv().unwrap();
expected_values.insert(7);
assert_map_content(&*TEST_MAP, &expected_values);
sender.send(Ack).unwrap();
r_receiver.recv().unwrap();
expected_values.insert(42);
assert_map_content(&*TEST_MAP, &expected_values);
sender.send(Ack).unwrap();
test_thread.join().unwrap();
expected_values.clear();
assert_map_content(&*TEST_MAP, &expected_values);
}
}