use crate::budget::Budget;
use crate::error::AgentError;
use crate::event::{AgentEvent, ScopedAgentEvent, StreamScopeFrame};
use crate::hooks::{HookDecision, HookInvocation, HookPatch, HookPoint};
use crate::ops::{
ForkBranch, ForkBudgetPolicy, OperationId, OperationResult, SpawnSpec, ToolAccessPolicy,
};
use crate::retry::RetryPolicy;
use crate::service::TurnToolOverlay;
use crate::session::Session;
use crate::state::LoopState;
#[cfg(target_arch = "wasm32")]
use crate::tokio;
use crate::tool_scope::{
EXTERNAL_TOOL_FILTER_METADATA_KEY, ToolFilter, ToolScopeRevision, ToolScopeStageError,
};
use crate::types::{Message, RunResult};
use async_trait::async_trait;
use std::collections::HashSet;
use std::sync::Arc;
use tokio::sync::mpsc;
use super::{
Agent, AgentBuilder, AgentLlmClient, AgentSessionStore, AgentToolDispatcher,
FilteredToolDispatcher,
};
fn spawn_scoped_forwarder(
mut child_event_rx: mpsc::Receiver<AgentEvent>,
scoped_tx: mpsc::Sender<ScopedAgentEvent>,
parent_scope_path: Arc<Vec<StreamScopeFrame>>,
child_scope_frame: StreamScopeFrame,
) -> tokio::task::JoinHandle<()> {
let base_scope_path = if parent_scope_path.is_empty() {
vec![
StreamScopeFrame::Primary {
session_id: "unknown".to_string(),
},
child_scope_frame,
]
} else {
let mut path = (*parent_scope_path).clone();
path.push(child_scope_frame);
path
};
tokio::spawn(async move {
while let Some(event) = child_event_rx.recv().await {
let scoped = ScopedAgentEvent::new(base_scope_path.clone(), event);
if scoped_tx.send(scoped).await.is_err() {
break;
}
}
})
}
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
pub trait AgentRunner: Send {
async fn run(&mut self, prompt: String) -> Result<RunResult, AgentError>;
async fn run_with_events(
&mut self,
prompt: String,
tx: mpsc::Sender<AgentEvent>,
) -> Result<RunResult, AgentError>;
}
impl<C, T, S> Agent<C, T, S>
where
C: AgentLlmClient + ?Sized,
T: AgentToolDispatcher + ?Sized,
S: AgentSessionStore + ?Sized,
{
pub fn stage_external_tool_filter(
&mut self,
filter: ToolFilter,
) -> Result<ToolScopeRevision, ToolScopeStageError> {
let handle = self.tool_scope.handle();
let revision = handle.stage_external_filter(filter.clone())?;
let _ = handle.staged_revision();
if let Ok(value) = serde_json::to_value(filter) {
self.session
.set_metadata(EXTERNAL_TOOL_FILTER_METADATA_KEY, value);
}
Ok(revision)
}
pub fn set_flow_tool_overlay(
&mut self,
overlay: Option<TurnToolOverlay>,
) -> Result<(), ToolScopeStageError> {
let handle = self.tool_scope.handle();
if let Some(overlay) = overlay {
let allow = overlay
.allowed_tools
.map(|tools| tools.into_iter().collect::<HashSet<_>>());
let deny = overlay
.blocked_tools
.unwrap_or_default()
.into_iter()
.collect::<HashSet<_>>();
handle.set_turn_overlay(allow, deny)?;
} else {
handle.clear_turn_overlay();
}
Ok(())
}
#[cfg(test)]
pub(crate) fn inject_tool_scope_boundary_failure_once_for_test(&self) {
self.tool_scope.inject_boundary_failure_once_for_test();
}
}
impl<C, T, S> Agent<C, T, S>
where
C: AgentLlmClient + ?Sized + 'static,
T: AgentToolDispatcher + ?Sized + 'static,
S: AgentSessionStore + ?Sized + 'static,
{
pub fn builder() -> AgentBuilder {
AgentBuilder::new()
}
pub fn session(&self) -> &Session {
&self.session
}
pub fn session_mut(&mut self) -> &mut Session {
&mut self.session
}
pub fn budget(&self) -> &Budget {
&self.budget
}
pub fn state(&self) -> &LoopState {
&self.state
}
pub fn retry_policy(&self) -> &RetryPolicy {
&self.retry_policy
}
pub fn depth(&self) -> u32 {
self.depth
}
pub fn event_tap(&self) -> &crate::event_tap::EventTap {
&self.event_tap
}
pub fn system_context_state(
&self,
) -> Arc<std::sync::Mutex<crate::session::SessionSystemContextState>> {
Arc::clone(&self.system_context_state)
}
pub fn session_with_system_context_state(&self) -> Session {
let mut session = self.session.clone();
let state = match self.system_context_state.lock() {
Ok(guard) => guard.clone(),
Err(poisoned) => {
tracing::warn!("system-context state lock poisoned while cloning session");
poisoned.into_inner().clone()
}
};
if let Err(err) = session.set_system_context_state(state) {
tracing::warn!(error = %err, "failed to serialize system-context state into session");
}
session
}
pub(crate) fn sync_system_context_state_to_session(&mut self) {
let state = match self.system_context_state.lock() {
Ok(guard) => guard.clone(),
Err(poisoned) => {
tracing::warn!("system-context state lock poisoned while syncing session");
poisoned.into_inner().clone()
}
};
if let Err(err) = self.session.set_system_context_state(state) {
tracing::warn!(error = %err, "failed to serialize system-context state into session");
}
}
pub(crate) fn apply_pending_system_context_boundary(&mut self) -> usize {
let pending = {
let mut state = match self.system_context_state.lock() {
Ok(guard) => guard,
Err(poisoned) => {
tracing::warn!("system-context state lock poisoned while applying boundary");
poisoned.into_inner()
}
};
if state.pending.is_empty() {
return 0;
}
let pending = state.pending.clone();
state.mark_pending_applied();
pending
};
self.session.append_system_context_blocks(&pending);
self.sync_system_context_state_to_session();
pending.len()
}
pub(crate) async fn checkpoint_current_session(&mut self) {
self.sync_system_context_state_to_session();
if let Some(ref cp) = self.checkpointer {
cp.checkpoint(&self.session).await;
}
}
pub async fn spawn(&self, spec: SpawnSpec) -> Result<OperationId, AgentError> {
if self.depth + 1 > self.sub_agent_manager.limits.max_depth {
return Err(AgentError::DepthLimitExceeded {
depth: self.depth + 1,
max: self.sub_agent_manager.limits.max_depth,
});
}
if !self.sub_agent_manager.can_spawn().await {
return Err(AgentError::SubAgentLimitExceeded {
limit: self.sub_agent_manager.limits.max_concurrent_agents,
});
}
let all_tools = self.tools.tools();
let allowed_tools = self
.sub_agent_manager
.apply_tool_access_policy(&all_tools, &spec.tool_access);
if let ToolAccessPolicy::AllowList(ref names) = spec.tool_access {
for name in names {
if !all_tools.iter().any(|t| &t.name == name) {
return Err(AgentError::InvalidToolAccess { tool: name.clone() });
}
}
}
let messages = self
.sub_agent_manager
.apply_context_strategy(&self.session, &spec.context);
let mut sub_session = Session::new();
for msg in messages {
sub_session.push(msg);
}
if let Some(sys_prompt) = &spec.system_prompt {
sub_session.set_system_prompt(sys_prompt.clone());
}
let op_id = OperationId::new();
self.sub_agent_manager
.register(op_id.clone(), "spawn".to_string())
.await?;
let client = self.client.clone();
let store = self.store.clone();
let prompt = spec.prompt.clone();
let budget = spec.budget.clone();
let sub_agent_manager = self.sub_agent_manager.clone();
let op_id_clone = op_id.clone();
let depth = self.depth + 1;
let model = self.config.model.clone();
let max_tokens = self.config.max_tokens_per_turn;
let parent_scoped_event_tx = self.default_scoped_event_tx.clone();
let parent_scope_path = self.default_scope_path.clone();
let allowed_tool_names: Vec<String> =
allowed_tools.iter().map(|t| t.name.clone()).collect();
let filtered_tools = Arc::new(FilteredToolDispatcher::new(
self.tools.clone(),
allowed_tool_names,
));
tokio::spawn(async move {
let start = crate::time_compat::Instant::now();
let mut sub_agent = AgentBuilder::new()
.model(&model)
.max_tokens_per_turn(max_tokens)
.budget(budget)
.resume_session(sub_session)
.build(client, filtered_tools, store)
.await;
let (result, forwarder_task) = if let Some(scoped_tx) = parent_scoped_event_tx {
let (child_event_tx, child_event_rx) = mpsc::channel::<AgentEvent>(64);
let child_scope_frame = StreamScopeFrame::SubAgent {
agent_id: op_id_clone.to_string(),
tool_call_id: None,
label: Some("spawn".to_string()),
};
let forwarder = spawn_scoped_forwarder(
child_event_rx,
scoped_tx,
Arc::new(parent_scope_path.clone()),
child_scope_frame,
);
(
sub_agent.run_with_events(prompt, child_event_tx).await,
Some(forwarder),
)
} else {
(sub_agent.run(prompt).await, None)
};
if let Some(forwarder) = forwarder_task {
let _ = forwarder.await;
}
match result {
Ok(run_result) => {
sub_agent_manager
.complete(
&op_id_clone,
OperationResult {
id: op_id_clone.clone(),
content: run_result.text,
is_error: false,
duration_ms: start.elapsed().as_millis() as u64,
tokens_used: run_result.usage.total_tokens(),
},
)
.await;
}
Err(e) => {
sub_agent_manager.fail(&op_id_clone, e.to_string()).await;
}
}
});
tracing::info!(
"Spawned sub-agent {} at depth {} with {} tools",
op_id,
depth,
allowed_tools.len()
);
Ok(op_id)
}
pub async fn fork(
&self,
branches: Vec<ForkBranch>,
budget_policy: ForkBudgetPolicy,
) -> Result<Vec<OperationId>, AgentError> {
if self.depth + 1 > self.sub_agent_manager.limits.max_depth {
return Err(AgentError::DepthLimitExceeded {
depth: self.depth + 1,
max: self.sub_agent_manager.limits.max_depth,
});
}
let running = self.sub_agent_manager.running_ids().await.len();
if running + branches.len() > self.sub_agent_manager.limits.max_concurrent_agents {
return Err(AgentError::SubAgentLimitExceeded {
limit: self.sub_agent_manager.limits.max_concurrent_agents,
});
}
let remaining_tokens = self.budget.remaining();
let budgets = self.sub_agent_manager.allocate_fork_budget(
remaining_tokens,
branches.len(),
&budget_policy,
);
let mut op_ids = Vec::with_capacity(branches.len());
for (i, branch) in branches.into_iter().enumerate() {
let op_id = OperationId::new();
if let Some(ToolAccessPolicy::AllowList(names)) = &branch.tool_access {
let all_tools = self.tools.tools();
for name in names {
if !all_tools.iter().any(|t| &t.name == name) {
return Err(AgentError::InvalidToolAccess { tool: name.clone() });
}
}
}
self.sub_agent_manager
.register(op_id.clone(), branch.name.clone())
.await?;
let all_tools = self.tools.tools();
let allowed_tools = match &branch.tool_access {
Some(policy) => self
.sub_agent_manager
.apply_tool_access_policy(&all_tools, policy),
None => all_tools.to_vec(), };
let allowed_tool_names: Vec<String> =
allowed_tools.iter().map(|t| t.name.clone()).collect();
let filtered_tools = Arc::new(FilteredToolDispatcher::new(
self.tools.clone(),
allowed_tool_names,
));
let client = self.client.clone();
let store = self.store.clone();
let prompt = branch.prompt.clone();
let budget = budgets[i].clone();
let sub_agent_manager = self.sub_agent_manager.clone();
let op_id_clone = op_id.clone();
let model = self.config.model.clone();
let max_tokens = self.config.max_tokens_per_turn;
let branch_name = branch.name.clone();
let parent_scoped_event_tx = self.default_scoped_event_tx.clone();
let parent_scope_path = self.default_scope_path.clone();
let mut fork_session = Session::new();
for msg in self.session.messages() {
fork_session.push(msg.clone());
}
tokio::spawn(async move {
let start = crate::time_compat::Instant::now();
let mut sub_agent = AgentBuilder::new()
.model(&model)
.max_tokens_per_turn(max_tokens)
.budget(budget)
.resume_session(fork_session)
.build(client, filtered_tools, store)
.await;
let (result, forwarder_task) = if let Some(scoped_tx) = parent_scoped_event_tx {
let (child_event_tx, child_event_rx) = mpsc::channel::<AgentEvent>(64);
let child_scope_frame = StreamScopeFrame::SubAgent {
agent_id: op_id_clone.to_string(),
tool_call_id: None,
label: Some(branch_name.clone()),
};
let forwarder = spawn_scoped_forwarder(
child_event_rx,
scoped_tx,
Arc::new(parent_scope_path.clone()),
child_scope_frame,
);
(
sub_agent.run_with_events(prompt, child_event_tx).await,
Some(forwarder),
)
} else {
(sub_agent.run(prompt).await, None)
};
if let Some(forwarder) = forwarder_task {
let _ = forwarder.await;
}
match result {
Ok(run_result) => {
sub_agent_manager
.complete(
&op_id_clone,
OperationResult {
id: op_id_clone.clone(),
content: format!("[{}] {}", branch_name, run_result.text),
is_error: false,
duration_ms: start.elapsed().as_millis() as u64,
tokens_used: run_result.usage.total_tokens(),
},
)
.await;
}
Err(e) => {
sub_agent_manager.fail(&op_id_clone, e.to_string()).await;
}
}
});
tracing::info!(
"Forked branch '{}' as {} at depth {}",
branch.name,
op_id,
self.depth + 1
);
op_ids.push(op_id);
}
Ok(op_ids)
}
pub async fn cancel_sub_agent(&self, op_id: &OperationId) {
self.sub_agent_manager.cancel(op_id).await;
}
pub async fn collect_sub_agent_results(&self) -> Vec<OperationResult> {
self.sub_agent_manager.collect_completed().await
}
pub async fn has_running_sub_agents(&self) -> bool {
self.sub_agent_manager.has_running().await
}
async fn run_started_hooks(
&self,
prompt: &str,
event_tx: Option<&mpsc::Sender<AgentEvent>>,
) -> Result<(), AgentError> {
let report = self
.execute_hooks(
HookInvocation {
point: HookPoint::RunStarted,
session_id: self.session.id().clone(),
turn_number: None,
prompt: Some(prompt.to_string()),
error: None,
llm_request: None,
llm_response: None,
tool_call: None,
tool_result: None,
},
event_tx,
)
.await?;
if let Some(HookDecision::Deny {
reason_code,
message,
payload,
..
}) = report.decision
{
return Err(AgentError::HookDenied {
point: HookPoint::RunStarted,
reason_code,
message,
payload,
});
}
Ok(())
}
async fn run_completed_hooks(
&mut self,
result: &mut RunResult,
event_tx: Option<&mpsc::Sender<AgentEvent>>,
) -> Result<(), AgentError> {
let report = self
.execute_hooks(
HookInvocation {
point: HookPoint::RunCompleted,
session_id: self.session.id().clone(),
turn_number: Some(result.turns),
prompt: None,
error: None,
llm_request: None,
llm_response: None,
tool_call: None,
tool_result: None,
},
event_tx,
)
.await?;
if let Some(HookDecision::Deny {
reason_code,
message,
payload,
..
}) = report.decision
{
return Err(AgentError::HookDenied {
point: HookPoint::RunCompleted,
reason_code,
message,
payload,
});
}
for outcome in &report.outcomes {
for patch in &outcome.patches {
if let HookPatch::RunResult { text } = patch {
crate::event_tap::tap_emit(
&self.event_tap,
event_tx,
AgentEvent::HookRewriteApplied {
hook_id: outcome.hook_id.to_string(),
point: HookPoint::RunCompleted,
patch: HookPatch::RunResult { text: text.clone() },
},
)
.await;
result.text.clone_from(text);
self.apply_run_result_text_patch(text);
}
}
}
if let Err(err) = self.store.save(&self.session).await {
tracing::warn!("Failed to save session after run_completed hooks: {}", err);
}
Ok(())
}
fn apply_run_result_text_patch(&mut self, text: &str) {
use super::state::rewrite_assistant_text;
let messages = self.session.messages_mut();
if let Some(last_assistant) = messages
.iter_mut()
.rev()
.find(|message| matches!(message, Message::BlockAssistant(_) | Message::Assistant(_)))
{
match last_assistant {
Message::BlockAssistant(block_assistant) => {
rewrite_assistant_text(&mut block_assistant.blocks, text.to_string());
}
Message::Assistant(assistant) => {
assistant.content = text.to_string();
}
_ => {}
}
self.session.touch();
}
}
async fn run_failed_hooks(
&self,
error: &AgentError,
event_tx: Option<&mpsc::Sender<AgentEvent>>,
) -> Result<(), AgentError> {
let report = self
.execute_hooks(
HookInvocation {
point: HookPoint::RunFailed,
session_id: self.session.id().clone(),
turn_number: None,
prompt: None,
error: Some(error.to_string()),
llm_request: None,
llm_response: None,
tool_call: None,
tool_result: None,
},
event_tx,
)
.await?;
if let Some(HookDecision::Deny {
reason_code,
message,
payload,
..
}) = report.decision
{
return Err(AgentError::HookDenied {
point: HookPoint::RunFailed,
reason_code,
message,
payload,
});
}
Ok(())
}
pub async fn run(&mut self, user_input: String) -> Result<RunResult, AgentError> {
self.run_inner(user_input, None).await
}
pub async fn run_with_events(
&mut self,
user_input: String,
event_tx: mpsc::Sender<AgentEvent>,
) -> Result<RunResult, AgentError> {
self.run_inner(user_input, Some(event_tx)).await
}
pub async fn run_pending(&mut self) -> Result<RunResult, AgentError> {
self.run_pending_inner(None).await
}
pub async fn run_pending_with_events(
&mut self,
event_tx: mpsc::Sender<AgentEvent>,
) -> Result<RunResult, AgentError> {
self.run_pending_inner(Some(event_tx)).await
}
async fn run_inner(
&mut self,
user_input: String,
event_tx: Option<mpsc::Sender<AgentEvent>>,
) -> Result<RunResult, AgentError> {
let event_tx = event_tx.or_else(|| self.default_event_tx.clone());
self.state = LoopState::CallingLlm;
let user_input = self.apply_skill_ref(user_input).await;
let run_prompt = user_input.clone();
self.session.push(Message::User(crate::types::UserMessage {
content: user_input,
}));
if let Some(ref tx) = event_tx {
let _ = crate::event_tap::tap_emit(
&self.event_tap,
Some(tx),
AgentEvent::RunStarted {
session_id: self.session.id().clone(),
prompt: run_prompt.clone(),
},
)
.await;
}
self.run_started_hooks(&run_prompt, event_tx.as_ref())
.await?;
match self.run_loop(event_tx.clone()).await {
Ok(mut result) => {
self.run_completed_hooks(&mut result, event_tx.as_ref())
.await?;
Ok(result)
}
Err(err) => {
if let Err(hook_err) = self.run_failed_hooks(&err, event_tx.as_ref()).await {
tracing::warn!(?hook_err, "run_failed hook execution failed");
}
if let Some(ref tx) = event_tx {
let _ = crate::event_tap::tap_emit(
&self.event_tap,
Some(tx),
AgentEvent::RunFailed {
session_id: self.session.id().clone(),
error: err.to_string(),
},
)
.await;
}
Err(err)
}
}
}
async fn run_pending_inner(
&mut self,
event_tx: Option<mpsc::Sender<AgentEvent>>,
) -> Result<RunResult, AgentError> {
let event_tx = event_tx.or_else(|| self.default_event_tx.clone());
let pending_prompt = self.session.messages().last().and_then(|m| match m {
Message::User(u) => Some(u.content.clone()),
_ => None,
});
let Some(prompt) = pending_prompt else {
return Err(AgentError::ConfigError(
"run_pending requires a pending user message in the session".to_string(),
));
};
self.state = LoopState::CallingLlm;
if let Some(ref tx) = event_tx {
let _ = crate::event_tap::tap_emit(
&self.event_tap,
Some(tx),
AgentEvent::RunStarted {
session_id: self.session.id().clone(),
prompt: prompt.clone(),
},
)
.await;
}
self.run_started_hooks(&prompt, event_tx.as_ref()).await?;
match self.run_loop(event_tx.clone()).await {
Ok(mut result) => {
self.run_completed_hooks(&mut result, event_tx.as_ref())
.await?;
Ok(result)
}
Err(err) => {
if let Err(hook_err) = self.run_failed_hooks(&err, event_tx.as_ref()).await {
tracing::warn!(?hook_err, "run_failed hook execution failed");
}
if let Some(ref tx) = event_tx {
let _ = crate::event_tap::tap_emit(
&self.event_tap,
Some(tx),
AgentEvent::RunFailed {
session_id: self.session.id().clone(),
error: err.to_string(),
},
)
.await;
}
Err(err)
}
}
}
pub fn cancel(&mut self) {
if !self.state.is_terminal() {
let _ = self.state.transition(LoopState::Cancelling);
}
}
async fn apply_skill_ref(&mut self, user_input: String) -> String {
let engine = match &self.skill_engine {
Some(e) => e.clone(),
None => return user_input,
};
let mut prefix_parts: Vec<String> = Vec::new();
if let Some(refs) = self.pending_skill_references.take()
&& !refs.is_empty()
{
let canonical_ids: Vec<crate::skills::SkillId> = refs
.into_iter()
.map(|key| {
crate::skills::SkillId(format!("{}/{}", key.source_uuid, key.skill_name))
})
.collect();
match engine.resolve_and_render(&canonical_ids).await {
Ok(resolved) => {
for skill in &resolved {
tracing::info!(
skill_id = %skill.id.0,
"Per-turn skill activation via skill_references"
);
prefix_parts.push(skill.rendered_body.clone());
}
}
Err(e) => {
tracing::warn!(
error = %e,
"Failed to resolve source-pinned skill_references"
);
}
}
}
if prefix_parts.is_empty() {
return user_input;
}
if user_input.is_empty() {
prefix_parts.join("\n\n")
} else {
format!("{}\n\n{user_input}", prefix_parts.join("\n\n"))
}
}
}