use std::sync::Arc;
use std::sync::Mutex;
use async_trait::async_trait;
use serde_json::json;
use crate::error::{Result, TinyAgentsError};
use crate::harness::context::{RunConfig, RunContext};
use crate::harness::limits::RunLimits;
use crate::harness::message::{AssistantMessage, ContentBlock, Message};
use crate::harness::middleware::{AgentRun, Middleware};
use crate::harness::model::{
ChatModel, ModelProfile, ModelRequest, ModelResponse, ResponseFormat, ToolChoice,
};
use crate::harness::providers::MockModel;
use crate::harness::retry::{FallbackPolicy, RetryPolicy};
use crate::harness::runtime::{AgentHarness, RunPolicy};
use crate::harness::tool::{Tool, ToolCall, ToolResult, ToolSchema};
use crate::harness::usage::Usage;
struct FakeTool {
name: &'static str,
reply: &'static str,
calls: Mutex<usize>,
}
impl FakeTool {
fn new(name: &'static str, reply: &'static str) -> Self {
Self {
name,
reply,
calls: Mutex::new(0),
}
}
}
#[async_trait]
impl Tool<()> for FakeTool {
fn name(&self) -> &str {
self.name
}
fn description(&self) -> &str {
"fake tool"
}
fn schema(&self) -> ToolSchema {
ToolSchema::new(self.name, "fake tool", json!({"type": "object"}))
}
async fn call(&self, _state: &(), call: ToolCall) -> Result<ToolResult> {
*self.calls.lock().unwrap() += 1;
Ok(ToolResult::text(call.id, self.name, self.reply))
}
}
fn tool_call_response(id: &str, name: &str, arguments: serde_json::Value) -> ModelResponse {
ModelResponse {
message: AssistantMessage {
id: Some(format!("msg-{id}")),
content: Vec::new(),
tool_calls: vec![ToolCall::new(id, name, arguments)],
usage: Some(Usage::new(7, 3)),
},
usage: Some(Usage::new(7, 3)),
finish_reason: Some("tool_calls".to_string()),
raw: None,
resolved_model: None,
}
}
fn text_response(text: &str, input: u64, output: u64) -> ModelResponse {
ModelResponse {
message: AssistantMessage {
id: None,
content: vec![ContentBlock::Text(text.to_string())],
tool_calls: Vec::new(),
usage: Some(Usage::new(input, output)),
},
usage: Some(Usage::new(input, output)),
finish_reason: Some("stop".to_string()),
raw: None,
resolved_model: None,
}
}
struct InjectMiddleware {
text: &'static str,
}
#[async_trait]
impl Middleware<(), ()> for InjectMiddleware {
fn name(&self) -> &str {
"inject"
}
async fn before_model(
&self,
_ctx: &mut RunContext<()>,
_state: &(),
request: &mut ModelRequest,
) -> Result<()> {
request.messages.push(Message::user(self.text));
request.tool_choice = ToolChoice::None;
Ok(())
}
}
struct ToolStructuredModel {
profile: ModelProfile,
}
impl ToolStructuredModel {
fn new() -> Self {
Self {
profile: ModelProfile {
tool_calling: true,
native_structured_output: false,
json_schema: false,
..ModelProfile::default()
},
}
}
}
#[async_trait]
impl ChatModel<()> for ToolStructuredModel {
fn profile(&self) -> Option<&ModelProfile> {
Some(&self.profile)
}
async fn invoke(&self, _state: &(), request: ModelRequest) -> Result<ModelResponse> {
assert_eq!(request.tool_choice, ToolChoice::Tool("answer".to_string()));
let name = request
.tools
.last()
.map(|t| t.name.clone())
.unwrap_or_default();
Ok(tool_call_response(
"s1",
&name,
json!({"value":"viatool","score":7}),
))
}
}
struct FailingModel {
attempts: Mutex<usize>,
}
#[async_trait]
impl ChatModel<()> for FailingModel {
async fn invoke(&self, _state: &(), _request: ModelRequest) -> Result<ModelResponse> {
*self.attempts.lock().unwrap() += 1;
Err(TinyAgentsError::Model("transient boom".to_string()))
}
}
#[tokio::test]
async fn single_model_call_no_tools() {
let mut harness: AgentHarness<()> = AgentHarness::new();
harness.register_model("mock", Arc::new(MockModel::constant("hello there")));
let run = harness
.invoke_default(&(), vec![Message::user("hi")])
.await
.expect("run succeeds");
assert_eq!(run.model_calls, 1);
assert_eq!(run.tool_calls, 0);
assert_eq!(run.steps, 1);
assert_eq!(run.text(), Some("hello there".to_string()));
assert_eq!(run.messages.len(), 2);
}
#[tokio::test]
async fn model_requests_tool_then_finishes() {
let mut harness: AgentHarness<()> = AgentHarness::new();
harness.register_model(
"mock",
Arc::new(MockModel::with_responses(vec![
tool_call_response("call-1", "lookup", json!({"q": "x"})),
text_response("done", 4, 2),
])),
);
harness.register_tool(Arc::new(FakeTool::new("lookup", "tool-output")));
let run = harness
.invoke_default(&(), vec![Message::user("please look up")])
.await
.expect("run succeeds");
assert_eq!(run.model_calls, 2);
assert_eq!(run.tool_calls, 1);
assert_eq!(run.steps, 2);
assert_eq!(run.text(), Some("done".to_string()));
assert_eq!(run.messages.len(), 4);
assert!(matches!(run.messages[2], Message::Tool(_)));
assert_eq!(run.messages[2].text(), "tool-output");
}
#[tokio::test]
async fn max_model_calls_limit_triggers_limit_exceeded() {
let mut harness: AgentHarness<()> = AgentHarness::new();
harness.register_model(
"mock",
Arc::new(MockModel::with_tool_call("spin", json!({}))),
);
harness.register_tool(Arc::new(FakeTool::new("spin", "again")));
harness.with_policy(RunPolicy {
limits: RunLimits::default().with_max_model_calls(1),
..RunPolicy::default()
});
let err = harness
.invoke_default(&(), vec![Message::user("go")])
.await
.expect_err("limit should be exceeded");
assert!(
matches!(err, TinyAgentsError::LimitExceeded(_)),
"got {err:?}"
);
}
#[tokio::test]
async fn max_tool_calls_limit_triggers_limit_exceeded() {
let mut harness: AgentHarness<()> = AgentHarness::new();
harness.register_model(
"mock",
Arc::new(MockModel::with_tool_call("spin", json!({}))),
);
harness.register_tool(Arc::new(FakeTool::new("spin", "again")));
harness.with_policy(RunPolicy {
limits: RunLimits::default()
.with_max_model_calls(10)
.with_max_tool_calls(0),
..RunPolicy::default()
});
let err = harness
.invoke_default(&(), vec![Message::user("go")])
.await
.expect_err("tool limit should be exceeded");
assert!(
matches!(err, TinyAgentsError::LimitExceeded(_)),
"got {err:?}"
);
}
#[tokio::test]
async fn before_model_middleware_mutates_request() {
let mut harness: AgentHarness<()> = AgentHarness::new();
harness.register_model("mock", Arc::new(MockModel::echo()));
harness.push_middleware(Arc::new(InjectMiddleware { text: "injected" }));
let run = harness
.invoke_default(&(), vec![Message::user("original")])
.await
.expect("run succeeds");
assert_eq!(run.text(), Some("injected".to_string()));
}
#[tokio::test]
async fn usage_accumulates_across_calls() {
let mut harness: AgentHarness<()> = AgentHarness::new();
harness.register_model(
"mock",
Arc::new(MockModel::with_responses(vec![
tool_call_response("call-1", "lookup", json!({})),
text_response("done", 4, 2),
])),
);
harness.register_tool(Arc::new(FakeTool::new("lookup", "out")));
let run = harness
.invoke_default(&(), vec![Message::user("hi")])
.await
.expect("run succeeds");
assert_eq!(run.usage.calls, 2);
assert_eq!(run.usage.usage.input_tokens, 11);
assert_eq!(run.usage.usage.output_tokens, 5);
}
#[tokio::test]
async fn tool_not_found_errors() {
let mut harness: AgentHarness<()> = AgentHarness::new();
harness.register_model(
"mock",
Arc::new(MockModel::with_tool_call("missing", json!({}))),
);
let err = harness
.invoke_default(&(), vec![Message::user("go")])
.await
.expect_err("tool should be missing");
match err {
TinyAgentsError::ToolNotFound(name) => assert_eq!(name, "missing"),
other => panic!("expected ToolNotFound, got {other:?}"),
}
}
#[tokio::test]
async fn structured_output_is_extracted() {
let mut harness: AgentHarness<()> = AgentHarness::new();
harness.register_model(
"mock",
Arc::new(MockModel::constant(r#"{"value":"hi","score":42}"#)),
);
harness.with_policy(RunPolicy {
default_response_format: Some(ResponseFormat::json_schema(
"answer",
json!({"type": "object"}),
)),
..RunPolicy::default()
});
let run = harness
.invoke_default(&(), vec![Message::user("answer")])
.await
.expect("run succeeds");
let structured = run.structured.expect("structured output present");
assert_eq!(structured["value"], "hi");
assert_eq!(structured["score"], 42);
}
#[tokio::test]
async fn auto_format_uses_provider_schema_for_native_model() {
let mut harness: AgentHarness<()> = AgentHarness::new();
harness.register_model(
"mock",
Arc::new(MockModel::constant(r#"{"value":"native","score":1}"#)),
);
harness.with_policy(RunPolicy {
default_response_format: Some(ResponseFormat::auto("answer", json!({"type": "object"}))),
..RunPolicy::default()
});
let run = harness
.invoke_default(&(), vec![Message::user("answer")])
.await
.expect("run succeeds");
let structured = run.structured.expect("structured output present");
assert_eq!(structured["value"], "native");
}
#[tokio::test]
async fn auto_format_uses_tool_call_for_non_native_model() {
let mut harness: AgentHarness<()> = AgentHarness::new();
harness.register_model("tool", Arc::new(ToolStructuredModel::new()));
harness.with_policy(RunPolicy {
default_response_format: Some(ResponseFormat::auto("answer", json!({"type": "object"}))),
..RunPolicy::default()
});
let run = harness
.invoke_default(&(), vec![Message::user("answer")])
.await
.expect("run succeeds");
let structured = run.structured.expect("structured output present");
assert_eq!(structured["value"], "viatool");
assert_eq!(structured["score"], 7);
assert_eq!(run.model_calls, 1);
}
#[tokio::test]
async fn no_model_registered_errors() {
let harness: AgentHarness<()> = AgentHarness::new();
let err = harness
.invoke_default(&(), vec![Message::user("hi")])
.await
.expect_err("no model");
assert!(
matches!(err, TinyAgentsError::ModelNotFound(_)),
"got {err:?}"
);
}
#[tokio::test]
async fn retry_then_fallback_succeeds() {
let mut harness: AgentHarness<()> = AgentHarness::new();
let failing = Arc::new(FailingModel {
attempts: Mutex::new(0),
});
harness.register_model("primary", failing.clone());
harness.register_model("backup", Arc::new(MockModel::constant("recovered")));
harness.with_policy(RunPolicy {
retry: RetryPolicy::default().with_max_attempts(2),
fallback: Some(FallbackPolicy::new(["primary", "backup"])),
..RunPolicy::default()
});
let run = harness
.invoke_default(&(), vec![Message::user("hi")])
.await
.expect("fallback recovers");
assert_eq!(run.text(), Some("recovered".to_string()));
assert_eq!(*failing.attempts.lock().unwrap(), 2);
}
#[tokio::test]
async fn non_retryable_or_exhausted_without_fallback_errors() {
let mut harness: AgentHarness<()> = AgentHarness::new();
harness.register_model(
"primary",
Arc::new(FailingModel {
attempts: Mutex::new(0),
}),
);
harness.with_policy(RunPolicy {
retry: RetryPolicy::default().with_max_attempts(1),
..RunPolicy::default()
});
let err = harness
.invoke_default(&(), vec![Message::user("hi")])
.await
.expect_err("no fallback, error propagates");
assert!(matches!(err, TinyAgentsError::Model(_)), "got {err:?}");
}
#[tokio::test]
async fn invoke_with_status_reports_completed() {
use crate::harness::ids::{ExecutionStatus, HarnessPhase};
let mut harness: AgentHarness<()> = AgentHarness::new();
harness.register_model("mock", Arc::new(MockModel::constant("ok")));
let result = harness
.invoke_with_status(&(), (), RunConfig::new("run-x"), vec![Message::user("hi")])
.await
.expect("run succeeds");
assert_eq!(result.status.status, ExecutionStatus::Completed);
assert_eq!(result.status.current_phase, HarnessPhase::Done);
assert_eq!(result.status.model_calls, 1);
assert_eq!(result.run.text(), Some("ok".to_string()));
}
#[test]
fn agent_run_default_is_empty() {
let run = AgentRun::new();
assert_eq!(run.model_calls, 0);
}
struct DeltaRecorder {
count: Arc<Mutex<usize>>,
texts: Arc<Mutex<Vec<String>>>,
}
#[async_trait]
impl Middleware<(), ()> for DeltaRecorder {
fn name(&self) -> &str {
"delta-recorder"
}
async fn on_model_delta(
&self,
_ctx: &mut RunContext<()>,
_state: &(),
delta: &mut crate::harness::model::ModelDelta,
) -> Result<()> {
*self.count.lock().unwrap() += 1;
self.texts.lock().unwrap().push(delta.content.clone());
Ok(())
}
}
#[tokio::test]
async fn invoke_streaming_fires_on_model_delta_per_delta_and_accumulates() {
use crate::harness::testkit::StreamingMock;
let count = Arc::new(Mutex::new(0usize));
let texts = Arc::new(Mutex::new(Vec::new()));
let mut harness: AgentHarness<()> = AgentHarness::new();
harness.register_model(
"stream",
Arc::new(StreamingMock::from_text_chunks(["Hel", "lo, ", "world"])),
);
harness.push_middleware(Arc::new(DeltaRecorder {
count: count.clone(),
texts: texts.clone(),
}));
let run = harness
.invoke_streaming(
&(),
(),
RunConfig::new("stream-run"),
vec![Message::user("hi")],
)
.await
.expect("streaming run succeeds");
assert_eq!(run.model_calls, 1);
assert_eq!(run.text(), Some("Hello, world".to_string()));
assert_eq!(*count.lock().unwrap(), 3);
assert_eq!(
*texts.lock().unwrap(),
vec!["Hel".to_string(), "lo, ".to_string(), "world".to_string()]
);
}
#[tokio::test]
async fn invoke_streaming_emits_model_delta_events() {
use crate::harness::testkit::{EventRecorder, StreamingMock};
let mut harness: AgentHarness<()> = AgentHarness::new();
harness.register_model(
"stream",
Arc::new(StreamingMock::from_text_chunks(["a", "b"])),
);
let recorder = EventRecorder::new();
let ctx = RunContext::new(RunConfig::new("stream-run"), ()).with_events(recorder.sink());
let run = harness
.invoke_streaming_in_context(&(), ctx, vec![Message::user("hi")])
.await
.expect("streaming run succeeds");
assert_eq!(run.text(), Some("ab".to_string()));
let delta_events = recorder
.kinds()
.into_iter()
.filter(|k| k == "model.delta")
.count();
assert_eq!(delta_events, 2, "one model.delta event per streamed delta");
}
use crate::harness::cancel::CancellationToken;
struct CountingToolModel {
name: &'static str,
invocations: Arc<Mutex<usize>>,
}
#[async_trait]
impl ChatModel<()> for CountingToolModel {
async fn invoke(&self, _state: &(), _request: ModelRequest) -> Result<ModelResponse> {
*self.invocations.lock().unwrap() += 1;
Ok(tool_call_response("call-1", self.name, json!({})))
}
}
struct CancelOnCallTool {
token: CancellationToken,
}
#[async_trait]
impl Tool<()> for CancelOnCallTool {
fn name(&self) -> &str {
"cancel_me"
}
fn description(&self) -> &str {
"cancels the run"
}
fn schema(&self) -> ToolSchema {
ToolSchema::new("cancel_me", "cancels the run", json!({"type": "object"}))
}
async fn call(&self, _state: &(), call: ToolCall) -> Result<ToolResult> {
self.token.cancel();
Ok(ToolResult::text(call.id, "cancel_me", "cancelled"))
}
}
#[tokio::test]
async fn token_cancelled_before_run_yields_cancelled() {
let invocations = Arc::new(Mutex::new(0usize));
let mut harness: AgentHarness<()> = AgentHarness::new();
harness.register_model(
"mock",
Arc::new(CountingToolModel {
name: "cancel_me",
invocations: invocations.clone(),
}),
);
let token = CancellationToken::new();
token.cancel();
let ctx = RunContext::new(RunConfig::new("cancel-run"), ()).with_cancellation(token);
let err = harness
.invoke_in_context(&(), ctx, vec![Message::user("hi")])
.await
.expect_err("a pre-cancelled run must not complete");
assert!(matches!(err, TinyAgentsError::Cancelled), "got {err:?}");
assert_eq!(*invocations.lock().unwrap(), 0);
}
#[tokio::test]
async fn cancelled_mid_run_stops_before_next_model_call() {
let invocations = Arc::new(Mutex::new(0usize));
let token = CancellationToken::new();
let mut harness: AgentHarness<()> = AgentHarness::new();
harness.register_model(
"mock",
Arc::new(CountingToolModel {
name: "cancel_me",
invocations: invocations.clone(),
}),
);
harness.register_tool(Arc::new(CancelOnCallTool {
token: token.clone(),
}));
let ctx = RunContext::new(RunConfig::new("cancel-mid"), ()).with_cancellation(token);
let err = harness
.invoke_in_context(&(), ctx, vec![Message::user("go")])
.await
.expect_err("cancellation during a tool call must stop the run");
assert!(matches!(err, TinyAgentsError::Cancelled), "got {err:?}");
assert_eq!(*invocations.lock().unwrap(), 1);
}
#[tokio::test]
async fn slow_model_call_is_timed_out_by_remaining_budget() {
use std::time::Duration;
use crate::harness::testkit::SlowModel;
let mut harness: AgentHarness<()> = AgentHarness::new();
harness.register_model(
"slow",
Arc::new(SlowModel::new(Duration::from_millis(200), "too late")),
);
let config = RunConfig::new("timeout-run").with_timeout_ms(20);
let err = harness
.invoke(&(), (), config, vec![Message::user("hi")])
.await
.expect_err("a model call slower than the budget must time out");
assert!(matches!(err, TinyAgentsError::Timeout(_)), "got {err:?}");
}
#[tokio::test]
async fn fast_model_call_succeeds_under_same_budget() {
let mut harness: AgentHarness<()> = AgentHarness::new();
harness.register_model("fast", Arc::new(MockModel::constant("done")));
let config = RunConfig::new("fast-run").with_timeout_ms(20);
let run = harness
.invoke(&(), (), config, vec![Message::user("hi")])
.await
.expect("a fast model call completes within the budget");
assert_eq!(run.text(), Some("done".to_string()));
}
#[tokio::test]
async fn slow_streaming_model_call_is_timed_out() {
use std::time::Duration;
use crate::harness::testkit::SlowModel;
let mut harness: AgentHarness<()> = AgentHarness::new();
harness.register_model(
"slow",
Arc::new(SlowModel::new(Duration::from_millis(200), "too late")),
);
let config = RunConfig::new("timeout-stream-run").with_timeout_ms(20);
let err = harness
.invoke_streaming(&(), (), config, vec![Message::user("hi")])
.await
.expect_err("a slow streaming model call must time out");
assert!(matches!(err, TinyAgentsError::Timeout(_)), "got {err:?}");
}
#[tokio::test]
async fn response_cache_serves_repeated_request_without_calling_model() {
use crate::harness::cache::InMemoryResponseCache;
use crate::harness::testkit::EventRecorder;
let model = Arc::new(MockModel::with_responses(vec![
text_response("first-answer", 4, 2),
text_response("second-answer", 4, 2),
]));
let cache = Arc::new(InMemoryResponseCache::new());
let mut harness: AgentHarness<()> = AgentHarness::new();
harness.register_model("mock", model.clone());
harness.with_response_cache(cache.clone());
let recorder1 = EventRecorder::new();
let ctx1 = RunContext::new(RunConfig::new("cache-run"), ()).with_events(recorder1.sink());
let run1 = harness
.invoke_in_context(&(), ctx1, vec![Message::user("same question")])
.await
.expect("first run succeeds");
assert_eq!(model.call_count(), 1, "model invoked once on first run");
assert_eq!(run1.text(), Some("first-answer".to_string()));
assert!(
recorder1.kinds().iter().any(|k| k == "cache.miss"),
"first run should emit a cache miss"
);
assert!(
!recorder1.kinds().iter().any(|k| k == "cache.hit"),
"first run should not emit a cache hit"
);
let recorder2 = EventRecorder::new();
let ctx2 = RunContext::new(RunConfig::new("cache-run-2"), ()).with_events(recorder2.sink());
let run2 = harness
.invoke_in_context(&(), ctx2, vec![Message::user("same question")])
.await
.expect("second run succeeds");
assert_eq!(
model.call_count(),
1,
"model must NOT be invoked again on a cache hit"
);
assert_eq!(
run2.text(),
Some("first-answer".to_string()),
"cached response text is reused"
);
assert!(
recorder2.kinds().iter().any(|k| k == "cache.hit"),
"second run should emit a cache hit"
);
assert_eq!(run2.model_calls, 1);
}
#[tokio::test]
async fn no_cache_attached_invokes_model_each_run() {
let model = Arc::new(MockModel::echo());
let mut harness: AgentHarness<()> = AgentHarness::new();
harness.register_model("mock", model.clone());
harness
.invoke_default(&(), vec![Message::user("hello")])
.await
.expect("first run succeeds");
harness
.invoke_default(&(), vec![Message::user("hello")])
.await
.expect("second run succeeds");
assert_eq!(
model.call_count(),
2,
"without a cache the model is invoked on every run"
);
}
#[tokio::test]
async fn request_cache_policy_overrides_run_policy_to_disable_caching() {
use crate::harness::cache::{CachePolicy, InMemoryResponseCache};
struct DisableCaching;
#[async_trait]
impl Middleware<(), ()> for DisableCaching {
fn name(&self) -> &str {
"disable-caching"
}
async fn before_model(
&self,
_ctx: &mut RunContext<()>,
_state: &(),
request: &mut ModelRequest,
) -> Result<()> {
request.cache_policy = Some(CachePolicy {
response_cache_enabled: false,
protect_prompt_prefix: false,
});
Ok(())
}
}
let model = Arc::new(MockModel::echo());
let cache = Arc::new(InMemoryResponseCache::new());
let mut harness: AgentHarness<()> = AgentHarness::new();
harness.register_model("mock", model.clone());
harness.with_response_cache(cache.clone());
harness.push_middleware(Arc::new(DisableCaching));
harness
.invoke_default(&(), vec![Message::user("hello")])
.await
.expect("first run succeeds");
harness
.invoke_default(&(), vec![Message::user("hello")])
.await
.expect("second run succeeds");
assert_eq!(
model.call_count(),
2,
"request-level cache_policy disabling caching must bypass the cache"
);
}