#[cfg(all(feature = "alloc", not(feature = "std")))]
use alloc::boxed::Box;
use core::cell::UnsafeCell;
use core::fmt;
use core::marker::PhantomData;
use core::ops::Index;
use core::sync::atomic::{AtomicUsize, Ordering};
#[derive(Debug)]
#[repr(align(128))]
pub struct Local<T>(UnsafeCell<Option<T>>);
impl<T> Default for Local<T> {
#[inline]
fn default() -> Self {
Self::new()
}
}
impl<T> Local<T> {
#[inline]
pub const fn new() -> Self {
Self(UnsafeCell::new(None))
}
}
pub struct BoundedThreadLocal<'s, T> {
storage: Storage<'s, T>,
registered: AtomicUsize,
}
unsafe impl<T> Send for BoundedThreadLocal<'_, T> {}
unsafe impl<T> Sync for BoundedThreadLocal<'_, T> {}
impl<'s, T: Default> BoundedThreadLocal<'s, T> {
#[inline]
pub fn with_buffer(buffer: &'s [Local<T>]) -> Self {
Self::with_buffer_and_init(buffer, Default::default)
}
#[cfg(any(feature = "alloc", feature = "std"))]
#[inline]
pub fn new(max_threads: usize) -> Self {
Self::with_init(max_threads, Default::default)
}
}
impl<'s, T> BoundedThreadLocal<'s, T> {
#[inline]
pub fn with_buffer_and_init(buffer: &'s [Local<T>], init: impl Fn() -> T) -> Self {
for local in buffer {
let slot = unsafe { &mut *local.0.get() };
*slot = Some(init());
}
Self { storage: Storage::Buffer(buffer), registered: AtomicUsize::new(0) }
}
#[inline]
pub fn with_init(max_threads: usize, init: impl Fn() -> T) -> Self {
assert!(max_threads > 0, "`max_threads` must be greater than 0");
Self {
storage: Storage::Heap(
(0..max_threads).map(|_| Local(UnsafeCell::new(Some(init())))).collect(),
),
registered: Default::default(),
}
}
#[inline]
pub fn thread_token(&self) -> Result<Token<'_, T>, BoundsError> {
let token: usize = self.registered.fetch_add(1, Ordering::Relaxed);
assert!(token <= isize::max_value() as usize, "thread counter too close to overflow");
if token < self.storage.len() {
let slot = &self.storage[token].0;
let local = unsafe { (&mut *slot.get()).as_mut().unwrap_or_else(|| unreachable!()) };
Ok(Token { local, _marker: PhantomData })
} else {
Err(BoundsError(()))
}
}
#[inline]
pub fn iter(&mut self) -> IterMut<'s, '_, T> {
IterMut { idx: 0, tls: self }
}
}
impl<'s, T> IntoIterator for BoundedThreadLocal<'s, T> {
type Item = T;
type IntoIter = IntoIter<'s, T>;
#[inline]
fn into_iter(self) -> Self::IntoIter {
IntoIter { tls: self, idx: 0 }
}
}
impl<T> fmt::Debug for BoundedThreadLocal<'_, T> {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("BoundedThreadLocal")
.field("max_size", &self.storage.len())
.field("access_count", &self.registered.load(Ordering::SeqCst))
.finish()
}
}
pub struct Token<'a, T> {
local: &'a mut T,
_marker: PhantomData<*const ()>,
}
impl<'a, T> Token<'a, T> {
#[inline]
pub fn get(&self) -> &T {
&self.local
}
#[inline]
pub fn update(&mut self, func: impl FnOnce(&mut T)) {
func(self.local);
}
}
impl<'a, T: fmt::Debug> fmt::Debug for Token<'a, T> {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Token").field("slot", &self.get()).finish()
}
}
impl<'a, T: fmt::Display> fmt::Display for Token<'a, T> {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Display::fmt(&self.get(), f)
}
}
#[derive(Copy, Clone, Hash, Eq, Ord, PartialEq, PartialOrd)]
pub struct BoundsError(());
impl fmt::Debug for BoundsError {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("BoundsError").finish()
}
}
impl fmt::Display for BoundsError {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "exceeded bounds for `BoundedThreadLocal`")
}
}
#[cfg(feature = "std")]
impl std::error::Error for BoundsError {}
#[derive(Debug)]
pub struct IterMut<'s, 'tls, T> {
idx: usize,
tls: &'tls mut BoundedThreadLocal<'s, T>,
}
impl<'s, 'tls, T> Iterator for IterMut<'s, 'tls, T> {
type Item = &'tls T;
#[inline]
fn next(&mut self) -> Option<Self::Item> {
let idx = self.idx;
if idx < self.tls.storage.len() {
self.idx += 1;
let local = &self.tls.storage[idx];
let slot = unsafe { &*local.0.get() };
Some(slot.as_ref().unwrap_or_else(|| unreachable!()))
} else {
None
}
}
}
#[derive(Debug)]
pub struct IntoIter<'s, T> {
idx: usize,
tls: BoundedThreadLocal<'s, T>,
}
impl<T> Iterator for IntoIter<'_, T> {
type Item = T;
#[inline]
fn next(&mut self) -> Option<Self::Item> {
let idx = self.idx;
if idx < self.tls.storage.len() {
self.idx += 1;
let local = &self.tls.storage[idx];
let slot = unsafe { &mut *local.0.get() };
Some(slot.take().unwrap_or_else(|| unreachable!()))
} else {
None
}
}
}
#[derive(Debug)]
enum Storage<'s, T> {
Buffer(&'s [Local<T>]),
#[cfg(any(feature = "alloc", feature = "std"))]
Heap(Box<[Local<T>]>),
}
impl<T> Storage<'_, T> {
#[inline]
fn len(&self) -> usize {
match self {
Storage::Buffer(slice) => slice.len(),
#[cfg(any(feature = "alloc", feature = "std"))]
Storage::Heap(boxed) => boxed.len(),
}
}
}
impl<T> Index<usize> for Storage<'_, T> {
type Output = Local<T>;
#[inline]
fn index(&self, index: usize) -> &Self::Output {
match self {
Storage::Buffer(slice) => &slice[index],
#[cfg(any(feature = "alloc", feature = "std"))]
Storage::Heap(boxed) => &boxed[index],
}
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::thread;
use super::BoundedThreadLocal;
#[test]
fn into_iter() {
const THREADS: usize = 4;
let tls: Arc<BoundedThreadLocal<usize>> = Arc::new(BoundedThreadLocal::new(THREADS));
let handles: Vec<_> = (0..THREADS)
.map(|_| {
let tls = Arc::clone(&tls);
thread::spawn(move || {
let mut token = tls.thread_token().unwrap();
for _ in 0..10 {
token.update(|curr| *curr += 1);
}
})
})
.collect();
for handle in handles {
handle.join().unwrap();
}
let counter = Arc::try_unwrap(tls).unwrap();
assert_eq!(counter.into_iter().sum::<usize>(), THREADS * 10);
}
}