use core::any::{TypeId, type_name};
use core::cell::UnsafeCell;
use core::error::Error;
use core::fmt::{Debug, Display, Formatter, Result as FmtResult};
use core::marker::PhantomData;
use core::mem::{MaybeUninit, take, transmute_copy};
use core::ptr::{NonNull, copy_nonoverlapping};
use rustc_hash::{FxBuildHasher, FxHashMap};
#[repr(C)]
#[derive(Copy, Clone)]
struct ErasedRef {
ptr: NonNull<()>,
meta: MaybeUninit<usize>,
}
impl ErasedRef {
const fn from_ref<T: ?Sized>(value: &T) -> Self {
let mut this = MaybeUninit::uninit();
unsafe { copy_nonoverlapping(&raw const value as *const u8, &raw mut this as *mut u8, size_of::<&T>()) };
unsafe { this.assume_init() }
}
pub(crate) const unsafe fn as_ref<'a, T: ?Sized>(self) -> &'a T {
unsafe { transmute_copy::<Self, &T>(&self) }
}
}
type ErasedRefMap = FxHashMap<TypeId, ErasedRef>;
#[repr(transparent)]
pub struct ContextCell(UnsafeCell<ErasedRefMap>);
impl ContextCell {
const fn new() -> Self {
Self(UnsafeCell::new(FxHashMap::with_hasher(FxBuildHasher)))
}
fn insert(&self, key: TypeId, value: ErasedRef) {
let this = unsafe { &mut *self.0.get() };
this.insert(key, value);
}
fn take_if_empty(&self) -> Option<ErasedRefMap> {
let this = unsafe { &mut *self.0.get() };
(!this.is_empty()).then(|| take(this))
}
fn replace_with(&self, states: ErasedRefMap) {
let this = unsafe { &mut *self.0.get() };
*this = states;
}
fn clear(&self) {
let this = unsafe { &mut *self.0.get() };
this.clear();
}
fn get(&self, key: TypeId) -> Option<ErasedRef> {
let this = unsafe { &*self.0.get() };
this.get(&key).copied()
}
}
pub trait Context: Sized {
#[doc(hidden)]
fn __insert(self, cx: &ContextCell);
}
impl Context for () {
fn __insert(self, _: &ContextCell) {}
}
impl<T: 'static + ?Sized> Context for &T {
fn __insert(self, cx: &ContextCell) {
cx.insert(TypeId::of::<T>(), ErasedRef::from_ref(self));
}
}
macro_rules! for_all_tuples {
($macro:ident $t:ident $($ts:ident)+) => {
for_all_tuples!($macro $($ts)+);
$macro!(#[cfg_attr(docsrs, doc(hidden))] $t $($ts)+);
};
($macro:ident $t:ident) => {};
($macro:ident) => {
$macro!(#[cfg_attr(docsrs, doc(fake_variadic), doc = "This trait is implemented for tuples up to 12 items long.")] T);
for_all_tuples!($macro A B C D E F G H I J K L);
};
}
macro_rules! impl_context {
(#[$attr:meta] $($t:ident)*) => {
#[$attr]
impl<$($t: Context,)*> Context for ($($t,)*) {
#[allow(non_snake_case)]
fn __insert(self, cx: &ContextCell) {
let ($($t,)*) = self;
$($t.__insert(cx);)*
}
}
};
}
for_all_tuples!(impl_context);
pub struct MissingContextError<T: ?Sized>(PhantomData<T>);
impl<T: ?Sized> Debug for MissingContextError<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
f.debug_tuple("MissingContextError").finish()
}
}
impl<T: ?Sized> Display for MissingContextError<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
write!(f, "type \"{}\" is missing from context", type_name::<T>())
}
}
impl<T: ?Sized> Error for MissingContextError<T> {}
#[repr(transparent)]
#[derive(Copy, Clone)]
pub struct ContextRef<'cx>(&'cx ContextCell);
impl<'cx> ContextRef<'cx> {
pub fn get<T: 'static + ?Sized>(self) -> Result<&'cx T, MissingContextError<T>> {
let ptr = self.0.get(TypeId::of::<T>()).ok_or(MissingContextError(PhantomData))?;
Ok(unsafe { ptr.as_ref::<'cx, T>() })
}
}
thread_local! {
static CONTEXT: ContextCell = const { ContextCell::new() };
}
#[repr(transparent)]
pub(crate) struct ContextSwapGuard(Option<ErasedRefMap>);
impl Drop for ContextSwapGuard {
fn drop(&mut self) {
CONTEXT.with(|cx| {
match self.0.take() {
Some(prev_context) => cx.replace_with(prev_context),
None => cx.clear(),
}
})
}
}
pub(crate) fn swap_global_context<C: Context>(context: C) -> ContextSwapGuard {
CONTEXT.with(|cx| {
let guard = ContextSwapGuard(cx.take_if_empty());
context.__insert(cx);
guard
})
}
pub(crate) fn with_global_context<R>(f: impl FnOnce(ContextRef) -> R) -> R {
CONTEXT.with(|cx| f(ContextRef(cx)))
}