1use std::sync::atomic::{AtomicU8, AtomicU64, Ordering};
6use std::sync::Arc;
7use anyhow::Result;
8use tokio::sync::mpsc;
9
10use crate::event::{AgentEvent, EventType, EventData};
11use crate::providers::{ChatRequest, ChatResponse, ContentBlock, Message, MessageContent, Provider, Role, StopReason, Usage};
12use crate::tools::{Tool, ToolDefinition};
13use crate::approval::{ApproveMode, needs_approval};
14use crate::compress::{CompressionConfig, should_compress};
15use crate::cancel::CancellationToken;
16
17const MAX_ITERATIONS: usize = 50;
18
19#[allow(dead_code)] pub struct Agent {
22 provider: Box<dyn Provider>,
23 model_name: String, tools: Vec<Arc<dyn Tool>>,
25 messages: Vec<Message>,
26 system_prompt: String,
27 max_tokens: u32,
28 think: bool,
29 approve_mode: Arc<AtomicU8>,
30 event_tx: mpsc::Sender<AgentEvent>,
31
32 skills: Vec<crate::skills::Skill>,
34 profile: crate::prompt::PromptProfile,
35 project_overview: Option<String>,
36 memory_summary: Option<String>,
37
38 total_input_tokens: AtomicU64,
40 total_output_tokens: AtomicU64,
41 last_input_tokens: AtomicU64,
43 cancel_token: Option<CancellationToken>,
44 compression_config: CompressionConfig,
45
46 ask_rx: Option<mpsc::Receiver<String>>,
48}
49
50pub struct AgentBuilder {
52 provider: Box<dyn Provider>,
53 model_name: String,
54 tools: Vec<Arc<dyn Tool>>,
55 system_prompt: String,
56 max_tokens: u32,
57 think: bool,
58 approve_mode: ApproveMode,
59 event_tx: Option<mpsc::Sender<AgentEvent>>,
60 skills: Vec<crate::skills::Skill>,
62 profile: crate::prompt::PromptProfile,
63 project_overview: Option<String>,
64 memory_summary: Option<String>,
65}
66
67impl AgentBuilder {
68 pub fn new(provider: Box<dyn Provider>) -> Self {
69 Self {
70 provider,
71 model_name: "unknown".to_string(),
72 tools: Vec::new(),
73 system_prompt: "You are a helpful AI coding assistant.".to_string(),
74 max_tokens: 4096,
75 think: false,
76 approve_mode: ApproveMode::Ask,
77 event_tx: None,
78 skills: Vec::new(),
79 profile: crate::prompt::PromptProfile::Default,
80 project_overview: None,
81 memory_summary: None,
82 }
83 }
84
85 pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
86 self.system_prompt = prompt.into();
87 self
88 }
89
90 pub fn model_name(mut self, name: impl Into<String>) -> Self {
91 self.model_name = name.into();
92 self
93 }
94
95 pub fn max_tokens(mut self, tokens: u32) -> Self {
96 self.max_tokens = tokens;
97 self
98 }
99
100 pub fn think(mut self, enabled: bool) -> Self {
101 self.think = enabled;
102 self
103 }
104
105 pub fn approve_mode(mut self, mode: ApproveMode) -> Self {
106 self.approve_mode = mode;
107 self
108 }
109
110 pub fn tool(mut self, tool: Arc<dyn Tool>) -> Self {
111 self.tools.push(tool);
112 self
113 }
114
115 pub fn tools(mut self, tools: Vec<Box<dyn Tool>>) -> Self {
117 self.tools.extend(tools.into_iter().map(Arc::from));
118 self
119 }
120
121 pub fn event_tx(mut self, tx: mpsc::Sender<AgentEvent>) -> Self {
123 self.event_tx = Some(tx);
124 self
125 }
126
127 pub fn skills(mut self, skills: Vec<crate::skills::Skill>) -> Self {
129 self.skills = skills;
130 self
131 }
132
133 pub fn profile(mut self, profile: crate::prompt::PromptProfile) -> Self {
135 self.profile = profile;
136 self
137 }
138
139 pub fn overview(mut self, overview: impl Into<String>) -> Self {
141 self.project_overview = Some(overview.into());
142 self
143 }
144
145 pub fn memory(mut self, summary: impl Into<String>) -> Self {
147 self.memory_summary = Some(summary.into());
148 self
149 }
150
151 pub fn build(self) -> Agent {
152 Agent::new(self)
153 }
154}
155
156impl Agent {
157 fn new(builder: AgentBuilder) -> Self {
158 let event_tx = builder.event_tx.unwrap_or_else(|| {
160 let (tx, _) = mpsc::channel(100);
161 tx
162 });
163
164 Self {
165 provider: builder.provider,
166 model_name: builder.model_name,
167 tools: builder.tools,
168 messages: Vec::new(),
169 system_prompt: builder.system_prompt,
170 max_tokens: builder.max_tokens,
171 think: builder.think,
172 approve_mode: Arc::new(AtomicU8::new(builder.approve_mode.to_u8())),
173 event_tx,
174 skills: builder.skills,
175 profile: builder.profile,
176 project_overview: builder.project_overview,
177 memory_summary: builder.memory_summary,
178 total_input_tokens: AtomicU64::new(0),
179 total_output_tokens: AtomicU64::new(0),
180 last_input_tokens: AtomicU64::new(0),
181 cancel_token: None,
182 compression_config: CompressionConfig::default(),
183 ask_rx: None,
184 }
185 }
186
187 pub fn event_sender(&self) -> mpsc::Sender<AgentEvent> {
189 self.event_tx.clone()
190 }
191
192 pub fn set_ask_channel(&mut self, rx: mpsc::Receiver<String>) {
194 self.ask_rx = Some(rx);
195 }
196
197 pub fn set_cancel_token(&mut self, token: CancellationToken) {
199 self.cancel_token = Some(token);
200 }
201
202 pub fn set_approve_mode(&mut self, mode: ApproveMode) {
204 let old = ApproveMode::from_u8(self.approve_mode.load(Ordering::Relaxed));
205 log::info!("Agent approve mode changed: {} -> {}", old, mode);
206 self.approve_mode.store(mode.to_u8(), Ordering::Relaxed);
207 }
208
209 pub fn approve_mode_shared(&self) -> Arc<AtomicU8> {
212 self.approve_mode.clone()
213 }
214
215 pub fn set_approve_mode_shared(&mut self, shared: Arc<AtomicU8>) {
219 self.approve_mode = shared;
220 }
221
222 pub fn update_memory_summary(&mut self, summary: Option<String>) {
225 self.memory_summary = summary;
226 self.system_prompt = crate::prompt::build_system_prompt(
228 &self.profile,
229 &self.skills,
230 self.project_overview.as_deref(),
231 self.memory_summary.as_deref(),
232 );
233 }
234
235 pub async fn run(&mut self, user_input: String) -> Result<Vec<AgentEvent>> {
240 self.emit(AgentEvent::session_started())?;
242
243 self.messages.push(Message {
245 role: Role::User,
246 content: MessageContent::Text(user_input.clone()),
247 });
248
249 let mut iterations = 0;
251 let mut should_continue = true;
252
253 while should_continue && iterations < MAX_ITERATIONS {
254 iterations += 1;
255
256 if let Some(token) = &self.cancel_token
258 && token.is_cancelled()
259 {
260 self.emit(AgentEvent::error("Operation cancelled".to_string(), None, None))?;
261 break;
262 }
263
264 let tool_defs: Vec<ToolDefinition> = self.tools.iter().map(|t| t.definition()).collect();
266 let request = ChatRequest {
267 system: Some(self.system_prompt.clone()),
268 messages: self.messages.clone(),
269 max_tokens: self.max_tokens,
270 tools: tool_defs,
271 think: self.think,
272 enable_caching: true,
273 server_tools: Vec::new(),
274 };
275
276 let response = self.call_streaming(&request).await?;
280
281 self.track_usage(&response.usage);
283
284 crate::debug::debug_log().api_call(
286 &self.model_name,
287 response.usage.input_tokens,
288 response.usage.cache_read_input_tokens > 0
289 );
290
291 should_continue = self.process_response(&response).await?;
293
294 let context_size = self.provider.context_size();
296
297 let api_tokens = self.last_input_tokens.load(Ordering::Relaxed) as u32;
300 let estimated_tokens = crate::compress::estimate_total_tokens(&self.messages);
301
302 let current_tokens = api_tokens.max(estimated_tokens);
304
305 crate::debug::debug_log().log(
307 "compression",
308 &format!("check: api_tokens={}, estimated_tokens={}, using={}, context_size={}, threshold={}",
309 api_tokens, estimated_tokens, current_tokens, context_size.unwrap_or(0), self.compression_config.threshold)
310 );
311
312 if should_compress(current_tokens, context_size, &self.compression_config) {
313 self.emit(AgentEvent::progress("Compressing context...", None))?;
314
315 let _original_count = self.messages.len();
316 let original_tokens = current_tokens;
317
318 match crate::compress::compress_messages(
320 &self.messages,
321 crate::compress::CompressionStrategy::SlidingWindow,
322 &self.compression_config,
323 ) {
324 Ok(compressed) => {
325 let compressed_tokens = crate::compress::estimate_total_tokens(&compressed);
326 self.messages = compressed;
327 self.total_input_tokens.store(compressed_tokens as u64, Ordering::Relaxed);
328 self.last_input_tokens.store(compressed_tokens as u64, Ordering::Relaxed);
329
330 let ratio = compressed_tokens as f32 / original_tokens as f32;
332 crate::debug::debug_log().compression(original_tokens, compressed_tokens, ratio);
333
334 self.emit(AgentEvent::with_data(
335 crate::event::EventType::CompressionCompleted,
336 crate::event::EventData::Compression {
337 original_tokens: original_tokens as u64,
338 compressed_tokens: compressed_tokens as u64,
339 ratio: compressed_tokens as f32 / original_tokens as f32,
340 },
341 ))?;
342 }
343 Err(e) => {
344 self.emit(AgentEvent::progress(
345 format!("Compression failed: {}", e),
346 None,
347 ))?;
348 }
349 }
350 }
351 }
352
353 self.emit(AgentEvent::usage_with_cache(
355 self.last_input_tokens.load(Ordering::Relaxed),
356 self.total_output_tokens.load(Ordering::Relaxed),
357 0, 0, ))?;
359
360 self.emit(AgentEvent::session_ended())?;
362
363 Ok(Vec::new())
364 }
365
366 async fn call_streaming(&mut self, request: &ChatRequest) -> Result<ChatResponse> {
368 use crate::providers::StreamEvent;
369
370 const MAX_RETRIES: u32 = 5;
371 const RETRY_DELAY_MS: u64 = 1000; let mut attempt = 0;
374
375 loop {
376 attempt += 1;
377
378 let rx_result = self.provider.chat_stream(request.clone()).await;
380
381 match rx_result {
382 Ok(mut rx) => {
383 let mut response_content: Vec<ContentBlock> = Vec::new();
385 let mut current_text = String::new();
386 let mut current_thinking = String::new();
387 let mut usage = Usage {
388 input_tokens: 0,
389 output_tokens: 0,
390 cache_creation_input_tokens: 0,
391 cache_read_input_tokens: 0,
392 };
393 let mut should_retry = false;
394
395 while let Some(event) = rx.recv().await {
396 match event {
397 StreamEvent::FirstByte => {
398 }
400 StreamEvent::ThinkingDelta(delta) => {
401 if current_thinking.is_empty() {
402 self.emit(AgentEvent::thinking_start())?;
403 }
404 current_thinking.push_str(&delta);
405 self.emit(AgentEvent::thinking_delta(delta, None))?;
406 }
407 StreamEvent::TextDelta(delta) => {
408 if current_text.is_empty() {
409 self.emit(AgentEvent::text_start())?;
410 }
411 current_text.push_str(&delta);
412 self.emit(AgentEvent::text_delta(delta))?;
413 }
414 StreamEvent::ToolUseStart { id, name } => {
415 if !current_thinking.is_empty() {
417 self.emit(AgentEvent::thinking_end())?;
418 response_content.push(ContentBlock::Thinking {
419 thinking: current_thinking.clone(),
420 signature: None,
421 });
422 current_thinking.clear();
423 }
424 if !current_text.is_empty() {
426 self.emit(AgentEvent::text_end())?;
427 response_content.push(ContentBlock::Text { text: current_text.clone() });
428 current_text.clear();
429 }
430 self.emit(AgentEvent::tool_use_start(&id, &name, None))?;
431 }
432 StreamEvent::ToolInputDelta { bytes_so_far: _ } => {
433 }
435 StreamEvent::Usage { output_tokens } => {
436 self.emit(AgentEvent::usage_with_cache(
438 0, output_tokens as u64,
440 0, 0 ))?;
442 usage.output_tokens = output_tokens;
443 }
444 StreamEvent::Done(resp) => {
445 if !current_thinking.is_empty() {
447 self.emit(AgentEvent::thinking_end())?;
448 response_content.push(ContentBlock::Thinking {
449 thinking: current_thinking.clone(),
450 signature: None,
451 });
452 }
453 if !current_text.is_empty() {
455 self.emit(AgentEvent::text_end())?;
456 response_content.push(ContentBlock::Text { text: current_text.clone() });
457 }
458 for block in &resp.content {
460 if !response_content.iter().any(|b| b == block) {
461 response_content.push(block.clone());
462 }
463 }
464 usage = resp.usage;
465 }
466 StreamEvent::Error(msg) => {
467 if attempt < MAX_RETRIES {
469 self.emit(AgentEvent::progress(
470 format!("⚠️ Stream error, retrying ({}/{}): {}", attempt, MAX_RETRIES, &msg),
471 None,
472 ))?;
473 let delay = RETRY_DELAY_MS * (1 << (attempt - 1));
475 tokio::time::sleep(tokio::time::Duration::from_millis(delay)).await;
476 should_retry = true;
477 break; } else {
479 self.emit(AgentEvent::error(msg.clone(), None, None))?;
480 return Err(anyhow::anyhow!("Stream error after {} retries: {}", MAX_RETRIES, msg));
481 }
482 }
483 }
484 }
485
486 if should_retry {
487 continue; }
489
490 return Ok(ChatResponse {
491 content: response_content,
492 stop_reason: StopReason::EndTurn,
493 usage,
494 });
495 }
496 Err(e) => {
497 if attempt < MAX_RETRIES {
499 let error_msg = e.to_string();
500 self.emit(AgentEvent::progress(
501 format!("⚠️ API error, retrying ({}/{}): {}", attempt, MAX_RETRIES, &error_msg),
502 None,
503 ))?;
504 let delay = RETRY_DELAY_MS * (1 << (attempt - 1));
506 tokio::time::sleep(tokio::time::Duration::from_millis(delay)).await;
507 } else {
508 return Err(anyhow::anyhow!("API error after {} retries: {}", MAX_RETRIES, e));
509 }
510 }
511 }
512 }
513 }
514
515 async fn process_response(&mut self, response: &ChatResponse) -> Result<bool> {
517 let mut has_tool_use = false;
518 let mut assistant_content: Vec<ContentBlock> = Vec::new();
519 let mut tool_results: Vec<Message> = Vec::new();
520
521 for block in &response.content {
522 match block {
523 ContentBlock::Text { text } => {
525 assistant_content.push(ContentBlock::Text { text: text.clone() });
526 }
527
528 ContentBlock::Thinking { thinking, signature } => {
529 assistant_content.push(ContentBlock::Thinking {
530 thinking: thinking.clone(),
531 signature: signature.clone(),
532 });
533 }
534
535 ContentBlock::ToolUse { id, name, input } => {
536 has_tool_use = true;
537
538 let result = self.execute_tool(name, input.clone()).await;
543
544 let (content, is_error) = match result {
545 Ok(output) => (output, false),
546 Err(e) => (e.to_string(), true),
547 };
548
549 self.emit(AgentEvent::tool_result(id.clone(), name.clone(), content.clone(), is_error))?;
550
551 assistant_content.push(ContentBlock::ToolUse {
553 id: id.clone(),
554 name: name.clone(),
555 input: input.clone(),
556 });
557
558 tool_results.push(Message {
560 role: Role::User,
561 content: MessageContent::Blocks(vec![ContentBlock::ToolResult {
562 tool_use_id: id.clone(),
563 content: format!("{}: {}", if is_error { "Error" } else { "Result" }, content),
564 }]),
565 });
566 }
567
568 _ => {}
569 }
570 }
571
572 if !assistant_content.is_empty() {
574 self.messages.push(Message {
575 role: Role::Assistant,
576 content: MessageContent::Blocks(assistant_content),
577 });
578 }
579
580 for msg in tool_results {
582 self.messages.push(msg);
583 }
584
585 Ok(has_tool_use)
587 }
588
589 async fn execute_tool(&mut self, name: &str, input: serde_json::Value) -> Result<String> {
591 let tool = self.tools.iter().find(|t| t.definition().name == name);
592
593 if let Some(tool) = tool {
594 let current_mode = ApproveMode::from_u8(self.approve_mode.load(Ordering::Relaxed));
596
597 log::debug!(
599 "Tool '{}' approval check: mode={}, risk={}, needs_approval={}",
600 name, current_mode, tool.risk_level(),
601 needs_approval(current_mode, tool.risk_level())
602 );
603
604 if needs_approval(current_mode, tool.risk_level()) {
606 if self.ask_rx.is_some() {
608 let detail = match name {
610 "bash" => format!("Command: {}", input["command"].as_str().unwrap_or("?")),
611 "write" => format!("File: {}", input["path"].as_str().unwrap_or("?")),
612 "edit" | "multi_edit" => format!("File: {}", input["path"].as_str().unwrap_or("?")),
613 _ => format!("Tool: {}", name),
614 };
615
616 let question = format!(
617 "⚠️ Tool '{}' requires approval (risk: {})\n{}\n\nAllow? (y/n)",
618 name, tool.risk_level(), detail
619 );
620
621 self.emit(AgentEvent::with_data(
623 EventType::AskQuestion,
624 EventData::AskQuestion { question, options: None },
625 ))?;
626
627 if let Some(rx) = &mut self.ask_rx {
629 match rx.recv().await {
630 Some(answer) => {
631 let answer_lower = answer.trim().to_lowercase();
632 if matches!(answer_lower.as_str(), "a" | "abort" | "q" | "quit" | "stop") {
634 self.emit(AgentEvent::with_data(
635 EventType::Error,
636 EventData::Error { message: "Aborted by user".into(), code: None, source: None },
637 ))?;
638 return Err(anyhow::anyhow!("Session aborted by user"));
639 }
640 let approved = matches!(
642 answer_lower.as_str(),
643 "y" | "yes" | "ok" | "approve" | ""
644 );
645 if !approved {
646 return Err(anyhow::anyhow!(
648 "Tool '{}' rejected by user (answer: '{}')", name, answer_lower
649 ));
650 }
651 }
652 None => {
653 return Err(anyhow::anyhow!("Approval channel closed"));
654 }
655 }
656 }
657 } else {
658 return Err(anyhow::anyhow!(
660 "Tool '{}' requires manual approval (risk: {}). Use --approve-mode auto to auto-approve.",
661 name, tool.risk_level()
662 ));
663 }
664 }
665
666 if name == "ask" && self.ask_rx.is_some() {
668 let question = input["question"].as_str().unwrap_or("").to_string();
669 let options = input.get("options").cloned();
670
671 self.emit(AgentEvent::with_data(
673 EventType::AskQuestion,
674 EventData::AskQuestion { question, options },
675 ))?;
676
677 if let Some(rx) = &mut self.ask_rx {
679 match rx.recv().await {
680 Some(answer) => return Ok(answer),
681 None => return Err(anyhow::anyhow!("Ask channel closed")),
682 }
683 }
684 }
685
686 self.emit(AgentEvent::progress(format!("Executing: {}", name), None))?;
688 tool.execute(input).await
689 } else {
690 Err(anyhow::anyhow!("Tool '{}' not found", name))
691 }
692 }
693
694 fn track_usage(&self, usage: &Usage) {
696 self.total_input_tokens.fetch_add(usage.input_tokens as u64, Ordering::Relaxed);
697 self.total_output_tokens.fetch_add(usage.output_tokens as u64, Ordering::Relaxed);
698 self.last_input_tokens.store(usage.input_tokens as u64, Ordering::Relaxed);
700
701 crate::debug::debug_log().log(
703 "usage",
704 &format!("tracked: input_tokens={}, output_tokens={}, cache_read={}, cache_created={}",
705 usage.input_tokens, usage.output_tokens, usage.cache_read_input_tokens, usage.cache_creation_input_tokens)
706 );
707
708 let _ = self.event_tx.try_send(AgentEvent::usage_with_cache(
710 usage.input_tokens as u64,
711 usage.output_tokens as u64,
712 usage.cache_read_input_tokens as u64,
713 usage.cache_creation_input_tokens as u64,
714 ));
715 }
716
717 #[allow(dead_code)]
719 fn estimate_context_size(&self) -> u32 {
720 (self.messages.len() as u32) * 100 + self.total_input_tokens.load(Ordering::Relaxed) as u32
722 }
723
724 fn emit(&self, event: AgentEvent) -> Result<()> {
726 match self.event_tx.try_send(event) {
728 Ok(_) => Ok(()),
729 Err(mpsc::error::TrySendError::Full(_)) => {
730 Ok(())
732 }
733 Err(mpsc::error::TrySendError::Closed(_)) => {
734 Err(anyhow::anyhow!("Event channel closed"))
736 }
737 }
738 }
739
740 pub fn set_messages(&mut self, messages: Vec<Message>) {
742 self.messages = messages;
743 }
744
745 pub fn get_messages(&self) -> &[Message] {
747 &self.messages
748 }
749
750 pub fn get_token_counts(&self) -> (u64, u64) {
752 (
753 self.total_input_tokens.load(Ordering::Relaxed),
754 self.total_output_tokens.load(Ordering::Relaxed),
755 )
756 }
757
758 pub fn clear_history(&mut self) {
760 self.messages.clear();
761 self.total_input_tokens.store(0, Ordering::Relaxed);
762 self.total_output_tokens.store(0, Ordering::Relaxed);
763 self.last_input_tokens.store(0, Ordering::Relaxed);
764 }
765
766 pub fn message_count(&self) -> usize {
768 self.messages.len()
769 }
770}
771