Skip to main content

mermaid_cli/runtime/
agent_loop.rs

1//! Shared agent loop for tool-calling models.
2//!
3//! Used by **non-interactive mode** (`SilentObserver`) and **sub-agents**
4//! (`SubagentObserver`). The TUI has its own agent loop in
5//! `tui::loop_coordinator::run_agent_loop` because it needs direct `Terminal`
6//! access for live rendering and interrupt handling — `Terminal` is not `Send`
7//! so it cannot be passed through the `AgentObserver` trait.
8//!
9//! The `AgentObserver` trait allows each non-TUI consumer to provide its own
10//! I/O behavior (interruption checks, status logging, tool result handling).
11
12use std::sync::{Arc, Mutex};
13
14use anyhow::Result;
15
16use crate::agents::{
17    ActionResult as AgentActionResult, AgentAction, SubagentProgress, collect_subagent_results,
18    execute_action, format_subagent_tool_result, spawn_subagents,
19};
20use crate::models::{ChatMessage, Model, ModelConfig, StreamCallback, ToolCall};
21use crate::utils::MutexExt;
22
23/// Default maximum iterations for the agent loop
24pub const MAX_AGENT_ITERATIONS: usize = 25;
25
26/// How the agent loop communicates with its environment
27pub trait AgentObserver: Send {
28    /// Called between steps to check for user interruption or injected messages.
29    /// Returns `LoopControl::Continue` to proceed, `Interrupt` to stop,
30    /// or `InjectMessage(text)` to redirect the agent with new user input.
31    fn check_interrupt(&mut self) -> LoopControl;
32
33    /// Called when the loop status changes (e.g., "Iteration 3 - executing tools")
34    fn on_status(&mut self, message: &str);
35
36    /// Called after a tool call is executed
37    fn on_tool_result(
38        &mut self,
39        tool_name: &str,
40        tool_call_id: &str,
41        action: &AgentAction,
42        result: &AgentActionResult,
43    );
44
45    /// Called when the model returns an error
46    fn on_error(&mut self, error: &str);
47
48    /// Called when model generation starts (for status tracking)
49    fn on_generation_start(&mut self);
50
51    /// Called when model generation completes with token count
52    fn on_generation_complete(&mut self, tokens: usize);
53}
54
55/// Control flow for the agent loop
56pub enum LoopControl {
57    /// Continue normally
58    Continue,
59    /// User interrupted (Esc, Ctrl+C)
60    Interrupt,
61    /// User injected a new message that redirects the agent
62    InjectMessage(String),
63}
64
65/// Result of running the agent loop
66pub struct AgentLoopResult {
67    /// The model's final text response (from the last iteration with no tool calls)
68    pub final_response: String,
69    /// Number of iterations completed
70    pub iterations: usize,
71    /// Whether the loop was interrupted by the user
72    pub interrupted: bool,
73    /// All tool execution results across iterations
74    pub tool_results: Vec<ToolExecutionResult>,
75    /// Total tokens used across all model calls
76    pub total_tokens: usize,
77}
78
79/// Result of a single tool execution
80#[derive(Debug, Clone)]
81pub struct ToolExecutionResult {
82    pub tool_call_id: String,
83    pub tool_name: String,
84    pub action: AgentAction,
85    pub success: bool,
86    pub output: String,
87    pub images: Option<Vec<String>>,
88}
89
90/// Run the agent loop: execute tool calls, feed results back, repeat.
91///
92/// Used by non-interactive mode and sub-agents. The TUI has its own
93/// implementation in `tui::loop_coordinator::run_agent_loop` — see that
94/// function's documentation for the rationale.
95pub async fn run_agent_loop(
96    model: Arc<tokio::sync::RwLock<Box<dyn Model>>>,
97    config: &ModelConfig,
98    messages: &mut Vec<ChatMessage>,
99    initial_tool_calls: Vec<ToolCall>,
100    observer: &mut dyn AgentObserver,
101    max_iterations: usize,
102) -> Result<AgentLoopResult> {
103    let mut current_tool_calls = initial_tool_calls;
104    let mut iteration = 0;
105    let mut all_tool_results = Vec::new();
106    let mut total_tokens = 0;
107    let mut final_response = String::new();
108    let mut interrupted = false;
109
110    while !current_tool_calls.is_empty() {
111        iteration += 1;
112        if iteration > max_iterations {
113            observer.on_status(&format!(
114                "Agent loop exceeded {} iterations",
115                max_iterations
116            ));
117            break;
118        }
119
120        observer.on_status(&format!("Agent loop iteration {}", iteration));
121
122        // Check for interruption or injected messages
123        match observer.check_interrupt() {
124            LoopControl::Continue => {},
125            LoopControl::Interrupt => {
126                interrupted = true;
127                break;
128            },
129            LoopControl::InjectMessage(msg) => {
130                // User typed a message during the loop -- redirect agent
131                observer.on_status("Processing queued message...");
132                messages.push(ChatMessage::user(msg));
133                current_tool_calls.clear();
134                // Falls through to model call below
135            },
136        }
137
138        // If tool calls were cleared by InjectMessage, skip execution and go to model call
139        if !current_tool_calls.is_empty() {
140            // Add tool_calls to the last assistant message in history
141            if let Some(last_assistant) = messages
142                .iter_mut()
143                .rev()
144                .find(|m| matches!(m.role, crate::models::MessageRole::Assistant))
145            {
146                last_assistant.tool_calls = Some(current_tool_calls.clone());
147            }
148
149            // Partition into regular tool calls and agent tool calls
150            let (regular_calls, agent_calls): (Vec<_>, Vec<_>) = current_tool_calls
151                .iter()
152                .partition(|tc| tc.function.name != "agent");
153
154            // Execute regular tool calls first (sequential, as before)
155            for tc in &regular_calls {
156                let tool_call_id = tc
157                    .id
158                    .clone()
159                    .unwrap_or_else(|| format!("call_{}_{}", iteration, tc.function.name));
160                let tool_name = tc.function.name.clone();
161
162                let agent_action = match tc.to_agent_action() {
163                    Ok(action) => action,
164                    Err(e) => {
165                        let error_msg = format!("Error: {}", e);
166                        messages.push(ChatMessage::tool(&tool_call_id, &tool_name, &error_msg));
167                        all_tool_results.push(ToolExecutionResult {
168                            tool_call_id,
169                            tool_name,
170                            action: AgentAction::ParseError {
171                                message: error_msg.clone(),
172                            },
173                            success: false,
174                            output: error_msg,
175                            images: None,
176                        });
177                        continue;
178                    },
179                };
180
181                let result = execute_action(&agent_action).await;
182                let (success, output, images) = match &result {
183                    AgentActionResult::Success { output, images } => {
184                        (true, output.clone(), images.clone())
185                    },
186                    AgentActionResult::Error { error } => {
187                        (false, format!("Error: {}", error), None)
188                    },
189                };
190
191                observer.on_tool_result(&tool_name, &tool_call_id, &agent_action, &result);
192
193                let mut tool_msg = ChatMessage::tool(&tool_call_id, &tool_name, &output);
194                if let Some(ref imgs) = images {
195                    tool_msg = tool_msg.with_images(imgs.clone());
196                }
197                messages.push(tool_msg);
198                all_tool_results.push(ToolExecutionResult {
199                    tool_call_id,
200                    tool_name,
201                    action: agent_action,
202                    success,
203                    output,
204                    images,
205                });
206            }
207
208            // Execute agent tool calls in parallel (non-interactive: join_all directly)
209            if !agent_calls.is_empty() {
210                let agent_specs: Vec<(String, String)> = agent_calls
211                    .iter()
212                    .filter_map(|tc| match tc.to_agent_action() {
213                        Ok(AgentAction::SpawnAgent {
214                            prompt,
215                            description,
216                        }) => Some((prompt, description)),
217                        _ => None,
218                    })
219                    .collect();
220
221                if !agent_specs.is_empty() {
222                    let progress = Arc::new(Mutex::new(Vec::<SubagentProgress>::new()));
223                    let (handles, overflow) = spawn_subagents(
224                        agent_specs,
225                        Arc::clone(&model),
226                        config,
227                        Arc::clone(&progress),
228                    );
229
230                    let subagent_results = collect_subagent_results(handles, overflow).await;
231
232                    for (i, result) in subagent_results.iter().enumerate() {
233                        let tool_call_id = agent_calls
234                            .get(i)
235                            .and_then(|tc| tc.id.clone())
236                            .unwrap_or_else(|| format!("call_agent_{}", i));
237                        let tool_name = "agent".to_string();
238                        let output = format_subagent_tool_result(result);
239
240                        observer.on_tool_result(
241                            &tool_name,
242                            &tool_call_id,
243                            &AgentAction::SpawnAgent {
244                                prompt: String::new(),
245                                description: result.description.clone(),
246                            },
247                            &if result.success {
248                                AgentActionResult::Success {
249                                    output: output.clone(),
250                                    images: None,
251                                }
252                            } else {
253                                AgentActionResult::Error {
254                                    error: output.clone(),
255                                }
256                            },
257                        );
258
259                        messages.push(ChatMessage::tool(&tool_call_id, &tool_name, &output));
260                        all_tool_results.push(ToolExecutionResult {
261                            tool_call_id,
262                            tool_name,
263                            action: AgentAction::SpawnAgent {
264                                prompt: String::new(),
265                                description: result.description.clone(),
266                            },
267                            success: result.success,
268                            output,
269                            images: None,
270                        });
271
272                        total_tokens += result.tokens;
273                    }
274                }
275            }
276
277            observer.on_status(&format!(
278                "Iteration {} - {} tool(s) executed, calling model...",
279                iteration,
280                current_tool_calls.len()
281            ));
282        }
283
284        // Check for interruption before model call
285        match observer.check_interrupt() {
286            LoopControl::Interrupt => {
287                interrupted = true;
288                break;
289            },
290            LoopControl::InjectMessage(msg) => {
291                messages.push(ChatMessage::user(msg));
292            },
293            LoopControl::Continue => {},
294        }
295
296        // Call model with updated history
297        observer.on_generation_start();
298        let response_text = Arc::new(std::sync::Mutex::new(String::new()));
299        let response_clone = Arc::clone(&response_text);
300        let callback: StreamCallback = Arc::new(move |chunk: &str| {
301            let mut resp = response_clone.lock_mut_safe();
302            resp.push_str(chunk);
303        });
304
305        let model_result = {
306            let model = model.read().await;
307            model.chat(messages, config, Some(callback)).await
308        };
309
310        match model_result {
311            Ok(response) => {
312                let content = {
313                    let buf = response_text.lock_mut_safe();
314                    if !buf.is_empty() {
315                        buf.clone()
316                    } else {
317                        response.content.clone()
318                    }
319                };
320                let tokens = response.usage.map(|u| u.total_tokens).unwrap_or(0);
321                total_tokens += tokens;
322                observer.on_generation_complete(tokens);
323
324                let new_tool_calls = response.tool_calls.unwrap_or_default();
325
326                // Add assistant message to history
327                if !content.is_empty() || !new_tool_calls.is_empty() {
328                    let msg = ChatMessage::assistant(content.clone())
329                        .with_tool_calls(new_tool_calls.clone());
330                    messages.push(msg);
331                }
332
333                if new_tool_calls.is_empty() {
334                    // No more tool calls -- agent loop complete
335                    final_response = content;
336                    observer.on_status(&format!(
337                        "Agent loop complete after {} iterations",
338                        iteration
339                    ));
340                    break;
341                } else {
342                    current_tool_calls = new_tool_calls;
343                }
344            },
345            Err(e) => {
346                observer.on_error(&e.to_string());
347                break;
348            },
349        }
350    }
351
352    Ok(AgentLoopResult {
353        final_response,
354        iterations: iteration,
355        interrupted,
356        tool_results: all_tool_results,
357        total_tokens,
358    })
359}