use std::cell::Cell;
use std::sync::Arc;
use crate::async_ctx::{self, TASK_CONTEXT};
use crate::snapshot::ContextSnapshot;
use crate::store::ContextStore;
use crate::sync_ctx;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum ContextInheritance {
#[default]
Fork,
Snapshot,
}
pub fn spawn_with_async_context<F>(
mode: ContextInheritance,
future: F,
) -> tokio::task::JoinHandle<F::Output>
where
F: std::future::Future + Send + 'static,
F::Output: Send + 'static,
{
let store = capture_async(mode);
tokio::spawn(TASK_CONTEXT.scope(Cell::new(Some(store)), future))
}
pub fn spawn_blocking_with_async_context<F, R>(
mode: ContextInheritance,
f: F,
) -> tokio::task::JoinHandle<R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
let store = capture_async(mode);
tokio::task::spawn_blocking(move || {
let saved = sync_ctx::try_apply(|s| std::mem::replace(s, store));
let result = f();
if let Some(original) = saved {
sync_ctx::try_apply(|s| *s = original);
}
result
})
}
pub fn spawn_with_sync_context<F>(
mode: ContextInheritance,
future: F,
) -> tokio::task::JoinHandle<F::Output>
where
F: std::future::Future + Send + 'static,
F::Output: Send + 'static,
{
let store = capture_sync(mode);
tokio::spawn(TASK_CONTEXT.scope(Cell::new(Some(store)), future))
}
fn capture_async(mode: ContextInheritance) -> ContextStore {
match mode {
ContextInheritance::Fork => async_ctx::fork().unwrap_or_else(ContextStore::new),
ContextInheritance::Snapshot => snapshot_to_store(async_ctx::snapshot()),
}
}
fn capture_sync(mode: ContextInheritance) -> ContextStore {
match mode {
ContextInheritance::Fork => sync_ctx::fork().unwrap_or_else(ContextStore::new),
ContextInheritance::Snapshot => snapshot_to_store(sync_ctx::snapshot()),
}
}
fn snapshot_to_store(snap: ContextSnapshot) -> ContextStore {
let chain = snap.scope_chain.clone();
let values = snap
.values
.iter()
.map(|(k, v)| (*k, Arc::clone(v)))
.collect();
ContextStore::from_values_with_chain(values, chain)
}