#![deny(
// currently there is no unsafe code
unsafe_code,
// public library should have docs
missing_docs,
)]
extern crate alloc;
use alloc::rc::Rc;
use alloc::sync::{Arc, Weak};
use core::any::Any;
use core::fmt::{Debug, Formatter};
use core::hash::{Hash, Hasher};
use core::marker::PhantomData;
use core::ops::Deref;
use std::thread::ThreadId;
use intid::IntegerId;
use parking_lot::Mutex;
use crate::local_ids::{LiveLocalId, OwnedLocalId};
mod local_ids;
pub struct DroppingThreadLocal<T: Send + Sync + 'static> {
id: OwnedLocalId,
marker: PhantomData<Arc<T>>,
}
impl<T: Send + Sync + 'static + Debug> Debug for DroppingThreadLocal<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
let value = self.get();
f.debug_struct("DroppingThreadLocal")
.field("local_data", &value.as_ref().map(|value| value.as_ref()))
.finish()
}
}
impl<T: Send + Sync + 'static> Default for DroppingThreadLocal<T> {
#[inline]
fn default() -> Self {
DroppingThreadLocal::new()
}
}
impl<T: Send + Sync + 'static> DroppingThreadLocal<T> {
#[inline]
pub fn new() -> Self {
DroppingThreadLocal {
id: OwnedLocalId::alloc(),
marker: PhantomData,
}
}
#[inline]
pub fn get(&self) -> Option<SharedRef<T>> {
THREAD_STATE.with(|thread| {
Some(SharedRef {
thread_id: thread.id,
value: thread.get(self.id.id())?.downcast::<T>().expect("unexpected type"),
})
})
}
pub fn set(&self, value: T) -> Result<SharedRef<T>, SharedRef<T>> {
if let Some(existing) = self.get() {
Err(existing)
} else {
THREAD_STATE.with(|thread| {
let new_value = Arc::new(value) as DynArc;
thread.init(&self.id, &new_value);
Ok(SharedRef {
thread_id: thread.id,
value: new_value.downcast::<T>().unwrap(),
})
})
}
}
pub fn get_or_init(&self, func: impl FnOnce() -> T) -> SharedRef<T> {
match self.get_or_try_init::<core::convert::Infallible>(|| Ok(func())) {
Ok(success) => success,
}
}
pub fn get_or_try_init<E>(&self, func: impl FnOnce() -> Result<T, E>) -> Result<SharedRef<T>, E> {
THREAD_STATE.with(|thread| {
let value = match thread.get(self.id.id()) {
Some(existing) => existing,
None => {
let new_value = Arc::new(func()?) as DynArc;
thread.init(&self.id, &new_value);
new_value
}
};
let value = value.downcast::<T>().expect("unexpected type");
Ok(SharedRef {
thread_id: thread.id,
value,
})
})
}
pub fn snapshot_iter(&self) -> SnapshotIter<T> {
let Some(snapshot) = snapshot_live_threads() else {
return SnapshotIter {
local_id: self.id.clone(),
iter: None,
marker: PhantomData,
};
};
SnapshotIter {
local_id: self.id.clone(),
iter: Some(snapshot.into_iter()),
marker: PhantomData,
}
}
}
impl<T: Send + Sync + 'static> Drop for DroppingThreadLocal<T> {
fn drop(&mut self) {
let Some(snapshot) = snapshot_live_threads() else {
return;
};
for (thread_id, thread) in snapshot {
if let Some(thread) = Weak::upgrade(&thread) {
assert_eq!(thread.id, thread_id);
let value: Option<DynArc> = {
let mut lock = thread.values.lock();
match lock.get_mut(self.id.index()) {
None => None, Some(inner) => inner.take().map(|(_, value)| value),
}
};
drop(value);
}
}
}
}
pub struct SnapshotIter<T: Send + Sync + 'static> {
local_id: OwnedLocalId,
iter: Option<imbl::hashmap::ConsumingIter<(ThreadId, Weak<LiveThreadState>), imbl::shared_ptr::DefaultSharedPtr>>,
marker: PhantomData<Rc<T>>,
}
impl<T: Send + Sync + 'static> Iterator for SnapshotIter<T> {
type Item = SharedRef<T>;
#[inline]
fn next(&mut self) -> Option<Self::Item> {
loop {
let (thread_id, thread) = self.iter.as_mut()?.next()?;
let Some(thread) = Weak::upgrade(&thread) else { continue };
let Some(arc) = ({
let lock = thread.values.lock();
lock.get(self.local_id.index())
.and_then(Option::as_ref)
.map(|(_id, value)| Arc::clone(value))
}) else {
continue;
};
return Some(SharedRef {
thread_id,
value: arc.downcast::<T>().expect("mismatched type"),
});
}
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
match self.iter {
Some(ref iter) => {
(0, Some(iter.len()))
}
None => (0, Some(0)),
}
}
}
impl<T: Send + Sync + 'static> core::iter::FusedIterator for SnapshotIter<T> {}
#[derive(Clone, Debug)]
pub struct SharedRef<T> {
thread_id: ThreadId,
value: Arc<T>,
}
impl<T> SharedRef<T> {
#[inline]
pub fn thread_id(this: &Self) -> ThreadId {
this.thread_id
}
}
impl<T> Deref for SharedRef<T> {
type Target = T;
#[inline]
fn deref(&self) -> &Self::Target {
&self.value
}
}
impl<T> AsRef<T> for SharedRef<T> {
#[inline]
fn as_ref(&self) -> &T {
&self.value
}
}
impl<T: Hash> Hash for SharedRef<T> {
#[inline]
fn hash<H: Hasher>(&self, state: &mut H) {
self.value.hash(state);
}
}
impl<T: Eq> Eq for SharedRef<T> {}
impl<T: PartialEq> PartialEq for SharedRef<T> {
#[inline]
fn eq(&self, other: &Self) -> bool {
self.value == other.value
}
}
impl<T: PartialOrd> PartialOrd for SharedRef<T> {
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
T::partial_cmp(&self.value, &other.value)
}
}
impl<T: Ord> Ord for SharedRef<T> {
fn cmp(&self, other: &Self) -> core::cmp::Ordering {
T::cmp(&self.value, &other.value)
}
}
type LiveThreadMap = imbl::GenericHashMap<
ThreadId,
Weak<LiveThreadState>,
foldhash::fast::RandomState,
imbl::shared_ptr::DefaultSharedPtr,
>;
static LIVE_THREADS: Mutex<Option<LiveThreadMap>> = Mutex::new(None);
fn snapshot_live_threads() -> Option<LiveThreadMap> {
let lock = LIVE_THREADS.lock();
lock.as_ref().cloned()
}
thread_local! {
static THREAD_STATE: Arc<LiveThreadState> = {
let id = std::thread::current().id();
let state = Arc::new(LiveThreadState {
id,
values: Mutex::new(Vec::new()),
});
let mut live_threads = LIVE_THREADS.lock();
let live_threads = live_threads.get_or_insert_default();
use imbl::hashmap::Entry;
match live_threads.entry(id) {
Entry::Occupied(_) => panic!("reinitialized thread"),
Entry::Vacant(entry) => {
entry.insert(Arc::downgrade(&state));
}
}
state
};
}
type DynArc = Arc<dyn Any + Send + Sync + 'static>;
struct LiveThreadState {
id: ThreadId,
values: Mutex<Vec<Option<(OwnedLocalId, DynArc)>>>,
}
impl LiveThreadState {
#[inline]
fn get(&self, id: LiveLocalId) -> Option<DynArc> {
let lock = self.values.lock();
Some(Arc::clone(&lock.get(id.to_int())?.as_ref()?.1))
}
#[cold]
#[inline(never)]
fn init(&self, id: &OwnedLocalId, new_value: &DynArc) {
let mut lock = self.values.lock();
let index = id.index();
match lock.get(index).and_then(Option::as_ref) {
Some(_existing) => {
panic!("unexpected double initialization of thread-local value")
}
None => {
while lock.len() <= index {
lock.push(None)
}
lock[index] = Some((id.clone(), Arc::clone(new_value)));
}
}
}
}
impl Drop for LiveThreadState {
fn drop(&mut self) {
self.values.get_mut().clear();
{
let mut threads = LIVE_THREADS.lock();
if let Some(threads) = threads.as_mut() {
let state: Option<Weak<LiveThreadState>> = threads.remove(&self.id);
drop(state)
}
}
}
}