Skip to main content

mermaid_cli/runtime/
agent_loop.rs

1//! Shared agent loop for tool-calling models.
2//!
3//! Used by **non-interactive mode** (`SilentObserver`), **sub-agents**
4//! (`SubagentObserver`), and the **TUI** (`TuiObserver`). The observer
5//! pattern lets each consumer plug in its own model-call and subagent-
6//! handling strategies: the default trait methods do simple sync/join
7//! work (good for non-interactive and subagents); the TUI overrides
8//! `call_model` and `run_subagents` to drive channel-based streaming
9//! and live rendering without forking the core loop.
10//!
11//! The trait is `Send`-bounded so futures returned by its default
12//! methods are Send (required by `tokio::spawn` in subagent code).
13//! The TUI observer is also Send: it holds `&mut Terminal<...>` and
14//! `&mut App`, both of which are Send since `io::Stdout` and the
15//! tokio sync primitives inside App are Send.
16
17use std::sync::{Arc, Mutex};
18
19use anyhow::Result;
20use async_trait::async_trait;
21use tokio::sync::RwLock;
22
23use crate::agents::{
24    ActionResult as AgentActionResult, AgentAction, SubagentProgress, SubagentResult,
25    collect_subagent_results, execute_action, format_subagent_tool_result, spawn_subagents,
26};
27use crate::models::{ChatMessage, Model, ModelConfig, StreamCallback, StreamEvent, ToolCall};
28use crate::utils::MutexExt;
29
30/// Default maximum iterations for the agent loop
31pub const MAX_AGENT_ITERATIONS: usize = 25;
32
33/// Output of a single model call, returned from the `call_model` hook.
34#[derive(Debug, Clone, Default)]
35pub struct ModelCallOutput {
36    pub content: String,
37    pub tool_calls: Vec<ToolCall>,
38    pub tokens: usize,
39}
40
41/// How the agent loop communicates with its environment.
42///
43/// The six small sync hooks (`check_interrupt`, `on_status`, etc.) are
44/// called between steps. The two async hooks (`call_model`,
45/// `run_subagents`) do the heavy lifting and have default
46/// implementations suitable for non-interactive use — the TUI overrides
47/// them to thread channel-based streaming and live rendering through
48/// the shared loop without forking it.
49#[async_trait]
50pub trait AgentObserver: Send {
51    /// Called between steps to check for user interruption or injected messages.
52    /// Returns `LoopControl::Continue` to proceed, `Interrupt` to stop,
53    /// or `InjectMessage(text)` to redirect the agent with new user input.
54    fn check_interrupt(&mut self) -> LoopControl;
55
56    /// Called when the loop status changes (e.g., "Iteration 3 - executing tools")
57    fn on_status(&mut self, message: &str);
58
59    /// Called after a tool call is executed.
60    ///
61    /// Observers may use this to mirror tool results into side storage
62    /// (e.g., the TUI commits the tool message to session_state for
63    /// live rendering) — but the loop ALSO appends the tool message to
64    /// its own `messages` vec so the next model call sees it. When
65    /// that vec IS the side storage (TUI passes `&mut session_state
66    /// .messages`), both paths land on the same data.
67    fn on_tool_result(
68        &mut self,
69        tool_name: &str,
70        tool_call_id: &str,
71        action: &AgentAction,
72        result: &AgentActionResult,
73    );
74
75    /// Called when the model returns an error.
76    fn on_error(&mut self, error: &str);
77
78    /// Called when model generation starts (for status tracking).
79    fn on_generation_start(&mut self);
80
81    /// Called when model generation completes with token count.
82    fn on_generation_complete(&mut self, tokens: usize);
83
84    /// Called whenever the loop appends a message to its `messages` vec.
85    ///
86    /// Default: no-op. The TUI overrides this to mirror the message into
87    /// `app.session_state.messages` (its UI source of truth) without
88    /// requiring run_agent_loop to borrow session_state mutably (which
89    /// would alias through the observer).
90    fn on_message_appended(&mut self, _msg: &ChatMessage) {}
91
92    /// Call the model and return its response.
93    ///
94    /// Default: direct `model.chat_typed()` with a typed-event accumulator —
95    /// matches what non-interactive mode and subagents need. Reasoning
96    /// chunks are accumulated separately so they don't pollute the
97    /// returned text content. Tool calls are deduped against the response
98    /// (the adapter populates `ModelResponse.tool_calls` as a fallback if
99    /// the typed stream emitted none).
100    ///
101    /// Override for: channel-based streaming, mid-stream interrupt, UI
102    /// rendering between chunks.
103    async fn call_model(
104        &mut self,
105        model: Arc<RwLock<Box<dyn Model>>>,
106        messages: &[ChatMessage],
107        config: &ModelConfig,
108    ) -> Result<ModelCallOutput> {
109        let text = Arc::new(std::sync::Mutex::new(String::new()));
110        let typed_tool_calls = Arc::new(std::sync::Mutex::new(Vec::<ToolCall>::new()));
111        let text_clone = Arc::clone(&text);
112        let tool_clone = Arc::clone(&typed_tool_calls);
113        let callback: StreamCallback = Arc::new(move |event| match event {
114            StreamEvent::Text(chunk) => {
115                text_clone.lock_mut_safe().push_str(&chunk);
116            },
117            StreamEvent::ToolCall(tc) => {
118                tool_clone.lock_mut_safe().push(tc);
119            },
120            // Reasoning chunks dropped — the loop doesn't surface
121            // reasoning to observers. `ModelResponse.thinking` (still
122            // populated by the adapter) is available if needed.
123            StreamEvent::Reasoning(_) | StreamEvent::Done { .. } => {},
124        });
125
126        let model_guard = model.read().await;
127        let response = model_guard
128            .chat(messages, config, Some(callback))
129            .await
130            .map_err(|e| anyhow::anyhow!("{}", e))?;
131
132        let streamed_text = text.lock_mut_safe().clone();
133        let content = if !streamed_text.is_empty() {
134            streamed_text
135        } else {
136            response.content.clone()
137        };
138        let tokens = response.usage.map(|u| u.total_tokens).unwrap_or(0);
139        let streamed_tool_calls = std::mem::take(&mut *typed_tool_calls.lock_mut_safe());
140        let tool_calls = if !streamed_tool_calls.is_empty() {
141            streamed_tool_calls
142        } else {
143            response.tool_calls.unwrap_or_default()
144        };
145
146        Ok(ModelCallOutput {
147            content,
148            tool_calls,
149            tokens,
150        })
151    }
152
153    /// Run a batch of subagents to completion and return their results.
154    ///
155    /// Default: `spawn_subagents` + `collect_subagent_results` with no
156    /// rendering between polls. Override for live progress rendering.
157    async fn run_subagents(
158        &mut self,
159        specs: Vec<(String, String)>,
160        model: Arc<RwLock<Box<dyn Model>>>,
161        config: &ModelConfig,
162    ) -> Vec<SubagentResult> {
163        let progress = Arc::new(Mutex::new(Vec::<SubagentProgress>::new()));
164        let (handles, overflow) = spawn_subagents(specs, model, config, Arc::clone(&progress));
165        collect_subagent_results(handles, overflow).await
166    }
167}
168
169/// Control flow for the agent loop
170pub enum LoopControl {
171    /// Continue normally
172    Continue,
173    /// User interrupted (Esc, Ctrl+C)
174    Interrupt,
175    /// User injected a new message that redirects the agent
176    InjectMessage(String),
177}
178
179/// Result of running the agent loop
180pub struct AgentLoopResult {
181    /// The model's final text response (from the last iteration with no tool calls)
182    pub final_response: String,
183    /// Number of iterations completed
184    pub iterations: usize,
185    /// Whether the loop was interrupted by the user
186    pub interrupted: bool,
187    /// All tool execution results across iterations
188    pub tool_results: Vec<ToolExecutionResult>,
189    /// Total tokens used across all model calls
190    pub total_tokens: usize,
191}
192
193/// Result of a single tool execution
194#[derive(Debug, Clone)]
195pub struct ToolExecutionResult {
196    pub tool_call_id: String,
197    pub tool_name: String,
198    pub action: AgentAction,
199    pub success: bool,
200    pub output: String,
201    pub images: Option<Vec<String>>,
202}
203
204/// Run the agent loop: execute tool calls, feed results back, repeat.
205///
206/// Used by non-interactive mode, sub-agents, AND the TUI. The observer's
207/// `call_model` and `run_subagents` hooks let each caller customize the
208/// execution strategy without duplicating the loop skeleton.
209pub async fn run_agent_loop(
210    model: Arc<RwLock<Box<dyn Model>>>,
211    config: &ModelConfig,
212    messages: &mut Vec<ChatMessage>,
213    initial_tool_calls: Vec<ToolCall>,
214    observer: &mut dyn AgentObserver,
215    max_iterations: usize,
216) -> Result<AgentLoopResult> {
217    let mut current_tool_calls = initial_tool_calls;
218    let mut iteration = 0;
219    let mut all_tool_results = Vec::new();
220    let mut total_tokens = 0;
221    let mut final_response = String::new();
222    let mut interrupted = false;
223
224    while !current_tool_calls.is_empty() {
225        iteration += 1;
226        if iteration > max_iterations {
227            observer.on_status(&format!(
228                "Agent loop exceeded {} iterations",
229                max_iterations
230            ));
231            break;
232        }
233
234        observer.on_status(&format!("Agent loop iteration {}", iteration));
235
236        // Check for interruption or injected messages
237        match observer.check_interrupt() {
238            LoopControl::Continue => {},
239            LoopControl::Interrupt => {
240                interrupted = true;
241                break;
242            },
243            LoopControl::InjectMessage(msg) => {
244                // User typed a message during the loop -- redirect agent
245                observer.on_status("Processing queued message...");
246                let user_msg = ChatMessage::user(msg);
247                observer.on_message_appended(&user_msg);
248                messages.push(user_msg);
249                current_tool_calls.clear();
250                // Falls through to model call below
251            },
252        }
253
254        // If tool calls were cleared by InjectMessage, skip execution and go to model call
255        if !current_tool_calls.is_empty() {
256            // Every caller pre-pushes the assistant message with its tool_calls
257            // already set (via `ChatMessage::with_tool_calls`), and subsequent
258            // iterations push fresh assistants with_tool_calls below — so no
259            // in-place mutation of the previous assistant is needed here.
260
261            // Partition into regular tool calls and agent tool calls
262            let (regular_calls, agent_calls): (Vec<_>, Vec<_>) = current_tool_calls
263                .iter()
264                .partition(|tc| tc.function.name != "agent");
265
266            // Execute regular tool calls first (sequential, as before).
267            // The fallback id includes the in-iteration index so two calls
268            // to the same tool (e.g., two `read_file`s in parallel) don't
269            // collide when the model omits the `id` field.
270            for (idx, tc) in regular_calls.iter().enumerate() {
271                let tool_call_id = tc
272                    .id
273                    .clone()
274                    .unwrap_or_else(|| format!("call_{}_{}_{}", iteration, idx, tc.function.name));
275                let tool_name = tc.function.name.clone();
276
277                let agent_action = match tc.to_agent_action() {
278                    Ok(action) => action,
279                    Err(e) => {
280                        let error_msg = format!("Error: {}", e);
281                        let tool_msg = ChatMessage::tool(&tool_call_id, &tool_name, &error_msg);
282                        observer.on_message_appended(&tool_msg);
283                        messages.push(tool_msg);
284                        all_tool_results.push(ToolExecutionResult {
285                            tool_call_id,
286                            tool_name,
287                            action: AgentAction::ParseError {
288                                message: error_msg.clone(),
289                            },
290                            success: false,
291                            output: error_msg,
292                            images: None,
293                        });
294                        continue;
295                    },
296                };
297
298                let result = execute_action(&agent_action).await;
299                let (success, output, images) = match &result {
300                    AgentActionResult::Success { output, images } => {
301                        (true, output.clone(), images.clone())
302                    },
303                    AgentActionResult::Error { error } => {
304                        (false, format!("Error: {}", error), None)
305                    },
306                };
307
308                observer.on_tool_result(&tool_name, &tool_call_id, &agent_action, &result);
309
310                let mut tool_msg = ChatMessage::tool(&tool_call_id, &tool_name, &output);
311                if let Some(ref imgs) = images {
312                    tool_msg = tool_msg.with_images(imgs.clone());
313                }
314                observer.on_message_appended(&tool_msg);
315                messages.push(tool_msg);
316                all_tool_results.push(ToolExecutionResult {
317                    tool_call_id,
318                    tool_name,
319                    action: agent_action,
320                    success,
321                    output,
322                    images,
323                });
324            }
325
326            // Execute agent tool calls in parallel via the observer hook
327            // (default: join_all; TUI: poll with live rendering).
328            if !agent_calls.is_empty() {
329                let agent_specs: Vec<(String, String)> = agent_calls
330                    .iter()
331                    .filter_map(|tc| match tc.to_agent_action() {
332                        Ok(AgentAction::SpawnAgent {
333                            prompt,
334                            description,
335                        }) => Some((prompt, description)),
336                        _ => None,
337                    })
338                    .collect();
339
340                if !agent_specs.is_empty() {
341                    let subagent_results = observer
342                        .run_subagents(agent_specs, Arc::clone(&model), config)
343                        .await;
344
345                    for (i, result) in subagent_results.iter().enumerate() {
346                        let tool_call_id = agent_calls
347                            .get(i)
348                            .and_then(|tc| tc.id.clone())
349                            .unwrap_or_else(|| format!("call_agent_{}_{}", iteration, i));
350                        let tool_name = "agent".to_string();
351                        let output = format_subagent_tool_result(result);
352
353                        observer.on_tool_result(
354                            &tool_name,
355                            &tool_call_id,
356                            &AgentAction::SpawnAgent {
357                                prompt: String::new(),
358                                description: result.description.clone(),
359                            },
360                            &if result.success {
361                                AgentActionResult::Success {
362                                    output: output.clone(),
363                                    images: None,
364                                }
365                            } else {
366                                AgentActionResult::Error {
367                                    error: output.clone(),
368                                }
369                            },
370                        );
371
372                        let tool_msg = ChatMessage::tool(&tool_call_id, &tool_name, &output);
373                        observer.on_message_appended(&tool_msg);
374                        messages.push(tool_msg);
375                        all_tool_results.push(ToolExecutionResult {
376                            tool_call_id,
377                            tool_name,
378                            action: AgentAction::SpawnAgent {
379                                prompt: String::new(),
380                                description: result.description.clone(),
381                            },
382                            success: result.success,
383                            output,
384                            images: None,
385                        });
386
387                        total_tokens += result.tokens;
388                    }
389                }
390            }
391
392            observer.on_status(&format!(
393                "Iteration {} - {} tool(s) executed, calling model...",
394                iteration,
395                current_tool_calls.len()
396            ));
397        }
398
399        // Check for interruption before model call
400        match observer.check_interrupt() {
401            LoopControl::Interrupt => {
402                interrupted = true;
403                break;
404            },
405            LoopControl::InjectMessage(msg) => {
406                let user_msg = ChatMessage::user(msg);
407                observer.on_message_appended(&user_msg);
408                messages.push(user_msg);
409            },
410            LoopControl::Continue => {},
411        }
412
413        // Call model via the observer hook (default: direct chat; TUI:
414        // channel-based streaming with live rendering).
415        observer.on_generation_start();
416        let model_result = observer
417            .call_model(Arc::clone(&model), messages, config)
418            .await;
419
420        match model_result {
421            Ok(out) => {
422                total_tokens += out.tokens;
423                observer.on_generation_complete(out.tokens);
424
425                let new_tool_calls = out.tool_calls;
426
427                // Add assistant message to history
428                if !out.content.is_empty() || !new_tool_calls.is_empty() {
429                    let msg = ChatMessage::assistant(out.content.clone())
430                        .with_tool_calls(new_tool_calls.clone());
431                    observer.on_message_appended(&msg);
432                    messages.push(msg);
433                }
434
435                if new_tool_calls.is_empty() {
436                    // No more tool calls -- agent loop complete
437                    final_response = out.content;
438                    observer.on_status(&format!(
439                        "Agent loop complete after {} iterations",
440                        iteration
441                    ));
442                    break;
443                } else {
444                    current_tool_calls = new_tool_calls;
445                }
446            },
447            Err(e) => {
448                observer.on_error(&e.to_string());
449                break;
450            },
451        }
452    }
453
454    Ok(AgentLoopResult {
455        final_response,
456        iterations: iteration,
457        interrupted,
458        tool_results: all_tool_results,
459        total_tokens,
460    })
461}