use std::cell::Cell;
use std::collections::HashMap;
use std::sync::Arc;
use crate::scope::{ScopeGarbage, ScopeGuard, ScopeNode};
use crate::value::ContextValue;
thread_local! {
pub(crate) static CONTEXT: Cell<Option<ContextStore>> =
Cell::new(Some(ContextStore::new()));
static FORCE_THREAD_LOCAL_DEPTH: Cell<u32> = const { Cell::new(0) };
}
tokio::task_local! {
pub(crate) static TASK_CONTEXT: Cell<Option<ContextStore>>;
}
pub(crate) struct ContextStore {
pub(crate) scope_chain: Option<Arc<ScopeNode>>,
pub(crate) current_values: HashMap<&'static str, Arc<dyn ContextValue>>,
pub(crate) current_name: Option<String>,
pub(crate) depth: usize,
pub(crate) remote_chain: Arc<Vec<String>>,
pub(crate) remote_chain_base_depth: usize,
}
impl ContextStore {
pub(crate) fn new() -> Self {
Self {
scope_chain: None,
current_values: HashMap::new(),
current_name: None,
depth: 1,
remote_chain: Arc::new(Vec::new()),
remote_chain_base_depth: 0,
}
}
pub(crate) fn from_values_with_chain(
values: HashMap<&'static str, Arc<dyn ContextValue>>,
remote_chain: Vec<String>,
) -> Self {
Self {
scope_chain: None,
current_values: values,
current_name: None,
depth: 1,
remote_chain: Arc::new(remote_chain),
remote_chain_base_depth: 1,
}
}
pub(crate) fn push_scope(&mut self, name: Option<String>) -> usize {
let cached = crate::registry::cached_keys();
let mut cached_values: Vec<(&'static str, Arc<dyn ContextValue>)> = Vec::new();
for &key in &cached {
if let Some(val) = self.get_value(key) {
cached_values.push((key, val));
}
}
let frozen_values = std::mem::take(&mut self.current_values);
let node = Arc::new(ScopeNode {
name: self.current_name.take(),
values: frozen_values,
parent: self.scope_chain.take(),
depth: self.depth,
remote_chain: Arc::clone(&self.remote_chain),
remote_chain_base_depth: self.remote_chain_base_depth,
});
self.scope_chain = Some(node);
self.current_name = name;
self.depth += 1;
for (key, val) in cached_values {
self.current_values.insert(key, val);
}
self.depth
}
pub(crate) fn pop_scope(
&mut self,
expected_depth: usize,
) -> Option<ScopeGarbage> {
if self.depth != expected_depth || self.depth <= 1 {
return None;
}
let node = self.scope_chain.take()?;
let old_current = std::mem::take(&mut self.current_values);
match Arc::try_unwrap(node) {
Ok(owned) => {
self.scope_chain = owned.parent;
self.current_name = owned.name;
self.current_values = owned.values;
self.depth = owned.depth;
self.remote_chain = owned.remote_chain;
self.remote_chain_base_depth = owned.remote_chain_base_depth;
}
Err(shared) => {
self.scope_chain = shared.parent.clone();
self.current_name = shared.name.clone();
self.current_values = shared.values.iter()
.map(|(&k, v)| (k, Arc::clone(v)))
.collect();
self.depth = shared.depth;
self.remote_chain = Arc::clone(&shared.remote_chain);
self.remote_chain_base_depth = shared.remote_chain_base_depth;
}
}
Some(ScopeGarbage {
_old_values: old_current,
})
}
pub(crate) fn set_value(
&mut self,
key: &'static str,
value: Arc<dyn ContextValue>,
) -> Option<Arc<dyn ContextValue>> {
self.current_values.insert(key, value)
}
pub(crate) fn get_value(&self, key: &str) -> Option<Arc<dyn ContextValue>> {
if let Some(v) = self.current_values.get(key) {
return Some(Arc::clone(v));
}
let mut node = self.scope_chain.as_ref();
while let Some(n) = node {
if let Some(v) = n.values.get(key) {
return Some(Arc::clone(v));
}
node = n.parent.as_ref();
}
None
}
pub(crate) fn collect_values(&self) -> HashMap<&'static str, Arc<dyn ContextValue>> {
let mut result: HashMap<&'static str, Arc<dyn ContextValue>> = HashMap::new();
for (&k, v) in &self.current_values {
result.insert(k, Arc::clone(v));
}
let mut node = self.scope_chain.as_ref();
while let Some(n) = node {
for (&k, v) in &n.values {
result.entry(k).or_insert_with(|| Arc::clone(v));
}
node = n.parent.as_ref();
}
result
}
pub(crate) fn set_remote_chain(&mut self, chain: Vec<String>) {
self.remote_chain = Arc::new(chain);
self.remote_chain_base_depth = self.depth;
}
pub(crate) fn scope_chain(&self) -> Vec<String> {
let mut local_names = Vec::new();
if let Some(name) = &self.current_name {
if self.depth > self.remote_chain_base_depth {
local_names.push(name.clone());
}
}
let mut node = self.scope_chain.as_ref();
while let Some(n) = node {
if n.depth > self.remote_chain_base_depth {
if let Some(name) = &n.name {
local_names.push(name.clone());
}
}
node = n.parent.as_ref();
}
local_names.reverse();
let max_len = crate::config::max_scope_chain_len();
let mut chain: Vec<String> = (*self.remote_chain).clone();
chain.extend(local_names);
if max_len > 0 && chain.len() > max_len {
let start = chain.len() - max_len;
chain.drain(..start);
}
chain
}
}
fn with_current_cell<R>(mut f: impl FnMut(&Cell<Option<ContextStore>>) -> R) -> R {
{
let force = FORCE_THREAD_LOCAL_DEPTH
.try_with(|c| c.get())
.unwrap_or(0)
> 0;
if !force {
let result: Cell<Option<R>> = Cell::new(None);
let found = TASK_CONTEXT.try_with(|cell| {
result.set(Some(f(cell)));
});
if found.is_ok() {
return result.into_inner()
.expect("invariant: closure set the result when try_with succeeded");
}
}
}
match CONTEXT.try_with(|cell| f(cell)) {
Ok(r) => r,
Err(_) => {
let temp = Cell::new(None);
f(&temp)
}
}
}
fn with_store<R>(f: impl FnOnce(&mut ContextStore) -> R) -> Option<R> {
let f = std::cell::Cell::new(Some(f));
with_current_cell(|cell| {
let mut store = cell.take()?; let func = f.take().expect("with_store closure called more than once");
let result = func(&mut store);
cell.set(Some(store));
Some(result)
})
}
pub fn enter_scope() -> ScopeGuard {
with_store(|store| ScopeGuard::new(store.push_scope(None)))
.unwrap_or_else(ScopeGuard::noop)
}
pub fn enter_named_scope(name: impl Into<String>) -> ScopeGuard {
let name = name.into();
with_store(|store| ScopeGuard::new(store.push_scope(Some(name))))
.unwrap_or_else(ScopeGuard::noop)
}
pub(crate) fn leave_scope(expected_depth: usize) {
if expected_depth == usize::MAX {
return; }
let _garbage = with_store(|store| store.pop_scope(expected_depth));
}
pub fn scope<R>(f: impl FnOnce() -> R) -> R {
let _guard = enter_scope();
f()
}
pub(crate) fn do_fork() -> Option<crate::fork::ForkHandle> {
with_store(|store| crate::fork::create_fork_handle(store))
}
struct ScopeCleanup(usize);
impl Drop for ScopeCleanup {
fn drop(&mut self) {
leave_scope(self.0);
}
}
pub async fn scope_async<F, R>(f: F) -> R
where
F: std::future::Future<Output = R>,
{
let depth = with_store(|store| store.push_scope(None));
match depth {
None => f.await, Some(depth) => {
let cleanup = ScopeCleanup(depth);
let result = f.await;
std::mem::forget(cleanup);
leave_scope(depth);
result
}
}
}
pub async fn named_scope_async<F, R>(name: impl Into<String>, f: F) -> R
where
F: std::future::Future<Output = R>,
{
let name = name.into();
let depth = with_store(|store| store.push_scope(Some(name)));
match depth {
None => f.await,
Some(depth) => {
let cleanup = ScopeCleanup(depth);
let result = f.await;
std::mem::forget(cleanup);
leave_scope(depth);
result
}
}
}
pub(crate) fn get_value(key: &str) -> Option<Arc<dyn ContextValue>> {
with_store(|store| store.get_value(key)).flatten()
}
pub(crate) fn set_value(key: &'static str, value: Arc<dyn ContextValue>) {
let _old = with_store(|store| store.set_value(key, value));
}
pub(crate) fn collect_values() -> HashMap<&'static str, Arc<dyn ContextValue>> {
with_store(|store| store.collect_values())
.unwrap_or_default()
}
pub fn scope_chain() -> Vec<String> {
with_store(|store| store.scope_chain())
.unwrap_or_default()
}
pub(crate) fn collect_scope_chain() -> Vec<String> {
scope_chain()
}
pub(crate) fn set_remote_chain(chain: Vec<String>) {
with_store(|store| store.set_remote_chain(chain));
}
pub fn force_thread_local<R>(f: impl FnOnce() -> R) -> R {
FORCE_THREAD_LOCAL_DEPTH.with(|c| c.set(c.get() + 1));
struct DepthGuard;
impl Drop for DepthGuard {
fn drop(&mut self) {
crate::storage::FORCE_THREAD_LOCAL_DEPTH.with(|c| c.set(c.get() - 1));
}
}
let _guard = DepthGuard;
f()
}