use std::sync::atomic::Ordering;
use std::sync::Arc;
use std::time::Duration;
use futures::pin_mut;
use tokio::sync::broadcast;
use crate::trust_graph::{policy_for_autonomy_tier, AutonomyTier};
use crate::value::VmValue;
use super::state::{ACTIVE_DISPATCH_CONTEXT, ACTIVE_DISPATCH_WAIT_LEASE};
use super::types::{
DispatchContext, DispatchError, DispatchExecutionPolicyGuard, DispatchWaitLease, Dispatcher,
};
use super::util::{
dispatch_cancel_requested, dispatch_error_from_vm_error, event_to_handler_value, recv_cancel,
split_binding_key,
};
use super::TriggerEvent;
impl Dispatcher {
#[allow(clippy::too_many_arguments)]
pub(super) async fn invoke_vm_callable(
&self,
closure: &crate::value::VmClosure,
binding_key: &str,
event: &TriggerEvent,
replay_of_event_id: Option<&String>,
agent_id: &str,
action: &str,
autonomy_tier: AutonomyTier,
wait_lease: Option<DispatchWaitLease>,
cancel_rx: &mut broadcast::Receiver<()>,
) -> Result<VmValue, DispatchError> {
let mut vm = self.base_vm.child_vm();
let cancel_token = Arc::new(std::sync::atomic::AtomicBool::new(false));
if self.state.shutting_down.load(Ordering::SeqCst) {
cancel_token.store(true, Ordering::SeqCst);
}
self.state
.cancel_tokens
.lock()
.expect("dispatcher cancel tokens poisoned")
.push(cancel_token.clone());
vm.install_cancel_token(cancel_token.clone());
let arg = event_to_handler_value(event)?;
let args = [arg];
let tier_policy = policy_for_autonomy_tier(autonomy_tier);
let effective_policy = match crate::orchestration::current_execution_policy() {
Some(parent) => parent
.intersect(&tier_policy)
.map_err(|error| DispatchError::Local(error.to_string()))?,
None => tier_policy,
};
crate::orchestration::push_execution_policy(effective_policy);
let _policy_guard = DispatchExecutionPolicyGuard;
let future = vm.call_closure_pub(closure, &args);
pin_mut!(future);
let (binding_id, binding_version) = split_binding_key(binding_key);
let prior_context = ACTIVE_DISPATCH_CONTEXT.with(|slot| {
slot.borrow_mut().replace(DispatchContext {
trigger_event: event.clone(),
replay_of_event_id: replay_of_event_id.cloned(),
binding_id,
binding_version,
agent_id: agent_id.to_string(),
action: action.to_string(),
autonomy_tier,
})
});
let prior_wait_lease = ACTIVE_DISPATCH_WAIT_LEASE
.with(|slot| std::mem::replace(&mut *slot.borrow_mut(), wait_lease));
let prior_hitl_state = crate::stdlib::hitl::take_hitl_state();
crate::stdlib::hitl::reset_hitl_state();
let mut poll = tokio::time::interval(Duration::from_millis(100));
let result = loop {
tokio::select! {
result = &mut future => break result,
_ = recv_cancel(cancel_rx) => {
cancel_token.store(true, Ordering::SeqCst);
}
_ = poll.tick() => {
if dispatch_cancel_requested(
&self.event_log,
binding_key,
&event.id.0,
replay_of_event_id,
)
.await? {
cancel_token.store(true, Ordering::SeqCst);
}
}
}
};
ACTIVE_DISPATCH_CONTEXT.with(|slot| {
*slot.borrow_mut() = prior_context;
});
ACTIVE_DISPATCH_WAIT_LEASE.with(|slot| {
*slot.borrow_mut() = prior_wait_lease;
});
crate::stdlib::hitl::restore_hitl_state(prior_hitl_state);
{
let mut tokens = self
.state
.cancel_tokens
.lock()
.expect("dispatcher cancel tokens poisoned");
tokens.retain(|token| !Arc::ptr_eq(token, &cancel_token));
}
if cancel_token.load(Ordering::SeqCst) {
if dispatch_cancel_requested(
&self.event_log,
binding_key,
&event.id.0,
replay_of_event_id,
)
.await?
{
Err(DispatchError::Cancelled(
"trigger cancel request cancelled local handler".to_string(),
))
} else {
Err(DispatchError::Cancelled(
"dispatcher shutdown cancelled local handler".to_string(),
))
}
} else {
result.map_err(dispatch_error_from_vm_error)
}
}
#[allow(clippy::too_many_arguments)]
pub(super) async fn invoke_vm_callable_with_timeout(
&self,
closure: &crate::value::VmClosure,
binding_key: &str,
event: &TriggerEvent,
replay_of_event_id: Option<&String>,
agent_id: &str,
action: &str,
autonomy_tier: AutonomyTier,
cancel_rx: &mut broadcast::Receiver<()>,
timeout: Option<Duration>,
) -> Result<VmValue, DispatchError> {
let future = self.invoke_vm_callable(
closure,
binding_key,
event,
replay_of_event_id,
agent_id,
action,
autonomy_tier,
None,
cancel_rx,
);
pin_mut!(future);
if let Some(timeout) = timeout {
match tokio::time::timeout(timeout, future).await {
Ok(result) => result,
Err(_) => Err(DispatchError::Local(
"predicate evaluation timed out".to_string(),
)),
}
} else {
future.await
}
}
}