mod tid;
pub use self::tid::ThreadId;
use owned_alloc::{Cache, OwnedAlloc, UninitAlloc};
use ptr::check_null_align;
use std::{
fmt,
marker::PhantomData,
mem::{forget, replace},
ptr::{null_mut, NonNull},
sync::atomic::{AtomicPtr, Ordering::*},
};
const BITS: usize = 8;
pub struct ThreadLocal<T> {
top: OwnedAlloc<Table<T>>,
}
impl<T> ThreadLocal<T> {
pub fn new() -> Self {
check_null_align::<Table<T>>();
check_null_align::<Entry<T>>();
Self { top: Table::new_alloc() }
}
pub fn clear(&mut self) {
let mut tables = Vec::new();
unsafe { self.top.clear(&mut tables) }
while let Some(mut table) = tables.pop() {
unsafe { table.free_nodes(&mut tables) }
}
}
pub fn iter(&self) -> Iter<T>
where
T: Sync,
{
Iter { curr_table: Some((&self.top, 0)), tables: Vec::new() }
}
pub fn iter_mut(&mut self) -> IterMut<T>
where
T: Send,
{
IterMut { curr_table: Some((&mut self.top, 0)), tables: Vec::new() }
}
#[inline]
pub fn get(&self) -> Option<&T> {
self.get_with_id(ThreadId::current())
}
pub fn get_with_id(&self, id: ThreadId) -> Option<&T> {
let mut table = &*self.top;
let mut shifted = id.bits();
loop {
let index = shifted & (1 << BITS) - 1;
let in_place = table.nodes[index].atomic.load(Acquire);
if in_place.is_null() {
break None;
}
if in_place as usize & 1 == 0 {
let entry = unsafe { &*(in_place as *mut Entry<T>) };
break if entry.id == id {
Some(&entry.data)
} else {
None
};
}
let table_ptr = (in_place as usize & !1) as *mut Table<T>;
table = unsafe { &*table_ptr };
shifted >>= BITS;
}
}
#[inline]
pub fn with_init<F>(&self, init: F) -> &T
where
F: FnOnce() -> T,
{
self.with_id_and_init(ThreadId::current(), init)
}
pub fn with_id_and_init<F>(&self, id: ThreadId, init: F) -> &T
where
F: FnOnce() -> T,
{
let mut table = &*self.top;
let mut depth = 1;
let mut shifted = id.bits();
let mut index = shifted & (1 << BITS) - 1;
let mut in_place = table.nodes[index].atomic.load(Acquire);
let mut init = LazyInit::Pending(move || Entry { id, data: init() });
let mut tbl_cache = Cache::<OwnedAlloc<Table<T>>>::new();
loop {
if in_place.is_null() {
let nnptr = init.get();
debug_assert!(nnptr.as_ptr() as usize & 1 == 0);
match table.nodes[index].atomic.compare_exchange(
in_place,
nnptr.as_ptr() as *mut (),
AcqRel,
Acquire,
) {
Ok(_) => {
break unsafe { &(*nnptr.as_ptr()).data };
},
Err(new) => in_place = new,
}
} else if in_place as usize & 1 == 0 {
let entry = unsafe { &*(in_place as *mut Entry<T>) };
if entry.id == id {
debug_assert!(init.is_pending());
break &entry.data;
}
let new_tbl = tbl_cache.take_or(Table::new_alloc);
let other_shifted = entry.id.bits() >> depth * BITS;
let other_index = other_shifted & (1 << BITS) - 1;
new_tbl.nodes[other_index].atomic.store(in_place, Relaxed);
let new_tbl_ptr = new_tbl.into_raw();
match table.nodes[index].atomic.compare_exchange(
in_place,
(new_tbl_ptr.as_ptr() as usize | 1) as *mut (),
AcqRel,
Release,
) {
Ok(_) => {
table = unsafe { &*new_tbl_ptr.as_ptr() };
depth += 1;
shifted >>= BITS;
index = shifted & (1 << BITS) - 1;
in_place = table.nodes[index].atomic.load(Acquire);
},
Err(new) => {
let new_tbl =
unsafe { OwnedAlloc::from_raw(new_tbl_ptr) };
new_tbl.nodes[other_index]
.atomic
.store(null_mut(), Relaxed);
tbl_cache.store(new_tbl);
in_place = new;
},
}
} else {
let table_ptr = (in_place as usize & !1) as *mut Table<T>;
table = unsafe { &*table_ptr };
depth += 1;
shifted >>= BITS;
index = shifted & (1 << BITS) - 1;
in_place = table.nodes[index].atomic.load(Acquire);
}
}
}
#[inline]
pub fn with_default(&self) -> &T
where
T: Default,
{
self.with_init(T::default)
}
#[inline]
pub fn with_id_and_default(&self, id: ThreadId) -> &T
where
T: Default,
{
self.with_id_and_init(id, T::default)
}
}
impl<T> Drop for ThreadLocal<T> {
fn drop(&mut self) {
let mut tables = Vec::new();
unsafe { self.top.free_nodes(&mut tables) }
while let Some(mut table) = tables.pop() {
unsafe { table.free_nodes(&mut tables) }
}
}
}
impl<T> fmt::Debug for ThreadLocal<T>
where
T: fmt::Debug,
{
fn fmt(&self, fmtr: &mut fmt::Formatter) -> fmt::Result {
write!(fmtr, "ThreadLocal {} storage: ", '{')?;
match self.get() {
Some(val) => write!(fmtr, "Some({:?})", val)?,
None => write!(fmtr, "None")?,
}
write!(fmtr, "{}", '}')
}
}
impl<T> Default for ThreadLocal<T> {
fn default() -> Self {
Self::new()
}
}
unsafe impl<T> Send for ThreadLocal<T> {}
unsafe impl<T> Sync for ThreadLocal<T> {}
impl<T> IntoIterator for ThreadLocal<T>
where
T: Send,
{
type IntoIter = IntoIter<T>;
type Item = T;
fn into_iter(self) -> Self::IntoIter {
let raw = self.top.raw();
forget(self);
let top = unsafe { OwnedAlloc::from_raw(raw) };
IntoIter { curr_table: Some((top, 0)), tables: Vec::new() }
}
}
impl<'tls, T> IntoIterator for &'tls ThreadLocal<T>
where
T: Sync,
{
type IntoIter = Iter<'tls, T>;
type Item = &'tls T;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
impl<'tls, T> IntoIterator for &'tls mut ThreadLocal<T>
where
T: Send,
{
type IntoIter = IterMut<'tls, T>;
type Item = &'tls mut T;
fn into_iter(self) -> Self::IntoIter {
self.iter_mut()
}
}
pub struct Iter<'tls, T>
where
T: 'tls,
{
tables: Vec<&'tls Table<T>>,
curr_table: Option<(&'tls Table<T>, usize)>,
}
impl<'tls, T> Iterator for Iter<'tls, T> {
type Item = &'tls T;
fn next(&mut self) -> Option<Self::Item> {
loop {
let (table, index) = self.curr_table.take()?;
match table.nodes.get(index).map(|node| node.atomic.load(Acquire)) {
Some(ptr) if ptr.is_null() => {
self.curr_table = Some((table, index + 1))
},
Some(ptr) if ptr as usize & 1 == 0 => {
let ptr = ptr as *mut Entry<T>;
self.curr_table = Some((table, index + 1));
break Some(unsafe { &(*ptr).data });
},
Some(ptr) => {
let ptr = (ptr as usize & !1) as *mut Table<T>;
self.tables.push(unsafe { &mut *ptr });
self.curr_table = Some((table, index + 1));
},
None => self.curr_table = self.tables.pop().map(|tbl| (tbl, 0)),
};
}
}
}
pub struct IterMut<'tls, T>
where
T: 'tls,
{
tables: Vec<&'tls mut Table<T>>,
curr_table: Option<(&'tls mut Table<T>, usize)>,
}
impl<'tls, T> Iterator for IterMut<'tls, T> {
type Item = &'tls mut T;
fn next(&mut self) -> Option<Self::Item> {
loop {
let (table, index) = self.curr_table.take()?;
match table.nodes.get_mut(index).map(|node| *node.atomic.get_mut())
{
Some(ptr) if ptr.is_null() => {
self.curr_table = Some((table, index + 1))
},
Some(ptr) if ptr as usize & 1 == 0 => {
let ptr = ptr as *mut Entry<T>;
self.curr_table = Some((table, index + 1));
break Some(unsafe { &mut (*ptr).data });
},
Some(ptr) => {
let ptr = (ptr as usize & !1) as *mut Table<T>;
self.tables.push(unsafe { &mut *ptr });
self.curr_table = Some((table, index + 1));
},
None => self.curr_table = self.tables.pop().map(|tbl| (tbl, 0)),
};
}
}
}
impl<'tls, T> fmt::Debug for IterMut<'tls, T> {
fn fmt(&self, fmtr: &mut fmt::Formatter) -> fmt::Result {
write!(
fmtr,
"IterMut {} tables: {:?}, curr_table: {:?} {}",
'{', self.tables, self.curr_table, '}'
)
}
}
pub struct IntoIter<T> {
tables: Vec<OwnedAlloc<Table<T>>>,
curr_table: Option<(OwnedAlloc<Table<T>>, usize)>,
}
impl<T> Iterator for IntoIter<T> {
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
loop {
let (mut table, index) = self.curr_table.take()?;
match table.nodes.get_mut(index).map(|node| *node.atomic.get_mut())
{
Some(ptr) if ptr.is_null() => {
self.curr_table = Some((table, index + 1))
},
Some(ptr) if ptr as usize & 1 == 0 => {
let ptr = ptr as *mut Entry<T>;
let alloc = unsafe {
OwnedAlloc::from_raw(NonNull::new_unchecked(ptr))
};
let (entry, _) = alloc.move_inner();
self.curr_table = Some((table, index + 1));
break Some(entry.data);
},
Some(ptr) => {
let ptr = (ptr as usize & !1) as *mut Table<T>;
self.tables.push(unsafe {
OwnedAlloc::from_raw(NonNull::new_unchecked(ptr))
});
self.curr_table = Some((table, index + 1));
},
None => self.curr_table = self.tables.pop().map(|tbl| (tbl, 0)),
};
}
}
}
impl<T> fmt::Debug for IntoIter<T> {
fn fmt(&self, fmtr: &mut fmt::Formatter) -> fmt::Result {
write!(
fmtr,
"IterMut {} tables: {:?}, curr_table: {:?} {}",
'{', self.tables, self.curr_table, '}'
)
}
}
struct Node<T> {
atomic: AtomicPtr<()>,
_marker: PhantomData<T>,
}
impl<T> Node<T> {
unsafe fn free_ptr(
ptr: *mut (),
tbl_stack: &mut Vec<OwnedAlloc<Table<T>>>,
) {
if ptr.is_null() {
return;
}
if ptr as usize & 1 == 0 {
OwnedAlloc::from_raw(NonNull::new_unchecked(ptr as *mut Entry<T>));
} else {
let table_ptr = (ptr as usize & !1) as *mut Table<T>;
debug_assert!(!table_ptr.is_null());
tbl_stack
.push(OwnedAlloc::from_raw(NonNull::new_unchecked(table_ptr)));
}
}
}
impl<T> fmt::Debug for Node<T> {
fn fmt(&self, fmtr: &mut fmt::Formatter) -> fmt::Result {
write!(fmtr, "Node {} pointer: {:?} {}", '{', self.atomic, '}')
}
}
#[repr(align(/* at least */ 2))]
struct Table<T> {
nodes: [Node<T>; 1 << BITS],
}
impl<T> Table<T> {
#[inline]
fn new_alloc() -> OwnedAlloc<Self> {
unsafe { UninitAlloc::<Self>::new().init_in_place(|this| this.init()) }
}
#[inline]
unsafe fn init(&mut self) {
for node_ref in &mut self.nodes as &mut [_] {
(node_ref as *mut Node<T>).write(Node {
atomic: AtomicPtr::new(null_mut()),
_marker: PhantomData,
})
}
}
#[inline]
unsafe fn free_nodes(&mut self, tbl_stack: &mut Vec<OwnedAlloc<Table<T>>>) {
for node in &mut self.nodes as &mut [Node<T>] {
Node::free_ptr(*node.atomic.get_mut(), tbl_stack);
}
}
#[inline]
unsafe fn clear(&mut self, tbl_stack: &mut Vec<OwnedAlloc<Table<T>>>) {
for node in &mut self.nodes as &mut [Node<T>] {
let ptr = node.atomic.get_mut();
Node::free_ptr(*ptr, tbl_stack);
*ptr = null_mut();
}
}
}
impl<T> fmt::Debug for Table<T> {
fn fmt(&self, fmtr: &mut fmt::Formatter) -> fmt::Result {
write!(
fmtr,
"Table {} nodes: {:?} {}",
'{', &self.nodes as &[Node<T>], '}'
)
}
}
#[repr(align(64))]
struct Entry<T> {
data: T,
id: ThreadId,
}
enum LazyInit<T, F> {
Done(NonNull<T>),
Pending(F),
}
impl<T, F> LazyInit<T, F>
where
F: FnOnce() -> T,
{
fn is_pending(&self) -> bool {
match self {
LazyInit::Pending(_) => true,
_ => false,
}
}
fn get(&mut self) -> NonNull<T> {
let old = replace(self, LazyInit::Done(NonNull::dangling()));
let ptr = match old {
LazyInit::Done(ptr) => ptr,
LazyInit::Pending(init) => OwnedAlloc::new(init()).into_raw(),
};
*self = LazyInit::Done(ptr);
ptr
}
}
#[cfg(test)]
mod test {
use super::ThreadLocal;
use std::{
sync::{Arc, Barrier},
thread,
};
#[test]
fn threads_with_their_id() {
const THREADS: usize = 32;
let tls = Arc::new(ThreadLocal::new());
let mut threads = Vec::with_capacity(THREADS);
let barrier = Arc::new(Barrier::new(THREADS));
for i in 0 .. THREADS {
let tls = tls.clone();
let barrier = barrier.clone();
threads.push(thread::spawn(move || {
assert_eq!(*tls.with_init(|| i), i);
barrier.wait();
}))
}
for thread in threads {
thread.join().unwrap();
}
}
#[test]
fn iter() {
const THREADS: usize = 32;
let tls = Arc::new(ThreadLocal::new());
let mut threads = Vec::with_capacity(THREADS);
let barrier = Arc::new(Barrier::new(THREADS));
for i in 0 .. THREADS {
let tls = tls.clone();
let barrier = barrier.clone();
threads.push(thread::spawn(move || {
tls.with_init(|| i);
barrier.wait();
}))
}
for entry in &*tls {
assert!(*entry < THREADS);
}
}
#[test]
fn iter_mut() {
const THREADS: usize = 32;
let tls = Arc::new(ThreadLocal::new());
let mut threads = Vec::with_capacity(THREADS);
let barrier = Arc::new(Barrier::new(THREADS));
for i in 0 .. THREADS {
let tls = tls.clone();
let barrier = barrier.clone();
threads.push(thread::spawn(move || {
tls.with_init(|| i);
barrier.wait();
}))
}
for thread in threads {
thread.join().unwrap();
}
let mut done = [0; THREADS];
let mut tls = Arc::try_unwrap(tls).unwrap();
for entry in &mut tls {
done[*entry] += 1;
*entry = (*entry + 1) % THREADS;
}
for entry in tls {
done[entry] += 1;
}
for &status in &done as &[_] {
assert_eq!(status, 2);
}
}
}