use super::*;
use crate::agent::agent_loop::hooks::{BeforeToolCallFn, BeforeToolCallReturn};
use crate::agent::agent_loop::message::{ContentBlock, StopReason};
use crate::agent::agent_loop::result::{AfterToolCallResult, BeforeToolCallResult};
use crate::agent::agent_loop::types::{ConvertToLlmFn, ToolExecutionMode};
use std::pin::Pin;
use std::sync::Mutex;
struct EchoTool {
name: String,
prepare_arguments_fn: Option<Box<dyn Fn(Value) -> Value + Send + Sync>>,
execution_mode: Option<ToolExecutionMode>,
terminate: bool,
executed_args: Arc<Mutex<Vec<Value>>>,
delay_ms: Option<u64>,
delay_first_ms: Option<u64>,
concurrency: Arc<Mutex<(u32, u32)>>,
parallel_observed: Arc<Mutex<bool>>,
}
impl std::fmt::Debug for EchoTool {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EchoTool")
.field("name", &self.name)
.field("execution_mode", &self.execution_mode)
.field("terminate", &self.terminate)
.finish()
}
}
impl EchoTool {
fn new(name: &str) -> Self {
Self {
name: name.to_string(),
prepare_arguments_fn: None,
execution_mode: None,
terminate: false,
executed_args: Arc::new(Mutex::new(Vec::new())),
delay_ms: None,
delay_first_ms: None,
concurrency: Arc::new(Mutex::new((0, 0))),
parallel_observed: Arc::new(Mutex::new(false)),
}
}
fn with_prepare(mut self, f: impl Fn(Value) -> Value + Send + Sync + 'static) -> Self {
self.prepare_arguments_fn = Some(Box::new(f));
self
}
fn with_terminate(mut self) -> Self {
self.terminate = true;
self
}
fn with_execution_mode(mut self, mode: ToolExecutionMode) -> Self {
self.execution_mode = Some(mode);
self
}
fn with_delay_ms(mut self, ms: u64) -> Self {
self.delay_ms = Some(ms);
self
}
fn with_delay_first_ms(mut self, ms: u64) -> Self {
self.delay_first_ms = Some(ms);
self
}
fn concurrency_snapshot(&self) -> (u32, u32) {
*self.concurrency.lock().unwrap()
}
fn parallel_was_observed(&self) -> bool {
*self.parallel_observed.lock().unwrap()
}
}
impl LoopTool for EchoTool {
fn name(&self) -> &str {
&self.name
}
fn description(&self) -> &str {
"Echo tool"
}
fn label(&self) -> &str {
"Echo"
}
fn parameters(&self) -> &Value {
static EMPTY: std::sync::OnceLock<Value> = std::sync::OnceLock::new();
EMPTY.get_or_init(|| serde_json::json!({"type": "object"}))
}
fn execution_mode(&self) -> Option<ToolExecutionMode> {
self.execution_mode
}
fn prepare_arguments(&self, args: Value) -> Value {
if let Some(f) = &self.prepare_arguments_fn {
f(args)
} else {
args
}
}
fn execute<'a>(
&'a self,
_tool_call_id: &'a str,
args: Value,
_signal: AbortSignal,
_on_update: LoopToolUpdate,
) -> Pin<Box<dyn Future<Output = Result<LoopToolResult, String>> + Send + 'a>> {
let recorded = self.executed_args.clone();
let terminate = self.terminate;
let delay_ms = self.delay_ms;
let delay_first_ms = self.delay_first_ms;
let concurrency = self.concurrency.clone();
let parallel_observed = self.parallel_observed.clone();
Box::pin(async move {
{
let mut c = concurrency.lock().unwrap();
c.0 += 1;
if c.0 > c.1 {
c.1 = c.0;
}
}
let value_str = args
.get("value")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
if let Some(ms) = delay_first_ms
&& value_str == "first"
{
tokio::time::sleep(std::time::Duration::from_millis(ms)).await;
}
if delay_first_ms.is_some() && value_str == "second" {
let c = concurrency.lock().unwrap();
if c.0 > 1 {
*parallel_observed.lock().unwrap() = true;
}
}
if let Some(ms) = delay_ms {
tokio::time::sleep(std::time::Duration::from_millis(ms)).await;
}
recorded.lock().unwrap().push(args.clone());
{
let mut c = concurrency.lock().unwrap();
c.0 -= 1;
}
let text = format!("echoed: {}", args);
Ok(LoopToolResult {
content: vec![serde_json::json!({"type": "text", "text": text})],
details: args,
terminate: if terminate { Some(true) } else { None },
})
})
}
}
fn identity_converter() -> ConvertToLlmFn {
Arc::new(|messages: &[Value]| messages.to_vec())
}
fn build_config() -> LoopConfig {
LoopConfig {
convert_to_llm: identity_converter(),
transform_context: None,
compaction_hooks: None,
get_api_key: None,
api_key: None,
tool_execution: ToolExecutionMode::Sequential,
before_tool_call: None,
after_tool_call: None,
prepare_next_turn: None,
should_stop_after_turn: None,
get_steering_messages: None,
get_followup_messages: None,
reasoning: None,
thinking_budgets: None,
headers: std::collections::HashMap::new(),
metadata: std::collections::HashMap::new(),
request_timeout: None,
provider_name: None,
model_name: None,
compact_model: None,
storm_mutating_tools: None,
storm_exempt_tools: None,
repair_stats: std::sync::Arc::new(
crate::agent::agent_loop::tool_input_repair::RepairStats::new(),
),
truncation_notes: std::sync::Arc::new(std::sync::Mutex::new(
std::collections::HashMap::new(),
)),
tool_def_filter: None,
dynamic_tool_search: false,
escalation_stream_fn: None,
escalation_provider_name: None,
escalation_pending: std::sync::Arc::new(std::sync::Mutex::new(None)),
escalation_max_per_session: 3,
escalation_remaining: std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(3)),
file_touch_tracker: None,
verifier: None,
critic_fn: None,
goal: None,
max_turns: None,
}
}
fn build_context(tool: Arc<dyn LoopTool>) -> Context {
Context {
system_prompt: String::new(),
messages: Vec::new(),
tools: vec![tool],
}
}
#[tokio::test]
async fn test_handle_tool_calls_and_results() {
let echo = Arc::new(EchoTool::new("echo"));
let context = build_context(echo.clone());
let assistant_msg = AssistantMessage::new(
vec![ContentBlock::ToolCall {
id: "tool-1".to_string(),
name: "echo".to_string(),
arguments: serde_json::json!({"value": "hello"}),
}],
StopReason::ToolUse,
);
let tool_calls = extract_tool_calls(&assistant_msg);
assert_eq!(tool_calls.len(), 1);
let (tx, mut rx) = mpsc::channel::<LoopEvent>(64);
let config = build_config();
let signal = AbortSignal::new();
let batch = execute_tool_calls_sequential(
&context,
&assistant_msg,
&tool_calls,
&config,
&signal,
&tx,
&InflightSet::new(),
)
.await;
drop(tx);
let recorded = echo.executed_args.lock().unwrap();
assert_eq!(recorded.len(), 1);
assert_eq!(recorded[0]["value"], "hello");
drop(recorded);
assert_eq!(batch.messages.len(), 1);
assert!(!batch.messages[0].is_error);
assert!(!batch.terminate);
let mut kinds = Vec::new();
while let Some(e) = rx.recv().await {
kinds.push(e.kind().to_string());
}
assert_eq!(
kinds,
vec![
"tool_execution_start",
"tool_execution_end",
"message_start",
"message_end",
]
);
}
#[tokio::test]
async fn test_before_tool_call_mutates_args() {
let echo = Arc::new(EchoTool::new("echo"));
let context = build_context(echo.clone());
let assistant_msg = AssistantMessage::new(
vec![ContentBlock::ToolCall {
id: "tool-1".to_string(),
name: "echo".to_string(),
arguments: serde_json::json!({"value": "hello"}),
}],
StopReason::ToolUse,
);
let tool_calls = extract_tool_calls(&assistant_msg);
let before: BeforeToolCallFn = Arc::new(|ctx: BeforeToolCallContext| {
Box::pin(async move {
let mut args = ctx.args.clone();
if let Some(obj) = args.as_object_mut() {
obj.insert("value".to_string(), serde_json::json!(123));
}
BeforeToolCallReturn { result: None, args }
})
});
let mut config = build_config();
config.before_tool_call = Some(before);
let (tx, mut rx) = mpsc::channel::<LoopEvent>(64);
let signal = AbortSignal::new();
let _ = execute_tool_calls_sequential(
&context,
&assistant_msg,
&tool_calls,
&config,
&signal,
&tx,
&InflightSet::new(),
)
.await;
drop(tx);
while rx.recv().await.is_some() {}
let recorded = echo.executed_args.lock().unwrap();
assert_eq!(recorded.len(), 1);
assert_eq!(recorded[0]["value"], serde_json::json!(123));
}
#[tokio::test]
async fn test_prepare_arguments_shim() {
let edit = Arc::new(EchoTool::new("edit").with_prepare(|args: Value| {
if let Some(obj) = args.as_object()
&& obj.contains_key("oldText")
&& obj.contains_key("newText")
{
return serde_json::json!({
"edits": [{
"oldText": obj.get("oldText").unwrap(),
"newText": obj.get("newText").unwrap(),
}]
});
}
args
}));
let context = build_context(edit.clone());
let assistant_msg = AssistantMessage::new(
vec![ContentBlock::ToolCall {
id: "tool-1".to_string(),
name: "edit".to_string(),
arguments: serde_json::json!({"oldText": "before", "newText": "after"}),
}],
StopReason::ToolUse,
);
let tool_calls = extract_tool_calls(&assistant_msg);
let (tx, mut rx) = mpsc::channel::<LoopEvent>(64);
let config = build_config();
let signal = AbortSignal::new();
let _ = execute_tool_calls_sequential(
&context,
&assistant_msg,
&tool_calls,
&config,
&signal,
&tx,
&InflightSet::new(),
)
.await;
drop(tx);
while rx.recv().await.is_some() {}
let recorded = edit.executed_args.lock().unwrap();
assert_eq!(recorded.len(), 1);
let edits = recorded[0].get("edits").expect("shim should produce edits");
let arr = edits.as_array().expect("edits is array");
assert_eq!(arr.len(), 1);
assert_eq!(arr[0]["oldText"], "before");
assert_eq!(arr[0]["newText"], "after");
}
#[tokio::test]
async fn test_dispatcher_terminate_when_all_results_terminate() {
let echo = Arc::new(EchoTool::new("echo").with_terminate());
let context = build_context(echo.clone());
let assistant_msg = AssistantMessage::new(
vec![ContentBlock::ToolCall {
id: "tool-1".to_string(),
name: "echo".to_string(),
arguments: serde_json::json!({}),
}],
StopReason::ToolUse,
);
let tool_calls = extract_tool_calls(&assistant_msg);
let (tx, _rx) = mpsc::channel::<LoopEvent>(64);
let config = build_config();
let signal = AbortSignal::new();
let batch = execute_tool_calls_sequential(
&context,
&assistant_msg,
&tool_calls,
&config,
&signal,
&tx,
&InflightSet::new(),
)
.await;
assert!(
batch.terminate,
"single terminate=true should set batch.terminate"
);
}
#[tokio::test]
async fn test_after_tool_call_can_set_terminate() {
let echo = Arc::new(EchoTool::new("echo")); let context = build_context(echo);
let assistant_msg = AssistantMessage::new(
vec![ContentBlock::ToolCall {
id: "tool-1".to_string(),
name: "echo".to_string(),
arguments: serde_json::json!({}),
}],
StopReason::ToolUse,
);
let tool_calls = extract_tool_calls(&assistant_msg);
let after: crate::agent::agent_loop::hooks::AfterToolCallFn = Arc::new(|_ctx| {
Box::pin(async move {
Some(AfterToolCallResult {
content: None,
details: None,
is_error: None,
terminate: Some(true),
})
})
});
let mut config = build_config();
config.after_tool_call = Some(after);
let (tx, _rx) = mpsc::channel::<LoopEvent>(64);
let signal = AbortSignal::new();
let batch = execute_tool_calls_sequential(
&context,
&assistant_msg,
&tool_calls,
&config,
&signal,
&tx,
&InflightSet::new(),
)
.await;
assert!(
batch.terminate,
"afterToolCall override should mark batch terminating"
);
}
#[tokio::test]
async fn test_tool_not_found_immediate_error() {
let echo = Arc::new(EchoTool::new("echo"));
let context = build_context(echo);
let assistant_msg = AssistantMessage::new(
vec![ContentBlock::ToolCall {
id: "tool-1".to_string(),
name: "nonexistent".to_string(),
arguments: serde_json::json!({}),
}],
StopReason::ToolUse,
);
let tool_calls = extract_tool_calls(&assistant_msg);
let (tx, _rx) = mpsc::channel::<LoopEvent>(64);
let config = build_config();
let signal = AbortSignal::new();
let batch = execute_tool_calls_sequential(
&context,
&assistant_msg,
&tool_calls,
&config,
&signal,
&tx,
&InflightSet::new(),
)
.await;
assert_eq!(batch.messages.len(), 1);
assert!(batch.messages[0].is_error);
match &batch.messages[0].content[0] {
ContentBlock::Text { text } => assert!(
text.contains("nonexistent"),
"error text should name the missing tool: {text}"
),
_ => panic!("expected text content block"),
}
}
#[tokio::test]
async fn test_tool_not_found_suggests_closest() {
let echo = Arc::new(EchoTool::new("echo"));
let context = build_context(echo);
let assistant_msg = AssistantMessage::new(
vec![ContentBlock::ToolCall {
id: "tool-1".to_string(),
name: "ehco".to_string(), arguments: serde_json::json!({}),
}],
StopReason::ToolUse,
);
let tool_calls = extract_tool_calls(&assistant_msg);
let (tx, _rx) = mpsc::channel::<LoopEvent>(64);
let config = build_config();
let batch = execute_tool_calls_sequential(
&context,
&assistant_msg,
&tool_calls,
&config,
&AbortSignal::new(),
&tx,
&InflightSet::new(),
)
.await;
assert!(batch.messages[0].is_error);
match &batch.messages[0].content[0] {
ContentBlock::Text { text } => assert!(
text.contains("Did you mean `echo`?"),
"should suggest the closest tool: {text}"
),
_ => panic!("expected text content block"),
}
}
#[tokio::test]
async fn test_before_tool_call_block_with_reason() {
let echo = Arc::new(EchoTool::new("echo"));
let context = build_context(echo.clone());
let assistant_msg = AssistantMessage::new(
vec![ContentBlock::ToolCall {
id: "tool-1".to_string(),
name: "echo".to_string(),
arguments: serde_json::json!({}),
}],
StopReason::ToolUse,
);
let tool_calls = extract_tool_calls(&assistant_msg);
let before: BeforeToolCallFn = Arc::new(|ctx: BeforeToolCallContext| {
Box::pin(async move {
BeforeToolCallReturn {
result: Some(BeforeToolCallResult {
block: Some(true),
reason: Some("policy violation".to_string()),
}),
args: ctx.args,
}
})
});
let mut config = build_config();
config.before_tool_call = Some(before);
let (tx, _rx) = mpsc::channel::<LoopEvent>(64);
let signal = AbortSignal::new();
let batch = execute_tool_calls_sequential(
&context,
&assistant_msg,
&tool_calls,
&config,
&signal,
&tx,
&InflightSet::new(),
)
.await;
assert!(echo.executed_args.lock().unwrap().is_empty());
assert!(batch.messages[0].is_error);
match &batch.messages[0].content[0] {
ContentBlock::Text { text } => {
assert!(text.contains("policy violation"), "got: {text}");
}
_ => panic!("expected text content block"),
}
}
#[test]
fn should_terminate_invariants() {
let make = |terminate: Option<bool>| FinalizedOutcome {
tool_call: ToolCall {
id: "x".into(),
name: "x".into(),
arguments: Value::Null,
},
result: LoopToolResult {
content: vec![],
details: Value::Null,
terminate,
},
is_error: false,
};
assert!(!should_terminate_tool_batch(&[]));
assert!(!should_terminate_tool_batch(&[make(Some(false))]));
assert!(!should_terminate_tool_batch(&[make(None)]));
assert!(!should_terminate_tool_batch(&[
make(Some(true)),
make(Some(false))
]));
assert!(should_terminate_tool_batch(&[make(Some(true))]));
assert!(should_terminate_tool_batch(&[
make(Some(true)),
make(Some(true)),
]));
}
fn two_echo_calls() -> Vec<ToolCall> {
vec![
ToolCall {
id: "tool-1".to_string(),
name: "echo".to_string(),
arguments: serde_json::json!({"value": "first"}),
},
ToolCall {
id: "tool-2".to_string(),
name: "echo".to_string(),
arguments: serde_json::json!({"value": "second"}),
},
]
}
fn assistant_with_calls(calls: &[ToolCall]) -> AssistantMessage {
let content = calls
.iter()
.map(|c| ContentBlock::ToolCall {
id: c.id.clone(),
name: c.name.clone(),
arguments: c.arguments.clone(),
})
.collect();
AssistantMessage::new(content, StopReason::ToolUse)
}
#[tokio::test]
async fn test_tool_execution_end_completion_order_results_source_order() {
let echo = Arc::new(EchoTool::new("echo").with_delay_first_ms(50));
let context = build_context(echo.clone());
let calls = two_echo_calls();
let assistant = assistant_with_calls(&calls);
let mut config = build_config();
config.tool_execution = ToolExecutionMode::Parallel;
let (tx, mut rx) = mpsc::channel::<LoopEvent>(128);
let signal = AbortSignal::new();
let _batch = execute_tool_calls_parallel(
&context,
&assistant,
&calls,
&config,
&signal,
&tx,
&InflightSet::new(),
)
.await;
drop(tx);
let mut tool_execution_end_ids: Vec<String> = Vec::new();
let mut tool_result_message_end_ids: Vec<String> = Vec::new();
while let Some(e) = rx.recv().await {
match &e {
LoopEvent::ToolExecutionEnd { tool_call_id, .. } => {
tool_execution_end_ids.push(tool_call_id.clone());
}
LoopEvent::MessageEnd { message } => {
if let LoopMessage::ToolResult(t) = message {
tool_result_message_end_ids.push(t.tool_call_id.clone());
}
}
_ => {}
}
}
assert_eq!(
tool_execution_end_ids,
vec!["tool-2".to_string(), "tool-1".to_string()],
"tool_execution_end should be in completion order"
);
assert_eq!(
tool_result_message_end_ids,
vec!["tool-1".to_string(), "tool-2".to_string()],
"tool-result message_end should be in source order"
);
assert!(
echo.parallel_was_observed(),
"second tool should have observed first still in flight"
);
}
#[tokio::test]
async fn test_per_tool_sequential_forces_sequential_route() {
let echo = Arc::new(
EchoTool::new("echo")
.with_execution_mode(ToolExecutionMode::Sequential)
.with_delay_first_ms(20),
);
let context = build_context(echo.clone());
let calls = two_echo_calls();
let assistant = assistant_with_calls(&calls);
let mut config = build_config();
config.tool_execution = ToolExecutionMode::Parallel;
let (tx, _rx) = mpsc::channel::<LoopEvent>(128);
let signal = AbortSignal::new();
let batch = execute_tool_calls_from_msg(
&context,
&assistant,
&config,
&signal,
&tx,
&InflightSet::new(),
)
.await;
drop(tx);
let (_current, max) = echo.concurrency_snapshot();
assert_eq!(
max, 1,
"per-tool Sequential should force max concurrency = 1, got {max}"
);
assert_eq!(batch.messages.len(), 2);
}
#[tokio::test]
async fn test_one_sequential_among_many_forces_sequential() {
let echo_seq = Arc::new(
EchoTool::new("echo_seq")
.with_execution_mode(ToolExecutionMode::Sequential)
.with_delay_ms(10),
);
let echo_par = Arc::new(EchoTool::new("echo_par").with_delay_ms(10));
let context = Context {
system_prompt: String::new(),
messages: Vec::new(),
tools: vec![echo_seq.clone(), echo_par.clone()],
};
let calls = vec![
ToolCall {
id: "tool-1".into(),
name: "echo_par".into(),
arguments: serde_json::json!({"v": 1}),
},
ToolCall {
id: "tool-2".into(),
name: "echo_seq".into(),
arguments: serde_json::json!({"v": 2}),
},
];
let assistant = assistant_with_calls(&calls);
let mut config = build_config();
config.tool_execution = ToolExecutionMode::Parallel;
let (tx, _rx) = mpsc::channel::<LoopEvent>(128);
let signal = AbortSignal::new();
let _ = execute_tool_calls_from_msg(
&context,
&assistant,
&config,
&signal,
&tx,
&InflightSet::new(),
)
.await;
drop(tx);
let (_, max_seq) = echo_seq.concurrency_snapshot();
let (_, max_par) = echo_par.concurrency_snapshot();
assert_eq!(max_seq, 1, "echo_seq max should be 1");
assert_eq!(max_par, 1, "echo_par max should be 1");
}
#[tokio::test]
async fn test_all_parallel_runs_concurrent() {
let echo = Arc::new(EchoTool::new("echo").with_delay_first_ms(30));
let context = build_context(echo.clone());
let calls = two_echo_calls();
let assistant = assistant_with_calls(&calls);
let mut config = build_config();
config.tool_execution = ToolExecutionMode::Parallel;
let (tx, _rx) = mpsc::channel::<LoopEvent>(128);
let signal = AbortSignal::new();
let _ = execute_tool_calls_from_msg(
&context,
&assistant,
&config,
&signal,
&tx,
&InflightSet::new(),
)
.await;
drop(tx);
let (_current, max) = echo.concurrency_snapshot();
assert!(
max >= 2,
"parallel dispatch should run >=2 tools concurrently, got {max}"
);
}
#[tokio::test]
async fn test_parallel_batch_not_terminating_when_mixed() {
let echo_term = Arc::new(EchoTool::new("term").with_terminate());
let echo_norm = Arc::new(EchoTool::new("norm"));
let context = Context {
system_prompt: String::new(),
messages: Vec::new(),
tools: vec![echo_term, echo_norm],
};
let calls = vec![
ToolCall {
id: "tool-1".into(),
name: "term".into(),
arguments: serde_json::json!({}),
},
ToolCall {
id: "tool-2".into(),
name: "norm".into(),
arguments: serde_json::json!({}),
},
];
let assistant = assistant_with_calls(&calls);
let mut config = build_config();
config.tool_execution = ToolExecutionMode::Parallel;
let (tx, _rx) = mpsc::channel::<LoopEvent>(128);
let signal = AbortSignal::new();
let batch = execute_tool_calls_from_msg(
&context,
&assistant,
&config,
&signal,
&tx,
&InflightSet::new(),
)
.await;
drop(tx);
assert!(
!batch.terminate,
"batch should NOT terminate when only some results have terminate=true"
);
assert_eq!(batch.messages.len(), 2);
}
#[tokio::test]
async fn test_parallel_mixes_immediate_and_async() {
let echo = Arc::new(EchoTool::new("echo").with_delay_first_ms(20));
let context = build_context(echo);
let calls = vec![
ToolCall {
id: "tool-1".into(),
name: "nonexistent".into(), arguments: serde_json::json!({}),
},
ToolCall {
id: "tool-2".into(),
name: "echo".into(),
arguments: serde_json::json!({"value": "first"}),
},
];
let assistant = assistant_with_calls(&calls);
let mut config = build_config();
config.tool_execution = ToolExecutionMode::Parallel;
let (tx, mut rx) = mpsc::channel::<LoopEvent>(128);
let signal = AbortSignal::new();
let batch = execute_tool_calls_parallel(
&context,
&assistant,
&calls,
&config,
&signal,
&tx,
&InflightSet::new(),
)
.await;
drop(tx);
assert_eq!(batch.messages.len(), 2);
assert!(batch.messages[0].is_error);
assert!(!batch.messages[1].is_error);
let mut tool_result_ids: Vec<String> = Vec::new();
while let Some(e) = rx.recv().await {
if let LoopEvent::MessageEnd {
message: LoopMessage::ToolResult(t),
} = e
{
tool_result_ids.push(t.tool_call_id);
}
}
assert_eq!(
tool_result_ids,
vec!["tool-1".to_string(), "tool-2".to_string()]
);
}
#[derive(Debug)]
struct BlockingTool {
delay: std::time::Duration,
}
impl LoopTool for BlockingTool {
fn name(&self) -> &str {
"block"
}
fn description(&self) -> &str {
"Blocks for a fixed duration without polling signal."
}
fn label(&self) -> &str {
"Block"
}
fn parameters(&self) -> &Value {
static EMPTY: std::sync::OnceLock<Value> = std::sync::OnceLock::new();
EMPTY.get_or_init(|| serde_json::json!({"type": "object"}))
}
fn execute<'a>(
&'a self,
_id: &'a str,
_args: Value,
_signal: AbortSignal, _on_update: LoopToolUpdate,
) -> std::pin::Pin<Box<dyn Future<Output = Result<LoopToolResult, String>> + Send + 'a>> {
let delay = self.delay;
Box::pin(async move {
tokio::time::sleep(delay).await;
Ok(LoopToolResult {
content: vec![serde_json::json!({
"type": "text",
"text": "completed",
})],
details: Value::Null,
terminate: None,
})
})
}
}
#[tokio::test]
async fn aborted_tool_returns_aborted_error_promptly() {
let blocking = Arc::new(BlockingTool {
delay: std::time::Duration::from_secs(10),
});
let mut ctx = Context::default();
ctx.tools.push(blocking.clone());
let signal = AbortSignal::new();
signal.cancel();
let calls = vec![ToolCall {
id: "tc-1".to_string(),
name: "block".to_string(),
arguments: serde_json::json!({}),
}];
let assistant = AssistantMessage::new(
calls
.iter()
.map(|c| ContentBlock::ToolCall {
id: c.id.clone(),
name: c.name.clone(),
arguments: c.arguments.clone(),
})
.collect(),
StopReason::ToolUse,
);
let cfg = build_config();
let (tx, _rx) = mpsc::channel::<LoopEvent>(64);
let started = std::time::Instant::now();
let batch = execute_tool_calls_sequential(
&ctx,
&assistant,
&calls,
&cfg,
&signal,
&tx,
&InflightSet::new(),
)
.await;
let elapsed = started.elapsed();
assert!(
elapsed < std::time::Duration::from_secs(1),
"expected near-instant abort; elapsed {elapsed:?}"
);
assert_eq!(batch.messages.len(), 1);
let block = &batch.messages[0].content[0];
let text = match block {
ContentBlock::Text { text } => text.clone(),
other => panic!("expected Text block; got {other:?}"),
};
assert!(
text.contains("aborted"),
"expected aborted message; got: {text:?}"
);
assert!(batch.messages[0].is_error);
}
#[tokio::test]
async fn cancelled_tool_future_is_dropped_not_detached() {
use std::sync::atomic::{AtomicBool, Ordering};
struct DropFlag(Arc<AtomicBool>);
impl Drop for DropFlag {
fn drop(&mut self) {
self.0.store(true, Ordering::SeqCst);
}
}
#[derive(Debug)]
struct DropProbeTool {
started: Arc<AtomicBool>,
dropped: Arc<AtomicBool>,
completed: Arc<AtomicBool>,
}
impl LoopTool for DropProbeTool {
fn name(&self) -> &str {
"probe"
}
fn description(&self) -> &str {
"Signals start, then sleeps; flips a flag if dropped."
}
fn label(&self) -> &str {
"Probe"
}
fn parameters(&self) -> &Value {
static EMPTY: std::sync::OnceLock<Value> = std::sync::OnceLock::new();
EMPTY.get_or_init(|| serde_json::json!({"type": "object"}))
}
fn execute<'a>(
&'a self,
_id: &'a str,
_args: Value,
_signal: AbortSignal, _on_update: LoopToolUpdate,
) -> std::pin::Pin<Box<dyn Future<Output = Result<LoopToolResult, String>> + Send + 'a>>
{
let started = self.started.clone();
let dropped = self.dropped.clone();
let completed = self.completed.clone();
Box::pin(async move {
let _guard = DropFlag(dropped);
started.store(true, Ordering::SeqCst);
tokio::time::sleep(std::time::Duration::from_secs(10)).await;
completed.store(true, Ordering::SeqCst);
Ok(LoopToolResult {
content: vec![serde_json::json!({"type": "text", "text": "done"})],
details: Value::Null,
terminate: None,
})
})
}
}
let started = Arc::new(AtomicBool::new(false));
let dropped = Arc::new(AtomicBool::new(false));
let completed = Arc::new(AtomicBool::new(false));
let mut ctx = Context::default();
ctx.tools.push(Arc::new(DropProbeTool {
started: started.clone(),
dropped: dropped.clone(),
completed: completed.clone(),
}));
let signal = AbortSignal::new();
let calls = vec![ToolCall {
id: "tc-1".to_string(),
name: "probe".to_string(),
arguments: serde_json::json!({}),
}];
let assistant = AssistantMessage::new(
calls
.iter()
.map(|c| ContentBlock::ToolCall {
id: c.id.clone(),
name: c.name.clone(),
arguments: c.arguments.clone(),
})
.collect(),
StopReason::ToolUse,
);
let cfg = build_config();
let (tx, _rx) = mpsc::channel::<LoopEvent>(64);
let canceller = {
let signal = signal.clone();
let started = started.clone();
async move {
while !started.load(Ordering::SeqCst) {
tokio::task::yield_now().await;
}
signal.cancel();
}
};
let inflight = InflightSet::new();
let dispatch =
execute_tool_calls_sequential(&ctx, &assistant, &calls, &cfg, &signal, &tx, &inflight);
let (batch, _) = tokio::time::timeout(std::time::Duration::from_secs(5), async move {
tokio::join!(dispatch, canceller)
})
.await
.expect("cancellation must drop the future promptly, not wait out the 10s sleep");
assert!(started.load(Ordering::SeqCst), "tool should have started");
assert!(
dropped.load(Ordering::SeqCst),
"tool future must be dropped on cancel"
);
assert!(
!completed.load(Ordering::SeqCst),
"tool must NOT run to completion after cancel (no detached execution)"
);
assert!(
batch.messages[0].is_error,
"result should be the abort error"
);
}
#[tokio::test]
async fn truncation_repair_end_to_end_through_dispatch() {
use crate::agent::agent_loop::tool_input_repair::RepairKind;
let echo = Arc::new(EchoTool::new("echo"));
let context = build_context(echo.clone());
let truncated = r#"{"value": "hello"#;
let assistant_msg = AssistantMessage::new(
vec![ContentBlock::ToolCall {
id: "tool-1".to_string(),
name: "echo".to_string(),
arguments: serde_json::Value::String(truncated.to_string()),
}],
StopReason::ToolUse,
);
let mut tool_calls = extract_tool_calls(&assistant_msg);
assert_eq!(tool_calls.len(), 1);
let (tx, mut rx) = mpsc::channel::<LoopEvent>(64);
let config = build_config();
let signal = AbortSignal::new();
crate::agent::agent_loop::run::apply_truncation_repair(
&mut tool_calls,
&config.repair_stats,
&config.truncation_notes,
);
let batch = execute_tool_calls_sequential(
&context,
&assistant_msg,
&tool_calls,
&config,
&signal,
&tx,
&InflightSet::new(),
)
.await;
drop(tx);
let recorded = echo.executed_args.lock().unwrap();
assert_eq!(
recorded.len(),
1,
"tool must have been invoked exactly once"
);
let received = &recorded[0];
assert!(
received.is_object(),
"args must reach execute as an Object, not Value::String; got: {received:?}",
);
assert_eq!(
received["value"], "hello",
"the closer must have closed the unterminated string preserving its content",
);
drop(recorded);
assert_eq!(batch.messages.len(), 1);
assert!(
!batch.messages[0].is_error,
"truncation-repaired call must dispatch as a normal success: {:?}",
batch.messages[0],
);
assert!(!batch.terminate);
let snap = config.repair_stats.snapshot();
assert_eq!(
snap.truncation_fixed, 1,
"RepairStats.truncation_fixed must increment by 1; got snapshot {:?}",
snap,
);
assert_eq!(snap.null_stripped, 0);
assert_eq!(snap.json_string_to_array, 0);
assert_eq!(snap.invalid, 0);
assert!(
RepairKind::ALL.contains(&RepairKind::TruncationFixed),
"TruncationFixed must appear in RepairKind::ALL for telemetry iteration",
);
let mut kinds = Vec::new();
while let Some(e) = rx.recv().await {
kinds.push(e.kind().to_string());
}
assert_eq!(
kinds,
vec![
"tool_execution_start",
"tool_execution_end",
"message_start",
"message_end",
],
"event sequence must match a non-truncated success path",
);
}
#[tokio::test]
async fn truncation_hard_fallback_does_not_fabricate_args() {
use crate::agent::agent_loop::tool::LoopTool;
use std::sync::OnceLock;
#[derive(Debug)]
struct StrictPathTool {
name: String,
executed: Arc<Mutex<Vec<Value>>>,
}
impl LoopTool for StrictPathTool {
fn name(&self) -> &str {
&self.name
}
fn description(&self) -> &str {
"needs path"
}
fn label(&self) -> &str {
"strict"
}
fn parameters(&self) -> &Value {
static SCHEMA: OnceLock<Value> = OnceLock::new();
SCHEMA.get_or_init(|| {
serde_json::json!({
"type": "object",
"properties": { "path": { "type": "string" } },
"required": ["path"]
})
})
}
fn execute<'a>(
&'a self,
_id: &'a str,
args: Value,
_signal: AbortSignal,
_on_update: LoopToolUpdate,
) -> Pin<Box<dyn std::future::Future<Output = Result<LoopToolResult, String>> + Send + 'a>>
{
let executed = self.executed.clone();
Box::pin(async move {
executed.lock().unwrap().push(args);
Ok(LoopToolResult {
content: vec![serde_json::json!({"type": "text", "text": "ok"})],
details: serde_json::Value::Null,
terminate: None,
})
})
}
}
let tool = Arc::new(StrictPathTool {
name: "strict".to_string(),
executed: Arc::new(Mutex::new(Vec::new())),
});
let context = build_context(tool.clone());
let assistant_msg = AssistantMessage::new(
vec![ContentBlock::ToolCall {
id: "tool-1".to_string(),
name: "strict".to_string(),
arguments: serde_json::Value::String("}}}}}".to_string()),
}],
StopReason::ToolUse,
);
let tool_calls = extract_tool_calls(&assistant_msg);
let (tx, _rx) = mpsc::channel::<LoopEvent>(64);
let config = build_config();
let signal = AbortSignal::new();
let batch = execute_tool_calls_sequential(
&context,
&assistant_msg,
&tool_calls,
&config,
&signal,
&tx,
&InflightSet::new(),
)
.await;
let executed = tool.executed.lock().unwrap();
assert!(
executed.is_empty(),
"strict tool must NOT receive fabricated args; got: {:?}",
*executed,
);
assert_eq!(batch.messages.len(), 1);
assert!(
batch.messages[0].is_error,
"hard-fallback must dispatch as an error so the model sees the failure",
);
let snap = config.repair_stats.snapshot();
assert_eq!(snap.truncation_fixed, 0);
assert_eq!(snap.invalid, 1);
}
#[test]
fn content_value_to_block_redacts_secrets() {
let v = serde_json::json!({
"type": "text",
"text": "OPENAI_API_KEY=sk-abcdefghijklmnopqrstuvwxyz0123456789"
});
match content_value_to_block(&v) {
ContentBlock::Text { text } => {
assert!(
!text.contains("sk-abcdefghijklmnopqrstuvwxyz0123456789"),
"secret must be redacted at the result boundary; got {text}"
);
assert!(text.contains("[REDACTED]"), "got {text}");
}
other => panic!("expected text block, got {other:?}"),
}
}
#[test]
fn content_value_to_block_passes_plain_text_through() {
let v = serde_json::json!({"type": "text", "text": "build ok: 12 files"});
match content_value_to_block(&v) {
ContentBlock::Text { text } => assert_eq!(text, "build ok: 12 files"),
other => panic!("expected text block, got {other:?}"),
}
}
fn tc(id: &str, name: &str) -> ToolCall {
ToolCall {
id: id.to_string(),
name: name.to_string(),
arguments: serde_json::Value::Null,
}
}
fn trm(id: &str, name: &str) -> ToolResultMessage {
ToolResultMessage {
tool_call_id: id.to_string(),
tool_name: name.to_string(),
content: vec![],
details: serde_json::Value::Null,
is_error: false,
}
}
#[test]
fn backfill_fills_only_the_unanswered_id() {
let calls = [tc("a", "edit"), tc("b", "read"), tc("c", "bash")];
let results = [trm("a", "edit"), trm("c", "bash")];
let back = backfill_missing_tool_results(&calls, &results);
assert_eq!(back.len(), 1, "exactly one orphan");
assert_eq!(back[0].tool_call_id, "b");
assert_eq!(back[0].tool_name, "read");
assert!(back[0].is_error, "backfill must be an error result");
}
#[test]
fn backfill_empty_when_every_call_is_answered() {
let calls = [tc("a", "x"), tc("b", "y")];
let results = [trm("a", "x"), trm("b", "y")];
assert!(backfill_missing_tool_results(&calls, &results).is_empty());
}
#[test]
fn backfill_fills_all_when_none_answered() {
let calls = [tc("a", "x"), tc("b", "y")];
let back = backfill_missing_tool_results(&calls, &[]);
assert_eq!(back.len(), 2);
let ids: std::collections::HashSet<&str> =
back.iter().map(|r| r.tool_call_id.as_str()).collect();
assert!(ids.contains("a") && ids.contains("b"));
assert!(back.iter().all(|r| r.is_error));
}