1use anyhow::Result;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicU8, Ordering};
6use tokio::sync::mpsc;
7
8use crate::approval::ApproveMode;
9use crate::cancel::CancellationToken;
10use crate::compress::{
11 CompressionStrategy, compress_messages, estimate_total_tokens, should_compress,
12};
13use crate::event::{AgentEvent, EventData, EventType};
14use crate::prompt;
15use crate::providers::{ChatRequest, Message, MessageContent, Role};
16use crate::tools::Tool;
17use crate::tools::ToolDefinition;
18use crate::tools::toolproxy::{ProxyToolExecutor, ProxyToolDef};
19
20use super::types::{Agent, AgentBuilder, MAX_ITERATIONS};
21
22impl Agent {
23 pub(crate) fn new(builder: AgentBuilder) -> Self {
24 let event_tx = builder.event_tx.unwrap_or_else(|| {
25 let (tx, _) = mpsc::channel(100);
26 tx
27 });
28
29 Self {
30 provider: builder.provider,
31 model_name: builder.model_name,
32 tools: builder.tools,
33 messages: Vec::new(),
34 system_prompt: builder.system_prompt,
35 max_tokens: builder.max_tokens,
36 think: builder.think,
37 approve_mode: Arc::new(AtomicU8::new(builder.approve_mode.to_u8())),
38 event_tx,
39 skills: builder.skills,
40 profile: builder.profile,
41 project_overview: builder.project_overview,
42 memory_summary: builder.memory_summary,
43 project_path: builder.project_path,
44 total_input_tokens: std::sync::atomic::AtomicU64::new(0),
45 total_output_tokens: std::sync::atomic::AtomicU64::new(0),
46 last_input_tokens: std::sync::atomic::AtomicU64::new(0),
47 cancel_token: None,
48 compression_config: crate::compress::CompressionConfig::default(),
49 ask_rx: None,
50 proxy_tool_defs: builder.proxy_tool_defs,
51 proxy_executor: builder.proxy_executor,
52 mcp_registry: builder.mcp_registry,
53 pending_input_rx: builder.pending_input_rx,
54 pending_inputs: Vec::new(),
55 }
56 }
57
58 pub fn event_sender(&self) -> mpsc::Sender<AgentEvent> {
60 self.event_tx.clone()
61 }
62
63 pub fn set_ask_channel(&mut self, rx: mpsc::Receiver<String>) {
65 self.ask_rx = Some(rx);
66 }
67
68 pub fn set_proxy_executor(&mut self, executor: Arc<dyn ProxyToolExecutor>, tool_defs: Vec<ProxyToolDef>) {
70 self.proxy_executor = Some(executor);
71 self.proxy_tool_defs = tool_defs;
72 }
73
74 pub fn set_cancel_token(&mut self, token: CancellationToken) {
76 self.cancel_token = Some(token);
77 }
78
79 pub fn set_approve_mode(&mut self, mode: ApproveMode) {
81 let old = ApproveMode::from_u8(self.approve_mode.load(Ordering::Relaxed));
82 log::info!("Agent approve mode changed: {} -> {}", old, mode);
83 self.approve_mode.store(mode.to_u8(), Ordering::Relaxed);
84 }
85
86 pub fn approve_mode_shared(&self) -> Arc<AtomicU8> {
88 self.approve_mode.clone()
89 }
90
91 pub fn set_approve_mode_shared(&mut self, shared: Arc<AtomicU8>) {
93 self.approve_mode = shared;
94 }
95
96 pub fn update_memory_summary(&mut self, summary: Option<String>) {
99 self.memory_summary = summary;
100 self.system_prompt = prompt::build_system_prompt(
102 &self.profile,
103 &self.skills,
104 self.project_overview.as_deref(),
105 self.memory_summary.as_deref(),
106 );
107 }
108
109 pub fn refresh_codegraph_tools(&mut self) {
113 if let Some(path) = &self.project_path {
114 let should_have_codegraph = crate::tools::codegraph::should_inject_codegraph_tools(path);
116
117 let has_codegraph = self.tools.iter().any(|t| {
119 let name = t.definition().name;
120 name.starts_with("code_") && name != "code_review"
121 });
122
123 if should_have_codegraph != has_codegraph {
125 if should_have_codegraph {
127 let codegraph_tools = crate::tools::codegraph::codegraph_tools(path);
128 for tool in codegraph_tools {
129 self.tools.push(Arc::from(tool));
130 }
131 self.system_prompt = prompt::build_system_prompt_with_workflows(
133 &self.profile,
134 &self.skills,
135 self.project_overview.as_deref(),
136 self.memory_summary.as_deref(),
137 Some(path),
138 None, );
140 } else {
141 self.tools.retain(|t| {
143 let name = t.definition().name;
144 !name.starts_with("code_") || name == "code_review"
145 });
146 self.system_prompt = prompt::build_system_prompt_with_workflows(
148 &self.profile,
149 &self.skills,
150 self.project_overview.as_deref(),
151 self.memory_summary.as_deref(),
152 Some(path),
153 None, );
155 }
156 }
157 }
158 }
159
160 pub async fn run(&mut self, user_input: String) -> Result<Vec<AgentEvent>> {
162 self.emit(AgentEvent::session_started())?;
163
164 self.messages.push(Message {
165 role: Role::User,
166 content: MessageContent::Text(user_input.clone()),
167 });
168
169 let mut iterations = 0;
170 let mut should_continue = true;
171 const ITERATION_WARNING_THRESHOLD: usize = MAX_ITERATIONS - 10;
172
173 while should_continue && iterations < MAX_ITERATIONS {
174 iterations += 1;
175
176 if self.has_pending_inputs() {
179 let pending = self.take_pending_inputs();
180 let merged = pending.join("\n\n---\n\n");
181 log::info!("Adding {} pending input messages to request", pending.len());
182
183 self.emit(AgentEvent::progress(
184 format!("📝 收到 {} 条追加消息", pending.len()),
185 None,
186 ))?;
187
188 self.messages.push(Message {
189 role: Role::User,
190 content: MessageContent::Text(merged),
191 });
192 }
193
194 if let Some(token) = &self.cancel_token
195 && token.is_cancelled()
196 {
197 self.emit(AgentEvent::error(
198 prompt::MSG_OPERATION_CANCELLED.to_string(),
199 None,
200 None,
201 ))?;
202 break;
203 }
204
205 if iterations == ITERATION_WARNING_THRESHOLD {
207 self.emit(AgentEvent::progress(
208 prompt::MSG_ITERATION_WARNING_UI
209 .replace("{iterations}", &iterations.to_string())
210 .replace("{max_iterations}", &MAX_ITERATIONS.to_string()),
211 None,
212 ))?;
213 }
214
215 let context_size = self.provider.context_size();
218 let estimated_tokens = estimate_total_tokens(&self.messages);
219
220 if should_compress(estimated_tokens, context_size, &self.compression_config) {
221 self.emit(AgentEvent::progress("⚠️ 上下文过大,正在预压缩...", None))?;
222
223 match compress_messages(
224 &self.messages,
225 CompressionStrategy::SlidingWindow,
226 &self.compression_config,
227 ) {
228 Ok(compressed) => {
229 let compressed_tokens = estimate_total_tokens(&compressed);
230 self.messages = compressed;
231 crate::debug::debug_log().compression(
232 estimated_tokens,
233 compressed_tokens,
234 compressed_tokens as f32 / estimated_tokens as f32,
235 );
236 }
237 Err(e) => {
238 self.emit(AgentEvent::progress(
239 format!("预压缩失败: {}", e),
240 None,
241 ))?;
242 }
243 }
244 }
245
246 let tool_defs: Vec<ToolDefinition> = {
248 let mut defs: Vec<ToolDefinition> = self.tools.iter().map(|t| {
249 let def = t.definition();
250 let description = def.description_for_llm();
251 ToolDefinition {
252 name: def.name,
253 description,
254 parameters: def.parameters,
255 is_priority: def.is_priority,
256 }
257 }).collect();
258 defs.extend(self.proxy_tool_defs.iter().map(|t| {
260 let def = &t.definition;
261 let description = def.description_for_llm();
262 ToolDefinition {
263 name: def.name.clone(),
264 description,
265 parameters: def.parameters.clone(),
266 is_priority: def.is_priority,
267 }
268 }));
269 defs
270 };
271 let request = ChatRequest {
272 system: Some(self.system_prompt.clone()),
273 messages: self.messages.clone(),
274 max_tokens: self.max_tokens,
275 tools: tool_defs,
276 think: self.think,
277 enable_caching: true,
278 server_tools: Vec::new(),
279 };
280
281 let response = self.call_streaming(&request).await?;
282
283 self.track_usage(&response.usage);
284
285 crate::debug::debug_log().api_call(
286 &self.model_name,
287 response.usage.input_tokens,
288 response.usage.cache_read_input_tokens > 0,
289 );
290
291 should_continue = self.process_response(&response).await?;
292
293 if !should_continue && iterations < MAX_ITERATIONS - 1 {
295 if self.has_pending_inputs() {
297 let pending = self.take_pending_inputs();
298 let merged = pending.join("\n\n---\n\n");
299 log::info!("Model stopped but user appended {} messages, continuing", pending.len());
300
301 self.emit(AgentEvent::progress(
302 format!("📝 处理 {} 条追加消息", pending.len()),
303 None,
304 ))?;
305
306 self.messages.push(Message {
307 role: Role::User,
308 content: MessageContent::Text(merged),
309 });
310 should_continue = true;
311 } else {
312 let pending = self.get_pending_todos();
314 if !pending.is_empty() {
315 let pending_list = pending.iter()
317 .map(|(status, content)| {
318 let marker = match status.as_str() {
319 "in_progress" => "[~]",
320 "pending" => "[ ]",
321 _ => "[?]"
322 };
323 format!(" {} {}", marker, content)
324 })
325 .collect::<Vec<_>>()
326 .join("\n");
327
328 let reminder = format!(
329 "📋 任务尚未完成。以下待办项需要处理:\n{}\n\n请继续执行,或在 todo_write 中标记为 completed。如遇阻塞请说明原因。",
330 pending_list
331 );
332
333 self.messages.push(Message {
334 role: Role::User,
335 content: MessageContent::Text(reminder),
336 });
337 should_continue = true;
338 }
339 }
340 }
341
342 let context_size = self.provider.context_size();
343 let api_tokens = self.last_input_tokens.load(Ordering::Relaxed) as u32;
344 let estimated_tokens = estimate_total_tokens(&self.messages);
345
346 let current_tokens = if api_tokens > 0 && api_tokens >= estimated_tokens / 2 {
347 api_tokens
348 } else {
349 estimated_tokens
350 };
351
352 if let Some(ctx_size) = context_size {
355 self.emit(AgentEvent::with_data(
357 EventType::ContextSize,
358 EventData::ContextSize {
359 context_size: ctx_size as u64,
360 },
361 ))?;
362
363 let usage_ratio = current_tokens as f64 / ctx_size as f64;
364 if usage_ratio >= 0.3 {
365 crate::debug::debug_log().log(
366 "checkcompress",
367 &format!(
368 "usage={:.1}%, tokens={}, context={}, threshold={}%",
369 usage_ratio * 100.0,
370 current_tokens,
371 ctx_size,
372 self.compression_config.threshold * 100.0
373 ),
374 );
375 }
376 }
377
378 if should_compress(current_tokens, context_size, &self.compression_config) {
379 self.emit(AgentEvent::progress(prompt::MSG_COMPRESSING_CONTEXT, None))?;
380
381 let original_tokens = current_tokens;
382
383 match compress_messages(
384 &self.messages,
385 CompressionStrategy::SlidingWindow,
386 &self.compression_config,
387 ) {
388 Ok(compressed) => {
389 let compressed_tokens = estimate_total_tokens(&compressed);
390 self.messages = compressed;
391 self.total_input_tokens
392 .store(compressed_tokens as u64, Ordering::Relaxed);
393 self.last_input_tokens
394 .store(compressed_tokens as u64, Ordering::Relaxed);
395
396 let ratio = compressed_tokens as f32 / original_tokens as f32;
397 crate::debug::debug_log().compression(
398 original_tokens,
399 compressed_tokens,
400 ratio,
401 );
402
403 self.emit(AgentEvent::with_data(
404 EventType::CompressionCompleted,
405 EventData::Compression {
406 original_tokens: original_tokens as u64,
407 compressed_tokens: compressed_tokens as u64,
408 ratio: compressed_tokens as f32 / original_tokens as f32,
409 },
410 ))?;
411 }
412 Err(e) => {
413 self.emit(AgentEvent::progress(
414 format!("{}{}", prompt::MSG_COMPRESSION_FAILED, e),
415 None,
416 ))?;
417 }
418 }
419 }
420 }
421
422 if iterations >= MAX_ITERATIONS && should_continue {
424 self.emit(AgentEvent::error(
425 prompt::MSG_MAX_ITERATIONS_REACHED
426 .replace("{max_iterations}", &MAX_ITERATIONS.to_string())
427 .replace("{iterations}", &iterations.to_string()),
428 Some("MAX_ITERATIONS_REACHED".to_string()),
429 Some("agent/run.rs".to_string()),
430 ))?;
431 }
432
433 self.emit(AgentEvent::usage_with_cache(
434 self.total_input_tokens.load(Ordering::Relaxed),
435 self.total_output_tokens.load(Ordering::Relaxed),
436 0,
437 0,
438 ))?;
439
440 self.emit(AgentEvent::session_ended())?;
441
442 Ok(Vec::new())
443 }
444
445 pub fn set_messages(&mut self, messages: Vec<Message>) {
447 self.messages = messages;
448 }
449
450 pub fn get_messages(&self) -> &[Message] {
452 &self.messages
453 }
454
455 pub fn get_tools(&self) -> &[Arc<dyn Tool>] {
457 &self.tools
458 }
459
460 pub fn get_system_prompt(&self) -> &str {
462 &self.system_prompt
463 }
464
465 pub fn get_token_counts(&self) -> (u64, u64) {
467 (
468 self.total_input_tokens.load(Ordering::Relaxed),
469 self.total_output_tokens.load(Ordering::Relaxed),
470 )
471 }
472
473 pub fn clear_history(&mut self) {
475 self.messages.clear();
476 self.total_input_tokens.store(0, Ordering::Relaxed);
477 self.total_output_tokens.store(0, Ordering::Relaxed);
478 self.last_input_tokens.store(0, Ordering::Relaxed);
479 }
480
481 pub fn message_count(&self) -> usize {
483 self.messages.len()
484 }
485
486 pub async fn add_mcp_server(&mut self, name: &str, config: crate::mcp::McpServerConfig) -> Result<()> {
500 if let Some(registry) = &self.mcp_registry {
501 let mut reg = registry.write().await;
502 reg.add_server(name.to_string(), config);
503 log::info!("MCP server '{}' added to registry", name);
504 } else {
505 log::warn!("MCP registry not initialized, cannot add server '{}'", name);
506 }
507 Ok(())
508 }
509
510 pub async fn remove_mcp_server(&mut self, name: &str) -> Result<()> {
512 if let Some(registry) = &self.mcp_registry {
513 let mut reg = registry.write().await;
514 reg.remove_server(name).await?;
515 log::info!("MCP server '{}' removed from registry", name);
516 }
517 Ok(())
518 }
519
520 pub async fn mcp_server_status(&self) -> Vec<crate::mcp::ServerStatus> {
522 if let Some(registry) = &self.mcp_registry {
523 let reg = registry.read().await;
524 reg.server_status().await.values().cloned().collect()
525 } else {
526 Vec::new()
527 }
528 }
529
530 pub async fn start_mcp_server(&self, name: &str) -> Result<Vec<Arc<crate::mcp::McpToolWrapper>>> {
532 if let Some(registry) = &self.mcp_registry {
533 let reg = registry.read().await;
534 if let Some(placeholder) = reg.get_server(name) {
535 let tools = placeholder.start().await?;
536 log::info!("MCP server '{}' started with {} tools", name, tools.len());
537 Ok(tools)
538 } else {
539 Err(anyhow::anyhow!("MCP server '{}' not found in registry", name))
540 }
541 } else {
542 Err(anyhow::anyhow!("MCP registry not initialized"))
543 }
544 }
545}