Skip to main content

matrixcode_core/agent/
run.rs

1//! Agent run loop and public methods.
2
3use anyhow::Result;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicU8, Ordering};
6use tokio::sync::mpsc;
7
8use crate::approval::ApproveMode;
9use crate::cancel::CancellationToken;
10use crate::compress::{
11    CompressionStrategy, compress_messages, estimate_total_tokens, should_compress,
12};
13use crate::event::{AgentEvent, EventData, EventType};
14use crate::prompt;
15use crate::providers::{ChatRequest, Message, MessageContent, Role};
16use crate::tools::Tool;
17use crate::tools::ToolDefinition;
18use crate::tools::toolproxy::{ProxyToolDef, ProxyToolExecutor};
19
20use super::types::{Agent, AgentBuilder, MAX_ITERATIONS};
21
22impl Agent {
23    pub(crate) fn new(builder: AgentBuilder) -> Self {
24        let event_tx = builder.event_tx.unwrap_or_else(|| {
25            let (tx, _) = mpsc::channel(100);
26            tx
27        });
28
29        Self {
30            provider: builder.provider,
31            model_name: builder.model_name,
32            tools: builder.tools,
33            messages: Vec::new(),
34            system_prompt: builder.system_prompt,
35            max_tokens: builder.max_tokens,
36            context_size_override: builder.context_size_override,
37            think: builder.think,
38            approve_mode: Arc::new(AtomicU8::new(builder.approve_mode.to_u8())),
39            event_tx,
40            skills: builder.skills,
41            profile: builder.profile,
42            project_overview: builder.project_overview,
43            memory_summary: builder.memory_summary,
44            project_path: builder.project_path,
45            total_input_tokens: std::sync::atomic::AtomicU64::new(0),
46            total_output_tokens: std::sync::atomic::AtomicU64::new(0),
47            last_input_tokens: std::sync::atomic::AtomicU64::new(0),
48            cancel_token: None,
49            compression_config: crate::compress::CompressionConfig::default(),
50            ask_rx: None,
51            proxy_tool_defs: builder.proxy_tool_defs,
52            proxy_executor: builder.proxy_executor,
53            mcp_registry: builder.mcp_registry,
54            pending_input_rx: builder.pending_input_rx,
55            pending_inputs: Vec::new(),
56            previewed_tool_inputs: std::collections::HashSet::new(),
57            todo_reminder_count: std::collections::HashMap::new(),
58        }
59    }
60
61    /// Effective context window size, preferring explicit configuration over model inference.
62    pub(crate) fn effective_context_size(&self) -> Option<u32> {
63        self.context_size_override
64            .or_else(|| self.provider.context_size())
65    }
66
67    /// Get event sender for streaming
68    pub fn event_sender(&self) -> mpsc::Sender<AgentEvent> {
69        self.event_tx.clone()
70    }
71
72    /// Set ask response channel (for TUI mode)
73    pub fn set_ask_channel(&mut self, rx: mpsc::Receiver<String>) {
74        self.ask_rx = Some(rx);
75    }
76
77    /// 设置代理工具执行器
78    pub fn set_proxy_executor(
79        &mut self,
80        executor: Arc<dyn ProxyToolExecutor>,
81        tool_defs: Vec<ProxyToolDef>,
82    ) {
83        self.proxy_executor = Some(executor);
84        self.proxy_tool_defs = tool_defs;
85    }
86
87    /// Set cancellation token
88    pub fn set_cancel_token(&mut self, token: CancellationToken) {
89        self.cancel_token = Some(token);
90    }
91
92    /// Set approve mode at runtime
93    pub fn set_approve_mode(&mut self, mode: ApproveMode) {
94        let old = ApproveMode::from_u8(self.approve_mode.load(Ordering::Relaxed));
95        log::info!("Agent approve mode changed: {} -> {}", old, mode);
96        self.approve_mode.store(mode.to_u8(), Ordering::Relaxed);
97    }
98
99    /// Get a shared reference to the approve mode atomic.
100    pub fn approve_mode_shared(&self) -> Arc<AtomicU8> {
101        self.approve_mode.clone()
102    }
103
104    /// Replace the internal approve mode with an externally-created shared atomic.
105    pub fn set_approve_mode_shared(&mut self, shared: Arc<AtomicU8>) {
106        self.approve_mode = shared;
107    }
108
109    /// Update memory summary and rebuild system prompt.
110    /// Note: Uses build_system_prompt (without project_path) to preserve cache.
111    pub fn update_memory_summary(&mut self, summary: Option<String>) {
112        self.memory_summary = summary;
113        // Preserve cache by using build_system_prompt (no dynamic CodeGraph injection)
114        self.system_prompt = prompt::build_system_prompt(
115            &self.profile,
116            &self.skills,
117            self.project_overview.as_deref(),
118            self.memory_summary.as_deref(),
119        );
120    }
121
122    /// Refresh CodeGraph tools after /init or codegraph init.
123    /// This rebuilds both tools and system prompt with project_path.
124    /// Call this only when CodeGraph state changes (not every request) to preserve cache.
125    pub fn refresh_codegraph_tools(&mut self) {
126        if let Some(path) = &self.project_path {
127            // Check if CodeGraph should be injected now
128            let should_have_codegraph =
129                crate::tools::codegraph::should_inject_codegraph_tools(path);
130
131            // Check if we currently have CodeGraph tools
132            let has_codegraph = self.tools.iter().any(|t| {
133                let name = t.definition().name;
134                name.starts_with("code_") && name != "code_review"
135            });
136
137            // Only update if state changed
138            if should_have_codegraph != has_codegraph {
139                // Add or remove CodeGraph tools
140                if should_have_codegraph {
141                    let codegraph_tools = crate::tools::codegraph::codegraph_tools(path);
142                    for tool in codegraph_tools {
143                        self.tools.push(Arc::from(tool));
144                    }
145                    // Update system prompt to include CodeGraph rules
146                    self.system_prompt = prompt::build_system_prompt_with_workflows(
147                        &self.profile,
148                        &self.skills,
149                        self.project_overview.as_deref(),
150                        self.memory_summary.as_deref(),
151                        Some(path),
152                        None, // LSP servers not available in agent context
153                    );
154                } else {
155                    // Remove CodeGraph tools
156                    self.tools.retain(|t| {
157                        let name = t.definition().name;
158                        !name.starts_with("code_") || name == "code_review"
159                    });
160                    // Update system prompt to remove CodeGraph rules
161                    self.system_prompt = prompt::build_system_prompt_with_workflows(
162                        &self.profile,
163                        &self.skills,
164                        self.project_overview.as_deref(),
165                        self.memory_summary.as_deref(),
166                        Some(path),
167                        None, // LSP servers not available in agent context
168                    );
169                }
170            }
171        }
172    }
173
174    /// Run chat loop with tool execution (streaming version).
175    pub async fn run(&mut self, user_input: String) -> Result<Vec<AgentEvent>> {
176        self.emit(AgentEvent::session_started())?;
177
178        self.messages.push(Message {
179            role: Role::User,
180            content: MessageContent::Text(user_input.clone()),
181        });
182
183        let mut iterations = 0;
184        let mut should_continue = true;
185        const ITERATION_WARNING_THRESHOLD: usize = MAX_ITERATIONS - 10;
186
187        while should_continue && iterations < MAX_ITERATIONS {
188            iterations += 1;
189
190            // Check for pending inputs BEFORE building request
191            // This ensures appended messages are sent in this iteration's API call
192            self.drain_pending_inputs();
193            if self.has_pending_inputs() {
194                let pending = self.take_pending_inputs();
195                let count = pending.len();
196                let merged = pending.join("\n\n---\n\n");
197                log::info!("Adding {} pending input messages to request", count);
198
199                // Send queue processed event to TUI with messages content
200                self.emit(AgentEvent::queue_processed(count, pending.clone()))?;
201
202                self.messages.push(Message {
203                    role: Role::User,
204                    content: MessageContent::Text(merged),
205                });
206            }
207
208            if let Some(token) = &self.cancel_token
209                && token.is_cancelled()
210            {
211                self.emit(AgentEvent::error(
212                    prompt::MSG_OPERATION_CANCELLED.to_string(),
213                    None,
214                    None,
215                ))?;
216                break;
217            }
218
219            // Warn when approaching iteration limit (UI only, not in messages history)
220            if iterations == ITERATION_WARNING_THRESHOLD {
221                self.emit(AgentEvent::progress(
222                    prompt::MSG_ITERATION_WARNING_UI
223                        .replace("{iterations}", &iterations.to_string())
224                        .replace("{max_iterations}", &MAX_ITERATIONS.to_string()),
225                    None,
226                ))?;
227            }
228
229            // Proactive compression: check context size BEFORE API call
230            // For long conversations, compress early to avoid timeout issues
231            let context_size = self.effective_context_size();
232            let estimated_tokens = estimate_total_tokens(&self.messages);
233
234            if should_compress(estimated_tokens, context_size, &self.compression_config) {
235                self.emit(AgentEvent::progress("⚠️ 上下文过大,正在预压缩...", None))?;
236
237                match compress_messages(
238                    &self.messages,
239                    CompressionStrategy::SlidingWindow,
240                    &self.compression_config,
241                ) {
242                    Ok(compressed) => {
243                        let compressed_tokens = estimate_total_tokens(&compressed);
244                        self.messages = compressed;
245                        crate::debug::debug_log().compression(
246                            estimated_tokens,
247                            compressed_tokens,
248                            compressed_tokens as f32 / estimated_tokens as f32,
249                        );
250                    }
251                    Err(e) => {
252                        self.emit(AgentEvent::progress(format!("预压缩失败: {}", e), None))?;
253                    }
254                }
255            }
256
257            // Build request with current messages (including any pending inputs)
258            let tool_defs: Vec<ToolDefinition> = {
259                let mut defs: Vec<ToolDefinition> = self
260                    .tools
261                    .iter()
262                    .map(|t| {
263                        let def = t.definition();
264                        let description = def.description_for_llm();
265                        ToolDefinition {
266                            name: def.name,
267                            description,
268                            parameters: def.parameters,
269                            is_priority: def.is_priority,
270                        }
271                    })
272                    .collect();
273                // 添加代理工具定义
274                defs.extend(self.proxy_tool_defs.iter().map(|t| {
275                    let def = &t.definition;
276                    let description = def.description_for_llm();
277                    ToolDefinition {
278                        name: def.name.clone(),
279                        description,
280                        parameters: def.parameters.clone(),
281                        is_priority: def.is_priority,
282                    }
283                }));
284                defs
285            };
286            let request = ChatRequest {
287                system: Some(self.system_prompt.clone()),
288                messages: self.messages.clone(),
289                max_tokens: self.max_tokens,
290                tools: tool_defs,
291                think: self.think,
292                enable_caching: true,
293                server_tools: Vec::new(),
294            };
295
296            let response = self.call_streaming(&request).await?;
297
298            self.track_usage(&response.usage);
299
300            crate::debug::debug_log().api_call(
301                &self.model_name,
302                response.usage.input_tokens,
303                response.usage.cache_read_input_tokens > 0,
304            );
305
306            should_continue = self.process_response(&response).await?;
307
308            // If model wants to stop, check for pending inputs first (higher priority than todos)
309            // This ensures appended messages are processed before session ends
310            if !should_continue && iterations < MAX_ITERATIONS - 1 {
311                // Final drain of pending inputs before checking todos
312                self.drain_pending_inputs();
313
314                if self.has_pending_inputs() {
315                    log::info!("Agent: found pending inputs at session end, continuing loop");
316                    should_continue = true;
317                    continue; // Will be processed at start of next iteration
318                }
319
320                // Then check for pending todos
321                // First check if we just sent a reminder (prevent immediate duplicate)
322                if self.last_message_was_todo_reminder() {
323                    log::info!("Skipping todo check: reminder already sent in recent messages");
324                } else {
325                    const MAX_TODO_REMINDERS: usize = 2;
326                    
327                    // Clone todo_reminder_count to avoid borrow conflict
328                    let reminder_count_clone = self.todo_reminder_count.clone();
329                    let (pending, all_at_limit) = self.get_pending_todos_with_limit(
330                        &reminder_count_clone,
331                        MAX_TODO_REMINDERS
332                    );
333                    
334                    if !pending.is_empty() {
335                        // Update reminder counts for todos we're about to remind about
336                        for (_, content) in &pending {
337                            *self.todo_reminder_count.entry(content.clone()).or_insert(0) += 1;
338                        }
339                        
340                        let pending_list = pending
341                            .iter()
342                            .map(|(status, content)| {
343                                let marker = match status.as_str() {
344                                    "in_progress" => "[~]",
345                                    "pending" => "[ ]",
346                                    _ => "[?]",
347                                };
348                                format!("  {} {}", marker, content)
349                            })
350                            .collect::<Vec<_>>()
351                            .join("\n");
352
353                        let reminder = format!(
354                            "📋 任务尚未完成。以下待办项需要处理:\n{}\n\n请继续执行,或在 todo_write 中标记为 completed。如遇阻塞请说明原因。",
355                            pending_list
356                        );
357
358                        self.messages.push(Message {
359                            role: Role::User,
360                            content: MessageContent::Text(reminder),
361                        });
362                        should_continue = true;
363                    } else if all_at_limit && !self.todo_reminder_count.is_empty() {
364                        // All todos have reached reminder limit, allow session to end
365                        // but inform user that todos remain incomplete
366                        let remaining_count = self.todo_reminder_count.len();
367                        self.emit(AgentEvent::progress(
368                            format!(
369                                "⚠️ 会话结束:{} 个待办项未完成(已提醒 {} 次,达到上限)",
370                                remaining_count, MAX_TODO_REMINDERS
371                            ),
372                            None,
373                        ))?;
374                        log::warn!(
375                            "Session ending with {} incomplete todos (reminder limit reached)",
376                            remaining_count
377                        );
378                    }
379                }
380            }
381
382            let context_size = self.effective_context_size();
383            let api_tokens = self.last_input_tokens.load(Ordering::Relaxed) as u32;
384            let estimated_tokens = estimate_total_tokens(&self.messages);
385
386            let current_tokens = if api_tokens > 0 && api_tokens >= estimated_tokens / 2 {
387                api_tokens
388            } else {
389                estimated_tokens
390            };
391
392            // Only log compression check when context is getting full (> 30%)
393            // This avoids cluttering debug panel with meaningless checks
394            if let Some(ctx_size) = context_size {
395                // Send context size to TUI for accurate display
396                self.emit(AgentEvent::with_data(
397                    EventType::ContextSize,
398                    EventData::ContextSize {
399                        context_size: ctx_size as u64,
400                    },
401                ))?;
402
403                let usage_ratio = current_tokens as f64 / ctx_size as f64;
404                if usage_ratio >= 0.3 {
405                    crate::debug::debug_log().log(
406                        "checkcompress",
407                        &format!(
408                            "usage={:.1}%, tokens={}, context={}, threshold={}%",
409                            usage_ratio * 100.0,
410                            current_tokens,
411                            ctx_size,
412                            self.compression_config.threshold * 100.0
413                        ),
414                    );
415                }
416            }
417
418            if should_compress(current_tokens, context_size, &self.compression_config) {
419                self.emit(AgentEvent::progress(prompt::MSG_COMPRESSING_CONTEXT, None))?;
420
421                let original_tokens = current_tokens;
422
423                match compress_messages(
424                    &self.messages,
425                    CompressionStrategy::SlidingWindow,
426                    &self.compression_config,
427                ) {
428                    Ok(compressed) => {
429                        let compressed_tokens = estimate_total_tokens(&compressed);
430                        self.messages = compressed;
431                        self.total_input_tokens
432                            .store(compressed_tokens as u64, Ordering::Relaxed);
433                        self.last_input_tokens
434                            .store(compressed_tokens as u64, Ordering::Relaxed);
435
436                        let ratio = compressed_tokens as f32 / original_tokens as f32;
437                        crate::debug::debug_log().compression(
438                            original_tokens,
439                            compressed_tokens,
440                            ratio,
441                        );
442
443                        self.emit(AgentEvent::with_data(
444                            EventType::CompressionCompleted,
445                            EventData::Compression {
446                                original_tokens: original_tokens as u64,
447                                compressed_tokens: compressed_tokens as u64,
448                                ratio: compressed_tokens as f32 / original_tokens as f32,
449                            },
450                        ))?;
451                    }
452                    Err(e) => {
453                        self.emit(AgentEvent::progress(
454                            format!("{}{}", prompt::MSG_COMPRESSION_FAILED, e),
455                            None,
456                        ))?;
457                    }
458                }
459            }
460        }
461
462        // Check if we stopped due to reaching MAX_ITERATIONS
463        if iterations >= MAX_ITERATIONS && should_continue {
464            self.emit(AgentEvent::error(
465                prompt::MSG_MAX_ITERATIONS_REACHED
466                    .replace("{max_iterations}", &MAX_ITERATIONS.to_string())
467                    .replace("{iterations}", &iterations.to_string()),
468                Some("MAX_ITERATIONS_REACHED".to_string()),
469                Some("agent/run.rs".to_string()),
470            ))?;
471        }
472
473        self.emit(AgentEvent::usage_with_cache(
474            self.total_input_tokens.load(Ordering::Relaxed),
475            self.total_output_tokens.load(Ordering::Relaxed),
476            0,
477            0,
478        ))?;
479
480        self.emit(AgentEvent::session_ended())?;
481
482        Ok(Vec::new())
483    }
484
485    /// Restore message history (for session continue/resume)
486    pub fn set_messages(&mut self, messages: Vec<Message>) {
487        self.messages = messages;
488    }
489
490    /// Get current messages (for session saving)
491    pub fn get_messages(&self) -> &[Message] {
492        &self.messages
493    }
494
495    /// Get available tools
496    pub fn get_tools(&self) -> &[Arc<dyn Tool>] {
497        &self.tools
498    }
499
500    /// Get system prompt
501    pub fn get_system_prompt(&self) -> &str {
502        &self.system_prompt
503    }
504
505    /// Get current token counts
506    pub fn get_token_counts(&self) -> (u64, u64) {
507        (
508            self.total_input_tokens.load(Ordering::Relaxed),
509            self.total_output_tokens.load(Ordering::Relaxed),
510        )
511    }
512
513    /// Clear message history
514    pub fn clear_history(&mut self) {
515        self.messages.clear();
516        self.total_input_tokens.store(0, Ordering::Relaxed);
517        self.total_output_tokens.store(0, Ordering::Relaxed);
518        self.last_input_tokens.store(0, Ordering::Relaxed);
519    }
520
521    /// Get message count
522    pub fn message_count(&self) -> usize {
523        self.messages.len()
524    }
525
526    // ========================================================================
527    // MCP Runtime Management
528    // ========================================================================
529
530    /// 动态添加 MCP 服务器
531    ///
532    /// # Example
533    /// ```ignore
534    /// use matrixcode_core::mcp::McpServerConfig;
535    ///
536    /// let config = McpServerConfig::stdio("npx", vec!["-y", "@playwright/mcp@latest"]);
537    /// agent.add_mcp_server("playwright", config).await?;
538    /// ```
539    pub async fn add_mcp_server(
540        &mut self,
541        name: &str,
542        config: crate::mcp::McpServerConfig,
543    ) -> Result<()> {
544        if let Some(registry) = &self.mcp_registry {
545            let mut reg = registry.write().await;
546            reg.add_server(name.to_string(), config);
547            log::info!("MCP server '{}' added to registry", name);
548        } else {
549            log::warn!("MCP registry not initialized, cannot add server '{}'", name);
550        }
551        Ok(())
552    }
553
554    /// 移除 MCP 服务器
555    pub async fn remove_mcp_server(&mut self, name: &str) -> Result<()> {
556        if let Some(registry) = &self.mcp_registry {
557            let mut reg = registry.write().await;
558            reg.remove_server(name).await?;
559            log::info!("MCP server '{}' removed from registry", name);
560        }
561        Ok(())
562    }
563
564    /// 获取 MCP 服务器状态列表
565    pub async fn mcp_server_status(&self) -> Vec<crate::mcp::ServerStatus> {
566        if let Some(registry) = &self.mcp_registry {
567            let reg = registry.read().await;
568            reg.server_status().await.values().cloned().collect()
569        } else {
570            Vec::new()
571        }
572    }
573
574    /// 启动指定的 MCP 服务器
575    pub async fn start_mcp_server(
576        &self,
577        name: &str,
578    ) -> Result<Vec<Arc<crate::mcp::McpToolWrapper>>> {
579        if let Some(registry) = &self.mcp_registry {
580            let reg = registry.read().await;
581            if let Some(placeholder) = reg.get_server(name) {
582                let tools = placeholder.start().await?;
583                log::info!("MCP server '{}' started with {} tools", name, tools.len());
584                Ok(tools)
585            } else {
586                Err(anyhow::anyhow!(
587                    "MCP server '{}' not found in registry",
588                    name
589                ))
590            }
591        } else {
592            Err(anyhow::anyhow!("MCP registry not initialized"))
593        }
594    }
595}