use std::any::Any;
use std::cell::RefCell;
use std::collections::HashMap;
use std::marker::PhantomData;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use super::Signal;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct ContextId(u64);
impl ContextId {
fn new() -> Self {
static COUNTER: AtomicU64 = AtomicU64::new(0);
Self(COUNTER.fetch_add(1, Ordering::Relaxed))
}
}
#[derive(Debug)]
pub struct Context<T> {
id: ContextId,
default: Option<T>,
_marker: PhantomData<T>,
}
impl<T> Context<T> {
pub fn new() -> Self {
Self {
id: ContextId::new(),
default: None,
_marker: PhantomData,
}
}
pub fn with_default(default: T) -> Self {
Self {
id: ContextId::new(),
default: Some(default),
_marker: PhantomData,
}
}
pub fn id(&self) -> ContextId {
self.id
}
pub fn default(&self) -> Option<&T> {
self.default.as_ref()
}
}
impl<T> Default for Context<T> {
fn default() -> Self {
Self::new()
}
}
impl<T> Clone for Context<T>
where
T: Clone,
{
fn clone(&self) -> Self {
Self {
id: self.id,
default: self.default.clone(),
_marker: PhantomData,
}
}
}
pub struct Provider<T: Clone + Send + Sync + 'static> {
context_id: ContextId,
value: Signal<T>,
}
impl<T: Clone + Send + Sync + 'static> Provider<T> {
pub fn new(context: &Context<T>, value: T) -> Self {
Self {
context_id: context.id,
value: Signal::new(value),
}
}
pub fn get(&self) -> T {
self.value.get()
}
pub fn set(&self, value: T) {
self.value.set(value);
}
pub fn update(&self, f: impl FnOnce(&mut T)) {
self.value.update(f);
}
pub fn signal(&self) -> &Signal<T> {
&self.value
}
}
impl<T: Clone + Send + Sync + 'static> Clone for Provider<T> {
fn clone(&self) -> Self {
Self {
context_id: self.context_id,
value: self.value.clone(),
}
}
}
type ContextValue = Arc<dyn Any + Send + Sync>;
thread_local! {
static CONTEXT_STACK: RefCell<Vec<HashMap<ContextId, ContextValue>>> = RefCell::new(Vec::new());
static GLOBAL_CONTEXTS: RefCell<HashMap<ContextId, ContextValue>> = RefCell::new(HashMap::new());
}
pub fn create_context<T>() -> Context<T> {
Context::new()
}
pub fn create_context_with_default<T>(default: T) -> Context<T> {
Context::with_default(default)
}
pub fn provide<T: Clone + Send + Sync + 'static>(context: &Context<T>, value: T) {
let signal = Signal::new(value);
let boxed: ContextValue = Arc::new(signal);
GLOBAL_CONTEXTS.with(|store| {
store.borrow_mut().insert(context.id, boxed);
});
}
pub fn provide_signal<T: Clone + Send + Sync + 'static>(
context: &Context<T>,
value: T,
) -> Signal<T> {
let signal = Signal::new(value);
let boxed: ContextValue = Arc::new(signal.clone());
GLOBAL_CONTEXTS.with(|store| {
store.borrow_mut().insert(context.id, boxed);
});
signal
}
pub fn use_context<T: Clone + Send + Sync + 'static>(context: &Context<T>) -> Option<T> {
let from_stack = CONTEXT_STACK.with(|stack| {
let stack = stack.borrow();
for scope in stack.iter().rev() {
if let Some(value) = scope.get(&context.id) {
if let Some(signal) = value.downcast_ref::<Signal<T>>() {
return Some(signal.get());
}
}
}
None
});
if from_stack.is_some() {
return from_stack;
}
let from_global = GLOBAL_CONTEXTS.with(|store| {
let store = store.borrow();
if let Some(value) = store.get(&context.id) {
if let Some(signal) = value.downcast_ref::<Signal<T>>() {
return Some(signal.get());
}
}
None
});
if from_global.is_some() {
return from_global;
}
context.default.clone()
}
pub fn use_context_signal<T: Clone + Send + Sync + 'static>(
context: &Context<T>,
) -> Option<Signal<T>> {
let from_stack = CONTEXT_STACK.with(|stack| {
let stack = stack.borrow();
for scope in stack.iter().rev() {
if let Some(value) = scope.get(&context.id) {
if let Some(signal) = value.downcast_ref::<Signal<T>>() {
return Some(signal.clone());
}
}
}
None
});
if from_stack.is_some() {
return from_stack;
}
GLOBAL_CONTEXTS.with(|store| {
let store = store.borrow();
if let Some(value) = store.get(&context.id) {
if let Some(signal) = value.downcast_ref::<Signal<T>>() {
return Some(signal.clone());
}
}
None
})
}
pub fn has_context<T: Clone + Send + Sync + 'static>(context: &Context<T>) -> bool {
let in_stack = CONTEXT_STACK.with(|stack| {
let stack = stack.borrow();
for scope in stack.iter().rev() {
if scope.contains_key(&context.id) {
return true;
}
}
false
});
if in_stack {
return true;
}
GLOBAL_CONTEXTS.with(|store| store.borrow().contains_key(&context.id))
}
pub fn clear_context<T>(context: &Context<T>) {
GLOBAL_CONTEXTS.with(|store| {
store.borrow_mut().remove(&context.id);
});
}
pub fn clear_all_contexts() {
GLOBAL_CONTEXTS.with(|store| {
store.borrow_mut().clear();
});
CONTEXT_STACK.with(|stack| {
stack.borrow_mut().clear();
});
}
pub struct ContextScope {
_private: (),
}
impl ContextScope {
pub fn new() -> Self {
CONTEXT_STACK.with(|stack| {
stack.borrow_mut().push(HashMap::new());
});
Self { _private: () }
}
pub fn provide<T: Clone + Send + Sync + 'static>(&self, context: &Context<T>, value: T) {
let signal = Signal::new(value);
let boxed: ContextValue = Arc::new(signal);
CONTEXT_STACK.with(|stack| {
let mut stack = stack.borrow_mut();
if let Some(scope) = stack.last_mut() {
scope.insert(context.id, boxed);
}
});
}
pub fn provide_signal<T: Clone + Send + Sync + 'static>(
&self,
context: &Context<T>,
value: T,
) -> Signal<T> {
let signal = Signal::new(value);
let boxed: ContextValue = Arc::new(signal.clone());
CONTEXT_STACK.with(|stack| {
let mut stack = stack.borrow_mut();
if let Some(scope) = stack.last_mut() {
scope.insert(context.id, boxed);
}
});
signal
}
}
impl Default for ContextScope {
fn default() -> Self {
Self::new()
}
}
impl Drop for ContextScope {
fn drop(&mut self) {
CONTEXT_STACK.with(|stack| {
stack.borrow_mut().pop();
});
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_context_new() {
let ctx: Context<i32> = Context::new();
assert!(ctx.default().is_none());
}
#[test]
fn test_context_new_string() {
let ctx: Context<String> = Context::new();
assert!(ctx.default().is_none());
}
#[test]
fn test_context_with_default() {
let ctx: Context<i32> = Context::with_default(42);
assert_eq!(ctx.default(), Some(&42));
}
#[test]
fn test_context_with_default_string() {
let ctx: Context<String> = Context::with_default("hello".to_string());
assert_eq!(ctx.default(), Some(&"hello".to_string()));
}
#[test]
fn test_context_default_trait() {
let ctx: Context<i32> = Default::default();
assert!(ctx.default().is_none());
}
#[test]
fn test_context_id() {
let ctx1: Context<i32> = Context::new();
let ctx2: Context<i32> = Context::new();
assert_ne!(ctx1.id(), ctx2.id());
}
#[test]
fn test_context_clone_without_default() {
let ctx1: Context<i32> = Context::new();
let ctx2 = ctx1.clone();
assert_eq!(ctx1.id(), ctx2.id());
assert!(ctx2.default().is_none());
}
#[test]
fn test_context_clone_with_default() {
let ctx1: Context<i32> = Context::with_default(42);
let ctx2 = ctx1.clone();
assert_eq!(ctx1.id(), ctx2.id());
assert_eq!(ctx2.default(), Some(&42));
}
#[test]
fn test_create_context() {
let ctx = create_context::<i32>();
assert!(ctx.default().is_none());
}
#[test]
fn test_create_context_string() {
let ctx = create_context::<String>();
assert!(ctx.default().is_none());
}
#[test]
fn test_create_context_with_default() {
let ctx = create_context_with_default(42);
assert_eq!(ctx.default(), Some(&42));
}
#[test]
fn test_provide_and_use_context() {
let ctx = create_context::<i32>();
provide(&ctx, 42);
assert_eq!(use_context(&ctx), Some(42));
clear_context(&ctx);
}
#[test]
fn test_provide_and_use_context_string() {
let ctx = create_context::<String>();
provide(&ctx, "hello".to_string());
assert_eq!(use_context(&ctx), Some("hello".to_string()));
clear_context(&ctx);
}
#[test]
fn test_use_context_not_provided() {
let ctx = create_context::<i32>();
assert_eq!(use_context(&ctx), None);
}
#[test]
fn test_use_context_with_default_not_provided() {
let ctx = create_context_with_default(42);
assert_eq!(use_context(&ctx), Some(42));
}
#[test]
fn test_use_context_default_overridden_by_provide() {
let ctx = create_context_with_default(42);
provide(&ctx, 100);
assert_eq!(use_context(&ctx), Some(100));
clear_context(&ctx);
}
#[test]
fn test_provide_signal_and_use_context() {
let ctx = create_context::<i32>();
let signal = provide_signal(&ctx, 42);
assert_eq!(signal.get(), 42);
assert_eq!(use_context(&ctx), Some(42));
clear_context(&ctx);
}
#[test]
fn test_provide_signal_update() {
let ctx = create_context::<i32>();
let signal = provide_signal(&ctx, 42);
assert_eq!(use_context(&ctx), Some(42));
signal.set(100);
assert_eq!(use_context(&ctx), Some(100));
clear_context(&ctx);
}
#[test]
fn test_use_context_signal() {
let ctx = create_context::<i32>();
provide(&ctx, 42);
let signal = use_context_signal(&ctx);
assert!(signal.is_some());
assert_eq!(signal.unwrap().get(), 42);
clear_context(&ctx);
}
#[test]
fn test_use_context_signal_not_provided() {
let ctx = create_context::<i32>();
let signal = use_context_signal(&ctx);
assert!(signal.is_none());
}
#[test]
fn test_has_context_false() {
let ctx = create_context::<i32>();
assert!(!has_context(&ctx));
}
#[test]
fn test_has_context_true() {
let ctx = create_context::<i32>();
provide(&ctx, 42);
assert!(has_context(&ctx));
clear_context(&ctx);
}
#[test]
fn test_clear_context() {
let ctx = create_context::<i32>();
provide(&ctx, 42);
assert!(has_context(&ctx));
clear_context(&ctx);
assert!(!has_context(&ctx));
}
#[test]
fn test_clear_all_contexts() {
let ctx1 = create_context::<i32>();
let ctx2 = create_context::<String>();
provide(&ctx1, 42);
provide(&ctx2, "hello".to_string());
assert!(has_context(&ctx1));
assert!(has_context(&ctx2));
clear_all_contexts();
assert!(!has_context(&ctx1));
assert!(!has_context(&ctx2));
}
#[test]
fn test_context_scope_new() {
let _scope = ContextScope::new();
}
#[test]
fn test_context_scope_default() {
let _scope = ContextScope::default();
}
#[test]
fn test_context_scope_provide() {
let ctx = create_context::<i32>();
{
let scope = ContextScope::new();
scope.provide(&ctx, 42);
assert_eq!(use_context(&ctx), Some(42));
}
assert_eq!(use_context(&ctx), None);
}
#[test]
fn test_context_scope_provide_signal() {
let ctx = create_context::<i32>();
let _signal = {
let scope = ContextScope::new();
let signal = scope.provide_signal(&ctx, 42);
assert_eq!(signal.get(), 42);
assert_eq!(use_context(&ctx), Some(42));
signal
};
assert_eq!(use_context(&ctx), None);
}
#[test]
fn test_context_scope_nested() {
let ctx = create_context::<i32>();
let scope1 = ContextScope::new();
scope1.provide(&ctx, 10);
assert_eq!(use_context(&ctx), Some(10));
{
let scope2 = ContextScope::new();
scope2.provide(&ctx, 20);
assert_eq!(use_context(&ctx), Some(20));
}
assert_eq!(use_context(&ctx), Some(10));
}
#[test]
fn test_provider_new() {
let ctx = create_context::<i32>();
let provider = Provider::new(&ctx, 42);
assert_eq!(provider.get(), 42);
}
#[test]
fn test_provider_get() {
let ctx = create_context::<i32>();
let provider = Provider::new(&ctx, 42);
assert_eq!(provider.get(), 42);
}
#[test]
fn test_provider_set() {
let ctx = create_context::<i32>();
let provider = Provider::new(&ctx, 42);
assert_eq!(provider.get(), 42);
provider.set(100);
assert_eq!(provider.get(), 100);
}
#[test]
fn test_provider_update() {
let ctx = create_context::<i32>();
let provider = Provider::new(&ctx, 42);
provider.update(|v| *v *= 2);
assert_eq!(provider.get(), 84);
}
#[test]
fn test_provider_signal() {
let ctx = create_context::<i32>();
let provider = Provider::new(&ctx, 42);
let signal = provider.signal();
assert_eq!(signal.get(), 42);
}
#[test]
fn test_provider_clone() {
let ctx = create_context::<i32>();
let provider1 = Provider::new(&ctx, 42);
let provider2 = provider1.clone();
assert_eq!(provider2.get(), 42);
}
#[test]
fn test_context_with_bool() {
let ctx = create_context::<bool>();
provide(&ctx, true);
assert_eq!(use_context(&ctx), Some(true));
clear_context(&ctx);
}
#[test]
fn test_context_with_vec() {
let ctx = create_context::<Vec<i32>>();
provide(&ctx, vec![1, 2, 3]);
assert_eq!(use_context(&ctx), Some(vec![1, 2, 3]));
clear_context(&ctx);
}
#[test]
fn test_context_with_option() {
let ctx = create_context::<Option<i32>>();
provide(&ctx, Some(42));
assert_eq!(use_context(&ctx), Some(Some(42)));
clear_context(&ctx);
}
#[test]
fn test_multiple_independent_contexts() {
let ctx1 = create_context::<i32>();
let ctx2 = create_context::<String>();
provide(&ctx1, 42);
provide(&ctx2, "hello".to_string());
assert_eq!(use_context(&ctx1), Some(42));
assert_eq!(use_context(&ctx2), Some("hello".to_string()));
clear_context(&ctx1);
clear_context(&ctx2);
}
}
pub fn with_context_scope<F, R>(f: F) -> R
where
F: FnOnce(&ContextScope) -> R,
{
let scope = ContextScope::new();
f(&scope)
}