1use anyhow::Result;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicU8, Ordering};
6use tokio::sync::mpsc;
7
8use crate::approval::ApproveMode;
9use crate::cancel::CancellationToken;
10use crate::compress::{CompressionConfig, CompressionStrategy, compress_messages, estimate_total_tokens, should_compress};
11use crate::event::{AgentEvent, EventData, EventType};
12use crate::prompt::{PromptProfile, preprocess::{preprocess_with_skills, ProcessResult}};
13use crate::providers::{ChatRequest, Message, MessageContent, Role};
14use crate::skills::Skill;
15use crate::tools::Tool;
16use crate::tools::ToolDefinition;
17use crate::tools::toolproxy::{ProxyToolDef, ProxyToolExecutor};
18
19use super::core::{AgentConfig, AgentState};
20use super::context::AgentContext;
21use super::session::SessionManager;
22use super::types::{Agent, AgentBuilder, MAX_ITERATIONS};
23
24#[allow(dead_code)]
25impl Agent {
26 pub(crate) fn new(builder: AgentBuilder) -> Self {
27 let event_tx = builder.event_tx.unwrap_or_else(|| {
29 let (tx, _) = mpsc::channel(100);
30 tx
31 });
32
33 let config = AgentConfig {
35 max_tokens: builder.max_tokens,
36 context_size_override: builder.context_size_override,
37 think: builder.think,
38 compression: builder.compression_config,
39 verify_strategy: builder.verify_strategy,
40 project_path: builder.project_path.clone(),
41 ..AgentConfig::default()
42 };
43
44 let state = AgentState::new();
45
46 let context = AgentContext::with_context(
48 builder.profile,
49 builder.skills,
50 builder.project_overview,
51 builder.memory_summary,
52 builder.project_path,
53 );
54
55 let session = SessionManager::with_all_channels(
57 event_tx.clone(),
58 None, builder.pending_input_rx,
60 );
61
62 Self {
63 config,
65 state,
66 context,
67 session,
68
69 provider: builder.provider,
71 model_name: builder.model_name,
72 tools: builder.tools,
73
74 event_tx,
76
77 approve_mode: Arc::new(AtomicU8::new(builder.approve_mode.to_u8())),
79
80 proxy_tool_defs: builder.proxy_tool_defs,
82 proxy_executor: builder.proxy_executor,
83
84 mcp_registry: builder.mcp_registry,
86 lsp_registry: builder.lsp_registry,
87 }
88 }
89
90 pub(crate) fn messages(&self) -> &Vec<Message> {
95 self.state.messages()
96 }
97
98 pub(crate) fn messages_mut(&mut self) -> &mut Vec<Message> {
100 self.state.messages_mut()
101 }
102
103 pub(crate) fn system_prompt(&self) -> &str {
105 self.context.system_prompt()
106 }
107
108 pub(crate) fn max_tokens(&self) -> u32 {
110 self.config.max_tokens()
111 }
112
113 pub(crate) fn context_size_override(&self) -> Option<u32> {
115 self.config.context_size_override()
116 }
117
118 pub(crate) fn think(&self) -> bool {
120 self.config.think()
121 }
122
123 pub(crate) fn verify_strategy(&self) -> crate::tools::code_quality_hook::VerificationStrategy {
125 self.config.verify_strategy()
126 }
127
128 pub(crate) fn verify_project_path(&self) -> Option<&std::path::Path> {
130 self.config.project_path()
131 }
132
133 pub(crate) fn compression_config(&self) -> &CompressionConfig {
135 self.config.compression_config()
136 }
137
138 pub(crate) fn compression_config_mut(&mut self) -> &mut CompressionConfig {
140 self.config.compression_config_mut()
141 }
142
143 pub(crate) fn cancel_token(&self) -> Option<&CancellationToken> {
145 self.session.cancel_token()
146 }
147
148 pub(crate) fn event_tx(&self) -> &mpsc::Sender<AgentEvent> {
150 &self.event_tx
151 }
152
153 pub(crate) fn skills(&self) -> &[Skill] {
155 self.context.skills()
156 }
157
158 pub(crate) fn profile(&self) -> &PromptProfile {
160 self.context.profile()
161 }
162
163 pub(crate) fn project_overview(&self) -> Option<&str> {
165 self.context.project_overview()
166 }
167
168 pub(crate) fn memory_summary(&self) -> Option<&str> {
170 self.context.memory_summary()
171 }
172
173 pub(crate) fn project_path(&self) -> Option<&std::path::PathBuf> {
175 self.context.project_path()
176 }
177
178 pub(crate) fn is_cancelled(&self) -> bool {
180 self.session.is_cancelled()
181 }
182
183 pub(crate) fn total_input_tokens(&self) -> u64 {
185 self.state.total_input_tokens()
186 }
187
188 pub(crate) fn total_output_tokens(&self) -> u64 {
190 self.state.total_output_tokens()
191 }
192
193 pub(crate) fn last_input_tokens(&self) -> u64 {
195 self.state.last_input_tokens()
196 }
197
198 pub(crate) fn todo_reminder_count(&self) -> &std::collections::HashMap<String, usize> {
200 self.state.todo_reminder_count_map()
201 }
202
203 pub(crate) fn todo_reminder_count_mut(&mut self) -> &mut std::collections::HashMap<String, usize> {
205 self.state.todo_reminder_count_map_mut()
206 }
207
208 pub(crate) fn pending_inputs(&self) -> &Vec<String> {
210 self.state.pending_inputs_vec()
211 }
212
213 pub(crate) fn pending_inputs_mut(&mut self) -> &mut Vec<String> {
215 self.state.pending_inputs_vec_mut()
216 }
217
218 pub(crate) fn ask_rx(&mut self) -> Option<&mut mpsc::Receiver<String>> {
220 self.session.ask_rx()
221 }
222
223 pub(crate) fn effective_context_size(&self) -> Option<u32> {
225 self.config.context_size_override()
226 .or_else(|| self.provider.context_size())
227 }
228
229 pub fn event_sender(&self) -> mpsc::Sender<AgentEvent> {
231 self.event_tx.clone()
232 }
233
234 pub fn set_ask_channel(&mut self, rx: mpsc::Receiver<String>) {
236 self.session.set_ask_channel(rx);
237 }
238
239 pub(crate) fn has_ask_channel(&self) -> bool {
241 self.session.has_ask_channel()
242 }
243
244 pub(crate) fn ask_channel(&mut self) -> Option<&mut mpsc::Receiver<String>> {
246 self.session.ask_rx()
247 }
248
249 pub fn set_proxy_executor(
251 &mut self,
252 executor: Arc<dyn ProxyToolExecutor>,
253 tool_defs: Vec<ProxyToolDef>,
254 ) {
255 self.proxy_executor = Some(executor);
256 self.proxy_tool_defs = tool_defs;
257 }
258
259 pub fn set_cancel_token(&mut self, token: CancellationToken) {
261 self.session.set_cancel_token(token);
262 }
263
264 pub(crate) fn get_cancel_token(&self) -> Option<&CancellationToken> {
266 self.session.cancel_token()
267 }
268
269 pub fn set_approve_mode(&mut self, mode: ApproveMode) {
271 let old = ApproveMode::from_u8(self.approve_mode.load(Ordering::Relaxed));
272 log::info!("Agent approve mode changed: {} -> {}", old, mode);
273 self.approve_mode.store(mode.to_u8(), Ordering::Relaxed);
274 }
275
276 pub fn approve_mode_shared(&self) -> Arc<AtomicU8> {
278 self.approve_mode.clone()
279 }
280
281 pub fn set_approve_mode_shared(&mut self, shared: Arc<AtomicU8>) {
283 self.approve_mode = shared;
284 }
285
286 pub fn update_memory_summary(&mut self, summary: Option<String>) {
289 self.context.update_memory(summary);
290 }
292
293 pub fn refresh_codegraph_tools(&mut self) {
297 if let Some(path) = self.context.project_path() {
298 let should_have_codegraph =
300 crate::tools::codegraph::should_inject_codegraph_tools(path);
301
302 let has_codegraph = self.tools.iter().any(|t| {
304 let name = t.definition().name;
305 name.starts_with("code_") && name != "code_review"
306 });
307
308 if should_have_codegraph != has_codegraph {
310 if should_have_codegraph {
312 let codegraph_tools = crate::tools::codegraph::codegraph_tools(path);
313 for tool in codegraph_tools {
314 self.tools.push(Arc::from(tool));
315 }
316 } else {
317 self.tools.retain(|t| {
319 let name = t.definition().name;
320 !name.starts_with("code_") || name == "code_review"
321 });
322 }
323 self.context.rebuild_system_prompt_with_workflows(Some(path.clone()));
325 }
326 }
327 }
328
329 pub async fn run(&mut self, user_input: String) -> Result<Vec<AgentEvent>> {
331 self.emit(AgentEvent::session_started())?;
332
333 let preprocess_result = self.preprocess_input(&user_input);
335
336 let processed_input = match preprocess_result {
338 ProcessResult::SkillTriggered {
339 skill_id,
340 confidence,
341 skill_body,
342 } => {
343 log::info!(
344 "Skill triggered: {} (confidence: {:.2})",
345 skill_id,
346 confidence
347 );
348 self.emit(AgentEvent::progress(
349 format!("🎯 触发技能: {}", skill_id),
350 None,
351 ))?;
352
353 if let Some(body) = skill_body {
355 let enhanced_input = format!(
357 "<command-name>{}</command-name>\n\n{}\n\n---\n\nUser request: {}",
358 skill_id,
359 body,
360 user_input
361 );
362 enhanced_input
363 } else {
364 let enhanced_input = format!(
366 "User invoked skill '{}'. Use the `skill` tool with name '{}' to load its instructions before proceeding.\n\nUser request: {}",
367 skill_id,
368 skill_id,
369 user_input
370 );
371 enhanced_input
372 }
373 }
374 ProcessResult::WorkflowTriggered {
375 workflow_id,
376 inputs,
377 } => {
378 log::info!("Workflow triggered: {} with inputs: {:?}", workflow_id, inputs);
379 self.emit(AgentEvent::progress(
380 format!("🔄 触发工作流: {}", workflow_id),
381 None,
382 ))?;
383 let inputs_json = serde_json::to_string_pretty(&inputs).unwrap_or_default();
385 let enhanced_input = format!(
386 "Workflow '{}' triggered with extracted inputs:\n{}\n\nUser request: {}",
387 workflow_id,
388 inputs_json,
389 user_input
390 );
391 enhanced_input
392 }
393 ProcessResult::Continue => {
394 user_input
396 }
397 };
398
399 self.state.add_message(Message {
401 role: Role::User,
402 content: MessageContent::Text(processed_input),
403 });
404
405 let mut iterations = 0;
406 let mut should_continue = true;
407 const ITERATION_WARNING_THRESHOLD: usize = MAX_ITERATIONS - 10;
408
409 while should_continue && iterations < MAX_ITERATIONS {
410 iterations += 1;
411
412 self.drain_pending_inputs();
415 if self.has_pending_inputs() {
416 let pending = self.take_pending_inputs();
417 let count = pending.len();
418 let merged = pending.join("\n\n---\n\n");
419 log::info!("Adding {} pending input messages to request", count);
420
421 self.emit(AgentEvent::queue_processed(count, pending.clone()))?;
423
424 self.state.add_message(Message {
425 role: Role::User,
426 content: MessageContent::Text(merged),
427 });
428 }
429
430 if self.session.is_cancelled() {
431 self.emit(AgentEvent::error(
432 crate::prompt::MSG_OPERATION_CANCELLED.to_string(),
433 None,
434 None,
435 ))?;
436 break;
437 }
438
439 if iterations == ITERATION_WARNING_THRESHOLD {
441 self.emit(AgentEvent::progress(
442 crate::prompt::MSG_ITERATION_WARNING_UI
443 .replace("{iterations}", &iterations.to_string())
444 .replace("{max_iterations}", &MAX_ITERATIONS.to_string()),
445 None,
446 ))?;
447 }
448
449 let context_size = self.effective_context_size();
452 let estimated_tokens = estimate_total_tokens(self.state.messages());
453
454 if should_compress(estimated_tokens, context_size, self.config.compression_config()) {
455 self.emit(AgentEvent::progress("⚠️ 上下文过大,正在预压缩...", None))?;
456
457 match compress_messages(
458 self.state.messages(),
459 CompressionStrategy::SlidingWindow,
460 self.config.compression_config(),
461 ) {
462 Ok(compressed) => {
463 let compressed_tokens = estimate_total_tokens(&compressed);
464 self.state.set_messages(compressed);
465 crate::debug::debug_log().compression(
466 estimated_tokens,
467 compressed_tokens,
468 compressed_tokens as f32 / estimated_tokens as f32,
469 );
470 }
471 Err(e) => {
472 self.emit(AgentEvent::progress(format!("预压缩失败: {}", e), None))?;
473 }
474 }
475 }
476
477 let tool_defs: Vec<ToolDefinition> = {
479 let mut defs: Vec<ToolDefinition> = self
480 .tools
481 .iter()
482 .map(|t| {
483 let def = t.definition();
484 let description = def.description_for_llm();
485 ToolDefinition {
486 name: def.name,
487 description,
488 parameters: def.parameters,
489 is_priority: def.is_priority,
490 }
491 })
492 .collect();
493 defs.extend(self.proxy_tool_defs.iter().map(|t| {
495 let def = &t.definition;
496 let description = def.description_for_llm();
497 ToolDefinition {
498 name: def.name.clone(),
499 description,
500 parameters: def.parameters.clone(),
501 is_priority: def.is_priority,
502 }
503 }));
504 defs
505 };
506 let request = ChatRequest {
507 system: Some(self.system_prompt().to_string()),
508 messages: self.state.messages().clone(),
509 max_tokens: self.max_tokens(),
510 tools: tool_defs,
511 think: self.think(),
512 enable_caching: true,
513 server_tools: Vec::new(),
514 };
515
516 let response = self.call_streaming(&request).await?;
517
518 self.track_usage(&response.usage);
519
520 crate::debug::debug_log().api_call(
521 &self.model_name,
522 response.usage.input_tokens,
523 response.usage.cache_read_input_tokens > 0,
524 );
525
526 should_continue = self.process_response(&response).await?;
527
528 if !should_continue && iterations < MAX_ITERATIONS - 1 {
531 self.drain_pending_inputs();
533
534 if self.has_pending_inputs() {
535 log::info!("Agent: found pending inputs at session end, continuing loop");
536 should_continue = true;
537 continue; }
539
540 if self.last_message_was_todo_reminder() {
543 log::info!("Skipping todo check: reminder already sent in recent messages");
544 } else {
545 const MAX_TODO_REMINDERS: usize = 2;
546
547 let reminder_count_clone = self.state.todo_reminder_count_map().clone();
549 let (pending, all_at_limit) = self.get_pending_todos_with_limit(
550 &reminder_count_clone,
551 MAX_TODO_REMINDERS
552 );
553
554 if !pending.is_empty() {
555 for (_, content) in &pending {
557 self.state.increment_todo_reminder(content.clone());
558 }
559
560 let pending_list = pending
561 .iter()
562 .map(|(status, content)| {
563 let marker = match status.as_str() {
564 "in_progress" => "[~]",
565 "pending" => "[ ]",
566 _ => "[?]",
567 };
568 format!(" {} {}", marker, content)
569 })
570 .collect::<Vec<_>>()
571 .join("\n");
572
573 let reminder = format!(
574 "📋 任务尚未完成。以下待办项需要处理:\n{}\n\n请继续执行,或在 todo_write 中标记为 completed。如遇阻塞请说明原因。",
575 pending_list
576 );
577
578 self.state.add_message(Message {
579 role: Role::User,
580 content: MessageContent::Text(reminder),
581 });
582 should_continue = true;
583 } else if all_at_limit && !self.state.todo_reminder_count_map().is_empty() {
584 let remaining_count = self.state.todo_reminder_count_map().len();
587 self.emit(AgentEvent::progress(
588 format!(
589 "⚠️ 会话结束:{} 个待办项未完成(已提醒 {} 次,达到上限)",
590 remaining_count, MAX_TODO_REMINDERS
591 ),
592 None,
593 ))?;
594 log::warn!(
595 "Session ending with {} incomplete todos (reminder limit reached)",
596 remaining_count
597 );
598 }
599 }
600 }
601
602 let context_size = self.effective_context_size();
603 let api_tokens = self.state.last_input_tokens() as u32;
604 let estimated_tokens = estimate_total_tokens(self.state.messages());
605
606 let current_tokens = if api_tokens > 0 {
611 api_tokens
612 } else {
613 estimated_tokens
614 };
615
616 if let Some(ctx_size) = context_size {
619 self.emit(AgentEvent::with_data(
621 EventType::ContextSize,
622 EventData::ContextSize {
623 context_size: ctx_size as u64,
624 },
625 ))?;
626
627 let usage_ratio = current_tokens as f64 / ctx_size as f64;
628 if usage_ratio >= 0.3 {
629 crate::debug::debug_log().log(
630 "checkcompress",
631 &format!(
632 "usage={:.1}%, tokens={}, context={}, threshold={}%",
633 usage_ratio * 100.0,
634 current_tokens,
635 ctx_size,
636 self.config.compression_config().threshold * 100.0
637 ),
638 );
639 }
640 }
641
642 if should_compress(current_tokens, context_size, self.config.compression_config()) {
643 self.emit(AgentEvent::progress(crate::prompt::MSG_COMPRESSING_CONTEXT, None))?;
644
645 let original_tokens = current_tokens;
646
647 match compress_messages(
648 self.state.messages(),
649 CompressionStrategy::SlidingWindow,
650 self.config.compression_config(),
651 ) {
652 Ok(compressed) => {
653 let compressed_tokens = estimate_total_tokens(&compressed);
654 self.state.set_messages(compressed);
655 self.state.set_total_input_tokens(compressed_tokens as u64);
656 self.state.set_last_input_tokens(compressed_tokens as u64);
657
658 let ratio = compressed_tokens as f32 / original_tokens as f32;
659 crate::debug::debug_log().compression(
660 original_tokens,
661 compressed_tokens,
662 ratio,
663 );
664
665 self.emit(AgentEvent::with_data(
666 EventType::CompressionCompleted,
667 EventData::Compression {
668 original_tokens: original_tokens as u64,
669 compressed_tokens: compressed_tokens as u64,
670 ratio: compressed_tokens as f32 / original_tokens as f32,
671 },
672 ))?;
673 }
674 Err(e) => {
675 self.emit(AgentEvent::progress(
676 format!("{}{}", crate::prompt::MSG_COMPRESSION_FAILED, e),
677 None,
678 ))?;
679 }
680 }
681 }
682 }
683
684 if iterations >= MAX_ITERATIONS && should_continue {
686 self.emit(AgentEvent::error(
687 crate::prompt::MSG_MAX_ITERATIONS_REACHED
688 .replace("{max_iterations}", &MAX_ITERATIONS.to_string())
689 .replace("{iterations}", &iterations.to_string()),
690 Some("MAX_ITERATIONS_REACHED".to_string()),
691 Some("agent/run.rs".to_string()),
692 ))?;
693 }
694
695 self.emit(AgentEvent::usage_with_cache(
696 self.state.total_input_tokens(),
697 self.state.total_output_tokens(),
698 0,
699 0,
700 ))?;
701
702 self.emit(AgentEvent::session_ended())?;
703
704 Ok(Vec::new())
705 }
706
707 pub fn set_messages(&mut self, messages: Vec<Message>) {
709 self.state.set_messages(messages);
710 }
711
712 pub fn get_messages(&self) -> &[Message] {
714 self.messages()
715 }
716
717 pub fn get_tools(&self) -> &[Arc<dyn Tool>] {
719 &self.tools
720 }
721
722 pub fn get_system_prompt(&self) -> &str {
724 self.system_prompt()
725 }
726
727 pub fn get_token_counts(&self) -> (u64, u64) {
729 (
730 self.state.total_input_tokens(),
731 self.state.total_output_tokens(),
732 )
733 }
734
735 pub fn clear_history(&mut self) {
737 self.messages_mut().clear();
738 self.state.set_total_input_tokens(0);
739 self.state.set_total_output_tokens(0);
740 self.state.set_last_input_tokens(0);
741 }
742
743 pub fn message_count(&self) -> usize {
745 self.messages().len()
746 }
747
748 pub fn preprocess_input(&self, user_input: &str) -> ProcessResult {
764 preprocess_with_skills(user_input, self.skills())
766 }
767
768 pub fn inject_skill_context(&self, skill_id: &str, skill_body: Option<&str>) -> String {
780 if let Some(body) = skill_body {
781 format!(
782 "<command-name>{}</command-name>\n\n{}\n\n**Important**: Follow the skill instructions above before responding to the user request below.",
783 skill_id,
784 body.trim_end()
785 )
786 } else {
787 format!(
788 "Skill '{}' was triggered but not auto-loaded. The model should call the `skill` tool with name '{}' to load its instructions.",
789 skill_id,
790 skill_id
791 )
792 }
793 }
794
795 pub async fn add_mcp_server(
809 &mut self,
810 name: &str,
811 config: crate::mcp::McpServerConfig,
812 ) -> Result<()> {
813 if let Some(registry) = &self.mcp_registry {
814 let mut reg = registry.write().await;
815 reg.add_server(name.to_string(), config);
816 log::info!("MCP server '{}' added to registry", name);
817 } else {
818 log::warn!("MCP registry not initialized, cannot add server '{}'", name);
819 }
820 Ok(())
821 }
822
823 pub async fn remove_mcp_server(&mut self, name: &str) -> Result<()> {
825 if let Some(registry) = &self.mcp_registry {
826 let mut reg = registry.write().await;
827 reg.remove_server(name).await?;
828 log::info!("MCP server '{}' removed from registry", name);
829 }
830 Ok(())
831 }
832
833 pub async fn mcp_server_status(&self) -> Vec<crate::mcp::ServerStatus> {
835 if let Some(registry) = &self.mcp_registry {
836 let reg = registry.read().await;
837 reg.server_status().await.values().cloned().collect()
838 } else {
839 Vec::new()
840 }
841 }
842
843 pub async fn start_mcp_server(
845 &self,
846 name: &str,
847 ) -> Result<Vec<Arc<crate::mcp::McpToolWrapper>>> {
848 if let Some(registry) = &self.mcp_registry {
849 let reg = registry.read().await;
850 if let Some(placeholder) = reg.get_server(name) {
851 let tools = placeholder.start().await?;
852 log::info!("MCP server '{}' started with {} tools", name, tools.len());
853 Ok(tools)
854 } else {
855 Err(anyhow::anyhow!(
856 "MCP server '{}' not found in registry",
857 name
858 ))
859 }
860 } else {
861 Err(anyhow::anyhow!("MCP registry not initialized"))
862 }
863 }
864}