1use std::sync::{Arc, Mutex};
8
9use anyhow::Result;
10
11use crate::agents::{
12 ActionResult as AgentActionResult, AgentAction, SubagentProgress, collect_subagent_results,
13 execute_action, format_subagent_tool_result, spawn_subagents,
14};
15use crate::models::{ChatMessage, Model, ModelConfig, StreamCallback, ToolCall};
16use crate::utils::MutexExt;
17
18pub const MAX_AGENT_ITERATIONS: usize = 25;
20
21pub trait AgentObserver: Send {
23 fn check_interrupt(&mut self) -> LoopControl;
27
28 fn on_status(&mut self, message: &str);
30
31 fn on_tool_result(
33 &mut self,
34 tool_name: &str,
35 tool_call_id: &str,
36 action: &AgentAction,
37 result: &AgentActionResult,
38 );
39
40 fn on_error(&mut self, error: &str);
42
43 fn on_generation_start(&mut self);
45
46 fn on_generation_complete(&mut self, tokens: usize);
48}
49
50pub enum LoopControl {
52 Continue,
54 Interrupt,
56 InjectMessage(String),
58}
59
60pub struct AgentLoopResult {
62 pub final_response: String,
64 pub iterations: usize,
66 pub interrupted: bool,
68 pub tool_results: Vec<ToolExecutionResult>,
70 pub total_tokens: usize,
72}
73
74#[derive(Debug, Clone)]
76pub struct ToolExecutionResult {
77 pub tool_call_id: String,
78 pub tool_name: String,
79 pub action: AgentAction,
80 pub success: bool,
81 pub output: String,
82}
83
84pub async fn run_agent_loop(
89 model: Arc<tokio::sync::RwLock<Box<dyn Model>>>,
90 config: &ModelConfig,
91 messages: &mut Vec<ChatMessage>,
92 initial_tool_calls: Vec<ToolCall>,
93 observer: &mut dyn AgentObserver,
94 max_iterations: usize,
95) -> Result<AgentLoopResult> {
96 let mut current_tool_calls = initial_tool_calls;
97 let mut iteration = 0;
98 let mut all_tool_results = Vec::new();
99 let mut total_tokens = 0;
100 let mut final_response = String::new();
101 let mut interrupted = false;
102
103 while !current_tool_calls.is_empty() {
104 iteration += 1;
105 if iteration > max_iterations {
106 observer.on_status(&format!(
107 "Agent loop exceeded {} iterations",
108 max_iterations
109 ));
110 break;
111 }
112
113 observer.on_status(&format!("Agent loop iteration {}", iteration));
114
115 match observer.check_interrupt() {
117 LoopControl::Continue => {},
118 LoopControl::Interrupt => {
119 interrupted = true;
120 break;
121 },
122 LoopControl::InjectMessage(msg) => {
123 observer.on_status("Processing queued message...");
125 messages.push(ChatMessage::user(msg));
126 current_tool_calls.clear();
127 },
129 }
130
131 if !current_tool_calls.is_empty() {
133 if let Some(last_assistant) = messages
135 .iter_mut()
136 .rev()
137 .find(|m| matches!(m.role, crate::models::MessageRole::Assistant))
138 {
139 last_assistant.tool_calls = Some(current_tool_calls.clone());
140 }
141
142 let (regular_calls, agent_calls): (Vec<_>, Vec<_>) = current_tool_calls
144 .iter()
145 .partition(|tc| tc.function.name != "agent");
146
147 for tc in ®ular_calls {
149 let tool_call_id = tc
150 .id
151 .clone()
152 .unwrap_or_else(|| format!("call_{}_{}", iteration, tc.function.name));
153 let tool_name = tc.function.name.clone();
154
155 let agent_action = match tc.to_agent_action() {
156 Ok(action) => action,
157 Err(e) => {
158 let error_msg = format!("Error: {}", e);
159 messages.push(ChatMessage::tool(&tool_call_id, &tool_name, &error_msg));
160 all_tool_results.push(ToolExecutionResult {
161 tool_call_id,
162 tool_name,
163 action: AgentAction::ParseError {
164 message: error_msg.clone(),
165 },
166 success: false,
167 output: error_msg,
168 });
169 continue;
170 },
171 };
172
173 let result = execute_action(&agent_action).await;
174 let (success, output) = match &result {
175 AgentActionResult::Success { output } => (true, output.clone()),
176 AgentActionResult::Error { error } => (false, format!("Error: {}", error)),
177 };
178
179 observer.on_tool_result(&tool_name, &tool_call_id, &agent_action, &result);
180
181 messages.push(ChatMessage::tool(&tool_call_id, &tool_name, &output));
182 all_tool_results.push(ToolExecutionResult {
183 tool_call_id,
184 tool_name,
185 action: agent_action,
186 success,
187 output,
188 });
189 }
190
191 if !agent_calls.is_empty() {
193 let agent_specs: Vec<(String, String)> = agent_calls
194 .iter()
195 .filter_map(|tc| match tc.to_agent_action() {
196 Ok(AgentAction::SpawnAgent {
197 prompt,
198 description,
199 }) => Some((prompt, description)),
200 _ => None,
201 })
202 .collect();
203
204 if !agent_specs.is_empty() {
205 let progress = Arc::new(Mutex::new(Vec::<SubagentProgress>::new()));
206 let (handles, overflow) = spawn_subagents(
207 agent_specs,
208 Arc::clone(&model),
209 config,
210 Arc::clone(&progress),
211 );
212
213 let subagent_results = collect_subagent_results(handles, overflow).await;
214
215 for (i, result) in subagent_results.iter().enumerate() {
216 let tool_call_id = agent_calls
217 .get(i)
218 .and_then(|tc| tc.id.clone())
219 .unwrap_or_else(|| format!("call_agent_{}", i));
220 let tool_name = "agent".to_string();
221 let output = format_subagent_tool_result(result);
222
223 observer.on_tool_result(
224 &tool_name,
225 &tool_call_id,
226 &AgentAction::SpawnAgent {
227 prompt: String::new(),
228 description: result.description.clone(),
229 },
230 &if result.success {
231 AgentActionResult::Success {
232 output: output.clone(),
233 }
234 } else {
235 AgentActionResult::Error {
236 error: output.clone(),
237 }
238 },
239 );
240
241 messages.push(ChatMessage::tool(&tool_call_id, &tool_name, &output));
242 all_tool_results.push(ToolExecutionResult {
243 tool_call_id,
244 tool_name,
245 action: AgentAction::SpawnAgent {
246 prompt: String::new(),
247 description: result.description.clone(),
248 },
249 success: result.success,
250 output,
251 });
252
253 total_tokens += result.tokens;
254 }
255 }
256 }
257
258 observer.on_status(&format!(
259 "Iteration {} - {} tool(s) executed, calling model...",
260 iteration,
261 current_tool_calls.len()
262 ));
263 }
264
265 match observer.check_interrupt() {
267 LoopControl::Interrupt => {
268 interrupted = true;
269 break;
270 },
271 LoopControl::InjectMessage(msg) => {
272 messages.push(ChatMessage::user(msg));
273 },
274 LoopControl::Continue => {},
275 }
276
277 observer.on_generation_start();
279 let response_text = Arc::new(std::sync::Mutex::new(String::new()));
280 let response_clone = Arc::clone(&response_text);
281 let callback: StreamCallback = Arc::new(move |chunk: &str| {
282 let mut resp = response_clone.lock_mut_safe();
283 resp.push_str(chunk);
284 });
285
286 let model_result = {
287 let model = model.read().await;
288 model.chat(messages, config, Some(callback)).await
289 };
290
291 match model_result {
292 Ok(response) => {
293 let content = {
294 let buf = response_text.lock_mut_safe();
295 if !buf.is_empty() {
296 buf.clone()
297 } else {
298 response.content.clone()
299 }
300 };
301 let tokens = response.usage.map(|u| u.completion_tokens).unwrap_or(0);
302 total_tokens += tokens;
303 observer.on_generation_complete(tokens);
304
305 let new_tool_calls = response.tool_calls.unwrap_or_default();
306
307 if !content.is_empty() || !new_tool_calls.is_empty() {
309 let msg = ChatMessage::assistant(content.clone())
310 .with_tool_calls(new_tool_calls.clone());
311 messages.push(msg);
312 }
313
314 if new_tool_calls.is_empty() {
315 final_response = content;
317 observer.on_status(&format!(
318 "Agent loop complete after {} iterations",
319 iteration
320 ));
321 break;
322 } else {
323 current_tool_calls = new_tool_calls;
324 }
325 },
326 Err(e) => {
327 observer.on_error(&e.to_string());
328 break;
329 },
330 }
331 }
332
333 Ok(AgentLoopResult {
334 final_response,
335 iterations: iteration,
336 interrupted,
337 tool_results: all_tool_results,
338 total_tokens,
339 })
340}