use anyhow::Result;
use std::sync::Arc;
use std::sync::atomic::{AtomicU8, Ordering};
use tokio::sync::mpsc;
use crate::approval::ApproveMode;
use crate::cancel::CancellationToken;
use crate::compress::{
CompressionStrategy, compress_messages, estimate_total_tokens, should_compress,
};
use crate::event::{AgentEvent, EventData, EventType};
use crate::prompt;
use crate::providers::{ChatRequest, Message, MessageContent, Role};
use crate::tools::Tool;
use crate::tools::ToolDefinition;
use crate::tools::toolproxy::{ProxyToolExecutor, ProxyToolDef};
use super::types::{Agent, AgentBuilder, MAX_ITERATIONS};
impl Agent {
pub(crate) fn new(builder: AgentBuilder) -> Self {
let event_tx = builder.event_tx.unwrap_or_else(|| {
let (tx, _) = mpsc::channel(100);
tx
});
Self {
provider: builder.provider,
model_name: builder.model_name,
tools: builder.tools,
messages: Vec::new(),
system_prompt: builder.system_prompt,
max_tokens: builder.max_tokens,
think: builder.think,
approve_mode: Arc::new(AtomicU8::new(builder.approve_mode.to_u8())),
event_tx,
skills: builder.skills,
profile: builder.profile,
project_overview: builder.project_overview,
memory_summary: builder.memory_summary,
project_path: builder.project_path,
total_input_tokens: std::sync::atomic::AtomicU64::new(0),
total_output_tokens: std::sync::atomic::AtomicU64::new(0),
last_input_tokens: std::sync::atomic::AtomicU64::new(0),
cancel_token: None,
compression_config: crate::compress::CompressionConfig::default(),
ask_rx: None,
proxy_tool_defs: builder.proxy_tool_defs,
proxy_executor: builder.proxy_executor,
mcp_registry: builder.mcp_registry,
pending_input_rx: builder.pending_input_rx,
pending_inputs: Vec::new(),
}
}
pub fn event_sender(&self) -> mpsc::Sender<AgentEvent> {
self.event_tx.clone()
}
pub fn set_ask_channel(&mut self, rx: mpsc::Receiver<String>) {
self.ask_rx = Some(rx);
}
pub fn set_proxy_executor(&mut self, executor: Arc<dyn ProxyToolExecutor>, tool_defs: Vec<ProxyToolDef>) {
self.proxy_executor = Some(executor);
self.proxy_tool_defs = tool_defs;
}
pub fn set_cancel_token(&mut self, token: CancellationToken) {
self.cancel_token = Some(token);
}
pub fn set_approve_mode(&mut self, mode: ApproveMode) {
let old = ApproveMode::from_u8(self.approve_mode.load(Ordering::Relaxed));
log::info!("Agent approve mode changed: {} -> {}", old, mode);
self.approve_mode.store(mode.to_u8(), Ordering::Relaxed);
}
pub fn approve_mode_shared(&self) -> Arc<AtomicU8> {
self.approve_mode.clone()
}
pub fn set_approve_mode_shared(&mut self, shared: Arc<AtomicU8>) {
self.approve_mode = shared;
}
pub fn update_memory_summary(&mut self, summary: Option<String>) {
self.memory_summary = summary;
self.system_prompt = prompt::build_system_prompt(
&self.profile,
&self.skills,
self.project_overview.as_deref(),
self.memory_summary.as_deref(),
);
}
pub fn refresh_codegraph_tools(&mut self) {
if let Some(path) = &self.project_path {
let should_have_codegraph = crate::tools::codegraph::should_inject_codegraph_tools(path);
let has_codegraph = self.tools.iter().any(|t| {
let name = t.definition().name;
name.starts_with("code_") && name != "code_review"
});
if should_have_codegraph != has_codegraph {
if should_have_codegraph {
let codegraph_tools = crate::tools::codegraph::codegraph_tools(path);
for tool in codegraph_tools {
self.tools.push(Arc::from(tool));
}
self.system_prompt = prompt::build_system_prompt_with_workflows(
&self.profile,
&self.skills,
self.project_overview.as_deref(),
self.memory_summary.as_deref(),
Some(path),
None, );
} else {
self.tools.retain(|t| {
let name = t.definition().name;
!name.starts_with("code_") || name == "code_review"
});
self.system_prompt = prompt::build_system_prompt_with_workflows(
&self.profile,
&self.skills,
self.project_overview.as_deref(),
self.memory_summary.as_deref(),
Some(path),
None, );
}
}
}
}
pub async fn run(&mut self, user_input: String) -> Result<Vec<AgentEvent>> {
self.emit(AgentEvent::session_started())?;
self.messages.push(Message {
role: Role::User,
content: MessageContent::Text(user_input.clone()),
});
let mut iterations = 0;
let mut should_continue = true;
const ITERATION_WARNING_THRESHOLD: usize = MAX_ITERATIONS - 10;
while should_continue && iterations < MAX_ITERATIONS {
iterations += 1;
if self.has_pending_inputs() {
let pending = self.take_pending_inputs();
let merged = pending.join("\n\n---\n\n");
log::info!("Adding {} pending input messages to request", pending.len());
self.emit(AgentEvent::progress(
format!("📝 收到 {} 条追加消息", pending.len()),
None,
))?;
self.messages.push(Message {
role: Role::User,
content: MessageContent::Text(merged),
});
}
if let Some(token) = &self.cancel_token
&& token.is_cancelled()
{
self.emit(AgentEvent::error(
prompt::MSG_OPERATION_CANCELLED.to_string(),
None,
None,
))?;
break;
}
if iterations == ITERATION_WARNING_THRESHOLD {
self.emit(AgentEvent::progress(
prompt::MSG_ITERATION_WARNING_UI
.replace("{iterations}", &iterations.to_string())
.replace("{max_iterations}", &MAX_ITERATIONS.to_string()),
None,
))?;
}
let context_size = self.provider.context_size();
let estimated_tokens = estimate_total_tokens(&self.messages);
if should_compress(estimated_tokens, context_size, &self.compression_config) {
self.emit(AgentEvent::progress("⚠️ 上下文过大,正在预压缩...", None))?;
match compress_messages(
&self.messages,
CompressionStrategy::SlidingWindow,
&self.compression_config,
) {
Ok(compressed) => {
let compressed_tokens = estimate_total_tokens(&compressed);
self.messages = compressed;
crate::debug::debug_log().compression(
estimated_tokens,
compressed_tokens,
compressed_tokens as f32 / estimated_tokens as f32,
);
}
Err(e) => {
self.emit(AgentEvent::progress(
format!("预压缩失败: {}", e),
None,
))?;
}
}
}
let tool_defs: Vec<ToolDefinition> = {
let mut defs: Vec<ToolDefinition> = self.tools.iter().map(|t| {
let def = t.definition();
let description = def.description_for_llm();
ToolDefinition {
name: def.name,
description,
parameters: def.parameters,
is_priority: def.is_priority,
}
}).collect();
defs.extend(self.proxy_tool_defs.iter().map(|t| {
let def = &t.definition;
let description = def.description_for_llm();
ToolDefinition {
name: def.name.clone(),
description,
parameters: def.parameters.clone(),
is_priority: def.is_priority,
}
}));
defs
};
let request = ChatRequest {
system: Some(self.system_prompt.clone()),
messages: self.messages.clone(),
max_tokens: self.max_tokens,
tools: tool_defs,
think: self.think,
enable_caching: true,
server_tools: Vec::new(),
};
let response = self.call_streaming(&request).await?;
self.track_usage(&response.usage);
crate::debug::debug_log().api_call(
&self.model_name,
response.usage.input_tokens,
response.usage.cache_read_input_tokens > 0,
);
should_continue = self.process_response(&response).await?;
if !should_continue && iterations < MAX_ITERATIONS - 1 {
if self.has_pending_inputs() {
let pending = self.take_pending_inputs();
let merged = pending.join("\n\n---\n\n");
log::info!("Model stopped but user appended {} messages, continuing", pending.len());
self.emit(AgentEvent::progress(
format!("📝 处理 {} 条追加消息", pending.len()),
None,
))?;
self.messages.push(Message {
role: Role::User,
content: MessageContent::Text(merged),
});
should_continue = true;
} else {
let pending = self.get_pending_todos();
if !pending.is_empty() {
let pending_list = pending.iter()
.map(|(status, content)| {
let marker = match status.as_str() {
"in_progress" => "[~]",
"pending" => "[ ]",
_ => "[?]"
};
format!(" {} {}", marker, content)
})
.collect::<Vec<_>>()
.join("\n");
let reminder = format!(
"📋 任务尚未完成。以下待办项需要处理:\n{}\n\n请继续执行,或在 todo_write 中标记为 completed。如遇阻塞请说明原因。",
pending_list
);
self.messages.push(Message {
role: Role::User,
content: MessageContent::Text(reminder),
});
should_continue = true;
}
}
}
let context_size = self.provider.context_size();
let api_tokens = self.last_input_tokens.load(Ordering::Relaxed) as u32;
let estimated_tokens = estimate_total_tokens(&self.messages);
let current_tokens = if api_tokens > 0 && api_tokens >= estimated_tokens / 2 {
api_tokens
} else {
estimated_tokens
};
if let Some(ctx_size) = context_size {
self.emit(AgentEvent::with_data(
EventType::ContextSize,
EventData::ContextSize {
context_size: ctx_size as u64,
},
))?;
let usage_ratio = current_tokens as f64 / ctx_size as f64;
if usage_ratio >= 0.3 {
crate::debug::debug_log().log(
"checkcompress",
&format!(
"usage={:.1}%, tokens={}, context={}, threshold={}%",
usage_ratio * 100.0,
current_tokens,
ctx_size,
self.compression_config.threshold * 100.0
),
);
}
}
if should_compress(current_tokens, context_size, &self.compression_config) {
self.emit(AgentEvent::progress(prompt::MSG_COMPRESSING_CONTEXT, None))?;
let original_tokens = current_tokens;
match compress_messages(
&self.messages,
CompressionStrategy::SlidingWindow,
&self.compression_config,
) {
Ok(compressed) => {
let compressed_tokens = estimate_total_tokens(&compressed);
self.messages = compressed;
self.total_input_tokens
.store(compressed_tokens as u64, Ordering::Relaxed);
self.last_input_tokens
.store(compressed_tokens as u64, Ordering::Relaxed);
let ratio = compressed_tokens as f32 / original_tokens as f32;
crate::debug::debug_log().compression(
original_tokens,
compressed_tokens,
ratio,
);
self.emit(AgentEvent::with_data(
EventType::CompressionCompleted,
EventData::Compression {
original_tokens: original_tokens as u64,
compressed_tokens: compressed_tokens as u64,
ratio: compressed_tokens as f32 / original_tokens as f32,
},
))?;
}
Err(e) => {
self.emit(AgentEvent::progress(
format!("{}{}", prompt::MSG_COMPRESSION_FAILED, e),
None,
))?;
}
}
}
}
if iterations >= MAX_ITERATIONS && should_continue {
self.emit(AgentEvent::error(
prompt::MSG_MAX_ITERATIONS_REACHED
.replace("{max_iterations}", &MAX_ITERATIONS.to_string())
.replace("{iterations}", &iterations.to_string()),
Some("MAX_ITERATIONS_REACHED".to_string()),
Some("agent/run.rs".to_string()),
))?;
}
self.emit(AgentEvent::usage_with_cache(
self.total_input_tokens.load(Ordering::Relaxed),
self.total_output_tokens.load(Ordering::Relaxed),
0,
0,
))?;
self.emit(AgentEvent::session_ended())?;
Ok(Vec::new())
}
pub fn set_messages(&mut self, messages: Vec<Message>) {
self.messages = messages;
}
pub fn get_messages(&self) -> &[Message] {
&self.messages
}
pub fn get_tools(&self) -> &[Arc<dyn Tool>] {
&self.tools
}
pub fn get_system_prompt(&self) -> &str {
&self.system_prompt
}
pub fn get_token_counts(&self) -> (u64, u64) {
(
self.total_input_tokens.load(Ordering::Relaxed),
self.total_output_tokens.load(Ordering::Relaxed),
)
}
pub fn clear_history(&mut self) {
self.messages.clear();
self.total_input_tokens.store(0, Ordering::Relaxed);
self.total_output_tokens.store(0, Ordering::Relaxed);
self.last_input_tokens.store(0, Ordering::Relaxed);
}
pub fn message_count(&self) -> usize {
self.messages.len()
}
pub async fn add_mcp_server(&mut self, name: &str, config: crate::mcp::McpServerConfig) -> Result<()> {
if let Some(registry) = &self.mcp_registry {
let mut reg = registry.write().await;
reg.add_server(name.to_string(), config);
log::info!("MCP server '{}' added to registry", name);
} else {
log::warn!("MCP registry not initialized, cannot add server '{}'", name);
}
Ok(())
}
pub async fn remove_mcp_server(&mut self, name: &str) -> Result<()> {
if let Some(registry) = &self.mcp_registry {
let mut reg = registry.write().await;
reg.remove_server(name).await?;
log::info!("MCP server '{}' removed from registry", name);
}
Ok(())
}
pub async fn mcp_server_status(&self) -> Vec<crate::mcp::ServerStatus> {
if let Some(registry) = &self.mcp_registry {
let reg = registry.read().await;
reg.server_status().await.values().cloned().collect()
} else {
Vec::new()
}
}
pub async fn start_mcp_server(&self, name: &str) -> Result<Vec<Arc<crate::mcp::McpToolWrapper>>> {
if let Some(registry) = &self.mcp_registry {
let reg = registry.read().await;
if let Some(placeholder) = reg.get_server(name) {
let tools = placeholder.start().await?;
log::info!("MCP server '{}' started with {} tools", name, tools.len());
Ok(tools)
} else {
Err(anyhow::anyhow!("MCP server '{}' not found in registry", name))
}
} else {
Err(anyhow::anyhow!("MCP registry not initialized"))
}
}
}