use std::sync::Arc;
use async_trait::async_trait;
use serde_json::json;
use rustagents::RustAgentsError;
use rustagents::harness::context::RunContext;
use rustagents::harness::events::RecordingListener;
use rustagents::harness::limits::RunLimits;
use rustagents::harness::message::{AssistantMessage, ContentBlock, Message};
use rustagents::harness::middleware::Middleware;
use rustagents::harness::model::ModelResponse;
use rustagents::harness::providers::MockModel;
use rustagents::harness::runtime::{AgentHarness, RunPolicy};
use rustagents::harness::testkit::{FakeTool, Trajectory};
use rustagents::harness::tool::ToolCall;
use rustagents::harness::usage::Usage;
struct CaptureMiddleware {
listener: Arc<RecordingListener>,
}
#[async_trait]
impl Middleware<(), ()> for CaptureMiddleware {
fn name(&self) -> &str {
"capture"
}
async fn before_agent(&self, ctx: &mut RunContext<()>, _state: &()) -> rustagents::Result<()> {
ctx.events.subscribe(self.listener.clone());
Ok(())
}
}
fn trajectory(listener: &Arc<RecordingListener>) -> Trajectory {
let events = listener.events().into_iter().map(|r| r.event).collect();
Trajectory::from_events(events)
}
fn tool_call_response(id: &str, name: &str, arguments: serde_json::Value) -> ModelResponse {
ModelResponse {
message: AssistantMessage {
id: Some(format!("msg-{id}")),
content: Vec::new(),
tool_calls: vec![ToolCall::new(id, name, arguments)],
usage: Some(Usage::new(7, 3)),
},
usage: Some(Usage::new(7, 3)),
finish_reason: Some("tool_calls".to_string()),
raw: None,
resolved_model: None,
}
}
fn text_response(text: &str, input: u64, output: u64) -> ModelResponse {
ModelResponse {
message: AssistantMessage {
id: None,
content: vec![ContentBlock::Text(text.to_string())],
tool_calls: Vec::new(),
usage: Some(Usage::new(input, output)),
},
usage: Some(Usage::new(input, output)),
finish_reason: Some("stop".to_string()),
raw: None,
resolved_model: None,
}
}
#[tokio::test]
async fn single_turn_response_completes_with_one_model_call() {
let listener = Arc::new(RecordingListener::new());
let mut harness: AgentHarness<()> = AgentHarness::new();
harness
.register_model("mock", Arc::new(MockModel::constant("hello there")))
.set_default_model("mock")
.push_middleware(Arc::new(CaptureMiddleware {
listener: listener.clone(),
}));
let run = harness
.invoke_default(&(), vec![Message::user("hi")])
.await
.expect("run succeeds");
assert_eq!(run.model_calls, 1);
assert_eq!(run.tool_calls, 0);
assert!(run.text().is_some(), "a final response should be produced");
let traj = trajectory(&listener);
traj.assert_model_called_times(1);
traj.assert_completed();
assert!(!traj.failed());
}
#[tokio::test]
async fn multi_step_tool_loop_calls_tool_then_finishes() {
let listener = Arc::new(RecordingListener::new());
let mut harness: AgentHarness<()> = AgentHarness::new();
harness
.register_model(
"mock",
Arc::new(MockModel::with_responses(vec![
tool_call_response("call-1", "lookup", json!({ "q": "x" })),
text_response("done", 4, 2),
])),
)
.set_default_model("mock")
.register_tool(Arc::new(FakeTool::returning("lookup", "tool-output")))
.push_middleware(Arc::new(CaptureMiddleware {
listener: listener.clone(),
}));
let run = harness
.invoke_default(&(), vec![Message::user("please look up")])
.await
.expect("run succeeds");
assert_eq!(run.model_calls, 2);
assert_eq!(run.tool_calls, 1);
let traj = trajectory(&listener);
traj.assert_tool_called("lookup");
assert_eq!(traj.tool_call_count("lookup"), 1);
traj.assert_model_called_times(2);
traj.assert_completed();
traj.assert_order(&["lookup", "model.completed"])
.expect("tool runs before the final model completion");
}
#[tokio::test]
async fn max_model_calls_limit_returns_limit_exceeded() {
let mut harness: AgentHarness<()> = AgentHarness::new();
harness
.register_model(
"mock",
Arc::new(MockModel::with_tool_call("spin", json!({}))),
)
.set_default_model("mock")
.register_tool(Arc::new(FakeTool::returning("spin", "again")))
.with_policy(RunPolicy {
limits: RunLimits::default().with_max_model_calls(1),
..RunPolicy::default()
});
let err = harness
.invoke_default(&(), vec![Message::user("go")])
.await
.expect_err("the model-call cap should be exceeded");
assert!(
matches!(err, RustAgentsError::LimitExceeded(_)),
"got {err:?}"
);
}
#[tokio::test]
async fn usage_accumulates_across_model_calls() {
let mut harness: AgentHarness<()> = AgentHarness::new();
harness
.register_model(
"mock",
Arc::new(MockModel::with_responses(vec![
tool_call_response("call-1", "lookup", json!({})),
text_response("done", 4, 2),
])),
)
.set_default_model("mock")
.register_tool(Arc::new(FakeTool::returning("lookup", "out")));
let run = harness
.invoke_default(&(), vec![Message::user("hi")])
.await
.expect("run succeeds");
assert_eq!(run.usage.calls, 2);
assert_eq!(run.usage.usage.input_tokens, 11);
assert_eq!(run.usage.usage.output_tokens, 5);
assert_eq!(run.usage.usage.total_tokens, 16);
}