use crate::error::AgentError;
use crate::event::{AgentEvent, BudgetType, ToolConfigChangeOperation, ToolConfigChangedPayload};
use crate::hooks::{
HookDecision, HookInvocation, HookLlmRequest, HookLlmResponse, HookPatch, HookPoint,
HookToolCall, HookToolResult,
};
use crate::state::LoopState;
#[cfg(target_arch = "wasm32")]
use crate::tokio;
use crate::types::{
AssistantBlock, BlockAssistantMessage, Message, RunResult, ToolCallView, ToolDef, ToolResult,
UserMessage,
};
use serde_json::Value;
use serde_json::value::RawValue;
use std::sync::Arc;
use tokio::sync::mpsc;
use super::{Agent, AgentLlmClient, AgentSessionStore, AgentToolDispatcher, LlmStreamResult};
impl<C, T, S> Agent<C, T, S>
where
C: AgentLlmClient + ?Sized + 'static,
T: AgentToolDispatcher + ?Sized + 'static,
S: AgentSessionStore + ?Sized + 'static,
{
async fn call_llm_with_retry(
&self,
messages: &[Message],
tools: &[Arc<ToolDef>],
max_tokens: u32,
temperature: Option<f32>,
provider_params: Option<&Value>,
) -> Result<LlmStreamResult, AgentError> {
let mut attempt = 0u32;
loop {
if attempt > 0 {
let delay = self.retry_policy.delay_for_attempt(attempt);
tokio::time::sleep(delay).await;
}
match self
.client
.stream_response(messages, tools, max_tokens, temperature, provider_params)
.await
{
Ok(result) => return Ok(result),
Err(e) => {
if e.is_recoverable() && self.retry_policy.should_retry(attempt) {
tracing::warn!(
"LLM call failed (attempt {}), retrying: {}",
attempt + 1,
e
);
attempt += 1;
continue;
}
return Err(e);
}
}
}
}
async fn drain_turn_boundary(
&mut self,
turn_count: u32,
event_tx: Option<&mpsc::Sender<AgentEvent>>,
) -> Result<(), AgentError> {
let turn_boundary_report = self
.execute_hooks(
HookInvocation {
point: HookPoint::TurnBoundary,
session_id: self.session.id().clone(),
turn_number: Some(turn_count),
prompt: None,
error: None,
llm_request: None,
llm_response: None,
tool_call: None,
tool_result: None,
},
event_tx,
)
.await?;
if let Some(HookDecision::Deny {
reason_code,
message,
payload,
..
}) = turn_boundary_report.decision
{
return Err(AgentError::HookDenied {
point: HookPoint::TurnBoundary,
reason_code,
message,
payload,
});
}
self.drain_comms_inbox().await;
let sub_agent_results = self.collect_sub_agent_results().await;
if !sub_agent_results.is_empty() {
let results: Vec<ToolResult> = sub_agent_results
.into_iter()
.map(|r| ToolResult::new(r.id.to_string(), r.content, r.is_error))
.collect();
self.session.push(Message::ToolResults { results });
}
Ok(())
}
#[allow(unused_assignments)]
pub(super) async fn run_loop(
&mut self,
event_tx: Option<mpsc::Sender<AgentEvent>>,
) -> Result<RunResult, AgentError> {
let mut turn_count = 0u32;
let max_turns = self.config.max_turns.unwrap_or(100);
let mut tool_call_count = 0u32;
let mut event_stream_open = true;
macro_rules! emit_event {
($event:expr) => {
{
let event = $event;
crate::event_tap::tap_try_send(&self.event_tap, &event);
if event_stream_open {
if let Some(ref tx) = event_tx {
if tx.send(event).await.is_err() {
event_stream_open = false;
tracing::warn!(
"agent event stream receiver dropped; continuing without streaming events"
);
}
}
}
}
};
}
loop {
if self.state == LoopState::CallingLlm {
self.drain_comms_inbox().await;
}
if turn_count >= max_turns {
self.state.transition(LoopState::Completed)?;
return Ok(self.build_result(turn_count, tool_call_count).await);
}
if self.budget.is_exhausted() {
emit_event!(AgentEvent::BudgetWarning {
budget_type: BudgetType::Tokens,
used: self.session.total_tokens(),
limit: self.budget.remaining(),
percent: 1.0,
});
self.state.transition(LoopState::Completed)?;
return Ok(self.build_result(turn_count, tool_call_count).await);
}
if self.state == LoopState::CallingLlm
&& let Some(ref compactor) = self.compactor
{
let ctx = crate::agent::compact::build_compaction_context(
self.session.messages(),
self.last_input_tokens,
self.last_compaction_turn,
turn_count,
);
if compactor.should_compact(&ctx) {
let outcome = crate::agent::compact::run_compaction(
self.client.as_ref(),
compactor,
self.session.messages(),
self.last_input_tokens,
turn_count,
&event_tx,
&self.event_tap,
)
.await;
if let Ok(outcome) = outcome {
*self.session.messages_mut() = outcome.new_messages;
self.session.record_usage(outcome.summary_usage.clone());
self.budget.record_usage(&outcome.summary_usage);
self.last_input_tokens = 0;
self.last_compaction_turn = Some(turn_count);
if let Some(ref memory_store) = self.memory_store {
let store = Arc::clone(memory_store);
let session_id = self.session.id().clone();
let discarded = outcome.discarded;
tokio::spawn(async move {
for message in &discarded {
let content = message.as_indexable_text();
if !content.is_empty() {
let metadata = crate::memory::MemoryMetadata {
session_id: session_id.clone(),
turn: Some(turn_count),
indexed_at: crate::time_compat::SystemTime::now(),
};
if let Err(e) = store.index(&content, metadata).await {
tracing::warn!(
"failed to index compaction discard into memory: {e}"
);
}
}
}
});
}
}
}
}
match self.state {
LoopState::CallingLlm => {
let ext = self.tools.poll_external_updates().await;
for notice in &ext.notices {
emit_event!(AgentEvent::ToolConfigChanged {
payload: ToolConfigChangedPayload {
operation: notice.operation.clone(),
target: notice.server.clone(),
status: notice.status.clone(),
persisted: false,
applied_at_turn: Some(turn_count),
},
});
}
const MCP_PENDING_PREFIX: &str = "[SYSTEM NOTICE][MCP_PENDING] ";
self.session.messages_mut().retain(
|m| !matches!(m, Message::User(u) if u.text_content().starts_with(MCP_PENDING_PREFIX)),
);
if !ext.pending.is_empty() {
self.session.push(Message::User(UserMessage::text(format!(
"{MCP_PENDING_PREFIX}Servers connecting: {}. \
Tools will appear when ready.",
ext.pending.join(", ")
))));
}
let tool_defs = {
let dispatcher_tools = self.tools.tools();
match self.tool_scope.apply_staged(dispatcher_tools.clone()) {
Ok(applied) => {
if applied.changed() {
let status = format!(
"boundary_applied(base_changed={},visible_changed={},revision={})",
applied.base_changed(),
applied.visible_changed(),
applied.applied_revision.0
);
emit_event!(AgentEvent::ToolConfigChanged {
payload: ToolConfigChangedPayload {
operation: ToolConfigChangeOperation::Reload,
target: "tool_scope".to_string(),
status: status.clone(),
persisted: false,
applied_at_turn: Some(turn_count),
},
});
self.session.push(Message::User(UserMessage::text(format!(
"[SYSTEM NOTICE][TOOL_SCOPE] Tool configuration changed at turn boundary: {status}"
))));
}
applied.tools
}
Err(err) => {
let status = format!("warning_fallback_all({err})");
tracing::warn!(
error = %err,
"tool scope boundary apply failed; falling back to full dispatcher tools"
);
emit_event!(AgentEvent::ToolConfigChanged {
payload: ToolConfigChangedPayload {
operation: ToolConfigChangeOperation::Reload,
target: "tool_scope".to_string(),
status: status.clone(),
persisted: false,
applied_at_turn: Some(turn_count),
},
});
self.session.push(Message::User(UserMessage::text(format!(
"[SYSTEM NOTICE][TOOL_SCOPE][WARNING] Tool scope apply failed ({err}); falling back to full tool set."
))));
dispatcher_tools
}
}
};
emit_event!(AgentEvent::TurnStarted {
turn_number: turn_count,
});
let mut effective_max_tokens = self.config.max_tokens_per_turn;
let mut effective_temperature = self.config.temperature;
let mut effective_provider_params = self.config.provider_params.clone();
let pre_llm_report = self
.execute_hooks(
HookInvocation {
point: HookPoint::PreLlmRequest,
session_id: self.session.id().clone(),
turn_number: Some(turn_count),
prompt: None,
error: None,
llm_request: Some(HookLlmRequest {
max_tokens: effective_max_tokens,
temperature: effective_temperature,
provider_params: effective_provider_params.clone(),
message_count: self.session.messages().len(),
}),
llm_response: None,
tool_call: None,
tool_result: None,
},
event_tx.as_ref(),
)
.await?;
if let Some(HookDecision::Deny {
reason_code,
message,
payload,
..
}) = pre_llm_report.decision
{
return Err(AgentError::HookDenied {
point: HookPoint::PreLlmRequest,
reason_code,
message,
payload,
});
}
for outcome in &pre_llm_report.outcomes {
for patch in &outcome.patches {
if let HookPatch::LlmRequest {
max_tokens,
temperature,
provider_params,
} = patch
{
emit_event!(AgentEvent::HookRewriteApplied {
hook_id: outcome.hook_id.to_string(),
point: HookPoint::PreLlmRequest,
patch: HookPatch::LlmRequest {
max_tokens: *max_tokens,
temperature: *temperature,
provider_params: provider_params.clone(),
},
});
if let Some(value) = max_tokens {
effective_max_tokens = *value;
}
if temperature.is_some() {
effective_temperature = *temperature;
}
if provider_params.is_some() {
effective_provider_params = provider_params.clone();
}
}
}
}
if self.extraction_mode {
effective_temperature = Some(0.0_f32);
let mut params =
effective_provider_params.unwrap_or_else(|| serde_json::json!({}));
if let Some(output_schema) = &self.config.output_schema
&& let Some(obj) = params.as_object_mut()
{
obj.insert("structured_output".to_string(), output_schema.to_value());
}
effective_provider_params = Some(params);
}
let empty_tools: Arc<[Arc<crate::types::ToolDef>]> = Arc::from([]);
let call_tool_defs = if self.extraction_mode {
&empty_tools
} else {
&tool_defs
};
let boundary_system_context = self.take_pending_system_context_boundary();
let request_messages =
self.llm_messages_with_runtime_system_context(&boundary_system_context);
let result = self
.call_llm_with_retry(
&request_messages,
call_tool_defs,
effective_max_tokens,
effective_temperature,
effective_provider_params.as_ref(),
)
.await?;
self.budget.record_usage(&result.usage);
self.last_input_tokens = result.usage.input_tokens;
self.session.record_usage(result.usage.clone());
let (blocks, stop_reason, usage) = result.into_parts();
let mut assistant_msg = BlockAssistantMessage {
blocks,
stop_reason,
};
let mut assistant_text = assistant_msg.to_string();
let post_llm_report = self
.execute_hooks(
HookInvocation {
point: HookPoint::PostLlmResponse,
session_id: self.session.id().clone(),
turn_number: Some(turn_count),
prompt: None,
error: None,
llm_request: None,
llm_response: Some(HookLlmResponse {
assistant_text: assistant_text.clone(),
tool_call_names: assistant_msg
.tool_calls()
.map(|call| call.name.to_string())
.collect(),
stop_reason: Some(stop_reason),
usage: Some(usage.clone()),
}),
tool_call: None,
tool_result: None,
},
event_tx.as_ref(),
)
.await?;
if let Some(HookDecision::Deny {
reason_code,
message,
payload,
..
}) = post_llm_report.decision
{
return Err(AgentError::HookDenied {
point: HookPoint::PostLlmResponse,
reason_code,
message,
payload,
});
}
for outcome in &post_llm_report.outcomes {
for patch in &outcome.patches {
if let HookPatch::AssistantText { text } = patch {
emit_event!(AgentEvent::HookRewriteApplied {
hook_id: outcome.hook_id.to_string(),
point: HookPoint::PostLlmResponse,
patch: HookPatch::AssistantText { text: text.clone() },
});
rewrite_assistant_text(&mut assistant_msg.blocks, text.clone());
assistant_text = assistant_msg.to_string();
}
}
}
if !assistant_text.is_empty() {
emit_event!(AgentEvent::TextComplete {
content: assistant_text.clone(),
});
}
if assistant_msg.has_tool_calls() {
self.session
.push(Message::BlockAssistant(assistant_msg.clone()));
for tc in assistant_msg.tool_calls() {
let args_value: Value = serde_json::from_str(tc.args.get())
.unwrap_or_else(|_| Value::String(tc.args.get().to_string()));
emit_event!(AgentEvent::ToolCallRequested {
id: tc.id.to_string(),
name: tc.name.to_string(),
args: args_value,
});
}
self.state.transition(LoopState::WaitingForOps)?;
let tool_calls: Vec<ToolCallOwned> = assistant_msg
.tool_calls()
.map(ToolCallOwned::from_view)
.collect();
let tools_ref = Arc::clone(&self.tools);
let mut executable_tool_calls = Vec::new();
let mut tool_results = Vec::with_capacity(tool_calls.len());
let pre_tool_reports =
futures::future::join_all(tool_calls.iter().map(|tc| {
let args_value: Value = serde_json::from_str(tc.args.get())
.unwrap_or_else(|_| Value::String(tc.args.get().to_string()));
self.execute_hooks(
HookInvocation {
point: HookPoint::PreToolExecution,
session_id: self.session.id().clone(),
turn_number: Some(turn_count),
prompt: None,
error: None,
llm_request: None,
llm_response: None,
tool_call: Some(HookToolCall {
tool_use_id: tc.id.clone(),
name: tc.name.clone(),
args: args_value,
}),
tool_result: None,
},
event_tx.as_ref(),
)
}))
.await;
for (mut tc, pre_tool_report) in
tool_calls.into_iter().zip(pre_tool_reports.into_iter())
{
let pre_tool_report = pre_tool_report?;
if let Some(HookDecision::Deny {
reason_code,
message,
payload,
..
}) = pre_tool_report.decision
{
let denied_payload = serde_json::json!({
"error": "hook_denied",
"reason_code": serde_json::to_value(reason_code).unwrap_or_else(|_| Value::String("runtime_error".to_string())),
"message": message,
"payload": payload,
});
let denied_content = serde_json::to_string(&denied_payload)
.unwrap_or_else(|_| {
"{\"error\":\"hook_denied\",\"message\":\"denied by hook\"}"
.to_string()
});
tool_results.push(ToolResult::new(
tc.id.clone(),
denied_content,
true,
));
emit_event!(AgentEvent::ToolExecutionCompleted {
id: tc.id.clone(),
name: tc.name.clone(),
result: tool_results
.last()
.map(ToolResult::text_content)
.unwrap_or_default(),
is_error: true,
duration_ms: 0,
has_images: false,
});
emit_event!(AgentEvent::ToolResultReceived {
id: tc.id.clone(),
name: tc.name.clone(),
is_error: true,
});
self.budget.record_tool_call();
tool_call_count += 1;
continue;
}
for outcome in &pre_tool_report.outcomes {
for patch in &outcome.patches {
if let HookPatch::ToolArgs { args } = patch {
emit_event!(AgentEvent::HookRewriteApplied {
hook_id: outcome.hook_id.to_string(),
point: HookPoint::PreToolExecution,
patch: HookPatch::ToolArgs { args: args.clone() },
});
tc.set_args(args.clone());
}
}
}
emit_event!(AgentEvent::ToolExecutionStarted {
id: tc.id.clone(),
name: tc.name.clone(),
});
executable_tool_calls.push(tc);
}
let visible_tool_names: std::collections::HashSet<String> =
tool_defs.iter().map(|t| t.name.clone()).collect();
let dispatch_futures: Vec<_> = executable_tool_calls
.into_iter()
.map(|tc| {
let tools_ref = Arc::clone(&tools_ref);
let visible = visible_tool_names.contains(&tc.name);
async move {
let start = crate::time_compat::Instant::now();
let dispatch_result = if visible {
tools_ref.dispatch(tc.as_view()).await
} else {
Err(crate::error::ToolError::NotFound {
name: tc.name.clone(),
})
};
let duration_ms = start.elapsed().as_millis() as u64;
(tc, dispatch_result, duration_ms)
}
})
.collect();
let dispatch_results = futures::future::join_all(dispatch_futures).await;
for (tc, dispatch_result, duration_ms) in dispatch_results {
let mut tool_result = match dispatch_result {
Ok(result) => result,
Err(crate::error::ToolError::CallbackPending {
tool_name: callback_tool,
args: callback_args,
}) => {
let mut merged_args =
callback_args.as_object().cloned().unwrap_or_default();
merged_args.insert(
"tool_use_id".to_string(),
Value::String(tc.id.clone()),
);
return Err(AgentError::CallbackPending {
tool_name: callback_tool,
args: Value::Object(merged_args),
});
}
Err(e) => {
let payload = e.to_error_payload();
let serialized = serde_json::to_string(&payload)
.unwrap_or_else(|_| {
"{\"error\":\"tool_error\",\"message\":\"tool error\"}"
.to_string()
});
ToolResult::new(tc.id.clone(), serialized, true)
}
};
if tool_result.tool_use_id.is_empty() {
tool_result.tool_use_id = tc.id.clone();
}
let post_tool_report = self
.execute_hooks(
HookInvocation {
point: HookPoint::PostToolExecution,
session_id: self.session.id().clone(),
turn_number: Some(turn_count),
prompt: None,
error: None,
llm_request: None,
llm_response: None,
tool_call: None,
tool_result: Some(HookToolResult {
tool_use_id: tc.id.clone(),
name: tc.name.clone(),
content: tool_result.text_content(),
is_error: tool_result.is_error,
has_images: tool_result.has_images(),
}),
},
event_tx.as_ref(),
)
.await?;
if let Some(HookDecision::Deny {
reason_code,
message,
payload,
..
}) = post_tool_report.decision
{
let denied_payload = serde_json::json!({
"error": "hook_denied",
"reason_code": serde_json::to_value(reason_code).unwrap_or_else(|_| Value::String("runtime_error".to_string())),
"message": message,
"payload": payload,
});
tool_result.set_text_content(
serde_json::to_string(&denied_payload).unwrap_or_else(|_| {
"{\"error\":\"hook_denied\",\"message\":\"denied by hook\"}"
.to_string()
}),
);
tool_result.is_error = true;
}
for outcome in &post_tool_report.outcomes {
for patch in &outcome.patches {
if let HookPatch::ToolResult { content, is_error } = patch {
emit_event!(AgentEvent::HookRewriteApplied {
hook_id: outcome.hook_id.to_string(),
point: HookPoint::PostToolExecution,
patch: HookPatch::ToolResult {
content: content.clone(),
is_error: *is_error,
},
});
crate::hooks::apply_tool_result_patch(
&mut tool_result,
content.clone(),
*is_error,
);
}
}
}
emit_event!(AgentEvent::ToolExecutionCompleted {
id: tc.id.clone(),
name: tc.name.clone(),
result: tool_result.text_content(),
is_error: tool_result.is_error,
duration_ms,
has_images: tool_result.has_images(),
});
emit_event!(AgentEvent::ToolResultReceived {
id: tc.id.clone(),
name: tc.name.clone(),
is_error: tool_result.is_error,
});
tool_results.push(tool_result);
self.budget.record_tool_call();
tool_call_count += 1;
}
self.session.push(Message::ToolResults {
results: tool_results,
});
self.state.transition(LoopState::DrainingEvents)?;
self.drain_turn_boundary(turn_count, event_tx.as_ref())
.await?;
self.state.transition(LoopState::CallingLlm)?;
turn_count += 1;
} else if self.extraction_mode {
self.session.push(Message::BlockAssistant(assistant_msg));
self.state.transition(LoopState::DrainingEvents)?;
self.drain_turn_boundary(turn_count, event_tx.as_ref())
.await?;
emit_event!(AgentEvent::TurnCompleted { stop_reason, usage });
let content = assistant_text.trim();
let json_content = super::extraction::strip_code_fences(content);
match serde_json::from_str::<serde_json::Value>(json_content) {
Ok(parsed) => {
let output_schema =
self.config.output_schema.as_ref().ok_or_else(|| {
AgentError::InternalError(
"extraction_mode without output_schema".into(),
)
})?;
let normalized = super::extraction::unwrap_named_object_wrapper(
parsed,
output_schema,
);
let validation_error: Option<String>;
#[cfg(feature = "jsonschema")]
{
let compiled =
self.client.compile_schema(output_schema).map_err(|e| {
AgentError::InvalidOutputSchema(e.to_string())
})?;
let validator = jsonschema::Validator::new(&compiled.schema)
.map_err(|e| {
AgentError::InvalidOutputSchema(e.to_string())
})?;
validation_error =
if let Err(error) = validator.validate(&normalized) {
Some(format!("Schema validation failed: {error}"))
} else {
None
};
}
#[cfg(not(feature = "jsonschema"))]
{
tracing::warn!(
"Structured output schema validation unavailable \
(jsonschema feature disabled). Accepting parsed \
JSON without schema validation."
);
validation_error = None;
}
if let Some(error) = validation_error {
self.extraction_attempts += 1;
if self.extraction_attempts
< self.config.structured_output_retries + 1
{
self.extraction_last_error = Some(error.clone());
let retry_prompt = format!(
"The previous output was invalid: {error}. \
Please provide valid JSON matching the schema. \
Output ONLY the JSON, no additional text."
);
self.session
.push(Message::User(UserMessage::text(retry_prompt)));
self.state.transition(LoopState::CallingLlm)?;
turn_count += 1;
continue;
}
self.state.transition(LoopState::Completed)?;
if let Err(e) = self.store.save(&self.session).await {
tracing::warn!("Failed to save session: {}", e);
}
return Err(AgentError::StructuredOutputValidationFailed {
attempts: self.config.structured_output_retries + 1,
reason: error,
last_output: self
.session
.last_assistant_text()
.unwrap_or_default(),
});
}
self.extraction_result = Some(normalized);
self.state.transition(LoopState::Completed)?;
if let Err(e) = self.store.save(&self.session).await {
tracing::warn!("Failed to save session: {}", e);
}
return Ok(RunResult {
text: self.session.last_assistant_text().unwrap_or_default(),
session_id: self.session.id().clone(),
usage: self.session.total_usage(),
turns: turn_count + 1,
tool_calls: tool_call_count,
structured_output: self.extraction_result.take(),
schema_warnings: self.extraction_schema_warnings.take(),
skill_diagnostics: None,
});
}
Err(e) => {
let error = format!("Invalid JSON: {e}");
self.extraction_attempts += 1;
if self.extraction_attempts
< self.config.structured_output_retries + 1
{
self.extraction_last_error = Some(error);
let retry_prompt = format!(
"The previous output was invalid: Invalid JSON: {e}. \
Please provide valid JSON matching the schema. \
Output ONLY the JSON, no additional text."
);
self.session
.push(Message::User(UserMessage::text(retry_prompt)));
self.state.transition(LoopState::CallingLlm)?;
turn_count += 1;
continue;
}
self.state.transition(LoopState::Completed)?;
if let Err(e) = self.store.save(&self.session).await {
tracing::warn!("Failed to save session: {}", e);
}
return Err(AgentError::StructuredOutputValidationFailed {
attempts: self.config.structured_output_retries + 1,
reason: error,
last_output: self
.session
.last_assistant_text()
.unwrap_or_default(),
});
}
}
} else {
let final_text = assistant_text.clone();
self.session.push(Message::BlockAssistant(assistant_msg));
self.state.transition(LoopState::DrainingEvents)?;
self.drain_turn_boundary(turn_count, event_tx.as_ref())
.await?;
emit_event!(AgentEvent::TurnCompleted { stop_reason, usage });
if let Some(output_schema) = self.config.output_schema.as_ref()
&& !self.extraction_mode
{
self.extraction_mode = true;
self.extraction_attempts = 0;
self.extraction_result = None;
self.extraction_last_error = None;
let compiled = self
.client
.compile_schema(output_schema)
.map_err(|e| AgentError::InvalidOutputSchema(e.to_string()))?;
self.extraction_schema_warnings = if compiled.warnings.is_empty() {
None
} else {
Some(compiled.warnings.clone())
};
let prompt =
self.config.extraction_prompt.clone().unwrap_or_else(|| {
super::extraction::DEFAULT_EXTRACTION_PROMPT.to_string()
});
self.session.push(Message::User(UserMessage::text(prompt)));
self.state.transition(LoopState::CallingLlm)?;
turn_count += 1;
continue;
}
self.state.transition(LoopState::Completed)?;
if let Err(e) = self.store.save(&self.session).await {
tracing::warn!("Failed to save session: {}", e);
}
return Ok(RunResult {
text: final_text,
session_id: self.session.id().clone(),
usage: self.session.total_usage(),
turns: turn_count + 1,
tool_calls: tool_call_count,
structured_output: None,
schema_warnings: None,
skill_diagnostics: self.collect_skill_diagnostics().await,
});
}
}
LoopState::WaitingForOps => {
unreachable!("WaitingForOps handled inline");
}
LoopState::DrainingEvents => {
self.state.transition(LoopState::Completed)?;
}
LoopState::Cancelling => {
self.state.transition(LoopState::Completed)?;
return Ok(self.build_result(turn_count, tool_call_count).await);
}
LoopState::ErrorRecovery => {
self.state.transition(LoopState::CallingLlm)?;
}
LoopState::Completed => {
return Ok(self.build_result(turn_count, tool_call_count).await);
}
}
}
}
async fn build_result(&self, turns: u32, tool_calls: u32) -> RunResult {
RunResult {
text: self.session.last_assistant_text().unwrap_or_default(),
session_id: self.session.id().clone(),
usage: self.session.total_usage(),
turns,
tool_calls,
structured_output: None,
schema_warnings: None,
skill_diagnostics: self.collect_skill_diagnostics().await,
}
}
async fn collect_skill_diagnostics(&self) -> Option<crate::skills::SkillRuntimeDiagnostics> {
let runtime = self.skill_engine.as_ref()?;
let source_health = runtime.health_snapshot().await.ok()?;
let quarantined = runtime.quarantined_diagnostics().await.unwrap_or_default();
Some(crate::skills::SkillRuntimeDiagnostics {
source_health,
quarantined,
})
}
}
pub(crate) fn rewrite_assistant_text(blocks: &mut Vec<AssistantBlock>, replacement: String) {
let first_text_idx = blocks
.iter()
.position(|block| matches!(block, AssistantBlock::Text { .. }));
if let Some(idx) = first_text_idx {
if let AssistantBlock::Text { text, .. } = &mut blocks[idx] {
*text = replacement;
}
let mut i = idx + 1;
while i < blocks.len() {
if matches!(blocks[i], AssistantBlock::Text { .. }) {
blocks.remove(i);
} else {
i += 1;
}
}
return;
}
blocks.insert(
0,
AssistantBlock::Text {
text: replacement,
meta: None,
},
);
}
#[derive(Debug, Clone)]
struct ToolCallOwned {
id: String,
name: String,
args: Box<RawValue>,
}
impl ToolCallOwned {
fn from_view(view: ToolCallView<'_>) -> Self {
let args = RawValue::from_string(view.args.get().to_string())
.unwrap_or_else(|_| fallback_raw_value());
Self {
id: view.id.to_string(),
name: view.name.to_string(),
args,
}
}
fn as_view(&self) -> ToolCallView<'_> {
ToolCallView {
id: &self.id,
name: &self.name,
args: &self.args,
}
}
fn set_args(&mut self, args: Value) {
let raw = RawValue::from_string(args.to_string()).unwrap_or_else(|_| fallback_raw_value());
self.args = raw;
}
}
#[allow(clippy::unwrap_used, clippy::expect_used)]
fn fallback_raw_value() -> Box<RawValue> {
RawValue::from_string("{}".to_string()).expect("static JSON is valid")
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::manual_async_fn)]
mod tests {
use super::rewrite_assistant_text;
use crate::agent::{AgentBuilder, AgentLlmClient, AgentSessionStore, AgentToolDispatcher};
use crate::budget::{Budget, BudgetLimits};
use crate::error::{AgentError, ToolError};
use crate::skills::{
ResolvedSkill, SkillCollection, SkillDescriptor, SkillEngine, SkillFilter, SkillId,
SkillKey, SkillName, SourceUuid,
};
use crate::state::LoopState;
use crate::tool_scope::{EXTERNAL_TOOL_FILTER_METADATA_KEY, ToolFilter};
use crate::types::{
AssistantBlock, Message, StopReason, ToolCallView, ToolDef, ToolResult, Usage,
};
use async_trait::async_trait;
use serde_json::Value;
use std::sync::{Arc, Mutex};
use tokio::sync::{Notify, mpsc};
#[test]
fn rewrite_assistant_text_rewrites_all_text_blocks() {
let mut blocks = vec![
AssistantBlock::Text {
text: "first".to_string(),
meta: None,
},
AssistantBlock::ToolUse {
id: "t1".to_string(),
name: "tool".to_string(),
args: serde_json::value::RawValue::from_string("{}".to_string()).unwrap(),
meta: None,
},
AssistantBlock::Text {
text: "second".to_string(),
meta: None,
},
];
rewrite_assistant_text(&mut blocks, "redacted".to_string());
let text_blocks: Vec<&str> = blocks
.iter()
.filter_map(|b| match b {
AssistantBlock::Text { text, .. } => Some(text.as_str()),
_ => None,
})
.collect();
assert_eq!(text_blocks, vec!["redacted"]);
}
struct StaticLlmClient;
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
impl AgentLlmClient for StaticLlmClient {
async fn stream_response(
&self,
_messages: &[Message],
_tools: &[Arc<ToolDef>],
_max_tokens: u32,
_temperature: Option<f32>,
_provider_params: Option<&Value>,
) -> Result<super::LlmStreamResult, AgentError> {
Ok(super::LlmStreamResult::new(
vec![AssistantBlock::Text {
text: "ok".to_string(),
meta: None,
}],
StopReason::EndTurn,
Usage::default(),
))
}
fn provider(&self) -> &'static str {
"mock"
}
}
struct RecordingLlmClient {
seen_user_messages: Mutex<Vec<String>>,
}
impl RecordingLlmClient {
fn new() -> Self {
Self {
seen_user_messages: Mutex::new(Vec::new()),
}
}
fn seen(&self) -> Vec<String> {
self.seen_user_messages.lock().unwrap().clone()
}
}
struct RecordingSkillEngine {
seen_ids: Mutex<Vec<SkillId>>,
}
impl RecordingSkillEngine {
fn new() -> Self {
Self {
seen_ids: Mutex::new(Vec::new()),
}
}
fn seen(&self) -> Vec<SkillId> {
self.seen_ids.lock().unwrap().clone()
}
}
impl SkillEngine for RecordingSkillEngine {
fn inventory_section(
&self,
) -> impl Future<Output = Result<String, crate::skills::SkillError>> + Send {
async move { Ok(String::new()) }
}
fn resolve_and_render(
&self,
ids: &[SkillId],
) -> impl Future<Output = Result<Vec<ResolvedSkill>, crate::skills::SkillError>> + Send
{
let ids = ids.to_vec();
async move {
let mut seen = self.seen_ids.lock().unwrap();
seen.extend_from_slice(&ids);
drop(seen);
Ok(vec![ResolvedSkill {
id: ids.first().cloned().unwrap_or_else(|| {
SkillId("dc256086-0d2f-4f61-a307-320d4148107f/email-extractor".to_string())
}),
name: "email-extractor".to_string(),
rendered_body: "<skill>injected canonical skill</skill>".to_string(),
byte_size: 34,
}])
}
}
fn collections(
&self,
) -> impl Future<Output = Result<Vec<SkillCollection>, crate::skills::SkillError>> + Send
{
async move { Ok(vec![]) }
}
fn list_skills(
&self,
_filter: &SkillFilter,
) -> impl Future<Output = Result<Vec<SkillDescriptor>, crate::skills::SkillError>> + Send
{
async move { Ok(vec![]) }
}
fn quarantined_diagnostics(
&self,
) -> impl Future<
Output = Result<
Vec<crate::skills::SkillQuarantineDiagnostic>,
crate::skills::SkillError,
>,
> + Send {
async move { Ok(Vec::new()) }
}
fn health_snapshot(
&self,
) -> impl Future<
Output = Result<crate::skills::SourceHealthSnapshot, crate::skills::SkillError>,
> + Send {
async move { Ok(crate::skills::SourceHealthSnapshot::default()) }
}
fn list_artifacts(
&self,
id: &SkillId,
) -> impl Future<
Output = Result<Vec<crate::skills::SkillArtifact>, crate::skills::SkillError>,
> + Send {
let missing = id.clone();
async move { Err(crate::skills::SkillError::NotFound { id: missing }) }
}
fn read_artifact(
&self,
id: &SkillId,
_artifact_path: &str,
) -> impl Future<
Output = Result<crate::skills::SkillArtifactContent, crate::skills::SkillError>,
> + Send {
let missing = id.clone();
async move { Err(crate::skills::SkillError::NotFound { id: missing }) }
}
fn invoke_function(
&self,
id: &SkillId,
_function_name: &str,
_arguments: Value,
) -> impl Future<Output = Result<Value, crate::skills::SkillError>> + Send {
let missing = id.clone();
async move { Err(crate::skills::SkillError::NotFound { id: missing }) }
}
}
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
impl AgentLlmClient for RecordingLlmClient {
async fn stream_response(
&self,
messages: &[Message],
_tools: &[Arc<ToolDef>],
_max_tokens: u32,
_temperature: Option<f32>,
_provider_params: Option<&Value>,
) -> Result<super::LlmStreamResult, AgentError> {
let mut seen = self.seen_user_messages.lock().unwrap();
for msg in messages {
if let Message::User(user) = msg {
seen.push(user.text_content());
}
}
drop(seen);
Ok(super::LlmStreamResult::new(
vec![AssistantBlock::Text {
text: "ok".to_string(),
meta: None,
}],
StopReason::EndTurn,
Usage::default(),
))
}
fn provider(&self) -> &'static str {
"mock"
}
}
struct NoTools;
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
impl AgentToolDispatcher for NoTools {
fn tools(&self) -> Arc<[Arc<ToolDef>]> {
Arc::new([])
}
async fn dispatch(&self, call: ToolCallView<'_>) -> Result<ToolResult, ToolError> {
Err(ToolError::NotFound {
name: call.name.to_string(),
})
}
}
struct NoopStore;
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
impl AgentSessionStore for NoopStore {
async fn save(&self, _session: &crate::session::Session) -> Result<(), AgentError> {
Ok(())
}
async fn load(&self, _id: &str) -> Result<Option<crate::session::Session>, AgentError> {
Ok(None)
}
}
struct FullToolDispatcher {
tools: Arc<[Arc<ToolDef>]>,
dispatched_names: Mutex<Vec<String>>,
}
impl FullToolDispatcher {
fn new(tool_names: &[&str]) -> Self {
let tools = tool_names
.iter()
.map(|name| {
Arc::new(ToolDef {
name: (*name).to_string(),
description: format!("{name} tool"),
input_schema: serde_json::json!({ "type": "object" }),
})
})
.collect::<Vec<_>>()
.into();
Self {
tools,
dispatched_names: Mutex::new(Vec::new()),
}
}
fn dispatched(&self) -> Vec<String> {
self.dispatched_names.lock().unwrap().clone()
}
}
#[async_trait]
impl AgentToolDispatcher for FullToolDispatcher {
fn tools(&self) -> Arc<[Arc<ToolDef>]> {
Arc::clone(&self.tools)
}
async fn dispatch(&self, call: ToolCallView<'_>) -> Result<ToolResult, ToolError> {
self.dispatched_names
.lock()
.unwrap()
.push(call.name.to_string());
Ok(ToolResult::new(
call.id.to_string(),
format!("dispatched {}", call.name),
false,
))
}
}
struct VisibilityRecordingLlmClient {
call_count: Mutex<u32>,
seen_tools: Mutex<Vec<Vec<String>>>,
}
impl VisibilityRecordingLlmClient {
fn new() -> Self {
Self {
call_count: Mutex::new(0),
seen_tools: Mutex::new(Vec::new()),
}
}
fn seen_tools(&self) -> Vec<Vec<String>> {
self.seen_tools.lock().unwrap().clone()
}
}
struct SingleTurnVisibilityClient {
seen_tools: Mutex<Vec<Vec<String>>>,
}
impl SingleTurnVisibilityClient {
fn new() -> Self {
Self {
seen_tools: Mutex::new(Vec::new()),
}
}
fn seen_tools(&self) -> Vec<Vec<String>> {
self.seen_tools.lock().unwrap().clone()
}
}
#[async_trait]
impl AgentLlmClient for SingleTurnVisibilityClient {
async fn stream_response(
&self,
_messages: &[Message],
tools: &[Arc<ToolDef>],
_max_tokens: u32,
_temperature: Option<f32>,
_provider_params: Option<&Value>,
) -> Result<super::LlmStreamResult, AgentError> {
self.seen_tools.lock().unwrap().push(
tools
.iter()
.map(|tool| tool.name.clone())
.collect::<Vec<_>>(),
);
Ok(super::LlmStreamResult::new(
vec![AssistantBlock::Text {
text: "done".to_string(),
meta: None,
}],
StopReason::EndTurn,
Usage::default(),
))
}
fn provider(&self) -> &'static str {
"mock"
}
}
#[async_trait]
impl AgentLlmClient for VisibilityRecordingLlmClient {
async fn stream_response(
&self,
_messages: &[Message],
tools: &[Arc<ToolDef>],
_max_tokens: u32,
_temperature: Option<f32>,
_provider_params: Option<&Value>,
) -> Result<super::LlmStreamResult, AgentError> {
self.seen_tools.lock().unwrap().push(
tools
.iter()
.map(|tool| tool.name.clone())
.collect::<Vec<_>>(),
);
let mut calls = self.call_count.lock().unwrap();
let response = if *calls == 0 {
super::LlmStreamResult::new(
vec![AssistantBlock::ToolUse {
id: "call-1".to_string(),
name: "secret".to_string(),
args: serde_json::value::RawValue::from_string("{}".to_string()).unwrap(),
meta: None,
}],
StopReason::ToolUse,
Usage::default(),
)
} else {
super::LlmStreamResult::new(
vec![AssistantBlock::Text {
text: "done".to_string(),
meta: None,
}],
StopReason::EndTurn,
Usage::default(),
)
};
*calls += 1;
Ok(response)
}
fn provider(&self) -> &'static str {
"mock"
}
}
struct MockDrainCommsRuntime {
queued: tokio::sync::Mutex<Vec<String>>,
notify: Arc<Notify>,
}
impl MockDrainCommsRuntime {
fn with_messages(messages: Vec<String>) -> Self {
Self {
queued: tokio::sync::Mutex::new(messages),
notify: Arc::new(Notify::new()),
}
}
}
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
impl crate::agent::CommsRuntime for MockDrainCommsRuntime {
async fn drain_messages(&self) -> Vec<String> {
let mut guard = self.queued.lock().await;
std::mem::take(&mut *guard)
}
fn inbox_notify(&self) -> Arc<Notify> {
self.notify.clone()
}
}
struct StagedDrainCommsRuntime {
batches: tokio::sync::Mutex<Vec<Vec<String>>>,
notify: Arc<Notify>,
}
impl StagedDrainCommsRuntime {
fn with_batches(batches: Vec<Vec<String>>) -> Self {
Self {
batches: tokio::sync::Mutex::new(batches),
notify: Arc::new(Notify::new()),
}
}
}
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
impl crate::agent::CommsRuntime for StagedDrainCommsRuntime {
async fn drain_messages(&self) -> Vec<String> {
let mut guard = self.batches.lock().await;
if guard.is_empty() {
Vec::new()
} else {
guard.remove(0)
}
}
fn inbox_notify(&self) -> Arc<Notify> {
self.notify.clone()
}
}
async fn build_agent<C>(client: Arc<C>) -> crate::agent::Agent<C, NoTools, NoopStore>
where
C: AgentLlmClient + ?Sized + 'static,
{
AgentBuilder::new()
.build(client, Arc::new(NoTools), Arc::new(NoopStore))
.await
}
#[tokio::test]
async fn calling_llm_with_max_turns_zero_completes_with_zero_turns() {
let mut agent = build_agent(Arc::new(StaticLlmClient)).await;
agent.config.max_turns = Some(0);
agent.state = LoopState::CallingLlm;
let result = agent.run_loop(None).await.unwrap();
assert_eq!(result.turns, 0);
assert_eq!(agent.state, LoopState::Completed);
}
#[tokio::test]
async fn calling_llm_with_budget_exhausted_completes_with_zero_turns() {
let mut agent = build_agent(Arc::new(StaticLlmClient)).await;
agent.config.max_turns = Some(10);
agent.state = LoopState::CallingLlm;
agent.budget = Budget::new(BudgetLimits {
max_tokens: Some(0),
max_duration: None,
max_tool_calls: None,
});
let result = agent.run_loop(None).await.unwrap();
assert_eq!(result.turns, 0);
assert_eq!(agent.state, LoopState::Completed);
}
#[tokio::test]
async fn completed_with_max_turns_zero_returns_invalid_transition() {
let mut agent = build_agent(Arc::new(StaticLlmClient)).await;
agent.config.max_turns = Some(0);
agent.state = LoopState::Completed;
let err = agent
.run_loop(None)
.await
.expect_err("expected transition error");
let AgentError::InvalidStateTransition { from, to } = err else {
unreachable!("expected InvalidStateTransition, got {err:?}");
};
assert_eq!(from, "Completed");
assert_eq!(to, "Completed");
}
#[tokio::test]
async fn error_recovery_with_max_turns_zero_completes() {
let mut agent = build_agent(Arc::new(StaticLlmClient)).await;
agent.config.max_turns = Some(0);
agent.state = LoopState::ErrorRecovery;
let result = agent.run_loop(None).await.unwrap();
assert_eq!(result.turns, 0);
assert_eq!(agent.state, LoopState::Completed);
}
#[tokio::test]
async fn error_recovery_drains_comms_message_when_transitioning_to_calling_llm() {
let client = Arc::new(RecordingLlmClient::new());
let comms = Arc::new(MockDrainCommsRuntime::with_messages(vec![
"queued during recovery".to_string(),
]));
let mut agent = AgentBuilder::new()
.with_comms_runtime(comms)
.build(client.clone(), Arc::new(NoTools), Arc::new(NoopStore))
.await;
agent.config.max_turns = Some(1);
agent.state = LoopState::ErrorRecovery;
let result = agent.run_loop(None).await.unwrap();
assert_eq!(result.turns, 1);
let seen = client.seen();
assert!(
seen.iter().any(|m| m.contains("queued during recovery")),
"expected queued comms message to be drained into LLM input, saw: {seen:?}"
);
}
#[tokio::test]
async fn no_tool_completion_drains_late_comms_before_returning() {
let comms = Arc::new(StagedDrainCommsRuntime::with_batches(vec![
Vec::new(),
vec!["late boundary message".to_string()],
]));
let mut agent = AgentBuilder::new()
.with_comms_runtime(comms)
.build(
Arc::new(StaticLlmClient),
Arc::new(NoTools),
Arc::new(NoopStore),
)
.await;
let result = agent.run("prompt".to_string().into()).await.unwrap();
assert_eq!(result.text, "ok");
assert!(
agent.session().messages().iter().any(|message| matches!(
message,
Message::User(user) if user.text_content().contains("late boundary message")
)),
"completion path should drain late comms messages into the final session transcript"
);
}
#[tokio::test]
async fn run_with_events_emits_run_completed_for_max_turns_zero() {
let mut agent = build_agent(Arc::new(StaticLlmClient)).await;
agent.config.max_turns = Some(0);
let (tx, mut rx) = mpsc::channel::<crate::event::AgentEvent>(32);
let result = agent
.run_with_events("prompt".to_string().into(), tx)
.await
.unwrap();
assert_eq!(result.turns, 0);
let mut saw_run_completed = false;
while let Ok(event) = rx.try_recv() {
if let crate::event::AgentEvent::RunCompleted { result, .. } = event {
saw_run_completed = true;
assert_eq!(result, "");
}
}
assert!(
saw_run_completed,
"successful early exits should still emit RunCompleted"
);
}
#[tokio::test]
async fn run_completed_event_uses_hook_rewritten_text() {
use crate::hooks::{
HookEngine, HookEngineError, HookExecutionReport, HookInvocation, HookOutcome,
HookPatch, HookPoint,
};
struct RewriteRunCompletedHook;
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
impl HookEngine for RewriteRunCompletedHook {
async fn execute(
&self,
invocation: HookInvocation,
_overrides: Option<&crate::config::HookRunOverrides>,
) -> Result<HookExecutionReport, HookEngineError> {
if invocation.point != HookPoint::RunCompleted {
return Ok(HookExecutionReport::empty());
}
Ok(HookExecutionReport {
outcomes: vec![HookOutcome {
hook_id: crate::hooks::HookId::new("rewrite-run-completed"),
point: HookPoint::RunCompleted,
priority: 0,
registration_index: 0,
decision: None,
patches: vec![HookPatch::RunResult {
text: "patched-final-text".to_string(),
}],
published_patches: Vec::new(),
error: None,
duration_ms: None,
}],
decision: None,
patches: Vec::new(),
published_patches: Vec::new(),
})
}
}
let mut agent = AgentBuilder::new()
.with_hook_engine(Arc::new(RewriteRunCompletedHook))
.build(
Arc::new(StaticLlmClient),
Arc::new(NoTools),
Arc::new(NoopStore),
)
.await;
let (tx, mut rx) = mpsc::channel::<crate::event::AgentEvent>(32);
let result = agent
.run_with_events("prompt".to_string().into(), tx)
.await
.unwrap();
assert_eq!(result.text, "patched-final-text");
let mut run_completed_text = None;
while let Ok(event) = rx.try_recv() {
if let crate::event::AgentEvent::RunCompleted { result, .. } = event {
run_completed_text = Some(result);
}
}
assert_eq!(
run_completed_text.as_deref(),
Some("patched-final-text"),
"RunCompleted should reflect the hook-rewritten final result"
);
}
#[tokio::test]
async fn run_completed_hook_failure_emits_run_failed_without_run_completed() {
use crate::hooks::{
HookDecision, HookEngine, HookEngineError, HookExecutionReport, HookInvocation,
HookPoint, HookReasonCode,
};
struct DenyRunCompletedHook;
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
impl HookEngine for DenyRunCompletedHook {
async fn execute(
&self,
invocation: HookInvocation,
_overrides: Option<&crate::config::HookRunOverrides>,
) -> Result<HookExecutionReport, HookEngineError> {
if invocation.point != HookPoint::RunCompleted {
return Ok(HookExecutionReport::empty());
}
Ok(HookExecutionReport {
decision: Some(HookDecision::Deny {
hook_id: crate::hooks::HookId::new("deny-run-completed"),
reason_code: HookReasonCode::PolicyViolation,
message: "deny completed".to_string(),
payload: None,
}),
..HookExecutionReport::empty()
})
}
}
let mut agent = AgentBuilder::new()
.with_hook_engine(Arc::new(DenyRunCompletedHook))
.build(
Arc::new(StaticLlmClient),
Arc::new(NoTools),
Arc::new(NoopStore),
)
.await;
let (tx, mut rx) = mpsc::channel::<crate::event::AgentEvent>(32);
let err = agent
.run_with_events("prompt".to_string().into(), tx)
.await
.expect_err("RunCompleted hook denial should fail the run");
assert!(matches!(
err,
AgentError::HookDenied {
point: HookPoint::RunCompleted,
..
}
));
let mut saw_run_failed = false;
let mut saw_run_completed = false;
while let Ok(event) = rx.try_recv() {
match event {
crate::event::AgentEvent::RunFailed { .. } => saw_run_failed = true,
crate::event::AgentEvent::RunCompleted { .. } => saw_run_completed = true,
_ => {}
}
}
assert!(
saw_run_failed,
"hook-denied completion should emit RunFailed"
);
assert!(
!saw_run_completed,
"hook-denied completion should not also emit RunCompleted"
);
}
#[tokio::test]
async fn turn_boundary_denial_blocks_boundary_side_effects_and_turn_completed() {
use crate::hooks::{
HookDecision, HookEngine, HookEngineError, HookExecutionReport, HookInvocation,
HookPoint, HookReasonCode,
};
struct DenyTurnBoundaryHook;
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
impl HookEngine for DenyTurnBoundaryHook {
async fn execute(
&self,
invocation: HookInvocation,
_overrides: Option<&crate::config::HookRunOverrides>,
) -> Result<HookExecutionReport, HookEngineError> {
if invocation.point != HookPoint::TurnBoundary {
return Ok(HookExecutionReport::empty());
}
Ok(HookExecutionReport {
decision: Some(HookDecision::Deny {
hook_id: crate::hooks::HookId::new("deny-turn-boundary"),
reason_code: HookReasonCode::PolicyViolation,
message: "deny boundary".to_string(),
payload: None,
}),
..HookExecutionReport::empty()
})
}
}
let comms = Arc::new(StagedDrainCommsRuntime::with_batches(vec![
Vec::new(),
vec!["late boundary message".to_string()],
]));
let mut agent = AgentBuilder::new()
.with_hook_engine(Arc::new(DenyTurnBoundaryHook))
.with_comms_runtime(comms)
.build(
Arc::new(StaticLlmClient),
Arc::new(NoTools),
Arc::new(NoopStore),
)
.await;
let (tx, mut rx) = mpsc::channel::<crate::event::AgentEvent>(32);
let err = agent
.run_with_events("prompt".to_string().into(), tx)
.await
.expect_err("TurnBoundary denial should fail the run");
assert!(matches!(
err,
AgentError::HookDenied {
point: HookPoint::TurnBoundary,
..
}
));
assert!(
!agent.session().messages().iter().any(|message| matches!(
message,
Message::User(user) if user.text_content().contains("late boundary message")
)),
"boundary-denied turns should not commit late comms boundary side effects"
);
let mut saw_turn_completed = false;
let mut saw_run_failed = false;
while let Ok(event) = rx.try_recv() {
match event {
crate::event::AgentEvent::TurnCompleted { .. } => saw_turn_completed = true,
crate::event::AgentEvent::RunFailed { .. } => saw_run_failed = true,
_ => {}
}
}
assert!(
!saw_turn_completed,
"boundary denial should not emit TurnCompleted before failing the run"
);
assert!(saw_run_failed, "boundary denial should emit RunFailed");
}
#[tokio::test]
async fn run_without_primary_channel_still_emits_run_lifecycle_to_tap() {
use crate::event_tap::EventTapState;
use std::sync::atomic::AtomicBool;
let tap = crate::event_tap::new_event_tap();
let (tap_tx, mut tap_rx) = mpsc::channel(128);
{
let mut guard = tap.lock();
*guard = Some(EventTapState {
tx: tap_tx,
truncated: AtomicBool::new(false),
});
}
let mut agent = AgentBuilder::new()
.with_event_tap(tap)
.build(
Arc::new(StaticLlmClient),
Arc::new(NoTools),
Arc::new(NoopStore),
)
.await;
let result = agent
.run("tap-only prompt".to_string().into())
.await
.unwrap();
assert_eq!(result.text, "ok");
let mut saw_run_started = false;
let mut saw_run_completed = false;
while let Ok(event) = tap_rx.try_recv() {
match event {
crate::event::AgentEvent::RunStarted { .. } => saw_run_started = true,
crate::event::AgentEvent::RunCompleted { .. } => saw_run_completed = true,
_ => {}
}
}
assert!(saw_run_started, "tap should receive RunStarted");
assert!(saw_run_completed, "tap should receive RunCompleted");
}
#[tokio::test]
async fn pending_skill_keys_are_resolved_and_injected_into_runtime_prompt() {
let client = Arc::new(RecordingLlmClient::new());
let skill_engine = Arc::new(RecordingSkillEngine::new());
let skill_runtime = Arc::new(crate::skills::SkillRuntime::new(skill_engine.clone()));
let mut agent = AgentBuilder::new()
.with_skill_engine(skill_runtime)
.build(client.clone(), Arc::new(NoTools), Arc::new(NoopStore))
.await;
agent.pending_skill_references = Some(vec![SkillKey {
source_uuid: SourceUuid::parse("dc256086-0d2f-4f61-a307-320d4148107f")
.expect("valid source uuid"),
skill_name: SkillName::parse("email-extractor").expect("valid skill name"),
}]);
agent.config.max_turns = Some(1);
let result = agent
.run("plain user prompt".to_string().into())
.await
.expect("run should succeed");
assert_eq!(result.turns, 1);
let seen_ids = skill_engine.seen();
assert!(
seen_ids
.iter()
.any(|id| id.0 == "dc256086-0d2f-4f61-a307-320d4148107f/email-extractor"),
"expected canonical skill id to be forwarded to skill engine, saw: {seen_ids:?}"
);
let seen_messages = client.seen();
assert!(
seen_messages
.iter()
.any(|msg| msg.contains("<skill>injected canonical skill</skill>")),
"expected runtime prompt to include rendered skill injection, saw: {seen_messages:?}"
);
}
#[tokio::test]
async fn provider_receives_filtered_tools_and_dispatch_blocks_hidden_tools() {
let client = Arc::new(VisibilityRecordingLlmClient::new());
let tools = Arc::new(FullToolDispatcher::new(&["visible", "secret"]));
let mut agent = AgentBuilder::new()
.build(client.clone(), tools.clone(), Arc::new(NoopStore))
.await;
agent
.stage_external_tool_filter(ToolFilter::Deny(
["secret".to_string()].into_iter().collect(),
))
.unwrap();
agent.config.max_turns = Some(2);
let result = agent.run("prompt".to_string().into()).await.unwrap();
assert_eq!(result.text, "done");
let seen = client.seen_tools();
assert_eq!(seen.len(), 2);
assert_eq!(seen[0], vec!["visible".to_string()]);
assert_eq!(seen[1], vec!["visible".to_string()]);
let dispatched = tools.dispatched();
assert!(
dispatched.is_empty(),
"hidden tools should not be dispatched, but got: {dispatched:?}"
);
}
#[tokio::test]
async fn run_loop_boundary_applies_filter_and_emits_tool_config_changed_and_notice() {
let client = Arc::new(SingleTurnVisibilityClient::new());
let tools = Arc::new(FullToolDispatcher::new(&["visible", "secret"]));
let mut agent = AgentBuilder::new()
.build(client.clone(), tools, Arc::new(NoopStore))
.await;
agent
.stage_external_tool_filter(ToolFilter::Deny(
["secret".to_string()].into_iter().collect(),
))
.unwrap();
let (tx, mut rx) = mpsc::channel::<crate::event::AgentEvent>(128);
let result = agent
.run_with_events("prompt".to_string().into(), tx)
.await
.unwrap();
assert_eq!(result.text, "done");
assert_eq!(client.seen_tools(), vec![vec!["visible".to_string()]]);
let mut saw_config_event = false;
while let Ok(event) = rx.try_recv() {
if let crate::event::AgentEvent::ToolConfigChanged { payload } = event {
assert_eq!(
payload.operation,
crate::event::ToolConfigChangeOperation::Reload
);
assert_eq!(payload.target, "tool_scope");
assert!(payload.status.contains("boundary_applied"));
saw_config_event = true;
}
}
assert!(
saw_config_event,
"expected ToolConfigChanged event on boundary visibility change"
);
let notices: Vec<String> = agent
.session()
.messages()
.iter()
.filter_map(|msg| match msg {
Message::User(user)
if user.text_content().contains("[SYSTEM NOTICE][TOOL_SCOPE]") =>
{
Some(user.text_content())
}
_ => None,
})
.collect();
assert_eq!(notices.len(), 1);
}
#[tokio::test]
async fn run_loop_fails_safe_to_full_tools_with_warning_event_and_notice() {
let client = Arc::new(SingleTurnVisibilityClient::new());
let tools = Arc::new(FullToolDispatcher::new(&["visible", "secret"]));
let mut agent = AgentBuilder::new()
.build(client.clone(), tools, Arc::new(NoopStore))
.await;
agent
.stage_external_tool_filter(ToolFilter::Deny(
["secret".to_string()].into_iter().collect(),
))
.unwrap();
agent.inject_tool_scope_boundary_failure_once_for_test();
let (tx, mut rx) = mpsc::channel::<crate::event::AgentEvent>(128);
let result = agent
.run_with_events("prompt".to_string().into(), tx)
.await
.unwrap();
assert_eq!(result.text, "done");
assert_eq!(
client.seen_tools(),
vec![vec!["visible".to_string(), "secret".to_string()]]
);
let mut saw_warning_event = false;
while let Ok(event) = rx.try_recv() {
if let crate::event::AgentEvent::ToolConfigChanged { payload } = event
&& payload.status.contains("warning_fallback_all")
{
saw_warning_event = true;
}
}
assert!(
saw_warning_event,
"expected warning ToolConfigChanged event during fail-safe fallback"
);
let notices: Vec<String> = agent
.session()
.messages()
.iter()
.filter_map(|msg| match msg {
Message::User(user) if user.text_content().contains("[TOOL_SCOPE][WARNING]") => {
Some(user.text_content())
}
_ => None,
})
.collect();
assert_eq!(notices.len(), 1);
}
#[tokio::test]
async fn builder_restores_persisted_external_filter_from_session_metadata() {
let client = Arc::new(VisibilityRecordingLlmClient::new());
let tools = Arc::new(FullToolDispatcher::new(&["visible", "secret"]));
let mut session = crate::Session::new();
session.set_metadata(
EXTERNAL_TOOL_FILTER_METADATA_KEY,
serde_json::to_value(ToolFilter::Deny(
["secret".to_string()].into_iter().collect(),
))
.unwrap(),
);
let mut agent = AgentBuilder::new()
.resume_session(session)
.build(client.clone(), tools, Arc::new(NoopStore))
.await;
agent.config.max_turns = Some(2);
let result = agent.run("prompt".to_string().into()).await.unwrap();
assert_eq!(result.text, "done");
assert_eq!(
client.seen_tools(),
vec![vec!["visible".to_string()], vec!["visible".to_string()]]
);
}
#[tokio::test]
async fn builder_prunes_unknown_persisted_filter_tools() {
let client = Arc::new(VisibilityRecordingLlmClient::new());
let tools = Arc::new(FullToolDispatcher::new(&["visible", "secret"]));
let mut session = crate::Session::new();
session.set_metadata(
EXTERNAL_TOOL_FILTER_METADATA_KEY,
serde_json::to_value(ToolFilter::Allow(
["visible".to_string(), "missing".to_string()]
.into_iter()
.collect(),
))
.unwrap(),
);
let mut agent = AgentBuilder::new()
.resume_session(session)
.build(client.clone(), tools, Arc::new(NoopStore))
.await;
agent.config.max_turns = Some(2);
let result = agent.run("prompt".to_string().into()).await.unwrap();
assert_eq!(result.text, "done");
let seen = client.seen_tools();
assert_eq!(
seen,
vec![vec!["visible".to_string()], vec!["visible".to_string()]]
);
}
}