use core::any::Any;
use core::cell::RefCell;
use core::marker::PhantomData;
use core::sync::atomic::{AtomicUsize, Ordering};
extern crate alloc;
use alloc::boxed::Box;
use alloc::collections::BTreeMap;
use alloc::vec::Vec;
static CONTEXT_ID_COUNTER: AtomicUsize = AtomicUsize::new(0);
#[cfg_attr(doc, aquamarine::aquamarine)]
pub struct Context<T: 'static> {
id: usize,
_phantom: PhantomData<T>,
}
impl<T: 'static> Context<T> {
pub fn new() -> Self {
Self {
id: CONTEXT_ID_COUNTER.fetch_add(1, Ordering::Relaxed),
_phantom: PhantomData,
}
}
pub fn id(&self) -> usize {
self.id
}
}
impl<T: 'static> Default for Context<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: 'static> Clone for Context<T> {
fn clone(&self) -> Self {
*self
}
}
impl<T: 'static> Copy for Context<T> {}
impl<T: 'static> core::fmt::Debug for Context<T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("Context")
.field("id", &self.id)
.field("type", &core::any::type_name::<T>())
.finish()
}
}
struct ContextValue {
value: Box<dyn Any>,
}
thread_local! {
static CONTEXT_STACK: RefCell<BTreeMap<usize, Vec<ContextValue>>> =
const { RefCell::new(BTreeMap::new()) };
}
pub const fn create_context<T: 'static>() -> Context<T> {
Context {
id: 0, _phantom: PhantomData,
}
}
pub fn provide_context<T: Clone + 'static>(ctx: &Context<T>, value: T) {
CONTEXT_STACK.with(|stack| {
let mut stack = stack.borrow_mut();
let values = stack.entry(ctx.id()).or_insert_with(Vec::new);
values.push(ContextValue {
value: Box::new(value),
});
});
}
pub fn get_context<T: Clone + 'static>(ctx: &Context<T>) -> Option<T> {
CONTEXT_STACK.with(|stack| {
let stack = stack.borrow();
stack.get(&ctx.id()).and_then(|values| {
values
.last()
.and_then(|cv| cv.value.downcast_ref::<T>().cloned())
})
})
}
pub fn remove_context<T: Clone + 'static>(ctx: &Context<T>) -> Option<T> {
CONTEXT_STACK.with(|stack| {
let mut stack = stack.borrow_mut();
stack.get_mut(&ctx.id()).and_then(|values| {
values
.pop()
.and_then(|cv| cv.value.downcast::<T>().ok().map(|b| *b))
})
})
}
#[doc(hidden)]
pub fn clear_all_contexts() {
CONTEXT_STACK.with(|stack| {
stack.borrow_mut().clear();
});
}
pub struct ContextGuard<T: Clone + 'static> {
ctx: Context<T>,
}
impl<T: Clone + 'static> ContextGuard<T> {
pub fn new(ctx: &Context<T>, value: T) -> Self {
provide_context(ctx, value);
Self { ctx: *ctx }
}
}
impl<T: Clone + 'static> Drop for ContextGuard<T> {
fn drop(&mut self) {
remove_context::<T>(&self.ctx);
}
}
#[cfg(test)]
mod tests {
use super::*;
fn setup() {
clear_all_contexts();
}
#[test]
fn test_context_creation() {
let ctx1: Context<i32> = Context::new();
let ctx2: Context<i32> = Context::new();
assert_ne!(ctx1.id(), ctx2.id());
}
#[test]
fn test_provide_and_get_context() {
setup();
let ctx: Context<i32> = Context::new();
assert_eq!(get_context(&ctx), None);
provide_context(&ctx, 42);
assert_eq!(get_context(&ctx), Some(42));
}
#[test]
fn test_nested_contexts() {
setup();
let ctx: Context<alloc::string::String> = Context::new();
provide_context(&ctx, alloc::string::String::from("outer"));
assert_eq!(
get_context(&ctx),
Some(alloc::string::String::from("outer"))
);
provide_context(&ctx, alloc::string::String::from("inner"));
assert_eq!(
get_context(&ctx),
Some(alloc::string::String::from("inner"))
);
remove_context::<alloc::string::String>(&ctx);
assert_eq!(
get_context(&ctx),
Some(alloc::string::String::from("outer"))
);
remove_context::<alloc::string::String>(&ctx);
assert_eq!(get_context(&ctx), None);
}
#[test]
fn test_multiple_contexts() {
setup();
let ctx1: Context<i32> = Context::new();
let ctx2: Context<alloc::string::String> = Context::new();
provide_context(&ctx1, 42);
provide_context(&ctx2, alloc::string::String::from("hello"));
assert_eq!(get_context(&ctx1), Some(42));
assert_eq!(
get_context(&ctx2),
Some(alloc::string::String::from("hello"))
);
}
#[test]
fn test_context_guard() {
setup();
let ctx: Context<i32> = Context::new();
{
let _guard = ContextGuard::new(&ctx, 42);
assert_eq!(get_context(&ctx), Some(42));
}
assert_eq!(get_context(&ctx), None);
}
#[test]
fn test_context_clone() {
let ctx1: Context<i32> = Context::new();
let ctx2 = ctx1;
assert_eq!(ctx1.id(), ctx2.id());
}
#[test]
fn test_context_debug() {
let ctx: Context<i32> = Context::new();
let debug_str = alloc::format!("{:?}", ctx);
assert!(debug_str.contains("Context"));
assert!(debug_str.contains("i32"));
}
}