1use anyhow::Result;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicU8, Ordering};
6use tokio::sync::mpsc;
7
8use crate::approval::ApproveMode;
9use crate::cancel::CancellationToken;
10use crate::compress::{
11 CompressionStrategy, compress_messages, estimate_total_tokens, should_compress,
12};
13use crate::event::{AgentEvent, EventData, EventType};
14use crate::prompt;
15use crate::providers::{ChatRequest, Message, MessageContent, Role};
16use crate::tools::ToolDefinition;
17
18use super::types::{Agent, AgentBuilder, MAX_ITERATIONS};
19
20impl Agent {
21 pub(crate) fn new(builder: AgentBuilder) -> Self {
22 let event_tx = builder.event_tx.unwrap_or_else(|| {
23 let (tx, _) = mpsc::channel(100);
24 tx
25 });
26
27 Self {
28 provider: builder.provider,
29 model_name: builder.model_name,
30 tools: builder.tools,
31 messages: Vec::new(),
32 system_prompt: builder.system_prompt,
33 max_tokens: builder.max_tokens,
34 think: builder.think,
35 approve_mode: Arc::new(AtomicU8::new(builder.approve_mode.to_u8())),
36 event_tx,
37 skills: builder.skills,
38 profile: builder.profile,
39 project_overview: builder.project_overview,
40 memory_summary: builder.memory_summary,
41 total_input_tokens: std::sync::atomic::AtomicU64::new(0),
42 total_output_tokens: std::sync::atomic::AtomicU64::new(0),
43 last_input_tokens: std::sync::atomic::AtomicU64::new(0),
44 cancel_token: None,
45 compression_config: crate::compress::CompressionConfig::default(),
46 ask_rx: None,
47 }
48 }
49
50 pub fn event_sender(&self) -> mpsc::Sender<AgentEvent> {
52 self.event_tx.clone()
53 }
54
55 pub fn set_ask_channel(&mut self, rx: mpsc::Receiver<String>) {
57 self.ask_rx = Some(rx);
58 }
59
60 pub fn set_cancel_token(&mut self, token: CancellationToken) {
62 self.cancel_token = Some(token);
63 }
64
65 pub fn set_approve_mode(&mut self, mode: ApproveMode) {
67 let old = ApproveMode::from_u8(self.approve_mode.load(Ordering::Relaxed));
68 log::info!("Agent approve mode changed: {} -> {}", old, mode);
69 self.approve_mode.store(mode.to_u8(), Ordering::Relaxed);
70 }
71
72 pub fn approve_mode_shared(&self) -> Arc<AtomicU8> {
74 self.approve_mode.clone()
75 }
76
77 pub fn set_approve_mode_shared(&mut self, shared: Arc<AtomicU8>) {
79 self.approve_mode = shared;
80 }
81
82 pub fn update_memory_summary(&mut self, summary: Option<String>) {
84 self.memory_summary = summary;
85 self.system_prompt = prompt::build_system_prompt(
86 &self.profile,
87 &self.skills,
88 self.project_overview.as_deref(),
89 self.memory_summary.as_deref(),
90 );
91 }
92
93 pub async fn run(&mut self, user_input: String) -> Result<Vec<AgentEvent>> {
95 self.emit(AgentEvent::session_started())?;
96
97 self.messages.push(Message {
98 role: Role::User,
99 content: MessageContent::Text(user_input.clone()),
100 });
101
102 let mut iterations = 0;
103 let mut should_continue = true;
104
105 while should_continue && iterations < MAX_ITERATIONS {
106 iterations += 1;
107
108 if let Some(token) = &self.cancel_token
109 && token.is_cancelled()
110 {
111 self.emit(AgentEvent::error(
112 "Operation cancelled".to_string(),
113 None,
114 None,
115 ))?;
116 break;
117 }
118
119 let tool_defs: Vec<ToolDefinition> =
120 self.tools.iter().map(|t| t.definition()).collect();
121 let request = ChatRequest {
122 system: Some(self.system_prompt.clone()),
123 messages: self.messages.clone(),
124 max_tokens: self.max_tokens,
125 tools: tool_defs,
126 think: self.think,
127 enable_caching: true,
128 server_tools: Vec::new(),
129 };
130
131 let response = self.call_streaming(&request).await?;
132
133 self.track_usage(&response.usage);
134
135 crate::debug::debug_log().api_call(
136 &self.model_name,
137 response.usage.input_tokens,
138 response.usage.cache_read_input_tokens > 0,
139 );
140
141 should_continue = self.process_response(&response).await?;
142
143 let context_size = self.provider.context_size();
144 let api_tokens = self.last_input_tokens.load(Ordering::Relaxed) as u32;
145 let estimated_tokens = estimate_total_tokens(&self.messages);
146
147 let current_tokens = if api_tokens > 0 && api_tokens >= estimated_tokens / 2 {
148 api_tokens
149 } else {
150 estimated_tokens
151 };
152
153 crate::debug::debug_log().log(
154 "compression",
155 &format!(
156 "check: api={}, estimated={}, using={}, context={}, threshold={}",
157 api_tokens,
158 estimated_tokens,
159 current_tokens,
160 context_size.unwrap_or(0),
161 self.compression_config.threshold
162 ),
163 );
164
165 if should_compress(current_tokens, context_size, &self.compression_config) {
166 self.emit(AgentEvent::progress("Compressing context...", None))?;
167
168 let original_tokens = current_tokens;
169
170 match compress_messages(
171 &self.messages,
172 CompressionStrategy::SlidingWindow,
173 &self.compression_config,
174 ) {
175 Ok(compressed) => {
176 let compressed_tokens = estimate_total_tokens(&compressed);
177 self.messages = compressed;
178 self.total_input_tokens
179 .store(compressed_tokens as u64, Ordering::Relaxed);
180 self.last_input_tokens
181 .store(compressed_tokens as u64, Ordering::Relaxed);
182
183 let ratio = compressed_tokens as f32 / original_tokens as f32;
184 crate::debug::debug_log().compression(
185 original_tokens,
186 compressed_tokens,
187 ratio,
188 );
189
190 self.emit(AgentEvent::with_data(
191 EventType::CompressionCompleted,
192 EventData::Compression {
193 original_tokens: original_tokens as u64,
194 compressed_tokens: compressed_tokens as u64,
195 ratio: compressed_tokens as f32 / original_tokens as f32,
196 },
197 ))?;
198 }
199 Err(e) => {
200 self.emit(AgentEvent::progress(
201 format!("Compression failed: {}", e),
202 None,
203 ))?;
204 }
205 }
206 }
207 }
208
209 self.emit(AgentEvent::usage_with_cache(
210 self.total_input_tokens.load(Ordering::Relaxed),
211 self.total_output_tokens.load(Ordering::Relaxed),
212 0,
213 0,
214 ))?;
215
216 self.emit(AgentEvent::session_ended())?;
217
218 Ok(Vec::new())
219 }
220
221 pub fn set_messages(&mut self, messages: Vec<Message>) {
223 self.messages = messages;
224 }
225
226 pub fn get_messages(&self) -> &[Message] {
228 &self.messages
229 }
230
231 pub fn get_token_counts(&self) -> (u64, u64) {
233 (
234 self.total_input_tokens.load(Ordering::Relaxed),
235 self.total_output_tokens.load(Ordering::Relaxed),
236 )
237 }
238
239 pub fn clear_history(&mut self) {
241 self.messages.clear();
242 self.total_input_tokens.store(0, Ordering::Relaxed);
243 self.total_output_tokens.store(0, Ordering::Relaxed);
244 self.last_input_tokens.store(0, Ordering::Relaxed);
245 }
246
247 pub fn message_count(&self) -> usize {
249 self.messages.len()
250 }
251}