use std::sync::{Arc, Mutex};
use anyhow::Result;
use crate::agents::{
ActionResult as AgentActionResult, AgentAction, SubagentProgress, collect_subagent_results,
execute_action, format_subagent_tool_result, spawn_subagents,
};
use crate::models::{ChatMessage, Model, ModelConfig, StreamCallback, ToolCall};
use crate::utils::MutexExt;
pub const MAX_AGENT_ITERATIONS: usize = 25;
pub trait AgentObserver: Send {
fn check_interrupt(&mut self) -> LoopControl;
fn on_status(&mut self, message: &str);
fn on_tool_result(
&mut self,
tool_name: &str,
tool_call_id: &str,
action: &AgentAction,
result: &AgentActionResult,
);
fn on_error(&mut self, error: &str);
fn on_generation_start(&mut self);
fn on_generation_complete(&mut self, tokens: usize);
}
pub enum LoopControl {
Continue,
Interrupt,
InjectMessage(String),
}
pub struct AgentLoopResult {
pub final_response: String,
pub iterations: usize,
pub interrupted: bool,
pub tool_results: Vec<ToolExecutionResult>,
pub total_tokens: usize,
}
#[derive(Debug, Clone)]
pub struct ToolExecutionResult {
pub tool_call_id: String,
pub tool_name: String,
pub action: AgentAction,
pub success: bool,
pub output: String,
pub images: Option<Vec<String>>,
}
pub async fn run_agent_loop(
model: Arc<tokio::sync::RwLock<Box<dyn Model>>>,
config: &ModelConfig,
messages: &mut Vec<ChatMessage>,
initial_tool_calls: Vec<ToolCall>,
observer: &mut dyn AgentObserver,
max_iterations: usize,
) -> Result<AgentLoopResult> {
let mut current_tool_calls = initial_tool_calls;
let mut iteration = 0;
let mut all_tool_results = Vec::new();
let mut total_tokens = 0;
let mut final_response = String::new();
let mut interrupted = false;
while !current_tool_calls.is_empty() {
iteration += 1;
if iteration > max_iterations {
observer.on_status(&format!(
"Agent loop exceeded {} iterations",
max_iterations
));
break;
}
observer.on_status(&format!("Agent loop iteration {}", iteration));
match observer.check_interrupt() {
LoopControl::Continue => {},
LoopControl::Interrupt => {
interrupted = true;
break;
},
LoopControl::InjectMessage(msg) => {
observer.on_status("Processing queued message...");
messages.push(ChatMessage::user(msg));
current_tool_calls.clear();
},
}
if !current_tool_calls.is_empty() {
if let Some(last_assistant) = messages
.iter_mut()
.rev()
.find(|m| matches!(m.role, crate::models::MessageRole::Assistant))
{
last_assistant.tool_calls = Some(current_tool_calls.clone());
}
let (regular_calls, agent_calls): (Vec<_>, Vec<_>) = current_tool_calls
.iter()
.partition(|tc| tc.function.name != "agent");
for tc in ®ular_calls {
let tool_call_id = tc
.id
.clone()
.unwrap_or_else(|| format!("call_{}_{}", iteration, tc.function.name));
let tool_name = tc.function.name.clone();
let agent_action = match tc.to_agent_action() {
Ok(action) => action,
Err(e) => {
let error_msg = format!("Error: {}", e);
messages.push(ChatMessage::tool(&tool_call_id, &tool_name, &error_msg));
all_tool_results.push(ToolExecutionResult {
tool_call_id,
tool_name,
action: AgentAction::ParseError {
message: error_msg.clone(),
},
success: false,
output: error_msg,
images: None,
});
continue;
},
};
let result = execute_action(&agent_action).await;
let (success, output, images) = match &result {
AgentActionResult::Success { output, images } => {
(true, output.clone(), images.clone())
},
AgentActionResult::Error { error } => {
(false, format!("Error: {}", error), None)
},
};
observer.on_tool_result(&tool_name, &tool_call_id, &agent_action, &result);
let mut tool_msg = ChatMessage::tool(&tool_call_id, &tool_name, &output);
if let Some(ref imgs) = images {
tool_msg = tool_msg.with_images(imgs.clone());
}
messages.push(tool_msg);
all_tool_results.push(ToolExecutionResult {
tool_call_id,
tool_name,
action: agent_action,
success,
output,
images,
});
}
if !agent_calls.is_empty() {
let agent_specs: Vec<(String, String)> = agent_calls
.iter()
.filter_map(|tc| match tc.to_agent_action() {
Ok(AgentAction::SpawnAgent {
prompt,
description,
}) => Some((prompt, description)),
_ => None,
})
.collect();
if !agent_specs.is_empty() {
let progress = Arc::new(Mutex::new(Vec::<SubagentProgress>::new()));
let (handles, overflow) = spawn_subagents(
agent_specs,
Arc::clone(&model),
config,
Arc::clone(&progress),
);
let subagent_results = collect_subagent_results(handles, overflow).await;
for (i, result) in subagent_results.iter().enumerate() {
let tool_call_id = agent_calls
.get(i)
.and_then(|tc| tc.id.clone())
.unwrap_or_else(|| format!("call_agent_{}", i));
let tool_name = "agent".to_string();
let output = format_subagent_tool_result(result);
observer.on_tool_result(
&tool_name,
&tool_call_id,
&AgentAction::SpawnAgent {
prompt: String::new(),
description: result.description.clone(),
},
&if result.success {
AgentActionResult::Success {
output: output.clone(),
images: None,
}
} else {
AgentActionResult::Error {
error: output.clone(),
}
},
);
messages.push(ChatMessage::tool(&tool_call_id, &tool_name, &output));
all_tool_results.push(ToolExecutionResult {
tool_call_id,
tool_name,
action: AgentAction::SpawnAgent {
prompt: String::new(),
description: result.description.clone(),
},
success: result.success,
output,
images: None,
});
total_tokens += result.tokens;
}
}
}
observer.on_status(&format!(
"Iteration {} - {} tool(s) executed, calling model...",
iteration,
current_tool_calls.len()
));
}
match observer.check_interrupt() {
LoopControl::Interrupt => {
interrupted = true;
break;
},
LoopControl::InjectMessage(msg) => {
messages.push(ChatMessage::user(msg));
},
LoopControl::Continue => {},
}
observer.on_generation_start();
let response_text = Arc::new(std::sync::Mutex::new(String::new()));
let response_clone = Arc::clone(&response_text);
let callback: StreamCallback = Arc::new(move |chunk: &str| {
let mut resp = response_clone.lock_mut_safe();
resp.push_str(chunk);
});
let model_result = {
let model = model.read().await;
model.chat(messages, config, Some(callback)).await
};
match model_result {
Ok(response) => {
let content = {
let buf = response_text.lock_mut_safe();
if !buf.is_empty() {
buf.clone()
} else {
response.content.clone()
}
};
let tokens = response.usage.map(|u| u.total_tokens).unwrap_or(0);
total_tokens += tokens;
observer.on_generation_complete(tokens);
let new_tool_calls = response.tool_calls.unwrap_or_default();
if !content.is_empty() || !new_tool_calls.is_empty() {
let msg = ChatMessage::assistant(content.clone())
.with_tool_calls(new_tool_calls.clone());
messages.push(msg);
}
if new_tool_calls.is_empty() {
final_response = content;
observer.on_status(&format!(
"Agent loop complete after {} iterations",
iteration
));
break;
} else {
current_tool_calls = new_tool_calls;
}
},
Err(e) => {
observer.on_error(&e.to_string());
break;
},
}
}
Ok(AgentLoopResult {
final_response,
iterations: iteration,
interrupted,
tool_results: all_tool_results,
total_tokens,
})
}