Skip to main content

mermaid_cli/runtime/
agent_loop.rs

1//! Shared agent loop for tool-calling models.
2//!
3//! Both TUI and non-interactive modes delegate to this loop.
4//! The `AgentObserver` trait allows each mode to provide its own
5//! I/O behavior (rendering, interruption, logging).
6
7use std::sync::{Arc, Mutex};
8
9use anyhow::Result;
10
11use crate::agents::{
12    ActionResult as AgentActionResult, AgentAction, SubagentProgress, collect_subagent_results,
13    execute_action, format_subagent_tool_result, spawn_subagents,
14};
15use crate::models::{ChatMessage, Model, ModelConfig, StreamCallback, ToolCall};
16use crate::utils::MutexExt;
17
18/// Default maximum iterations for the agent loop
19pub const MAX_AGENT_ITERATIONS: usize = 25;
20
21/// How the agent loop communicates with its environment
22pub trait AgentObserver: Send {
23    /// Called between steps to check for user interruption or injected messages.
24    /// Returns `LoopControl::Continue` to proceed, `Interrupt` to stop,
25    /// or `InjectMessage(text)` to redirect the agent with new user input.
26    fn check_interrupt(&mut self) -> LoopControl;
27
28    /// Called when the loop status changes (e.g., "Iteration 3 - executing tools")
29    fn on_status(&mut self, message: &str);
30
31    /// Called after a tool call is executed
32    fn on_tool_result(
33        &mut self,
34        tool_name: &str,
35        tool_call_id: &str,
36        action: &AgentAction,
37        result: &AgentActionResult,
38    );
39
40    /// Called when the model returns an error
41    fn on_error(&mut self, error: &str);
42
43    /// Called when model generation starts (for status tracking)
44    fn on_generation_start(&mut self);
45
46    /// Called when model generation completes with token count
47    fn on_generation_complete(&mut self, tokens: usize);
48}
49
50/// Control flow for the agent loop
51pub enum LoopControl {
52    /// Continue normally
53    Continue,
54    /// User interrupted (Esc, Ctrl+C)
55    Interrupt,
56    /// User injected a new message that redirects the agent
57    InjectMessage(String),
58}
59
60/// Result of running the agent loop
61pub struct AgentLoopResult {
62    /// The model's final text response (from the last iteration with no tool calls)
63    pub final_response: String,
64    /// Number of iterations completed
65    pub iterations: usize,
66    /// Whether the loop was interrupted by the user
67    pub interrupted: bool,
68    /// All tool execution results across iterations
69    pub tool_results: Vec<ToolExecutionResult>,
70    /// Total tokens used across all model calls
71    pub total_tokens: usize,
72}
73
74/// Result of a single tool execution
75#[derive(Debug, Clone)]
76pub struct ToolExecutionResult {
77    pub tool_call_id: String,
78    pub tool_name: String,
79    pub action: AgentAction,
80    pub success: bool,
81    pub output: String,
82}
83
84/// Run the agent loop: execute tool calls, feed results back, repeat.
85///
86/// This is the core loop shared by TUI and non-interactive modes.
87/// The observer trait provides environment-specific behavior.
88pub async fn run_agent_loop(
89    model: Arc<tokio::sync::RwLock<Box<dyn Model>>>,
90    config: &ModelConfig,
91    messages: &mut Vec<ChatMessage>,
92    initial_tool_calls: Vec<ToolCall>,
93    observer: &mut dyn AgentObserver,
94    max_iterations: usize,
95) -> Result<AgentLoopResult> {
96    let mut current_tool_calls = initial_tool_calls;
97    let mut iteration = 0;
98    let mut all_tool_results = Vec::new();
99    let mut total_tokens = 0;
100    let mut final_response = String::new();
101    let mut interrupted = false;
102
103    while !current_tool_calls.is_empty() {
104        iteration += 1;
105        if iteration > max_iterations {
106            observer.on_status(&format!(
107                "Agent loop exceeded {} iterations",
108                max_iterations
109            ));
110            break;
111        }
112
113        observer.on_status(&format!("Agent loop iteration {}", iteration));
114
115        // Check for interruption or injected messages
116        match observer.check_interrupt() {
117            LoopControl::Continue => {},
118            LoopControl::Interrupt => {
119                interrupted = true;
120                break;
121            },
122            LoopControl::InjectMessage(msg) => {
123                // User typed a message during the loop -- redirect agent
124                observer.on_status("Processing queued message...");
125                messages.push(ChatMessage::user(msg));
126                current_tool_calls.clear();
127                // Falls through to model call below
128            },
129        }
130
131        // If tool calls were cleared by InjectMessage, skip execution and go to model call
132        if !current_tool_calls.is_empty() {
133            // Add tool_calls to the last assistant message in history
134            if let Some(last_assistant) = messages
135                .iter_mut()
136                .rev()
137                .find(|m| matches!(m.role, crate::models::MessageRole::Assistant))
138            {
139                last_assistant.tool_calls = Some(current_tool_calls.clone());
140            }
141
142            // Partition into regular tool calls and agent tool calls
143            let (regular_calls, agent_calls): (Vec<_>, Vec<_>) = current_tool_calls
144                .iter()
145                .partition(|tc| tc.function.name != "agent");
146
147            // Execute regular tool calls first (sequential, as before)
148            for tc in &regular_calls {
149                let tool_call_id = tc
150                    .id
151                    .clone()
152                    .unwrap_or_else(|| format!("call_{}_{}", iteration, tc.function.name));
153                let tool_name = tc.function.name.clone();
154
155                let agent_action = match tc.to_agent_action() {
156                    Ok(action) => action,
157                    Err(e) => {
158                        let error_msg = format!("Error: {}", e);
159                        messages.push(ChatMessage::tool(&tool_call_id, &tool_name, &error_msg));
160                        all_tool_results.push(ToolExecutionResult {
161                            tool_call_id,
162                            tool_name,
163                            action: AgentAction::ParseError {
164                                message: error_msg.clone(),
165                            },
166                            success: false,
167                            output: error_msg,
168                        });
169                        continue;
170                    },
171                };
172
173                let result = execute_action(&agent_action).await;
174                let (success, output) = match &result {
175                    AgentActionResult::Success { output } => (true, output.clone()),
176                    AgentActionResult::Error { error } => (false, format!("Error: {}", error)),
177                };
178
179                observer.on_tool_result(&tool_name, &tool_call_id, &agent_action, &result);
180
181                messages.push(ChatMessage::tool(&tool_call_id, &tool_name, &output));
182                all_tool_results.push(ToolExecutionResult {
183                    tool_call_id,
184                    tool_name,
185                    action: agent_action,
186                    success,
187                    output,
188                });
189            }
190
191            // Execute agent tool calls in parallel (non-interactive: join_all directly)
192            if !agent_calls.is_empty() {
193                let agent_specs: Vec<(String, String)> = agent_calls
194                    .iter()
195                    .filter_map(|tc| match tc.to_agent_action() {
196                        Ok(AgentAction::SpawnAgent {
197                            prompt,
198                            description,
199                        }) => Some((prompt, description)),
200                        _ => None,
201                    })
202                    .collect();
203
204                if !agent_specs.is_empty() {
205                    let progress = Arc::new(Mutex::new(Vec::<SubagentProgress>::new()));
206                    let (handles, overflow) = spawn_subagents(
207                        agent_specs,
208                        Arc::clone(&model),
209                        config,
210                        Arc::clone(&progress),
211                    );
212
213                    let subagent_results = collect_subagent_results(handles, overflow).await;
214
215                    for (i, result) in subagent_results.iter().enumerate() {
216                        let tool_call_id = agent_calls
217                            .get(i)
218                            .and_then(|tc| tc.id.clone())
219                            .unwrap_or_else(|| format!("call_agent_{}", i));
220                        let tool_name = "agent".to_string();
221                        let output = format_subagent_tool_result(result);
222
223                        observer.on_tool_result(
224                            &tool_name,
225                            &tool_call_id,
226                            &AgentAction::SpawnAgent {
227                                prompt: String::new(),
228                                description: result.description.clone(),
229                            },
230                            &if result.success {
231                                AgentActionResult::Success {
232                                    output: output.clone(),
233                                }
234                            } else {
235                                AgentActionResult::Error {
236                                    error: output.clone(),
237                                }
238                            },
239                        );
240
241                        messages.push(ChatMessage::tool(&tool_call_id, &tool_name, &output));
242                        all_tool_results.push(ToolExecutionResult {
243                            tool_call_id,
244                            tool_name,
245                            action: AgentAction::SpawnAgent {
246                                prompt: String::new(),
247                                description: result.description.clone(),
248                            },
249                            success: result.success,
250                            output,
251                        });
252
253                        total_tokens += result.tokens;
254                    }
255                }
256            }
257
258            observer.on_status(&format!(
259                "Iteration {} - {} tool(s) executed, calling model...",
260                iteration,
261                current_tool_calls.len()
262            ));
263        }
264
265        // Check for interruption before model call
266        match observer.check_interrupt() {
267            LoopControl::Interrupt => {
268                interrupted = true;
269                break;
270            },
271            LoopControl::InjectMessage(msg) => {
272                messages.push(ChatMessage::user(msg));
273            },
274            LoopControl::Continue => {},
275        }
276
277        // Call model with updated history
278        observer.on_generation_start();
279        let response_text = Arc::new(std::sync::Mutex::new(String::new()));
280        let response_clone = Arc::clone(&response_text);
281        let callback: StreamCallback = Arc::new(move |chunk: &str| {
282            let mut resp = response_clone.lock_mut_safe();
283            resp.push_str(chunk);
284        });
285
286        let model_result = {
287            let model = model.read().await;
288            model.chat(messages, config, Some(callback)).await
289        };
290
291        match model_result {
292            Ok(response) => {
293                let content = {
294                    let buf = response_text.lock_mut_safe();
295                    if !buf.is_empty() {
296                        buf.clone()
297                    } else {
298                        response.content.clone()
299                    }
300                };
301                let tokens = response.usage.map(|u| u.completion_tokens).unwrap_or(0);
302                total_tokens += tokens;
303                observer.on_generation_complete(tokens);
304
305                let new_tool_calls = response.tool_calls.unwrap_or_default();
306
307                // Add assistant message to history
308                if !content.is_empty() || !new_tool_calls.is_empty() {
309                    let msg = ChatMessage::assistant(content.clone())
310                        .with_tool_calls(new_tool_calls.clone());
311                    messages.push(msg);
312                }
313
314                if new_tool_calls.is_empty() {
315                    // No more tool calls -- agent loop complete
316                    final_response = content;
317                    observer.on_status(&format!(
318                        "Agent loop complete after {} iterations",
319                        iteration
320                    ));
321                    break;
322                } else {
323                    current_tool_calls = new_tool_calls;
324                }
325            },
326            Err(e) => {
327                observer.on_error(&e.to_string());
328                break;
329            },
330        }
331    }
332
333    Ok(AgentLoopResult {
334        final_response,
335        iterations: iteration,
336        interrupted,
337        tool_results: all_tool_results,
338        total_tokens,
339    })
340}