use std::cell::Cell;
use std::collections::HashMap;
use std::sync::Arc;
use crate::scope::ScopeNode;
use crate::storage::{self, ContextStore, TASK_CONTEXT};
use crate::value::ContextValue;
#[derive(Clone)]
pub struct ForkHandle {
frozen: Option<Arc<ScopeNode>>,
remote_chain: Arc<Vec<String>>,
remote_chain_base_depth: usize,
cached_values: HashMap<&'static str, Arc<dyn ContextValue>>,
}
impl ForkHandle {
fn empty() -> Self {
Self {
frozen: None,
remote_chain: Arc::new(Vec::new()),
remote_chain_base_depth: 0,
cached_values: HashMap::new(),
}
}
}
pub fn fork() -> ForkHandle {
storage::do_fork().unwrap_or_else(ForkHandle::empty)
}
pub async fn with_fork<F>(handle: ForkHandle, future: F) -> F::Output
where
F: std::future::Future,
{
let store = build_forked_store(handle);
TASK_CONTEXT.scope(Cell::new(Some(store)), future).await
}
pub fn spawn_with_fork_async<F>(future: F) -> tokio::task::JoinHandle<F::Output>
where
F: std::future::Future + Send + 'static,
F::Output: Send + 'static,
{
let handle = fork();
tokio::spawn(with_fork(handle, future))
}
fn build_forked_store(handle: ForkHandle) -> ContextStore {
let (depth, remote_chain, remote_chain_base_depth) = match &handle.frozen {
Some(node) => (
node.depth + 1,
Arc::clone(&handle.remote_chain),
handle.remote_chain_base_depth,
),
None => (1, handle.remote_chain, handle.remote_chain_base_depth),
};
ContextStore {
scope_chain: handle.frozen,
current_values: handle.cached_values,
current_name: None,
depth,
remote_chain,
remote_chain_base_depth,
}
}
pub(crate) fn create_fork_handle(store: &ContextStore) -> ForkHandle {
let frozen_values: HashMap<&'static str, Arc<dyn ContextValue>> = store
.current_values
.iter()
.map(|(&k, v)| (k, Arc::clone(v)))
.collect();
let frozen = Arc::new(ScopeNode {
name: store.current_name.clone(),
values: frozen_values,
parent: store.scope_chain.clone(),
depth: store.depth,
remote_chain: Arc::clone(&store.remote_chain),
remote_chain_base_depth: store.remote_chain_base_depth,
});
let cached_keys = crate::registry::cached_keys();
let mut cached_values: HashMap<&'static str, Arc<dyn ContextValue>> = HashMap::new();
for key in cached_keys {
if let Some(val) = store.get_value(key) {
cached_values.insert(key, val);
}
}
ForkHandle {
remote_chain: Arc::clone(&store.remote_chain),
remote_chain_base_depth: store.remote_chain_base_depth,
frozen: Some(frozen),
cached_values,
}
}