1use std::sync::atomic::{AtomicU8, AtomicU64, Ordering};
6use std::sync::Arc;
7use anyhow::Result;
8use tokio::sync::mpsc;
9
10use crate::event::{AgentEvent, EventType, EventData};
11use crate::providers::{ChatRequest, ChatResponse, ContentBlock, Message, MessageContent, Provider, Role, StopReason, Usage};
12use crate::tools::{Tool, ToolDefinition};
13use crate::approval::{ApproveMode, needs_approval};
14use crate::compress::{CompressionConfig, should_compress};
15use crate::cancel::CancellationToken;
16
17const MAX_ITERATIONS: usize = 50;
18
19#[allow(dead_code)] pub struct Agent {
22 provider: Box<dyn Provider>,
23 model_name: String, tools: Vec<Arc<dyn Tool>>,
25 messages: Vec<Message>,
26 system_prompt: String,
27 max_tokens: u32,
28 think: bool,
29 approve_mode: Arc<AtomicU8>,
30 event_tx: mpsc::Sender<AgentEvent>,
31
32 skills: Vec<crate::skills::Skill>,
34 profile: crate::prompt::PromptProfile,
35 project_overview: Option<String>,
36 memory_summary: Option<String>,
37
38 total_input_tokens: AtomicU64,
40 total_output_tokens: AtomicU64,
41 last_input_tokens: AtomicU64,
43 cancel_token: Option<CancellationToken>,
44 compression_config: CompressionConfig,
45
46 ask_rx: Option<mpsc::Receiver<String>>,
48}
49
50pub struct AgentBuilder {
52 provider: Box<dyn Provider>,
53 model_name: String,
54 tools: Vec<Arc<dyn Tool>>,
55 system_prompt: String,
56 max_tokens: u32,
57 think: bool,
58 approve_mode: ApproveMode,
59 event_tx: Option<mpsc::Sender<AgentEvent>>,
60 skills: Vec<crate::skills::Skill>,
62 profile: crate::prompt::PromptProfile,
63 project_overview: Option<String>,
64 memory_summary: Option<String>,
65}
66
67impl AgentBuilder {
68 pub fn new(provider: Box<dyn Provider>) -> Self {
69 Self {
70 provider,
71 model_name: "unknown".to_string(),
72 tools: Vec::new(),
73 system_prompt: "You are a helpful AI coding assistant.".to_string(),
74 max_tokens: 4096,
75 think: false,
76 approve_mode: ApproveMode::Ask,
77 event_tx: None,
78 skills: Vec::new(),
79 profile: crate::prompt::PromptProfile::Default,
80 project_overview: None,
81 memory_summary: None,
82 }
83 }
84
85 pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
86 self.system_prompt = prompt.into();
87 self
88 }
89
90 pub fn model_name(mut self, name: impl Into<String>) -> Self {
91 self.model_name = name.into();
92 self
93 }
94
95 pub fn max_tokens(mut self, tokens: u32) -> Self {
96 self.max_tokens = tokens;
97 self
98 }
99
100 pub fn think(mut self, enabled: bool) -> Self {
101 self.think = enabled;
102 self
103 }
104
105 pub fn approve_mode(mut self, mode: ApproveMode) -> Self {
106 self.approve_mode = mode;
107 self
108 }
109
110 pub fn tool(mut self, tool: Arc<dyn Tool>) -> Self {
111 self.tools.push(tool);
112 self
113 }
114
115 pub fn tools(mut self, tools: Vec<Box<dyn Tool>>) -> Self {
117 self.tools.extend(tools.into_iter().map(Arc::from));
118 self
119 }
120
121 pub fn event_tx(mut self, tx: mpsc::Sender<AgentEvent>) -> Self {
123 self.event_tx = Some(tx);
124 self
125 }
126
127 pub fn skills(mut self, skills: Vec<crate::skills::Skill>) -> Self {
129 self.skills = skills;
130 self
131 }
132
133 pub fn profile(mut self, profile: crate::prompt::PromptProfile) -> Self {
135 self.profile = profile;
136 self
137 }
138
139 pub fn overview(mut self, overview: impl Into<String>) -> Self {
141 self.project_overview = Some(overview.into());
142 self
143 }
144
145 pub fn memory(mut self, summary: impl Into<String>) -> Self {
147 self.memory_summary = Some(summary.into());
148 self
149 }
150
151 pub fn build(self) -> Agent {
152 Agent::new(self)
153 }
154}
155
156impl Agent {
157 fn new(builder: AgentBuilder) -> Self {
158 let event_tx = builder.event_tx.unwrap_or_else(|| {
160 let (tx, _) = mpsc::channel(100);
161 tx
162 });
163
164 Self {
165 provider: builder.provider,
166 model_name: builder.model_name,
167 tools: builder.tools,
168 messages: Vec::new(),
169 system_prompt: builder.system_prompt,
170 max_tokens: builder.max_tokens,
171 think: builder.think,
172 approve_mode: Arc::new(AtomicU8::new(builder.approve_mode.to_u8())),
173 event_tx,
174 skills: builder.skills,
175 profile: builder.profile,
176 project_overview: builder.project_overview,
177 memory_summary: builder.memory_summary,
178 total_input_tokens: AtomicU64::new(0),
179 total_output_tokens: AtomicU64::new(0),
180 last_input_tokens: AtomicU64::new(0),
181 cancel_token: None,
182 compression_config: CompressionConfig::default(),
183 ask_rx: None,
184 }
185 }
186
187 pub fn event_sender(&self) -> mpsc::Sender<AgentEvent> {
189 self.event_tx.clone()
190 }
191
192 pub fn set_ask_channel(&mut self, rx: mpsc::Receiver<String>) {
194 self.ask_rx = Some(rx);
195 }
196
197 pub fn set_cancel_token(&mut self, token: CancellationToken) {
199 self.cancel_token = Some(token);
200 }
201
202 pub fn set_approve_mode(&mut self, mode: ApproveMode) {
204 let old = ApproveMode::from_u8(self.approve_mode.load(Ordering::Relaxed));
205 log::info!("Agent approve mode changed: {} -> {}", old, mode);
206 self.approve_mode.store(mode.to_u8(), Ordering::Relaxed);
207 }
208
209 pub fn approve_mode_shared(&self) -> Arc<AtomicU8> {
212 self.approve_mode.clone()
213 }
214
215 pub fn set_approve_mode_shared(&mut self, shared: Arc<AtomicU8>) {
219 self.approve_mode = shared;
220 }
221
222 pub fn update_memory_summary(&mut self, summary: Option<String>) {
225 self.memory_summary = summary;
226 self.system_prompt = crate::prompt::build_system_prompt(
228 &self.profile,
229 &self.skills,
230 self.project_overview.as_deref(),
231 self.memory_summary.as_deref(),
232 );
233 }
234
235 pub async fn run(&mut self, user_input: String) -> Result<Vec<AgentEvent>> {
240 self.emit(AgentEvent::session_started())?;
242
243 self.messages.push(Message {
245 role: Role::User,
246 content: MessageContent::Text(user_input.clone()),
247 });
248
249 let mut iterations = 0;
251 let mut should_continue = true;
252
253 while should_continue && iterations < MAX_ITERATIONS {
254 iterations += 1;
255
256 if let Some(token) = &self.cancel_token
258 && token.is_cancelled()
259 {
260 self.emit(AgentEvent::error("Operation cancelled".to_string(), None, None))?;
261 break;
262 }
263
264 let tool_defs: Vec<ToolDefinition> = self.tools.iter().map(|t| t.definition()).collect();
266 let request = ChatRequest {
267 system: Some(self.system_prompt.clone()),
268 messages: self.messages.clone(),
269 max_tokens: self.max_tokens,
270 tools: tool_defs,
271 think: self.think,
272 enable_caching: true,
273 server_tools: Vec::new(),
274 };
275
276 let response = self.call_streaming(&request).await?;
280
281 self.track_usage(&response.usage);
283
284 crate::debug::debug_log().api_call(
286 &self.model_name,
287 response.usage.input_tokens,
288 response.usage.cache_read_input_tokens > 0
289 );
290
291 should_continue = self.process_response(&response).await?;
293
294 let context_size = self.provider.context_size();
296
297 let api_tokens = self.last_input_tokens.load(Ordering::Relaxed) as u32;
300 let estimated_tokens = crate::compress::estimate_total_tokens(&self.messages);
301
302 let current_tokens = if api_tokens > 0 && api_tokens >= estimated_tokens / 2 {
304 api_tokens } else {
306 estimated_tokens };
308
309 crate::debug::debug_log().log(
311 "compression",
312 &format!("check: api={}, estimated={}, using={}, context={}, threshold={}",
313 api_tokens, estimated_tokens, current_tokens, context_size.unwrap_or(0), self.compression_config.threshold)
314 );
315
316 if should_compress(current_tokens, context_size, &self.compression_config) {
317 self.emit(AgentEvent::progress("Compressing context...", None))?;
318
319 let _original_count = self.messages.len();
320 let original_tokens = current_tokens;
321
322 match crate::compress::compress_messages(
324 &self.messages,
325 crate::compress::CompressionStrategy::SlidingWindow,
326 &self.compression_config,
327 ) {
328 Ok(compressed) => {
329 let compressed_tokens = crate::compress::estimate_total_tokens(&compressed);
330 self.messages = compressed;
331 self.total_input_tokens.store(compressed_tokens as u64, Ordering::Relaxed);
332 self.last_input_tokens.store(compressed_tokens as u64, Ordering::Relaxed);
333
334 let ratio = compressed_tokens as f32 / original_tokens as f32;
336 crate::debug::debug_log().compression(original_tokens, compressed_tokens, ratio);
337
338 self.emit(AgentEvent::with_data(
339 crate::event::EventType::CompressionCompleted,
340 crate::event::EventData::Compression {
341 original_tokens: original_tokens as u64,
342 compressed_tokens: compressed_tokens as u64,
343 ratio: compressed_tokens as f32 / original_tokens as f32,
344 },
345 ))?;
346 }
347 Err(e) => {
348 self.emit(AgentEvent::progress(
349 format!("Compression failed: {}", e),
350 None,
351 ))?;
352 }
353 }
354 }
355 }
356
357 self.emit(AgentEvent::usage_with_cache(
359 self.last_input_tokens.load(Ordering::Relaxed),
360 self.total_output_tokens.load(Ordering::Relaxed),
361 0, 0, ))?;
363
364 self.emit(AgentEvent::session_ended())?;
366
367 Ok(Vec::new())
368 }
369
370 async fn call_streaming(&mut self, request: &ChatRequest) -> Result<ChatResponse> {
372 use crate::providers::StreamEvent;
373
374 const MAX_RETRIES: u32 = 5;
375 const RETRY_DELAY_MS: u64 = 1000; let mut attempt = 0;
378
379 loop {
380 attempt += 1;
381
382 let rx_result = self.provider.chat_stream(request.clone()).await;
384
385 match rx_result {
386 Ok(mut rx) => {
387 let mut response_content: Vec<ContentBlock> = Vec::new();
389 let mut current_text = String::new();
390 let mut current_thinking = String::new();
391 let mut usage = Usage {
392 input_tokens: 0,
393 output_tokens: 0,
394 cache_creation_input_tokens: 0,
395 cache_read_input_tokens: 0,
396 };
397 let mut should_retry = false;
398
399 while let Some(event) = rx.recv().await {
400 match event {
401 StreamEvent::FirstByte => {
402 }
404 StreamEvent::ThinkingDelta(delta) => {
405 if current_thinking.is_empty() {
406 self.emit(AgentEvent::thinking_start())?;
407 }
408 current_thinking.push_str(&delta);
409 self.emit(AgentEvent::thinking_delta(delta, None))?;
410 }
411 StreamEvent::TextDelta(delta) => {
412 if current_text.is_empty() {
413 self.emit(AgentEvent::text_start())?;
414 }
415 current_text.push_str(&delta);
416 self.emit(AgentEvent::text_delta(delta))?;
417 }
418 StreamEvent::ToolUseStart { id, name } => {
419 if !current_thinking.is_empty() {
421 self.emit(AgentEvent::thinking_end())?;
422 response_content.push(ContentBlock::Thinking {
423 thinking: current_thinking.clone(),
424 signature: None,
425 });
426 current_thinking.clear();
427 }
428 if !current_text.is_empty() {
430 self.emit(AgentEvent::text_end())?;
431 response_content.push(ContentBlock::Text { text: current_text.clone() });
432 current_text.clear();
433 }
434 self.emit(AgentEvent::tool_use_start(&id, &name, None))?;
435 }
436 StreamEvent::ToolInputDelta { bytes_so_far: _ } => {
437 }
439 StreamEvent::Usage { output_tokens } => {
440 self.emit(AgentEvent::usage_with_cache(
442 0, output_tokens as u64,
444 0, 0 ))?;
446 usage.output_tokens = output_tokens;
447 }
448 StreamEvent::Done(resp) => {
449 if !current_thinking.is_empty() {
451 self.emit(AgentEvent::thinking_end())?;
452 response_content.push(ContentBlock::Thinking {
453 thinking: current_thinking.clone(),
454 signature: None,
455 });
456 }
457 if !current_text.is_empty() {
459 self.emit(AgentEvent::text_end())?;
460 response_content.push(ContentBlock::Text { text: current_text.clone() });
461 }
462 for block in &resp.content {
464 if !response_content.iter().any(|b| b == block) {
465 response_content.push(block.clone());
466 }
467 }
468 usage = resp.usage;
469 }
470 StreamEvent::Error(msg) => {
471 if attempt < MAX_RETRIES {
473 self.emit(AgentEvent::progress(
474 format!("⚠️ Stream error, retrying ({}/{}): {}", attempt, MAX_RETRIES, &msg),
475 None,
476 ))?;
477 let delay = RETRY_DELAY_MS * (1 << (attempt - 1));
479 tokio::time::sleep(tokio::time::Duration::from_millis(delay)).await;
480 should_retry = true;
481 break; } else {
483 self.emit(AgentEvent::error(msg.clone(), None, None))?;
484 return Err(anyhow::anyhow!("Stream error after {} retries: {}", MAX_RETRIES, msg));
485 }
486 }
487 }
488 }
489
490 if should_retry {
491 continue; }
493
494 return Ok(ChatResponse {
495 content: response_content,
496 stop_reason: StopReason::EndTurn,
497 usage,
498 });
499 }
500 Err(e) => {
501 if attempt < MAX_RETRIES {
503 let error_msg = e.to_string();
504 self.emit(AgentEvent::progress(
505 format!("⚠️ API error, retrying ({}/{}): {}", attempt, MAX_RETRIES, &error_msg),
506 None,
507 ))?;
508 let delay = RETRY_DELAY_MS * (1 << (attempt - 1));
510 tokio::time::sleep(tokio::time::Duration::from_millis(delay)).await;
511 } else {
512 return Err(anyhow::anyhow!("API error after {} retries: {}", MAX_RETRIES, e));
513 }
514 }
515 }
516 }
517 }
518
519 async fn process_response(&mut self, response: &ChatResponse) -> Result<bool> {
521 let mut has_tool_use = false;
522 let mut assistant_content: Vec<ContentBlock> = Vec::new();
523 let mut tool_results: Vec<Message> = Vec::new();
524
525 for block in &response.content {
526 match block {
527 ContentBlock::Text { text } => {
529 assistant_content.push(ContentBlock::Text { text: text.clone() });
530 }
531
532 ContentBlock::Thinking { thinking, signature } => {
533 assistant_content.push(ContentBlock::Thinking {
534 thinking: thinking.clone(),
535 signature: signature.clone(),
536 });
537 }
538
539 ContentBlock::ToolUse { id, name, input } => {
540 has_tool_use = true;
541
542 let result = self.execute_tool(name, input.clone()).await;
547
548 let (content, is_error) = match result {
549 Ok(output) => (output, false),
550 Err(e) => (e.to_string(), true),
551 };
552
553 self.emit(AgentEvent::tool_result(id.clone(), name.clone(), content.clone(), is_error))?;
554
555 assistant_content.push(ContentBlock::ToolUse {
557 id: id.clone(),
558 name: name.clone(),
559 input: input.clone(),
560 });
561
562 tool_results.push(Message {
564 role: Role::User,
565 content: MessageContent::Blocks(vec![ContentBlock::ToolResult {
566 tool_use_id: id.clone(),
567 content: format!("{}: {}", if is_error { "Error" } else { "Result" }, content),
568 }]),
569 });
570 }
571
572 _ => {}
573 }
574 }
575
576 if !assistant_content.is_empty() {
578 self.messages.push(Message {
579 role: Role::Assistant,
580 content: MessageContent::Blocks(assistant_content),
581 });
582 }
583
584 for msg in tool_results {
586 self.messages.push(msg);
587 }
588
589 Ok(has_tool_use)
591 }
592
593 async fn execute_tool(&mut self, name: &str, input: serde_json::Value) -> Result<String> {
595 let tool = self.tools.iter().find(|t| t.definition().name == name);
596
597 if let Some(tool) = tool {
598 let current_mode = ApproveMode::from_u8(self.approve_mode.load(Ordering::Relaxed));
600
601 log::debug!(
603 "Tool '{}' approval check: mode={}, risk={}, needs_approval={}",
604 name, current_mode, tool.risk_level(),
605 needs_approval(current_mode, tool.risk_level())
606 );
607
608 if needs_approval(current_mode, tool.risk_level()) {
610 if self.ask_rx.is_some() {
612 let detail = match name {
614 "bash" => format!("Command: {}", input["command"].as_str().unwrap_or("?")),
615 "write" => format!("File: {}", input["path"].as_str().unwrap_or("?")),
616 "edit" | "multi_edit" => format!("File: {}", input["path"].as_str().unwrap_or("?")),
617 _ => format!("Tool: {}", name),
618 };
619
620 let question = format!(
621 "⚠️ Tool '{}' requires approval (risk: {})\n{}\n\nAllow? (y/n)",
622 name, tool.risk_level(), detail
623 );
624
625 self.emit(AgentEvent::with_data(
627 EventType::AskQuestion,
628 EventData::AskQuestion { question, options: None },
629 ))?;
630
631 if let Some(rx) = &mut self.ask_rx {
633 match rx.recv().await {
634 Some(answer) => {
635 let answer_lower = answer.trim().to_lowercase();
636 if matches!(answer_lower.as_str(), "a" | "abort" | "q" | "quit" | "stop") {
638 self.emit(AgentEvent::with_data(
639 EventType::Error,
640 EventData::Error { message: "Aborted by user".into(), code: None, source: None },
641 ))?;
642 return Err(anyhow::anyhow!("Session aborted by user"));
643 }
644 let approved = matches!(
646 answer_lower.as_str(),
647 "y" | "yes" | "ok" | "approve" | ""
648 );
649 if !approved {
650 return Err(anyhow::anyhow!(
652 "Tool '{}' rejected by user (answer: '{}')", name, answer_lower
653 ));
654 }
655 }
656 None => {
657 return Err(anyhow::anyhow!("Approval channel closed"));
658 }
659 }
660 }
661 } else {
662 return Err(anyhow::anyhow!(
664 "Tool '{}' requires manual approval (risk: {}). Use --approve-mode auto to auto-approve.",
665 name, tool.risk_level()
666 ));
667 }
668 }
669
670 if name == "ask" && self.ask_rx.is_some() {
672 let question = input["question"].as_str().unwrap_or("").to_string();
673 let options = input.get("options").cloned();
674
675 self.emit(AgentEvent::with_data(
677 EventType::AskQuestion,
678 EventData::AskQuestion { question, options },
679 ))?;
680
681 if let Some(rx) = &mut self.ask_rx {
683 match rx.recv().await {
684 Some(answer) => return Ok(answer),
685 None => return Err(anyhow::anyhow!("Ask channel closed")),
686 }
687 }
688 }
689
690 self.emit(AgentEvent::progress(format!("Executing: {}", name), None))?;
692 tool.execute(input).await
693 } else {
694 Err(anyhow::anyhow!("Tool '{}' not found", name))
695 }
696 }
697
698 fn track_usage(&self, usage: &Usage) {
700 self.total_input_tokens.fetch_add(usage.input_tokens as u64, Ordering::Relaxed);
701 self.total_output_tokens.fetch_add(usage.output_tokens as u64, Ordering::Relaxed);
702 self.last_input_tokens.store(usage.input_tokens as u64, Ordering::Relaxed);
704
705 crate::debug::debug_log().log(
707 "usage",
708 &format!("tracked: input_tokens={}, output_tokens={}, cache_read={}, cache_created={}",
709 usage.input_tokens, usage.output_tokens, usage.cache_read_input_tokens, usage.cache_creation_input_tokens)
710 );
711
712 let _ = self.event_tx.try_send(AgentEvent::usage_with_cache(
714 usage.input_tokens as u64,
715 usage.output_tokens as u64,
716 usage.cache_read_input_tokens as u64,
717 usage.cache_creation_input_tokens as u64,
718 ));
719 }
720
721 #[allow(dead_code)]
723 fn estimate_context_size(&self) -> u32 {
724 (self.messages.len() as u32) * 100 + self.total_input_tokens.load(Ordering::Relaxed) as u32
726 }
727
728 fn emit(&self, event: AgentEvent) -> Result<()> {
730 match self.event_tx.try_send(event) {
732 Ok(_) => Ok(()),
733 Err(mpsc::error::TrySendError::Full(_)) => {
734 Ok(())
736 }
737 Err(mpsc::error::TrySendError::Closed(_)) => {
738 Err(anyhow::anyhow!("Event channel closed"))
740 }
741 }
742 }
743
744 pub fn set_messages(&mut self, messages: Vec<Message>) {
746 self.messages = messages;
747 }
748
749 pub fn get_messages(&self) -> &[Message] {
751 &self.messages
752 }
753
754 pub fn get_token_counts(&self) -> (u64, u64) {
756 (
757 self.total_input_tokens.load(Ordering::Relaxed),
758 self.total_output_tokens.load(Ordering::Relaxed),
759 )
760 }
761
762 pub fn clear_history(&mut self) {
764 self.messages.clear();
765 self.total_input_tokens.store(0, Ordering::Relaxed);
766 self.total_output_tokens.store(0, Ordering::Relaxed);
767 self.last_input_tokens.store(0, Ordering::Relaxed);
768 }
769
770 pub fn message_count(&self) -> usize {
772 self.messages.len()
773 }
774}
775