Skip to main content

matrixcode_core/agent/
run.rs

1//! Agent run loop and public methods.
2
3use anyhow::Result;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicU8, Ordering};
6use tokio::sync::mpsc;
7
8use crate::approval::ApproveMode;
9use crate::cancel::CancellationToken;
10use crate::compress::{
11    CompressionStrategy, compress_messages, estimate_total_tokens, should_compress,
12};
13use crate::event::{AgentEvent, EventData, EventType};
14use crate::prompt;
15use crate::providers::{ChatRequest, Message, MessageContent, Role};
16use crate::tools::ToolDefinition;
17
18use super::types::{Agent, AgentBuilder, MAX_ITERATIONS};
19
20impl Agent {
21    pub(crate) fn new(builder: AgentBuilder) -> Self {
22        let event_tx = builder.event_tx.unwrap_or_else(|| {
23            let (tx, _) = mpsc::channel(100);
24            tx
25        });
26
27        Self {
28            provider: builder.provider,
29            model_name: builder.model_name,
30            tools: builder.tools,
31            messages: Vec::new(),
32            system_prompt: builder.system_prompt,
33            max_tokens: builder.max_tokens,
34            think: builder.think,
35            approve_mode: Arc::new(AtomicU8::new(builder.approve_mode.to_u8())),
36            event_tx,
37            skills: builder.skills,
38            profile: builder.profile,
39            project_overview: builder.project_overview,
40            memory_summary: builder.memory_summary,
41            total_input_tokens: std::sync::atomic::AtomicU64::new(0),
42            total_output_tokens: std::sync::atomic::AtomicU64::new(0),
43            last_input_tokens: std::sync::atomic::AtomicU64::new(0),
44            cancel_token: None,
45            compression_config: crate::compress::CompressionConfig::default(),
46            ask_rx: None,
47        }
48    }
49
50    /// Get event sender for streaming
51    pub fn event_sender(&self) -> mpsc::Sender<AgentEvent> {
52        self.event_tx.clone()
53    }
54
55    /// Set ask response channel (for TUI mode)
56    pub fn set_ask_channel(&mut self, rx: mpsc::Receiver<String>) {
57        self.ask_rx = Some(rx);
58    }
59
60    /// Set cancellation token
61    pub fn set_cancel_token(&mut self, token: CancellationToken) {
62        self.cancel_token = Some(token);
63    }
64
65    /// Set approve mode at runtime
66    pub fn set_approve_mode(&mut self, mode: ApproveMode) {
67        let old = ApproveMode::from_u8(self.approve_mode.load(Ordering::Relaxed));
68        log::info!("Agent approve mode changed: {} -> {}", old, mode);
69        self.approve_mode.store(mode.to_u8(), Ordering::Relaxed);
70    }
71
72    /// Get a shared reference to the approve mode atomic.
73    pub fn approve_mode_shared(&self) -> Arc<AtomicU8> {
74        self.approve_mode.clone()
75    }
76
77    /// Replace the internal approve mode with an externally-created shared atomic.
78    pub fn set_approve_mode_shared(&mut self, shared: Arc<AtomicU8>) {
79        self.approve_mode = shared;
80    }
81
82    /// Update memory summary and rebuild system prompt.
83    pub fn update_memory_summary(&mut self, summary: Option<String>) {
84        self.memory_summary = summary;
85        self.system_prompt = prompt::build_system_prompt(
86            &self.profile,
87            &self.skills,
88            self.project_overview.as_deref(),
89            self.memory_summary.as_deref(),
90        );
91    }
92
93    /// Run chat loop with tool execution (streaming version).
94    pub async fn run(&mut self, user_input: String) -> Result<Vec<AgentEvent>> {
95        self.emit(AgentEvent::session_started())?;
96
97        self.messages.push(Message {
98            role: Role::User,
99            content: MessageContent::Text(user_input.clone()),
100        });
101
102        let mut iterations = 0;
103        let mut should_continue = true;
104        const ITERATION_WARNING_THRESHOLD: usize = MAX_ITERATIONS - 10;
105
106        while should_continue && iterations < MAX_ITERATIONS {
107            iterations += 1;
108
109            if let Some(token) = &self.cancel_token
110                && token.is_cancelled()
111            {
112                self.emit(AgentEvent::error(
113                    prompt::MSG_OPERATION_CANCELLED.to_string(),
114                    None,
115                    None,
116                ))?;
117                break;
118            }
119
120            // Warn when approaching iteration limit
121            if iterations == ITERATION_WARNING_THRESHOLD {
122                self.messages.push(Message {
123                    role: Role::User,
124                    content: MessageContent::Text(
125                        prompt::MSG_ITERATION_WARNING
126                            .replace("{iterations}", &iterations.to_string())
127                            .replace("{max_iterations}", &MAX_ITERATIONS.to_string()),
128                    ),
129                });
130            }
131
132            // Proactive compression: check context size BEFORE API call
133            // For long conversations, compress early to avoid timeout issues
134            let context_size = self.provider.context_size();
135            let estimated_tokens = estimate_total_tokens(&self.messages);
136
137            if should_compress(estimated_tokens, context_size, &self.compression_config) {
138                self.emit(AgentEvent::progress("⚠️ 上下文过大,正在预压缩...", None))?;
139
140                match compress_messages(
141                    &self.messages,
142                    CompressionStrategy::SlidingWindow,
143                    &self.compression_config,
144                ) {
145                    Ok(compressed) => {
146                        let compressed_tokens = estimate_total_tokens(&compressed);
147                        self.messages = compressed;
148                        crate::debug::debug_log().compression(
149                            estimated_tokens,
150                            compressed_tokens,
151                            compressed_tokens as f32 / estimated_tokens as f32,
152                        );
153                    }
154                    Err(e) => {
155                        self.emit(AgentEvent::progress(
156                            format!("预压缩失败: {}", e),
157                            None,
158                        ))?;
159                    }
160                }
161            }
162
163            let tool_defs: Vec<ToolDefinition> =
164                self.tools.iter().map(|t| t.definition()).collect();
165            let request = ChatRequest {
166                system: Some(self.system_prompt.clone()),
167                messages: self.messages.clone(),
168                max_tokens: self.max_tokens,
169                tools: tool_defs,
170                think: self.think,
171                enable_caching: true,
172                server_tools: Vec::new(),
173            };
174
175            let response = self.call_streaming(&request).await?;
176
177            self.track_usage(&response.usage);
178
179            crate::debug::debug_log().api_call(
180                &self.model_name,
181                response.usage.input_tokens,
182                response.usage.cache_read_input_tokens > 0,
183            );
184
185            should_continue = self.process_response(&response).await?;
186
187            // If model wants to stop (no tool calls), check for pending todos
188            if !should_continue && iterations < MAX_ITERATIONS - 1 {
189                if self.has_pending_todos() {
190                    self.messages.push(Message {
191                        role: Role::User,
192                        content: MessageContent::Text(prompt::MSG_PENDING_TODOS.to_string()),
193                    });
194                    should_continue = true;
195                }
196            }
197
198            let context_size = self.provider.context_size();
199            let api_tokens = self.last_input_tokens.load(Ordering::Relaxed) as u32;
200            let estimated_tokens = estimate_total_tokens(&self.messages);
201
202            let current_tokens = if api_tokens > 0 && api_tokens >= estimated_tokens / 2 {
203                api_tokens
204            } else {
205                estimated_tokens
206            };
207
208            crate::debug::debug_log().log(
209                "compression",
210                &format!(
211                    "check: api={}, estimated={}, using={}, context={}, threshold={}",
212                    api_tokens,
213                    estimated_tokens,
214                    current_tokens,
215                    context_size.unwrap_or(0),
216                    self.compression_config.threshold
217                ),
218            );
219
220            if should_compress(current_tokens, context_size, &self.compression_config) {
221                self.emit(AgentEvent::progress(prompt::MSG_COMPRESSING_CONTEXT, None))?;
222
223                let original_tokens = current_tokens;
224
225                match compress_messages(
226                    &self.messages,
227                    CompressionStrategy::SlidingWindow,
228                    &self.compression_config,
229                ) {
230                    Ok(compressed) => {
231                        let compressed_tokens = estimate_total_tokens(&compressed);
232                        self.messages = compressed;
233                        self.total_input_tokens
234                            .store(compressed_tokens as u64, Ordering::Relaxed);
235                        self.last_input_tokens
236                            .store(compressed_tokens as u64, Ordering::Relaxed);
237
238                        let ratio = compressed_tokens as f32 / original_tokens as f32;
239                        crate::debug::debug_log().compression(
240                            original_tokens,
241                            compressed_tokens,
242                            ratio,
243                        );
244
245                        self.emit(AgentEvent::with_data(
246                            EventType::CompressionCompleted,
247                            EventData::Compression {
248                                original_tokens: original_tokens as u64,
249                                compressed_tokens: compressed_tokens as u64,
250                                ratio: compressed_tokens as f32 / original_tokens as f32,
251                            },
252                        ))?;
253                    }
254                    Err(e) => {
255                        self.emit(AgentEvent::progress(
256                            format!("{}{}", prompt::MSG_COMPRESSION_FAILED, e),
257                            None,
258                        ))?;
259                    }
260                }
261            }
262        }
263        
264        // Check if we stopped due to reaching MAX_ITERATIONS
265        if iterations >= MAX_ITERATIONS && should_continue {
266            self.emit(AgentEvent::error(
267                prompt::MSG_MAX_ITERATIONS_REACHED
268                    .replace("{max_iterations}", &MAX_ITERATIONS.to_string())
269                    .replace("{iterations}", &iterations.to_string()),
270                Some("MAX_ITERATIONS_REACHED".to_string()),
271                Some("agent/run.rs".to_string()),
272            ))?;
273        }
274        
275        self.emit(AgentEvent::usage_with_cache(
276            self.total_input_tokens.load(Ordering::Relaxed),
277            self.total_output_tokens.load(Ordering::Relaxed),
278            0,
279            0,
280        ))?;
281
282        self.emit(AgentEvent::session_ended())?;
283
284        Ok(Vec::new())
285    }
286
287    /// Restore message history (for session continue/resume)
288    pub fn set_messages(&mut self, messages: Vec<Message>) {
289        self.messages = messages;
290    }
291
292    /// Get current messages (for session saving)
293    pub fn get_messages(&self) -> &[Message] {
294        &self.messages
295    }
296
297    /// Get current token counts
298    pub fn get_token_counts(&self) -> (u64, u64) {
299        (
300            self.total_input_tokens.load(Ordering::Relaxed),
301            self.total_output_tokens.load(Ordering::Relaxed),
302        )
303    }
304
305    /// Clear message history
306    pub fn clear_history(&mut self) {
307        self.messages.clear();
308        self.total_input_tokens.store(0, Ordering::Relaxed);
309        self.total_output_tokens.store(0, Ordering::Relaxed);
310        self.last_input_tokens.store(0, Ordering::Relaxed);
311    }
312
313    /// Get message count
314    pub fn message_count(&self) -> usize {
315        self.messages.len()
316    }
317}