1use anyhow::Result;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicU8, Ordering};
6use tokio::sync::mpsc;
7
8use crate::approval::ApproveMode;
9use crate::cancel::CancellationToken;
10use crate::compress::{
11 CompressionStrategy, compress_messages, estimate_total_tokens, should_compress,
12};
13use crate::event::{AgentEvent, EventData, EventType};
14use crate::prompt;
15use crate::prompt::preprocess::{preprocess_with_skills, ProcessResult};
16use crate::providers::{ChatRequest, Message, MessageContent, Role};
17use crate::tools::Tool;
18use crate::tools::ToolDefinition;
19use crate::tools::toolproxy::{ProxyToolDef, ProxyToolExecutor};
20
21use super::types::{Agent, AgentBuilder, MAX_ITERATIONS};
22
23impl Agent {
24 pub(crate) fn new(builder: AgentBuilder) -> Self {
25 let event_tx = builder.event_tx.unwrap_or_else(|| {
26 let (tx, _) = mpsc::channel(100);
27 tx
28 });
29
30 Self {
35 provider: builder.provider,
36 model_name: builder.model_name,
37 tools: builder.tools,
38 messages: Vec::new(),
39 system_prompt: builder.system_prompt,
40 max_tokens: builder.max_tokens,
41 context_size_override: builder.context_size_override,
42 think: builder.think,
43 approve_mode: Arc::new(AtomicU8::new(builder.approve_mode.to_u8())),
44 event_tx,
45 skills: builder.skills,
46 profile: builder.profile,
47 project_overview: builder.project_overview,
48 memory_summary: builder.memory_summary,
49 project_path: builder.project_path,
50 total_input_tokens: std::sync::atomic::AtomicU64::new(0),
51 total_output_tokens: std::sync::atomic::AtomicU64::new(0),
52 last_input_tokens: std::sync::atomic::AtomicU64::new(0),
53 cancel_token: None,
54 compression_config: crate::compress::CompressionConfig::default(),
55 ask_rx: None,
56 proxy_tool_defs: builder.proxy_tool_defs,
57 proxy_executor: builder.proxy_executor,
58 mcp_registry: builder.mcp_registry,
59 lsp_registry: builder.lsp_registry,
60 pending_input_rx: builder.pending_input_rx,
61 pending_inputs: Vec::new(),
62 previewed_tool_inputs: std::collections::HashSet::new(),
63 todo_reminder_count: std::collections::HashMap::new(),
64 read_history: crate::tools::ReadHistoryTracker::new(),
65 }
66 }
67
68 pub(crate) fn effective_context_size(&self) -> Option<u32> {
70 self.context_size_override
71 .or_else(|| self.provider.context_size())
72 }
73
74 pub fn event_sender(&self) -> mpsc::Sender<AgentEvent> {
76 self.event_tx.clone()
77 }
78
79 pub fn set_ask_channel(&mut self, rx: mpsc::Receiver<String>) {
81 self.ask_rx = Some(rx);
82 }
83
84 pub fn set_proxy_executor(
86 &mut self,
87 executor: Arc<dyn ProxyToolExecutor>,
88 tool_defs: Vec<ProxyToolDef>,
89 ) {
90 self.proxy_executor = Some(executor);
91 self.proxy_tool_defs = tool_defs;
92 }
93
94 pub fn set_cancel_token(&mut self, token: CancellationToken) {
96 self.cancel_token = Some(token);
97 }
98
99 pub fn set_approve_mode(&mut self, mode: ApproveMode) {
101 let old = ApproveMode::from_u8(self.approve_mode.load(Ordering::Relaxed));
102 log::info!("Agent approve mode changed: {} -> {}", old, mode);
103 self.approve_mode.store(mode.to_u8(), Ordering::Relaxed);
104 }
105
106 pub fn approve_mode_shared(&self) -> Arc<AtomicU8> {
108 self.approve_mode.clone()
109 }
110
111 pub fn set_approve_mode_shared(&mut self, shared: Arc<AtomicU8>) {
113 self.approve_mode = shared;
114 }
115
116 pub fn update_memory_summary(&mut self, summary: Option<String>) {
119 self.memory_summary = summary;
120 self.system_prompt = prompt::build_system_prompt(
122 &self.profile,
123 &self.skills,
124 self.project_overview.as_deref(),
125 self.memory_summary.as_deref(),
126 );
127 }
128
129 pub fn refresh_codegraph_tools(&mut self) {
133 if let Some(path) = &self.project_path {
134 let should_have_codegraph =
136 crate::tools::codegraph::should_inject_codegraph_tools(path);
137
138 let has_codegraph = self.tools.iter().any(|t| {
140 let name = t.definition().name;
141 name.starts_with("code_") && name != "code_review"
142 });
143
144 if should_have_codegraph != has_codegraph {
146 if should_have_codegraph {
148 let codegraph_tools = crate::tools::codegraph::codegraph_tools(path);
149 for tool in codegraph_tools {
150 self.tools.push(Arc::from(tool));
151 }
152 self.system_prompt = prompt::build_system_prompt_with_workflows(
154 &self.profile,
155 &self.skills,
156 self.project_overview.as_deref(),
157 self.memory_summary.as_deref(),
158 Some(path),
159 None, );
161 } else {
162 self.tools.retain(|t| {
164 let name = t.definition().name;
165 !name.starts_with("code_") || name == "code_review"
166 });
167 self.system_prompt = prompt::build_system_prompt_with_workflows(
169 &self.profile,
170 &self.skills,
171 self.project_overview.as_deref(),
172 self.memory_summary.as_deref(),
173 Some(path),
174 None, );
176 }
177 }
178 }
179 }
180
181 pub async fn run(&mut self, user_input: String) -> Result<Vec<AgentEvent>> {
183 self.emit(AgentEvent::session_started())?;
184
185 let preprocess_result = self.preprocess_input(&user_input);
187
188 let processed_input = match preprocess_result {
190 ProcessResult::SkillTriggered {
191 skill_id,
192 confidence,
193 skill_body,
194 } => {
195 log::info!(
196 "Skill triggered: {} (confidence: {:.2})",
197 skill_id,
198 confidence
199 );
200 self.emit(AgentEvent::progress(
201 format!("🎯 触发技能: {}", skill_id),
202 None,
203 ))?;
204
205 if let Some(body) = skill_body {
207 let enhanced_input = format!(
209 "<command-name>{}</command-name>\n\n{}\n\n---\n\nUser request: {}",
210 skill_id,
211 body,
212 user_input
213 );
214 enhanced_input
215 } else {
216 let enhanced_input = format!(
218 "User invoked skill '{}'. Use the `skill` tool with name '{}' to load its instructions before proceeding.\n\nUser request: {}",
219 skill_id,
220 skill_id,
221 user_input
222 );
223 enhanced_input
224 }
225 }
226 ProcessResult::WorkflowTriggered {
227 workflow_id,
228 inputs,
229 } => {
230 log::info!("Workflow triggered: {} with inputs: {:?}", workflow_id, inputs);
231 self.emit(AgentEvent::progress(
232 format!("🔄 触发工作流: {}", workflow_id),
233 None,
234 ))?;
235 let inputs_json = serde_json::to_string_pretty(&inputs).unwrap_or_default();
237 let enhanced_input = format!(
238 "Workflow '{}' triggered with extracted inputs:\n{}\n\nUser request: {}",
239 workflow_id,
240 inputs_json,
241 user_input
242 );
243 enhanced_input
244 }
245 ProcessResult::Continue => {
246 user_input
248 }
249 };
250
251 self.messages.push(Message {
253 role: Role::User,
254 content: MessageContent::Text(processed_input),
255 });
256
257 let mut iterations = 0;
258 let mut should_continue = true;
259 const ITERATION_WARNING_THRESHOLD: usize = MAX_ITERATIONS - 10;
260
261 while should_continue && iterations < MAX_ITERATIONS {
262 iterations += 1;
263
264 self.drain_pending_inputs();
267 if self.has_pending_inputs() {
268 let pending = self.take_pending_inputs();
269 let count = pending.len();
270 let merged = pending.join("\n\n---\n\n");
271 log::info!("Adding {} pending input messages to request", count);
272
273 self.emit(AgentEvent::queue_processed(count, pending.clone()))?;
275
276 self.messages.push(Message {
277 role: Role::User,
278 content: MessageContent::Text(merged),
279 });
280 }
281
282 if let Some(token) = &self.cancel_token
283 && token.is_cancelled()
284 {
285 self.emit(AgentEvent::error(
286 prompt::MSG_OPERATION_CANCELLED.to_string(),
287 None,
288 None,
289 ))?;
290 break;
291 }
292
293 if iterations == ITERATION_WARNING_THRESHOLD {
295 self.emit(AgentEvent::progress(
296 prompt::MSG_ITERATION_WARNING_UI
297 .replace("{iterations}", &iterations.to_string())
298 .replace("{max_iterations}", &MAX_ITERATIONS.to_string()),
299 None,
300 ))?;
301 }
302
303 let context_size = self.effective_context_size();
306 let estimated_tokens = estimate_total_tokens(&self.messages);
307
308 if should_compress(estimated_tokens, context_size, &self.compression_config) {
309 self.emit(AgentEvent::progress("⚠️ 上下文过大,正在预压缩...", None))?;
310
311 match compress_messages(
312 &self.messages,
313 CompressionStrategy::SlidingWindow,
314 &self.compression_config,
315 ) {
316 Ok(compressed) => {
317 let compressed_tokens = estimate_total_tokens(&compressed);
318 self.messages = compressed;
319 crate::debug::debug_log().compression(
320 estimated_tokens,
321 compressed_tokens,
322 compressed_tokens as f32 / estimated_tokens as f32,
323 );
324 }
325 Err(e) => {
326 self.emit(AgentEvent::progress(format!("预压缩失败: {}", e), None))?;
327 }
328 }
329 }
330
331 let tool_defs: Vec<ToolDefinition> = {
333 let mut defs: Vec<ToolDefinition> = self
334 .tools
335 .iter()
336 .map(|t| {
337 let def = t.definition();
338 let description = def.description_for_llm();
339 ToolDefinition {
340 name: def.name,
341 description,
342 parameters: def.parameters,
343 is_priority: def.is_priority,
344 }
345 })
346 .collect();
347 defs.extend(self.proxy_tool_defs.iter().map(|t| {
349 let def = &t.definition;
350 let description = def.description_for_llm();
351 ToolDefinition {
352 name: def.name.clone(),
353 description,
354 parameters: def.parameters.clone(),
355 is_priority: def.is_priority,
356 }
357 }));
358 defs
359 };
360 let request = ChatRequest {
361 system: Some(self.system_prompt.clone()),
362 messages: self.messages.clone(),
363 max_tokens: self.max_tokens,
364 tools: tool_defs,
365 think: self.think,
366 enable_caching: true,
367 server_tools: Vec::new(),
368 };
369
370 let response = self.call_streaming(&request).await?;
371
372 self.track_usage(&response.usage);
373
374 crate::debug::debug_log().api_call(
375 &self.model_name,
376 response.usage.input_tokens,
377 response.usage.cache_read_input_tokens > 0,
378 );
379
380 should_continue = self.process_response(&response).await?;
381
382 if !should_continue && iterations < MAX_ITERATIONS - 1 {
385 self.drain_pending_inputs();
387
388 if self.has_pending_inputs() {
389 log::info!("Agent: found pending inputs at session end, continuing loop");
390 should_continue = true;
391 continue; }
393
394 if self.last_message_was_todo_reminder() {
397 log::info!("Skipping todo check: reminder already sent in recent messages");
398 } else {
399 const MAX_TODO_REMINDERS: usize = 2;
400
401 let reminder_count_clone = self.todo_reminder_count.clone();
403 let (pending, all_at_limit) = self.get_pending_todos_with_limit(
404 &reminder_count_clone,
405 MAX_TODO_REMINDERS
406 );
407
408 if !pending.is_empty() {
409 for (_, content) in &pending {
411 *self.todo_reminder_count.entry(content.clone()).or_insert(0) += 1;
412 }
413
414 let pending_list = pending
415 .iter()
416 .map(|(status, content)| {
417 let marker = match status.as_str() {
418 "in_progress" => "[~]",
419 "pending" => "[ ]",
420 _ => "[?]",
421 };
422 format!(" {} {}", marker, content)
423 })
424 .collect::<Vec<_>>()
425 .join("\n");
426
427 let reminder = format!(
428 "📋 任务尚未完成。以下待办项需要处理:\n{}\n\n请继续执行,或在 todo_write 中标记为 completed。如遇阻塞请说明原因。",
429 pending_list
430 );
431
432 self.messages.push(Message {
433 role: Role::User,
434 content: MessageContent::Text(reminder),
435 });
436 should_continue = true;
437 } else if all_at_limit && !self.todo_reminder_count.is_empty() {
438 let remaining_count = self.todo_reminder_count.len();
441 self.emit(AgentEvent::progress(
442 format!(
443 "⚠️ 会话结束:{} 个待办项未完成(已提醒 {} 次,达到上限)",
444 remaining_count, MAX_TODO_REMINDERS
445 ),
446 None,
447 ))?;
448 log::warn!(
449 "Session ending with {} incomplete todos (reminder limit reached)",
450 remaining_count
451 );
452 }
453 }
454 }
455
456 let context_size = self.effective_context_size();
457 let api_tokens = self.last_input_tokens.load(Ordering::Relaxed) as u32;
458 let estimated_tokens = estimate_total_tokens(&self.messages);
459
460 let current_tokens = if api_tokens > 0 && api_tokens >= estimated_tokens / 2 {
461 api_tokens
462 } else {
463 estimated_tokens
464 };
465
466 if let Some(ctx_size) = context_size {
469 self.emit(AgentEvent::with_data(
471 EventType::ContextSize,
472 EventData::ContextSize {
473 context_size: ctx_size as u64,
474 },
475 ))?;
476
477 let usage_ratio = current_tokens as f64 / ctx_size as f64;
478 if usage_ratio >= 0.3 {
479 crate::debug::debug_log().log(
480 "checkcompress",
481 &format!(
482 "usage={:.1}%, tokens={}, context={}, threshold={}%",
483 usage_ratio * 100.0,
484 current_tokens,
485 ctx_size,
486 self.compression_config.threshold * 100.0
487 ),
488 );
489 }
490 }
491
492 if should_compress(current_tokens, context_size, &self.compression_config) {
493 self.emit(AgentEvent::progress(prompt::MSG_COMPRESSING_CONTEXT, None))?;
494
495 let original_tokens = current_tokens;
496
497 match compress_messages(
498 &self.messages,
499 CompressionStrategy::SlidingWindow,
500 &self.compression_config,
501 ) {
502 Ok(compressed) => {
503 let compressed_tokens = estimate_total_tokens(&compressed);
504 self.messages = compressed;
505 self.total_input_tokens
506 .store(compressed_tokens as u64, Ordering::Relaxed);
507 self.last_input_tokens
508 .store(compressed_tokens as u64, Ordering::Relaxed);
509
510 let ratio = compressed_tokens as f32 / original_tokens as f32;
511 crate::debug::debug_log().compression(
512 original_tokens,
513 compressed_tokens,
514 ratio,
515 );
516
517 self.emit(AgentEvent::with_data(
518 EventType::CompressionCompleted,
519 EventData::Compression {
520 original_tokens: original_tokens as u64,
521 compressed_tokens: compressed_tokens as u64,
522 ratio: compressed_tokens as f32 / original_tokens as f32,
523 },
524 ))?;
525 }
526 Err(e) => {
527 self.emit(AgentEvent::progress(
528 format!("{}{}", prompt::MSG_COMPRESSION_FAILED, e),
529 None,
530 ))?;
531 }
532 }
533 }
534 }
535
536 if iterations >= MAX_ITERATIONS && should_continue {
538 self.emit(AgentEvent::error(
539 prompt::MSG_MAX_ITERATIONS_REACHED
540 .replace("{max_iterations}", &MAX_ITERATIONS.to_string())
541 .replace("{iterations}", &iterations.to_string()),
542 Some("MAX_ITERATIONS_REACHED".to_string()),
543 Some("agent/run.rs".to_string()),
544 ))?;
545 }
546
547 self.emit(AgentEvent::usage_with_cache(
548 self.total_input_tokens.load(Ordering::Relaxed),
549 self.total_output_tokens.load(Ordering::Relaxed),
550 0,
551 0,
552 ))?;
553
554 self.emit(AgentEvent::session_ended())?;
555
556 Ok(Vec::new())
557 }
558
559 pub fn set_messages(&mut self, messages: Vec<Message>) {
561 self.messages = messages;
562 }
563
564 pub fn get_messages(&self) -> &[Message] {
566 &self.messages
567 }
568
569 pub fn get_tools(&self) -> &[Arc<dyn Tool>] {
571 &self.tools
572 }
573
574 pub fn get_system_prompt(&self) -> &str {
576 &self.system_prompt
577 }
578
579 pub fn get_token_counts(&self) -> (u64, u64) {
581 (
582 self.total_input_tokens.load(Ordering::Relaxed),
583 self.total_output_tokens.load(Ordering::Relaxed),
584 )
585 }
586
587 pub fn clear_history(&mut self) {
589 self.messages.clear();
590 self.total_input_tokens.store(0, Ordering::Relaxed);
591 self.total_output_tokens.store(0, Ordering::Relaxed);
592 self.last_input_tokens.store(0, Ordering::Relaxed);
593 }
594
595 pub fn message_count(&self) -> usize {
597 self.messages.len()
598 }
599
600 pub fn preprocess_input(&self, user_input: &str) -> ProcessResult {
616 preprocess_with_skills(user_input, &self.skills)
618 }
619
620 pub fn inject_skill_context(&self, skill_id: &str, skill_body: Option<&str>) -> String {
632 if let Some(body) = skill_body {
633 format!(
634 "<command-name>{}</command-name>\n\n{}\n\n**Important**: Follow the skill instructions above before responding to the user request below.",
635 skill_id,
636 body.trim_end()
637 )
638 } else {
639 format!(
640 "Skill '{}' was triggered but not auto-loaded. The model should call the `skill` tool with name '{}' to load its instructions.",
641 skill_id,
642 skill_id
643 )
644 }
645 }
646
647 pub async fn add_mcp_server(
661 &mut self,
662 name: &str,
663 config: crate::mcp::McpServerConfig,
664 ) -> Result<()> {
665 if let Some(registry) = &self.mcp_registry {
666 let mut reg = registry.write().await;
667 reg.add_server(name.to_string(), config);
668 log::info!("MCP server '{}' added to registry", name);
669 } else {
670 log::warn!("MCP registry not initialized, cannot add server '{}'", name);
671 }
672 Ok(())
673 }
674
675 pub async fn remove_mcp_server(&mut self, name: &str) -> Result<()> {
677 if let Some(registry) = &self.mcp_registry {
678 let mut reg = registry.write().await;
679 reg.remove_server(name).await?;
680 log::info!("MCP server '{}' removed from registry", name);
681 }
682 Ok(())
683 }
684
685 pub async fn mcp_server_status(&self) -> Vec<crate::mcp::ServerStatus> {
687 if let Some(registry) = &self.mcp_registry {
688 let reg = registry.read().await;
689 reg.server_status().await.values().cloned().collect()
690 } else {
691 Vec::new()
692 }
693 }
694
695 pub async fn start_mcp_server(
697 &self,
698 name: &str,
699 ) -> Result<Vec<Arc<crate::mcp::McpToolWrapper>>> {
700 if let Some(registry) = &self.mcp_registry {
701 let reg = registry.read().await;
702 if let Some(placeholder) = reg.get_server(name) {
703 let tools = placeholder.start().await?;
704 log::info!("MCP server '{}' started with {} tools", name, tools.len());
705 Ok(tools)
706 } else {
707 Err(anyhow::anyhow!(
708 "MCP server '{}' not found in registry",
709 name
710 ))
711 }
712 } else {
713 Err(anyhow::anyhow!("MCP registry not initialized"))
714 }
715 }
716}