1use anyhow::Result;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicU8, Ordering};
6use tokio::sync::mpsc;
7
8use crate::approval::ApproveMode;
9use crate::cancel::CancellationToken;
10use crate::compress::{
11 CompressionStrategy, compress_messages, estimate_total_tokens, should_compress,
12};
13use crate::event::{AgentEvent, EventData, EventType};
14use crate::prompt;
15use crate::providers::{ChatRequest, Message, MessageContent, Role};
16use crate::tools::ToolDefinition;
17use crate::tools::toolproxy::{ProxyToolExecutor, ProxyToolDef};
18
19use super::types::{Agent, AgentBuilder, MAX_ITERATIONS};
20
21impl Agent {
22 pub(crate) fn new(builder: AgentBuilder) -> Self {
23 let event_tx = builder.event_tx.unwrap_or_else(|| {
24 let (tx, _) = mpsc::channel(100);
25 tx
26 });
27
28 Self {
29 provider: builder.provider,
30 model_name: builder.model_name,
31 tools: builder.tools,
32 messages: Vec::new(),
33 system_prompt: builder.system_prompt,
34 max_tokens: builder.max_tokens,
35 think: builder.think,
36 approve_mode: Arc::new(AtomicU8::new(builder.approve_mode.to_u8())),
37 event_tx,
38 skills: builder.skills,
39 profile: builder.profile,
40 project_overview: builder.project_overview,
41 memory_summary: builder.memory_summary,
42 project_path: builder.project_path,
43 total_input_tokens: std::sync::atomic::AtomicU64::new(0),
44 total_output_tokens: std::sync::atomic::AtomicU64::new(0),
45 last_input_tokens: std::sync::atomic::AtomicU64::new(0),
46 cancel_token: None,
47 compression_config: crate::compress::CompressionConfig::default(),
48 ask_rx: None,
49 proxy_tool_defs: builder.proxy_tool_defs,
50 proxy_executor: builder.proxy_executor,
51 }
52 }
53
54 pub fn event_sender(&self) -> mpsc::Sender<AgentEvent> {
56 self.event_tx.clone()
57 }
58
59 pub fn set_ask_channel(&mut self, rx: mpsc::Receiver<String>) {
61 self.ask_rx = Some(rx);
62 }
63
64 pub fn set_proxy_executor(&mut self, executor: Arc<dyn ProxyToolExecutor>, tool_defs: Vec<ProxyToolDef>) {
66 self.proxy_executor = Some(executor);
67 self.proxy_tool_defs = tool_defs;
68 }
69
70 pub fn set_cancel_token(&mut self, token: CancellationToken) {
72 self.cancel_token = Some(token);
73 }
74
75 pub fn set_approve_mode(&mut self, mode: ApproveMode) {
77 let old = ApproveMode::from_u8(self.approve_mode.load(Ordering::Relaxed));
78 log::info!("Agent approve mode changed: {} -> {}", old, mode);
79 self.approve_mode.store(mode.to_u8(), Ordering::Relaxed);
80 }
81
82 pub fn approve_mode_shared(&self) -> Arc<AtomicU8> {
84 self.approve_mode.clone()
85 }
86
87 pub fn set_approve_mode_shared(&mut self, shared: Arc<AtomicU8>) {
89 self.approve_mode = shared;
90 }
91
92 pub fn update_memory_summary(&mut self, summary: Option<String>) {
95 self.memory_summary = summary;
96 self.system_prompt = prompt::build_system_prompt(
98 &self.profile,
99 &self.skills,
100 self.project_overview.as_deref(),
101 self.memory_summary.as_deref(),
102 );
103 }
104
105 pub fn refresh_codegraph_tools(&mut self) {
109 if let Some(path) = &self.project_path {
110 let should_have_codegraph = crate::tools::codegraph::should_inject_codegraph_tools(path);
112
113 let has_codegraph = self.tools.iter().any(|t| {
115 let name = t.definition().name;
116 name.starts_with("code_") && name != "code_review"
117 });
118
119 if should_have_codegraph != has_codegraph {
121 if should_have_codegraph {
123 let codegraph_tools = crate::tools::codegraph::codegraph_tools(path);
124 for tool in codegraph_tools {
125 self.tools.push(Arc::from(tool));
126 }
127 self.system_prompt = prompt::build_system_prompt_with_workflows(
129 &self.profile,
130 &self.skills,
131 self.project_overview.as_deref(),
132 self.memory_summary.as_deref(),
133 Some(path),
134 );
135 } else {
136 self.tools.retain(|t| {
138 let name = t.definition().name;
139 !name.starts_with("code_") || name == "code_review"
140 });
141 self.system_prompt = prompt::build_system_prompt_with_workflows(
143 &self.profile,
144 &self.skills,
145 self.project_overview.as_deref(),
146 self.memory_summary.as_deref(),
147 Some(path),
148 );
149 }
150 }
151 }
152 }
153
154 pub async fn run(&mut self, user_input: String) -> Result<Vec<AgentEvent>> {
156 self.emit(AgentEvent::session_started())?;
157
158 self.messages.push(Message {
159 role: Role::User,
160 content: MessageContent::Text(user_input.clone()),
161 });
162
163 let mut iterations = 0;
164 let mut should_continue = true;
165 const ITERATION_WARNING_THRESHOLD: usize = MAX_ITERATIONS - 10;
166
167 while should_continue && iterations < MAX_ITERATIONS {
168 iterations += 1;
169
170 if let Some(token) = &self.cancel_token
171 && token.is_cancelled()
172 {
173 self.emit(AgentEvent::error(
174 prompt::MSG_OPERATION_CANCELLED.to_string(),
175 None,
176 None,
177 ))?;
178 break;
179 }
180
181 if iterations == ITERATION_WARNING_THRESHOLD {
183 self.messages.push(Message {
184 role: Role::User,
185 content: MessageContent::Text(
186 prompt::MSG_ITERATION_WARNING
187 .replace("{iterations}", &iterations.to_string())
188 .replace("{max_iterations}", &MAX_ITERATIONS.to_string()),
189 ),
190 });
191 }
192
193 let context_size = self.provider.context_size();
196 let estimated_tokens = estimate_total_tokens(&self.messages);
197
198 if should_compress(estimated_tokens, context_size, &self.compression_config) {
199 self.emit(AgentEvent::progress("⚠️ 上下文过大,正在预压缩...", None))?;
200
201 match compress_messages(
202 &self.messages,
203 CompressionStrategy::SlidingWindow,
204 &self.compression_config,
205 ) {
206 Ok(compressed) => {
207 let compressed_tokens = estimate_total_tokens(&compressed);
208 self.messages = compressed;
209 crate::debug::debug_log().compression(
210 estimated_tokens,
211 compressed_tokens,
212 compressed_tokens as f32 / estimated_tokens as f32,
213 );
214 }
215 Err(e) => {
216 self.emit(AgentEvent::progress(
217 format!("预压缩失败: {}", e),
218 None,
219 ))?;
220 }
221 }
222 }
223
224 let tool_defs: Vec<ToolDefinition> = {
226 let mut defs: Vec<ToolDefinition> = self.tools.iter().map(|t| {
227 let def = t.definition();
228 let description = def.description_for_llm();
229 ToolDefinition {
230 name: def.name,
231 description,
232 parameters: def.parameters,
233 is_priority: def.is_priority,
234 }
235 }).collect();
236 defs.extend(self.proxy_tool_defs.iter().map(|t| {
238 let def = &t.definition;
239 let description = def.description_for_llm();
240 ToolDefinition {
241 name: def.name.clone(),
242 description,
243 parameters: def.parameters.clone(),
244 is_priority: def.is_priority,
245 }
246 }));
247 defs
248 };
249 let request = ChatRequest {
250 system: Some(self.system_prompt.clone()),
251 messages: self.messages.clone(),
252 max_tokens: self.max_tokens,
253 tools: tool_defs,
254 think: self.think,
255 enable_caching: true,
256 server_tools: Vec::new(),
257 };
258
259 let response = self.call_streaming(&request).await?;
260
261 self.track_usage(&response.usage);
262
263 crate::debug::debug_log().api_call(
264 &self.model_name,
265 response.usage.input_tokens,
266 response.usage.cache_read_input_tokens > 0,
267 );
268
269 should_continue = self.process_response(&response).await?;
270
271 if !should_continue && iterations < MAX_ITERATIONS - 1
273 && self.has_pending_todos() {
274 self.messages.push(Message {
275 role: Role::User,
276 content: MessageContent::Text(prompt::MSG_PENDING_TODOS.to_string()),
277 });
278 should_continue = true;
279 }
280
281 let context_size = self.provider.context_size();
282 let api_tokens = self.last_input_tokens.load(Ordering::Relaxed) as u32;
283 let estimated_tokens = estimate_total_tokens(&self.messages);
284
285 let current_tokens = if api_tokens > 0 && api_tokens >= estimated_tokens / 2 {
286 api_tokens
287 } else {
288 estimated_tokens
289 };
290
291 if let Some(ctx_size) = context_size {
294 self.emit(AgentEvent::with_data(
296 EventType::ContextSize,
297 EventData::ContextSize {
298 context_size: ctx_size as u64,
299 },
300 ))?;
301
302 let usage_ratio = current_tokens as f64 / ctx_size as f64;
303 if usage_ratio >= 0.3 {
304 crate::debug::debug_log().log(
305 "checkcompress",
306 &format!(
307 "usage={:.1}%, tokens={}, context={}, threshold={}%",
308 usage_ratio * 100.0,
309 current_tokens,
310 ctx_size,
311 self.compression_config.threshold * 100.0
312 ),
313 );
314 }
315 }
316
317 if should_compress(current_tokens, context_size, &self.compression_config) {
318 self.emit(AgentEvent::progress(prompt::MSG_COMPRESSING_CONTEXT, None))?;
319
320 let original_tokens = current_tokens;
321
322 match compress_messages(
323 &self.messages,
324 CompressionStrategy::SlidingWindow,
325 &self.compression_config,
326 ) {
327 Ok(compressed) => {
328 let compressed_tokens = estimate_total_tokens(&compressed);
329 self.messages = compressed;
330 self.total_input_tokens
331 .store(compressed_tokens as u64, Ordering::Relaxed);
332 self.last_input_tokens
333 .store(compressed_tokens as u64, Ordering::Relaxed);
334
335 let ratio = compressed_tokens as f32 / original_tokens as f32;
336 crate::debug::debug_log().compression(
337 original_tokens,
338 compressed_tokens,
339 ratio,
340 );
341
342 self.emit(AgentEvent::with_data(
343 EventType::CompressionCompleted,
344 EventData::Compression {
345 original_tokens: original_tokens as u64,
346 compressed_tokens: compressed_tokens as u64,
347 ratio: compressed_tokens as f32 / original_tokens as f32,
348 },
349 ))?;
350 }
351 Err(e) => {
352 self.emit(AgentEvent::progress(
353 format!("{}{}", prompt::MSG_COMPRESSION_FAILED, e),
354 None,
355 ))?;
356 }
357 }
358 }
359 }
360
361 if iterations >= MAX_ITERATIONS && should_continue {
363 self.emit(AgentEvent::error(
364 prompt::MSG_MAX_ITERATIONS_REACHED
365 .replace("{max_iterations}", &MAX_ITERATIONS.to_string())
366 .replace("{iterations}", &iterations.to_string()),
367 Some("MAX_ITERATIONS_REACHED".to_string()),
368 Some("agent/run.rs".to_string()),
369 ))?;
370 }
371
372 self.emit(AgentEvent::usage_with_cache(
373 self.total_input_tokens.load(Ordering::Relaxed),
374 self.total_output_tokens.load(Ordering::Relaxed),
375 0,
376 0,
377 ))?;
378
379 self.emit(AgentEvent::session_ended())?;
380
381 Ok(Vec::new())
382 }
383
384 pub fn set_messages(&mut self, messages: Vec<Message>) {
386 self.messages = messages;
387 }
388
389 pub fn get_messages(&self) -> &[Message] {
391 &self.messages
392 }
393
394 pub fn get_token_counts(&self) -> (u64, u64) {
396 (
397 self.total_input_tokens.load(Ordering::Relaxed),
398 self.total_output_tokens.load(Ordering::Relaxed),
399 )
400 }
401
402 pub fn clear_history(&mut self) {
404 self.messages.clear();
405 self.total_input_tokens.store(0, Ordering::Relaxed);
406 self.total_output_tokens.store(0, Ordering::Relaxed);
407 self.last_input_tokens.store(0, Ordering::Relaxed);
408 }
409
410 pub fn message_count(&self) -> usize {
412 self.messages.len()
413 }
414}