mod types;
pub use types::*;
use std::future::Future;
use std::sync::Arc;
use std::time::Duration;
use crate::error::{Result, TinyAgentsError};
use crate::harness::cache::{ResponseCache, cache_key};
use crate::harness::context::{RunConfig, RunContext};
use crate::harness::events::{AgentEvent, HarnessRunStatus, LimitKind};
use crate::harness::ids::{CallId, ComponentId, HarnessPhase};
use crate::harness::message::{Message, MessageDelta};
use crate::harness::middleware::{
AgentRun, BoxModelFuture, BoxToolFuture, ModelBaseCall, ToolBaseCall,
};
use crate::harness::model::{
ChatModel, ModelDelta, ModelRequest, ModelResolutionSource, ModelResponse, ModelStreamItem,
ResolvedModel, ResolvedModelBinding, ResponseFormat, StreamAccumulator, ToolChoice,
};
use crate::harness::retry::is_retryable;
use crate::harness::runtime::AgentHarness;
use crate::harness::structured::{StructuredExtractor, StructuredStrategy};
use crate::harness::tool::{Tool, ToolCall, ToolSchema};
use futures::StreamExt;
use serde_json::Value;
impl<State: Send + Sync, Ctx: Send + Sync> AgentHarness<State, Ctx> {
pub async fn invoke(
&self,
state: &State,
ctx_data: Ctx,
config: RunConfig,
input: Vec<Message>,
) -> Result<AgentRun> {
self.invoke_with_status(state, ctx_data, config, input)
.await
.map(|result| result.run)
}
pub async fn invoke_default(&self, state: &State, input: Vec<Message>) -> Result<AgentRun>
where
Ctx: Default,
{
self.invoke(state, Ctx::default(), RunConfig::new("run"), input)
.await
}
pub async fn invoke_with_status(
&self,
state: &State,
ctx_data: Ctx,
config: RunConfig,
input: Vec<Message>,
) -> Result<AgentLoopResult> {
let ctx = RunContext::new(config, ctx_data);
self.drive(state, ctx, input, false).await
}
pub async fn invoke_in_context(
&self,
state: &State,
ctx: RunContext<Ctx>,
input: Vec<Message>,
) -> Result<AgentRun> {
self.drive(state, ctx, input, false)
.await
.map(|result| result.run)
}
pub async fn invoke_in_context_with_status(
&self,
state: &State,
ctx: RunContext<Ctx>,
input: Vec<Message>,
) -> Result<AgentLoopResult> {
self.drive(state, ctx, input, false).await
}
pub async fn invoke_streaming(
&self,
state: &State,
ctx_data: Ctx,
config: RunConfig,
input: Vec<Message>,
) -> Result<AgentRun> {
let ctx = RunContext::new(config, ctx_data);
self.drive(state, ctx, input, true)
.await
.map(|result| result.run)
}
pub async fn invoke_streaming_default(
&self,
state: &State,
input: Vec<Message>,
) -> Result<AgentRun>
where
Ctx: Default,
{
self.invoke_streaming(state, Ctx::default(), RunConfig::new("run"), input)
.await
}
pub async fn invoke_streaming_in_context(
&self,
state: &State,
ctx: RunContext<Ctx>,
input: Vec<Message>,
) -> Result<AgentRun> {
self.drive(state, ctx, input, true)
.await
.map(|result| result.run)
}
pub async fn invoke_streaming_in_context_with_status(
&self,
state: &State,
ctx: RunContext<Ctx>,
input: Vec<Message>,
) -> Result<AgentLoopResult> {
self.drive(state, ctx, input, true).await
}
async fn drive(
&self,
state: &State,
mut ctx: RunContext<Ctx>,
input: Vec<Message>,
streaming: bool,
) -> Result<AgentLoopResult> {
let run_id = ctx.config.run_id.clone();
let thread_id = ctx.config.thread_id.clone();
let mut status = HarnessRunStatus::new(run_id.clone(), ComponentId::new("agent_loop"));
if let Some(thread) = thread_id {
status = status.with_thread(thread);
}
let mut run = AgentRun::new();
match self
.run_loop(state, &mut ctx, &mut run, &mut status, input, streaming)
.await
{
Ok(()) => {
status.mark_completed();
Ok(AgentLoopResult { run, status })
}
Err(error) => {
let record = ctx.emit(AgentEvent::RunFailed {
run_id,
error: error.to_string(),
});
status.set_last_event(record.id);
status.mark_failed(error.to_string());
let _ = self.middleware.run_on_error(&mut ctx, &error).await;
Err(error)
}
}
}
async fn run_loop(
&self,
state: &State,
ctx: &mut RunContext<Ctx>,
run: &mut AgentRun,
status: &mut HarnessRunStatus,
input: Vec<Message>,
streaming: bool,
) -> Result<()> {
let record = ctx.emit(AgentEvent::RunStarted {
run_id: ctx.run_id().clone(),
thread_id: ctx.thread_id().cloned(),
});
status.set_last_event(record.id);
status.mark_running(HarnessPhase::Idle);
let mut messages = input;
status.mark_running(HarnessPhase::Middleware);
self.middleware.run_before_agent(ctx, state).await?;
loop {
if ctx.cancellation.is_cancelled() {
return Err(TinyAgentsError::Cancelled);
}
match crate::harness::steering::apply_pending_steering(ctx, &mut messages)? {
crate::harness::steering::SteeringOutcome::Cancel => {
return Err(TinyAgentsError::Cancelled);
}
crate::harness::steering::SteeringOutcome::Pause => break,
crate::harness::steering::SteeringOutcome::Continue => {}
}
if ctx.check_deadline().is_err() {
ctx.emit(AgentEvent::LimitReached {
kind: LimitKind::WallClock,
});
return Err(TinyAgentsError::Timeout(format!(
"run `{}` exceeded its wall-clock deadline",
ctx.run_id()
)));
}
if run.model_calls >= self.policy.limits.max_model_calls {
ctx.emit(AgentEvent::LimitReached {
kind: LimitKind::ModelCalls,
});
return Err(TinyAgentsError::LimitExceeded(format!(
"max model calls ({}) reached",
self.policy.limits.max_model_calls
)));
}
status.mark_running(HarnessPhase::BuildingRequest);
let mut request = ModelRequest::new(messages.clone()).with_tools(self.tools.schemas());
if let Some(format) = &self.policy.default_response_format {
request = request.with_response_format(format.clone());
}
if let Some(cap) = ctx.config.max_turn_output_tokens {
request.max_tokens =
Some(request.max_tokens.map_or(cap, |current| current.min(cap)));
}
status.mark_running(HarnessPhase::Middleware);
self.middleware
.run_before_model(ctx, state, &mut request)
.await?;
ctx.record_model_call().map_err(|_| {
TinyAgentsError::LimitExceeded(format!(
"max model calls ({}) reached",
self.policy.limits.max_model_calls
))
})?;
let binding = self
.models
.resolve_request(&request, None, None)
.ok_or_else(|| {
TinyAgentsError::ModelNotFound(
request
.model
.clone()
.unwrap_or_else(|| "<default>".to_string()),
)
})?;
let model_name = binding.resolved.name.clone();
let structured_plan: Option<(StructuredStrategy, String, Value)> =
match request.response_format.clone() {
Some(ResponseFormat::Auto { name, schema }) => {
let strategy = StructuredStrategy::for_profile(binding.model.profile());
match strategy {
StructuredStrategy::ProviderSchema => {
request.response_format =
Some(ResponseFormat::json_schema(name.clone(), schema.clone()));
}
StructuredStrategy::ToolCall => {
request.response_format = Some(ResponseFormat::Text);
request.tools.push(ToolSchema {
name: name.clone(),
description: format!("Return the result as `{name}`."),
parameters: schema.clone(),
format: crate::harness::tool::ToolFormat::Json,
});
request.tool_choice = ToolChoice::Tool(name.clone());
}
}
Some((strategy, name, schema))
}
Some(ResponseFormat::JsonSchema { name, schema }) => {
Some((StructuredStrategy::ProviderSchema, name, schema))
}
_ => None,
};
let call_id = CallId::new(format!("{}-model-{}", ctx.run_id(), run.model_calls + 1));
status.mark_running(HarnessPhase::Model);
status.active_model_call = Some(call_id.clone());
let record = ctx.emit(AgentEvent::ModelStarted {
call_id: call_id.clone(),
model: model_name,
});
status.set_last_event(record.id);
let base = ModelCallBase {
harness: self,
call_id: call_id.clone(),
resolved: binding.resolved,
model: binding.model,
streaming,
};
let mut response = self
.middleware
.run_wrapped_model(ctx, state, request, &base)
.await?
.into_response();
status.mark_running(HarnessPhase::Middleware);
self.middleware
.run_after_model(ctx, state, &mut response)
.await?;
run.model_calls += 1;
run.steps += 1;
status.model_calls = run.model_calls;
status.active_model_call = None;
if let Some(usage) = response.usage {
run.usage.record(usage);
status.usage = run.usage;
let record = ctx.emit(AgentEvent::UsageRecorded { usage });
status.set_last_event(record.id);
}
let record = ctx.emit(AgentEvent::ModelCompleted {
call_id,
usage: response.usage,
});
status.set_last_event(record.id);
messages.push(Message::Assistant(response.message.clone()));
let tool_calls = response.tool_calls().to_vec();
let structured_tool_hit = matches!(
&structured_plan,
Some((StructuredStrategy::ToolCall, name, _))
if tool_calls.iter().any(|c| &c.name == name)
);
if tool_calls.is_empty() || structured_tool_hit {
if let Some((strategy, name, schema)) = &structured_plan {
let extractor =
StructuredExtractor::new(*strategy, name.clone(), schema.clone());
let output = extractor.extract(&response)?;
run.structured = Some(output.value);
}
run.final_response = Some(response);
break;
}
status.mark_running(HarnessPhase::Tools);
for mut call in tool_calls {
if ctx.cancellation.is_cancelled() {
return Err(TinyAgentsError::Cancelled);
}
if ctx.check_deadline().is_err() {
ctx.emit(AgentEvent::LimitReached {
kind: LimitKind::WallClock,
});
return Err(TinyAgentsError::Timeout(format!(
"run `{}` exceeded its wall-clock deadline",
ctx.run_id()
)));
}
if run.tool_calls >= self.policy.limits.max_tool_calls {
ctx.emit(AgentEvent::LimitReached {
kind: LimitKind::ToolCalls,
});
return Err(TinyAgentsError::LimitExceeded(format!(
"max tool calls ({}) reached",
self.policy.limits.max_tool_calls
)));
}
ctx.record_tool_call().map_err(|_| {
TinyAgentsError::LimitExceeded(format!(
"max tool calls ({}) reached",
self.policy.limits.max_tool_calls
))
})?;
self.middleware
.run_before_tool(ctx, state, &mut call)
.await?;
let tool = self
.tools
.get(&call.name)
.ok_or_else(|| TinyAgentsError::ToolNotFound(call.name.clone()))?;
tool.schema().validate_call(&call)?;
let tool_call_id = CallId::new(call.id.clone());
let tool_name = call.name.clone();
status.active_tool_calls.push(tool_call_id.clone());
let record = ctx.emit(AgentEvent::ToolStarted {
call_id: tool_call_id.clone(),
tool_name: tool_name.clone(),
});
status.set_last_event(record.id);
let base = ToolCallBase { tool };
let mut result = self
.middleware
.run_wrapped_tool(ctx, state, call, &base)
.await?
.into_result();
self.middleware
.run_after_tool(ctx, state, &mut result)
.await?;
run.tool_calls += 1;
status.tool_calls = run.tool_calls;
status.active_tool_calls.retain(|c| c != &tool_call_id);
let record = ctx.emit(AgentEvent::ToolCompleted {
call_id: tool_call_id,
tool_name,
});
status.set_last_event(record.id);
messages.push(Message::tool(
result.call_id.clone(),
result.content.clone(),
));
}
}
run.messages = messages;
status.mark_running(HarnessPhase::Middleware);
self.middleware.run_after_agent(ctx, state, run).await?;
let record = ctx.emit(AgentEvent::RunCompleted {
run_id: ctx.run_id().clone(),
});
status.set_last_event(record.id);
Ok(())
}
fn response_cache_decision(
&self,
request: &ModelRequest,
) -> Option<(Arc<dyn ResponseCache>, String)> {
let cache = self.response_cache.as_ref()?;
let enabled = match &request.cache_policy {
Some(policy) => policy.response_cache_enabled,
None => self.policy.cache.response_cache_enabled,
};
if !enabled {
return None;
}
Some((Arc::clone(cache), cache_key(request)))
}
async fn invoke_model_with_retry(
&self,
state: &State,
ctx: &mut RunContext<Ctx>,
request: &ModelRequest,
call_id: &CallId,
binding: ResolvedModelBinding<State>,
streaming: bool,
) -> Result<ModelResponse> {
let decision = self.response_cache_decision(request);
if let Some((cache, key)) = decision.as_ref() {
if let Some(mut cached) = cache.get(key).await? {
ctx.emit(AgentEvent::CacheHit {
call_id: call_id.clone(),
key: key.clone(),
});
if cached.resolved_model.is_none() {
cached.resolved_model = Some(binding.resolved.clone());
}
return Ok(cached);
}
ctx.emit(AgentEvent::CacheMiss {
call_id: call_id.clone(),
key: key.clone(),
});
}
let response = self
.invoke_model_resolving(state, ctx, request, call_id, binding, streaming)
.await?;
if let Some((cache, key)) = decision.as_ref() {
cache.put(key, response.clone()).await?;
}
Ok(response)
}
async fn invoke_model_resolving(
&self,
state: &State,
ctx: &mut RunContext<Ctx>,
request: &ModelRequest,
call_id: &CallId,
binding: ResolvedModelBinding<State>,
streaming: bool,
) -> Result<ModelResponse> {
let mut current_name = binding.resolved.name.clone();
let mut model = binding.model;
let mut resolved = binding.resolved;
let run_id = ctx.run_id().clone();
loop {
let mut attempt = 0usize;
let outcome = loop {
if ctx.cancellation.is_cancelled() {
return Err(TinyAgentsError::Cancelled);
}
let remaining = self.call_budget(ctx);
let attempt_result = if streaming {
let fut =
self.invoke_model_streaming_once(state, ctx, &model, request, call_id);
Self::with_call_budget(remaining, run_id.as_str(), fut).await
} else {
let fut = model.invoke(state, request.clone());
Self::with_call_budget(remaining, run_id.as_str(), fut).await
};
match attempt_result {
Ok(response) => break Ok(response),
Err(error) => {
if is_retryable(&error) && self.policy.retry.should_retry(attempt) {
attempt += 1;
ctx.emit(AgentEvent::RetryScheduled {
call_id: call_id.clone(),
attempt,
});
let _ = self.policy.retry.backoff_for_attempt(attempt);
continue;
}
break Err(error);
}
}
};
match outcome {
Ok(mut response) => {
if response.resolved_model.is_none() {
response.resolved_model = Some(resolved);
}
return Ok(response);
}
Err(error) => {
let next = self
.policy
.fallback
.as_ref()
.and_then(|fallback| fallback.next_after(¤t_name))
.map(str::to_owned);
match next.and_then(|name| self.models.get(&name).map(|m| (name, m))) {
Some((name, next_model)) => {
resolved = ResolvedModel {
name: name.clone(),
requested: Some(name.clone()),
source: ModelResolutionSource::Hint,
};
current_name = name;
model = next_model;
continue;
}
None => return Err(error),
}
}
}
}
}
fn call_budget(&self, ctx: &RunContext<Ctx>) -> Option<Duration> {
let config_budget = ctx.remaining_wall_clock();
let policy_budget = self.policy.limits.max_wall_clock_ms.map(|ms| {
Duration::from_millis(ms)
.checked_sub(ctx.limits.elapsed())
.unwrap_or(Duration::ZERO)
});
match (config_budget, policy_budget) {
(Some(a), Some(b)) => Some(a.min(b)),
(Some(a), None) => Some(a),
(None, Some(b)) => Some(b),
(None, None) => None,
}
}
async fn with_call_budget<F>(budget: Option<Duration>, run_id: &str, fut: F) -> F::Output
where
F: Future<Output = Result<ModelResponse>>,
{
match budget {
Some(budget) => match tokio::time::timeout(budget, fut).await {
Ok(result) => result,
Err(_) => Err(TinyAgentsError::Timeout(format!(
"model call for run `{run_id}` exceeded its remaining wall-clock budget \
({} ms)",
budget.as_millis()
))),
},
None => fut.await,
}
}
async fn invoke_model_streaming_once(
&self,
state: &State,
ctx: &mut RunContext<Ctx>,
model: &Arc<dyn ChatModel<State>>,
request: &ModelRequest,
call_id: &CallId,
) -> Result<ModelResponse> {
let mut stream = model.stream(state, request.clone()).await?;
let mut accumulator = StreamAccumulator::new();
let cancellation = ctx.cancellation.clone();
loop {
let item = tokio::select! {
biased;
_ = cancellation.cancelled() => {
return Err(TinyAgentsError::Cancelled);
}
next = stream.next() => match next {
Some(item) => item,
None => break,
},
};
let message_delta = match &item {
ModelStreamItem::MessageDelta(delta) => Some(delta.clone()),
ModelStreamItem::ToolCallDelta(tool_delta) => Some(MessageDelta {
text: String::new(),
tool_call: Some(tool_delta.clone()),
}),
_ => None,
};
if let Some(message_delta) = message_delta {
let record = ctx.emit(AgentEvent::ModelDelta {
call_id: call_id.clone(),
delta: message_delta.clone(),
});
let _ = record;
let mut model_delta = ModelDelta {
call_id: call_id.as_str().to_string(),
content: message_delta.text.clone(),
tool_call: message_delta.tool_call.clone(),
};
self.middleware
.run_on_model_delta(ctx, state, &mut model_delta)
.await?;
}
accumulator.push(&item);
}
accumulator.finish()
}
}
struct ModelCallBase<'h, State: Send + Sync, Ctx: Send + Sync> {
harness: &'h AgentHarness<State, Ctx>,
call_id: CallId,
resolved: ResolvedModel,
model: Arc<dyn ChatModel<State>>,
streaming: bool,
}
impl<State: Send + Sync, Ctx: Send + Sync> ModelBaseCall<State, Ctx>
for ModelCallBase<'_, State, Ctx>
{
fn call<'a>(
&'a self,
ctx: &'a mut RunContext<Ctx>,
state: &'a State,
request: ModelRequest,
) -> BoxModelFuture<'a> {
Box::pin(async move {
let binding = ResolvedModelBinding {
resolved: self.resolved.clone(),
model: Arc::clone(&self.model),
};
self.harness
.invoke_model_with_retry(
state,
ctx,
&request,
&self.call_id,
binding,
self.streaming,
)
.await
})
}
}
struct ToolCallBase<State: Send + Sync> {
tool: Arc<dyn Tool<State>>,
}
impl<State: Send + Sync, Ctx: Send + Sync> ToolBaseCall<State, Ctx> for ToolCallBase<State> {
fn call<'a>(
&'a self,
ctx: &'a mut RunContext<Ctx>,
state: &'a State,
call: ToolCall,
) -> BoxToolFuture<'a> {
Box::pin(async move {
self.tool
.call_with_context(
state,
call,
crate::harness::tool::ToolExecutionContext::from_run_context(ctx),
)
.await
})
}
}
#[cfg(test)]
mod test;