use std::sync::{Arc, Mutex};
use anyhow::Result;
use async_trait::async_trait;
use tokio::sync::RwLock;
use crate::agents::{
ActionResult as AgentActionResult, AgentAction, SubagentProgress, SubagentResult,
collect_subagent_results, execute_action, format_subagent_tool_result, spawn_subagents,
};
use crate::models::{ChatMessage, Model, ModelConfig, StreamCallback, StreamEvent, ToolCall};
use crate::utils::MutexExt;
pub const MAX_AGENT_ITERATIONS: usize = 25;
#[derive(Debug, Clone, Default)]
pub struct ModelCallOutput {
pub content: String,
pub tool_calls: Vec<ToolCall>,
pub tokens: usize,
}
#[async_trait]
pub trait AgentObserver: Send {
fn check_interrupt(&mut self) -> LoopControl;
fn on_status(&mut self, message: &str);
fn on_tool_result(
&mut self,
tool_name: &str,
tool_call_id: &str,
action: &AgentAction,
result: &AgentActionResult,
);
fn on_error(&mut self, error: &str);
fn on_generation_start(&mut self);
fn on_generation_complete(&mut self, tokens: usize);
fn on_message_appended(&mut self, _msg: &ChatMessage) {}
async fn call_model(
&mut self,
model: Arc<RwLock<Box<dyn Model>>>,
messages: &[ChatMessage],
config: &ModelConfig,
) -> Result<ModelCallOutput> {
let text = Arc::new(std::sync::Mutex::new(String::new()));
let typed_tool_calls = Arc::new(std::sync::Mutex::new(Vec::<ToolCall>::new()));
let text_clone = Arc::clone(&text);
let tool_clone = Arc::clone(&typed_tool_calls);
let callback: StreamCallback = Arc::new(move |event| match event {
StreamEvent::Text(chunk) => {
text_clone.lock_mut_safe().push_str(&chunk);
},
StreamEvent::ToolCall(tc) => {
tool_clone.lock_mut_safe().push(tc);
},
StreamEvent::Reasoning(_) | StreamEvent::Done { .. } => {},
});
let model_guard = model.read().await;
let response = model_guard
.chat(messages, config, Some(callback))
.await
.map_err(|e| anyhow::anyhow!("{}", e))?;
let streamed_text = text.lock_mut_safe().clone();
let content = if !streamed_text.is_empty() {
streamed_text
} else {
response.content.clone()
};
let tokens = response.usage.map(|u| u.total_tokens).unwrap_or(0);
let streamed_tool_calls = std::mem::take(&mut *typed_tool_calls.lock_mut_safe());
let tool_calls = if !streamed_tool_calls.is_empty() {
streamed_tool_calls
} else {
response.tool_calls.unwrap_or_default()
};
Ok(ModelCallOutput {
content,
tool_calls,
tokens,
})
}
async fn run_subagents(
&mut self,
specs: Vec<(String, String)>,
model: Arc<RwLock<Box<dyn Model>>>,
config: &ModelConfig,
) -> Vec<SubagentResult> {
let progress = Arc::new(Mutex::new(Vec::<SubagentProgress>::new()));
let (handles, overflow) = spawn_subagents(specs, model, config, Arc::clone(&progress));
collect_subagent_results(handles, overflow).await
}
}
pub enum LoopControl {
Continue,
Interrupt,
InjectMessage(String),
}
pub struct AgentLoopResult {
pub final_response: String,
pub iterations: usize,
pub interrupted: bool,
pub tool_results: Vec<ToolExecutionResult>,
pub total_tokens: usize,
}
#[derive(Debug, Clone)]
pub struct ToolExecutionResult {
pub tool_call_id: String,
pub tool_name: String,
pub action: AgentAction,
pub success: bool,
pub output: String,
pub images: Option<Vec<String>>,
}
pub async fn run_agent_loop(
model: Arc<RwLock<Box<dyn Model>>>,
config: &ModelConfig,
messages: &mut Vec<ChatMessage>,
initial_tool_calls: Vec<ToolCall>,
observer: &mut dyn AgentObserver,
max_iterations: usize,
) -> Result<AgentLoopResult> {
let mut current_tool_calls = initial_tool_calls;
let mut iteration = 0;
let mut all_tool_results = Vec::new();
let mut total_tokens = 0;
let mut final_response = String::new();
let mut interrupted = false;
while !current_tool_calls.is_empty() {
iteration += 1;
if iteration > max_iterations {
observer.on_status(&format!(
"Agent loop exceeded {} iterations",
max_iterations
));
break;
}
observer.on_status(&format!("Agent loop iteration {}", iteration));
match observer.check_interrupt() {
LoopControl::Continue => {},
LoopControl::Interrupt => {
interrupted = true;
break;
},
LoopControl::InjectMessage(msg) => {
observer.on_status("Processing queued message...");
let user_msg = ChatMessage::user(msg);
observer.on_message_appended(&user_msg);
messages.push(user_msg);
current_tool_calls.clear();
},
}
if !current_tool_calls.is_empty() {
let (regular_calls, agent_calls): (Vec<_>, Vec<_>) = current_tool_calls
.iter()
.partition(|tc| tc.function.name != "agent");
for (idx, tc) in regular_calls.iter().enumerate() {
let tool_call_id = tc
.id
.clone()
.unwrap_or_else(|| format!("call_{}_{}_{}", iteration, idx, tc.function.name));
let tool_name = tc.function.name.clone();
let agent_action = match tc.to_agent_action() {
Ok(action) => action,
Err(e) => {
let error_msg = format!("Error: {}", e);
let tool_msg = ChatMessage::tool(&tool_call_id, &tool_name, &error_msg);
observer.on_message_appended(&tool_msg);
messages.push(tool_msg);
all_tool_results.push(ToolExecutionResult {
tool_call_id,
tool_name,
action: AgentAction::ParseError {
message: error_msg.clone(),
},
success: false,
output: error_msg,
images: None,
});
continue;
},
};
let result = execute_action(&agent_action).await;
let (success, output, images) = match &result {
AgentActionResult::Success { output, images } => {
(true, output.clone(), images.clone())
},
AgentActionResult::Error { error } => {
(false, format!("Error: {}", error), None)
},
};
observer.on_tool_result(&tool_name, &tool_call_id, &agent_action, &result);
let mut tool_msg = ChatMessage::tool(&tool_call_id, &tool_name, &output);
if let Some(ref imgs) = images {
tool_msg = tool_msg.with_images(imgs.clone());
}
observer.on_message_appended(&tool_msg);
messages.push(tool_msg);
all_tool_results.push(ToolExecutionResult {
tool_call_id,
tool_name,
action: agent_action,
success,
output,
images,
});
}
if !agent_calls.is_empty() {
let agent_specs: Vec<(String, String)> = agent_calls
.iter()
.filter_map(|tc| match tc.to_agent_action() {
Ok(AgentAction::SpawnAgent {
prompt,
description,
}) => Some((prompt, description)),
_ => None,
})
.collect();
if !agent_specs.is_empty() {
let subagent_results = observer
.run_subagents(agent_specs, Arc::clone(&model), config)
.await;
for (i, result) in subagent_results.iter().enumerate() {
let tool_call_id = agent_calls
.get(i)
.and_then(|tc| tc.id.clone())
.unwrap_or_else(|| format!("call_agent_{}_{}", iteration, i));
let tool_name = "agent".to_string();
let output = format_subagent_tool_result(result);
observer.on_tool_result(
&tool_name,
&tool_call_id,
&AgentAction::SpawnAgent {
prompt: String::new(),
description: result.description.clone(),
},
&if result.success {
AgentActionResult::Success {
output: output.clone(),
images: None,
}
} else {
AgentActionResult::Error {
error: output.clone(),
}
},
);
let tool_msg = ChatMessage::tool(&tool_call_id, &tool_name, &output);
observer.on_message_appended(&tool_msg);
messages.push(tool_msg);
all_tool_results.push(ToolExecutionResult {
tool_call_id,
tool_name,
action: AgentAction::SpawnAgent {
prompt: String::new(),
description: result.description.clone(),
},
success: result.success,
output,
images: None,
});
total_tokens += result.tokens;
}
}
}
observer.on_status(&format!(
"Iteration {} - {} tool(s) executed, calling model...",
iteration,
current_tool_calls.len()
));
}
match observer.check_interrupt() {
LoopControl::Interrupt => {
interrupted = true;
break;
},
LoopControl::InjectMessage(msg) => {
let user_msg = ChatMessage::user(msg);
observer.on_message_appended(&user_msg);
messages.push(user_msg);
},
LoopControl::Continue => {},
}
observer.on_generation_start();
let model_result = observer
.call_model(Arc::clone(&model), messages, config)
.await;
match model_result {
Ok(out) => {
total_tokens += out.tokens;
observer.on_generation_complete(out.tokens);
let new_tool_calls = out.tool_calls;
if !out.content.is_empty() || !new_tool_calls.is_empty() {
let msg = ChatMessage::assistant(out.content.clone())
.with_tool_calls(new_tool_calls.clone());
observer.on_message_appended(&msg);
messages.push(msg);
}
if new_tool_calls.is_empty() {
final_response = out.content;
observer.on_status(&format!(
"Agent loop complete after {} iterations",
iteration
));
break;
} else {
current_tool_calls = new_tool_calls;
}
},
Err(e) => {
observer.on_error(&e.to_string());
break;
},
}
}
Ok(AgentLoopResult {
final_response,
iterations: iteration,
interrupted,
tool_results: all_tool_results,
total_tokens,
})
}