Skip to main content

matrixcode_core/agent/
run.rs

1//! Agent run loop and public methods.
2
3use anyhow::Result;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicU8, Ordering};
6use tokio::sync::mpsc;
7
8use crate::approval::ApproveMode;
9use crate::cancel::CancellationToken;
10use crate::compress::{
11    CompressionStrategy, compress_messages, estimate_total_tokens, should_compress,
12};
13use crate::event::{AgentEvent, EventData, EventType};
14use crate::prompt;
15use crate::providers::{ChatRequest, Message, MessageContent, Role};
16use crate::tools::Tool;
17use crate::tools::ToolDefinition;
18use crate::tools::toolproxy::{ProxyToolExecutor, ProxyToolDef};
19
20use super::types::{Agent, AgentBuilder, MAX_ITERATIONS};
21
22impl Agent {
23    pub(crate) fn new(builder: AgentBuilder) -> Self {
24        let event_tx = builder.event_tx.unwrap_or_else(|| {
25            let (tx, _) = mpsc::channel(100);
26            tx
27        });
28
29        Self {
30            provider: builder.provider,
31            model_name: builder.model_name,
32            tools: builder.tools,
33            messages: Vec::new(),
34            system_prompt: builder.system_prompt,
35            max_tokens: builder.max_tokens,
36            think: builder.think,
37            approve_mode: Arc::new(AtomicU8::new(builder.approve_mode.to_u8())),
38            event_tx,
39            skills: builder.skills,
40            profile: builder.profile,
41            project_overview: builder.project_overview,
42            memory_summary: builder.memory_summary,
43            project_path: builder.project_path,
44            total_input_tokens: std::sync::atomic::AtomicU64::new(0),
45            total_output_tokens: std::sync::atomic::AtomicU64::new(0),
46            last_input_tokens: std::sync::atomic::AtomicU64::new(0),
47            cancel_token: None,
48            compression_config: crate::compress::CompressionConfig::default(),
49            ask_rx: None,
50            proxy_tool_defs: builder.proxy_tool_defs,
51            proxy_executor: builder.proxy_executor,
52            mcp_registry: builder.mcp_registry,
53            pending_input_rx: builder.pending_input_rx,
54            pending_inputs: Vec::new(),
55        }
56    }
57
58    /// Get event sender for streaming
59    pub fn event_sender(&self) -> mpsc::Sender<AgentEvent> {
60        self.event_tx.clone()
61    }
62
63    /// Set ask response channel (for TUI mode)
64    pub fn set_ask_channel(&mut self, rx: mpsc::Receiver<String>) {
65        self.ask_rx = Some(rx);
66    }
67
68    /// 设置代理工具执行器
69    pub fn set_proxy_executor(&mut self, executor: Arc<dyn ProxyToolExecutor>, tool_defs: Vec<ProxyToolDef>) {
70        self.proxy_executor = Some(executor);
71        self.proxy_tool_defs = tool_defs;
72    }
73
74    /// Set cancellation token
75    pub fn set_cancel_token(&mut self, token: CancellationToken) {
76        self.cancel_token = Some(token);
77    }
78
79    /// Set approve mode at runtime
80    pub fn set_approve_mode(&mut self, mode: ApproveMode) {
81        let old = ApproveMode::from_u8(self.approve_mode.load(Ordering::Relaxed));
82        log::info!("Agent approve mode changed: {} -> {}", old, mode);
83        self.approve_mode.store(mode.to_u8(), Ordering::Relaxed);
84    }
85
86    /// Get a shared reference to the approve mode atomic.
87    pub fn approve_mode_shared(&self) -> Arc<AtomicU8> {
88        self.approve_mode.clone()
89    }
90
91    /// Replace the internal approve mode with an externally-created shared atomic.
92    pub fn set_approve_mode_shared(&mut self, shared: Arc<AtomicU8>) {
93        self.approve_mode = shared;
94    }
95
96    /// Update memory summary and rebuild system prompt.
97    /// Note: Uses build_system_prompt (without project_path) to preserve cache.
98    pub fn update_memory_summary(&mut self, summary: Option<String>) {
99        self.memory_summary = summary;
100        // Preserve cache by using build_system_prompt (no dynamic CodeGraph injection)
101        self.system_prompt = prompt::build_system_prompt(
102            &self.profile,
103            &self.skills,
104            self.project_overview.as_deref(),
105            self.memory_summary.as_deref(),
106        );
107    }
108
109    /// Refresh CodeGraph tools after /init or codegraph init.
110    /// This rebuilds both tools and system prompt with project_path.
111    /// Call this only when CodeGraph state changes (not every request) to preserve cache.
112    pub fn refresh_codegraph_tools(&mut self) {
113        if let Some(path) = &self.project_path {
114            // Check if CodeGraph should be injected now
115            let should_have_codegraph = crate::tools::codegraph::should_inject_codegraph_tools(path);
116
117            // Check if we currently have CodeGraph tools
118            let has_codegraph = self.tools.iter().any(|t| {
119                let name = t.definition().name;
120                name.starts_with("code_") && name != "code_review"
121            });
122
123            // Only update if state changed
124            if should_have_codegraph != has_codegraph {
125                // Add or remove CodeGraph tools
126                if should_have_codegraph {
127                    let codegraph_tools = crate::tools::codegraph::codegraph_tools(path);
128                    for tool in codegraph_tools {
129                        self.tools.push(Arc::from(tool));
130                    }
131                    // Update system prompt to include CodeGraph rules
132                    self.system_prompt = prompt::build_system_prompt_with_workflows(
133                        &self.profile,
134                        &self.skills,
135                        self.project_overview.as_deref(),
136                        self.memory_summary.as_deref(),
137                        Some(path),
138                        None, // LSP servers not available in agent context
139                    );
140                } else {
141                    // Remove CodeGraph tools
142                    self.tools.retain(|t| {
143                        let name = t.definition().name;
144                        !name.starts_with("code_") || name == "code_review"
145                    });
146                    // Update system prompt to remove CodeGraph rules
147                    self.system_prompt = prompt::build_system_prompt_with_workflows(
148                        &self.profile,
149                        &self.skills,
150                        self.project_overview.as_deref(),
151                        self.memory_summary.as_deref(),
152                        Some(path),
153                        None, // LSP servers not available in agent context
154                    );
155                }
156            }
157        }
158    }
159
160    /// Run chat loop with tool execution (streaming version).
161    pub async fn run(&mut self, user_input: String) -> Result<Vec<AgentEvent>> {
162        self.emit(AgentEvent::session_started())?;
163
164        self.messages.push(Message {
165            role: Role::User,
166            content: MessageContent::Text(user_input.clone()),
167        });
168
169        let mut iterations = 0;
170        let mut should_continue = true;
171        const ITERATION_WARNING_THRESHOLD: usize = MAX_ITERATIONS - 10;
172
173        while should_continue && iterations < MAX_ITERATIONS {
174            iterations += 1;
175
176            // Check for pending inputs BEFORE building request
177            // This ensures appended messages are sent in this iteration's API call
178            if self.has_pending_inputs() {
179                let pending = self.take_pending_inputs();
180                let merged = pending.join("\n\n---\n\n");
181                log::info!("Adding {} pending input messages to request", pending.len());
182
183                self.emit(AgentEvent::progress(
184                    format!("📝 收到 {} 条追加消息", pending.len()),
185                    None,
186                ))?;
187
188                self.messages.push(Message {
189                    role: Role::User,
190                    content: MessageContent::Text(merged),
191                });
192            }
193
194            if let Some(token) = &self.cancel_token
195                && token.is_cancelled()
196            {
197                self.emit(AgentEvent::error(
198                    prompt::MSG_OPERATION_CANCELLED.to_string(),
199                    None,
200                    None,
201                ))?;
202                break;
203            }
204
205            // Warn when approaching iteration limit (UI only, not in messages history)
206            if iterations == ITERATION_WARNING_THRESHOLD {
207                self.emit(AgentEvent::progress(
208                    prompt::MSG_ITERATION_WARNING_UI
209                        .replace("{iterations}", &iterations.to_string())
210                        .replace("{max_iterations}", &MAX_ITERATIONS.to_string()),
211                    None,
212                ))?;
213            }
214
215            // Proactive compression: check context size BEFORE API call
216            // For long conversations, compress early to avoid timeout issues
217            let context_size = self.provider.context_size();
218            let estimated_tokens = estimate_total_tokens(&self.messages);
219
220            if should_compress(estimated_tokens, context_size, &self.compression_config) {
221                self.emit(AgentEvent::progress("⚠️ 上下文过大,正在预压缩...", None))?;
222
223                match compress_messages(
224                    &self.messages,
225                    CompressionStrategy::SlidingWindow,
226                    &self.compression_config,
227                ) {
228                    Ok(compressed) => {
229                        let compressed_tokens = estimate_total_tokens(&compressed);
230                        self.messages = compressed;
231                        crate::debug::debug_log().compression(
232                            estimated_tokens,
233                            compressed_tokens,
234                            compressed_tokens as f32 / estimated_tokens as f32,
235                        );
236                    }
237                    Err(e) => {
238                        self.emit(AgentEvent::progress(
239                            format!("预压缩失败: {}", e),
240                            None,
241                        ))?;
242                    }
243                }
244            }
245
246            // Build request with current messages (including any pending inputs)
247            let tool_defs: Vec<ToolDefinition> = {
248                let mut defs: Vec<ToolDefinition> = self.tools.iter().map(|t| {
249                    let def = t.definition();
250                    let description = def.description_for_llm();
251                    ToolDefinition {
252                        name: def.name,
253                        description,
254                        parameters: def.parameters,
255                        is_priority: def.is_priority,
256                    }
257                }).collect();
258                // 添加代理工具定义
259                defs.extend(self.proxy_tool_defs.iter().map(|t| {
260                    let def = &t.definition;
261                    let description = def.description_for_llm();
262                    ToolDefinition {
263                        name: def.name.clone(),
264                        description,
265                        parameters: def.parameters.clone(),
266                        is_priority: def.is_priority,
267                    }
268                }));
269                defs
270            };
271            let request = ChatRequest {
272                system: Some(self.system_prompt.clone()),
273                messages: self.messages.clone(),
274                max_tokens: self.max_tokens,
275                tools: tool_defs,
276                think: self.think,
277                enable_caching: true,
278                server_tools: Vec::new(),
279            };
280
281            let response = self.call_streaming(&request).await?;
282
283            self.track_usage(&response.usage);
284
285            crate::debug::debug_log().api_call(
286                &self.model_name,
287                response.usage.input_tokens,
288                response.usage.cache_read_input_tokens > 0,
289            );
290
291            should_continue = self.process_response(&response).await?;
292
293            // If model wants to stop, check for pending inputs first, then pending todos
294            if !should_continue && iterations < MAX_ITERATIONS - 1 {
295                // Check for user appended messages (real-time input during processing)
296                if self.has_pending_inputs() {
297                    let pending = self.take_pending_inputs();
298                    let merged = pending.join("\n\n---\n\n");
299                    log::info!("Model stopped but user appended {} messages, continuing", pending.len());
300
301                    self.emit(AgentEvent::progress(
302                        format!("📝 处理 {} 条追加消息", pending.len()),
303                        None,
304                    ))?;
305
306                    self.messages.push(Message {
307                        role: Role::User,
308                        content: MessageContent::Text(merged),
309                    });
310                    should_continue = true;
311                } else {
312                    // No pending inputs, check for pending todos
313                    let pending = self.get_pending_todos();
314                    if !pending.is_empty() {
315                        // Generate specific reminder with pending tasks
316                        let pending_list = pending.iter()
317                            .map(|(status, content)| {
318                                let marker = match status.as_str() {
319                                    "in_progress" => "[~]",
320                                    "pending" => "[ ]",
321                                    _ => "[?]"
322                                };
323                                format!("  {} {}", marker, content)
324                            })
325                            .collect::<Vec<_>>()
326                            .join("\n");
327
328                        let reminder = format!(
329                            "📋 任务尚未完成。以下待办项需要处理:\n{}\n\n请继续执行,或在 todo_write 中标记为 completed。如遇阻塞请说明原因。",
330                            pending_list
331                        );
332
333                        self.messages.push(Message {
334                            role: Role::User,
335                            content: MessageContent::Text(reminder),
336                        });
337                        should_continue = true;
338                    }
339                }
340            }
341
342            let context_size = self.provider.context_size();
343            let api_tokens = self.last_input_tokens.load(Ordering::Relaxed) as u32;
344            let estimated_tokens = estimate_total_tokens(&self.messages);
345
346            let current_tokens = if api_tokens > 0 && api_tokens >= estimated_tokens / 2 {
347                api_tokens
348            } else {
349                estimated_tokens
350            };
351
352            // Only log compression check when context is getting full (> 30%)
353            // This avoids cluttering debug panel with meaningless checks
354            if let Some(ctx_size) = context_size {
355                // Send context size to TUI for accurate display
356                self.emit(AgentEvent::with_data(
357                    EventType::ContextSize,
358                    EventData::ContextSize {
359                        context_size: ctx_size as u64,
360                    },
361                ))?;
362
363                let usage_ratio = current_tokens as f64 / ctx_size as f64;
364                if usage_ratio >= 0.3 {
365                    crate::debug::debug_log().log(
366                        "checkcompress",
367                        &format!(
368                            "usage={:.1}%, tokens={}, context={}, threshold={}%",
369                            usage_ratio * 100.0,
370                            current_tokens,
371                            ctx_size,
372                            self.compression_config.threshold * 100.0
373                        ),
374                    );
375                }
376            }
377
378            if should_compress(current_tokens, context_size, &self.compression_config) {
379                self.emit(AgentEvent::progress(prompt::MSG_COMPRESSING_CONTEXT, None))?;
380
381                let original_tokens = current_tokens;
382
383                match compress_messages(
384                    &self.messages,
385                    CompressionStrategy::SlidingWindow,
386                    &self.compression_config,
387                ) {
388                    Ok(compressed) => {
389                        let compressed_tokens = estimate_total_tokens(&compressed);
390                        self.messages = compressed;
391                        self.total_input_tokens
392                            .store(compressed_tokens as u64, Ordering::Relaxed);
393                        self.last_input_tokens
394                            .store(compressed_tokens as u64, Ordering::Relaxed);
395
396                        let ratio = compressed_tokens as f32 / original_tokens as f32;
397                        crate::debug::debug_log().compression(
398                            original_tokens,
399                            compressed_tokens,
400                            ratio,
401                        );
402
403                        self.emit(AgentEvent::with_data(
404                            EventType::CompressionCompleted,
405                            EventData::Compression {
406                                original_tokens: original_tokens as u64,
407                                compressed_tokens: compressed_tokens as u64,
408                                ratio: compressed_tokens as f32 / original_tokens as f32,
409                            },
410                        ))?;
411                    }
412                    Err(e) => {
413                        self.emit(AgentEvent::progress(
414                            format!("{}{}", prompt::MSG_COMPRESSION_FAILED, e),
415                            None,
416                        ))?;
417                    }
418                }
419            }
420        }
421        
422        // Check if we stopped due to reaching MAX_ITERATIONS
423        if iterations >= MAX_ITERATIONS && should_continue {
424            self.emit(AgentEvent::error(
425                prompt::MSG_MAX_ITERATIONS_REACHED
426                    .replace("{max_iterations}", &MAX_ITERATIONS.to_string())
427                    .replace("{iterations}", &iterations.to_string()),
428                Some("MAX_ITERATIONS_REACHED".to_string()),
429                Some("agent/run.rs".to_string()),
430            ))?;
431        }
432        
433        self.emit(AgentEvent::usage_with_cache(
434            self.total_input_tokens.load(Ordering::Relaxed),
435            self.total_output_tokens.load(Ordering::Relaxed),
436            0,
437            0,
438        ))?;
439
440        self.emit(AgentEvent::session_ended())?;
441
442        Ok(Vec::new())
443    }
444
445    /// Restore message history (for session continue/resume)
446    pub fn set_messages(&mut self, messages: Vec<Message>) {
447        self.messages = messages;
448    }
449
450    /// Get current messages (for session saving)
451    pub fn get_messages(&self) -> &[Message] {
452        &self.messages
453    }
454
455    /// Get available tools
456    pub fn get_tools(&self) -> &[Arc<dyn Tool>] {
457        &self.tools
458    }
459
460    /// Get system prompt
461    pub fn get_system_prompt(&self) -> &str {
462        &self.system_prompt
463    }
464
465    /// Get current token counts
466    pub fn get_token_counts(&self) -> (u64, u64) {
467        (
468            self.total_input_tokens.load(Ordering::Relaxed),
469            self.total_output_tokens.load(Ordering::Relaxed),
470        )
471    }
472
473    /// Clear message history
474    pub fn clear_history(&mut self) {
475        self.messages.clear();
476        self.total_input_tokens.store(0, Ordering::Relaxed);
477        self.total_output_tokens.store(0, Ordering::Relaxed);
478        self.last_input_tokens.store(0, Ordering::Relaxed);
479    }
480
481    /// Get message count
482    pub fn message_count(&self) -> usize {
483        self.messages.len()
484    }
485
486    // ========================================================================
487    // MCP Runtime Management
488    // ========================================================================
489
490    /// 动态添加 MCP 服务器
491    /// 
492    /// # Example
493    /// ```ignore
494    /// use matrixcode_core::mcp::McpServerConfig;
495    /// 
496    /// let config = McpServerConfig::stdio("npx", vec!["-y", "@playwright/mcp@latest"]);
497    /// agent.add_mcp_server("playwright", config).await?;
498    /// ```
499    pub async fn add_mcp_server(&mut self, name: &str, config: crate::mcp::McpServerConfig) -> Result<()> {
500        if let Some(registry) = &self.mcp_registry {
501            let mut reg = registry.write().await;
502            reg.add_server(name.to_string(), config);
503            log::info!("MCP server '{}' added to registry", name);
504        } else {
505            log::warn!("MCP registry not initialized, cannot add server '{}'", name);
506        }
507        Ok(())
508    }
509
510    /// 移除 MCP 服务器
511    pub async fn remove_mcp_server(&mut self, name: &str) -> Result<()> {
512        if let Some(registry) = &self.mcp_registry {
513            let mut reg = registry.write().await;
514            reg.remove_server(name).await?;
515            log::info!("MCP server '{}' removed from registry", name);
516        }
517        Ok(())
518    }
519
520    /// 获取 MCP 服务器状态列表
521    pub async fn mcp_server_status(&self) -> Vec<crate::mcp::ServerStatus> {
522        if let Some(registry) = &self.mcp_registry {
523            let reg = registry.read().await;
524            reg.server_status().await.values().cloned().collect()
525        } else {
526            Vec::new()
527        }
528    }
529
530    /// 启动指定的 MCP 服务器
531    pub async fn start_mcp_server(&self, name: &str) -> Result<Vec<Arc<crate::mcp::McpToolWrapper>>> {
532        if let Some(registry) = &self.mcp_registry {
533            let reg = registry.read().await;
534            if let Some(placeholder) = reg.get_server(name) {
535                let tools = placeholder.start().await?;
536                log::info!("MCP server '{}' started with {} tools", name, tools.len());
537                Ok(tools)
538            } else {
539                Err(anyhow::anyhow!("MCP server '{}' not found in registry", name))
540            }
541        } else {
542            Err(anyhow::anyhow!("MCP registry not initialized"))
543        }
544    }
545}