Skip to main content

matrixcode_core/agent/
run.rs

1//! Agent run loop and public methods.
2
3use anyhow::Result;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicU8, Ordering};
6use tokio::sync::mpsc;
7
8use crate::approval::ApproveMode;
9use crate::cancel::CancellationToken;
10use crate::compress::{
11    CompressionStrategy, compress_messages, estimate_total_tokens, should_compress,
12};
13use crate::event::{AgentEvent, EventData, EventType};
14use crate::prompt;
15use crate::providers::{ChatRequest, Message, MessageContent, Role};
16use crate::tools::ToolDefinition;
17use crate::tools::toolproxy::{ProxyToolExecutor, ProxyToolDef};
18
19use super::types::{Agent, AgentBuilder, MAX_ITERATIONS};
20
21impl Agent {
22    pub(crate) fn new(builder: AgentBuilder) -> Self {
23        let event_tx = builder.event_tx.unwrap_or_else(|| {
24            let (tx, _) = mpsc::channel(100);
25            tx
26        });
27
28        Self {
29            provider: builder.provider,
30            model_name: builder.model_name,
31            tools: builder.tools,
32            messages: Vec::new(),
33            system_prompt: builder.system_prompt,
34            max_tokens: builder.max_tokens,
35            think: builder.think,
36            approve_mode: Arc::new(AtomicU8::new(builder.approve_mode.to_u8())),
37            event_tx,
38            skills: builder.skills,
39            profile: builder.profile,
40            project_overview: builder.project_overview,
41            memory_summary: builder.memory_summary,
42            project_path: builder.project_path,
43            total_input_tokens: std::sync::atomic::AtomicU64::new(0),
44            total_output_tokens: std::sync::atomic::AtomicU64::new(0),
45            last_input_tokens: std::sync::atomic::AtomicU64::new(0),
46            cancel_token: None,
47            compression_config: crate::compress::CompressionConfig::default(),
48            ask_rx: None,
49            proxy_tool_defs: builder.proxy_tool_defs,
50            proxy_executor: builder.proxy_executor,
51        }
52    }
53
54    /// Get event sender for streaming
55    pub fn event_sender(&self) -> mpsc::Sender<AgentEvent> {
56        self.event_tx.clone()
57    }
58
59    /// Set ask response channel (for TUI mode)
60    pub fn set_ask_channel(&mut self, rx: mpsc::Receiver<String>) {
61        self.ask_rx = Some(rx);
62    }
63
64    /// 设置代理工具执行器
65    pub fn set_proxy_executor(&mut self, executor: Arc<dyn ProxyToolExecutor>, tool_defs: Vec<ProxyToolDef>) {
66        self.proxy_executor = Some(executor);
67        self.proxy_tool_defs = tool_defs;
68    }
69
70    /// Set cancellation token
71    pub fn set_cancel_token(&mut self, token: CancellationToken) {
72        self.cancel_token = Some(token);
73    }
74
75    /// Set approve mode at runtime
76    pub fn set_approve_mode(&mut self, mode: ApproveMode) {
77        let old = ApproveMode::from_u8(self.approve_mode.load(Ordering::Relaxed));
78        log::info!("Agent approve mode changed: {} -> {}", old, mode);
79        self.approve_mode.store(mode.to_u8(), Ordering::Relaxed);
80    }
81
82    /// Get a shared reference to the approve mode atomic.
83    pub fn approve_mode_shared(&self) -> Arc<AtomicU8> {
84        self.approve_mode.clone()
85    }
86
87    /// Replace the internal approve mode with an externally-created shared atomic.
88    pub fn set_approve_mode_shared(&mut self, shared: Arc<AtomicU8>) {
89        self.approve_mode = shared;
90    }
91
92    /// Update memory summary and rebuild system prompt.
93    /// Note: Uses build_system_prompt (without project_path) to preserve cache.
94    pub fn update_memory_summary(&mut self, summary: Option<String>) {
95        self.memory_summary = summary;
96        // Preserve cache by using build_system_prompt (no dynamic CodeGraph injection)
97        self.system_prompt = prompt::build_system_prompt(
98            &self.profile,
99            &self.skills,
100            self.project_overview.as_deref(),
101            self.memory_summary.as_deref(),
102        );
103    }
104
105    /// Refresh CodeGraph tools after /init or codegraph init.
106    /// This rebuilds both tools and system prompt with project_path.
107    /// Call this only when CodeGraph state changes (not every request) to preserve cache.
108    pub fn refresh_codegraph_tools(&mut self) {
109        if let Some(path) = &self.project_path {
110            // Check if CodeGraph should be injected now
111            let should_have_codegraph = crate::tools::codegraph::should_inject_codegraph_tools(path);
112
113            // Check if we currently have CodeGraph tools
114            let has_codegraph = self.tools.iter().any(|t| {
115                let name = t.definition().name;
116                name.starts_with("code_") && name != "code_review"
117            });
118
119            // Only update if state changed
120            if should_have_codegraph != has_codegraph {
121                // Add or remove CodeGraph tools
122                if should_have_codegraph {
123                    let codegraph_tools = crate::tools::codegraph::codegraph_tools(path);
124                    for tool in codegraph_tools {
125                        self.tools.push(Arc::from(tool));
126                    }
127                    // Update system prompt to include CodeGraph rules
128                    self.system_prompt = prompt::build_system_prompt_with_workflows(
129                        &self.profile,
130                        &self.skills,
131                        self.project_overview.as_deref(),
132                        self.memory_summary.as_deref(),
133                        Some(path),
134                    );
135                } else {
136                    // Remove CodeGraph tools
137                    self.tools.retain(|t| {
138                        let name = t.definition().name;
139                        !name.starts_with("code_") || name == "code_review"
140                    });
141                    // Update system prompt to remove CodeGraph rules
142                    self.system_prompt = prompt::build_system_prompt_with_workflows(
143                        &self.profile,
144                        &self.skills,
145                        self.project_overview.as_deref(),
146                        self.memory_summary.as_deref(),
147                        Some(path),
148                    );
149                }
150            }
151        }
152    }
153
154    /// Run chat loop with tool execution (streaming version).
155    pub async fn run(&mut self, user_input: String) -> Result<Vec<AgentEvent>> {
156        self.emit(AgentEvent::session_started())?;
157
158        self.messages.push(Message {
159            role: Role::User,
160            content: MessageContent::Text(user_input.clone()),
161        });
162
163        let mut iterations = 0;
164        let mut should_continue = true;
165        const ITERATION_WARNING_THRESHOLD: usize = MAX_ITERATIONS - 10;
166
167        while should_continue && iterations < MAX_ITERATIONS {
168            iterations += 1;
169
170            if let Some(token) = &self.cancel_token
171                && token.is_cancelled()
172            {
173                self.emit(AgentEvent::error(
174                    prompt::MSG_OPERATION_CANCELLED.to_string(),
175                    None,
176                    None,
177                ))?;
178                break;
179            }
180
181            // Warn when approaching iteration limit
182            if iterations == ITERATION_WARNING_THRESHOLD {
183                self.messages.push(Message {
184                    role: Role::User,
185                    content: MessageContent::Text(
186                        prompt::MSG_ITERATION_WARNING
187                            .replace("{iterations}", &iterations.to_string())
188                            .replace("{max_iterations}", &MAX_ITERATIONS.to_string()),
189                    ),
190                });
191            }
192
193            // Proactive compression: check context size BEFORE API call
194            // For long conversations, compress early to avoid timeout issues
195            let context_size = self.provider.context_size();
196            let estimated_tokens = estimate_total_tokens(&self.messages);
197
198            if should_compress(estimated_tokens, context_size, &self.compression_config) {
199                self.emit(AgentEvent::progress("⚠️ 上下文过大,正在预压缩...", None))?;
200
201                match compress_messages(
202                    &self.messages,
203                    CompressionStrategy::SlidingWindow,
204                    &self.compression_config,
205                ) {
206                    Ok(compressed) => {
207                        let compressed_tokens = estimate_total_tokens(&compressed);
208                        self.messages = compressed;
209                        crate::debug::debug_log().compression(
210                            estimated_tokens,
211                            compressed_tokens,
212                            compressed_tokens as f32 / estimated_tokens as f32,
213                        );
214                    }
215                    Err(e) => {
216                        self.emit(AgentEvent::progress(
217                            format!("预压缩失败: {}", e),
218                            None,
219                        ))?;
220                    }
221                }
222            }
223
224            // 合并内置工具和代理工具定义,应用优先标记
225            let tool_defs: Vec<ToolDefinition> = {
226                let mut defs: Vec<ToolDefinition> = self.tools.iter().map(|t| {
227                    let def = t.definition();
228                    let description = def.description_for_llm();
229                    ToolDefinition {
230                        name: def.name,
231                        description,
232                        parameters: def.parameters,
233                        is_priority: def.is_priority,
234                    }
235                }).collect();
236                // 添加代理工具定义
237                defs.extend(self.proxy_tool_defs.iter().map(|t| {
238                    let def = &t.definition;
239                    let description = def.description_for_llm();
240                    ToolDefinition {
241                        name: def.name.clone(),
242                        description,
243                        parameters: def.parameters.clone(),
244                        is_priority: def.is_priority,
245                    }
246                }));
247                defs
248            };
249            let request = ChatRequest {
250                system: Some(self.system_prompt.clone()),
251                messages: self.messages.clone(),
252                max_tokens: self.max_tokens,
253                tools: tool_defs,
254                think: self.think,
255                enable_caching: true,
256                server_tools: Vec::new(),
257            };
258
259            let response = self.call_streaming(&request).await?;
260
261            self.track_usage(&response.usage);
262
263            crate::debug::debug_log().api_call(
264                &self.model_name,
265                response.usage.input_tokens,
266                response.usage.cache_read_input_tokens > 0,
267            );
268
269            should_continue = self.process_response(&response).await?;
270
271            // If model wants to stop (no tool calls), check for pending todos
272            if !should_continue && iterations < MAX_ITERATIONS - 1
273                && self.has_pending_todos() {
274                    self.messages.push(Message {
275                        role: Role::User,
276                        content: MessageContent::Text(prompt::MSG_PENDING_TODOS.to_string()),
277                    });
278                    should_continue = true;
279                }
280
281            let context_size = self.provider.context_size();
282            let api_tokens = self.last_input_tokens.load(Ordering::Relaxed) as u32;
283            let estimated_tokens = estimate_total_tokens(&self.messages);
284
285            let current_tokens = if api_tokens > 0 && api_tokens >= estimated_tokens / 2 {
286                api_tokens
287            } else {
288                estimated_tokens
289            };
290
291            // Only log compression check when context is getting full (> 30%)
292            // This avoids cluttering debug panel with meaningless checks
293            if let Some(ctx_size) = context_size {
294                // Send context size to TUI for accurate display
295                self.emit(AgentEvent::with_data(
296                    EventType::ContextSize,
297                    EventData::ContextSize {
298                        context_size: ctx_size as u64,
299                    },
300                ))?;
301
302                let usage_ratio = current_tokens as f64 / ctx_size as f64;
303                if usage_ratio >= 0.3 {
304                    crate::debug::debug_log().log(
305                        "checkcompress",
306                        &format!(
307                            "usage={:.1}%, tokens={}, context={}, threshold={}%",
308                            usage_ratio * 100.0,
309                            current_tokens,
310                            ctx_size,
311                            self.compression_config.threshold * 100.0
312                        ),
313                    );
314                }
315            }
316
317            if should_compress(current_tokens, context_size, &self.compression_config) {
318                self.emit(AgentEvent::progress(prompt::MSG_COMPRESSING_CONTEXT, None))?;
319
320                let original_tokens = current_tokens;
321
322                match compress_messages(
323                    &self.messages,
324                    CompressionStrategy::SlidingWindow,
325                    &self.compression_config,
326                ) {
327                    Ok(compressed) => {
328                        let compressed_tokens = estimate_total_tokens(&compressed);
329                        self.messages = compressed;
330                        self.total_input_tokens
331                            .store(compressed_tokens as u64, Ordering::Relaxed);
332                        self.last_input_tokens
333                            .store(compressed_tokens as u64, Ordering::Relaxed);
334
335                        let ratio = compressed_tokens as f32 / original_tokens as f32;
336                        crate::debug::debug_log().compression(
337                            original_tokens,
338                            compressed_tokens,
339                            ratio,
340                        );
341
342                        self.emit(AgentEvent::with_data(
343                            EventType::CompressionCompleted,
344                            EventData::Compression {
345                                original_tokens: original_tokens as u64,
346                                compressed_tokens: compressed_tokens as u64,
347                                ratio: compressed_tokens as f32 / original_tokens as f32,
348                            },
349                        ))?;
350                    }
351                    Err(e) => {
352                        self.emit(AgentEvent::progress(
353                            format!("{}{}", prompt::MSG_COMPRESSION_FAILED, e),
354                            None,
355                        ))?;
356                    }
357                }
358            }
359        }
360        
361        // Check if we stopped due to reaching MAX_ITERATIONS
362        if iterations >= MAX_ITERATIONS && should_continue {
363            self.emit(AgentEvent::error(
364                prompt::MSG_MAX_ITERATIONS_REACHED
365                    .replace("{max_iterations}", &MAX_ITERATIONS.to_string())
366                    .replace("{iterations}", &iterations.to_string()),
367                Some("MAX_ITERATIONS_REACHED".to_string()),
368                Some("agent/run.rs".to_string()),
369            ))?;
370        }
371        
372        self.emit(AgentEvent::usage_with_cache(
373            self.total_input_tokens.load(Ordering::Relaxed),
374            self.total_output_tokens.load(Ordering::Relaxed),
375            0,
376            0,
377        ))?;
378
379        self.emit(AgentEvent::session_ended())?;
380
381        Ok(Vec::new())
382    }
383
384    /// Restore message history (for session continue/resume)
385    pub fn set_messages(&mut self, messages: Vec<Message>) {
386        self.messages = messages;
387    }
388
389    /// Get current messages (for session saving)
390    pub fn get_messages(&self) -> &[Message] {
391        &self.messages
392    }
393
394    /// Get current token counts
395    pub fn get_token_counts(&self) -> (u64, u64) {
396        (
397            self.total_input_tokens.load(Ordering::Relaxed),
398            self.total_output_tokens.load(Ordering::Relaxed),
399        )
400    }
401
402    /// Clear message history
403    pub fn clear_history(&mut self) {
404        self.messages.clear();
405        self.total_input_tokens.store(0, Ordering::Relaxed);
406        self.total_output_tokens.store(0, Ordering::Relaxed);
407        self.last_input_tokens.store(0, Ordering::Relaxed);
408    }
409
410    /// Get message count
411    pub fn message_count(&self) -> usize {
412        self.messages.len()
413    }
414}