#[cfg(all(feature = "alloc", not(feature = "std")))]
use alloc::boxed::Box;
use core::cell::UnsafeCell;
use core::fmt;
use core::hint;
use core::ops::Index;
use core::sync::atomic::{AtomicUsize, Ordering};
pub struct BoundedThreadLocal<'s, T> {
storage: Storage<'s, T>,
registered: AtomicUsize,
completed: AtomicUsize,
}
unsafe impl<T> Send for BoundedThreadLocal<'_, T> {}
unsafe impl<T> Sync for BoundedThreadLocal<'_, T> {}
impl<'s, T: Default> BoundedThreadLocal<'s, T> {
#[inline]
pub fn new(max_threads: usize) -> Self {
Self::with_init(max_threads, Default::default)
}
}
impl<'s, T> BoundedThreadLocal<'s, T> {
#[inline]
pub fn with_init(max_threads: usize, init: impl Fn() -> T) -> Self {
assert!(max_threads > 0, "`max_threads` must be greater than 0");
Self {
storage: Storage::Heap(
(0..max_threads).map(|_| Local(UnsafeCell::new(Some(init())))).collect(),
),
registered: AtomicUsize::new(0),
completed: AtomicUsize::new(0),
}
}
}
impl<'s, T> BoundedThreadLocal<'s, T> {
#[inline]
pub const unsafe fn with_buffer(buf: &'s [Local<T>]) -> Self {
Self {
storage: Storage::Buffer(buf),
registered: AtomicUsize::new(0),
completed: AtomicUsize::new(0),
}
}
#[inline]
pub fn thread_token(&self) -> Result<Token<'_, '_, T>, BoundsError> {
let token: usize = self.registered.fetch_add(1, Ordering::Relaxed);
assert!(token <= isize::max_value() as usize, "thread counter close to overflow");
if token < self.storage.len() {
let local = unsafe { self.storage[token].as_ptr_unchecked() };
Ok(Token { local, tls: self })
} else {
Err(BoundsError(()))
}
}
#[inline]
pub fn try_iter(&self) -> Result<Iter<T>, ConcurrentAccessErr> {
let (completed, len) = (self.completed.load(Ordering::Relaxed), self.storage.len());
if completed == len || completed == self.registered.load(Ordering::Relaxed) {
Ok(Iter { idx: 0, tls: self })
} else {
Err(ConcurrentAccessErr(()))
}
}
}
impl<'s, T> IntoIterator for BoundedThreadLocal<'s, T> {
type Item = T;
type IntoIter = IntoIter<'s, T>;
#[inline]
fn into_iter(self) -> Self::IntoIter {
IntoIter { tls: self, idx: 0 }
}
}
impl<T> fmt::Debug for BoundedThreadLocal<'_, T> {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("BoundedThreadLocal")
.field("max_size", &self.storage.len())
.field("access_count", &self.registered.load(Ordering::Relaxed))
.finish()
}
}
#[derive(Debug, Default)]
#[repr(align(64))]
pub struct Local<T>(UnsafeCell<Option<T>>);
unsafe impl<T> Send for Local<T> {}
unsafe impl<T> Sync for Local<T> {}
impl<T> Local<T> {
#[inline]
pub const fn new(local: T) -> Self {
Self(UnsafeCell::new(Some(local)))
}
#[inline]
unsafe fn as_ptr_unchecked(&self) -> *mut T {
(*self.0.get()).as_mut().unwrap_or_else(|| hint::unreachable_unchecked())
}
#[inline]
unsafe fn take_unchecked(&self) -> T {
(*self.0.get()).take().unwrap_or_else(|| hint::unreachable_unchecked())
}
}
pub struct Token<'s, 'tls, T> {
local: *mut T,
tls: &'tls BoundedThreadLocal<'s, T>,
}
impl<T> Token<'_, '_, T> {
#[inline]
pub fn get(&self) -> &T {
unsafe { &(*self.local) }
}
#[inline]
pub fn get_mut(&mut self) -> &mut T {
unsafe { &mut (*self.local) }
}
}
impl<T: fmt::Debug> fmt::Debug for Token<'_, '_, T> {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Token").field("slot", &self.get()).finish()
}
}
impl<T: fmt::Display> fmt::Display for Token<'_, '_, T> {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Display::fmt(&self.get(), f)
}
}
impl<T> Drop for Token<'_, '_, T> {
fn drop(&mut self) {
self.tls.completed.fetch_add(1, Ordering::Relaxed);
}
}
#[derive(Debug)]
pub struct Iter<'s, 'tls, T> {
idx: usize,
tls: &'tls BoundedThreadLocal<'s, T>,
}
impl<'s, 'tls, T> Iterator for Iter<'s, 'tls, T> {
type Item = &'tls T;
#[inline]
fn next(&mut self) -> Option<Self::Item> {
let idx = self.idx;
if idx < self.tls.storage.len() {
self.idx += 1;
let local = unsafe { &*self.tls.storage[idx].as_ptr_unchecked() };
Some(local)
} else {
None
}
}
}
#[derive(Debug)]
pub struct IntoIter<'s, T> {
idx: usize,
tls: BoundedThreadLocal<'s, T>,
}
impl<T> Iterator for IntoIter<'_, T> {
type Item = T;
#[inline]
fn next(&mut self) -> Option<Self::Item> {
let idx = self.idx;
if idx < self.tls.storage.len() {
self.idx += 1;
let local = unsafe { self.tls.storage[idx].take_unchecked() };
Some(local)
} else {
None
}
}
}
#[derive(Copy, Clone, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
pub struct BoundsError(());
impl fmt::Display for BoundsError {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "exceeded bounds for `BoundedThreadLocal`")
}
}
#[cfg(feature = "std")]
impl std::error::Error for BoundsError {}
#[derive(Copy, Clone, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
pub struct ConcurrentAccessErr(());
impl fmt::Display for ConcurrentAccessErr {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "concurrent access from live thread token (not all tokens have yet been dropped")
}
}
#[cfg(feature = "std")]
impl std::error::Error for ConcurrentAccessErr {}
#[derive(Debug)]
enum Storage<'s, T> {
Buffer(&'s [Local<T>]),
Heap(Box<[Local<T>]>),
}
impl<T> Storage<'_, T> {
#[inline]
fn len(&self) -> usize {
match self {
Storage::Buffer(slice) => slice.len(),
Storage::Heap(boxed) => boxed.len(),
}
}
}
impl<T> Index<usize> for Storage<'_, T> {
type Output = Local<T>;
#[inline]
fn index(&self, index: usize) -> &Self::Output {
match self {
&Storage::Buffer(slice) => &slice[index],
&Storage::Heap(ref boxed) => &boxed[index],
}
}
}
#[cfg(test)]
mod tests {
extern crate std;
#[cfg(any(feature = "alloc", feature = "std"))]
use std::sync::Arc;
use std::thread;
use std::vec::Vec;
use super::BoundedThreadLocal;
use crate::Local;
#[test]
fn static_buffer() {
static BUF: [Local<usize>; 4] =
[Local::new(0), Local::new(0), Local::new(0), Local::new(0)];
static TLS: BoundedThreadLocal<usize> = unsafe { BoundedThreadLocal::with_buffer(&BUF) };
let handles: Vec<_> = (0..BUF.len())
.map(|_| {
thread::spawn(move || {
let mut token = TLS.thread_token().unwrap();
for _ in 0..10 {
*token.get_mut() += 1;
}
})
})
.collect();
for handle in handles {
handle.join().unwrap();
}
assert!(TLS.try_iter().unwrap().all(|&count| count == 10));
}
#[test]
fn into_iter() {
const THREADS: usize = 4;
let tls: Arc<BoundedThreadLocal<usize>> = Arc::new(BoundedThreadLocal::new(THREADS));
let handles: Vec<_> = (0..THREADS)
.map(|_| {
let tls = Arc::clone(&tls);
thread::spawn(move || {
let mut token = tls.thread_token().unwrap();
for _ in 0..10 {
*token.get_mut() += 1;
}
})
})
.collect();
for handle in handles {
handle.join().unwrap();
}
let counter = Arc::try_unwrap(tls).unwrap();
assert_eq!(counter.into_iter().sum::<usize>(), THREADS * 10);
}
}