use pin_project_lite::pin_project;
use std::cell::RefCell;
use std::error::Error;
use std::future::Future;
use std::marker::PhantomPinned;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{fmt, mem, thread};
#[macro_export]
#[cfg_attr(docsrs, doc(cfg(feature = "rt")))]
macro_rules! task_local {
() => {};
($(#[$attr:meta])* $vis:vis static $name:ident: $t:ty; $($rest:tt)*) => {
$crate::__task_local_inner!($(#[$attr])* $vis $name, $t);
$crate::task_local!($($rest)*);
};
($(#[$attr:meta])* $vis:vis static $name:ident: $t:ty) => {
$crate::__task_local_inner!($(#[$attr])* $vis $name, $t);
}
}
#[doc(hidden)]
#[cfg(not(tokio_no_const_thread_local))]
#[macro_export]
macro_rules! __task_local_inner {
($(#[$attr:meta])* $vis:vis $name:ident, $t:ty) => {
$(#[$attr])*
$vis static $name: $crate::task::LocalKey<$t> = {
std::thread_local! {
static __KEY: std::cell::RefCell<Option<$t>> = const { std::cell::RefCell::new(None) };
}
$crate::task::LocalKey { inner: __KEY }
};
};
}
#[doc(hidden)]
#[cfg(tokio_no_const_thread_local)]
#[macro_export]
macro_rules! __task_local_inner {
($(#[$attr:meta])* $vis:vis $name:ident, $t:ty) => {
$(#[$attr])*
$vis static $name: $crate::task::LocalKey<$t> = {
std::thread_local! {
static __KEY: std::cell::RefCell<Option<$t>> = std::cell::RefCell::new(None);
}
$crate::task::LocalKey { inner: __KEY }
};
};
}
#[cfg_attr(docsrs, doc(cfg(feature = "rt")))]
pub struct LocalKey<T: 'static> {
#[doc(hidden)]
pub inner: thread::LocalKey<RefCell<Option<T>>>,
}
impl<T: 'static> LocalKey<T> {
pub fn scope<F>(&'static self, value: T, f: F) -> TaskLocalFuture<T, F>
where
F: Future,
{
TaskLocalFuture {
local: self,
slot: Some(value),
future: Some(f),
_pinned: PhantomPinned,
}
}
#[track_caller]
pub fn sync_scope<F, R>(&'static self, value: T, f: F) -> R
where
F: FnOnce() -> R,
{
let mut value = Some(value);
match self.scope_inner(&mut value, f) {
Ok(res) => res,
Err(err) => err.panic(),
}
}
fn scope_inner<F, R>(&'static self, slot: &mut Option<T>, f: F) -> Result<R, ScopeInnerErr>
where
F: FnOnce() -> R,
{
struct Guard<'a, T: 'static> {
local: &'static LocalKey<T>,
slot: &'a mut Option<T>,
}
impl<'a, T: 'static> Drop for Guard<'a, T> {
fn drop(&mut self) {
self.local.inner.with(|inner| {
let mut ref_mut = inner.borrow_mut();
mem::swap(self.slot, &mut *ref_mut);
});
}
}
self.inner.try_with(|inner| {
inner
.try_borrow_mut()
.map(|mut ref_mut| mem::swap(slot, &mut *ref_mut))
})??;
let guard = Guard { local: self, slot };
let res = f();
drop(guard);
Ok(res)
}
#[track_caller]
pub fn with<F, R>(&'static self, f: F) -> R
where
F: FnOnce(&T) -> R,
{
match self.try_with(f) {
Ok(res) => res,
Err(_) => panic!("cannot access a task-local storage value without setting it first"),
}
}
pub fn try_with<F, R>(&'static self, f: F) -> Result<R, AccessError>
where
F: FnOnce(&T) -> R,
{
let try_with_res = self.inner.try_with(|v| {
v.borrow().as_ref().map(f)
});
match try_with_res {
Ok(Some(res)) => Ok(res),
Ok(None) | Err(_) => Err(AccessError { _private: () }),
}
}
}
impl<T: Copy + 'static> LocalKey<T> {
#[track_caller]
pub fn get(&'static self) -> T {
self.with(|v| *v)
}
}
impl<T: 'static> fmt::Debug for LocalKey<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.pad("LocalKey { .. }")
}
}
pin_project! {
pub struct TaskLocalFuture<T, F>
where
T: 'static,
{
local: &'static LocalKey<T>,
slot: Option<T>,
#[pin]
future: Option<F>,
#[pin]
_pinned: PhantomPinned,
}
impl<T: 'static, F> PinnedDrop for TaskLocalFuture<T, F> {
fn drop(this: Pin<&mut Self>) {
let this = this.project();
if mem::needs_drop::<F>() && this.future.is_some() {
let mut future = this.future;
let _ = this.local.scope_inner(this.slot, || {
future.set(None);
});
}
}
}
}
impl<T: 'static, F: Future> Future for TaskLocalFuture<T, F> {
type Output = F::Output;
#[track_caller]
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let mut future_opt = this.future;
let res = this
.local
.scope_inner(this.slot, || match future_opt.as_mut().as_pin_mut() {
Some(fut) => {
let res = fut.poll(cx);
if res.is_ready() {
future_opt.set(None);
}
Some(res)
}
None => None,
});
match res {
Ok(Some(res)) => res,
Ok(None) => panic!("`TaskLocalFuture` polled after completion"),
Err(err) => err.panic(),
}
}
}
impl<T: 'static, F> fmt::Debug for TaskLocalFuture<T, F>
where
T: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
struct TransparentOption<'a, T> {
value: &'a Option<T>,
}
impl<'a, T: fmt::Debug> fmt::Debug for TransparentOption<'a, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.value.as_ref() {
Some(value) => value.fmt(f),
None => f.pad("<missing>"),
}
}
}
f.debug_struct("TaskLocalFuture")
.field("value", &TransparentOption { value: &self.slot })
.finish()
}
}
#[derive(Clone, Copy, Eq, PartialEq)]
pub struct AccessError {
_private: (),
}
impl fmt::Debug for AccessError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AccessError").finish()
}
}
impl fmt::Display for AccessError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Display::fmt("task-local value not set", f)
}
}
impl Error for AccessError {}
enum ScopeInnerErr {
BorrowError,
AccessError,
}
impl ScopeInnerErr {
#[track_caller]
fn panic(&self) -> ! {
match self {
Self::BorrowError => panic!("cannot enter a task-local scope while the task-local storage is borrowed"),
Self::AccessError => panic!("cannot enter a task-local scope during or after destruction of the underlying thread-local"),
}
}
}
impl From<std::cell::BorrowMutError> for ScopeInnerErr {
fn from(_: std::cell::BorrowMutError) -> Self {
Self::BorrowError
}
}
impl From<std::thread::AccessError> for ScopeInnerErr {
fn from(_: std::thread::AccessError) -> Self {
Self::AccessError
}
}