use std::{collections::HashSet, future::Future, pin::Pin};
use futures::{StreamExt as _, stream::Stream};
use serde_json::Value;
use tracing::{Instrument, debug, error, info, info_span, warn};
use super::{
config::Agent,
result::{
NextStep, RunConfig, RunEvent, RunResult, StepInfo, ToolCallRecord, ToolCallRequest,
UserInput,
},
};
use crate::{
chat::{ChatProvider, ChatRequest, ChatResponse, ToolChoice},
error::{AgentError, Error, Result},
guardrail::{InputGuardrail, InputGuardrailResult, OutputGuardrail, OutputGuardrailResult},
hooks::{Hooks, NoopHooks, RunContext},
message::Message,
stream::{StreamAggregator, StreamChunk},
tool::{
BoxedTool, ConfirmationHandler, ToolConfirmationRequest, ToolConfirmationResponse,
ToolDefinition, ToolExecutionPolicy,
},
usage::Usage,
};
enum StepOutcome {
Done(RunResult),
Continue,
}
struct RunState<'a> {
agent: &'a Agent,
provider: &'a dyn ChatProvider,
context: RunContext,
messages: Vec<Message>,
step_history: Vec<StepInfo>,
cumulative_usage: Usage,
auto_approved: HashSet<String>,
user_message: Message,
system_prompt: String,
all_definitions: Vec<ToolDefinition>,
all_output_guardrails: Vec<&'a OutputGuardrail>,
input_guardrail_results: Vec<InputGuardrailResult>,
parallel_guardrails: Vec<&'a InputGuardrail>,
max_steps: usize,
max_tool_concurrency: Option<usize>,
structured_output: bool,
}
impl<'a> RunState<'a> {
async fn init(agent: &'a Agent, input: UserInput, config: &'a RunConfig) -> Result<Self> {
let provider = agent.provider.as_deref().ok_or_else(|| {
AgentError::runtime(format!(
"Agent '{}' has no provider configured. Call .provider() before running.",
agent.name
))
})?;
let max_steps = config.max_steps.unwrap_or(agent.max_steps);
let context = RunContext::new().with_agent_name(&agent.name);
let mut messages = Vec::new();
let system_prompt = agent.resolve_instructions();
if !system_prompt.is_empty() {
messages.push(Message::system(&system_prompt));
}
let user_message = input.into_message();
messages.push(user_message.clone());
if let Some(ref session) = config.session {
let history = session.get_messages(None).await?;
if !history.is_empty() {
let insert_pos = messages.len().saturating_sub(1);
messages.splice(insert_pos..insert_pos, history);
}
}
let all_definitions = Runner::collect_all_definitions(agent);
let tool_names: Vec<&str> = all_definitions.iter().map(ToolDefinition::name).collect();
tracing::Span::current().record("agent.tools", tracing::field::debug(&tool_names));
let all_input_guardrails = Runner::collect_input_guardrails(agent, config);
let all_output_guardrails = Runner::collect_output_guardrails(agent, config);
let mut input_guardrail_results = Vec::new();
let sequential: Vec<_> = all_input_guardrails
.iter()
.filter(|g| !g.is_parallel())
.copied()
.collect();
let parallel: Vec<_> = all_input_guardrails
.iter()
.filter(|g| g.is_parallel())
.copied()
.collect();
if !sequential.is_empty() {
let results =
Runner::run_input_guardrails(&sequential, &context, &agent.name, &messages).await?;
input_guardrail_results.extend(results);
}
Ok(Self {
agent,
provider,
context,
messages,
step_history: Vec::new(),
cumulative_usage: Usage::zero(),
auto_approved: HashSet::new(),
user_message,
system_prompt,
all_definitions,
all_output_guardrails,
input_guardrail_results,
parallel_guardrails: parallel,
max_steps,
max_tool_concurrency: config.max_tool_concurrency,
structured_output: agent.output_schema.is_some(),
})
}
fn system_ref(&self) -> Option<&str> {
(!self.system_prompt.is_empty()).then_some(self.system_prompt.as_str())
}
fn build_request(&self) -> ChatRequest {
Runner::build_request(self.agent, &self.messages, &self.all_definitions)
}
fn build_stream_request(&self) -> ChatRequest {
let mut req = self.build_request();
req.stream = true;
req
}
fn accumulate_usage(&mut self, response: &ChatResponse) {
if let Some(usage) = response.usage {
self.cumulative_usage += usage;
self.context.add_usage(usage);
}
}
fn accumulate_tool_usage(&mut self, records: &[ToolCallRecord]) {
for record in records {
if record.sub_usage.total_tokens > 0 {
self.cumulative_usage += record.sub_usage;
self.context.add_usage(record.sub_usage);
}
}
}
async fn process_step(
&mut self,
step: usize,
response: ChatResponse,
hooks: &dyn Hooks,
agent_name: &str,
config: &RunConfig,
) -> Result<StepOutcome> {
let next_step = Runner::classify_response(&response, self.structured_output);
let (next_step, forbidden) =
Runner::apply_policies(next_step, self.agent, &self.auto_approved);
match next_step {
NextStep::FinalOutput { ref output } => {
self.messages.push(response.message.clone());
self.step_history.push(StepInfo {
step,
response: response.clone(),
tool_calls: Vec::new(),
});
let output_value = output.clone();
let output_guardrail_results = Runner::run_output_guardrails(
&self.all_output_guardrails,
&self.context,
&self.agent.name,
&output_value,
)
.await?;
hooks
.on_agent_end(&self.context, agent_name, &output_value)
.await;
if let Some(ref session) = config.session {
let to_save = vec![self.user_message.clone(), response.message.clone()];
let _ = session.add_messages(&to_save).await;
}
tracing::Span::current().record("agent.result_steps", step);
info!(
agent = %self.agent.name,
steps = step,
input_tokens = self.cumulative_usage.input_tokens,
output_tokens = self.cumulative_usage.output_tokens,
"Agent run completed",
);
let result = RunResult {
output: output_value,
usage: self.cumulative_usage,
steps: step,
step_history: std::mem::take(&mut self.step_history),
agent_name: self.agent.name.clone(),
input_guardrail_results: std::mem::take(&mut self.input_guardrail_results),
output_guardrail_results,
};
Ok(StepOutcome::Done(result))
}
NextStep::ToolCalls { ref calls } => {
self.messages.push(response.message.clone());
Runner::append_denied_messages(
&forbidden,
"forbidden by execution policy",
&mut self.messages,
);
let tool_records = Runner::execute_tool_calls(
calls,
self.agent,
&self.context,
hooks,
agent_name,
&mut self.messages,
self.max_tool_concurrency,
)
.await?;
self.accumulate_tool_usage(&tool_records);
self.step_history.push(StepInfo {
step,
response,
tool_calls: tool_records,
});
Ok(StepOutcome::Continue)
}
NextStep::NeedsApproval {
ref pending_approval,
ref approved,
} => {
self.messages.push(response.message.clone());
Runner::append_denied_messages(
&forbidden,
"forbidden by execution policy",
&mut self.messages,
);
let handler = config.confirmation_handler.as_deref().ok_or_else(|| {
AgentError::runtime(
"Tool execution requires approval but no confirmation handler is configured",
)
})?;
let (confirmed, denied) =
Runner::seek_confirmations(pending_approval, handler, &mut self.auto_approved)
.await;
Runner::append_denied_messages(&denied, "denied by user", &mut self.messages);
let executable: Vec<ToolCallRequest> =
approved.iter().chain(&confirmed).cloned().collect();
let tool_records = if executable.is_empty() {
Vec::new()
} else {
Runner::execute_tool_calls(
&executable,
self.agent,
&self.context,
hooks,
agent_name,
&mut self.messages,
self.max_tool_concurrency,
)
.await?
};
self.accumulate_tool_usage(&tool_records);
self.step_history.push(StepInfo {
step,
response,
tool_calls: tool_records,
});
Ok(StepOutcome::Continue)
}
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct Runner;
impl Runner {
pub fn run<'a>(
agent: &'a Agent,
input: impl Into<UserInput>,
config: RunConfig,
) -> Pin<Box<dyn Future<Output = Result<RunResult>> + Send + 'a>> {
let input = input.into();
let span = info_span!(
"agent",
agent.name = %agent.name,
agent.model = %agent.model,
gen_ai.system = "machi",
agent.max_steps = agent.max_steps,
agent.tools = tracing::field::Empty,
agent.result_steps = tracing::field::Empty,
error = tracing::field::Empty,
);
Box::pin(Self::run_inner(agent, input, config).instrument(span))
}
async fn run_inner(agent: &Agent, input: UserInput, config: RunConfig) -> Result<RunResult> {
let noop = NoopHooks;
let hooks: &dyn Hooks = config.hooks.as_deref().unwrap_or(&noop);
let mut state = RunState::init(agent, input, &config).await?;
hooks.on_agent_start(&state.context, &agent.name).await;
for step in 1..=state.max_steps {
state.context.advance_step();
debug!(agent = %agent.name, step, "Starting step");
let request = state.build_request();
hooks
.on_llm_start(
&state.context,
&agent.name,
state.system_ref(),
&state.messages,
)
.await;
let response = if step == 1 && !state.parallel_guardrails.is_empty() {
let (guardrail_result, llm_result) = tokio::join!(
Self::run_input_guardrails(
&state.parallel_guardrails,
&state.context,
&agent.name,
&state.messages,
),
state.provider.chat(&request),
);
state.input_guardrail_results.extend(guardrail_result?);
llm_result
} else {
state.provider.chat(&request).await
}
.map_err(|e| {
error!(error = %e, agent = %agent.name, step, "LLM call failed");
tracing::Span::current().record("error", tracing::field::display(&e));
e
})?;
hooks
.on_llm_end(&state.context, &agent.name, &response)
.await;
state.accumulate_usage(&response);
match state
.process_step(step, response, hooks, &agent.name, &config)
.await?
{
StepOutcome::Done(result) => return Ok(result),
StepOutcome::Continue => {}
}
}
let err = Error::from(AgentError::max_steps(state.max_steps));
error!(error = %err, agent = %agent.name, max_steps = state.max_steps, "Max steps exceeded");
tracing::Span::current().record("error", tracing::field::display(&err));
hooks.on_error(&state.context, &agent.name, &err).await;
Err(err)
}
pub fn run_streamed<'a>(
agent: &'a Agent,
input: impl Into<UserInput>,
config: RunConfig,
) -> Pin<Box<dyn Stream<Item = Result<RunEvent>> + Send + 'a>> {
let input = input.into();
Box::pin(Self::run_streamed_inner(agent, input, config))
}
#[allow(tail_expr_drop_order)]
fn run_streamed_inner(
agent: &Agent,
input: UserInput,
config: RunConfig,
) -> impl Stream<Item = Result<RunEvent>> + Send + '_ {
async_stream::try_stream! {
let noop = NoopHooks;
let hooks: &dyn Hooks = config.hooks.as_deref().unwrap_or(&noop);
let mut state = RunState::init(agent, input, &config).await?;
info!(
agent = %agent.name,
model = %agent.model,
tools = ?state.all_definitions.iter().map(ToolDefinition::name).collect::<Vec<_>>(),
gen_ai.system = "machi",
"Agent streamed run started",
);
hooks.on_agent_start(&state.context, &agent.name).await;
yield RunEvent::RunStarted { agent_name: agent.name.clone() };
for step in 1..=state.max_steps {
state.context.advance_step();
debug!(agent = %agent.name, step, "Starting streamed step");
yield RunEvent::StepStarted { step };
let request = state.build_stream_request();
hooks
.on_llm_start(&state.context, &agent.name, state.system_ref(), &state.messages)
.await;
if step == 1 && !state.parallel_guardrails.is_empty() {
let par_results = Self::run_input_guardrails(
&state.parallel_guardrails,
&state.context,
&agent.name,
&state.messages,
)
.await?;
state.input_guardrail_results.extend(par_results);
}
let mut chunk_stream = state.provider.chat_stream(&request).await?;
let mut aggregator = StreamAggregator::new();
while let Some(chunk_result) = chunk_stream.next().await {
let chunk = chunk_result?;
match &chunk {
StreamChunk::Text(delta) => {
yield RunEvent::TextDelta(delta.clone());
}
StreamChunk::ReasoningContent(delta) => {
yield RunEvent::ReasoningDelta(delta.clone());
}
StreamChunk::Audio { data, transcript } => {
yield RunEvent::AudioDelta {
data: data.clone(),
transcript: transcript.clone(),
};
}
StreamChunk::ToolUseStart { id, name, .. } => {
yield RunEvent::ToolCallStarted {
id: id.clone(),
name: name.clone(),
};
}
_ => {}
}
aggregator.apply(&chunk);
}
let response = aggregator.into_chat_response();
hooks.on_llm_end(&state.context, &agent.name, &response).await;
state.accumulate_usage(&response);
match state.process_step(step, response, hooks, &agent.name, &config).await? {
StepOutcome::Done(result) => {
if let Some(last_step) = result.step_history.last() {
yield RunEvent::StepCompleted {
step_info: Box::new(last_step.clone()),
};
}
yield RunEvent::RunCompleted {
result: Box::new(result),
};
return;
}
StepOutcome::Continue => {
let last = state.step_history.last().expect("just pushed");
for record in &last.tool_calls {
yield RunEvent::ToolCallCompleted {
record: record.clone(),
};
}
yield RunEvent::StepCompleted {
step_info: Box::new(last.clone()),
};
}
}
}
let err = Error::from(AgentError::max_steps(state.max_steps));
error!(error = %err, agent = %agent.name, max_steps = state.max_steps, "Max steps exceeded");
hooks.on_error(&state.context, &agent.name, &err).await;
Err(err)?;
}
}
}
impl Runner {
fn collect_all_definitions(agent: &Agent) -> Vec<ToolDefinition> {
agent
.tools
.iter()
.map(|t| t.definition())
.chain(agent.managed_agents.iter().map(Agent::tool_definition))
.collect()
}
fn build_request(
agent: &Agent,
messages: &[Message],
definitions: &[ToolDefinition],
) -> ChatRequest {
let mut request = ChatRequest::with_messages(&agent.model, messages.to_vec());
if !definitions.is_empty() {
request = request
.tools(definitions.to_vec())
.tool_choice(ToolChoice::Auto)
.parallel_tool_calls(true);
}
if let Some(ref schema) = agent.output_schema {
request = request.response_format(schema.to_response_format());
}
request
}
fn classify_response(response: &ChatResponse, structured_output: bool) -> NextStep {
if let Some(tool_calls) = response.tool_calls() {
let calls: Vec<ToolCallRequest> =
tool_calls.iter().map(ToolCallRequest::from).collect();
if !calls.is_empty() {
return NextStep::ToolCalls { calls };
}
}
let output = if structured_output {
response.text().map_or(Value::Null, |text| {
serde_json::from_str(&text).unwrap_or(Value::String(text))
})
} else {
response.text().map_or(Value::Null, Value::String)
};
NextStep::FinalOutput { output }
}
fn apply_policies(
next_step: NextStep,
agent: &Agent,
auto_approved: &HashSet<String>,
) -> (NextStep, Vec<ToolCallRequest>) {
let NextStep::ToolCalls { calls } = next_step else {
return (next_step, Vec::new());
};
let mut approved = Vec::new();
let mut pending = Vec::new();
let mut forbidden = Vec::new();
for call in calls {
let policy = agent
.tool_policies
.get(&call.name)
.copied()
.unwrap_or(ToolExecutionPolicy::Auto);
match policy {
ToolExecutionPolicy::Auto => approved.push(call),
ToolExecutionPolicy::RequireConfirmation => {
if auto_approved.contains(&call.name) {
approved.push(call);
} else {
pending.push(call);
}
}
ToolExecutionPolicy::Forbidden => forbidden.push(call),
}
}
let result = if pending.is_empty() {
NextStep::ToolCalls { calls: approved }
} else {
NextStep::NeedsApproval {
pending_approval: pending,
approved,
}
};
(result, forbidden)
}
async fn execute_tool_calls(
calls: &[ToolCallRequest],
agent: &Agent,
context: &RunContext,
hooks: &dyn Hooks,
agent_name: &str,
messages: &mut Vec<Message>,
max_concurrency: Option<usize>,
) -> Result<Vec<ToolCallRecord>> {
let concurrency = max_concurrency.unwrap_or(calls.len()).max(1);
let mut records = Vec::with_capacity(calls.len());
for chunk in calls.chunks(concurrency) {
let mut futs = Vec::with_capacity(chunk.len());
for call in chunk {
futs.push(Self::execute_single_tool(
call, agent, context, hooks, agent_name,
));
}
records.extend(futures::future::join_all(futs).await);
}
for record in &records {
messages.push(Message::tool(&record.id, &record.result));
}
Ok(records)
}
async fn execute_single_tool(
call: &ToolCallRequest,
agent: &Agent,
context: &RunContext,
hooks: &dyn Hooks,
agent_name: &str,
) -> ToolCallRecord {
let tool_span = info_span!(
"tool",
tool.name = %call.name,
tool.id = %call.id,
tool.input = %call.arguments,
tool.output = tracing::field::Empty,
tool.success = tracing::field::Empty,
error = tracing::field::Empty,
);
async {
hooks.on_tool_start(context, agent_name, &call.name).await;
let (result_str, success, sub_usage) =
if let Some(sub) = agent.managed_agents.iter().find(|a| a.name == call.name) {
Self::dispatch_managed_agent(sub, &call.arguments).await
} else if let Some(tool) = agent.tools.iter().find(|t| t.name() == call.name) {
let (r, s) = Self::dispatch_tool(tool, call).await;
(r, s, Usage::zero())
} else {
warn!(tool = %call.name, "Tool not found");
(
format!("Tool '{}' not found", call.name),
false,
Usage::zero(),
)
};
let current = tracing::Span::current();
current.record("tool.success", success);
current.record("tool.output", result_str.as_str());
if !success {
current.record("error", result_str.as_str());
}
hooks
.on_tool_end(context, agent_name, &call.name, &result_str)
.await;
ToolCallRecord {
id: call.id.clone(),
name: call.name.clone(),
arguments: call.arguments.clone(),
result: result_str,
success,
sub_usage,
}
}
.instrument(tool_span)
.await
}
async fn dispatch_managed_agent(sub_agent: &Agent, args: &Value) -> (String, bool, Usage) {
let task = args.get("task").and_then(Value::as_str).unwrap_or_default();
info!(
from_agent = tracing::field::Empty,
to_agent = %sub_agent.name,
"Handoff to managed agent",
);
match Self::run(sub_agent, task, RunConfig::default()).await {
Ok(result) => {
let output = serde_json::to_string(&result.output)
.unwrap_or_else(|_| result.output.to_string());
(output, true, result.usage)
}
Err(e) => (
format!("Managed agent '{}' failed: {e}", sub_agent.name),
false,
Usage::zero(),
),
}
}
async fn dispatch_tool(tool: &BoxedTool, call: &ToolCallRequest) -> (String, bool) {
match tool.call_json(call.arguments.clone()).await {
Ok(value) => {
let output = serde_json::to_string(&value).unwrap_or_else(|_| value.to_string());
(output, true)
}
Err(e) => {
warn!(tool = %call.name, error = %e, "Tool execution failed");
(format!("Tool error: {e}"), false)
}
}
}
async fn seek_confirmations(
pending: &[ToolCallRequest],
handler: &dyn ConfirmationHandler,
auto_approved: &mut HashSet<String>,
) -> (Vec<ToolCallRequest>, Vec<ToolCallRequest>) {
let mut confirmed = Vec::new();
let mut denied = Vec::new();
for call in pending {
let request =
ToolConfirmationRequest::new(&call.id, &call.name, call.arguments.clone());
let response = handler.confirm(&request).await;
match response {
ToolConfirmationResponse::Approved => confirmed.push(call.clone()),
ToolConfirmationResponse::ApproveAll => {
auto_approved.insert(call.name.clone());
confirmed.push(call.clone());
}
ToolConfirmationResponse::Denied => denied.push(call.clone()),
}
}
(confirmed, denied)
}
fn append_denied_messages(
denied: &[ToolCallRequest],
reason: &str,
messages: &mut Vec<Message>,
) {
for call in denied {
messages.push(Message::tool(
&call.id,
format!("Tool '{}' was {reason}.", call.name),
));
}
}
fn collect_input_guardrails<'a>(
agent: &'a Agent,
config: &'a RunConfig,
) -> Vec<&'a InputGuardrail> {
agent
.input_guardrails
.iter()
.chain(config.input_guardrails.iter())
.collect()
}
fn collect_output_guardrails<'a>(
agent: &'a Agent,
config: &'a RunConfig,
) -> Vec<&'a OutputGuardrail> {
agent
.output_guardrails
.iter()
.chain(config.output_guardrails.iter())
.collect()
}
async fn run_input_guardrails(
guardrails: &[&InputGuardrail],
context: &RunContext,
agent_name: &str,
messages: &[Message],
) -> Result<Vec<InputGuardrailResult>> {
let mut results = Vec::with_capacity(guardrails.len());
for guardrail in guardrails {
let result = guardrail.run(context, agent_name, messages).await?;
if result.is_triggered() {
return Err(AgentError::input_guardrail_triggered(
&result.guardrail_name,
result.output.output_info.clone(),
)
.into());
}
results.push(result);
}
Ok(results)
}
async fn run_output_guardrails(
guardrails: &[&OutputGuardrail],
context: &RunContext,
agent_name: &str,
output: &Value,
) -> Result<Vec<OutputGuardrailResult>> {
if guardrails.is_empty() {
return Ok(Vec::new());
}
let futs: Vec<_> = guardrails
.iter()
.map(|g| g.run(context, agent_name, output))
.collect();
let all_results = futures::future::join_all(futs).await;
let mut results = Vec::with_capacity(all_results.len());
for r in all_results {
let result = r?;
if result.is_triggered() {
return Err(AgentError::output_guardrail_triggered(
&result.guardrail_name,
result.output.output_info.clone(),
)
.into());
}
results.push(result);
}
Ok(results)
}
}