#![deny(missing_docs)]
use core::fmt;
use core::ptr::NonNull;
use std::boxed::Box;
use std::error::Error;
#[cfg(windows)]
mod oskey {
use winapi::um::fibersapi;
pub(crate) type Key = winapi::shared::minwindef::DWORD;
#[allow(non_camel_case_types)]
pub(crate) type c_void = winapi::ctypes::c_void;
#[inline]
pub(crate) unsafe fn create(dtor: Option<unsafe extern "system" fn(*mut c_void)>) -> Key {
fibersapi::FlsAlloc(dtor)
}
#[inline]
pub(crate) unsafe fn set(key: Key, value: *mut c_void) {
let r = fibersapi::FlsSetValue(key, value);
debug_assert_ne!(r, 0);
}
#[inline]
pub(crate) unsafe fn get(key: Key) -> *mut c_void {
fibersapi::FlsGetValue(key)
}
#[inline]
pub(crate) unsafe fn destroy(key: Key) {
let r = fibersapi::FlsFree(key);
debug_assert_ne!(r, 0);
}
}
#[cfg(not(windows))]
mod oskey {
use core::mem::{self, MaybeUninit};
pub(crate) type Key = libc::pthread_key_t;
#[allow(non_camel_case_types)]
pub(crate) type c_void = core::ffi::c_void;
#[inline]
pub(crate) unsafe fn create(dtor: Option<unsafe extern "system" fn(*mut c_void)>) -> Key {
let mut key = MaybeUninit::uninit();
assert_eq!(
libc::pthread_key_create(key.as_mut_ptr(), mem::transmute(dtor)),
0
);
key.assume_init()
}
#[inline]
pub(crate) unsafe fn set(key: Key, value: *mut c_void) {
let r = libc::pthread_setspecific(key, value);
debug_assert_eq!(r, 0);
}
#[inline]
pub(crate) unsafe fn get(key: Key) -> *mut c_void {
libc::pthread_getspecific(key)
}
#[inline]
pub(crate) unsafe fn destroy(key: Key) {
let r = libc::pthread_key_delete(key);
debug_assert_eq!(r, 0);
}
}
use oskey::c_void;
pub struct ThreadLocal<T> {
key: oskey::Key,
init: fn() -> T,
}
impl<T: Default> Default for ThreadLocal<T> {
fn default() -> Self {
ThreadLocal::new(Default::default)
}
}
impl<T> fmt::Debug for ThreadLocal<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.pad("ThreadLocal {{ .. }}")
}
}
#[derive(Clone, Copy, Eq, PartialEq)]
pub struct AccessError {
_private: (),
}
impl fmt::Debug for AccessError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("AccessError").finish()
}
}
impl fmt::Display for AccessError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Display::fmt("already destroyed", f)
}
}
impl Error for AccessError {}
struct ThreadLocalValue<T> {
inner: T,
key: oskey::Key,
}
const GUARD: NonNull<c_void> = NonNull::dangling();
unsafe extern "system" fn thread_local_drop<T>(ptr: *mut c_void) {
let ptr = NonNull::new_unchecked(ptr as *mut ThreadLocalValue<T>);
if ptr != GUARD.cast() {
let value = Box::from_raw(ptr.as_ptr());
oskey::set(value.key, GUARD.as_ptr());
}
}
impl<T> ThreadLocal<T> {
pub fn new(f: fn() -> T) -> Self {
ThreadLocal {
key: unsafe { oskey::create(Some(thread_local_drop::<T>)) },
init: f,
}
}
pub fn with<R, F: FnOnce(&T) -> R>(&self, f: F) -> R {
self.try_with(f)
.expect("cannot access a TLS value during or after it is destroyed")
}
pub fn try_with<R, F: FnOnce(&T) -> R>(&self, f: F) -> Result<R, AccessError> {
let ptr = unsafe { oskey::get(self.key) as *mut ThreadLocalValue<T> };
let value = NonNull::new(ptr).unwrap_or_else(|| unsafe {
let result = NonNull::new_unchecked(Box::into_raw(Box::new(ThreadLocalValue {
inner: (self.init)(),
key: self.key,
})));
oskey::set(self.key, result.as_ptr() as *mut _);
result
});
if value != GUARD.cast() {
Ok(f(&unsafe { value.as_ref() }.inner))
} else {
Err(AccessError { _private: () })
}
}
}
impl<T> Drop for ThreadLocal<T> {
fn drop(&mut self) {
unsafe {
oskey::destroy(self.key);
}
}
}
#[cfg(test)]
pub(crate) mod tests {
use super::ThreadLocal;
use core::cell::{Cell, UnsafeCell};
use crossbeam_utils::thread::scope;
use once_cell::sync::Lazy;
use std::sync::mpsc::{channel, Sender};
use std::sync::RwLock;
use std::thread;
pub static LOCK: Lazy<RwLock<()>> = Lazy::new(|| RwLock::new(()));
#[test]
fn assumptions() {
use super::oskey;
use core::ptr::{self, NonNull};
use core::sync::atomic::{AtomicBool, Ordering};
let _l = LOCK.write().unwrap();
static CALLED: AtomicBool = AtomicBool::new(false);
unsafe extern "system" fn call(_: *mut oskey::c_void) {
CALLED.store(true, Ordering::Release);
}
unsafe {
let key = oskey::create(None);
assert_eq!(oskey::get(key), ptr::null_mut());
oskey::set(key, NonNull::dangling().as_ptr());
assert_eq!(oskey::get(key), NonNull::dangling().as_ptr());
oskey::destroy(key);
let key2 = oskey::create(None);
assert_eq!(key, key2);
assert_eq!(oskey::get(key), ptr::null_mut());
oskey::destroy(key2);
let key = oskey::create(Some(call));
scope(|s| {
s.spawn(|_| {
oskey::get(key);
})
.join()
.unwrap();
assert_eq!(CALLED.load(Ordering::Acquire), false);
s.spawn(|_| {
oskey::set(key, NonNull::dangling().as_ptr());
})
.join()
.unwrap();
assert_eq!(CALLED.load(Ordering::Acquire), true);
CALLED.store(false, Ordering::Release);
s.spawn(|_| {
oskey::set(key, NonNull::dangling().as_ptr());
oskey::set(key, ptr::null_mut());
})
.join()
.unwrap();
assert_eq!(CALLED.load(Ordering::Acquire), false);
})
.unwrap();
}
}
struct Foo(Sender<()>);
impl Drop for Foo {
fn drop(&mut self) {
let Foo(ref s) = *self;
s.send(()).unwrap();
}
}
#[test]
fn smoke_dtor() {
let _l = LOCK.read().unwrap();
let foo = ThreadLocal::new(|| UnsafeCell::new(None));
scope(|s| {
let foo = &foo;
let (tx, rx) = channel();
let _t = s.spawn(move |_| unsafe {
let mut tx = Some(tx);
foo.with(|f| {
*f.get() = Some(Foo(tx.take().unwrap()));
});
});
rx.recv().unwrap();
})
.unwrap();
}
#[test]
fn smoke_no_dtor() {
let _l = LOCK.read().unwrap();
let foo = ThreadLocal::new(|| Cell::new(1));
scope(|s| {
let foo = &foo;
foo.with(|f| {
assert_eq!(f.get(), 1);
f.set(2);
});
let (tx, rx) = channel();
let _t = s.spawn(move |_| {
foo.with(|f| {
assert_eq!(f.get(), 1);
});
tx.send(()).unwrap();
});
rx.recv().unwrap();
foo.with(|f| {
assert_eq!(f.get(), 2);
});
})
.unwrap();
}
#[test]
fn states() {
let _l = LOCK.read().unwrap();
struct Foo;
impl Drop for Foo {
fn drop(&mut self) {
assert!(FOO.try_with(|_| ()).is_err());
}
}
static FOO: Lazy<ThreadLocal<Foo>> = Lazy::new(|| ThreadLocal::new(|| Foo));
thread::spawn(|| {
assert!(FOO.try_with(|_| ()).is_ok());
})
.join()
.ok()
.expect("thread panicked");
}
#[test]
fn circular() {
let _l = LOCK.read().unwrap();
struct S1;
struct S2;
static K1: Lazy<ThreadLocal<UnsafeCell<Option<S1>>>> =
Lazy::new(|| ThreadLocal::new(|| UnsafeCell::new(None)));
static K2: Lazy<ThreadLocal<UnsafeCell<Option<S2>>>> =
Lazy::new(|| ThreadLocal::new(|| UnsafeCell::new(None)));
static mut HITS: u32 = 0;
impl Drop for S1 {
fn drop(&mut self) {
unsafe {
HITS += 1;
if K2.try_with(|_| ()).is_err() {
assert_eq!(HITS, 3);
} else {
if HITS == 1 {
K2.with(|s| *s.get() = Some(S2));
} else {
assert_eq!(HITS, 3);
}
}
}
}
}
impl Drop for S2 {
fn drop(&mut self) {
unsafe {
HITS += 1;
assert!(K1.try_with(|_| ()).is_ok());
assert_eq!(HITS, 2);
K1.with(|s| *s.get() = Some(S1));
}
}
}
thread::spawn(move || {
drop(S1);
})
.join()
.ok()
.expect("thread panicked");
}
#[test]
fn self_referential() {
let _l = LOCK.read().unwrap();
struct S1;
static K1: Lazy<ThreadLocal<UnsafeCell<Option<S1>>>> =
Lazy::new(|| ThreadLocal::new(|| UnsafeCell::new(None)));
impl Drop for S1 {
fn drop(&mut self) {
assert!(K1.try_with(|_| ()).is_err());
}
}
thread::spawn(move || unsafe {
K1.with(|s| *s.get() = Some(S1));
})
.join()
.ok()
.expect("thread panicked");
}
#[test]
fn dtors_in_dtors_in_dtors() {
let _l = LOCK.write().unwrap();
struct S1(Sender<()>);
static K: Lazy<(
ThreadLocal<UnsafeCell<Option<S1>>>,
ThreadLocal<UnsafeCell<Option<Foo>>>,
)> = Lazy::new(|| {
(
ThreadLocal::new(|| UnsafeCell::new(None)),
ThreadLocal::new(|| UnsafeCell::new(None)),
)
});
impl Drop for S1 {
fn drop(&mut self) {
let S1(ref tx) = *self;
unsafe {
let _ = K.1.try_with(|s| *s.get() = Some(Foo(tx.clone())));
}
}
}
let (tx, rx) = channel();
let _t = thread::spawn(move || unsafe {
let mut tx = Some(tx);
K.0.with(|s| *s.get() = Some(S1(tx.take().unwrap())));
});
rx.recv().unwrap();
}
}
#[cfg(test)]
mod dynamic_tests {
use super::tests::LOCK;
use super::ThreadLocal;
use core::cell::RefCell;
use std::collections::HashMap;
use std::vec;
#[test]
fn smoke() {
let _l = LOCK.read().unwrap();
fn square(i: i32) -> i32 {
i * i
}
let foo = ThreadLocal::new(|| square(3));
foo.with(|f| {
assert_eq!(*f, 9);
});
}
#[test]
fn hashmap() {
let _l = LOCK.read().unwrap();
fn map() -> RefCell<HashMap<i32, i32>> {
let mut m = HashMap::new();
m.insert(1, 2);
RefCell::new(m)
}
let foo = ThreadLocal::new(|| map());
foo.with(|map| {
assert_eq!(map.borrow()[&1], 2);
});
}
#[test]
fn refcell_vec() {
let _l = LOCK.read().unwrap();
let foo = ThreadLocal::new(|| RefCell::new(vec![1, 2, 3]));
foo.with(|vec| {
assert_eq!(vec.borrow().len(), 3);
vec.borrow_mut().push(4);
assert_eq!(vec.borrow()[3], 4);
});
}
}