Skip to main content

matrixcode_core/agent/
run.rs

1//! Agent run loop and public methods.
2
3use anyhow::Result;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicU8, Ordering};
6use tokio::sync::mpsc;
7
8use crate::approval::ApproveMode;
9use crate::cancel::CancellationToken;
10use crate::compress::{
11    CompressionStrategy, compress_messages, estimate_total_tokens, should_compress,
12};
13use crate::event::{AgentEvent, EventData, EventType};
14use crate::prompt;
15use crate::providers::{ChatRequest, Message, MessageContent, Role};
16use crate::tools::ToolDefinition;
17
18use super::types::{Agent, AgentBuilder, MAX_ITERATIONS};
19
20impl Agent {
21    pub(crate) fn new(builder: AgentBuilder) -> Self {
22        let event_tx = builder.event_tx.unwrap_or_else(|| {
23            let (tx, _) = mpsc::channel(100);
24            tx
25        });
26
27        Self {
28            provider: builder.provider,
29            model_name: builder.model_name,
30            tools: builder.tools,
31            messages: Vec::new(),
32            system_prompt: builder.system_prompt,
33            max_tokens: builder.max_tokens,
34            think: builder.think,
35            approve_mode: Arc::new(AtomicU8::new(builder.approve_mode.to_u8())),
36            event_tx,
37            skills: builder.skills,
38            profile: builder.profile,
39            project_overview: builder.project_overview,
40            memory_summary: builder.memory_summary,
41            total_input_tokens: std::sync::atomic::AtomicU64::new(0),
42            total_output_tokens: std::sync::atomic::AtomicU64::new(0),
43            last_input_tokens: std::sync::atomic::AtomicU64::new(0),
44            cancel_token: None,
45            compression_config: crate::compress::CompressionConfig::default(),
46            ask_rx: None,
47        }
48    }
49
50    /// Get event sender for streaming
51    pub fn event_sender(&self) -> mpsc::Sender<AgentEvent> {
52        self.event_tx.clone()
53    }
54
55    /// Set ask response channel (for TUI mode)
56    pub fn set_ask_channel(&mut self, rx: mpsc::Receiver<String>) {
57        self.ask_rx = Some(rx);
58    }
59
60    /// Set cancellation token
61    pub fn set_cancel_token(&mut self, token: CancellationToken) {
62        self.cancel_token = Some(token);
63    }
64
65    /// Set approve mode at runtime
66    pub fn set_approve_mode(&mut self, mode: ApproveMode) {
67        let old = ApproveMode::from_u8(self.approve_mode.load(Ordering::Relaxed));
68        log::info!("Agent approve mode changed: {} -> {}", old, mode);
69        self.approve_mode.store(mode.to_u8(), Ordering::Relaxed);
70    }
71
72    /// Get a shared reference to the approve mode atomic.
73    pub fn approve_mode_shared(&self) -> Arc<AtomicU8> {
74        self.approve_mode.clone()
75    }
76
77    /// Replace the internal approve mode with an externally-created shared atomic.
78    pub fn set_approve_mode_shared(&mut self, shared: Arc<AtomicU8>) {
79        self.approve_mode = shared;
80    }
81
82    /// Update memory summary and rebuild system prompt.
83    pub fn update_memory_summary(&mut self, summary: Option<String>) {
84        self.memory_summary = summary;
85        self.system_prompt = prompt::build_system_prompt(
86            &self.profile,
87            &self.skills,
88            self.project_overview.as_deref(),
89            self.memory_summary.as_deref(),
90        );
91    }
92
93    /// Run chat loop with tool execution (streaming version).
94    pub async fn run(&mut self, user_input: String) -> Result<Vec<AgentEvent>> {
95        self.emit(AgentEvent::session_started())?;
96
97        self.messages.push(Message {
98            role: Role::User,
99            content: MessageContent::Text(user_input.clone()),
100        });
101
102        let mut iterations = 0;
103        let mut should_continue = true;
104
105        while should_continue && iterations < MAX_ITERATIONS {
106            iterations += 1;
107
108            if let Some(token) = &self.cancel_token
109                && token.is_cancelled()
110            {
111                self.emit(AgentEvent::error(
112                    "Operation cancelled".to_string(),
113                    None,
114                    None,
115                ))?;
116                break;
117            }
118
119            let tool_defs: Vec<ToolDefinition> =
120                self.tools.iter().map(|t| t.definition()).collect();
121            let request = ChatRequest {
122                system: Some(self.system_prompt.clone()),
123                messages: self.messages.clone(),
124                max_tokens: self.max_tokens,
125                tools: tool_defs,
126                think: self.think,
127                enable_caching: true,
128                server_tools: Vec::new(),
129            };
130
131            let response = self.call_streaming(&request).await?;
132
133            self.track_usage(&response.usage);
134
135            crate::debug::debug_log().api_call(
136                &self.model_name,
137                response.usage.input_tokens,
138                response.usage.cache_read_input_tokens > 0,
139            );
140
141            should_continue = self.process_response(&response).await?;
142
143            let context_size = self.provider.context_size();
144            let api_tokens = self.last_input_tokens.load(Ordering::Relaxed) as u32;
145            let estimated_tokens = estimate_total_tokens(&self.messages);
146
147            let current_tokens = if api_tokens > 0 && api_tokens >= estimated_tokens / 2 {
148                api_tokens
149            } else {
150                estimated_tokens
151            };
152
153            crate::debug::debug_log().log(
154                "compression",
155                &format!(
156                    "check: api={}, estimated={}, using={}, context={}, threshold={}",
157                    api_tokens,
158                    estimated_tokens,
159                    current_tokens,
160                    context_size.unwrap_or(0),
161                    self.compression_config.threshold
162                ),
163            );
164
165            if should_compress(current_tokens, context_size, &self.compression_config) {
166                self.emit(AgentEvent::progress("Compressing context...", None))?;
167
168                let original_tokens = current_tokens;
169
170                match compress_messages(
171                    &self.messages,
172                    CompressionStrategy::SlidingWindow,
173                    &self.compression_config,
174                ) {
175                    Ok(compressed) => {
176                        let compressed_tokens = estimate_total_tokens(&compressed);
177                        self.messages = compressed;
178                        self.total_input_tokens
179                            .store(compressed_tokens as u64, Ordering::Relaxed);
180                        self.last_input_tokens
181                            .store(compressed_tokens as u64, Ordering::Relaxed);
182
183                        let ratio = compressed_tokens as f32 / original_tokens as f32;
184                        crate::debug::debug_log().compression(
185                            original_tokens,
186                            compressed_tokens,
187                            ratio,
188                        );
189
190                        self.emit(AgentEvent::with_data(
191                            EventType::CompressionCompleted,
192                            EventData::Compression {
193                                original_tokens: original_tokens as u64,
194                                compressed_tokens: compressed_tokens as u64,
195                                ratio: compressed_tokens as f32 / original_tokens as f32,
196                            },
197                        ))?;
198                    }
199                    Err(e) => {
200                        self.emit(AgentEvent::progress(
201                            format!("Compression failed: {}", e),
202                            None,
203                        ))?;
204                    }
205                }
206            }
207        }
208
209        self.emit(AgentEvent::usage_with_cache(
210            self.total_input_tokens.load(Ordering::Relaxed),
211            self.total_output_tokens.load(Ordering::Relaxed),
212            0,
213            0,
214        ))?;
215
216        self.emit(AgentEvent::session_ended())?;
217
218        Ok(Vec::new())
219    }
220
221    /// Restore message history (for session continue/resume)
222    pub fn set_messages(&mut self, messages: Vec<Message>) {
223        self.messages = messages;
224    }
225
226    /// Get current messages (for session saving)
227    pub fn get_messages(&self) -> &[Message] {
228        &self.messages
229    }
230
231    /// Get current token counts
232    pub fn get_token_counts(&self) -> (u64, u64) {
233        (
234            self.total_input_tokens.load(Ordering::Relaxed),
235            self.total_output_tokens.load(Ordering::Relaxed),
236        )
237    }
238
239    /// Clear message history
240    pub fn clear_history(&mut self) {
241        self.messages.clear();
242        self.total_input_tokens.store(0, Ordering::Relaxed);
243        self.total_output_tokens.store(0, Ordering::Relaxed);
244        self.last_input_tokens.store(0, Ordering::Relaxed);
245    }
246
247    /// Get message count
248    pub fn message_count(&self) -> usize {
249        self.messages.len()
250    }
251}