1use anyhow::Result;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicU8, Ordering};
6use tokio::sync::mpsc;
7
8use crate::approval::ApproveMode;
9use crate::cancel::CancellationToken;
10use crate::compress::{
11 CompressionStrategy, compress_messages, estimate_total_tokens, should_compress,
12};
13use crate::event::{AgentEvent, EventData, EventType};
14use crate::prompt;
15use crate::providers::{ChatRequest, Message, MessageContent, Role};
16use crate::tools::ToolDefinition;
17
18use super::types::{Agent, AgentBuilder, MAX_ITERATIONS};
19
20impl Agent {
21 pub(crate) fn new(builder: AgentBuilder) -> Self {
22 let event_tx = builder.event_tx.unwrap_or_else(|| {
23 let (tx, _) = mpsc::channel(100);
24 tx
25 });
26
27 Self {
28 provider: builder.provider,
29 model_name: builder.model_name,
30 tools: builder.tools,
31 messages: Vec::new(),
32 system_prompt: builder.system_prompt,
33 max_tokens: builder.max_tokens,
34 think: builder.think,
35 approve_mode: Arc::new(AtomicU8::new(builder.approve_mode.to_u8())),
36 event_tx,
37 skills: builder.skills,
38 profile: builder.profile,
39 project_overview: builder.project_overview,
40 memory_summary: builder.memory_summary,
41 total_input_tokens: std::sync::atomic::AtomicU64::new(0),
42 total_output_tokens: std::sync::atomic::AtomicU64::new(0),
43 last_input_tokens: std::sync::atomic::AtomicU64::new(0),
44 cancel_token: None,
45 compression_config: crate::compress::CompressionConfig::default(),
46 ask_rx: None,
47 }
48 }
49
50 pub fn event_sender(&self) -> mpsc::Sender<AgentEvent> {
52 self.event_tx.clone()
53 }
54
55 pub fn set_ask_channel(&mut self, rx: mpsc::Receiver<String>) {
57 self.ask_rx = Some(rx);
58 }
59
60 pub fn set_cancel_token(&mut self, token: CancellationToken) {
62 self.cancel_token = Some(token);
63 }
64
65 pub fn set_approve_mode(&mut self, mode: ApproveMode) {
67 let old = ApproveMode::from_u8(self.approve_mode.load(Ordering::Relaxed));
68 log::info!("Agent approve mode changed: {} -> {}", old, mode);
69 self.approve_mode.store(mode.to_u8(), Ordering::Relaxed);
70 }
71
72 pub fn approve_mode_shared(&self) -> Arc<AtomicU8> {
74 self.approve_mode.clone()
75 }
76
77 pub fn set_approve_mode_shared(&mut self, shared: Arc<AtomicU8>) {
79 self.approve_mode = shared;
80 }
81
82 pub fn update_memory_summary(&mut self, summary: Option<String>) {
84 self.memory_summary = summary;
85 self.system_prompt = prompt::build_system_prompt(
86 &self.profile,
87 &self.skills,
88 self.project_overview.as_deref(),
89 self.memory_summary.as_deref(),
90 );
91 }
92
93 pub async fn run(&mut self, user_input: String) -> Result<Vec<AgentEvent>> {
95 self.emit(AgentEvent::session_started())?;
96
97 self.messages.push(Message {
98 role: Role::User,
99 content: MessageContent::Text(user_input.clone()),
100 });
101
102 let mut iterations = 0;
103 let mut should_continue = true;
104 const ITERATION_WARNING_THRESHOLD: usize = MAX_ITERATIONS - 10;
105
106 while should_continue && iterations < MAX_ITERATIONS {
107 iterations += 1;
108
109 if let Some(token) = &self.cancel_token
110 && token.is_cancelled()
111 {
112 self.emit(AgentEvent::error(
113 prompt::MSG_OPERATION_CANCELLED.to_string(),
114 None,
115 None,
116 ))?;
117 break;
118 }
119
120 if iterations == ITERATION_WARNING_THRESHOLD {
122 self.messages.push(Message {
123 role: Role::User,
124 content: MessageContent::Text(
125 prompt::MSG_ITERATION_WARNING
126 .replace("{iterations}", &iterations.to_string())
127 .replace("{max_iterations}", &MAX_ITERATIONS.to_string()),
128 ),
129 });
130 }
131
132 let context_size = self.provider.context_size();
135 let estimated_tokens = estimate_total_tokens(&self.messages);
136
137 if should_compress(estimated_tokens, context_size, &self.compression_config) {
138 self.emit(AgentEvent::progress("⚠️ 上下文过大,正在预压缩...", None))?;
139
140 match compress_messages(
141 &self.messages,
142 CompressionStrategy::SlidingWindow,
143 &self.compression_config,
144 ) {
145 Ok(compressed) => {
146 let compressed_tokens = estimate_total_tokens(&compressed);
147 self.messages = compressed;
148 crate::debug::debug_log().compression(
149 estimated_tokens,
150 compressed_tokens,
151 compressed_tokens as f32 / estimated_tokens as f32,
152 );
153 }
154 Err(e) => {
155 self.emit(AgentEvent::progress(
156 format!("预压缩失败: {}", e),
157 None,
158 ))?;
159 }
160 }
161 }
162
163 let tool_defs: Vec<ToolDefinition> =
164 self.tools.iter().map(|t| t.definition()).collect();
165 let request = ChatRequest {
166 system: Some(self.system_prompt.clone()),
167 messages: self.messages.clone(),
168 max_tokens: self.max_tokens,
169 tools: tool_defs,
170 think: self.think,
171 enable_caching: true,
172 server_tools: Vec::new(),
173 };
174
175 let response = self.call_streaming(&request).await?;
176
177 self.track_usage(&response.usage);
178
179 crate::debug::debug_log().api_call(
180 &self.model_name,
181 response.usage.input_tokens,
182 response.usage.cache_read_input_tokens > 0,
183 );
184
185 should_continue = self.process_response(&response).await?;
186
187 if !should_continue && iterations < MAX_ITERATIONS - 1 {
189 if self.has_pending_todos() {
190 self.messages.push(Message {
191 role: Role::User,
192 content: MessageContent::Text(prompt::MSG_PENDING_TODOS.to_string()),
193 });
194 should_continue = true;
195 }
196 }
197
198 let context_size = self.provider.context_size();
199 let api_tokens = self.last_input_tokens.load(Ordering::Relaxed) as u32;
200 let estimated_tokens = estimate_total_tokens(&self.messages);
201
202 let current_tokens = if api_tokens > 0 && api_tokens >= estimated_tokens / 2 {
203 api_tokens
204 } else {
205 estimated_tokens
206 };
207
208 if let Some(ctx_size) = context_size {
211 let usage_ratio = current_tokens as f64 / ctx_size as f64;
212 if usage_ratio >= 0.3 {
213 crate::debug::debug_log().log(
214 "compression",
215 &format!(
216 "check: usage={:.1}%, tokens={}, context={}, threshold={}%",
217 usage_ratio * 100.0,
218 current_tokens,
219 ctx_size,
220 self.compression_config.threshold * 100.0
221 ),
222 );
223 }
224 }
225
226 if should_compress(current_tokens, context_size, &self.compression_config) {
227 self.emit(AgentEvent::progress(prompt::MSG_COMPRESSING_CONTEXT, None))?;
228
229 let original_tokens = current_tokens;
230
231 match compress_messages(
232 &self.messages,
233 CompressionStrategy::SlidingWindow,
234 &self.compression_config,
235 ) {
236 Ok(compressed) => {
237 let compressed_tokens = estimate_total_tokens(&compressed);
238 self.messages = compressed;
239 self.total_input_tokens
240 .store(compressed_tokens as u64, Ordering::Relaxed);
241 self.last_input_tokens
242 .store(compressed_tokens as u64, Ordering::Relaxed);
243
244 let ratio = compressed_tokens as f32 / original_tokens as f32;
245 crate::debug::debug_log().compression(
246 original_tokens,
247 compressed_tokens,
248 ratio,
249 );
250
251 self.emit(AgentEvent::with_data(
252 EventType::CompressionCompleted,
253 EventData::Compression {
254 original_tokens: original_tokens as u64,
255 compressed_tokens: compressed_tokens as u64,
256 ratio: compressed_tokens as f32 / original_tokens as f32,
257 },
258 ))?;
259 }
260 Err(e) => {
261 self.emit(AgentEvent::progress(
262 format!("{}{}", prompt::MSG_COMPRESSION_FAILED, e),
263 None,
264 ))?;
265 }
266 }
267 }
268 }
269
270 if iterations >= MAX_ITERATIONS && should_continue {
272 self.emit(AgentEvent::error(
273 prompt::MSG_MAX_ITERATIONS_REACHED
274 .replace("{max_iterations}", &MAX_ITERATIONS.to_string())
275 .replace("{iterations}", &iterations.to_string()),
276 Some("MAX_ITERATIONS_REACHED".to_string()),
277 Some("agent/run.rs".to_string()),
278 ))?;
279 }
280
281 self.emit(AgentEvent::usage_with_cache(
282 self.total_input_tokens.load(Ordering::Relaxed),
283 self.total_output_tokens.load(Ordering::Relaxed),
284 0,
285 0,
286 ))?;
287
288 self.emit(AgentEvent::session_ended())?;
289
290 Ok(Vec::new())
291 }
292
293 pub fn set_messages(&mut self, messages: Vec<Message>) {
295 self.messages = messages;
296 }
297
298 pub fn get_messages(&self) -> &[Message] {
300 &self.messages
301 }
302
303 pub fn get_token_counts(&self) -> (u64, u64) {
305 (
306 self.total_input_tokens.load(Ordering::Relaxed),
307 self.total_output_tokens.load(Ordering::Relaxed),
308 )
309 }
310
311 pub fn clear_history(&mut self) {
313 self.messages.clear();
314 self.total_input_tokens.store(0, Ordering::Relaxed);
315 self.total_output_tokens.store(0, Ordering::Relaxed);
316 self.last_input_tokens.store(0, Ordering::Relaxed);
317 }
318
319 pub fn message_count(&self) -> usize {
321 self.messages.len()
322 }
323}