use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use serde_json::{Value, json};
use super::*;
use crate::ToolCall;
use crate::chat::{ChatMessage, ChatRole, ContentBlock, StopReason};
use crate::provider::{ChatParams, JsonSchema, ToolDefinition};
use crate::test_helpers::{mock_for, sample_response, sample_tool_response};
use crate::usage::Usage;
#[test]
fn key_types_are_send_and_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<ToolRegistry<()>>();
assert_send_sync::<OwnedToolLoopHandle<()>>();
assert_send_sync::<LoopEvent>();
assert_send_sync::<ToolLoopResult>();
assert_send_sync::<ToolLoopConfig>();
assert_send_sync::<ToolOutput>();
assert_send_sync::<ToolError>();
assert_send_sync::<LoopDetectionConfig>();
}
fn number_schema() -> JsonSchema {
JsonSchema::new(json!({
"type": "object",
"properties": {
"a": {"type": "number"},
"b": {"type": "number"}
},
"required": ["a", "b"]
}))
}
struct AddTool;
impl ToolHandler<()> for AddTool {
fn definition(&self) -> ToolDefinition {
ToolDefinition {
name: "add".into(),
description: "Add two numbers".into(),
parameters: number_schema(),
retry: None,
}
}
fn execute<'a>(
&'a self,
input: Value,
_ctx: &'a (),
) -> Pin<Box<dyn Future<Output = Result<ToolOutput, ToolError>> + Send + 'a>> {
Box::pin(async move {
let a = input["a"].as_f64().unwrap_or(0.0);
let b = input["b"].as_f64().unwrap_or(0.0);
Ok(ToolOutput::new(format!("{}", a + b)))
})
}
}
struct FailTool;
impl ToolHandler<()> for FailTool {
fn definition(&self) -> ToolDefinition {
ToolDefinition {
name: "fail".into(),
description: "Always fails".into(),
parameters: JsonSchema::new(json!({"type": "object"})),
retry: None,
}
}
fn execute<'a>(
&'a self,
input: Value,
_ctx: &'a (),
) -> Pin<Box<dyn Future<Output = Result<ToolOutput, ToolError>> + Send + 'a>> {
let _ = input;
Box::pin(async move { Err(ToolError::new("intentional failure")) })
}
}
#[test]
fn test_tool_handler_is_object_safe() {
fn assert_object_safe(_: &dyn ToolHandler<()>) {}
assert_object_safe(&AddTool);
}
#[test]
fn test_tool_error_display() {
let err = ToolError::new("something broke");
assert_eq!(format!("{err}"), "something broke");
}
#[test]
fn test_tool_handler_definition() {
let def = AddTool.definition();
assert_eq!(def.name, "add");
assert_eq!(def.description, "Add two numbers");
}
#[tokio::test]
async fn test_tool_handler_execute() {
let result = AddTool.execute(json!({"a": 2, "b": 3}), &()).await.unwrap();
assert_eq!(result.content, "5");
}
#[tokio::test]
async fn test_tool_handler_error() {
let result = FailTool.execute(json!({}), &()).await;
assert!(result.is_err());
assert_eq!(result.unwrap_err().message, "intentional failure");
}
#[test]
fn test_fn_tool_handler() {
let handler = tool_fn(
ToolDefinition {
name: "greet".into(),
description: "Say hello".into(),
parameters: JsonSchema::new(json!({"type": "object"})),
retry: None,
},
|_input: Value| async { Ok("Hello!".to_string()) },
);
assert_eq!(handler.definition().name, "greet");
}
#[test]
fn test_registry_empty() {
let registry: ToolRegistry<()> = ToolRegistry::new();
assert!(registry.is_empty());
assert_eq!(registry.len(), 0);
assert!(registry.definitions().is_empty());
}
#[test]
fn test_registry_register_and_get() {
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(AddTool);
assert_eq!(registry.len(), 1);
assert!(registry.contains("add"));
assert!(!registry.contains("subtract"));
assert!(registry.get("add").is_some());
}
#[test]
fn test_registry_definitions() {
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(AddTool);
registry.register(FailTool);
let defs = registry.definitions();
assert_eq!(defs.len(), 2);
let names: Vec<_> = defs.iter().map(|d| d.name.as_str()).collect();
assert!(names.contains(&"add"));
assert!(names.contains(&"fail"));
}
#[test]
fn test_registry_overwrite_duplicate() {
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(AddTool);
registry.register(AddTool);
assert_eq!(registry.len(), 1);
}
#[tokio::test]
async fn test_registry_execute_valid() {
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(AddTool);
let call = ToolCall {
id: "call_1".into(),
name: "add".into(),
arguments: json!({"a": 10, "b": 20}),
};
let result = registry.execute(&call, &()).await;
assert!(!result.is_error);
assert_eq!(result.content, "30");
assert_eq!(result.tool_call_id, "call_1");
}
#[tokio::test]
async fn test_registry_execute_unknown_tool() {
let registry: ToolRegistry<()> = ToolRegistry::new();
let call = ToolCall {
id: "call_1".into(),
name: "nonexistent".into(),
arguments: json!({}),
};
let result = registry.execute(&call, &()).await;
assert!(result.is_error);
assert!(result.content.contains("Unknown tool"));
}
#[tokio::test]
async fn test_registry_execute_schema_validation_failure() {
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(AddTool);
let call = ToolCall {
id: "call_1".into(),
name: "add".into(),
arguments: json!({"a": "not a number", "b": 5}),
};
let result = registry.execute(&call, &()).await;
assert!(result.is_error);
assert!(result.content.contains("Invalid arguments"));
}
#[tokio::test]
async fn test_registry_execute_missing_required_field() {
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(AddTool);
let call = ToolCall {
id: "call_1".into(),
name: "add".into(),
arguments: json!({"a": 5}),
};
let result = registry.execute(&call, &()).await;
assert!(result.is_error);
assert!(result.content.contains("Invalid arguments"));
}
#[tokio::test]
async fn test_registry_execute_handler_error() {
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(FailTool);
let call = ToolCall {
id: "call_1".into(),
name: "fail".into(),
arguments: json!({}),
};
let result = registry.execute(&call, &()).await;
assert!(result.is_error);
assert_eq!(result.content, "intentional failure");
}
#[tokio::test]
async fn test_registry_execute_all_sequential() {
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(AddTool);
let calls = vec![
ToolCall {
id: "c1".into(),
name: "add".into(),
arguments: json!({"a": 1, "b": 2}),
},
ToolCall {
id: "c2".into(),
name: "add".into(),
arguments: json!({"a": 3, "b": 4}),
},
];
let results = registry.execute_all(&calls, &(), false).await;
assert_eq!(results.len(), 2);
assert_eq!(results[0].content, "3");
assert_eq!(results[1].content, "7");
}
#[tokio::test]
async fn test_registry_execute_all_parallel() {
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(AddTool);
let calls = vec![
ToolCall {
id: "c1".into(),
name: "add".into(),
arguments: json!({"a": 10, "b": 20}),
},
ToolCall {
id: "c2".into(),
name: "add".into(),
arguments: json!({"a": 30, "b": 40}),
},
];
let results = registry.execute_all(&calls, &(), true).await;
assert_eq!(results.len(), 2);
assert_eq!(results[0].content, "30");
assert_eq!(results[1].content, "70");
}
#[tokio::test]
async fn test_registry_execute_all_with_failure() {
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(AddTool);
registry.register(FailTool);
let calls = vec![
ToolCall {
id: "c1".into(),
name: "add".into(),
arguments: json!({"a": 1, "b": 2}),
},
ToolCall {
id: "c2".into(),
name: "fail".into(),
arguments: json!({}),
},
];
let results = registry.execute_all(&calls, &(), true).await;
assert!(!results[0].is_error);
assert_eq!(results[0].content, "3");
assert!(results[1].is_error);
assert_eq!(results[1].content, "intentional failure");
}
#[test]
fn test_registry_is_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<ToolRegistry>();
}
#[test]
fn test_registry_debug() {
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(AddTool);
let debug = format!("{registry:?}");
assert!(debug.contains("ToolRegistry"));
assert!(debug.contains("add"));
}
#[tokio::test]
async fn test_tool_loop_no_tool_calls() {
let mock = mock_for("test", "test-model");
mock.queue_response(sample_response("Hello!"));
let registry: ToolRegistry<()> = ToolRegistry::new();
let params = ChatParams {
messages: vec![ChatMessage::user("Hi")],
..Default::default()
};
let result = tool_loop(&mock, ®istry, params, ToolLoopConfig::default(), &())
.await
.unwrap();
assert_eq!(result.iterations, 1);
assert_eq!(result.response.text(), Some("Hello!"));
}
#[tokio::test]
async fn test_tool_loop_one_iteration() {
let mock = mock_for("test", "test-model");
mock.queue_response(sample_tool_response(vec![ToolCall {
id: "call_1".into(),
name: "add".into(),
arguments: json!({"a": 2, "b": 3}),
}]));
mock.queue_response(sample_response("The answer is 5"));
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(AddTool);
let params = ChatParams {
messages: vec![ChatMessage::user("What is 2 + 3?")],
..Default::default()
};
let result = tool_loop(&mock, ®istry, params, ToolLoopConfig::default(), &())
.await
.unwrap();
assert_eq!(result.iterations, 2);
assert_eq!(result.response.text(), Some("The answer is 5"));
assert_eq!(result.total_usage.input_tokens, 200); }
#[tokio::test]
async fn test_tool_loop_multiple_iterations() {
let mock = mock_for("test", "test-model");
mock.queue_response(sample_tool_response(vec![ToolCall {
id: "c1".into(),
name: "add".into(),
arguments: json!({"a": 1, "b": 2}),
}]));
mock.queue_response(sample_tool_response(vec![ToolCall {
id: "c2".into(),
name: "add".into(),
arguments: json!({"a": 3, "b": 4}),
}]));
mock.queue_response(sample_response("Done"));
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(AddTool);
let params = ChatParams {
messages: vec![ChatMessage::user("Chain calls")],
..Default::default()
};
let result = tool_loop(&mock, ®istry, params, ToolLoopConfig::default(), &())
.await
.unwrap();
assert_eq!(result.iterations, 3);
assert_eq!(result.response.text(), Some("Done"));
}
#[tokio::test]
async fn test_tool_loop_max_iterations_exceeded() {
let mock = mock_for("test", "test-model");
for _ in 0..5 {
mock.queue_response(sample_tool_response(vec![ToolCall {
id: "c".into(),
name: "add".into(),
arguments: json!({"a": 1, "b": 2}),
}]));
}
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(AddTool);
let params = ChatParams {
messages: vec![ChatMessage::user("Loop")],
..Default::default()
};
let config = ToolLoopConfig {
max_iterations: 3,
..Default::default()
};
let result = tool_loop(&mock, ®istry, params, config, &())
.await
.unwrap();
assert!(matches!(
result.termination_reason,
TerminationReason::MaxIterations { limit: 3 }
));
}
#[tokio::test]
async fn test_tool_loop_approval_deny() {
let mock = mock_for("test", "test-model");
mock.queue_response(sample_tool_response(vec![ToolCall {
id: "call_1".into(),
name: "add".into(),
arguments: json!({"a": 2, "b": 3}),
}]));
mock.queue_response(sample_response("OK denied"));
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(AddTool);
let params = ChatParams {
messages: vec![ChatMessage::user("Denied")],
..Default::default()
};
let config = ToolLoopConfig {
on_tool_call: Some(Arc::new(|_call| ToolApproval::Deny("not allowed".into()))),
..Default::default()
};
let _result = tool_loop(&mock, ®istry, params, config, &())
.await
.unwrap();
let recorded = mock.recorded_calls();
let last_call = &recorded[1];
let tool_msgs: Vec<_> = last_call
.messages
.iter()
.filter(|m| m.role == ChatRole::Tool)
.collect();
assert!(!tool_msgs.is_empty());
}
#[tokio::test]
async fn test_tool_loop_approval_modify() {
let mock = mock_for("test", "test-model");
mock.queue_response(sample_tool_response(vec![ToolCall {
id: "call_1".into(),
name: "add".into(),
arguments: json!({"a": 2, "b": 3}),
}]));
mock.queue_response(sample_response("Modified"));
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(AddTool);
let params = ChatParams {
messages: vec![ChatMessage::user("Modify")],
..Default::default()
};
let config = ToolLoopConfig {
on_tool_call: Some(Arc::new(|_call| {
ToolApproval::Modify(json!({"a": 100, "b": 200}))
})),
..Default::default()
};
let result = tool_loop(&mock, ®istry, params, config, &())
.await
.unwrap();
assert_eq!(result.iterations, 2);
let recorded = mock.recorded_calls();
let tool_msgs: Vec<_> = recorded[1]
.messages
.iter()
.filter(|m| m.role == ChatRole::Tool)
.collect();
let tool_content = &tool_msgs[0].content;
let has_300 = tool_content.iter().any(|b| {
if let ContentBlock::ToolResult(r) = b {
r.content == "300"
} else {
false
}
});
assert!(has_300, "Expected tool result with '300'");
}
#[tokio::test]
async fn test_tool_loop_parallel_execution() {
let mock = mock_for("test", "test-model");
mock.queue_response(sample_tool_response(vec![
ToolCall {
id: "c1".into(),
name: "add".into(),
arguments: json!({"a": 1, "b": 2}),
},
ToolCall {
id: "c2".into(),
name: "add".into(),
arguments: json!({"a": 3, "b": 4}),
},
]));
mock.queue_response(sample_response("Both done"));
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(AddTool);
let params = ChatParams {
messages: vec![ChatMessage::user("Parallel")],
..Default::default()
};
let result = tool_loop(&mock, ®istry, params, ToolLoopConfig::default(), &())
.await
.unwrap();
assert_eq!(result.iterations, 2);
let recorded = mock.recorded_calls();
let tool_msgs: Vec<_> = recorded[1]
.messages
.iter()
.filter(|m| m.role == ChatRole::Tool)
.collect();
assert_eq!(tool_msgs.len(), 2);
}
#[tokio::test]
async fn test_tool_loop_unknown_tool_sends_error_result() {
let mock = mock_for("test", "test-model");
mock.queue_response(sample_tool_response(vec![ToolCall {
id: "call_1".into(),
name: "unknown_tool".into(),
arguments: json!({}),
}]));
mock.queue_response(sample_response("Handled unknown"));
let registry: ToolRegistry<()> = ToolRegistry::new();
let params = ChatParams {
messages: vec![ChatMessage::user("Unknown")],
..Default::default()
};
let result = tool_loop(&mock, ®istry, params, ToolLoopConfig::default(), &())
.await
.unwrap();
assert_eq!(result.iterations, 2);
}
#[tokio::test]
async fn test_tool_loop_usage_accumulation() {
let mock = mock_for("test", "test-model");
mock.queue_response(sample_tool_response(vec![ToolCall {
id: "c1".into(),
name: "add".into(),
arguments: json!({"a": 1, "b": 2}),
}]));
mock.queue_response(sample_response("Done"));
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(AddTool);
let params = ChatParams {
messages: vec![ChatMessage::user("Usage test")],
..Default::default()
};
let result = tool_loop(&mock, ®istry, params, ToolLoopConfig::default(), &())
.await
.unwrap();
assert_eq!(result.total_usage.input_tokens, 200);
assert_eq!(result.total_usage.output_tokens, 100);
}
#[tokio::test]
async fn test_tool_loop_stream_no_tools() {
use futures::StreamExt;
let mock = Arc::new(mock_for("test", "test-model"));
mock.queue_stream(vec![
crate::stream::StreamEvent::TextDelta("Hello".into()),
crate::stream::StreamEvent::Done {
stop_reason: StopReason::EndTurn,
},
]);
let registry = Arc::new(ToolRegistry::new());
let params = ChatParams {
messages: vec![ChatMessage::user("Hi")],
..Default::default()
};
let stream = tool_loop_stream(
mock,
registry,
params,
ToolLoopConfig::default(),
Arc::new(()),
);
let events: Vec<_> = stream
.collect::<Vec<_>>()
.await
.into_iter()
.map(|r| r.unwrap())
.collect();
assert!(
events
.iter()
.any(|e| matches!(e, LoopEvent::TextDelta(t) if t == "Hello"))
);
assert!(events.iter().any(|e| matches!(
e,
LoopEvent::Done(result) if result.termination_reason == TerminationReason::Complete
)));
}
#[tokio::test]
async fn test_tool_loop_stream_one_iteration() {
use futures::StreamExt;
let mock = Arc::new(mock_for("test", "test-model"));
mock.queue_stream(vec![
crate::stream::StreamEvent::ToolCallStart {
index: 0,
id: "call_1".into(),
name: "add".into(),
},
crate::stream::StreamEvent::ToolCallDelta {
index: 0,
json_chunk: r#"{"a":2,"b":3}"#.into(),
},
crate::stream::StreamEvent::ToolCallComplete {
index: 0,
call: ToolCall {
id: "call_1".into(),
name: "add".into(),
arguments: json!({"a": 2, "b": 3}),
},
},
crate::stream::StreamEvent::Done {
stop_reason: StopReason::ToolUse,
},
]);
mock.queue_stream(vec![
crate::stream::StreamEvent::TextDelta("The answer is 5".into()),
crate::stream::StreamEvent::Done {
stop_reason: StopReason::EndTurn,
},
]);
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(AddTool);
let registry = Arc::new(registry);
let params = ChatParams {
messages: vec![ChatMessage::user("What is 2+3?")],
..Default::default()
};
let stream = tool_loop_stream(
mock,
registry,
params,
ToolLoopConfig::default(),
Arc::new(()),
);
let events: Vec<_> = stream
.collect::<Vec<_>>()
.await
.into_iter()
.map(|r| r.unwrap())
.collect();
assert!(
events
.iter()
.any(|e| matches!(e, LoopEvent::ToolCallStart { .. }))
);
assert!(
events
.iter()
.any(|e| matches!(e, LoopEvent::TextDelta(t) if t == "The answer is 5"))
);
assert!(events.iter().any(|e| matches!(
e,
LoopEvent::Done(result) if result.termination_reason == TerminationReason::Complete
)));
}
#[tokio::test]
async fn test_tool_loop_stream_max_iterations() {
use futures::StreamExt;
let mock = Arc::new(mock_for("test", "test-model"));
for _ in 0..5 {
mock.queue_stream(vec![
crate::stream::StreamEvent::ToolCallComplete {
index: 0,
call: ToolCall {
id: "c".into(),
name: "add".into(),
arguments: json!({"a": 1, "b": 2}),
},
},
crate::stream::StreamEvent::Done {
stop_reason: StopReason::ToolUse,
},
]);
}
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(AddTool);
let registry = Arc::new(registry);
let params = ChatParams {
messages: vec![ChatMessage::user("Loop")],
..Default::default()
};
let config = ToolLoopConfig {
max_iterations: 2,
..Default::default()
};
let stream = tool_loop_stream(mock, registry, params, config, Arc::new(()));
let results: Vec<_> = stream.collect::<Vec<_>>().await;
let events: Vec<LoopEvent> = results.into_iter().filter_map(Result::ok).collect();
assert!(events.iter().any(|e| matches!(
e,
LoopEvent::Done(result) if matches!(result.termination_reason, TerminationReason::MaxIterations { limit: 2 })
)));
}
#[test]
fn test_usage_add_assign() {
let mut total = Usage {
input_tokens: 100,
output_tokens: 50,
reasoning_tokens: Some(10),
cache_read_tokens: None,
cache_write_tokens: None,
};
total += Usage {
input_tokens: 200,
output_tokens: 75,
reasoning_tokens: Some(20),
cache_read_tokens: Some(5),
cache_write_tokens: None,
};
assert_eq!(total.input_tokens, 300);
assert_eq!(total.output_tokens, 125);
assert_eq!(total.reasoning_tokens, Some(30));
assert_eq!(total.cache_read_tokens, Some(5));
assert!(total.cache_write_tokens.is_none());
}
#[test]
fn test_usage_add_assign_both_none() {
let mut a = Usage::default();
a += Usage::default();
assert!(a.reasoning_tokens.is_none());
}
#[test]
fn test_tool_result_full_message() {
use crate::chat::ToolResult;
let result = ToolResult {
tool_call_id: "c1".into(),
content: "42".into(),
is_error: false,
};
let msg = ChatMessage::tool_result_full(result);
assert_eq!(msg.role, ChatRole::Tool);
assert!(matches!(&msg.content[0], ContentBlock::ToolResult(r) if r.content == "42"));
}
#[test]
fn test_loop_event_debug() {
let event = LoopEvent::IterationStart {
iteration: 1,
message_count: 5,
};
let debug = format!("{event:?}");
assert!(debug.contains("IterationStart"));
assert!(debug.contains("iteration: 1"));
}
#[test]
fn test_loop_event_clone() {
let event = LoopEvent::ToolExecutionStart {
call_id: "c1".into(),
tool_name: "add".into(),
arguments: json!({"a": 1}),
};
let cloned = event.clone();
assert!(
matches!(cloned, LoopEvent::ToolExecutionStart { tool_name, .. } if tool_name == "add")
);
}
#[tokio::test]
async fn test_tool_loop_emits_iteration_start() {
use futures::StreamExt;
let mock = Arc::new(mock_for("test", "test-model"));
mock.queue_response(sample_response("Hello!"));
let registry = Arc::new(ToolRegistry::<()>::new());
let params = ChatParams {
messages: vec![ChatMessage::user("Hi")],
..Default::default()
};
let config = ToolLoopConfig::default();
let stream = tool_loop_stream(mock, registry, params, config, Arc::new(()));
let results: Vec<Result<LoopEvent, _>> = stream.collect().await;
let events: Vec<LoopEvent> = results.into_iter().filter_map(Result::ok).collect();
assert!(
events
.iter()
.any(|e| matches!(e, LoopEvent::IterationStart { iteration: 1, .. }))
);
}
#[tokio::test]
async fn test_tool_loop_emits_tool_execution_events() {
use futures::StreamExt;
let mock = Arc::new(mock_for("test", "test-model"));
mock.queue_response(sample_tool_response(vec![ToolCall {
id: "call_1".into(),
name: "add".into(),
arguments: json!({"a": 2, "b": 3}),
}]));
mock.queue_response(sample_response("The answer is 5"));
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(AddTool);
let registry = Arc::new(registry);
let params = ChatParams {
messages: vec![ChatMessage::user("What is 2 + 3?")],
..Default::default()
};
let config = ToolLoopConfig::default();
let stream = tool_loop_stream(mock, registry, params, config, Arc::new(()));
let results: Vec<Result<LoopEvent, _>> = stream.collect().await;
let events: Vec<LoopEvent> = results.into_iter().filter_map(Result::ok).collect();
assert!(events.iter().any(
|e| matches!(e, LoopEvent::ToolExecutionStart { tool_name, .. } if tool_name == "add")
));
assert!(events.iter().any(|e| matches!(e, LoopEvent::ToolExecutionEnd { tool_name, result, .. } if tool_name == "add" && result.content == "5")));
let iteration_starts: Vec<_> = events
.iter()
.filter(|e| matches!(e, LoopEvent::IterationStart { .. }))
.collect();
assert_eq!(iteration_starts.len(), 2);
}
#[tokio::test]
async fn test_tool_loop_event_duration_is_positive() {
use futures::StreamExt;
let mock = Arc::new(mock_for("test", "test-model"));
mock.queue_response(sample_tool_response(vec![ToolCall {
id: "call_1".into(),
name: "add".into(),
arguments: json!({"a": 1, "b": 2}),
}]));
mock.queue_response(sample_response("Done"));
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(AddTool);
let registry = Arc::new(registry);
let params = ChatParams {
messages: vec![ChatMessage::user("Add")],
..Default::default()
};
let config = ToolLoopConfig::default();
let stream = tool_loop_stream(mock, registry, params, config, Arc::new(()));
let results: Vec<Result<LoopEvent, _>> = stream.collect().await;
let events: Vec<LoopEvent> = results.into_iter().filter_map(Result::ok).collect();
let end_event = events
.iter()
.find(|e| matches!(e, LoopEvent::ToolExecutionEnd { .. }));
assert!(end_event.is_some());
if let Some(LoopEvent::ToolExecutionEnd { duration, .. }) = end_event {
assert!(*duration >= std::time::Duration::ZERO);
}
}
#[tokio::test]
async fn test_tool_loop_events_with_parallel_execution() {
use futures::StreamExt;
let mock = Arc::new(mock_for("test", "test-model"));
mock.queue_response(sample_tool_response(vec![
ToolCall {
id: "c1".into(),
name: "add".into(),
arguments: json!({"a": 1, "b": 2}),
},
ToolCall {
id: "c2".into(),
name: "add".into(),
arguments: json!({"a": 3, "b": 4}),
},
]));
mock.queue_response(sample_response("Both done"));
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(AddTool);
let registry = Arc::new(registry);
let params = ChatParams {
messages: vec![ChatMessage::user("Parallel")],
..Default::default()
};
let config = ToolLoopConfig {
parallel_tool_execution: true,
..Default::default()
};
let stream = tool_loop_stream(mock, registry, params, config, Arc::new(()));
let results: Vec<Result<LoopEvent, _>> = stream.collect().await;
let events: Vec<LoopEvent> = results.into_iter().filter_map(Result::ok).collect();
let starts: Vec<_> = events
.iter()
.filter(|e| matches!(e, LoopEvent::ToolExecutionStart { .. }))
.collect();
let ends: Vec<_> = events
.iter()
.filter(|e| matches!(e, LoopEvent::ToolExecutionEnd { .. }))
.collect();
assert_eq!(starts.len(), 2);
assert_eq!(ends.len(), 2);
}
#[tokio::test]
async fn test_tool_loop_stream_emits_events() {
use futures::StreamExt;
let mock = Arc::new(mock_for("test", "test-model"));
mock.queue_stream(vec![
crate::stream::StreamEvent::ToolCallStart {
index: 0,
id: "call_1".into(),
name: "add".into(),
},
crate::stream::StreamEvent::ToolCallDelta {
index: 0,
json_chunk: r#"{"a":2,"b":3}"#.into(),
},
crate::stream::StreamEvent::ToolCallComplete {
index: 0,
call: ToolCall {
id: "call_1".into(),
name: "add".into(),
arguments: json!({"a": 2, "b": 3}),
},
},
crate::stream::StreamEvent::Done {
stop_reason: StopReason::ToolUse,
},
]);
mock.queue_stream(vec![
crate::stream::StreamEvent::TextDelta("The answer is 5".into()),
crate::stream::StreamEvent::Done {
stop_reason: StopReason::EndTurn,
},
]);
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(AddTool);
let registry = Arc::new(registry);
let params = ChatParams {
messages: vec![ChatMessage::user("What is 2+3?")],
..Default::default()
};
let config = ToolLoopConfig::default();
let stream = tool_loop_stream(mock, registry, params, config, Arc::new(()));
let results: Vec<Result<LoopEvent, _>> = stream.collect().await;
let events: Vec<LoopEvent> = results.into_iter().filter_map(Result::ok).collect();
assert!(
events
.iter()
.any(|e| matches!(e, LoopEvent::IterationStart { .. }))
);
assert!(
events
.iter()
.any(|e| matches!(e, LoopEvent::ToolExecutionStart { .. }))
);
assert!(
events
.iter()
.any(|e| matches!(e, LoopEvent::ToolExecutionEnd { .. }))
);
}
#[test]
fn test_tool_loop_config_debug() {
let config = ToolLoopConfig::default();
let debug = format!("{config:?}");
assert!(debug.contains("ToolLoopConfig"));
}
#[test]
fn test_stop_decision_equality() {
assert_eq!(StopDecision::Continue, StopDecision::Continue);
assert_eq!(StopDecision::Stop, StopDecision::Stop);
assert_eq!(
StopDecision::StopWithReason("done".into()),
StopDecision::StopWithReason("done".into())
);
assert_ne!(StopDecision::Continue, StopDecision::Stop);
}
#[test]
fn test_stop_decision_debug() {
let decision = StopDecision::StopWithReason("token limit".into());
let debug = format!("{decision:?}");
assert!(debug.contains("StopWithReason"));
assert!(debug.contains("token limit"));
}
#[test]
fn test_stop_context_debug() {
use crate::chat::ChatResponse;
let response = ChatResponse {
content: vec![ContentBlock::Text("test".into())],
usage: Usage::default(),
stop_reason: StopReason::EndTurn,
model: "test".into(),
metadata: std::collections::HashMap::new(),
};
let usage = Usage::default();
let results: Vec<crate::chat::ToolResult> = vec![];
let ctx = StopContext {
iteration: 1,
response: &response,
total_usage: &usage,
tool_calls_executed: 5,
last_tool_results: &results,
};
let debug = format!("{ctx:?}");
assert!(debug.contains("iteration: 1"));
assert!(debug.contains("tool_calls_executed: 5"));
}
#[test]
fn test_tool_loop_config_debug_with_stop_when() {
let config = ToolLoopConfig {
stop_when: Some(Arc::new(|_| StopDecision::Continue)),
..Default::default()
};
let debug = format!("{config:?}");
assert!(debug.contains("has_stop_when: true"));
}
#[test]
fn test_tool_loop_config_debug_without_stop_when() {
let config = ToolLoopConfig::default();
let debug = format!("{config:?}");
assert!(debug.contains("has_stop_when: false"));
}
#[tokio::test]
async fn test_tool_loop_stop_on_first_response() {
let mock = mock_for("test", "test-model");
mock.queue_response(sample_tool_response(vec![ToolCall {
id: "call_1".into(),
name: "add".into(),
arguments: json!({"a": 2, "b": 3}),
}]));
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(AddTool);
let params = ChatParams {
messages: vec![ChatMessage::user("What is 2 + 3?")],
..Default::default()
};
let config = ToolLoopConfig {
stop_when: Some(Arc::new(|_| StopDecision::Stop)),
..Default::default()
};
let result = tool_loop(&mock, ®istry, params, config, &())
.await
.unwrap();
assert_eq!(result.iterations, 1);
assert!(!result.response.tool_calls().is_empty());
}
#[tokio::test]
async fn test_tool_loop_stop_after_tool_call_limit() {
let mock = mock_for("test", "test-model");
mock.queue_response(sample_tool_response(vec![ToolCall {
id: "c1".into(),
name: "add".into(),
arguments: json!({"a": 1, "b": 2}),
}]));
mock.queue_response(sample_tool_response(vec![ToolCall {
id: "c2".into(),
name: "add".into(),
arguments: json!({"a": 3, "b": 4}),
}]));
mock.queue_response(sample_response("Final"));
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(AddTool);
let params = ChatParams {
messages: vec![ChatMessage::user("Chain calls")],
..Default::default()
};
let config = ToolLoopConfig {
stop_when: Some(Arc::new(|ctx| {
if ctx.tool_calls_executed >= 1 {
StopDecision::StopWithReason("Tool call limit".into())
} else {
StopDecision::Continue
}
})),
..Default::default()
};
let result = tool_loop(&mock, ®istry, params, config, &())
.await
.unwrap();
assert_eq!(result.iterations, 2);
}
#[tokio::test]
async fn test_tool_loop_stop_context_has_last_results() {
use std::sync::Mutex;
let mock = mock_for("test", "test-model");
mock.queue_response(sample_tool_response(vec![ToolCall {
id: "call_1".into(),
name: "add".into(),
arguments: json!({"a": 2, "b": 3}),
}]));
mock.queue_response(sample_response("Done"));
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(AddTool);
let captured = Arc::new(Mutex::new(Vec::new()));
let captured_clone = Arc::clone(&captured);
let params = ChatParams {
messages: vec![ChatMessage::user("Test")],
..Default::default()
};
let config = ToolLoopConfig {
stop_when: Some(Arc::new(move |ctx| {
captured_clone.lock().unwrap().push((
ctx.iteration,
ctx.tool_calls_executed,
ctx.last_tool_results
.iter()
.map(|r| r.content.clone())
.collect::<Vec<_>>(),
));
StopDecision::Continue
})),
..Default::default()
};
let _result = tool_loop(&mock, ®istry, params, config, &())
.await
.unwrap();
let checks = captured.lock().unwrap();
assert_eq!(checks[0].0, 1);
assert_eq!(checks[0].1, 0);
assert!(checks[0].2.is_empty());
assert_eq!(checks[1].0, 2);
assert_eq!(checks[1].1, 1);
assert_eq!(checks[1].2, vec!["5".to_string()]);
}
#[tokio::test]
async fn test_tool_loop_stop_on_specific_tool() {
let mock = mock_for("test", "test-model");
mock.queue_response(sample_tool_response(vec![ToolCall {
id: "c1".into(),
name: "add".into(),
arguments: json!({"a": 1, "b": 2}),
}]));
mock.queue_response(sample_tool_response(vec![ToolCall {
id: "c2".into(),
name: "final_answer".into(),
arguments: json!({"answer": "The result is 3"}),
}]));
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(AddTool);
registry.register(tool_fn(
ToolDefinition {
name: "final_answer".into(),
description: "Provide final answer".into(),
parameters: JsonSchema::new(json!({"type": "object"})),
retry: None,
},
|_| async { Ok(ToolOutput::new("acknowledged")) },
));
let params = ChatParams {
messages: vec![ChatMessage::user("Calculate")],
..Default::default()
};
let config = ToolLoopConfig {
stop_when: Some(Arc::new(|ctx| {
let has_final = ctx
.response
.tool_calls()
.iter()
.any(|c| c.name == "final_answer");
if has_final {
StopDecision::StopWithReason("Final answer provided".into())
} else {
StopDecision::Continue
}
})),
..Default::default()
};
let result = tool_loop(&mock, ®istry, params, config, &())
.await
.unwrap();
assert_eq!(result.iterations, 2);
assert!(
result
.response
.tool_calls()
.iter()
.any(|c| c.name == "final_answer")
);
}
#[tokio::test]
async fn test_tool_loop_stream_stop_early() {
use futures::StreamExt;
let mock = Arc::new(mock_for("test", "test-model"));
mock.queue_stream(vec![
crate::stream::StreamEvent::ToolCallComplete {
index: 0,
call: ToolCall {
id: "call_1".into(),
name: "add".into(),
arguments: json!({"a": 2, "b": 3}),
},
},
crate::stream::StreamEvent::Done {
stop_reason: StopReason::ToolUse,
},
]);
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(AddTool);
let registry = Arc::new(registry);
let params = ChatParams {
messages: vec![ChatMessage::user("What is 2+3?")],
..Default::default()
};
let config = ToolLoopConfig {
stop_when: Some(Arc::new(|_| StopDecision::Stop)),
..Default::default()
};
let stream = tool_loop_stream(mock, registry, params, config, Arc::new(()));
let events: Vec<_> = stream.collect::<Vec<_>>().await;
assert!(events.iter().all(Result::is_ok));
assert!(events.iter().any(|r| matches!(r, Ok(LoopEvent::Done(_)))));
}
#[tokio::test]
async fn test_tool_loop_stream_stop_after_tool_execution() {
use futures::StreamExt;
let mock = Arc::new(mock_for("test", "test-model"));
mock.queue_stream(vec![
crate::stream::StreamEvent::ToolCallComplete {
index: 0,
call: ToolCall {
id: "call_1".into(),
name: "add".into(),
arguments: json!({"a": 2, "b": 3}),
},
},
crate::stream::StreamEvent::Done {
stop_reason: StopReason::ToolUse,
},
]);
mock.queue_stream(vec![
crate::stream::StreamEvent::TextDelta("Final response".into()),
crate::stream::StreamEvent::Done {
stop_reason: StopReason::EndTurn,
},
]);
mock.queue_stream(vec![
crate::stream::StreamEvent::TextDelta("This should never appear".into()),
crate::stream::StreamEvent::Done {
stop_reason: StopReason::EndTurn,
},
]);
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(AddTool);
let registry = Arc::new(registry);
let params = ChatParams {
messages: vec![ChatMessage::user("What is 2+3?")],
..Default::default()
};
let config = ToolLoopConfig {
stop_when: Some(Arc::new(|ctx| {
if ctx.tool_calls_executed >= 1 {
StopDecision::Stop
} else {
StopDecision::Continue
}
})),
..Default::default()
};
let stream = tool_loop_stream(mock, registry, params, config, Arc::new(()));
let events: Vec<_> = stream.filter_map(|r| async { r.ok() }).collect().await;
let done_count = events
.iter()
.filter(|e| matches!(e, LoopEvent::Done(_)))
.count();
assert_eq!(done_count, 1);
let has_final_response = events
.iter()
.any(|e| matches!(e, LoopEvent::TextDelta(t) if t.contains("Final response")));
assert!(has_final_response);
let has_third_iteration = events
.iter()
.any(|e| matches!(e, LoopEvent::TextDelta(t) if t.contains("never appear")));
assert!(!has_third_iteration);
}
#[tokio::test]
async fn test_tool_loop_stop_continues_when_condition_not_met() {
let mock = mock_for("test", "test-model");
mock.queue_response(sample_tool_response(vec![ToolCall {
id: "c1".into(),
name: "add".into(),
arguments: json!({"a": 1, "b": 2}),
}]));
mock.queue_response(sample_response("The answer is 3"));
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(AddTool);
let params = ChatParams {
messages: vec![ChatMessage::user("Add 1 and 2")],
..Default::default()
};
let config = ToolLoopConfig {
stop_when: Some(Arc::new(|_| StopDecision::Continue)),
..Default::default()
};
let result = tool_loop(&mock, ®istry, params, config, &())
.await
.unwrap();
assert_eq!(result.iterations, 2);
assert_eq!(result.response.text(), Some("The answer is 3"));
}
#[test]
fn test_loop_detection_config_default() {
let config = LoopDetectionConfig::default();
assert_eq!(config.threshold, 3);
assert_eq!(config.action, LoopAction::Warn);
}
#[test]
fn test_loop_action_equality() {
assert_eq!(LoopAction::Warn, LoopAction::Warn);
assert_eq!(LoopAction::Stop, LoopAction::Stop);
assert_eq!(LoopAction::InjectWarning, LoopAction::InjectWarning);
assert_ne!(LoopAction::Warn, LoopAction::Stop);
}
#[test]
fn test_loop_detection_state_no_loop() {
use super::loop_detection::LoopDetectionState;
let mut state = LoopDetectionState::default();
let calls = vec![ToolCall {
id: "c1".into(),
name: "add".into(),
arguments: json!({"a": 1, "b": 2}),
}];
assert!(state.update(&calls, 3).is_none());
let calls2 = vec![ToolCall {
id: "c2".into(),
name: "add".into(),
arguments: json!({"a": 3, "b": 4}), }];
assert!(state.update(&calls2, 3).is_none());
}
#[test]
fn test_loop_detection_state_detects_loop() {
use super::loop_detection::LoopDetectionState;
let mut state = LoopDetectionState::default();
let calls = vec![ToolCall {
id: "c1".into(),
name: "search".into(),
arguments: json!({"query": "test"}),
}];
assert!(state.update(&calls, 3).is_none());
assert!(state.update(&calls, 3).is_none());
let result = state.update(&calls, 3);
assert!(result.is_some());
let (name, count) = result.unwrap();
assert_eq!(name, "search");
assert_eq!(count, 3);
}
#[test]
fn test_loop_detection_state_reset() {
use super::loop_detection::LoopDetectionState;
let mut state = LoopDetectionState::default();
let calls = vec![ToolCall {
id: "c1".into(),
name: "search".into(),
arguments: json!({"query": "test"}),
}];
state.update(&calls, 3);
state.update(&calls, 3);
state.reset();
assert!(state.update(&calls, 3).is_none());
}
#[test]
fn test_tool_loop_config_debug_with_loop_detection() {
let config = ToolLoopConfig {
loop_detection: Some(LoopDetectionConfig::default()),
..Default::default()
};
let debug = format!("{config:?}");
assert!(debug.contains("loop_detection: Some"));
}
#[tokio::test]
async fn test_tool_loop_detects_loop_warn() {
use futures::StreamExt;
let mock = Arc::new(mock_for("test", "test-model"));
for _ in 0..3 {
mock.queue_response(sample_tool_response(vec![ToolCall {
id: "c".into(),
name: "search".into(),
arguments: json!({"query": "foo"}),
}]));
}
mock.queue_response(sample_response("Done"));
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(tool_fn(
ToolDefinition {
name: "search".into(),
description: "Search".into(),
parameters: JsonSchema::new(json!({"type": "object"})),
retry: None,
},
|_| async { Ok(ToolOutput::new("result")) },
));
let registry = Arc::new(registry);
let params = ChatParams {
messages: vec![ChatMessage::user("Search")],
..Default::default()
};
let config = ToolLoopConfig {
loop_detection: Some(LoopDetectionConfig {
threshold: 3,
action: LoopAction::Warn,
}),
..Default::default()
};
let stream = tool_loop_stream(mock, registry, params, config, Arc::new(()));
let results: Vec<Result<LoopEvent, _>> = stream.collect().await;
let events: Vec<LoopEvent> = results.into_iter().filter_map(Result::ok).collect();
let loop_events: Vec<_> = events
.iter()
.filter(|e| matches!(e, LoopEvent::LoopDetected { .. }))
.collect();
assert_eq!(loop_events.len(), 1);
if let LoopEvent::LoopDetected {
tool_name,
consecutive_count,
action,
} = &loop_events[0]
{
assert_eq!(tool_name, "search");
assert_eq!(*consecutive_count, 3);
assert_eq!(*action, LoopAction::Warn);
}
}
#[tokio::test]
async fn test_tool_loop_detects_loop_stop() {
let mock = mock_for("test", "test-model");
for _ in 0..5 {
mock.queue_response(sample_tool_response(vec![ToolCall {
id: "c".into(),
name: "search".into(),
arguments: json!({"query": "foo"}),
}]));
}
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(tool_fn(
ToolDefinition {
name: "search".into(),
description: "Search".into(),
parameters: JsonSchema::new(json!({"type": "object"})),
retry: None,
},
|_| async { Ok(ToolOutput::new("result")) },
));
let params = ChatParams {
messages: vec![ChatMessage::user("Search")],
..Default::default()
};
let config = ToolLoopConfig {
loop_detection: Some(LoopDetectionConfig {
threshold: 3,
action: LoopAction::Stop,
}),
..Default::default()
};
let result = tool_loop(&mock, ®istry, params, config, &())
.await
.unwrap();
assert!(matches!(
result.termination_reason,
TerminationReason::LoopDetected { ref tool_name, count }
if tool_name == "search" && count == 3
));
}
#[tokio::test]
async fn test_tool_loop_detects_loop_inject_warning() {
let mock = mock_for("test", "test-model");
for _ in 0..3 {
mock.queue_response(sample_tool_response(vec![ToolCall {
id: "c".into(),
name: "search".into(),
arguments: json!({"query": "foo"}),
}]));
}
mock.queue_response(sample_response("I'll try something different"));
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(tool_fn(
ToolDefinition {
name: "search".into(),
description: "Search".into(),
parameters: JsonSchema::new(json!({"type": "object"})),
retry: None,
},
|_| async { Ok(ToolOutput::new("result")) },
));
let params = ChatParams {
messages: vec![ChatMessage::user("Search")],
..Default::default()
};
let config = ToolLoopConfig {
loop_detection: Some(LoopDetectionConfig {
threshold: 3,
action: LoopAction::InjectWarning,
}),
..Default::default()
};
let result = tool_loop(&mock, ®istry, params, config, &())
.await
.unwrap();
assert_eq!(result.iterations, 4);
let recorded = mock.recorded_calls();
let last_call = &recorded[3];
let has_warning = last_call.messages.iter().any(|m| {
if m.role == ChatRole::System {
m.content.iter().any(|b| {
if let ContentBlock::Text(t) = b {
t.contains("identical arguments")
} else {
false
}
})
} else {
false
}
});
assert!(has_warning, "Warning message should be in conversation");
}
#[tokio::test]
async fn test_tool_loop_no_false_positive_different_args() {
let mock = mock_for("test", "test-model");
mock.queue_response(sample_tool_response(vec![ToolCall {
id: "c1".into(),
name: "search".into(),
arguments: json!({"query": "foo"}),
}]));
mock.queue_response(sample_tool_response(vec![ToolCall {
id: "c2".into(),
name: "search".into(),
arguments: json!({"query": "bar"}), }]));
mock.queue_response(sample_tool_response(vec![ToolCall {
id: "c3".into(),
name: "search".into(),
arguments: json!({"query": "baz"}), }]));
mock.queue_response(sample_response("Done"));
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(tool_fn(
ToolDefinition {
name: "search".into(),
description: "Search".into(),
parameters: JsonSchema::new(json!({"type": "object"})),
retry: None,
},
|_| async { Ok(ToolOutput::new("result")) },
));
let params = ChatParams {
messages: vec![ChatMessage::user("Search")],
..Default::default()
};
let config = ToolLoopConfig {
loop_detection: Some(LoopDetectionConfig {
threshold: 3,
action: LoopAction::Stop,
}),
..Default::default()
};
let result = tool_loop(&mock, ®istry, params, config, &()).await;
assert!(result.is_ok());
assert_eq!(result.unwrap().iterations, 4);
}
#[tokio::test]
async fn test_tool_loop_stream_detects_loop_stop() {
use futures::StreamExt;
let mock = Arc::new(mock_for("test", "test-model"));
for _ in 0..5 {
mock.queue_stream(vec![
crate::stream::StreamEvent::ToolCallComplete {
index: 0,
call: ToolCall {
id: "c".into(),
name: "search".into(),
arguments: json!({"query": "foo"}),
},
},
crate::stream::StreamEvent::Done {
stop_reason: StopReason::ToolUse,
},
]);
}
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(tool_fn(
ToolDefinition {
name: "search".into(),
description: "Search".into(),
parameters: JsonSchema::new(json!({"type": "object"})),
retry: None,
},
|_| async { Ok(ToolOutput::new("result")) },
));
let registry = Arc::new(registry);
let params = ChatParams {
messages: vec![ChatMessage::user("Search")],
..Default::default()
};
let config = ToolLoopConfig {
loop_detection: Some(LoopDetectionConfig {
threshold: 3,
action: LoopAction::Stop,
}),
..Default::default()
};
let stream = tool_loop_stream(mock, registry, params, config, Arc::new(()));
let results: Vec<_> = stream.collect::<Vec<_>>().await;
let events: Vec<LoopEvent> = results.into_iter().filter_map(Result::ok).collect();
assert!(events.iter().any(|e| matches!(
e,
LoopEvent::Done(result) if matches!(result.termination_reason, TerminationReason::LoopDetected { .. })
)));
}
#[test]
fn test_compute_tool_calls_signature_single() {
use super::loop_detection::compute_tool_calls_signature;
let calls = vec![ToolCall {
id: "c1".into(),
name: "search".into(),
arguments: json!({"query": "test"}),
}];
let sig = compute_tool_calls_signature(&calls);
assert_eq!(sig.0, "search");
assert!(sig.1.contains("query"));
}
#[test]
fn test_compute_tool_calls_signature_multiple() {
use super::loop_detection::compute_tool_calls_signature;
let calls = vec![
ToolCall {
id: "c1".into(),
name: "search".into(),
arguments: json!({"query": "a"}),
},
ToolCall {
id: "c2".into(),
name: "read".into(),
arguments: json!({"file": "b"}),
},
];
let sig = compute_tool_calls_signature(&calls);
assert_eq!(sig.0, "search+read");
assert!(sig.1.contains('|'));
}
#[test]
fn test_compute_tool_calls_signature_empty() {
use super::loop_detection::compute_tool_calls_signature;
let calls: Vec<ToolCall> = vec![];
let sig = compute_tool_calls_signature(&calls);
assert!(sig.0.is_empty());
assert!(sig.1.is_empty());
}
#[test]
fn test_termination_reason_complete() {
let reason = TerminationReason::Complete;
assert_eq!(reason, TerminationReason::Complete);
}
#[test]
fn test_termination_reason_stop_condition_without_reason() {
let reason = TerminationReason::StopCondition { reason: None };
assert!(matches!(
reason,
TerminationReason::StopCondition { reason: None }
));
}
#[test]
fn test_termination_reason_stop_condition_with_reason() {
let reason = TerminationReason::StopCondition {
reason: Some("budget exceeded".into()),
};
assert!(matches!(
reason,
TerminationReason::StopCondition { reason: Some(r) } if r == "budget exceeded"
));
}
#[test]
fn test_termination_reason_max_iterations() {
let reason = TerminationReason::MaxIterations { limit: 10 };
assert!(matches!(
reason,
TerminationReason::MaxIterations { limit: 10 }
));
}
#[test]
fn test_termination_reason_loop_detected() {
let reason = TerminationReason::LoopDetected {
tool_name: "search".into(),
count: 5,
};
assert!(matches!(
reason,
TerminationReason::LoopDetected { tool_name, count } if tool_name == "search" && count == 5
));
}
#[test]
fn test_termination_reason_clone() {
let reason = TerminationReason::StopCondition {
reason: Some("test".into()),
};
let cloned = reason.clone();
assert_eq!(reason, cloned);
}
#[test]
fn test_termination_reason_debug() {
let reason = TerminationReason::Complete;
let debug = format!("{reason:?}");
assert!(debug.contains("Complete"));
}
#[tokio::test]
async fn test_tool_loop_returns_termination_reason_complete() {
let mock = mock_for("test", "test-model");
mock.queue_response(sample_response("Done!"));
let registry: ToolRegistry<()> = ToolRegistry::new();
let params = ChatParams {
messages: vec![ChatMessage::user("Hello")],
..Default::default()
};
let result = tool_loop(&mock, ®istry, params, ToolLoopConfig::default(), &())
.await
.unwrap();
assert_eq!(result.termination_reason, TerminationReason::Complete);
}
#[tokio::test]
async fn test_tool_loop_returns_termination_reason_stop_condition() {
let mock = mock_for("test", "test-model");
mock.queue_response(sample_response("First response"));
let registry: ToolRegistry<()> = ToolRegistry::new();
let params = ChatParams {
messages: vec![ChatMessage::user("Hello")],
..Default::default()
};
let config = ToolLoopConfig {
stop_when: Some(Arc::new(|_ctx| {
StopDecision::StopWithReason("manual stop".into())
})),
..Default::default()
};
let result = tool_loop(&mock, ®istry, params, config, &())
.await
.unwrap();
assert!(matches!(
result.termination_reason,
TerminationReason::StopCondition { reason: Some(r) } if r == "manual stop"
));
}
#[tokio::test]
async fn test_tool_loop_timeout_returns_immediately() {
let mock = mock_for("test", "test-model");
for _ in 0..5 {
mock.queue_response(sample_tool_response(vec![ToolCall {
id: "c".into(),
name: "add".into(),
arguments: json!({"a": 1, "b": 2}),
}]));
}
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(AddTool);
let params = ChatParams {
messages: vec![ChatMessage::user("Timeout test")],
..Default::default()
};
let config = ToolLoopConfig {
timeout: Some(Duration::ZERO),
..Default::default()
};
let result = tool_loop(&mock, ®istry, params, config, &())
.await
.unwrap();
assert!(matches!(
result.termination_reason,
TerminationReason::Timeout { limit } if limit == Duration::ZERO
));
}
#[tokio::test]
async fn test_tool_loop_no_timeout_completes_normally() {
let mock = mock_for("test", "test-model");
mock.queue_response(sample_response("Done quickly"));
let registry: ToolRegistry<()> = ToolRegistry::new();
let params = ChatParams {
messages: vec![ChatMessage::user("No timeout")],
..Default::default()
};
let config = ToolLoopConfig {
timeout: None,
..Default::default()
};
let result = tool_loop(&mock, ®istry, params, config, &())
.await
.unwrap();
assert!(matches!(
result.termination_reason,
TerminationReason::Complete
));
}
#[tokio::test]
async fn test_tool_loop_stream_timeout() {
use futures::StreamExt;
let mock = Arc::new(mock_for("test", "test-model"));
for _ in 0..10 {
mock.queue_response(sample_tool_response(vec![ToolCall {
id: "c".into(),
name: "add".into(),
arguments: json!({"a": 1, "b": 2}),
}]));
}
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(AddTool);
let registry = Arc::new(registry);
let params = ChatParams {
messages: vec![ChatMessage::user("Stream timeout")],
..Default::default()
};
let config = ToolLoopConfig {
timeout: Some(Duration::ZERO),
..Default::default()
};
let ctx = Arc::new(());
let stream = tool_loop_stream(mock, registry, params, config, ctx);
let results: Vec<_> = stream.collect::<Vec<_>>().await;
let events: Vec<LoopEvent> = results.into_iter().filter_map(Result::ok).collect();
assert!(events.iter().any(|e| matches!(
e,
LoopEvent::Done(result) if matches!(result.termination_reason, TerminationReason::Timeout { .. })
)), "Expected Done event with Timeout termination reason");
}
struct FlakeyTool {
fail_count: std::sync::atomic::AtomicU32,
max_failures: u32,
}
impl FlakeyTool {
fn new(max_failures: u32) -> Self {
Self {
fail_count: std::sync::atomic::AtomicU32::new(0),
max_failures,
}
}
}
impl ToolHandler<()> for FlakeyTool {
fn definition(&self) -> ToolDefinition {
ToolDefinition {
name: "flakey".into(),
description: "A tool that fails sometimes".into(),
parameters: JsonSchema::new(json!({"type": "object"})),
retry: Some(crate::provider::ToolRetryConfig {
max_retries: 3,
initial_backoff: Duration::from_millis(1),
max_backoff: Duration::from_millis(10),
backoff_multiplier: 2.0,
jitter: 0.0, retry_if: None,
}),
}
}
fn execute<'a>(
&'a self,
_input: Value,
_ctx: &'a (),
) -> Pin<Box<dyn Future<Output = Result<ToolOutput, ToolError>> + Send + 'a>> {
Box::pin(async move {
let count = self
.fail_count
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
if count < self.max_failures {
Err(ToolError::new("transient failure"))
} else {
Ok(ToolOutput::new("success after retries"))
}
})
}
}
#[tokio::test]
async fn test_tool_retry_succeeds_after_failures() {
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(FlakeyTool::new(2));
let call = ToolCall {
id: "c1".into(),
name: "flakey".into(),
arguments: json!({}),
};
let result = registry.execute(&call, &()).await;
assert!(!result.is_error);
assert_eq!(result.content, "success after retries");
}
#[tokio::test]
async fn test_tool_retry_exhausted() {
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(FlakeyTool::new(10));
let call = ToolCall {
id: "c1".into(),
name: "flakey".into(),
arguments: json!({}),
};
let result = registry.execute(&call, &()).await;
assert!(result.is_error);
assert!(result.content.contains("transient failure"));
}
struct SelectiveRetryTool {
call_count: std::sync::atomic::AtomicU32,
}
impl SelectiveRetryTool {
fn new() -> Self {
Self {
call_count: std::sync::atomic::AtomicU32::new(0),
}
}
}
impl ToolHandler<()> for SelectiveRetryTool {
fn definition(&self) -> ToolDefinition {
ToolDefinition {
name: "selective".into(),
description: "A tool with selective retry".into(),
parameters: JsonSchema::new(json!({"type": "object"})),
retry: Some(crate::provider::ToolRetryConfig {
max_retries: 3,
initial_backoff: Duration::from_millis(1),
max_backoff: Duration::from_millis(10),
backoff_multiplier: 2.0,
jitter: 0.0,
retry_if: Some(Arc::new(|msg: &str| msg.contains("TRANSIENT"))),
}),
}
}
fn execute<'a>(
&'a self,
_input: Value,
_ctx: &'a (),
) -> Pin<Box<dyn Future<Output = Result<ToolOutput, ToolError>> + Send + 'a>> {
Box::pin(async move {
let count = self
.call_count
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
if count == 0 {
Err(ToolError::new("permanent failure"))
} else {
Ok(ToolOutput::new("should not reach"))
}
})
}
}
#[tokio::test]
async fn test_tool_retry_predicate_prevents_retry() {
let tool = SelectiveRetryTool::new();
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(tool);
let call = ToolCall {
id: "c1".into(),
name: "selective".into(),
arguments: json!({}),
};
let result = registry.execute(&call, &()).await;
assert!(
result.is_error,
"Expected is_error=true but got false. Content: {}",
result.content
);
assert!(result.content.contains("permanent failure"));
}
#[test]
fn test_tool_retry_config_default() {
let config = crate::provider::ToolRetryConfig::default();
assert_eq!(config.max_retries, 3);
assert_eq!(config.initial_backoff, Duration::from_millis(100));
assert_eq!(config.max_backoff, Duration::from_secs(5));
assert!((config.backoff_multiplier - 2.0).abs() < f64::EPSILON);
assert!((config.jitter - 0.5).abs() < f64::EPSILON);
assert!(config.retry_if.is_none());
}
#[test]
fn test_tool_retry_config_partial_eq() {
let config1 = crate::provider::ToolRetryConfig::default();
let config2 = crate::provider::ToolRetryConfig::default();
assert_eq!(config1, config2);
let config3 = crate::provider::ToolRetryConfig {
max_retries: 5,
..Default::default()
};
assert_ne!(config1, config3);
}
#[test]
fn test_tool_retry_config_debug() {
let config = crate::provider::ToolRetryConfig::default();
let debug = format!("{config:?}");
assert!(debug.contains("ToolRetryConfig"));
assert!(debug.contains("max_retries"));
}
use super::loop_resumable::{ToolLoopHandle, TurnResult};
use crate::test_helpers::sample_tool_response_with_text;
#[tokio::test]
async fn test_resumable_no_tools_completes() {
let mock = mock_for("test", "test-model");
mock.queue_response(sample_response("Hello!"));
let registry: ToolRegistry<()> = ToolRegistry::new();
let params = ChatParams {
messages: vec![ChatMessage::user("Hi")],
..Default::default()
};
let mut handle = ToolLoopHandle::new(&mock, ®istry, params, ToolLoopConfig::default(), &());
let turn = handle.next_turn().await;
match turn {
TurnResult::Completed(done) => {
assert_eq!(done.termination_reason, TerminationReason::Complete);
assert_eq!(done.response.text(), Some("Hello!"));
}
_ => panic!("expected Completed"),
}
assert!(handle.is_finished());
}
#[tokio::test]
async fn test_resumable_one_tool_iteration() {
let mock = mock_for("test", "test-model");
mock.queue_response(sample_tool_response(vec![ToolCall {
id: "call_1".into(),
name: "add".into(),
arguments: json!({"a": 2, "b": 3}),
}]));
mock.queue_response(sample_response("The answer is 5"));
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(AddTool);
let params = ChatParams {
messages: vec![ChatMessage::user("What is 2 + 3?")],
..Default::default()
};
let mut handle = ToolLoopHandle::new(&mock, ®istry, params, ToolLoopConfig::default(), &());
match handle.next_turn().await {
TurnResult::Yielded(turn) => {
assert_eq!(turn.iteration, 1);
assert_eq!(turn.tool_calls.len(), 1);
assert_eq!(turn.tool_calls[0].name, "add");
assert_eq!(turn.results.len(), 1);
assert_eq!(turn.results[0].content, "5");
turn.continue_loop();
}
_ => panic!("expected Yielded"),
}
match handle.next_turn().await {
TurnResult::Completed(done) => {
assert_eq!(done.iterations, 2);
assert_eq!(done.termination_reason, TerminationReason::Complete);
assert_eq!(done.response.text(), Some("The answer is 5"));
}
_ => panic!("expected Completed"),
}
}
#[tokio::test]
async fn test_resumable_inject_messages() {
let mock = mock_for("test", "test-model");
mock.queue_response(sample_tool_response(vec![ToolCall {
id: "c1".into(),
name: "add".into(),
arguments: json!({"a": 1, "b": 2}),
}]));
mock.queue_response(sample_response("Done with injection"));
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(AddTool);
let params = ChatParams {
messages: vec![ChatMessage::user("Add numbers")],
..Default::default()
};
let mut handle = ToolLoopHandle::new(&mock, ®istry, params, ToolLoopConfig::default(), &());
match handle.next_turn().await {
TurnResult::Yielded(turn) => {
turn.inject_and_continue(vec![ChatMessage::system("Additional context from worker")]);
}
_ => panic!("expected Yielded"),
}
assert!(matches!(handle.next_turn().await, TurnResult::Completed(_)));
let recorded = mock.recorded_calls();
let last_call = &recorded[1];
let has_injection = last_call.messages.iter().any(|m| {
m.content.iter().any(|b| {
if let ContentBlock::Text(t) = b {
t.contains("Additional context from worker")
} else {
false
}
})
});
assert!(has_injection, "Injected message should be in conversation");
}
#[tokio::test]
async fn test_resumable_stop_command() {
let mock = mock_for("test", "test-model");
mock.queue_response(sample_tool_response(vec![ToolCall {
id: "c1".into(),
name: "add".into(),
arguments: json!({"a": 1, "b": 2}),
}]));
mock.queue_response(sample_response("Should not appear"));
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(AddTool);
let params = ChatParams {
messages: vec![ChatMessage::user("Stop early")],
..Default::default()
};
let mut handle = ToolLoopHandle::new(&mock, ®istry, params, ToolLoopConfig::default(), &());
match handle.next_turn().await {
TurnResult::Yielded(turn) => {
turn.stop(Some("task_spawn detected".into()));
}
_ => panic!("expected Yielded"),
}
match handle.next_turn().await {
TurnResult::Completed(done) => {
assert!(matches!(
done.termination_reason,
TerminationReason::StopCondition {
reason: Some(ref r)
} if r == "task_spawn detected"
));
}
_ => panic!("expected Completed"),
}
assert_eq!(mock.recorded_calls().len(), 1);
}
#[tokio::test]
async fn test_resumable_usage_accumulation() {
let mock = mock_for("test", "test-model");
mock.queue_response(sample_tool_response(vec![ToolCall {
id: "c1".into(),
name: "add".into(),
arguments: json!({"a": 1, "b": 2}),
}]));
mock.queue_response(sample_tool_response(vec![ToolCall {
id: "c2".into(),
name: "add".into(),
arguments: json!({"a": 3, "b": 4}),
}]));
mock.queue_response(sample_response("All done"));
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(AddTool);
let params = ChatParams {
messages: vec![ChatMessage::user("Chain")],
..Default::default()
};
let mut handle = ToolLoopHandle::new(&mock, ®istry, params, ToolLoopConfig::default(), &());
match handle.next_turn().await {
TurnResult::Yielded(turn) => {
assert_eq!(turn.total_usage.input_tokens, 100);
turn.continue_loop();
}
_ => panic!("expected Yielded"),
}
match handle.next_turn().await {
TurnResult::Yielded(turn) => {
assert_eq!(turn.total_usage.input_tokens, 200);
turn.continue_loop();
}
_ => panic!("expected Yielded"),
}
match handle.next_turn().await {
TurnResult::Completed(done) => {
assert_eq!(done.total_usage.input_tokens, 300);
assert_eq!(done.total_usage.output_tokens, 150);
}
_ => panic!("expected Completed"),
}
let result = handle.into_result();
assert_eq!(result.total_usage.input_tokens, 300);
}
#[tokio::test]
async fn test_resumable_max_iterations() {
let mock = mock_for("test", "test-model");
for _ in 0..5 {
mock.queue_response(sample_tool_response(vec![ToolCall {
id: "c".into(),
name: "add".into(),
arguments: json!({"a": 1, "b": 2}),
}]));
}
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(AddTool);
let params = ChatParams {
messages: vec![ChatMessage::user("Loop")],
..Default::default()
};
let config = ToolLoopConfig {
max_iterations: 2,
..Default::default()
};
let mut handle = ToolLoopHandle::new(&mock, ®istry, params, config, &());
match handle.next_turn().await {
TurnResult::Yielded(turn) => {
assert_eq!(turn.iteration, 1);
turn.continue_loop();
}
_ => panic!("expected Yielded"),
}
match handle.next_turn().await {
TurnResult::Yielded(turn) => {
assert_eq!(turn.iteration, 2);
turn.continue_loop();
}
_ => panic!("expected Yielded"),
}
match handle.next_turn().await {
TurnResult::Completed(done) => {
assert!(matches!(
done.termination_reason,
TerminationReason::MaxIterations { limit: 2 }
));
}
_ => panic!("expected Completed"),
}
}
#[tokio::test]
async fn test_resumable_depth_exceeded() {
#[derive(Clone)]
struct DepthCtx(u32);
impl super::LoopDepth for DepthCtx {
fn loop_depth(&self) -> u32 {
self.0
}
fn with_depth(&self, depth: u32) -> Self {
DepthCtx(depth)
}
}
let mock = mock_for("test", "test-model");
let params = ChatParams {
messages: vec![ChatMessage::user("Hi")],
..Default::default()
};
let config = ToolLoopConfig {
max_depth: Some(1),
..Default::default()
};
let registry_typed: ToolRegistry<DepthCtx> = ToolRegistry::new();
let ctx = DepthCtx(1); let mut handle = ToolLoopHandle::new(&mock, ®istry_typed, params, config, &ctx);
match handle.next_turn().await {
TurnResult::Error(err) => {
assert!(matches!(
err.error,
crate::LlmError::MaxDepthExceeded {
current: 1,
limit: 1
}
));
}
_ => panic!("expected Error"),
}
assert!(handle.is_finished());
}
#[tokio::test]
async fn test_resumable_messages_access() {
let mock = mock_for("test", "test-model");
mock.queue_response(sample_tool_response(vec![ToolCall {
id: "c1".into(),
name: "add".into(),
arguments: json!({"a": 1, "b": 2}),
}]));
mock.queue_response(sample_response("Done"));
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(AddTool);
let params = ChatParams {
messages: vec![ChatMessage::user("Go")],
..Default::default()
};
let mut handle = ToolLoopHandle::new(&mock, ®istry, params, ToolLoopConfig::default(), &());
assert_eq!(handle.messages().len(), 1);
match handle.next_turn().await {
TurnResult::Yielded(turn) => {
assert_eq!(turn.messages().len(), 3); turn.continue_loop();
}
_ => panic!("expected Yielded"),
}
assert_eq!(handle.messages().len(), 3);
let _ = handle.next_turn().await;
assert!(handle.messages().len() >= 3);
}
#[tokio::test]
async fn test_resumable_into_result() {
let mock = mock_for("test", "test-model");
mock.queue_response(sample_response("Quick"));
let registry: ToolRegistry<()> = ToolRegistry::new();
let params = ChatParams {
messages: vec![ChatMessage::user("Hi")],
..Default::default()
};
let mut handle = ToolLoopHandle::new(&mock, ®istry, params, ToolLoopConfig::default(), &());
let _ = handle.next_turn().await;
let result = handle.into_result();
assert_eq!(result.termination_reason, TerminationReason::Complete);
assert_eq!(result.iterations, 1);
assert_eq!(result.total_usage.input_tokens, 100);
}
#[tokio::test]
async fn test_resumable_repeated_next_after_completion() {
let mock = mock_for("test", "test-model");
mock.queue_response(sample_response("Done"));
let registry: ToolRegistry<()> = ToolRegistry::new();
let params = ChatParams {
messages: vec![ChatMessage::user("Hi")],
..Default::default()
};
let mut handle = ToolLoopHandle::new(&mock, ®istry, params, ToolLoopConfig::default(), &());
assert!(matches!(handle.next_turn().await, TurnResult::Completed(_)));
assert!(matches!(handle.next_turn().await, TurnResult::Completed(_)));
}
#[tokio::test]
async fn test_resumable_stop_condition_callback() {
let mock = mock_for("test", "test-model");
mock.queue_response(sample_tool_response(vec![ToolCall {
id: "c1".into(),
name: "add".into(),
arguments: json!({"a": 2, "b": 3}),
}]));
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(AddTool);
let params = ChatParams {
messages: vec![ChatMessage::user("Test stop")],
..Default::default()
};
let config = ToolLoopConfig {
stop_when: Some(Arc::new(|_| StopDecision::Stop)),
..Default::default()
};
let mut handle = ToolLoopHandle::new(&mock, ®istry, params, config, &());
match handle.next_turn().await {
TurnResult::Completed(done) => {
assert!(matches!(
done.termination_reason,
TerminationReason::StopCondition { reason: None }
));
}
_ => panic!("expected Completed"),
}
}
#[tokio::test]
async fn test_resumable_debug_impl() {
let mock = mock_for("test", "test-model");
mock.queue_response(sample_response("Debug"));
let registry: ToolRegistry<()> = ToolRegistry::new();
let params = ChatParams {
messages: vec![ChatMessage::user("Hi")],
..Default::default()
};
let handle = ToolLoopHandle::new(&mock, ®istry, params, ToolLoopConfig::default(), &());
let debug = format!("{handle:?}");
assert!(debug.contains("ToolLoopHandle"));
assert!(debug.contains("iterations"));
assert!(debug.contains("finished"));
}
#[tokio::test]
async fn test_resumable_assistant_content_exposed() {
let mock = mock_for("test", "test-model");
mock.queue_response(sample_tool_response_with_text(
"I'll help with that",
vec![ToolCall {
id: "c1".into(),
name: "add".into(),
arguments: json!({"a": 5, "b": 10}),
}],
));
mock.queue_response(sample_response("The answer is 15"));
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(AddTool);
let params = ChatParams {
messages: vec![ChatMessage::user("What is 5 + 10?")],
..Default::default()
};
let mut handle = ToolLoopHandle::new(&mock, ®istry, params, ToolLoopConfig::default(), &());
match handle.next_turn().await {
TurnResult::Yielded(turn) => {
assert_eq!(turn.assistant_content.len(), 1);
assert!(
matches!(&turn.assistant_content[0], ContentBlock::Text(t) if t == "I'll help with that")
);
assert_eq!(turn.assistant_text(), Some("I'll help with that".into()));
turn.continue_loop();
}
_ => panic!("expected Yielded"),
}
match handle.next_turn().await {
TurnResult::Completed(done) => {
assert_eq!(done.response.text(), Some("The answer is 15"));
}
_ => panic!("expected Completed"),
}
}
#[tokio::test]
async fn test_resumable_assistant_content_empty_when_no_text() {
let mock = mock_for("test", "test-model");
mock.queue_response(sample_tool_response(vec![ToolCall {
id: "c1".into(),
name: "add".into(),
arguments: json!({"a": 1, "b": 2}),
}]));
mock.queue_response(sample_response("Done"));
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(AddTool);
let params = ChatParams {
messages: vec![ChatMessage::user("Go")],
..Default::default()
};
let mut handle = ToolLoopHandle::new(&mock, ®istry, params, ToolLoopConfig::default(), &());
match handle.next_turn().await {
TurnResult::Yielded(turn) => {
assert!(turn.assistant_content.is_empty());
assert!(turn.assistant_text().is_none());
turn.continue_loop();
}
_ => panic!("expected Yielded"),
}
let _ = handle.next_turn().await;
}
#[tokio::test]
async fn test_resumable_multi_iteration_with_mixed_commands() {
let mock = mock_for("test", "test-model");
mock.queue_response(sample_tool_response(vec![ToolCall {
id: "c1".into(),
name: "add".into(),
arguments: json!({"a": 1, "b": 2}),
}]));
mock.queue_response(sample_tool_response(vec![ToolCall {
id: "c2".into(),
name: "add".into(),
arguments: json!({"a": 10, "b": 20}),
}]));
mock.queue_response(sample_response("All done"));
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(AddTool);
let params = ChatParams {
messages: vec![ChatMessage::user("Mix commands")],
..Default::default()
};
let mut handle = ToolLoopHandle::new(&mock, ®istry, params, ToolLoopConfig::default(), &());
match handle.next_turn().await {
TurnResult::Yielded(turn) => {
assert_eq!(turn.iteration, 1);
turn.continue_loop();
}
_ => panic!("expected Yielded"),
}
match handle.next_turn().await {
TurnResult::Yielded(turn) => {
assert_eq!(turn.iteration, 2);
turn.inject_and_continue(vec![ChatMessage::user("Worker completed task X")]);
}
_ => panic!("expected Yielded"),
}
match handle.next_turn().await {
TurnResult::Completed(done) => {
assert_eq!(done.iterations, 3);
}
_ => panic!("expected Completed"),
}
let recorded = mock.recorded_calls();
let last_call = &recorded[2];
let has_worker_msg = last_call.messages.iter().any(|m| {
m.content.iter().any(|b| {
if let ContentBlock::Text(t) = b {
t.contains("Worker completed task X")
} else {
false
}
})
});
assert!(has_worker_msg);
}
#[tokio::test]
async fn test_resumable_timeout() {
let mock = mock_for("test", "test-model");
mock.queue_response(sample_tool_response(vec![ToolCall {
id: "c".into(),
name: "add".into(),
arguments: json!({"a": 1, "b": 2}),
}]));
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(AddTool);
let params = ChatParams {
messages: vec![ChatMessage::user("Timeout")],
..Default::default()
};
let config = ToolLoopConfig {
timeout: Some(Duration::ZERO),
..Default::default()
};
let mut handle = ToolLoopHandle::new(&mock, ®istry, params, config, &());
match handle.next_turn().await {
TurnResult::Completed(done) => {
assert!(matches!(
done.termination_reason,
TerminationReason::Timeout { .. }
));
}
_ => panic!("expected Completed with timeout"),
}
}
#[tokio::test]
async fn test_resumable_events_via_stream() {
use futures::StreamExt;
let mock = Arc::new(mock_for("test", "test-model"));
mock.queue_response(sample_tool_response(vec![ToolCall {
id: "c1".into(),
name: "add".into(),
arguments: json!({"a": 2, "b": 3}),
}]));
mock.queue_response(sample_response("Done"));
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(AddTool);
let registry = Arc::new(registry);
let params = ChatParams {
messages: vec![ChatMessage::user("Events")],
..Default::default()
};
let config = ToolLoopConfig::default();
let stream = tool_loop_stream(mock, registry, params, config, Arc::new(()));
let results: Vec<Result<LoopEvent, _>> = stream.collect().await;
let events: Vec<LoopEvent> = results.into_iter().filter_map(Result::ok).collect();
assert!(
events
.iter()
.any(|e| matches!(e, LoopEvent::IterationStart { .. }))
);
assert!(
events
.iter()
.any(|e| matches!(e, LoopEvent::ToolExecutionStart { .. }))
);
assert!(
events
.iter()
.any(|e| matches!(e, LoopEvent::ToolExecutionEnd { .. }))
);
}
#[tokio::test]
async fn test_resumable_drain_events() {
let mock = mock_for("test", "test-model");
mock.queue_response(sample_tool_response(vec![ToolCall {
id: "c1".into(),
name: "add".into(),
arguments: json!({"a": 2, "b": 3}),
}]));
mock.queue_response(sample_response("Done"));
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(AddTool);
let params = ChatParams {
messages: vec![ChatMessage::user("Drain test")],
..Default::default()
};
let mut handle = ToolLoopHandle::new(&mock, ®istry, params, ToolLoopConfig::default(), &());
match handle.next_turn().await {
TurnResult::Yielded(turn) => {
assert!(
turn.events
.iter()
.any(|e| matches!(e, LoopEvent::IterationStart { .. })),
"should have IterationStart"
);
assert!(
turn.events
.iter()
.any(|e| matches!(e, LoopEvent::ToolExecutionStart { .. })),
"should have ToolExecutionStart"
);
assert!(
turn.events
.iter()
.any(|e| matches!(e, LoopEvent::ToolExecutionEnd { .. })),
"should have ToolExecutionEnd"
);
turn.continue_loop();
}
_ => panic!("expected Yielded"),
}
assert!(handle.drain_events().is_empty());
match handle.next_turn().await {
TurnResult::Completed(done) => {
assert_eq!(done.response.text(), Some("Done"));
assert!(
done.events
.iter()
.any(|e| matches!(e, LoopEvent::IterationStart { .. })),
"completion turn should have IterationStart"
);
}
_ => panic!("expected Completed"),
}
assert!(handle.drain_events().is_empty());
}
use super::loop_owned::{OwnedToolLoopHandle, OwnedTurnResult};
#[test]
fn test_owned_handle_is_send_static() {
fn assert_send_static<T: Send + 'static>() {}
assert_send_static::<OwnedToolLoopHandle<()>>();
}
#[tokio::test]
async fn test_owned_no_tools_completes() {
let mock = Arc::new(mock_for("test", "test-model"));
mock.queue_response(sample_response("Hello!"));
let registry = Arc::new(ToolRegistry::<()>::new());
let params = ChatParams {
messages: vec![ChatMessage::user("Hi")],
..Default::default()
};
let mut handle =
OwnedToolLoopHandle::new(mock, registry, params, ToolLoopConfig::default(), &());
match handle.next_turn().await {
OwnedTurnResult::Completed(done) => {
assert_eq!(done.termination_reason, TerminationReason::Complete);
assert_eq!(done.response.text(), Some("Hello!"));
}
_ => panic!("expected Completed"),
}
assert!(handle.is_finished());
}
#[tokio::test]
async fn test_owned_one_tool_iteration() {
let mock = Arc::new(mock_for("test", "test-model"));
mock.queue_response(sample_tool_response(vec![ToolCall {
id: "call_1".into(),
name: "add".into(),
arguments: json!({"a": 2, "b": 3}),
}]));
mock.queue_response(sample_response("The answer is 5"));
let mut registry = ToolRegistry::<()>::new();
registry.register(AddTool);
let registry = Arc::new(registry);
let params = ChatParams {
messages: vec![ChatMessage::user("What is 2 + 3?")],
..Default::default()
};
let mut handle =
OwnedToolLoopHandle::new(mock, registry, params, ToolLoopConfig::default(), &());
match handle.next_turn().await {
OwnedTurnResult::Yielded(turn) => {
assert_eq!(turn.iteration, 1);
assert_eq!(turn.tool_calls.len(), 1);
assert_eq!(turn.tool_calls[0].name, "add");
assert_eq!(turn.results.len(), 1);
assert_eq!(turn.results[0].content, "5");
turn.continue_loop();
}
_ => panic!("expected Yielded"),
}
match handle.next_turn().await {
OwnedTurnResult::Completed(done) => {
assert_eq!(done.iterations, 2);
assert_eq!(done.response.text(), Some("The answer is 5"));
}
_ => panic!("expected Completed"),
}
}
#[tokio::test]
async fn test_owned_inject_messages() {
let mock = Arc::new(mock_for("test", "test-model"));
mock.queue_response(sample_tool_response(vec![ToolCall {
id: "c1".into(),
name: "add".into(),
arguments: json!({"a": 1, "b": 2}),
}]));
mock.queue_response(sample_response("Done with injection"));
let mut registry = ToolRegistry::<()>::new();
registry.register(AddTool);
let registry = Arc::new(registry);
let params = ChatParams {
messages: vec![ChatMessage::user("Add")],
..Default::default()
};
let mut handle = OwnedToolLoopHandle::new(
mock.clone(),
registry,
params,
ToolLoopConfig::default(),
&(),
);
match handle.next_turn().await {
OwnedTurnResult::Yielded(turn) => {
turn.inject_and_continue(vec![ChatMessage::system("Extra context")]);
}
_ => panic!("expected Yielded"),
}
assert!(matches!(
handle.next_turn().await,
OwnedTurnResult::Completed(_)
));
let recorded = mock.recorded_calls();
let last_call = &recorded[1];
let has_injection = last_call.messages.iter().any(|m| {
m.content.iter().any(|b| {
if let ContentBlock::Text(t) = b {
t.contains("Extra context")
} else {
false
}
})
});
assert!(has_injection, "Injected message should be in conversation");
}
#[tokio::test]
async fn test_owned_stop_command() {
let mock = Arc::new(mock_for("test", "test-model"));
mock.queue_response(sample_tool_response(vec![ToolCall {
id: "c1".into(),
name: "add".into(),
arguments: json!({"a": 1, "b": 2}),
}]));
mock.queue_response(sample_response("Should not appear"));
let mut registry = ToolRegistry::<()>::new();
registry.register(AddTool);
let registry = Arc::new(registry);
let params = ChatParams {
messages: vec![ChatMessage::user("Stop early")],
..Default::default()
};
let mut handle = OwnedToolLoopHandle::new(
mock.clone(),
registry,
params,
ToolLoopConfig::default(),
&(),
);
match handle.next_turn().await {
OwnedTurnResult::Yielded(turn) => {
turn.stop(Some("task_spawn detected".into()));
}
_ => panic!("expected Yielded"),
}
match handle.next_turn().await {
OwnedTurnResult::Completed(done) => {
assert!(matches!(
done.termination_reason,
TerminationReason::StopCondition {
reason: Some(ref r)
} if r == "task_spawn detected"
));
}
_ => panic!("expected Completed"),
}
assert_eq!(mock.recorded_calls().len(), 1);
}
#[tokio::test]
async fn test_owned_usage_accumulation() {
let mock = Arc::new(mock_for("test", "test-model"));
mock.queue_response(sample_tool_response(vec![ToolCall {
id: "c1".into(),
name: "add".into(),
arguments: json!({"a": 1, "b": 2}),
}]));
mock.queue_response(sample_response("Done"));
let mut registry = ToolRegistry::<()>::new();
registry.register(AddTool);
let registry = Arc::new(registry);
let params = ChatParams {
messages: vec![ChatMessage::user("Chain")],
..Default::default()
};
let mut handle =
OwnedToolLoopHandle::new(mock, registry, params, ToolLoopConfig::default(), &());
match handle.next_turn().await {
OwnedTurnResult::Yielded(turn) => {
assert_eq!(turn.total_usage.input_tokens, 100);
turn.continue_loop();
}
_ => panic!("expected Yielded"),
}
match handle.next_turn().await {
OwnedTurnResult::Completed(done) => {
assert_eq!(done.total_usage.input_tokens, 200);
assert_eq!(done.total_usage.output_tokens, 100);
}
_ => panic!("expected Completed"),
}
let result = handle.into_result();
assert_eq!(result.total_usage.input_tokens, 200);
}
#[tokio::test]
async fn test_owned_inside_tokio_spawn() {
let mock = Arc::new(mock_for("test", "test-model"));
mock.queue_response(sample_tool_response(vec![ToolCall {
id: "c1".into(),
name: "add".into(),
arguments: json!({"a": 10, "b": 20}),
}]));
mock.queue_response(sample_response("Result: 30"));
let mut registry = ToolRegistry::<()>::new();
registry.register(AddTool);
let registry = Arc::new(registry);
let params = ChatParams {
messages: vec![ChatMessage::user("Compute")],
..Default::default()
};
let mut handle =
OwnedToolLoopHandle::new(mock, registry, params, ToolLoopConfig::default(), &());
let result = tokio::spawn(async move {
loop {
match handle.next_turn().await {
OwnedTurnResult::Yielded(turn) => turn.continue_loop(),
OwnedTurnResult::Completed(done) => return done,
OwnedTurnResult::Error(err) => panic!("unexpected error: {}", err.error),
}
}
})
.await
.expect("spawned task should complete");
assert_eq!(result.response.text(), Some("Result: 30"));
assert_eq!(result.iterations, 2);
assert_eq!(result.total_usage.input_tokens, 200);
}
#[tokio::test]
async fn test_owned_into_owned_preserves_state() {
let mock = mock_for("test", "test-model");
mock.queue_response(sample_tool_response(vec![ToolCall {
id: "c1".into(),
name: "add".into(),
arguments: json!({"a": 1, "b": 2}),
}]));
mock.queue_response(sample_tool_response(vec![ToolCall {
id: "c2".into(),
name: "add".into(),
arguments: json!({"a": 3, "b": 4}),
}]));
mock.queue_response(sample_response("Final"));
let mut registry = ToolRegistry::<()>::new();
registry.register(AddTool);
let params = ChatParams {
messages: vec![ChatMessage::user("Convert mid-flight")],
..Default::default()
};
let mut handle = ToolLoopHandle::new(&mock, ®istry, params, ToolLoopConfig::default(), &());
match handle.next_turn().await {
TurnResult::Yielded(turn) => {
assert_eq!(turn.iteration, 1);
turn.continue_loop();
}
_ => panic!("expected Yielded"),
}
let new_mock = Arc::new(mock_for("test", "test-model"));
new_mock.queue_response(sample_tool_response(vec![ToolCall {
id: "c2".into(),
name: "add".into(),
arguments: json!({"a": 3, "b": 4}),
}]));
new_mock.queue_response(sample_response("Final"));
let registry_arc = Arc::new({
let mut r = ToolRegistry::<()>::new();
r.register(AddTool);
r
});
let mut owned = handle.into_owned(new_mock, registry_arc);
assert_eq!(owned.iterations(), 1);
assert!(owned.messages().len() >= 3);
match owned.next_turn().await {
OwnedTurnResult::Yielded(turn) => {
assert_eq!(turn.iteration, 2);
turn.continue_loop();
}
_ => panic!("expected Yielded"),
}
match owned.next_turn().await {
OwnedTurnResult::Completed(done) => {
assert_eq!(done.iterations, 3);
assert_eq!(done.response.text(), Some("Final"));
}
_ => panic!("expected Completed"),
}
}
#[tokio::test]
async fn test_owned_assistant_content_exposed() {
let mock = Arc::new(mock_for("test", "test-model"));
mock.queue_response(sample_tool_response_with_text(
"I'll help with that",
vec![ToolCall {
id: "c1".into(),
name: "add".into(),
arguments: json!({"a": 5, "b": 10}),
}],
));
mock.queue_response(sample_response("The answer is 15"));
let mut registry = ToolRegistry::<()>::new();
registry.register(AddTool);
let registry = Arc::new(registry);
let params = ChatParams {
messages: vec![ChatMessage::user("What is 5 + 10?")],
..Default::default()
};
let mut handle =
OwnedToolLoopHandle::new(mock, registry, params, ToolLoopConfig::default(), &());
match handle.next_turn().await {
OwnedTurnResult::Yielded(turn) => {
assert_eq!(turn.assistant_text(), Some("I'll help with that".into()));
turn.continue_loop();
}
_ => panic!("expected Yielded"),
}
match handle.next_turn().await {
OwnedTurnResult::Completed(done) => {
assert_eq!(done.response.text(), Some("The answer is 15"));
}
_ => panic!("expected Completed"),
}
}
#[tokio::test]
async fn test_owned_debug_impl() {
let mock = Arc::new(mock_for("test", "test-model"));
mock.queue_response(sample_response("Debug"));
let registry = Arc::new(ToolRegistry::<()>::new());
let params = ChatParams {
messages: vec![ChatMessage::user("Hi")],
..Default::default()
};
let handle = OwnedToolLoopHandle::new(mock, registry, params, ToolLoopConfig::default(), &());
let debug = format!("{handle:?}");
assert!(debug.contains("OwnedToolLoopHandle"));
assert!(debug.contains("iterations"));
assert!(debug.contains("finished"));
}
#[tokio::test]
async fn test_owned_max_iterations() {
let mock = Arc::new(mock_for("test", "test-model"));
for _ in 0..5 {
mock.queue_response(sample_tool_response(vec![ToolCall {
id: "c".into(),
name: "add".into(),
arguments: json!({"a": 1, "b": 2}),
}]));
}
let mut registry = ToolRegistry::<()>::new();
registry.register(AddTool);
let registry = Arc::new(registry);
let params = ChatParams {
messages: vec![ChatMessage::user("Loop")],
..Default::default()
};
let config = ToolLoopConfig {
max_iterations: 2,
..Default::default()
};
let mut handle = OwnedToolLoopHandle::new(mock, registry, params, config, &());
match handle.next_turn().await {
OwnedTurnResult::Yielded(turn) => {
assert_eq!(turn.iteration, 1);
turn.continue_loop();
}
_ => panic!("expected Yielded"),
}
match handle.next_turn().await {
OwnedTurnResult::Yielded(turn) => {
assert_eq!(turn.iteration, 2);
turn.continue_loop();
}
_ => panic!("expected Yielded"),
}
match handle.next_turn().await {
OwnedTurnResult::Completed(done) => {
assert!(matches!(
done.termination_reason,
TerminationReason::MaxIterations { limit: 2 }
));
}
_ => panic!("expected Completed"),
}
}
#[tokio::test]
async fn test_owned_depth_exceeded() {
#[derive(Clone)]
struct DepthCtx(u32);
impl super::LoopDepth for DepthCtx {
fn loop_depth(&self) -> u32 {
self.0
}
fn with_depth(&self, depth: u32) -> Self {
DepthCtx(depth)
}
}
let mock = Arc::new(mock_for("test", "test-model"));
let registry = Arc::new(ToolRegistry::<DepthCtx>::new());
let params = ChatParams {
messages: vec![ChatMessage::user("Hi")],
..Default::default()
};
let config = ToolLoopConfig {
max_depth: Some(1),
..Default::default()
};
let ctx = DepthCtx(1);
let mut handle = OwnedToolLoopHandle::new(mock, registry, params, config, &ctx);
match handle.next_turn().await {
OwnedTurnResult::Error(err) => {
assert!(matches!(
err.error,
crate::LlmError::MaxDepthExceeded {
current: 1,
limit: 1
}
));
}
_ => panic!("expected Error"),
}
assert!(handle.is_finished());
}
struct BigOutputTool;
impl ToolHandler<()> for BigOutputTool {
fn definition(&self) -> ToolDefinition {
ToolDefinition {
name: "big_output".into(),
description: "Returns a large string".into(),
parameters: JsonSchema::new(json!({
"type": "object",
"properties": {
"size": {"type": "number"}
},
"required": ["size"]
})),
retry: None,
}
}
fn execute<'a>(
&'a self,
input: Value,
_ctx: &'a (),
) -> Pin<Box<dyn Future<Output = Result<ToolOutput, ToolError>> + Send + 'a>> {
Box::pin(async move {
let size = usize::try_from(input["size"].as_u64().unwrap_or(100)).unwrap();
Ok(ToolOutput::new("x".repeat(size)))
})
}
}
#[tokio::test]
async fn test_masking_replaces_old_large_results() {
let mock = mock_for("test", "test-model");
for i in 0..4 {
mock.queue_response(sample_tool_response(vec![ToolCall {
id: format!("call_{i}"),
name: "big_output".into(),
arguments: json!({"size": 4000}), }]));
}
mock.queue_response(sample_response("Done"));
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(BigOutputTool);
let params = ChatParams {
messages: vec![ChatMessage::user("Go")],
..Default::default()
};
let config = ToolLoopConfig {
masking: Some(ObservationMaskingConfig {
max_iterations_to_keep: 2,
min_tokens_to_mask: 500,
}),
..Default::default()
};
let result = tool_loop(&mock, ®istry, params, config, &())
.await
.unwrap();
assert_eq!(result.iterations, 5);
let calls = mock.recorded_calls();
assert!(calls.len() >= 5);
let last_call = &calls[4];
let tool_result_msgs: Vec<_> = last_call
.messages
.iter()
.filter(|m| m.role == ChatRole::Tool)
.collect();
assert_eq!(tool_result_msgs.len(), 4);
for msg in &tool_result_msgs[..3] {
if let Some(ContentBlock::ToolResult(tr)) = msg.content.first() {
assert!(
tr.content.contains("[Masked"),
"Expected masked placeholder, got: {}",
&tr.content[..tr.content.len().min(80)]
);
} else {
panic!("Expected ToolResult content block");
}
}
if let Some(ContentBlock::ToolResult(tr)) = tool_result_msgs[3].content.first() {
assert!(
!tr.content.contains("[Masked"),
"Iteration 4 result should not be masked"
);
assert_eq!(tr.content.len(), 4000);
} else {
panic!("Expected ToolResult content block");
}
}
#[tokio::test]
async fn test_masking_preserves_small_results() {
let mock = mock_for("test", "test-model");
for i in 0..4 {
mock.queue_response(sample_tool_response(vec![ToolCall {
id: format!("call_{i}"),
name: "big_output".into(),
arguments: json!({"size": 100}),
}]));
}
mock.queue_response(sample_response("Done"));
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(BigOutputTool);
let params = ChatParams {
messages: vec![ChatMessage::user("Go")],
..Default::default()
};
let config = ToolLoopConfig {
masking: Some(ObservationMaskingConfig {
max_iterations_to_keep: 2,
min_tokens_to_mask: 500,
}),
..Default::default()
};
let result = tool_loop(&mock, ®istry, params, config, &())
.await
.unwrap();
assert_eq!(result.iterations, 5);
let last_call = &mock.recorded_calls()[4];
let tool_results: Vec<_> = last_call
.messages
.iter()
.filter(|m| m.role == ChatRole::Tool)
.collect();
for msg in &tool_results {
if let Some(ContentBlock::ToolResult(tr)) = msg.content.first() {
assert!(
!tr.content.contains("[Masked"),
"Small results should not be masked"
);
}
}
}
#[tokio::test]
async fn test_masking_preserves_error_results() {
let mock = mock_for("test", "test-model");
mock.queue_response(sample_tool_response(vec![ToolCall {
id: "call_0".into(),
name: "fail".into(),
arguments: json!({}),
}]));
for i in 1..4 {
mock.queue_response(sample_tool_response(vec![ToolCall {
id: format!("call_{i}"),
name: "big_output".into(),
arguments: json!({"size": 4000}),
}]));
}
mock.queue_response(sample_response("Done"));
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(BigOutputTool);
registry.register(FailTool);
let params = ChatParams {
messages: vec![ChatMessage::user("Go")],
..Default::default()
};
let config = ToolLoopConfig {
masking: Some(ObservationMaskingConfig {
max_iterations_to_keep: 2,
min_tokens_to_mask: 500,
}),
..Default::default()
};
let result = tool_loop(&mock, ®istry, params, config, &())
.await
.unwrap();
assert_eq!(result.iterations, 5);
let last_call = &mock.recorded_calls()[4];
let tool_results: Vec<_> = last_call
.messages
.iter()
.filter(|m| m.role == ChatRole::Tool)
.collect();
if let Some(ContentBlock::ToolResult(tr)) = tool_results[0].content.first() {
assert!(tr.is_error, "First result should be an error");
assert!(
!tr.content.contains("[Masked"),
"Error results should never be masked"
);
}
}
#[tokio::test]
async fn test_masking_emits_observations_masked_event() {
use futures::StreamExt;
let mock = Arc::new(mock_for("test", "test-model"));
for i in 0..3 {
mock.queue_response(sample_tool_response(vec![ToolCall {
id: format!("call_{i}"),
name: "big_output".into(),
arguments: json!({"size": 4000}),
}]));
}
mock.queue_response(sample_response("Done"));
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(BigOutputTool);
let registry = Arc::new(registry);
let params = ChatParams {
messages: vec![ChatMessage::user("Go")],
..Default::default()
};
let config = ToolLoopConfig {
masking: Some(ObservationMaskingConfig {
max_iterations_to_keep: 1,
min_tokens_to_mask: 500,
}),
..Default::default()
};
let stream = tool_loop_stream(mock, registry, params, config, Arc::new(()));
let events: Vec<LoopEvent> = stream
.collect::<Vec<_>>()
.await
.into_iter()
.filter_map(Result::ok)
.collect();
let masked_events: Vec<_> = events
.iter()
.filter(|e| matches!(e, LoopEvent::ObservationsMasked { .. }))
.collect();
assert!(
!masked_events.is_empty(),
"Expected at least one ObservationsMasked event"
);
if let LoopEvent::ObservationsMasked {
masked_count,
tokens_saved,
} = masked_events[0]
{
assert!(*masked_count >= 1, "Should mask at least 1 result");
assert!(*tokens_saved > 0, "Should save some tokens");
}
}
#[tokio::test]
async fn test_no_masking_without_config() {
let mock = mock_for("test", "test-model");
for i in 0..4 {
mock.queue_response(sample_tool_response(vec![ToolCall {
id: format!("call_{i}"),
name: "big_output".into(),
arguments: json!({"size": 4000}),
}]));
}
mock.queue_response(sample_response("Done"));
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(BigOutputTool);
let params = ChatParams {
messages: vec![ChatMessage::user("Go")],
..Default::default()
};
let config = ToolLoopConfig::default();
let result = tool_loop(&mock, ®istry, params, config, &())
.await
.unwrap();
assert_eq!(result.iterations, 5);
let last_call = &mock.recorded_calls()[4];
let tool_results: Vec<_> = last_call
.messages
.iter()
.filter(|m| m.role == ChatRole::Tool)
.collect();
for msg in &tool_results {
if let Some(ContentBlock::ToolResult(tr)) = msg.content.first() {
assert!(
!tr.content.contains("[Masked"),
"No masking should occur without config"
);
assert_eq!(tr.content.len(), 4000);
}
}
}
#[test]
fn test_observation_masking_config_default() {
let cfg = ObservationMaskingConfig::default();
assert_eq!(cfg.max_iterations_to_keep, 2);
assert_eq!(cfg.min_tokens_to_mask, 500);
}
#[test]
fn test_observation_masking_config_debug_clone() {
let cfg = ObservationMaskingConfig {
max_iterations_to_keep: 3,
min_tokens_to_mask: 100,
};
let cloned = cfg;
assert_eq!(cloned.max_iterations_to_keep, 3);
assert_eq!(cloned.min_tokens_to_mask, 100);
let debug = format!("{cfg:?}");
assert!(debug.contains("max_iterations_to_keep"));
assert!(debug.contains("min_tokens_to_mask"));
}
struct TestExtractor {
threshold: u32,
}
impl ToolResultExtractor for TestExtractor {
fn extract<'a>(
&'a self,
tool_name: &'a str,
output: &'a str,
user_query: &'a str,
) -> Pin<Box<dyn Future<Output = Option<ExtractedResult>> + Send + 'a>> {
Box::pin(async move {
Some(ExtractedResult {
content: format!(
"[Extracted from {tool_name} for query: {user_query}] {} chars",
output.len()
),
original_tokens_est: crate::context::estimate_tokens(output),
extracted_tokens_est: 20,
})
})
}
fn extraction_threshold(&self) -> u32 {
self.threshold
}
}
#[tokio::test]
async fn test_extraction_condenses_large_results() {
let mock = mock_for("test", "test-model");
mock.queue_response(sample_tool_response(vec![ToolCall {
id: "call_0".into(),
name: "big_output".into(),
arguments: json!({"size": 80000}), }]));
mock.queue_response(sample_response("Done"));
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(BigOutputTool);
let params = ChatParams {
messages: vec![ChatMessage::user("What is the weather?")],
tools: Some(registry.definitions()),
..Default::default()
};
let config = ToolLoopConfig {
max_iterations: 10,
result_extractor: Some(Arc::new(TestExtractor { threshold: 100 })),
..Default::default()
};
let result = tool_loop(&mock, ®istry, params, config, &())
.await
.unwrap();
assert_eq!(result.response.text().unwrap(), "Done");
let calls = mock.recorded_calls();
assert_eq!(calls.len(), 2);
let second_call = &calls[1];
let tool_result_msg = second_call
.messages
.iter()
.find(|m| m.role == ChatRole::Tool)
.expect("should have tool result message");
let content = match &tool_result_msg.content[0] {
ContentBlock::ToolResult(tr) => &tr.content,
_ => panic!("expected ToolResult"),
};
assert!(
content.contains("[Extracted from big_output"),
"content should be extracted: {content}"
);
assert!(
content.contains("What is the weather"),
"extraction should include user query: {content}"
);
}
#[tokio::test]
async fn test_extraction_skips_small_results() {
let mock = mock_for("test", "test-model");
mock.queue_response(sample_tool_response(vec![ToolCall {
id: "call_0".into(),
name: "big_output".into(),
arguments: json!({"size": 100}), }]));
mock.queue_response(sample_response("Done"));
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(BigOutputTool);
let params = ChatParams {
messages: vec![ChatMessage::user("Test query")],
tools: Some(registry.definitions()),
..Default::default()
};
let config = ToolLoopConfig {
max_iterations: 10,
result_extractor: Some(Arc::new(TestExtractor { threshold: 1000 })),
..Default::default()
};
let result = tool_loop(&mock, ®istry, params, config, &())
.await
.unwrap();
assert_eq!(result.response.text().unwrap(), "Done");
let calls = mock.recorded_calls();
let second_call = &calls[1];
let tool_result_msg = second_call
.messages
.iter()
.find(|m| m.role == ChatRole::Tool)
.expect("should have tool result message");
let content = match &tool_result_msg.content[0] {
ContentBlock::ToolResult(tr) => &tr.content,
_ => panic!("expected ToolResult"),
};
assert_eq!(content, &"x".repeat(100));
}
#[tokio::test]
async fn test_extraction_emits_event_via_stream() {
use futures::StreamExt;
let mock = mock_for("test", "test-model");
mock.queue_response(sample_tool_response(vec![ToolCall {
id: "call_0".into(),
name: "big_output".into(),
arguments: json!({"size": 80000}),
}]));
mock.queue_response(sample_response("Done"));
let mut registry: ToolRegistry<()> = ToolRegistry::new();
registry.register(BigOutputTool);
let params = ChatParams {
messages: vec![ChatMessage::user("Extract test")],
tools: Some(registry.definitions()),
..Default::default()
};
let config = ToolLoopConfig {
max_iterations: 10,
result_extractor: Some(Arc::new(TestExtractor { threshold: 100 })),
..Default::default()
};
let mut stream = tool_loop_stream(
Arc::new(mock),
Arc::new(registry),
params,
config,
Arc::new(()),
);
let mut found_extracted = false;
while let Some(event) = stream.next().await {
if let Ok(LoopEvent::ToolResultExtracted {
tool_name,
original_tokens,
extracted_tokens,
}) = event
{
assert_eq!(tool_name, "big_output");
assert!(original_tokens > 100);
assert_eq!(extracted_tokens, 20);
found_extracted = true;
}
}
assert!(
found_extracted,
"Should have emitted ToolResultExtracted event"
);
}
#[test]
fn test_extracted_result_type() {
let result = ExtractedResult {
content: "summary".into(),
original_tokens_est: 5000,
extracted_tokens_est: 100,
};
let cloned = result.clone();
assert_eq!(cloned.content, "summary");
let debug = format!("{result:?}");
assert!(debug.contains("5000"));
assert!(debug.contains("100"));
}