Skip to main content

matrixcode_core/agent/
run.rs

1//! Agent run loop and public methods.
2
3use anyhow::Result;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicU8, Ordering};
6use tokio::sync::mpsc;
7
8use crate::approval::ApproveMode;
9use crate::cancel::CancellationToken;
10use crate::compress::{
11    CompressionStrategy, compress_messages, estimate_total_tokens, should_compress,
12};
13use crate::event::{AgentEvent, EventData, EventType};
14use crate::prompt;
15use crate::providers::{ChatRequest, Message, MessageContent, Role};
16use crate::tools::Tool;
17use crate::tools::ToolDefinition;
18use crate::tools::toolproxy::{ProxyToolExecutor, ProxyToolDef};
19
20use super::types::{Agent, AgentBuilder, MAX_ITERATIONS};
21
22impl Agent {
23    pub(crate) fn new(builder: AgentBuilder) -> Self {
24        let event_tx = builder.event_tx.unwrap_or_else(|| {
25            let (tx, _) = mpsc::channel(100);
26            tx
27        });
28
29        Self {
30            provider: builder.provider,
31            model_name: builder.model_name,
32            tools: builder.tools,
33            messages: Vec::new(),
34            system_prompt: builder.system_prompt,
35            max_tokens: builder.max_tokens,
36            think: builder.think,
37            approve_mode: Arc::new(AtomicU8::new(builder.approve_mode.to_u8())),
38            event_tx,
39            skills: builder.skills,
40            profile: builder.profile,
41            project_overview: builder.project_overview,
42            memory_summary: builder.memory_summary,
43            project_path: builder.project_path,
44            total_input_tokens: std::sync::atomic::AtomicU64::new(0),
45            total_output_tokens: std::sync::atomic::AtomicU64::new(0),
46            last_input_tokens: std::sync::atomic::AtomicU64::new(0),
47            cancel_token: None,
48            compression_config: crate::compress::CompressionConfig::default(),
49            ask_rx: None,
50            proxy_tool_defs: builder.proxy_tool_defs,
51            proxy_executor: builder.proxy_executor,
52            mcp_registry: builder.mcp_registry,
53            pending_input_rx: builder.pending_input_rx,
54            pending_inputs: Vec::new(),
55        }
56    }
57
58    /// Get event sender for streaming
59    pub fn event_sender(&self) -> mpsc::Sender<AgentEvent> {
60        self.event_tx.clone()
61    }
62
63    /// Set ask response channel (for TUI mode)
64    pub fn set_ask_channel(&mut self, rx: mpsc::Receiver<String>) {
65        self.ask_rx = Some(rx);
66    }
67
68    /// 设置代理工具执行器
69    pub fn set_proxy_executor(&mut self, executor: Arc<dyn ProxyToolExecutor>, tool_defs: Vec<ProxyToolDef>) {
70        self.proxy_executor = Some(executor);
71        self.proxy_tool_defs = tool_defs;
72    }
73
74    /// Set cancellation token
75    pub fn set_cancel_token(&mut self, token: CancellationToken) {
76        self.cancel_token = Some(token);
77    }
78
79    /// Set approve mode at runtime
80    pub fn set_approve_mode(&mut self, mode: ApproveMode) {
81        let old = ApproveMode::from_u8(self.approve_mode.load(Ordering::Relaxed));
82        log::info!("Agent approve mode changed: {} -> {}", old, mode);
83        self.approve_mode.store(mode.to_u8(), Ordering::Relaxed);
84    }
85
86    /// Get a shared reference to the approve mode atomic.
87    pub fn approve_mode_shared(&self) -> Arc<AtomicU8> {
88        self.approve_mode.clone()
89    }
90
91    /// Replace the internal approve mode with an externally-created shared atomic.
92    pub fn set_approve_mode_shared(&mut self, shared: Arc<AtomicU8>) {
93        self.approve_mode = shared;
94    }
95
96    /// Update memory summary and rebuild system prompt.
97    /// Note: Uses build_system_prompt (without project_path) to preserve cache.
98    pub fn update_memory_summary(&mut self, summary: Option<String>) {
99        self.memory_summary = summary;
100        // Preserve cache by using build_system_prompt (no dynamic CodeGraph injection)
101        self.system_prompt = prompt::build_system_prompt(
102            &self.profile,
103            &self.skills,
104            self.project_overview.as_deref(),
105            self.memory_summary.as_deref(),
106        );
107    }
108
109    /// Refresh CodeGraph tools after /init or codegraph init.
110    /// This rebuilds both tools and system prompt with project_path.
111    /// Call this only when CodeGraph state changes (not every request) to preserve cache.
112    pub fn refresh_codegraph_tools(&mut self) {
113        if let Some(path) = &self.project_path {
114            // Check if CodeGraph should be injected now
115            let should_have_codegraph = crate::tools::codegraph::should_inject_codegraph_tools(path);
116
117            // Check if we currently have CodeGraph tools
118            let has_codegraph = self.tools.iter().any(|t| {
119                let name = t.definition().name;
120                name.starts_with("code_") && name != "code_review"
121            });
122
123            // Only update if state changed
124            if should_have_codegraph != has_codegraph {
125                // Add or remove CodeGraph tools
126                if should_have_codegraph {
127                    let codegraph_tools = crate::tools::codegraph::codegraph_tools(path);
128                    for tool in codegraph_tools {
129                        self.tools.push(Arc::from(tool));
130                    }
131                    // Update system prompt to include CodeGraph rules
132                    self.system_prompt = prompt::build_system_prompt_with_workflows(
133                        &self.profile,
134                        &self.skills,
135                        self.project_overview.as_deref(),
136                        self.memory_summary.as_deref(),
137                        Some(path),
138                        None, // LSP servers not available in agent context
139                    );
140                } else {
141                    // Remove CodeGraph tools
142                    self.tools.retain(|t| {
143                        let name = t.definition().name;
144                        !name.starts_with("code_") || name == "code_review"
145                    });
146                    // Update system prompt to remove CodeGraph rules
147                    self.system_prompt = prompt::build_system_prompt_with_workflows(
148                        &self.profile,
149                        &self.skills,
150                        self.project_overview.as_deref(),
151                        self.memory_summary.as_deref(),
152                        Some(path),
153                        None, // LSP servers not available in agent context
154                    );
155                }
156            }
157        }
158    }
159
160    /// Run chat loop with tool execution (streaming version).
161    pub async fn run(&mut self, user_input: String) -> Result<Vec<AgentEvent>> {
162        self.emit(AgentEvent::session_started())?;
163
164        self.messages.push(Message {
165            role: Role::User,
166            content: MessageContent::Text(user_input.clone()),
167        });
168
169        let mut iterations = 0;
170        let mut should_continue = true;
171        const ITERATION_WARNING_THRESHOLD: usize = MAX_ITERATIONS - 10;
172
173        while should_continue && iterations < MAX_ITERATIONS {
174            iterations += 1;
175
176            // Check for pending inputs BEFORE building request
177            // This ensures appended messages are sent in this iteration's API call
178            if self.has_pending_inputs() {
179                let pending = self.take_pending_inputs();
180                let count = pending.len();
181                let merged = pending.join("\n\n---\n\n");
182                log::info!("Adding {} pending input messages to request", count);
183
184                // Send queue processed event to TUI with messages content
185                self.emit(AgentEvent::queue_processed(count, pending.clone()))?;
186
187                self.messages.push(Message {
188                    role: Role::User,
189                    content: MessageContent::Text(merged),
190                });
191            }
192
193            if let Some(token) = &self.cancel_token
194                && token.is_cancelled()
195            {
196                self.emit(AgentEvent::error(
197                    prompt::MSG_OPERATION_CANCELLED.to_string(),
198                    None,
199                    None,
200                ))?;
201                break;
202            }
203
204            // Warn when approaching iteration limit (UI only, not in messages history)
205            if iterations == ITERATION_WARNING_THRESHOLD {
206                self.emit(AgentEvent::progress(
207                    prompt::MSG_ITERATION_WARNING_UI
208                        .replace("{iterations}", &iterations.to_string())
209                        .replace("{max_iterations}", &MAX_ITERATIONS.to_string()),
210                    None,
211                ))?;
212            }
213
214            // Proactive compression: check context size BEFORE API call
215            // For long conversations, compress early to avoid timeout issues
216            let context_size = self.provider.context_size();
217            let estimated_tokens = estimate_total_tokens(&self.messages);
218
219            if should_compress(estimated_tokens, context_size, &self.compression_config) {
220                self.emit(AgentEvent::progress("⚠️ 上下文过大,正在预压缩...", None))?;
221
222                match compress_messages(
223                    &self.messages,
224                    CompressionStrategy::SlidingWindow,
225                    &self.compression_config,
226                ) {
227                    Ok(compressed) => {
228                        let compressed_tokens = estimate_total_tokens(&compressed);
229                        self.messages = compressed;
230                        crate::debug::debug_log().compression(
231                            estimated_tokens,
232                            compressed_tokens,
233                            compressed_tokens as f32 / estimated_tokens as f32,
234                        );
235                    }
236                    Err(e) => {
237                        self.emit(AgentEvent::progress(
238                            format!("预压缩失败: {}", e),
239                            None,
240                        ))?;
241                    }
242                }
243            }
244
245            // Build request with current messages (including any pending inputs)
246            let tool_defs: Vec<ToolDefinition> = {
247                let mut defs: Vec<ToolDefinition> = self.tools.iter().map(|t| {
248                    let def = t.definition();
249                    let description = def.description_for_llm();
250                    ToolDefinition {
251                        name: def.name,
252                        description,
253                        parameters: def.parameters,
254                        is_priority: def.is_priority,
255                    }
256                }).collect();
257                // 添加代理工具定义
258                defs.extend(self.proxy_tool_defs.iter().map(|t| {
259                    let def = &t.definition;
260                    let description = def.description_for_llm();
261                    ToolDefinition {
262                        name: def.name.clone(),
263                        description,
264                        parameters: def.parameters.clone(),
265                        is_priority: def.is_priority,
266                    }
267                }));
268                defs
269            };
270            let request = ChatRequest {
271                system: Some(self.system_prompt.clone()),
272                messages: self.messages.clone(),
273                max_tokens: self.max_tokens,
274                tools: tool_defs,
275                think: self.think,
276                enable_caching: true,
277                server_tools: Vec::new(),
278            };
279
280            let response = self.call_streaming(&request).await?;
281
282            self.track_usage(&response.usage);
283
284            crate::debug::debug_log().api_call(
285                &self.model_name,
286                response.usage.input_tokens,
287                response.usage.cache_read_input_tokens > 0,
288            );
289
290            should_continue = self.process_response(&response).await?;
291
292            // If model wants to stop, check for pending inputs first (higher priority than todos)
293            // This ensures appended messages are processed before session ends
294            if !should_continue && iterations < MAX_ITERATIONS - 1 {
295                // Final drain of pending inputs before checking todos
296                self.drain_pending_inputs();
297                
298                if self.has_pending_inputs() {
299                    log::info!("Agent: found pending inputs at session end, continuing loop");
300                    should_continue = true;
301                    continue;  // Will be processed at start of next iteration
302                }
303                
304                // Then check for pending todos
305                let pending = self.get_pending_todos();
306                if !pending.is_empty() {
307                    let pending_list = pending.iter()
308                        .map(|(status, content)| {
309                            let marker = match status.as_str() {
310                                "in_progress" => "[~]",
311                                "pending" => "[ ]",
312                                _ => "[?]"
313                            };
314                            format!("  {} {}", marker, content)
315                        })
316                        .collect::<Vec<_>>()
317                        .join("\n");
318
319                    let reminder = format!(
320                        "📋 任务尚未完成。以下待办项需要处理:\n{}\n\n请继续执行,或在 todo_write 中标记为 completed。如遇阻塞请说明原因。",
321                        pending_list
322                    );
323
324                    self.messages.push(Message {
325                        role: Role::User,
326                        content: MessageContent::Text(reminder),
327                    });
328                    should_continue = true;
329                }
330            }
331
332            let context_size = self.provider.context_size();
333            let api_tokens = self.last_input_tokens.load(Ordering::Relaxed) as u32;
334            let estimated_tokens = estimate_total_tokens(&self.messages);
335
336            let current_tokens = if api_tokens > 0 && api_tokens >= estimated_tokens / 2 {
337                api_tokens
338            } else {
339                estimated_tokens
340            };
341
342            // Only log compression check when context is getting full (> 30%)
343            // This avoids cluttering debug panel with meaningless checks
344            if let Some(ctx_size) = context_size {
345                // Send context size to TUI for accurate display
346                self.emit(AgentEvent::with_data(
347                    EventType::ContextSize,
348                    EventData::ContextSize {
349                        context_size: ctx_size as u64,
350                    },
351                ))?;
352
353                let usage_ratio = current_tokens as f64 / ctx_size as f64;
354                if usage_ratio >= 0.3 {
355                    crate::debug::debug_log().log(
356                        "checkcompress",
357                        &format!(
358                            "usage={:.1}%, tokens={}, context={}, threshold={}%",
359                            usage_ratio * 100.0,
360                            current_tokens,
361                            ctx_size,
362                            self.compression_config.threshold * 100.0
363                        ),
364                    );
365                }
366            }
367
368            if should_compress(current_tokens, context_size, &self.compression_config) {
369                self.emit(AgentEvent::progress(prompt::MSG_COMPRESSING_CONTEXT, None))?;
370
371                let original_tokens = current_tokens;
372
373                match compress_messages(
374                    &self.messages,
375                    CompressionStrategy::SlidingWindow,
376                    &self.compression_config,
377                ) {
378                    Ok(compressed) => {
379                        let compressed_tokens = estimate_total_tokens(&compressed);
380                        self.messages = compressed;
381                        self.total_input_tokens
382                            .store(compressed_tokens as u64, Ordering::Relaxed);
383                        self.last_input_tokens
384                            .store(compressed_tokens as u64, Ordering::Relaxed);
385
386                        let ratio = compressed_tokens as f32 / original_tokens as f32;
387                        crate::debug::debug_log().compression(
388                            original_tokens,
389                            compressed_tokens,
390                            ratio,
391                        );
392
393                        self.emit(AgentEvent::with_data(
394                            EventType::CompressionCompleted,
395                            EventData::Compression {
396                                original_tokens: original_tokens as u64,
397                                compressed_tokens: compressed_tokens as u64,
398                                ratio: compressed_tokens as f32 / original_tokens as f32,
399                            },
400                        ))?;
401                    }
402                    Err(e) => {
403                        self.emit(AgentEvent::progress(
404                            format!("{}{}", prompt::MSG_COMPRESSION_FAILED, e),
405                            None,
406                        ))?;
407                    }
408                }
409            }
410        }
411        
412        // Check if we stopped due to reaching MAX_ITERATIONS
413        if iterations >= MAX_ITERATIONS && should_continue {
414            self.emit(AgentEvent::error(
415                prompt::MSG_MAX_ITERATIONS_REACHED
416                    .replace("{max_iterations}", &MAX_ITERATIONS.to_string())
417                    .replace("{iterations}", &iterations.to_string()),
418                Some("MAX_ITERATIONS_REACHED".to_string()),
419                Some("agent/run.rs".to_string()),
420            ))?;
421        }
422        
423        self.emit(AgentEvent::usage_with_cache(
424            self.total_input_tokens.load(Ordering::Relaxed),
425            self.total_output_tokens.load(Ordering::Relaxed),
426            0,
427            0,
428        ))?;
429
430        self.emit(AgentEvent::session_ended())?;
431
432        Ok(Vec::new())
433    }
434
435    /// Restore message history (for session continue/resume)
436    pub fn set_messages(&mut self, messages: Vec<Message>) {
437        self.messages = messages;
438    }
439
440    /// Get current messages (for session saving)
441    pub fn get_messages(&self) -> &[Message] {
442        &self.messages
443    }
444
445    /// Get available tools
446    pub fn get_tools(&self) -> &[Arc<dyn Tool>] {
447        &self.tools
448    }
449
450    /// Get system prompt
451    pub fn get_system_prompt(&self) -> &str {
452        &self.system_prompt
453    }
454
455    /// Get current token counts
456    pub fn get_token_counts(&self) -> (u64, u64) {
457        (
458            self.total_input_tokens.load(Ordering::Relaxed),
459            self.total_output_tokens.load(Ordering::Relaxed),
460        )
461    }
462
463    /// Clear message history
464    pub fn clear_history(&mut self) {
465        self.messages.clear();
466        self.total_input_tokens.store(0, Ordering::Relaxed);
467        self.total_output_tokens.store(0, Ordering::Relaxed);
468        self.last_input_tokens.store(0, Ordering::Relaxed);
469    }
470
471    /// Get message count
472    pub fn message_count(&self) -> usize {
473        self.messages.len()
474    }
475
476    // ========================================================================
477    // MCP Runtime Management
478    // ========================================================================
479
480    /// 动态添加 MCP 服务器
481    /// 
482    /// # Example
483    /// ```ignore
484    /// use matrixcode_core::mcp::McpServerConfig;
485    /// 
486    /// let config = McpServerConfig::stdio("npx", vec!["-y", "@playwright/mcp@latest"]);
487    /// agent.add_mcp_server("playwright", config).await?;
488    /// ```
489    pub async fn add_mcp_server(&mut self, name: &str, config: crate::mcp::McpServerConfig) -> Result<()> {
490        if let Some(registry) = &self.mcp_registry {
491            let mut reg = registry.write().await;
492            reg.add_server(name.to_string(), config);
493            log::info!("MCP server '{}' added to registry", name);
494        } else {
495            log::warn!("MCP registry not initialized, cannot add server '{}'", name);
496        }
497        Ok(())
498    }
499
500    /// 移除 MCP 服务器
501    pub async fn remove_mcp_server(&mut self, name: &str) -> Result<()> {
502        if let Some(registry) = &self.mcp_registry {
503            let mut reg = registry.write().await;
504            reg.remove_server(name).await?;
505            log::info!("MCP server '{}' removed from registry", name);
506        }
507        Ok(())
508    }
509
510    /// 获取 MCP 服务器状态列表
511    pub async fn mcp_server_status(&self) -> Vec<crate::mcp::ServerStatus> {
512        if let Some(registry) = &self.mcp_registry {
513            let reg = registry.read().await;
514            reg.server_status().await.values().cloned().collect()
515        } else {
516            Vec::new()
517        }
518    }
519
520    /// 启动指定的 MCP 服务器
521    pub async fn start_mcp_server(&self, name: &str) -> Result<Vec<Arc<crate::mcp::McpToolWrapper>>> {
522        if let Some(registry) = &self.mcp_registry {
523            let reg = registry.read().await;
524            if let Some(placeholder) = reg.get_server(name) {
525                let tools = placeholder.start().await?;
526                log::info!("MCP server '{}' started with {} tools", name, tools.len());
527                Ok(tools)
528            } else {
529                Err(anyhow::anyhow!("MCP server '{}' not found in registry", name))
530            }
531        } else {
532            Err(anyhow::anyhow!("MCP registry not initialized"))
533        }
534    }
535}