mod thread_id;
use std::cell::UnsafeCell;
use std::mem::MaybeUninit;
use std::sync::atomic::{self, AtomicBool, AtomicPtr, Ordering};
use std::{mem, ptr};
pub use thread_id::Thread;
pub struct ThreadLocal<T> {
buckets: [AtomicPtr<Entry<T>>; thread_id::BUCKETS],
}
struct Entry<T> {
present: AtomicBool,
value: UnsafeCell<MaybeUninit<T>>,
}
unsafe impl<T: Send> Send for ThreadLocal<T> {}
unsafe impl<T: Send> Sync for ThreadLocal<T> {}
impl<T> ThreadLocal<T> {
pub fn with_capacity(capacity: usize) -> ThreadLocal<T> {
let init = match capacity {
0 => 0,
n => Thread::new(n).bucket,
};
let mut buckets = [ptr::null_mut(); thread_id::BUCKETS];
for (i, bucket) in buckets[..=init].iter_mut().enumerate() {
let bucket_size = Thread::bucket_capacity(i);
*bucket = allocate_bucket::<T>(bucket_size);
}
ThreadLocal {
buckets: unsafe { mem::transmute(buckets) },
}
}
#[inline]
pub unsafe fn load(&self, thread: Thread) -> &T
where
T: Default,
{
unsafe { self.load_or(T::default, thread) }
}
#[inline]
pub unsafe fn load_or(&self, create: impl Fn() -> T, thread: Thread) -> &T {
let bucket = unsafe { self.buckets.get_unchecked(thread.bucket) };
let mut bucket_ptr = bucket.load(Ordering::Acquire);
if bucket_ptr.is_null() {
bucket_ptr = self.initialize(bucket, thread);
}
let entry = unsafe { &*bucket_ptr.add(thread.entry) };
if !entry.present.load(Ordering::Relaxed) {
unsafe { self.write(entry, create) }
}
unsafe { (*entry.value.get()).assume_init_ref() }
}
#[cfg(test)]
fn try_load(&self) -> Option<&T> {
let thread = Thread::current();
let bucket_ptr =
unsafe { self.buckets.get_unchecked(thread.bucket) }.load(Ordering::Acquire);
if bucket_ptr.is_null() {
return None;
}
let entry = unsafe { &*bucket_ptr.add(thread.entry) };
if !entry.present.load(Ordering::Relaxed) {
return None;
}
unsafe { Some((*entry.value.get()).assume_init_ref()) }
}
#[cold]
#[inline(never)]
unsafe fn write(&self, entry: &Entry<T>, create: impl Fn() -> T) {
unsafe { entry.value.get().write(MaybeUninit::new(create())) };
entry.present.store(true, Ordering::Release);
atomic::fence(Ordering::SeqCst);
}
#[cold]
#[inline(never)]
fn initialize(&self, bucket: &AtomicPtr<Entry<T>>, thread: Thread) -> *mut Entry<T> {
let new_bucket = allocate_bucket(Thread::bucket_capacity(thread.bucket));
match bucket.compare_exchange(
ptr::null_mut(),
new_bucket,
Ordering::Release,
Ordering::Acquire,
) {
Ok(_) => new_bucket,
Err(other) => unsafe {
let _ = Box::from_raw(ptr::slice_from_raw_parts_mut(
new_bucket,
Thread::bucket_capacity(thread.bucket),
));
other
},
}
}
#[inline]
pub unsafe fn iter(&self) -> Iter<'_, T> {
Iter {
index: 0,
bucket: 0,
thread_local: self,
bucket_size: Thread::bucket_capacity(0),
}
}
}
impl<T> Drop for ThreadLocal<T> {
fn drop(&mut self) {
for (i, bucket) in self.buckets.iter_mut().enumerate() {
let bucket_ptr = *bucket.get_mut();
if bucket_ptr.is_null() {
continue;
}
let bucket_size = Thread::bucket_capacity(i);
let _ =
unsafe { Box::from_raw(std::slice::from_raw_parts_mut(bucket_ptr, bucket_size)) };
}
}
}
impl<T> Drop for Entry<T> {
fn drop(&mut self) {
if *self.present.get_mut() {
unsafe {
ptr::drop_in_place((*self.value.get()).as_mut_ptr());
}
}
}
}
pub struct Iter<'a, T> {
bucket: usize,
index: usize,
bucket_size: usize,
thread_local: &'a ThreadLocal<T>,
}
impl<'a, T> Iterator for Iter<'a, T> {
type Item = &'a T;
#[inline]
fn next(&mut self) -> Option<Self::Item> {
while self.bucket < thread_id::BUCKETS {
let bucket = unsafe {
self.thread_local
.buckets
.get_unchecked(self.bucket)
.load(Ordering::Acquire)
};
if !bucket.is_null() {
while self.index < self.bucket_size {
let entry = unsafe { &*bucket.add(self.index) };
self.index += 1;
if entry.present.load(Ordering::Acquire) {
return Some(unsafe { (*entry.value.get()).assume_init_ref() });
}
}
}
self.index = 0;
self.bucket += 1;
self.bucket_size <<= 1;
}
None
}
}
fn allocate_bucket<T>(capacity: usize) -> *mut Entry<T> {
let entries = (0..capacity)
.map(|_| Entry::<T> {
present: AtomicBool::new(false),
value: UnsafeCell::new(MaybeUninit::uninit()),
})
.collect::<Box<[Entry<T>]>>();
Box::into_raw(entries) as *mut _
}
#[cfg(test)]
#[allow(clippy::redundant_closure)]
mod tests {
use super::*;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering::Relaxed;
use std::sync::{Arc, Barrier};
use std::thread;
fn make_create() -> Arc<dyn Fn() -> usize + Send + Sync> {
let count = AtomicUsize::new(0);
Arc::new(move || count.fetch_add(1, Relaxed))
}
#[test]
fn same_thread() {
unsafe {
let create = make_create();
let tls = ThreadLocal::with_capacity(1);
assert_eq!(None, tls.try_load());
assert_eq!(0, *tls.load_or(|| create(), Thread::current()));
assert_eq!(Some(&0), tls.try_load());
assert_eq!(0, *tls.load_or(|| create(), Thread::current()));
assert_eq!(Some(&0), tls.try_load());
assert_eq!(0, *tls.load_or(|| create(), Thread::current()));
assert_eq!(Some(&0), tls.try_load());
}
}
#[test]
fn different_thread() {
unsafe {
let create = make_create();
let tls = Arc::new(ThreadLocal::with_capacity(1));
assert_eq!(None, tls.try_load());
assert_eq!(0, *tls.load_or(|| create(), Thread::current()));
assert_eq!(Some(&0), tls.try_load());
let tls2 = tls.clone();
let create2 = create.clone();
thread::spawn(move || {
assert_eq!(None, tls2.try_load());
assert_eq!(1, *tls2.load_or(|| create2(), Thread::current()));
assert_eq!(Some(&1), tls2.try_load());
})
.join()
.unwrap();
assert_eq!(Some(&0), tls.try_load());
assert_eq!(0, *tls.load_or(|| create(), Thread::current()));
}
}
#[test]
fn iter() {
unsafe {
let tls = Arc::new(ThreadLocal::with_capacity(1));
tls.load_or(|| Box::new(1), Thread::current());
let tls2 = tls.clone();
thread::spawn(move || {
tls2.load_or(|| Box::new(2), Thread::current());
let tls3 = tls2.clone();
thread::spawn(move || {
tls3.load_or(|| Box::new(3), Thread::current());
})
.join()
.unwrap();
drop(tls2);
})
.join()
.unwrap();
let tls = Arc::try_unwrap(tls).unwrap_or_else(|_| panic!("."));
let mut v = tls.iter().map(|x| **x).collect::<Vec<i32>>();
v.sort_unstable();
assert_eq!(vec![1, 2, 3], v);
}
}
#[test]
fn iter_snapshot() {
unsafe {
let tls = Arc::new(ThreadLocal::with_capacity(1));
tls.load_or(|| Box::new(1), Thread::current());
let iterator = tls.iter();
tls.load_or(|| Box::new(2), Thread::current());
let v = iterator.map(|x| **x).collect::<Vec<i32>>();
assert_eq!(vec![1], v);
}
}
#[test]
fn test_drop() {
let local = ThreadLocal::with_capacity(1);
struct Dropped(Arc<AtomicUsize>);
impl Drop for Dropped {
fn drop(&mut self) {
self.0.fetch_add(1, Relaxed);
}
}
let dropped = Arc::new(AtomicUsize::new(0));
unsafe {
local.load_or(|| Dropped(dropped.clone()), Thread::current());
}
assert_eq!(dropped.load(Relaxed), 0);
drop(local);
assert_eq!(dropped.load(Relaxed), 1);
}
#[test]
fn iter_many() {
let tls = Arc::new(ThreadLocal::with_capacity(0));
let barrier = Arc::new(Barrier::new(65));
for i in 0..64 {
let tls = tls.clone();
let barrier = barrier.clone();
thread::spawn(move || {
dbg!(i);
unsafe {
tls.load_or(|| 1, Thread::current());
}
barrier.wait();
});
}
barrier.wait();
unsafe { assert_eq!(tls.iter().count(), 64) }
}
}