use std::cell::Cell;
use std::sync::Arc;
use crate::scope::ScopeGuard;
use crate::snapshot::ContextSnapshot;
use crate::store::ContextStore;
use crate::value::ContextValue;
tokio::task_local! {
pub(crate) static TASK_CONTEXT: Cell<Option<ContextStore>>;
}
pub fn push_scope(name: &str) -> ScopeGuard {
try_push_scope(name).unwrap_or_else(ScopeGuard::noop)
}
pub fn try_push_scope(name: &str) -> Option<ScopeGuard> {
let name = name.to_string();
try_apply(|store| ScopeGuard::new_async(store.push_scope(Some(name))))
}
pub fn pop_scope(expected_depth: usize) {
let _garbage = try_apply(|store| store.pop_scope(expected_depth));
}
pub(crate) fn set_scope_barrier() {
try_apply(|store| store.set_scope_barrier());
}
pub fn current_depth() -> Option<usize> {
try_apply(|store| store.depth)
}
pub fn scope_chain() -> Vec<String> {
try_apply(|store| store.scope_chain()).unwrap_or_default()
}
pub fn set_context<T>(key: &'static str, value: T)
where
T: Clone + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
{
try_apply(|store| {
store.set_value(key, Arc::new(value));
});
}
pub fn get_context<T>(key: &str) -> Option<T>
where
T: Clone + Send + Sync + 'static,
{
try_apply(|store| {
store
.get_value(key)
.and_then(|arc| arc.as_any().downcast_ref::<T>().cloned())
})
.flatten()
}
pub fn update_context<T>(key: &'static str, f: impl FnOnce(T) -> T)
where
T: Clone + Default + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
{
let old = get_context::<T>(key).unwrap_or_default();
let new = f(old);
set_context(key, new);
}
pub fn set_raw_value(key: &'static str, value: Arc<dyn ContextValue>) {
try_apply(|store| {
store.set_value(key, value);
});
}
pub(crate) fn set_remote_chain(chain: Vec<String>) {
try_apply(|store| store.set_remote_chain(chain));
}
pub fn get_raw_value(key: &str) -> Option<Arc<dyn ContextValue>> {
try_apply(|store| store.get_value(key)).flatten()
}
pub fn with_context_value<R>(key: &str, f: impl FnOnce(&dyn std::any::Any) -> R) -> Option<R> {
get_raw_value(key).map(|arc_val| f(arc_val.as_any()))
}
pub fn snapshot() -> ContextSnapshot {
try_apply(|store| {
let values = store.collect_values();
let scope_chain = store.scope_chain();
ContextSnapshot {
values: Arc::new(values),
scope_chain,
}
})
.unwrap_or_default()
}
pub fn attach(snap: ContextSnapshot) -> ScopeGuard {
let guard = push_scope("");
if !snap.scope_chain.is_empty() {
set_remote_chain(snap.scope_chain);
}
for (key, val) in snap.values.iter() {
set_raw_value(key, Arc::clone(val));
}
guard
}
pub fn serialize_context() -> Result<Vec<u8>, crate::error::ContextError> {
let values = snapshot()
.values
.iter()
.map(|(k, v)| (*k, Arc::clone(v)))
.collect();
let chain = scope_chain();
crate::wire::serialize_from(values, chain)
}
pub fn deserialize_context(bytes: &[u8]) -> Result<ScopeGuard, crate::error::ContextError> {
crate::wire::deserialize_into(bytes, true)
}
pub async fn with_context<F>(snapshot: ContextSnapshot, fut: F) -> F::Output
where
F: std::future::Future,
{
let chain = snapshot.scope_chain.clone();
let values = snapshot
.values
.iter()
.map(|(k, v)| (*k, Arc::clone(v)))
.collect();
let store = ContextStore::from_values_with_chain(values, chain);
TASK_CONTEXT.scope(Cell::new(Some(store)), fut).await
}
pub async fn scope<F>(name: &str, fut: F) -> F::Output
where
F: std::future::Future,
{
let name_owned = name.to_string();
let depth = try_apply(|store| store.push_scope(Some(name_owned)));
match depth {
None => fut.await,
Some(depth) => {
struct ScopeCleanup(usize);
impl Drop for ScopeCleanup {
fn drop(&mut self) {
super::async_ctx::pop_scope(self.0);
}
}
let cleanup = ScopeCleanup(depth);
let result = fut.await;
std::mem::forget(cleanup);
pop_scope(depth);
result
}
}
}
pub(crate) fn fork() -> Option<ContextStore> {
try_apply(|store| store.fork_child())
}
fn try_apply<R>(f: impl FnOnce(&mut ContextStore) -> R) -> Option<R> {
let found = TASK_CONTEXT.try_with(|cell| {
let mut store = cell.take()?;
let r = f(&mut store);
cell.set(Some(store));
Some(r)
});
match found {
Ok(Some(r)) => Some(r),
_ => None,
}
}