1use anyhow::Result;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicU8, Ordering};
6use tokio::sync::mpsc;
7
8use crate::approval::ApproveMode;
9use crate::cancel::CancellationToken;
10use crate::compress::{
11 CompressionStrategy, compress_messages, estimate_total_tokens, should_compress,
12};
13use crate::event::{AgentEvent, EventData, EventType};
14use crate::prompt;
15use crate::providers::{ChatRequest, Message, MessageContent, Role};
16use crate::tools::Tool;
17use crate::tools::ToolDefinition;
18use crate::tools::toolproxy::{ProxyToolDef, ProxyToolExecutor};
19
20use super::types::{Agent, AgentBuilder, MAX_ITERATIONS};
21
22impl Agent {
23 pub(crate) fn new(builder: AgentBuilder) -> Self {
24 let event_tx = builder.event_tx.unwrap_or_else(|| {
25 let (tx, _) = mpsc::channel(100);
26 tx
27 });
28
29 Self {
30 provider: builder.provider,
31 model_name: builder.model_name,
32 tools: builder.tools,
33 messages: Vec::new(),
34 system_prompt: builder.system_prompt,
35 max_tokens: builder.max_tokens,
36 context_size_override: builder.context_size_override,
37 think: builder.think,
38 approve_mode: Arc::new(AtomicU8::new(builder.approve_mode.to_u8())),
39 event_tx,
40 skills: builder.skills,
41 profile: builder.profile,
42 project_overview: builder.project_overview,
43 memory_summary: builder.memory_summary,
44 project_path: builder.project_path,
45 total_input_tokens: std::sync::atomic::AtomicU64::new(0),
46 total_output_tokens: std::sync::atomic::AtomicU64::new(0),
47 last_input_tokens: std::sync::atomic::AtomicU64::new(0),
48 cancel_token: None,
49 compression_config: crate::compress::CompressionConfig::default(),
50 ask_rx: None,
51 proxy_tool_defs: builder.proxy_tool_defs,
52 proxy_executor: builder.proxy_executor,
53 mcp_registry: builder.mcp_registry,
54 pending_input_rx: builder.pending_input_rx,
55 pending_inputs: Vec::new(),
56 previewed_tool_inputs: std::collections::HashSet::new(),
57 todo_reminder_count: std::collections::HashMap::new(),
58 }
59 }
60
61 pub(crate) fn effective_context_size(&self) -> Option<u32> {
63 self.context_size_override
64 .or_else(|| self.provider.context_size())
65 }
66
67 pub fn event_sender(&self) -> mpsc::Sender<AgentEvent> {
69 self.event_tx.clone()
70 }
71
72 pub fn set_ask_channel(&mut self, rx: mpsc::Receiver<String>) {
74 self.ask_rx = Some(rx);
75 }
76
77 pub fn set_proxy_executor(
79 &mut self,
80 executor: Arc<dyn ProxyToolExecutor>,
81 tool_defs: Vec<ProxyToolDef>,
82 ) {
83 self.proxy_executor = Some(executor);
84 self.proxy_tool_defs = tool_defs;
85 }
86
87 pub fn set_cancel_token(&mut self, token: CancellationToken) {
89 self.cancel_token = Some(token);
90 }
91
92 pub fn set_approve_mode(&mut self, mode: ApproveMode) {
94 let old = ApproveMode::from_u8(self.approve_mode.load(Ordering::Relaxed));
95 log::info!("Agent approve mode changed: {} -> {}", old, mode);
96 self.approve_mode.store(mode.to_u8(), Ordering::Relaxed);
97 }
98
99 pub fn approve_mode_shared(&self) -> Arc<AtomicU8> {
101 self.approve_mode.clone()
102 }
103
104 pub fn set_approve_mode_shared(&mut self, shared: Arc<AtomicU8>) {
106 self.approve_mode = shared;
107 }
108
109 pub fn update_memory_summary(&mut self, summary: Option<String>) {
112 self.memory_summary = summary;
113 self.system_prompt = prompt::build_system_prompt(
115 &self.profile,
116 &self.skills,
117 self.project_overview.as_deref(),
118 self.memory_summary.as_deref(),
119 );
120 }
121
122 pub fn refresh_codegraph_tools(&mut self) {
126 if let Some(path) = &self.project_path {
127 let should_have_codegraph =
129 crate::tools::codegraph::should_inject_codegraph_tools(path);
130
131 let has_codegraph = self.tools.iter().any(|t| {
133 let name = t.definition().name;
134 name.starts_with("code_") && name != "code_review"
135 });
136
137 if should_have_codegraph != has_codegraph {
139 if should_have_codegraph {
141 let codegraph_tools = crate::tools::codegraph::codegraph_tools(path);
142 for tool in codegraph_tools {
143 self.tools.push(Arc::from(tool));
144 }
145 self.system_prompt = prompt::build_system_prompt_with_workflows(
147 &self.profile,
148 &self.skills,
149 self.project_overview.as_deref(),
150 self.memory_summary.as_deref(),
151 Some(path),
152 None, );
154 } else {
155 self.tools.retain(|t| {
157 let name = t.definition().name;
158 !name.starts_with("code_") || name == "code_review"
159 });
160 self.system_prompt = prompt::build_system_prompt_with_workflows(
162 &self.profile,
163 &self.skills,
164 self.project_overview.as_deref(),
165 self.memory_summary.as_deref(),
166 Some(path),
167 None, );
169 }
170 }
171 }
172 }
173
174 pub async fn run(&mut self, user_input: String) -> Result<Vec<AgentEvent>> {
176 self.emit(AgentEvent::session_started())?;
177
178 self.messages.push(Message {
179 role: Role::User,
180 content: MessageContent::Text(user_input.clone()),
181 });
182
183 let mut iterations = 0;
184 let mut should_continue = true;
185 const ITERATION_WARNING_THRESHOLD: usize = MAX_ITERATIONS - 10;
186
187 while should_continue && iterations < MAX_ITERATIONS {
188 iterations += 1;
189
190 self.drain_pending_inputs();
193 if self.has_pending_inputs() {
194 let pending = self.take_pending_inputs();
195 let count = pending.len();
196 let merged = pending.join("\n\n---\n\n");
197 log::info!("Adding {} pending input messages to request", count);
198
199 self.emit(AgentEvent::queue_processed(count, pending.clone()))?;
201
202 self.messages.push(Message {
203 role: Role::User,
204 content: MessageContent::Text(merged),
205 });
206 }
207
208 if let Some(token) = &self.cancel_token
209 && token.is_cancelled()
210 {
211 self.emit(AgentEvent::error(
212 prompt::MSG_OPERATION_CANCELLED.to_string(),
213 None,
214 None,
215 ))?;
216 break;
217 }
218
219 if iterations == ITERATION_WARNING_THRESHOLD {
221 self.emit(AgentEvent::progress(
222 prompt::MSG_ITERATION_WARNING_UI
223 .replace("{iterations}", &iterations.to_string())
224 .replace("{max_iterations}", &MAX_ITERATIONS.to_string()),
225 None,
226 ))?;
227 }
228
229 let context_size = self.effective_context_size();
232 let estimated_tokens = estimate_total_tokens(&self.messages);
233
234 if should_compress(estimated_tokens, context_size, &self.compression_config) {
235 self.emit(AgentEvent::progress("⚠️ 上下文过大,正在预压缩...", None))?;
236
237 match compress_messages(
238 &self.messages,
239 CompressionStrategy::SlidingWindow,
240 &self.compression_config,
241 ) {
242 Ok(compressed) => {
243 let compressed_tokens = estimate_total_tokens(&compressed);
244 self.messages = compressed;
245 crate::debug::debug_log().compression(
246 estimated_tokens,
247 compressed_tokens,
248 compressed_tokens as f32 / estimated_tokens as f32,
249 );
250 }
251 Err(e) => {
252 self.emit(AgentEvent::progress(format!("预压缩失败: {}", e), None))?;
253 }
254 }
255 }
256
257 let tool_defs: Vec<ToolDefinition> = {
259 let mut defs: Vec<ToolDefinition> = self
260 .tools
261 .iter()
262 .map(|t| {
263 let def = t.definition();
264 let description = def.description_for_llm();
265 ToolDefinition {
266 name: def.name,
267 description,
268 parameters: def.parameters,
269 is_priority: def.is_priority,
270 }
271 })
272 .collect();
273 defs.extend(self.proxy_tool_defs.iter().map(|t| {
275 let def = &t.definition;
276 let description = def.description_for_llm();
277 ToolDefinition {
278 name: def.name.clone(),
279 description,
280 parameters: def.parameters.clone(),
281 is_priority: def.is_priority,
282 }
283 }));
284 defs
285 };
286 let request = ChatRequest {
287 system: Some(self.system_prompt.clone()),
288 messages: self.messages.clone(),
289 max_tokens: self.max_tokens,
290 tools: tool_defs,
291 think: self.think,
292 enable_caching: true,
293 server_tools: Vec::new(),
294 };
295
296 let response = self.call_streaming(&request).await?;
297
298 self.track_usage(&response.usage);
299
300 crate::debug::debug_log().api_call(
301 &self.model_name,
302 response.usage.input_tokens,
303 response.usage.cache_read_input_tokens > 0,
304 );
305
306 should_continue = self.process_response(&response).await?;
307
308 if !should_continue && iterations < MAX_ITERATIONS - 1 {
311 self.drain_pending_inputs();
313
314 if self.has_pending_inputs() {
315 log::info!("Agent: found pending inputs at session end, continuing loop");
316 should_continue = true;
317 continue; }
319
320 if self.last_message_was_todo_reminder() {
323 log::info!("Skipping todo check: reminder already sent in recent messages");
324 } else {
325 const MAX_TODO_REMINDERS: usize = 2;
326
327 let reminder_count_clone = self.todo_reminder_count.clone();
329 let (pending, all_at_limit) = self.get_pending_todos_with_limit(
330 &reminder_count_clone,
331 MAX_TODO_REMINDERS
332 );
333
334 if !pending.is_empty() {
335 for (_, content) in &pending {
337 *self.todo_reminder_count.entry(content.clone()).or_insert(0) += 1;
338 }
339
340 let pending_list = pending
341 .iter()
342 .map(|(status, content)| {
343 let marker = match status.as_str() {
344 "in_progress" => "[~]",
345 "pending" => "[ ]",
346 _ => "[?]",
347 };
348 format!(" {} {}", marker, content)
349 })
350 .collect::<Vec<_>>()
351 .join("\n");
352
353 let reminder = format!(
354 "📋 任务尚未完成。以下待办项需要处理:\n{}\n\n请继续执行,或在 todo_write 中标记为 completed。如遇阻塞请说明原因。",
355 pending_list
356 );
357
358 self.messages.push(Message {
359 role: Role::User,
360 content: MessageContent::Text(reminder),
361 });
362 should_continue = true;
363 } else if all_at_limit && !self.todo_reminder_count.is_empty() {
364 let remaining_count = self.todo_reminder_count.len();
367 self.emit(AgentEvent::progress(
368 format!(
369 "⚠️ 会话结束:{} 个待办项未完成(已提醒 {} 次,达到上限)",
370 remaining_count, MAX_TODO_REMINDERS
371 ),
372 None,
373 ))?;
374 log::warn!(
375 "Session ending with {} incomplete todos (reminder limit reached)",
376 remaining_count
377 );
378 }
379 }
380 }
381
382 let context_size = self.effective_context_size();
383 let api_tokens = self.last_input_tokens.load(Ordering::Relaxed) as u32;
384 let estimated_tokens = estimate_total_tokens(&self.messages);
385
386 let current_tokens = if api_tokens > 0 && api_tokens >= estimated_tokens / 2 {
387 api_tokens
388 } else {
389 estimated_tokens
390 };
391
392 if let Some(ctx_size) = context_size {
395 self.emit(AgentEvent::with_data(
397 EventType::ContextSize,
398 EventData::ContextSize {
399 context_size: ctx_size as u64,
400 },
401 ))?;
402
403 let usage_ratio = current_tokens as f64 / ctx_size as f64;
404 if usage_ratio >= 0.3 {
405 crate::debug::debug_log().log(
406 "checkcompress",
407 &format!(
408 "usage={:.1}%, tokens={}, context={}, threshold={}%",
409 usage_ratio * 100.0,
410 current_tokens,
411 ctx_size,
412 self.compression_config.threshold * 100.0
413 ),
414 );
415 }
416 }
417
418 if should_compress(current_tokens, context_size, &self.compression_config) {
419 self.emit(AgentEvent::progress(prompt::MSG_COMPRESSING_CONTEXT, None))?;
420
421 let original_tokens = current_tokens;
422
423 match compress_messages(
424 &self.messages,
425 CompressionStrategy::SlidingWindow,
426 &self.compression_config,
427 ) {
428 Ok(compressed) => {
429 let compressed_tokens = estimate_total_tokens(&compressed);
430 self.messages = compressed;
431 self.total_input_tokens
432 .store(compressed_tokens as u64, Ordering::Relaxed);
433 self.last_input_tokens
434 .store(compressed_tokens as u64, Ordering::Relaxed);
435
436 let ratio = compressed_tokens as f32 / original_tokens as f32;
437 crate::debug::debug_log().compression(
438 original_tokens,
439 compressed_tokens,
440 ratio,
441 );
442
443 self.emit(AgentEvent::with_data(
444 EventType::CompressionCompleted,
445 EventData::Compression {
446 original_tokens: original_tokens as u64,
447 compressed_tokens: compressed_tokens as u64,
448 ratio: compressed_tokens as f32 / original_tokens as f32,
449 },
450 ))?;
451 }
452 Err(e) => {
453 self.emit(AgentEvent::progress(
454 format!("{}{}", prompt::MSG_COMPRESSION_FAILED, e),
455 None,
456 ))?;
457 }
458 }
459 }
460 }
461
462 if iterations >= MAX_ITERATIONS && should_continue {
464 self.emit(AgentEvent::error(
465 prompt::MSG_MAX_ITERATIONS_REACHED
466 .replace("{max_iterations}", &MAX_ITERATIONS.to_string())
467 .replace("{iterations}", &iterations.to_string()),
468 Some("MAX_ITERATIONS_REACHED".to_string()),
469 Some("agent/run.rs".to_string()),
470 ))?;
471 }
472
473 self.emit(AgentEvent::usage_with_cache(
474 self.total_input_tokens.load(Ordering::Relaxed),
475 self.total_output_tokens.load(Ordering::Relaxed),
476 0,
477 0,
478 ))?;
479
480 self.emit(AgentEvent::session_ended())?;
481
482 Ok(Vec::new())
483 }
484
485 pub fn set_messages(&mut self, messages: Vec<Message>) {
487 self.messages = messages;
488 }
489
490 pub fn get_messages(&self) -> &[Message] {
492 &self.messages
493 }
494
495 pub fn get_tools(&self) -> &[Arc<dyn Tool>] {
497 &self.tools
498 }
499
500 pub fn get_system_prompt(&self) -> &str {
502 &self.system_prompt
503 }
504
505 pub fn get_token_counts(&self) -> (u64, u64) {
507 (
508 self.total_input_tokens.load(Ordering::Relaxed),
509 self.total_output_tokens.load(Ordering::Relaxed),
510 )
511 }
512
513 pub fn clear_history(&mut self) {
515 self.messages.clear();
516 self.total_input_tokens.store(0, Ordering::Relaxed);
517 self.total_output_tokens.store(0, Ordering::Relaxed);
518 self.last_input_tokens.store(0, Ordering::Relaxed);
519 }
520
521 pub fn message_count(&self) -> usize {
523 self.messages.len()
524 }
525
526 pub async fn add_mcp_server(
540 &mut self,
541 name: &str,
542 config: crate::mcp::McpServerConfig,
543 ) -> Result<()> {
544 if let Some(registry) = &self.mcp_registry {
545 let mut reg = registry.write().await;
546 reg.add_server(name.to_string(), config);
547 log::info!("MCP server '{}' added to registry", name);
548 } else {
549 log::warn!("MCP registry not initialized, cannot add server '{}'", name);
550 }
551 Ok(())
552 }
553
554 pub async fn remove_mcp_server(&mut self, name: &str) -> Result<()> {
556 if let Some(registry) = &self.mcp_registry {
557 let mut reg = registry.write().await;
558 reg.remove_server(name).await?;
559 log::info!("MCP server '{}' removed from registry", name);
560 }
561 Ok(())
562 }
563
564 pub async fn mcp_server_status(&self) -> Vec<crate::mcp::ServerStatus> {
566 if let Some(registry) = &self.mcp_registry {
567 let reg = registry.read().await;
568 reg.server_status().await.values().cloned().collect()
569 } else {
570 Vec::new()
571 }
572 }
573
574 pub async fn start_mcp_server(
576 &self,
577 name: &str,
578 ) -> Result<Vec<Arc<crate::mcp::McpToolWrapper>>> {
579 if let Some(registry) = &self.mcp_registry {
580 let reg = registry.read().await;
581 if let Some(placeholder) = reg.get_server(name) {
582 let tools = placeholder.start().await?;
583 log::info!("MCP server '{}' started with {} tools", name, tools.len());
584 Ok(tools)
585 } else {
586 Err(anyhow::anyhow!(
587 "MCP server '{}' not found in registry",
588 name
589 ))
590 }
591 } else {
592 Err(anyhow::anyhow!("MCP registry not initialized"))
593 }
594 }
595}