use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use pin_project_lite::pin_project;
use super::command_policy::{
swap_command_policy_hook_depth, swap_command_policy_stack, CommandPolicy,
};
use super::policy::{
swap_approval_policy_stack, swap_execution_policy_stack, swap_trusted_bridge_depth,
CapabilityPolicy, ToolApprovalPolicy,
};
use crate::autonomy::{swap_autonomy_policy_stack, AutonomyPolicy};
use crate::connectors::harn_module::swap_active_harn_connector_ctx;
use crate::connectors::ConnectorCtx;
use crate::llm::permissions::{swap_dynamic_permission_stack, DynamicPermissionPolicy};
use crate::runtime_context::{swap_runtime_context_overlay_stack, RuntimeContextOverlay};
use crate::stdlib::template::llm_context::{swap_llm_render_stack, LlmRenderContext};
#[derive(Default, Clone)]
pub(crate) struct AmbientExecutionScope {
execution: Vec<CapabilityPolicy>,
approval: Vec<ToolApprovalPolicy>,
command: Vec<CommandPolicy>,
permissions: Vec<DynamicPermissionPolicy>,
runtime_context: Vec<RuntimeContextOverlay>,
autonomy: Vec<AutonomyPolicy>,
llm_render: Vec<LlmRenderContext>,
connector_ctx: Vec<ConnectorCtx>,
trusted_depth: usize,
command_hook_depth: usize,
}
fn clone_via_swap<T: Clone>(swap: impl Fn(Vec<T>) -> Vec<T>) -> Vec<T> {
let owned = swap(Vec::new());
let cloned = owned.clone();
let _ = swap(owned);
cloned
}
impl AmbientExecutionScope {
pub(crate) fn capture_inherited() -> Self {
Self {
command: clone_via_swap(swap_command_policy_stack),
permissions: clone_via_swap(swap_dynamic_permission_stack),
runtime_context: clone_via_swap(swap_runtime_context_overlay_stack),
autonomy: clone_via_swap(swap_autonomy_policy_stack),
..Self::default()
}
}
fn swap_in(self) -> Self {
Self {
execution: swap_execution_policy_stack(self.execution),
approval: swap_approval_policy_stack(self.approval),
command: swap_command_policy_stack(self.command),
permissions: swap_dynamic_permission_stack(self.permissions),
runtime_context: swap_runtime_context_overlay_stack(self.runtime_context),
autonomy: swap_autonomy_policy_stack(self.autonomy),
llm_render: swap_llm_render_stack(self.llm_render),
connector_ctx: swap_active_harn_connector_ctx(self.connector_ctx),
trusted_depth: swap_trusted_bridge_depth(self.trusted_depth),
command_hook_depth: swap_command_policy_hook_depth(self.command_hook_depth),
}
}
}
pin_project! {
pub(crate) struct Scoped<F> {
#[pin]
inner: F,
scope: Option<AmbientExecutionScope>,
}
}
pub(crate) fn scope_ambient<F: Future>(scope: AmbientExecutionScope, inner: F) -> Scoped<F> {
Scoped {
inner,
scope: Some(scope),
}
}
struct RestoreGuard<'a> {
outer: Option<AmbientExecutionScope>,
slot: &'a mut Option<AmbientExecutionScope>,
}
impl Drop for RestoreGuard<'_> {
fn drop(&mut self) {
if let Some(outer) = self.outer.take() {
*self.slot = Some(outer.swap_in());
}
}
}
impl<F: Future> Future for Scoped<F> {
type Output = F::Output;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<F::Output> {
let this = self.project();
let task_scope = this.scope.take().unwrap_or_default();
let outer = task_scope.swap_in();
let _restore = RestoreGuard {
outer: Some(outer),
slot: this.scope,
};
this.inner.poll(cx)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::orchestration::{current_execution_policy, push_execution_policy};
fn policy_named(tool: &str) -> CapabilityPolicy {
CapabilityPolicy {
tools: vec![tool.to_string()],
..Default::default()
}
}
#[tokio::test]
async fn scoped_tasks_do_not_cross_wire_execution_policy() {
let local = tokio::task::LocalSet::new();
local
.run_until(async {
let alpha = tokio::task::spawn_local(scope_ambient(
AmbientExecutionScope::default(),
async {
push_execution_policy(policy_named("alpha"));
tokio::task::yield_now().await;
tokio::task::yield_now().await;
current_execution_policy().map(|p| p.tools)
},
));
let beta = tokio::task::spawn_local(scope_ambient(
AmbientExecutionScope::default(),
async {
push_execution_policy(policy_named("beta"));
tokio::task::yield_now().await;
tokio::task::yield_now().await;
current_execution_policy().map(|p| p.tools)
},
));
assert_eq!(alpha.await.unwrap(), Some(vec!["alpha".to_string()]));
assert_eq!(beta.await.unwrap(), Some(vec!["beta".to_string()]));
})
.await;
assert!(current_execution_policy().is_none());
}
#[tokio::test]
async fn scope_is_restored_after_completion() {
let local = tokio::task::LocalSet::new();
local
.run_until(async {
tokio::task::spawn_local(scope_ambient(AmbientExecutionScope::default(), async {
push_execution_policy(policy_named("gamma"));
tokio::task::yield_now().await;
}))
.await
.unwrap();
})
.await;
assert!(current_execution_policy().is_none());
}
}