use core::{cell::UnsafeCell, mem::MaybeUninit};
use crate::ThreadId;
pub const DEFAULT_MAX_THREADS: usize = usize::BITS as usize;
pub struct ThreadLocal<T: Send, const MAX_THREADS: usize = DEFAULT_MAX_THREADS> {
entries: [UnsafeCell<MaybeUninit<T>>; MAX_THREADS],
is_present: [UnsafeCell<bool>; MAX_THREADS],
}
unsafe impl<const MAX_THREADS: usize, T: Send> Send for ThreadLocal<T, MAX_THREADS> {}
unsafe impl<const MAX_THREADS: usize, T: Send> Sync for ThreadLocal<T, MAX_THREADS> {}
impl<const MAX_THREADS: usize, T: Send> ThreadLocal<T, MAX_THREADS> {
const DEFAULT_ENTRY: UnsafeCell<MaybeUninit<T>> = UnsafeCell::new(MaybeUninit::uninit());
const DEFAULT_PRESENCE: UnsafeCell<bool> = UnsafeCell::new(false);
pub fn new() -> Self {
Self {
entries: [Self::DEFAULT_ENTRY; MAX_THREADS],
is_present: [Self::DEFAULT_PRESENCE; MAX_THREADS],
}
}
#[inline]
pub fn get(&self) -> Option<&T> {
self.get_inner(ThreadId::current().as_usize())
}
#[inline]
pub fn get_or_default(&self) -> &T
where
T: Default,
{
self.get_or(|| Default::default())
}
#[inline]
pub fn get_or<F>(&self, create: F) -> &T
where
F: FnOnce() -> T,
{
let thread_id = ThreadId::current().as_usize();
if let Some(val) = self.get_inner(thread_id) {
return val;
}
self.insert(thread_id, create())
}
#[inline]
pub fn get_or_try<F, E>(&self, create: F) -> Result<&T, E>
where
F: FnOnce() -> Result<T, E>,
{
let thread_id = ThreadId::current().as_usize();
if let Some(val) = self.get_inner(thread_id) {
return Ok(val);
}
Ok(self.insert(thread_id, create()?))
}
#[inline]
pub fn iter_mut(&mut self) -> IterMut<'_, MAX_THREADS, T> {
IterMut {
thread_local: self,
cursor: 0,
}
}
#[inline]
fn get_inner(&self, thread_id: usize) -> Option<&T> {
if self.is_present(thread_id) {
Some(unsafe { (&*self.entries.get_unchecked(thread_id).get()).assume_init_ref() })
} else {
None
}
}
fn get_inner_ptr(&self, thread_id: usize) -> Option<*mut T> {
if self.is_present(thread_id) {
Some(
unsafe { (&mut *self.entries.get_unchecked(thread_id).get()).assume_init_mut() }
as *mut _,
)
} else {
None
}
}
#[inline]
fn is_present(&self, thread_id: usize) -> bool {
let Some(is_present) = self
.is_present
.get(thread_id)
.map(|slot| unsafe { *slot.get() })
else {
panic!(
"Too many threads used with thid::ThreadLocal. ThreadLocal has {MAX_THREADS} slots
but the ID of the thread that attempted to access it is {thread_id}."
);
};
is_present
}
#[cold]
fn insert(&self, thread_id: usize, value: T) -> &T {
unsafe {
*self.is_present.get_unchecked(thread_id).get() = true;
let entry = &mut *self.entries.get_unchecked(thread_id).get();
entry.write(value)
}
}
}
impl<const MAX_THREADS: usize, T: Send> Drop for ThreadLocal<T, MAX_THREADS> {
fn drop(&mut self) {
for thread_id in 0..MAX_THREADS {
let is_present = unsafe { *self.is_present[thread_id].get() };
if is_present {
unsafe { (&mut *self.entries[thread_id].get()).assume_init_drop() }
}
}
}
}
pub struct IterMut<'a, const MAX_THREADS: usize, T: Send> {
thread_local: &'a ThreadLocal<T, MAX_THREADS>,
cursor: usize,
}
impl<'a, const MAX_THREADS: usize, T: Send> Iterator for IterMut<'a, MAX_THREADS, T> {
type Item = &'a mut T;
fn next(&mut self) -> Option<Self::Item> {
while self.cursor < MAX_THREADS {
if let Some(entry_ptr) = self.thread_local.get_inner_ptr(self.cursor) {
self.cursor += 1;
return Some(unsafe { &mut *entry_ptr });
}
self.cursor += 1;
}
None
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_thread_local_noncopy() {
struct TestDrop<'a> {
some_value: u32,
did_drop: &'a mut bool,
}
impl Drop for TestDrop<'_> {
fn drop(&mut self) {
*self.did_drop = true;
}
}
let mut did_drop = false;
let t = ThreadLocal::<TestDrop>::new();
assert!(t.get().is_none());
let v = t.get_or(|| TestDrop {
some_value: 42,
did_drop: &mut did_drop,
});
assert_eq!(v.some_value, 42);
drop(t);
assert!(did_drop);
}
#[cfg(feature = "std")]
#[test]
fn test_thread_local_end_to_end() {
use std::cell::Cell;
use std::sync::Arc;
let t = Arc::new(ThreadLocal::<Vec<Cell<u32>>>::new());
let threads = (0..4)
.map(|x| {
let t = t.clone();
std::thread::spawn(move || {
let v = t.get_or(|| vec![Cell::new(x * 10), Cell::new(x * 10 + 5)]);
assert_eq!(v[0].get(), x * 10);
assert_eq!(v[1].get(), x * 10 + 5);
t.get().unwrap()[0].set(t.get().unwrap()[0].get() + 1);
t.get().unwrap()[1].set(t.get().unwrap()[1].get() + 1);
assert_eq!(v[0].get(), x * 10 + 1);
assert_eq!(v[1].get(), x * 10 + 5 + 1);
})
})
.collect::<Vec<_>>();
for thread in threads {
thread.join().unwrap();
}
let mut thread_local_owned = Arc::into_inner(t).unwrap();
let mut entries = thread_local_owned.iter_mut().collect::<Vec<_>>();
entries.sort_by_key(|x| x[0].get());
assert_eq!(
entries,
vec![
&mut vec![Cell::new(1), Cell::new(6)],
&mut vec![Cell::new(11), Cell::new(16)],
&mut vec![Cell::new(21), Cell::new(26)],
&mut vec![Cell::new(31), Cell::new(36)],
],
);
}
}