use std::{
any::Any,
boxed::Box,
cell::{BorrowError, BorrowMutError, Cell, RefCell, UnsafeCell},
collections::btree_map::BTreeMap,
ptr,
rc::Rc,
sync::{
LazyLock,
atomic::{AtomicU32, Ordering},
},
};
use crate::executor::EXECUTOR;
#[derive(Debug)]
pub struct LocalKey<T: 'static> {
init: fn() -> T,
key: LazyLock<u32>,
}
unsafe impl<T> Sync for LocalKey<T> {}
unsafe impl<T> Send for LocalKey<T> {}
#[macro_export]
macro_rules! task_local {
{
$(#[$attr:meta])*
$vis:vis static $name:ident: $type:ty = $init:expr;
} => {
$(#[$attr])*
$vis static $name: $crate::task::LocalKey<$type> = {
fn init() -> $type { $init }
$crate::task::LocalKey::new(init)
};
};
{
$(#[$attr:meta])*
$vis:vis static $name:ident: $type:ty = $init:expr;
$($rest:tt)*
} => {
$crate::task_local!($vis static $name: $type = $init;);
$crate::task_local!($($rest)*);
}
}
pub use task_local;
impl<T: 'static> LocalKey<T> {
#[doc(hidden)]
pub const fn new(init: fn() -> T) -> Self {
static LOCAL_KEY_COUNTER: AtomicU32 = AtomicU32::new(0);
Self {
init,
key: LazyLock::new(|| LOCAL_KEY_COUNTER.fetch_add(1, Ordering::Relaxed)),
}
}
pub fn with<F, R>(&'static self, f: F) -> R
where
F: FnOnce(&T) -> R,
{
TaskLocalStorage::with_current(|storage| {
f(unsafe { storage.get_or_init(*self.key, self.init) })
})
}
}
impl<T: 'static> LocalKey<Cell<T>> {
pub fn get(&'static self) -> T
where
T: Copy,
{
self.with(Cell::get)
}
pub fn set(&'static self, value: T) {
self.with(|cell| cell.set(value));
}
pub fn take(&'static self) -> T
where
T: Default,
{
self.with(Cell::take)
}
pub fn replace(&'static self, value: T) -> T {
self.with(|cell| cell.replace(value))
}
}
impl<T: 'static> LocalKey<RefCell<T>> {
pub fn with_borrow<F, R>(&'static self, f: F) -> R
where
F: FnOnce(&T) -> R,
{
self.with(|cell| f(&cell.borrow()))
}
pub fn with_borrow_mut<F, R>(&'static self, f: F) -> R
where
F: FnOnce(&mut T) -> R,
{
self.with(|cell| f(&mut cell.borrow_mut()))
}
pub fn try_with_borrow<F, R>(&'static self, f: F) -> Result<R, BorrowError>
where
F: FnOnce(&T) -> R,
{
self.with(|cell| cell.try_borrow().map(|value| f(&value)))
}
pub fn try_with_borrow_mut<F, R>(&'static self, f: F) -> Result<R, BorrowMutError>
where
F: FnOnce(&T) -> R,
{
self.with(|cell| cell.try_borrow_mut().map(|value| f(&value)))
}
pub fn set(&'static self, value: T) {
self.with_borrow_mut(|refmut| *refmut = value);
}
pub fn take(&'static self) -> T
where
T: Default,
{
self.with(RefCell::take)
}
pub fn replace(&'static self, value: T) -> T {
self.with(|cell| cell.replace(value))
}
}
struct ErasedTaskLocal {
value: Box<dyn Any>,
}
impl ErasedTaskLocal {
#[doc(hidden)]
fn new<T: 'static>(value: T) -> Self {
Self {
value: Box::new(value),
}
}
unsafe fn get<T: 'static>(&self) -> &T {
if cfg!(debug_assertions) {
self.value.downcast_ref().unwrap()
} else {
unsafe { &*ptr::from_ref(&*self.value).cast() }
}
}
}
thread_local! {
static FALLBACK_TLS: TaskLocalStorage = const { TaskLocalStorage::new() };
}
#[derive(Debug)]
pub(crate) struct TaskLocalStorage {
locals: UnsafeCell<BTreeMap<u32, ErasedTaskLocal>>,
}
impl TaskLocalStorage {
pub(crate) const fn new() -> Self {
Self {
locals: UnsafeCell::new(BTreeMap::new()),
}
}
pub(crate) fn scope(value: Rc<TaskLocalStorage>, scope: impl FnOnce()) {
let outer_scope = EXECUTOR.with(|ex| (*ex.tls.borrow_mut()).replace(value));
scope();
EXECUTOR.with(|ex| {
*ex.tls.borrow_mut() = outer_scope;
});
}
pub(crate) fn with_current<F, R>(f: F) -> R
where
F: FnOnce(&Self) -> R,
{
EXECUTOR.with(|ex| {
if let Some(tls) = ex.tls.borrow().as_ref() {
f(tls)
} else {
FALLBACK_TLS.with(|fallback| f(fallback))
}
})
}
pub(crate) unsafe fn get_or_init<T: 'static>(&self, key: u32, init: fn() -> T) -> &T {
let locals = self.locals.get();
unsafe {
#[expect(
clippy::map_entry,
reason = "cannot hold mutable reference over init() call"
)]
if !(*locals).contains_key(&key) {
let new_value = ErasedTaskLocal::new(init());
(*locals).insert(key, new_value);
}
(*locals).get(&key).unwrap().get()
}
}
}