#![cfg_attr(not(feature = "boxed"), feature(type_alias_impl_trait))]
#![cfg_attr(test, feature(exit_status_error))]
#![cfg_attr(docsrs, feature(doc_cfg))]
assert_cfg!(not(all(
feature = "tokio-runtime",
feature = "async-std-runtime"
)));
extern crate self as async_local;
#[cfg(not(loom))]
use std::thread::LocalKey;
use std::{future::Future, marker::PhantomData, ops::Deref, ptr::addr_of};
#[cfg(feature = "async-std-runtime")]
use async_std::task::{spawn_blocking, JoinHandle};
pub use derive_async_local::AsContext;
#[cfg(loom)]
use loom::thread::LocalKey;
use shutdown_barrier::guard_thread_shutdown;
use static_assertions::assert_cfg;
#[cfg(feature = "tokio-runtime")]
use tokio::task::{spawn_blocking, JoinHandle};
pub struct Context<T: Sync>(T);
impl<T> Context<T>
where
T: Sync,
{
pub fn new(inner: T) -> Context<T> {
Context(inner)
}
}
impl<T> AsRef<Context<T>> for Context<T>
where
T: Sync,
{
fn as_ref(&self) -> &Context<T> {
self
}
}
impl<T> Deref for Context<T>
where
T: Sync,
{
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
pub unsafe trait AsContext: AsRef<Context<Self::Target>> {
type Target: Sync;
}
unsafe impl<T> AsContext for Context<T>
where
T: Sync,
{
type Target = T;
}
#[derive(PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
pub struct LocalRef<T: Sync + 'static>(*const Context<T>);
impl<T> LocalRef<T>
where
T: Sync + 'static,
{
unsafe fn new(context: &Context<T>) -> Self {
guard_thread_shutdown();
LocalRef(addr_of!(*context))
}
pub unsafe fn guarded_ref<'a>(&self) -> RefGuard<'a, T> {
RefGuard {
inner: self.0,
_marker: PhantomData,
}
}
#[cfg_attr(
docsrs,
doc(cfg(any(feature = "tokio-runtime", feature = "async-std-runtime")))
)]
#[cfg(any(feature = "tokio-runtime", feature = "async-std-runtime"))]
pub fn with_blocking<F, R>(self, f: F) -> JoinHandle<R>
where
F: for<'a> FnOnce(&'a LocalRef<T>) -> R + Send + 'static,
R: Send + 'static,
{
spawn_blocking(move || f(&self))
}
}
impl<T> Deref for LocalRef<T>
where
T: Sync,
{
type Target = T;
fn deref(&self) -> &Self::Target {
unsafe { (*self.0).deref() }
}
}
impl<T> Clone for LocalRef<T>
where
T: Sync + 'static,
{
fn clone(&self) -> Self {
LocalRef(self.0)
}
}
impl<T> Copy for LocalRef<T> where T: Sync + 'static {}
unsafe impl<T> Send for LocalRef<T> where T: Sync {}
unsafe impl<T> Sync for LocalRef<T> where T: Sync {}
#[derive(PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
pub struct RefGuard<'a, T: Sync + 'static> {
inner: *const Context<T>,
_marker: PhantomData<fn(&'a ()) -> &'a ()>,
}
impl<'a, T> RefGuard<'a, T>
where
T: Sync + 'static,
{
unsafe fn new(context: &Context<T>) -> Self {
guard_thread_shutdown();
RefGuard {
inner: addr_of!(*context),
_marker: PhantomData,
}
}
#[cfg_attr(
docsrs,
doc(cfg(any(feature = "tokio-runtime", feature = "async-std-runtime")))
)]
#[cfg(any(feature = "tokio-runtime", feature = "async-std-runtime"))]
pub fn with_blocking<F, R>(self, f: F) -> JoinHandle<R>
where
F: for<'b> FnOnce(RefGuard<'b, T>) -> R + Send + 'static,
R: Send + 'static,
{
let ref_guard = unsafe { std::mem::transmute(self) };
spawn_blocking(move || f(ref_guard))
}
}
impl<'a, T> Deref for RefGuard<'a, T>
where
T: Sync,
{
type Target = T;
fn deref(&self) -> &Self::Target {
unsafe { (*self.inner).deref() }
}
}
impl<'a, T> Clone for RefGuard<'a, T>
where
T: Sync + 'static,
{
fn clone(&self) -> Self {
RefGuard {
inner: self.inner,
_marker: PhantomData,
}
}
}
impl<'a, T> Copy for RefGuard<'a, T> where T: Sync + 'static {}
unsafe impl<'a, T> Send for RefGuard<'a, T> where T: Sync {}
unsafe impl<'a, T> Sync for RefGuard<'a, T> where T: Sync {}
#[async_t::async_trait]
pub trait AsyncLocal<T>
where
T: 'static + AsContext,
{
async fn with_async<F, R, Fut>(&'static self, f: F) -> R
where
F: FnOnce(RefGuard<'async_trait, T::Target>) -> Fut + Send,
Fut: Future<Output = R> + Send;
#[cfg_attr(
docsrs,
doc(cfg(any(feature = "tokio-runtime", feature = "async-std-runtime")))
)]
#[cfg(any(feature = "tokio-runtime", feature = "async-std-runtime"))]
fn with_blocking<F, R>(&'static self, f: F) -> JoinHandle<R>
where
F: for<'a> FnOnce(RefGuard<'a, T::Target>) -> R + Send + 'static,
R: Send + 'static;
unsafe fn local_ref(&'static self) -> LocalRef<T::Target>;
unsafe fn guarded_ref<'a>(&'static self) -> RefGuard<'a, T::Target>;
}
#[async_t::async_trait]
impl<T> AsyncLocal<T> for LocalKey<T>
where
T: AsContext + 'static,
{
async fn with_async<F, R, Fut>(&'static self, f: F) -> R
where
F: FnOnce(RefGuard<'async_trait, T::Target>) -> Fut + Send,
Fut: Future<Output = R> + Send,
{
let local_ref = unsafe { self.guarded_ref() };
f(local_ref).await
}
#[cfg_attr(
docsrs,
doc(cfg(any(feature = "tokio-runtime", feature = "async-std-runtime")))
)]
#[cfg(any(feature = "tokio-runtime", feature = "async-std-runtime"))]
fn with_blocking<F, R>(&'static self, f: F) -> JoinHandle<R>
where
F: for<'a> FnOnce(RefGuard<'a, T::Target>) -> R + Send + 'static,
R: Send + 'static,
{
let guarded_ref = unsafe { self.guarded_ref() };
spawn_blocking(move || f(guarded_ref))
}
unsafe fn local_ref(&'static self) -> LocalRef<T::Target> {
debug_assert!(
!std::mem::needs_drop::<T>(),
"AsyncLocal cannot be used with thread locals types that impl std::ops::Drop"
);
self.with(|value| LocalRef::new(value.as_ref()))
}
unsafe fn guarded_ref<'a>(&'static self) -> RefGuard<'a, T::Target> {
debug_assert!(
!std::mem::needs_drop::<T>(),
"AsyncLocal cannot be used with thread locals types that impl std::ops::Drop"
);
self.with(|value| RefGuard::new(value.as_ref()))
}
}
#[cfg(all(test))]
mod tests {
#[cfg(not(loom))]
use std::sync::atomic::AtomicUsize;
#[cfg(all(
any(feature = "tokio-runtime", feature = "async-std-runtime"),
not(loom)
))]
use std::sync::atomic::Ordering;
#[cfg(feature = "async-std-runtime")]
use async_std::task::yield_now;
#[cfg(loom)]
use loom::{
sync::atomic::{AtomicUsize, Ordering},
thread_local,
};
#[cfg(feature = "tokio-runtime")]
use tokio::task::yield_now;
use super::*;
thread_local! {
static COUNTER: Context<AtomicUsize> = Context::new(AtomicUsize::new(0));
}
#[cfg(all(not(loom), feature = "tokio-runtime"))]
#[tokio::test(flavor = "multi_thread")]
async fn with_blocking() {
COUNTER
.with_blocking(|counter| counter.fetch_add(1, Ordering::Relaxed))
.await
.unwrap();
let guarded_ref = unsafe { COUNTER.guarded_ref() };
guarded_ref
.with_blocking(|counter| counter.fetch_add(1, Ordering::Relaxed))
.await
.unwrap();
let local_ref = unsafe { COUNTER.local_ref() };
local_ref
.with_blocking(|counter| counter.fetch_add(1, Ordering::Relaxed))
.await
.unwrap();
}
#[cfg(all(not(loom), feature = "async-std-runtime"))]
#[async_std::test]
async fn with_blocking() {
COUNTER
.with_blocking(|counter| counter.fetch_add(1, Ordering::Relaxed))
.await;
let guarded_ref = unsafe { COUNTER.guarded_ref() };
guarded_ref
.with_blocking(|counter| counter.fetch_add(1, Ordering::Relaxed))
.await;
let local_ref = unsafe { COUNTER.local_ref() };
local_ref
.with_blocking(|counter| counter.fetch_add(1, Ordering::Relaxed))
.await;
}
#[cfg(loom)]
#[test]
fn guard_protects_context() {
loom::model(|| {
let counter = Context::new(AtomicUsize::new(0));
let local_ref = unsafe { LocalRef::new(&counter) };
let guard = ContextGuard::new(addr_of!(counter));
loom::thread::spawn(move || {
let count = local_ref.fetch_add(1, Ordering::Relaxed);
assert_eq!(count, 0);
drop(guard);
});
drop(counter);
});
}
#[cfg(all(
not(loom),
any(feature = "tokio-runtime", feature = "async-std-runtime")
))]
#[cfg_attr(feature = "tokio-runtime", tokio::test(flavor = "multi_thread"))]
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
async fn ref_spans_await() {
let counter = unsafe { COUNTER.local_ref() };
yield_now().await;
counter.fetch_add(1, Ordering::SeqCst);
}
#[cfg(all(
not(loom),
any(feature = "tokio-runtime", feature = "async-std-runtime")
))]
#[cfg_attr(feature = "tokio-runtime", tokio::test(flavor = "multi_thread"))]
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
async fn with_async() {
COUNTER
.with_async(|counter| async move {
yield_now().await;
counter.fetch_add(1, Ordering::Release);
})
.await;
}
#[cfg(all(
not(loom),
any(feature = "tokio-runtime", feature = "async-std-runtime")
))]
#[cfg_attr(feature = "tokio-runtime", tokio::test(flavor = "multi_thread"))]
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
async fn bound_to_async_trait_lifetime() {
struct Counter;
#[async_t::async_trait]
trait Countable {
#[allow(clippy::needless_lifetimes)]
async fn add_one(ref_guard: RefGuard<'async_trait, AtomicUsize>) -> usize;
}
#[async_t::async_trait]
impl Countable for Counter {
async fn add_one(counter: RefGuard<'async_trait, AtomicUsize>) -> usize {
yield_now().await;
counter.fetch_add(1, Ordering::Release)
}
}
let counter = unsafe { COUNTER.guarded_ref() };
Counter::add_one(counter).await;
}
}