1use std::sync::atomic::{AtomicU64, Ordering};
6use std::sync::Arc;
7use anyhow::Result;
8use tokio::sync::mpsc;
9
10use crate::event::{AgentEvent, EventType, EventData};
11use crate::providers::{ChatRequest, ChatResponse, ContentBlock, Message, MessageContent, Provider, Role, StopReason, Usage};
12use crate::tools::{Tool, ToolDefinition};
13use crate::approval::{ApproveMode, needs_approval};
14use crate::compress::{CompressionConfig, should_compress};
15use crate::cancel::CancellationToken;
16
17const MAX_ITERATIONS: usize = 50;
18
19#[allow(dead_code)] pub struct Agent {
22 provider: Box<dyn Provider>,
23 model_name: String, tools: Vec<Arc<dyn Tool>>,
25 messages: Vec<Message>,
26 system_prompt: String,
27 max_tokens: u32,
28 think: bool,
29 approve_mode: ApproveMode,
30 event_tx: mpsc::Sender<AgentEvent>,
31
32 skills: Vec<crate::skills::Skill>,
34 profile: crate::prompt::PromptProfile,
35 project_overview: Option<String>,
36 memory_summary: Option<String>,
37
38 total_input_tokens: AtomicU64,
40 total_output_tokens: AtomicU64,
41 last_input_tokens: AtomicU64,
43 cancel_token: Option<CancellationToken>,
44 compression_config: CompressionConfig,
45
46 ask_rx: Option<mpsc::Receiver<String>>,
48}
49
50pub struct AgentBuilder {
52 provider: Box<dyn Provider>,
53 model_name: String,
54 tools: Vec<Arc<dyn Tool>>,
55 system_prompt: String,
56 max_tokens: u32,
57 think: bool,
58 approve_mode: ApproveMode,
59 event_tx: Option<mpsc::Sender<AgentEvent>>,
60 skills: Vec<crate::skills::Skill>,
62 profile: crate::prompt::PromptProfile,
63 project_overview: Option<String>,
64 memory_summary: Option<String>,
65}
66
67impl AgentBuilder {
68 pub fn new(provider: Box<dyn Provider>) -> Self {
69 Self {
70 provider,
71 model_name: "unknown".to_string(),
72 tools: Vec::new(),
73 system_prompt: "You are a helpful AI coding assistant.".to_string(),
74 max_tokens: 4096,
75 think: false,
76 approve_mode: ApproveMode::Ask,
77 event_tx: None,
78 skills: Vec::new(),
79 profile: crate::prompt::PromptProfile::Default,
80 project_overview: None,
81 memory_summary: None,
82 }
83 }
84
85 pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
86 self.system_prompt = prompt.into();
87 self
88 }
89
90 pub fn model_name(mut self, name: impl Into<String>) -> Self {
91 self.model_name = name.into();
92 self
93 }
94
95 pub fn max_tokens(mut self, tokens: u32) -> Self {
96 self.max_tokens = tokens;
97 self
98 }
99
100 pub fn think(mut self, enabled: bool) -> Self {
101 self.think = enabled;
102 self
103 }
104
105 pub fn approve_mode(mut self, mode: ApproveMode) -> Self {
106 self.approve_mode = mode;
107 self
108 }
109
110 pub fn tool(mut self, tool: Arc<dyn Tool>) -> Self {
111 self.tools.push(tool);
112 self
113 }
114
115 pub fn tools(mut self, tools: Vec<Box<dyn Tool>>) -> Self {
117 self.tools.extend(tools.into_iter().map(Arc::from));
118 self
119 }
120
121 pub fn event_tx(mut self, tx: mpsc::Sender<AgentEvent>) -> Self {
123 self.event_tx = Some(tx);
124 self
125 }
126
127 pub fn skills(mut self, skills: Vec<crate::skills::Skill>) -> Self {
129 self.skills = skills;
130 self
131 }
132
133 pub fn profile(mut self, profile: crate::prompt::PromptProfile) -> Self {
135 self.profile = profile;
136 self
137 }
138
139 pub fn overview(mut self, overview: impl Into<String>) -> Self {
141 self.project_overview = Some(overview.into());
142 self
143 }
144
145 pub fn memory(mut self, summary: impl Into<String>) -> Self {
147 self.memory_summary = Some(summary.into());
148 self
149 }
150
151 pub fn build(self) -> Agent {
152 Agent::new(self)
153 }
154}
155
156impl Agent {
157 fn new(builder: AgentBuilder) -> Self {
158 let event_tx = builder.event_tx.unwrap_or_else(|| {
160 let (tx, _) = mpsc::channel(100);
161 tx
162 });
163
164 Self {
165 provider: builder.provider,
166 model_name: builder.model_name,
167 tools: builder.tools,
168 messages: Vec::new(),
169 system_prompt: builder.system_prompt,
170 max_tokens: builder.max_tokens,
171 think: builder.think,
172 approve_mode: builder.approve_mode,
173 event_tx,
174 skills: builder.skills,
175 profile: builder.profile,
176 project_overview: builder.project_overview,
177 memory_summary: builder.memory_summary,
178 total_input_tokens: AtomicU64::new(0),
179 total_output_tokens: AtomicU64::new(0),
180 last_input_tokens: AtomicU64::new(0),
181 cancel_token: None,
182 compression_config: CompressionConfig::default(),
183 ask_rx: None,
184 }
185 }
186
187 pub fn event_sender(&self) -> mpsc::Sender<AgentEvent> {
189 self.event_tx.clone()
190 }
191
192 pub fn set_ask_channel(&mut self, rx: mpsc::Receiver<String>) {
194 self.ask_rx = Some(rx);
195 }
196
197 pub fn set_cancel_token(&mut self, token: CancellationToken) {
199 self.cancel_token = Some(token);
200 }
201
202 pub fn set_approve_mode(&mut self, mode: ApproveMode) {
204 log::info!("Agent approve mode changed: {} -> {}", self.approve_mode, mode);
205 self.approve_mode = mode;
206 }
207
208 pub fn update_memory_summary(&mut self, summary: Option<String>) {
211 self.memory_summary = summary;
212 self.system_prompt = crate::prompt::build_system_prompt(
214 &self.profile,
215 &self.skills,
216 self.project_overview.as_deref(),
217 self.memory_summary.as_deref(),
218 );
219 }
220
221 pub async fn run(&mut self, user_input: String) -> Result<Vec<AgentEvent>> {
223 let collector = EventCollector::new();
224
225 self.emit(AgentEvent::session_started())?;
227
228 self.messages.push(Message {
230 role: Role::User,
231 content: MessageContent::Text(user_input.clone()),
232 });
233
234 let mut iterations = 0;
236 let mut should_continue = true;
237
238 while should_continue && iterations < MAX_ITERATIONS {
239 iterations += 1;
240
241 if let Some(token) = &self.cancel_token
243 && token.is_cancelled()
244 {
245 self.emit(AgentEvent::error("Operation cancelled".to_string(), None, None))?;
246 break;
247 }
248
249 let tool_defs: Vec<ToolDefinition> = self.tools.iter().map(|t| t.definition()).collect();
251 let request = ChatRequest {
252 system: Some(self.system_prompt.clone()),
253 messages: self.messages.clone(),
254 max_tokens: self.max_tokens,
255 tools: tool_defs,
256 think: self.think,
257 enable_caching: true,
258 server_tools: Vec::new(),
259 };
260
261 self.emit(AgentEvent::progress(
263 if iterations == 1 { "Thinking..." } else { "Processing..." },
264 None,
265 ))?;
266
267 let response = self.call_streaming(&request).await?;
269
270 self.track_usage(&response.usage);
272
273 crate::debug::debug_log().api_call(
275 &self.model_name,
276 response.usage.input_tokens,
277 response.usage.cache_read_input_tokens > 0
278 );
279
280 should_continue = self.process_response(&response).await?;
282
283 let context_size = self.provider.context_size();
285 let current_tokens = self.last_input_tokens.load(Ordering::Relaxed) as u32;
286 if should_compress(current_tokens, context_size, &self.compression_config) {
287 self.emit(AgentEvent::progress("Compressing context...", None))?;
288
289 let _original_count = self.messages.len();
290 let original_tokens = current_tokens;
291
292 match crate::compress::compress_messages(
294 &self.messages,
295 crate::compress::CompressionStrategy::SlidingWindow,
296 &self.compression_config,
297 ) {
298 Ok(compressed) => {
299 let compressed_tokens = crate::compress::estimate_total_tokens(&compressed);
300 self.messages = compressed;
301 self.total_input_tokens.store(compressed_tokens as u64, Ordering::Relaxed);
302 self.last_input_tokens.store(compressed_tokens as u64, Ordering::Relaxed);
303
304 let ratio = compressed_tokens as f32 / original_tokens as f32;
306 crate::debug::debug_log().compression(original_tokens, compressed_tokens, ratio);
307
308 self.emit(AgentEvent::with_data(
309 crate::event::EventType::CompressionCompleted,
310 crate::event::EventData::Compression {
311 original_tokens: original_tokens as u64,
312 compressed_tokens: compressed_tokens as u64,
313 ratio: compressed_tokens as f32 / original_tokens as f32,
314 },
315 ))?;
316 }
317 Err(e) => {
318 self.emit(AgentEvent::progress(
319 format!("Compression failed: {}", e),
320 None,
321 ))?;
322 }
323 }
324 }
325 }
326
327 self.emit(AgentEvent::usage_with_cache(
329 self.last_input_tokens.load(Ordering::Relaxed),
330 self.total_output_tokens.load(Ordering::Relaxed),
331 0, 0, ))?;
333
334 self.emit(AgentEvent::session_ended())?;
336
337 Ok(collector.events().to_vec())
338 }
339
340 async fn call_streaming(&mut self, request: &ChatRequest) -> Result<ChatResponse> {
342 use crate::providers::StreamEvent;
343
344 const MAX_RETRIES: u32 = 5;
345 const RETRY_DELAY_MS: u64 = 1000; let mut attempt = 0;
348
349 loop {
350 attempt += 1;
351
352 let rx_result = self.provider.chat_stream(request.clone()).await;
354
355 match rx_result {
356 Ok(mut rx) => {
357 let mut response_content: Vec<ContentBlock> = Vec::new();
359 let mut current_text = String::new();
360 let mut current_thinking = String::new();
361 let mut usage = Usage {
362 input_tokens: 0,
363 output_tokens: 0,
364 cache_creation_input_tokens: 0,
365 cache_read_input_tokens: 0,
366 };
367
368 while let Some(event) = rx.recv().await {
369 match event {
370 StreamEvent::FirstByte => {
371 }
373 StreamEvent::ThinkingDelta(delta) => {
374 if current_thinking.is_empty() {
375 self.emit(AgentEvent::thinking_start())?;
376 }
377 current_thinking.push_str(&delta);
378 self.emit(AgentEvent::thinking_delta(delta, None))?;
379 }
380 StreamEvent::TextDelta(delta) => {
381 if current_text.is_empty() {
382 self.emit(AgentEvent::text_start())?;
383 }
384 current_text.push_str(&delta);
385 self.emit(AgentEvent::text_delta(delta))?;
386 }
387 StreamEvent::ToolUseStart { id, name } => {
388 if !current_text.is_empty() {
390 self.emit(AgentEvent::text_end())?;
391 response_content.push(ContentBlock::Text { text: current_text.clone() });
392 current_text.clear();
393 }
394 if !current_thinking.is_empty() {
396 self.emit(AgentEvent::thinking_end())?;
397 response_content.push(ContentBlock::Thinking {
398 thinking: current_thinking.clone(),
399 signature: None,
400 });
401 current_thinking.clear();
402 }
403 self.emit(AgentEvent::tool_use_start(&id, &name, None))?;
404 }
405 StreamEvent::ToolInputDelta { bytes_so_far: _ } => {
406 }
408 StreamEvent::Done(resp) => {
409 if !current_text.is_empty() {
411 self.emit(AgentEvent::text_end())?;
412 response_content.push(ContentBlock::Text { text: current_text.clone() });
413 }
414 if !current_thinking.is_empty() {
416 self.emit(AgentEvent::thinking_end())?;
417 response_content.push(ContentBlock::Thinking {
418 thinking: current_thinking.clone(),
419 signature: None,
420 });
421 }
422 for block in &resp.content {
424 if !response_content.iter().any(|b| b == block) {
425 response_content.push(block.clone());
426 }
427 }
428 usage = resp.usage;
429 }
430 StreamEvent::Error(msg) => {
431 if attempt < MAX_RETRIES {
433 self.emit(AgentEvent::progress(
434 format!("⚠️ Stream error, retrying ({}/{}): {}", attempt, MAX_RETRIES, &msg),
435 None,
436 ))?;
437 let delay = RETRY_DELAY_MS * (1 << (attempt - 1));
439 tokio::time::sleep(tokio::time::Duration::from_millis(delay)).await;
440 continue; } else {
442 self.emit(AgentEvent::error(msg.clone(), None, None))?;
443 return Err(anyhow::anyhow!("Stream error after {} retries: {}", MAX_RETRIES, msg));
444 }
445 }
446 }
447 }
448
449 return Ok(ChatResponse {
450 content: response_content,
451 stop_reason: StopReason::EndTurn,
452 usage,
453 });
454 }
455 Err(e) => {
456 if attempt < MAX_RETRIES {
458 let error_msg = e.to_string();
459 self.emit(AgentEvent::progress(
460 format!("⚠️ API error, retrying ({}/{}): {}", attempt, MAX_RETRIES, &error_msg),
461 None,
462 ))?;
463 let delay = RETRY_DELAY_MS * (1 << (attempt - 1));
465 tokio::time::sleep(tokio::time::Duration::from_millis(delay)).await;
466 } else {
467 return Err(anyhow::anyhow!("API error after {} retries: {}", MAX_RETRIES, e));
468 }
469 }
470 }
471 }
472 }
473
474 async fn process_response(&mut self, response: &ChatResponse) -> Result<bool> {
476 let mut has_tool_use = false;
477 let mut assistant_content: Vec<ContentBlock> = Vec::new();
478 let mut tool_results: Vec<Message> = Vec::new();
479
480 for block in &response.content {
481 match block {
482 ContentBlock::Text { text } => {
484 assistant_content.push(ContentBlock::Text { text: text.clone() });
485 }
486
487 ContentBlock::Thinking { thinking, signature } => {
488 assistant_content.push(ContentBlock::Thinking {
489 thinking: thinking.clone(),
490 signature: signature.clone(),
491 });
492 }
493
494 ContentBlock::ToolUse { id, name, input } => {
495 has_tool_use = true;
496
497 self.emit(AgentEvent::tool_use_start(id.clone(), name.clone(), Some(input.clone())))?;
498
499 let result = self.execute_tool(name, input.clone()).await;
501
502 let (content, is_error) = match result {
503 Ok(output) => (output, false),
504 Err(e) => (e.to_string(), true),
505 };
506
507 self.emit(AgentEvent::tool_result(id.clone(), content.clone(), is_error))?;
508
509 assistant_content.push(ContentBlock::ToolUse {
511 id: id.clone(),
512 name: name.clone(),
513 input: input.clone(),
514 });
515
516 tool_results.push(Message {
518 role: Role::User,
519 content: MessageContent::Blocks(vec![ContentBlock::ToolResult {
520 tool_use_id: id.clone(),
521 content: format!("{}: {}", if is_error { "Error" } else { "Result" }, content),
522 }]),
523 });
524 }
525
526 _ => {}
527 }
528 }
529
530 if !assistant_content.is_empty() {
532 self.messages.push(Message {
533 role: Role::Assistant,
534 content: MessageContent::Blocks(assistant_content),
535 });
536 }
537
538 for msg in tool_results {
540 self.messages.push(msg);
541 }
542
543 Ok(has_tool_use)
545 }
546
547 async fn execute_tool(&mut self, name: &str, input: serde_json::Value) -> Result<String> {
549 let tool = self.tools.iter().find(|t| t.definition().name == name);
550
551 if let Some(tool) = tool {
552 log::debug!(
554 "Tool '{}' approval check: mode={}, risk={}, needs_approval={}",
555 name, self.approve_mode, tool.risk_level(),
556 needs_approval(self.approve_mode, tool.risk_level())
557 );
558
559 if needs_approval(self.approve_mode, tool.risk_level()) {
561 if self.ask_rx.is_some() {
563 let detail = match name {
565 "bash" => format!("Command: {}", input["command"].as_str().unwrap_or("?")),
566 "write" => format!("File: {}", input["path"].as_str().unwrap_or("?")),
567 "edit" | "multi_edit" => format!("File: {}", input["path"].as_str().unwrap_or("?")),
568 _ => format!("Tool: {}", name),
569 };
570
571 let question = format!(
572 "⚠️ Tool '{}' requires approval (risk: {})\n{}\n\nAllow? (y/n)",
573 name, tool.risk_level(), detail
574 );
575
576 self.emit(AgentEvent::with_data(
578 EventType::AskQuestion,
579 EventData::AskQuestion { question, options: None },
580 ))?;
581
582 if let Some(rx) = &mut self.ask_rx {
584 match rx.recv().await {
585 Some(answer) => {
586 let answer_lower = answer.trim().to_lowercase();
587 if matches!(answer_lower.as_str(), "a" | "abort" | "q" | "quit" | "stop") {
589 self.emit(AgentEvent::with_data(
590 EventType::Error,
591 EventData::Error { message: "Aborted by user".into(), code: None, source: None },
592 ))?;
593 return Err(anyhow::anyhow!("Session aborted by user"));
594 }
595 let approved = matches!(
597 answer_lower.as_str(),
598 "y" | "yes" | "ok" | "approve" | ""
599 );
600 if !approved {
601 return Err(anyhow::anyhow!(
603 "Tool '{}' rejected by user (answer: '{}')", name, answer_lower
604 ));
605 }
606 }
607 None => {
608 return Err(anyhow::anyhow!("Approval channel closed"));
609 }
610 }
611 }
612 } else {
613 return Err(anyhow::anyhow!(
615 "Tool '{}' requires manual approval (risk: {}). Use --approve-mode auto to auto-approve.",
616 name, tool.risk_level()
617 ));
618 }
619 }
620
621 if name == "ask" && self.ask_rx.is_some() {
623 let question = input["question"].as_str().unwrap_or("").to_string();
624 let options = input.get("options").cloned();
625
626 self.emit(AgentEvent::with_data(
628 EventType::AskQuestion,
629 EventData::AskQuestion { question, options },
630 ))?;
631
632 if let Some(rx) = &mut self.ask_rx {
634 match rx.recv().await {
635 Some(answer) => return Ok(answer),
636 None => return Err(anyhow::anyhow!("Ask channel closed")),
637 }
638 }
639 }
640
641 self.emit(AgentEvent::progress(format!("Executing: {}", name), None))?;
643 tool.execute(input).await
644 } else {
645 Err(anyhow::anyhow!("Tool '{}' not found", name))
646 }
647 }
648
649 fn track_usage(&self, usage: &Usage) {
651 self.total_input_tokens.fetch_add(usage.input_tokens as u64, Ordering::Relaxed);
652 self.total_output_tokens.fetch_add(usage.output_tokens as u64, Ordering::Relaxed);
653 self.last_input_tokens.store(usage.input_tokens as u64, Ordering::Relaxed);
655
656 let _ = self.event_tx.try_send(AgentEvent::usage_with_cache(
658 usage.input_tokens as u64,
659 usage.output_tokens as u64,
660 usage.cache_read_input_tokens as u64,
661 usage.cache_creation_input_tokens as u64,
662 ));
663 }
664
665 #[allow(dead_code)]
667 fn estimate_context_size(&self) -> u32 {
668 (self.messages.len() as u32) * 100 + self.total_input_tokens.load(Ordering::Relaxed) as u32
670 }
671
672 fn emit(&self, event: AgentEvent) -> Result<()> {
674 match self.event_tx.try_send(event) {
676 Ok(_) => Ok(()),
677 Err(mpsc::error::TrySendError::Full(_)) => {
678 Ok(())
680 }
681 Err(mpsc::error::TrySendError::Closed(_)) => {
682 Err(anyhow::anyhow!("Event channel closed"))
684 }
685 }
686 }
687
688 pub fn set_messages(&mut self, messages: Vec<Message>) {
690 self.messages = messages;
691 }
692
693 pub fn get_messages(&self) -> &[Message] {
695 &self.messages
696 }
697
698 pub fn get_token_counts(&self) -> (u64, u64) {
700 (
701 self.total_input_tokens.load(Ordering::Relaxed),
702 self.total_output_tokens.load(Ordering::Relaxed),
703 )
704 }
705
706 pub fn clear_history(&mut self) {
708 self.messages.clear();
709 self.total_input_tokens.store(0, Ordering::Relaxed);
710 self.total_output_tokens.store(0, Ordering::Relaxed);
711 self.last_input_tokens.store(0, Ordering::Relaxed);
712 }
713
714 pub fn message_count(&self) -> usize {
716 self.messages.len()
717 }
718}
719
720#[derive(Default)]
722pub struct EventCollector {
723 events: Vec<AgentEvent>,
724}
725
726impl EventCollector {
727 pub fn new() -> Self {
728 Self::default()
729 }
730
731 pub fn events(&self) -> &[AgentEvent] {
732 &self.events
733 }
734}