use std::sync::Arc;
use std::sync::Mutex;
use async_trait::async_trait;
use serde_json::json;
use crate::error::{Result, RustAgentsError};
use crate::harness::context::{RunConfig, RunContext};
use crate::harness::limits::RunLimits;
use crate::harness::message::{AssistantMessage, ContentBlock, Message};
use crate::harness::middleware::{AgentRun, Middleware};
use crate::harness::model::{ChatModel, ModelRequest, ModelResponse, ResponseFormat, ToolChoice};
use crate::harness::providers::MockModel;
use crate::harness::retry::{FallbackPolicy, RetryPolicy};
use crate::harness::runtime::{AgentHarness, RunPolicy};
use crate::harness::tool::{Tool, ToolCall, ToolResult, ToolSchema};
use crate::harness::usage::Usage;
struct FakeTool {
name: &'static str,
reply: &'static str,
calls: Mutex<usize>,
}
impl FakeTool {
fn new(name: &'static str, reply: &'static str) -> Self {
Self {
name,
reply,
calls: Mutex::new(0),
}
}
}
#[async_trait]
impl Tool<()> for FakeTool {
fn name(&self) -> &str {
self.name
}
fn description(&self) -> &str {
"fake tool"
}
fn schema(&self) -> ToolSchema {
ToolSchema::new(self.name, "fake tool", json!({"type": "object"}))
}
async fn call(&self, _state: &(), call: ToolCall) -> Result<ToolResult> {
*self.calls.lock().unwrap() += 1;
Ok(ToolResult::text(call.id, self.name, self.reply))
}
}
fn tool_call_response(id: &str, name: &str, arguments: serde_json::Value) -> ModelResponse {
ModelResponse {
message: AssistantMessage {
id: Some(format!("msg-{id}")),
content: Vec::new(),
tool_calls: vec![ToolCall::new(id, name, arguments)],
usage: Some(Usage::new(7, 3)),
},
usage: Some(Usage::new(7, 3)),
finish_reason: Some("tool_calls".to_string()),
raw: None,
resolved_model: None,
}
}
fn text_response(text: &str, input: u64, output: u64) -> ModelResponse {
ModelResponse {
message: AssistantMessage {
id: None,
content: vec![ContentBlock::Text(text.to_string())],
tool_calls: Vec::new(),
usage: Some(Usage::new(input, output)),
},
usage: Some(Usage::new(input, output)),
finish_reason: Some("stop".to_string()),
raw: None,
resolved_model: None,
}
}
struct InjectMiddleware {
text: &'static str,
}
#[async_trait]
impl Middleware<(), ()> for InjectMiddleware {
fn name(&self) -> &str {
"inject"
}
async fn before_model(
&self,
_ctx: &mut RunContext<()>,
_state: &(),
request: &mut ModelRequest,
) -> Result<()> {
request.messages.push(Message::user(self.text));
request.tool_choice = ToolChoice::None;
Ok(())
}
}
struct FailingModel {
attempts: Mutex<usize>,
}
#[async_trait]
impl ChatModel<()> for FailingModel {
async fn invoke(&self, _state: &(), _request: ModelRequest) -> Result<ModelResponse> {
*self.attempts.lock().unwrap() += 1;
Err(RustAgentsError::Model("transient boom".to_string()))
}
}
#[tokio::test]
async fn single_model_call_no_tools() {
let mut harness: AgentHarness<()> = AgentHarness::new();
harness.register_model("mock", Arc::new(MockModel::constant("hello there")));
let run = harness
.invoke_default(&(), vec![Message::user("hi")])
.await
.expect("run succeeds");
assert_eq!(run.model_calls, 1);
assert_eq!(run.tool_calls, 0);
assert_eq!(run.steps, 1);
assert_eq!(run.text(), Some("hello there".to_string()));
assert_eq!(run.messages.len(), 2);
}
#[tokio::test]
async fn model_requests_tool_then_finishes() {
let mut harness: AgentHarness<()> = AgentHarness::new();
harness.register_model(
"mock",
Arc::new(MockModel::with_responses(vec![
tool_call_response("call-1", "lookup", json!({"q": "x"})),
text_response("done", 4, 2),
])),
);
harness.register_tool(Arc::new(FakeTool::new("lookup", "tool-output")));
let run = harness
.invoke_default(&(), vec![Message::user("please look up")])
.await
.expect("run succeeds");
assert_eq!(run.model_calls, 2);
assert_eq!(run.tool_calls, 1);
assert_eq!(run.steps, 2);
assert_eq!(run.text(), Some("done".to_string()));
assert_eq!(run.messages.len(), 4);
assert!(matches!(run.messages[2], Message::Tool(_)));
assert_eq!(run.messages[2].text(), "tool-output");
}
#[tokio::test]
async fn max_model_calls_limit_triggers_limit_exceeded() {
let mut harness: AgentHarness<()> = AgentHarness::new();
harness.register_model(
"mock",
Arc::new(MockModel::with_tool_call("spin", json!({}))),
);
harness.register_tool(Arc::new(FakeTool::new("spin", "again")));
harness.with_policy(RunPolicy {
limits: RunLimits::default().with_max_model_calls(1),
..RunPolicy::default()
});
let err = harness
.invoke_default(&(), vec![Message::user("go")])
.await
.expect_err("limit should be exceeded");
assert!(
matches!(err, RustAgentsError::LimitExceeded(_)),
"got {err:?}"
);
}
#[tokio::test]
async fn max_tool_calls_limit_triggers_limit_exceeded() {
let mut harness: AgentHarness<()> = AgentHarness::new();
harness.register_model(
"mock",
Arc::new(MockModel::with_tool_call("spin", json!({}))),
);
harness.register_tool(Arc::new(FakeTool::new("spin", "again")));
harness.with_policy(RunPolicy {
limits: RunLimits::default()
.with_max_model_calls(10)
.with_max_tool_calls(0),
..RunPolicy::default()
});
let err = harness
.invoke_default(&(), vec![Message::user("go")])
.await
.expect_err("tool limit should be exceeded");
assert!(
matches!(err, RustAgentsError::LimitExceeded(_)),
"got {err:?}"
);
}
#[tokio::test]
async fn before_model_middleware_mutates_request() {
let mut harness: AgentHarness<()> = AgentHarness::new();
harness.register_model("mock", Arc::new(MockModel::echo()));
harness.push_middleware(Arc::new(InjectMiddleware { text: "injected" }));
let run = harness
.invoke_default(&(), vec![Message::user("original")])
.await
.expect("run succeeds");
assert_eq!(run.text(), Some("injected".to_string()));
}
#[tokio::test]
async fn usage_accumulates_across_calls() {
let mut harness: AgentHarness<()> = AgentHarness::new();
harness.register_model(
"mock",
Arc::new(MockModel::with_responses(vec![
tool_call_response("call-1", "lookup", json!({})),
text_response("done", 4, 2),
])),
);
harness.register_tool(Arc::new(FakeTool::new("lookup", "out")));
let run = harness
.invoke_default(&(), vec![Message::user("hi")])
.await
.expect("run succeeds");
assert_eq!(run.usage.calls, 2);
assert_eq!(run.usage.usage.input_tokens, 11);
assert_eq!(run.usage.usage.output_tokens, 5);
}
#[tokio::test]
async fn tool_not_found_errors() {
let mut harness: AgentHarness<()> = AgentHarness::new();
harness.register_model(
"mock",
Arc::new(MockModel::with_tool_call("missing", json!({}))),
);
let err = harness
.invoke_default(&(), vec![Message::user("go")])
.await
.expect_err("tool should be missing");
match err {
RustAgentsError::ToolNotFound(name) => assert_eq!(name, "missing"),
other => panic!("expected ToolNotFound, got {other:?}"),
}
}
#[tokio::test]
async fn structured_output_is_extracted() {
let mut harness: AgentHarness<()> = AgentHarness::new();
harness.register_model(
"mock",
Arc::new(MockModel::constant(r#"{"value":"hi","score":42}"#)),
);
harness.with_policy(RunPolicy {
default_response_format: Some(ResponseFormat::json_schema(
"answer",
json!({"type": "object"}),
)),
..RunPolicy::default()
});
let run = harness
.invoke_default(&(), vec![Message::user("answer")])
.await
.expect("run succeeds");
let structured = run.structured.expect("structured output present");
assert_eq!(structured["value"], "hi");
assert_eq!(structured["score"], 42);
}
#[tokio::test]
async fn no_model_registered_errors() {
let harness: AgentHarness<()> = AgentHarness::new();
let err = harness
.invoke_default(&(), vec![Message::user("hi")])
.await
.expect_err("no model");
assert!(
matches!(err, RustAgentsError::ModelNotFound(_)),
"got {err:?}"
);
}
#[tokio::test]
async fn retry_then_fallback_succeeds() {
let mut harness: AgentHarness<()> = AgentHarness::new();
let failing = Arc::new(FailingModel {
attempts: Mutex::new(0),
});
harness.register_model("primary", failing.clone());
harness.register_model("backup", Arc::new(MockModel::constant("recovered")));
harness.with_policy(RunPolicy {
retry: RetryPolicy::default().with_max_attempts(2),
fallback: Some(FallbackPolicy::new(["primary", "backup"])),
..RunPolicy::default()
});
let run = harness
.invoke_default(&(), vec![Message::user("hi")])
.await
.expect("fallback recovers");
assert_eq!(run.text(), Some("recovered".to_string()));
assert_eq!(*failing.attempts.lock().unwrap(), 2);
}
#[tokio::test]
async fn non_retryable_or_exhausted_without_fallback_errors() {
let mut harness: AgentHarness<()> = AgentHarness::new();
harness.register_model(
"primary",
Arc::new(FailingModel {
attempts: Mutex::new(0),
}),
);
harness.with_policy(RunPolicy {
retry: RetryPolicy::default().with_max_attempts(1),
..RunPolicy::default()
});
let err = harness
.invoke_default(&(), vec![Message::user("hi")])
.await
.expect_err("no fallback, error propagates");
assert!(matches!(err, RustAgentsError::Model(_)), "got {err:?}");
}
#[tokio::test]
async fn invoke_with_status_reports_completed() {
use crate::harness::ids::{ExecutionStatus, HarnessPhase};
let mut harness: AgentHarness<()> = AgentHarness::new();
harness.register_model("mock", Arc::new(MockModel::constant("ok")));
let result = harness
.invoke_with_status(&(), (), RunConfig::new("run-x"), vec![Message::user("hi")])
.await
.expect("run succeeds");
assert_eq!(result.status.status, ExecutionStatus::Completed);
assert_eq!(result.status.current_phase, HarnessPhase::Done);
assert_eq!(result.status.model_calls, 1);
assert_eq!(result.run.text(), Some("ok".to_string()));
}
#[test]
fn agent_run_default_is_empty() {
let run = AgentRun::new();
assert_eq!(run.model_calls, 0);
}