use std::any::{Any, TypeId};
use std::cell::RefCell;
use std::collections::HashMap;
use std::fmt;
use std::hash::{BuildHasherDefault, Hasher};
use std::marker::PhantomData;
use std::sync::Arc;
thread_local! {
static CURRENT_CONTEXT: RefCell<Context> = RefCell::new(Context::default());
}
#[derive(Clone, Default)]
pub struct Context {
entries: HashMap<TypeId, Arc<dyn Any + Sync + Send>, BuildHasherDefault<IdHasher>>,
}
impl Context {
pub fn new() -> Self {
Context::default()
}
pub fn current() -> Self {
Context::map_current(|cx| cx.clone())
}
pub fn map_current<T>(f: impl FnOnce(&Context) -> T) -> T {
CURRENT_CONTEXT.with(|cx| f(&cx.borrow()))
}
pub fn current_with_value<T: 'static + Send + Sync>(value: T) -> Self {
let mut new_context = Context::current();
new_context
.entries
.insert(TypeId::of::<T>(), Arc::new(value));
new_context
}
pub fn get<T: 'static>(&self) -> Option<&T> {
self.entries
.get(&TypeId::of::<T>())
.and_then(|rc| rc.downcast_ref())
}
pub fn with_value<T: 'static + Send + Sync>(&self, value: T) -> Self {
let mut new_context = self.clone();
new_context
.entries
.insert(TypeId::of::<T>(), Arc::new(value));
new_context
}
pub fn attach(self) -> ContextGuard {
let previous_cx = CURRENT_CONTEXT
.try_with(|current| current.replace(self))
.ok();
ContextGuard {
previous_cx,
_marker: PhantomData,
}
}
}
impl fmt::Debug for Context {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Context")
.field("entries", &self.entries.len())
.finish()
}
}
#[allow(missing_debug_implementations)]
pub struct ContextGuard {
previous_cx: Option<Context>,
_marker: PhantomData<*const ()>,
}
impl Drop for ContextGuard {
fn drop(&mut self) {
if let Some(previous_cx) = self.previous_cx.take() {
let _ = CURRENT_CONTEXT.try_with(|current| current.replace(previous_cx));
}
}
}
#[derive(Clone, Default, Debug)]
struct IdHasher(u64);
impl Hasher for IdHasher {
fn write(&mut self, _: &[u8]) {
unreachable!("TypeId calls write_u64");
}
#[inline]
fn write_u64(&mut self, id: u64) {
self.0 = id;
}
#[inline]
fn finish(&self) -> u64 {
self.0
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn nested_contexts() {
#[derive(Debug, PartialEq)]
struct ValueA(&'static str);
#[derive(Debug, PartialEq)]
struct ValueB(u64);
let _outer_guard = Context::new().with_value(ValueA("a")).attach();
let current = Context::current();
assert_eq!(current.get(), Some(&ValueA("a")));
assert_eq!(current.get::<ValueB>(), None);
{
let _inner_guard = Context::current_with_value(ValueB(42)).attach();
let current = Context::current();
assert_eq!(current.get(), Some(&ValueA("a")));
assert_eq!(current.get(), Some(&ValueB(42)));
assert!(Context::map_current(|cx| {
assert_eq!(cx.get(), Some(&ValueA("a")));
assert_eq!(cx.get(), Some(&ValueB(42)));
true
}));
}
let current = Context::current();
assert_eq!(current.get(), Some(&ValueA("a")));
assert_eq!(current.get::<ValueB>(), None);
assert!(Context::map_current(|cx| {
assert_eq!(cx.get(), Some(&ValueA("a")));
assert_eq!(cx.get::<ValueB>(), None);
true
}));
}
}