Skip to main content

matrixcode_core/agent/
run.rs

1//! Agent run loop and public methods.
2
3use anyhow::Result;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicU8, Ordering};
6use tokio::sync::mpsc;
7
8use crate::approval::ApproveMode;
9use crate::cancel::CancellationToken;
10use crate::compress::{
11    CompressionStrategy, compress_messages, estimate_total_tokens, should_compress,
12};
13use crate::event::{AgentEvent, EventData, EventType};
14use crate::prompt;
15use crate::providers::{ChatRequest, Message, MessageContent, Role};
16use crate::tools::ToolDefinition;
17
18use super::types::{Agent, AgentBuilder, MAX_ITERATIONS};
19
20impl Agent {
21    pub(crate) fn new(builder: AgentBuilder) -> Self {
22        let event_tx = builder.event_tx.unwrap_or_else(|| {
23            let (tx, _) = mpsc::channel(100);
24            tx
25        });
26
27        Self {
28            provider: builder.provider,
29            model_name: builder.model_name,
30            tools: builder.tools,
31            messages: Vec::new(),
32            system_prompt: builder.system_prompt,
33            max_tokens: builder.max_tokens,
34            think: builder.think,
35            approve_mode: Arc::new(AtomicU8::new(builder.approve_mode.to_u8())),
36            event_tx,
37            skills: builder.skills,
38            profile: builder.profile,
39            project_overview: builder.project_overview,
40            memory_summary: builder.memory_summary,
41            total_input_tokens: std::sync::atomic::AtomicU64::new(0),
42            total_output_tokens: std::sync::atomic::AtomicU64::new(0),
43            last_input_tokens: std::sync::atomic::AtomicU64::new(0),
44            cancel_token: None,
45            compression_config: crate::compress::CompressionConfig::default(),
46            ask_rx: None,
47        }
48    }
49
50    /// Get event sender for streaming
51    pub fn event_sender(&self) -> mpsc::Sender<AgentEvent> {
52        self.event_tx.clone()
53    }
54
55    /// Set ask response channel (for TUI mode)
56    pub fn set_ask_channel(&mut self, rx: mpsc::Receiver<String>) {
57        self.ask_rx = Some(rx);
58    }
59
60    /// Set cancellation token
61    pub fn set_cancel_token(&mut self, token: CancellationToken) {
62        self.cancel_token = Some(token);
63    }
64
65    /// Set approve mode at runtime
66    pub fn set_approve_mode(&mut self, mode: ApproveMode) {
67        let old = ApproveMode::from_u8(self.approve_mode.load(Ordering::Relaxed));
68        log::info!("Agent approve mode changed: {} -> {}", old, mode);
69        self.approve_mode.store(mode.to_u8(), Ordering::Relaxed);
70    }
71
72    /// Get a shared reference to the approve mode atomic.
73    pub fn approve_mode_shared(&self) -> Arc<AtomicU8> {
74        self.approve_mode.clone()
75    }
76
77    /// Replace the internal approve mode with an externally-created shared atomic.
78    pub fn set_approve_mode_shared(&mut self, shared: Arc<AtomicU8>) {
79        self.approve_mode = shared;
80    }
81
82    /// Update memory summary and rebuild system prompt.
83    pub fn update_memory_summary(&mut self, summary: Option<String>) {
84        self.memory_summary = summary;
85        self.system_prompt = prompt::build_system_prompt(
86            &self.profile,
87            &self.skills,
88            self.project_overview.as_deref(),
89            self.memory_summary.as_deref(),
90        );
91    }
92
93    /// Run chat loop with tool execution (streaming version).
94    pub async fn run(&mut self, user_input: String) -> Result<Vec<AgentEvent>> {
95        self.emit(AgentEvent::session_started())?;
96
97        self.messages.push(Message {
98            role: Role::User,
99            content: MessageContent::Text(user_input.clone()),
100        });
101
102        let mut iterations = 0;
103        let mut should_continue = true;
104        const ITERATION_WARNING_THRESHOLD: usize = MAX_ITERATIONS - 10;
105
106        while should_continue && iterations < MAX_ITERATIONS {
107            iterations += 1;
108
109            if let Some(token) = &self.cancel_token
110                && token.is_cancelled()
111            {
112                self.emit(AgentEvent::error(
113                    "Operation cancelled".to_string(),
114                    None,
115                    None,
116                ))?;
117                break;
118            }
119
120            // Warn when approaching iteration limit
121            if iterations == ITERATION_WARNING_THRESHOLD {
122                self.messages.push(Message {
123                    role: Role::User,
124                    content: MessageContent::Text(
125                        "⚠️ 接近最大迭代次数限制(当前 {iterations}/{MAX_ITERATIONS})。\
126                         请检查任务进度:\n\
127                         1. 如果有未完成的子任务,优先完成最关键的项\n\
128                         2. 使用 todo_write 查看和更新任务状态\n\
129                         3. 确保在限制内完成或在最后输出剩余任务摘要".replace("{iterations}", &iterations.to_string()).replace("{MAX_ITERATIONS}", &MAX_ITERATIONS.to_string())
130                    ),
131                });
132            }
133
134            let tool_defs: Vec<ToolDefinition> =
135                self.tools.iter().map(|t| t.definition()).collect();
136            let request = ChatRequest {
137                system: Some(self.system_prompt.clone()),
138                messages: self.messages.clone(),
139                max_tokens: self.max_tokens,
140                tools: tool_defs,
141                think: self.think,
142                enable_caching: true,
143                server_tools: Vec::new(),
144            };
145
146            let response = self.call_streaming(&request).await?;
147
148            self.track_usage(&response.usage);
149
150            crate::debug::debug_log().api_call(
151                &self.model_name,
152                response.usage.input_tokens,
153                response.usage.cache_read_input_tokens > 0,
154            );
155
156            should_continue = self.process_response(&response).await?;
157
158            // If model wants to stop (no tool calls), check for pending todos
159            if !should_continue && iterations < MAX_ITERATIONS - 1 {
160                if self.has_pending_todos() {
161                    self.messages.push(Message {
162                        role: Role::User,
163                        content: MessageContent::Text(
164                            "📋 检测到未完成的待办任务。请继续执行剩余任务,或在 todo_write 中将已完成的任务标记为 completed。\n\
165                             注意:只有所有任务都完成后才能结束。如果遇到阻塞,请说明原因。".to_string()
166                        ),
167                    });
168                    should_continue = true;
169                }
170            }
171
172            let context_size = self.provider.context_size();
173            let api_tokens = self.last_input_tokens.load(Ordering::Relaxed) as u32;
174            let estimated_tokens = estimate_total_tokens(&self.messages);
175
176            let current_tokens = if api_tokens > 0 && api_tokens >= estimated_tokens / 2 {
177                api_tokens
178            } else {
179                estimated_tokens
180            };
181
182            crate::debug::debug_log().log(
183                "compression",
184                &format!(
185                    "check: api={}, estimated={}, using={}, context={}, threshold={}",
186                    api_tokens,
187                    estimated_tokens,
188                    current_tokens,
189                    context_size.unwrap_or(0),
190                    self.compression_config.threshold
191                ),
192            );
193
194            if should_compress(current_tokens, context_size, &self.compression_config) {
195                self.emit(AgentEvent::progress("Compressing context...", None))?;
196
197                let original_tokens = current_tokens;
198
199                match compress_messages(
200                    &self.messages,
201                    CompressionStrategy::SlidingWindow,
202                    &self.compression_config,
203                ) {
204                    Ok(compressed) => {
205                        let compressed_tokens = estimate_total_tokens(&compressed);
206                        self.messages = compressed;
207                        self.total_input_tokens
208                            .store(compressed_tokens as u64, Ordering::Relaxed);
209                        self.last_input_tokens
210                            .store(compressed_tokens as u64, Ordering::Relaxed);
211
212                        let ratio = compressed_tokens as f32 / original_tokens as f32;
213                        crate::debug::debug_log().compression(
214                            original_tokens,
215                            compressed_tokens,
216                            ratio,
217                        );
218
219                        self.emit(AgentEvent::with_data(
220                            EventType::CompressionCompleted,
221                            EventData::Compression {
222                                original_tokens: original_tokens as u64,
223                                compressed_tokens: compressed_tokens as u64,
224                                ratio: compressed_tokens as f32 / original_tokens as f32,
225                            },
226                        ))?;
227                    }
228                    Err(e) => {
229                        self.emit(AgentEvent::progress(
230                            format!("Compression failed: {}", e),
231                            None,
232                        ))?;
233                    }
234                }
235            }
236        }
237        
238        // Check if we stopped due to reaching MAX_ITERATIONS
239        if iterations >= MAX_ITERATIONS && should_continue {
240            self.emit(AgentEvent::error(
241                format!(
242                    "⚠️ Reached maximum iterations limit ({} iterations).\n\n\
243                    **Task status**: The task may not be fully complete.\n\n\
244                    **What happened**: Agent stopped after {} iterations to prevent infinite loops.\n\n\
245                    **Next steps**:\n\
246                    1. Check if the task is complete\n\
247                    2. If incomplete, you can:\n\
248                       - Continue with more specific instructions\n\
249                       - Break down the task into smaller subtasks\n\
250                       - Use '/resume' to continue from current state\n\n\
251                    **Why this limit exists**: Prevents runaway operations and resource exhaustion.\n\
252                    **Adjustable**: Future versions will allow custom iteration limits.",
253                    MAX_ITERATIONS, iterations
254                ),
255                Some("MAX_ITERATIONS_REACHED".to_string()),
256                Some("agent/run.rs".to_string()),
257            ))?;
258        }
259        
260        self.emit(AgentEvent::usage_with_cache(
261            self.total_input_tokens.load(Ordering::Relaxed),
262            self.total_output_tokens.load(Ordering::Relaxed),
263            0,
264            0,
265        ))?;
266
267        self.emit(AgentEvent::session_ended())?;
268
269        Ok(Vec::new())
270    }
271
272    /// Restore message history (for session continue/resume)
273    pub fn set_messages(&mut self, messages: Vec<Message>) {
274        self.messages = messages;
275    }
276
277    /// Get current messages (for session saving)
278    pub fn get_messages(&self) -> &[Message] {
279        &self.messages
280    }
281
282    /// Get current token counts
283    pub fn get_token_counts(&self) -> (u64, u64) {
284        (
285            self.total_input_tokens.load(Ordering::Relaxed),
286            self.total_output_tokens.load(Ordering::Relaxed),
287        )
288    }
289
290    /// Clear message history
291    pub fn clear_history(&mut self) {
292        self.messages.clear();
293        self.total_input_tokens.store(0, Ordering::Relaxed);
294        self.total_output_tokens.store(0, Ordering::Relaxed);
295        self.last_input_tokens.store(0, Ordering::Relaxed);
296    }
297
298    /// Get message count
299    pub fn message_count(&self) -> usize {
300        self.messages.len()
301    }
302}