use std::cell::Cell;
use std::collections::HashMap;
use std::sync::Arc;
use crate::registry::Registry;
use crate::scope::ScopeNode;
use crate::value::ContextValue;
pub(crate) const MAX_SCOPE_DEPTH: usize = 1024;
pub struct ContextStore {
pub(crate) scope_chain: Option<Arc<ScopeNode>>,
pub(crate) current_values: HashMap<&'static str, Arc<dyn ContextValue>>,
pub(crate) current_name: Option<String>,
pub(crate) depth: usize,
pub(crate) remote_chain: Arc<Vec<String>>,
pub(crate) remote_chain_base_depth: usize,
pub(crate) frozen_parent: Option<Arc<ScopeNode>>,
pub(crate) scope_barrier: Option<usize>,
}
impl ContextStore {
pub(crate) fn new() -> Self {
Self {
scope_chain: None,
current_values: HashMap::new(),
current_name: None,
depth: 1,
remote_chain: Arc::new(Vec::new()),
remote_chain_base_depth: 0,
frozen_parent: None,
scope_barrier: None,
}
}
pub(crate) fn from_values_with_chain(
values: HashMap<&'static str, Arc<dyn ContextValue>>,
remote_chain: Vec<String>,
) -> Self {
Self {
scope_chain: None,
current_values: values,
current_name: None,
depth: 1,
remote_chain: Arc::new(remote_chain),
remote_chain_base_depth: 1,
frozen_parent: None,
scope_barrier: None,
}
}
pub(crate) fn fork_child(&self) -> Self {
let frozen_values: HashMap<&'static str, Arc<dyn ContextValue>> = self
.current_values
.iter()
.map(|(&k, v)| (k, Arc::clone(v)))
.collect();
let frozen = Arc::new(ScopeNode {
name: self.current_name.clone(),
values: frozen_values,
parent: self.scope_chain.clone(),
depth: self.depth,
remote_chain: Arc::clone(&self.remote_chain),
remote_chain_base_depth: self.remote_chain_base_depth,
saved_scope_barrier: self.scope_barrier,
});
Self {
scope_chain: None,
current_values: HashMap::new(),
current_name: None,
depth: 1,
remote_chain: Arc::clone(&self.remote_chain),
remote_chain_base_depth: self.remote_chain_base_depth,
frozen_parent: Some(frozen),
scope_barrier: None,
}
}
pub(crate) fn push_scope(&mut self, registry: &Registry<'_>, name: Option<String>) -> usize {
self.depth += 1;
if self.depth > MAX_SCOPE_DEPTH + 1 {
return self.depth;
}
let cached = registry.cached_keys();
let mut cached_values: Vec<(&'static str, Arc<dyn ContextValue>)> = Vec::new();
for &key in &cached {
if let Some(val) = self.get_value(key) {
cached_values.push((key, val));
}
}
let frozen_values = std::mem::take(&mut self.current_values);
let node = Arc::new(ScopeNode {
name: self.current_name.take(),
values: frozen_values,
parent: self.scope_chain.take(),
depth: self.depth - 1,
remote_chain: Arc::clone(&self.remote_chain),
remote_chain_base_depth: self.remote_chain_base_depth,
saved_scope_barrier: self.scope_barrier,
});
self.scope_chain = Some(node);
self.current_name = name;
for (key, val) in cached_values {
self.current_values.insert(key, val);
}
self.depth
}
pub(crate) fn pop_scope(
&mut self,
expected_depth: usize,
) -> Option<crate::scope::ScopeGarbage> {
if self.depth < expected_depth || expected_depth <= 1 {
return None;
}
let mut all_old: Option<HashMap<&'static str, Arc<dyn ContextValue>>> = None;
while self.depth >= expected_depth && self.depth > 1 {
if self.depth > MAX_SCOPE_DEPTH + 1 {
self.depth -= 1;
continue;
}
let node = match self.scope_chain.take() {
Some(n) => n,
None => break,
};
let old_current = std::mem::take(&mut self.current_values);
match &mut all_old {
Some(existing) => existing.extend(old_current),
None => all_old = Some(old_current),
}
match Arc::try_unwrap(node) {
Ok(owned) => {
self.scope_chain = owned.parent;
self.current_name = owned.name;
self.current_values = owned.values;
self.depth = owned.depth;
self.remote_chain = owned.remote_chain;
self.remote_chain_base_depth = owned.remote_chain_base_depth;
self.scope_barrier = owned.saved_scope_barrier;
}
Err(shared) => {
self.scope_chain = shared.parent.clone();
self.current_name = shared.name.clone();
self.current_values = shared
.values
.iter()
.map(|(&k, v)| (k, Arc::clone(v)))
.collect();
self.depth = shared.depth;
self.remote_chain = Arc::clone(&shared.remote_chain);
self.remote_chain_base_depth = shared.remote_chain_base_depth;
self.scope_barrier = shared.saved_scope_barrier;
}
}
}
all_old.map(|old| crate::scope::ScopeGarbage { _old_values: old })
}
pub(crate) fn set_value(
&mut self,
key: &'static str,
value: Arc<dyn ContextValue>,
) -> Option<Arc<dyn ContextValue>> {
self.current_values.insert(key, value)
}
pub(crate) fn get_value(&self, key: &str) -> Option<Arc<dyn ContextValue>> {
if let Some(v) = self.current_values.get(key) {
return Some(Arc::clone(v));
}
let barrier = self.scope_barrier.unwrap_or(0);
let mut node = self.scope_chain.as_ref();
while let Some(n) = node {
if n.depth <= barrier {
break;
}
if let Some(v) = n.values.get(key) {
return Some(Arc::clone(v));
}
node = n.parent.as_ref();
}
if self.scope_barrier.is_none() {
let mut node = self.frozen_parent.as_ref();
while let Some(n) = node {
if let Some(v) = n.values.get(key) {
return Some(Arc::clone(v));
}
node = n.parent.as_ref();
}
}
None
}
pub(crate) fn collect_values(&self) -> HashMap<&'static str, Arc<dyn ContextValue>> {
let mut result: HashMap<&'static str, Arc<dyn ContextValue>> = HashMap::new();
for (&k, v) in &self.current_values {
result.insert(k, Arc::clone(v));
}
let barrier = self.scope_barrier.unwrap_or(0);
let mut node = self.scope_chain.as_ref();
while let Some(n) = node {
if n.depth <= barrier {
break;
}
for (&k, v) in &n.values {
result.entry(k).or_insert_with(|| Arc::clone(v));
}
node = n.parent.as_ref();
}
if self.scope_barrier.is_none() {
let mut node = self.frozen_parent.as_ref();
while let Some(n) = node {
for (&k, v) in &n.values {
result.entry(k).or_insert_with(|| Arc::clone(v));
}
node = n.parent.as_ref();
}
}
result
}
pub(crate) fn scope_chain(&self) -> Vec<String> {
let mut local_names = Vec::new();
if let Some(name) = &self.current_name {
if self.depth > self.remote_chain_base_depth {
local_names.push(name.clone());
}
}
let barrier = self.scope_barrier.unwrap_or(0);
let mut node = self.scope_chain.as_ref();
while let Some(n) = node {
if n.depth <= barrier {
break;
}
if n.depth > self.remote_chain_base_depth {
if let Some(name) = &n.name {
local_names.push(name.clone());
}
}
node = n.parent.as_ref();
}
local_names.reverse();
let mut parent_names = Vec::new();
if self.scope_barrier.is_none() {
let mut node = self.frozen_parent.as_ref();
while let Some(n) = node {
if let Some(name) = &n.name {
parent_names.push(name.clone());
}
node = n.parent.as_ref();
}
parent_names.reverse();
}
let max_len = crate::config::max_scope_chain_len();
let mut chain: Vec<String> = (*self.remote_chain).clone();
chain.extend(parent_names);
chain.extend(local_names);
if max_len > 0 && chain.len() > max_len {
let start = chain.len() - max_len;
chain.drain(..start);
}
chain
}
}
thread_local! {
pub(crate) static CONTEXT: Cell<Option<ContextStore>> =
Cell::new(Some(ContextStore::new()));
}
pub(crate) fn try_apply<R>(f: impl FnOnce(&mut ContextStore) -> R) -> Option<R> {
CONTEXT
.try_with(|cell| {
let mut store = cell.take()?;
let result = f(&mut store);
cell.set(Some(store));
Some(result)
})
.unwrap_or(None)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_max_scope_depth_creates_dead_scopes() {
let mut store = ContextStore::new();
for i in 0..MAX_SCOPE_DEPTH {
let registry = crate::registry::Registry::empty();
let depth = store.push_scope(®istry, Some(format!("scope_{}", i)));
assert_eq!(depth, i + 2); }
assert_eq!(store.depth, MAX_SCOPE_DEPTH + 1);
let real_node_count = {
let mut count = 0usize;
let mut node = store.scope_chain.as_ref();
while let Some(n) = node {
count += 1;
node = n.parent.as_ref();
}
count
};
assert_eq!(real_node_count, MAX_SCOPE_DEPTH);
let registry = crate::registry::Registry::empty();
let dead_depth_1 = store.push_scope(®istry, Some("dead_1".to_string()));
assert_eq!(dead_depth_1, MAX_SCOPE_DEPTH + 2);
let dead_depth_2 = store.push_scope(®istry, Some("dead_2".to_string()));
assert_eq!(dead_depth_2, MAX_SCOPE_DEPTH + 3);
let real_node_count_after = {
let mut count = 0usize;
let mut node = store.scope_chain.as_ref();
while let Some(n) = node {
count += 1;
node = n.parent.as_ref();
}
count
};
assert_eq!(real_node_count_after, MAX_SCOPE_DEPTH);
store.set_value("key", Arc::new("dead_val".to_string()));
assert!(store.get_value("key").is_some());
let garbage = store.pop_scope(dead_depth_2);
assert!(garbage.is_none()); assert_eq!(store.depth, MAX_SCOPE_DEPTH + 2);
let garbage = store.pop_scope(dead_depth_1);
assert!(garbage.is_none());
assert_eq!(store.depth, MAX_SCOPE_DEPTH + 1);
let real_depth = store.depth;
let garbage = store.pop_scope(real_depth);
assert!(garbage.is_some()); assert_eq!(store.depth, MAX_SCOPE_DEPTH);
}
#[test]
fn test_max_scope_depth_values_survive_dead_pop() {
let mut store = ContextStore::new();
let registry = crate::registry::Registry::empty();
store.push_scope(®istry, None);
store.set_value("persistent", Arc::new(42u64));
for _ in 0..MAX_SCOPE_DEPTH + 5 {
store.push_scope(®istry, None);
}
let val = store.get_value("persistent");
assert!(val.is_some());
for _ in 0..5 {
let d = store.depth;
store.pop_scope(d);
}
assert!(store.get_value("persistent").is_some());
}
#[test]
fn test_max_scope_depth_full_roundtrip() {
let mut store = ContextStore::new();
let mut depths = Vec::new();
let total = MAX_SCOPE_DEPTH + 10;
for _ in 0..total {
let registry = crate::registry::Registry::empty();
let d = store.push_scope(®istry, None);
depths.push(d);
}
assert_eq!(store.depth, total + 1);
for &d in depths.iter().rev() {
store.pop_scope(d);
}
assert_eq!(store.depth, 1); assert!(store.scope_chain.is_none()); }
}