1use crate::context::{CompactionConfig, ContextCompactor, LlmContextCompactor};
31use crate::events::AgentEvent;
32use crate::hooks::{AgentHooks, DefaultHooks, ToolDecision};
33use crate::llm::{
34 ChatOutcome, ChatRequest, ChatResponse, Content, ContentBlock, LlmProvider, Message, Role,
35 StopReason,
36};
37use crate::skills::Skill;
38use crate::stores::{InMemoryStore, MessageStore, StateStore};
39use crate::tools::{ToolContext, ToolRegistry};
40use crate::types::{AgentConfig, AgentState, RetryConfig, ThreadId, TokenUsage, ToolResult};
41use anyhow::Result;
42use std::sync::Arc;
43use std::time::{Duration, Instant};
44use tokio::sync::mpsc;
45use tokio::time::sleep;
46use tracing::{debug, error, info, warn};
47
48pub struct AgentLoopBuilder<Ctx, P, H, M, S> {
60 provider: Option<P>,
61 tools: Option<ToolRegistry<Ctx>>,
62 hooks: Option<H>,
63 message_store: Option<M>,
64 state_store: Option<S>,
65 config: Option<AgentConfig>,
66 compaction_config: Option<CompactionConfig>,
67}
68
69impl<Ctx> AgentLoopBuilder<Ctx, (), (), (), ()> {
70 #[must_use]
72 pub const fn new() -> Self {
73 Self {
74 provider: None,
75 tools: None,
76 hooks: None,
77 message_store: None,
78 state_store: None,
79 config: None,
80 compaction_config: None,
81 }
82 }
83}
84
85impl<Ctx> Default for AgentLoopBuilder<Ctx, (), (), (), ()> {
86 fn default() -> Self {
87 Self::new()
88 }
89}
90
91impl<Ctx, P, H, M, S> AgentLoopBuilder<Ctx, P, H, M, S> {
92 #[must_use]
94 pub fn provider<P2: LlmProvider>(self, provider: P2) -> AgentLoopBuilder<Ctx, P2, H, M, S> {
95 AgentLoopBuilder {
96 provider: Some(provider),
97 tools: self.tools,
98 hooks: self.hooks,
99 message_store: self.message_store,
100 state_store: self.state_store,
101 config: self.config,
102 compaction_config: self.compaction_config,
103 }
104 }
105
106 #[must_use]
108 pub fn tools(mut self, tools: ToolRegistry<Ctx>) -> Self {
109 self.tools = Some(tools);
110 self
111 }
112
113 #[must_use]
115 pub fn hooks<H2: AgentHooks>(self, hooks: H2) -> AgentLoopBuilder<Ctx, P, H2, M, S> {
116 AgentLoopBuilder {
117 provider: self.provider,
118 tools: self.tools,
119 hooks: Some(hooks),
120 message_store: self.message_store,
121 state_store: self.state_store,
122 config: self.config,
123 compaction_config: self.compaction_config,
124 }
125 }
126
127 #[must_use]
129 pub fn message_store<M2: MessageStore>(
130 self,
131 message_store: M2,
132 ) -> AgentLoopBuilder<Ctx, P, H, M2, S> {
133 AgentLoopBuilder {
134 provider: self.provider,
135 tools: self.tools,
136 hooks: self.hooks,
137 message_store: Some(message_store),
138 state_store: self.state_store,
139 config: self.config,
140 compaction_config: self.compaction_config,
141 }
142 }
143
144 #[must_use]
146 pub fn state_store<S2: StateStore>(
147 self,
148 state_store: S2,
149 ) -> AgentLoopBuilder<Ctx, P, H, M, S2> {
150 AgentLoopBuilder {
151 provider: self.provider,
152 tools: self.tools,
153 hooks: self.hooks,
154 message_store: self.message_store,
155 state_store: Some(state_store),
156 config: self.config,
157 compaction_config: self.compaction_config,
158 }
159 }
160
161 #[must_use]
163 pub fn config(mut self, config: AgentConfig) -> Self {
164 self.config = Some(config);
165 self
166 }
167
168 #[must_use]
184 pub const fn with_compaction(mut self, config: CompactionConfig) -> Self {
185 self.compaction_config = Some(config);
186 self
187 }
188
189 #[must_use]
196 pub fn with_auto_compaction(self) -> Self {
197 self.with_compaction(CompactionConfig::default())
198 }
199
200 #[must_use]
218 pub fn with_skill(mut self, skill: Skill) -> Self
219 where
220 Ctx: Send + Sync + 'static,
221 {
222 if let Some(ref mut tools) = self.tools {
224 tools.filter(|name| skill.is_tool_allowed(name));
225 }
226
227 let mut config = self.config.take().unwrap_or_default();
229 if config.system_prompt.is_empty() {
230 config.system_prompt = skill.system_prompt;
231 } else {
232 config.system_prompt = format!("{}\n\n{}", config.system_prompt, skill.system_prompt);
233 }
234 self.config = Some(config);
235
236 self
237 }
238}
239
240impl<Ctx, P> AgentLoopBuilder<Ctx, P, (), (), ()>
241where
242 Ctx: Send + Sync + 'static,
243 P: LlmProvider + 'static,
244{
245 #[must_use]
257 pub fn build(self) -> AgentLoop<Ctx, P, DefaultHooks, InMemoryStore, InMemoryStore> {
258 let provider = self.provider.expect("provider is required");
259 let tools = self.tools.unwrap_or_default();
260 let config = self.config.unwrap_or_default();
261
262 AgentLoop {
263 provider: Arc::new(provider),
264 tools: Arc::new(tools),
265 hooks: Arc::new(DefaultHooks),
266 message_store: Arc::new(InMemoryStore::new()),
267 state_store: Arc::new(InMemoryStore::new()),
268 config,
269 compaction_config: self.compaction_config,
270 }
271 }
272}
273
274impl<Ctx, P, H, M, S> AgentLoopBuilder<Ctx, P, H, M, S>
275where
276 Ctx: Send + Sync + 'static,
277 P: LlmProvider + 'static,
278 H: AgentHooks + 'static,
279 M: MessageStore + 'static,
280 S: StateStore + 'static,
281{
282 #[must_use]
292 pub fn build_with_stores(self) -> AgentLoop<Ctx, P, H, M, S> {
293 let provider = self.provider.expect("provider is required");
294 let tools = self.tools.unwrap_or_default();
295 let hooks = self
296 .hooks
297 .expect("hooks is required when using build_with_stores");
298 let message_store = self
299 .message_store
300 .expect("message_store is required when using build_with_stores");
301 let state_store = self
302 .state_store
303 .expect("state_store is required when using build_with_stores");
304 let config = self.config.unwrap_or_default();
305
306 AgentLoop {
307 provider: Arc::new(provider),
308 tools: Arc::new(tools),
309 hooks: Arc::new(hooks),
310 message_store: Arc::new(message_store),
311 state_store: Arc::new(state_store),
312 config,
313 compaction_config: self.compaction_config,
314 }
315 }
316}
317
318pub struct AgentLoop<Ctx, P, H, M, S>
348where
349 P: LlmProvider,
350 H: AgentHooks,
351 M: MessageStore,
352 S: StateStore,
353{
354 provider: Arc<P>,
355 tools: Arc<ToolRegistry<Ctx>>,
356 hooks: Arc<H>,
357 message_store: Arc<M>,
358 state_store: Arc<S>,
359 config: AgentConfig,
360 compaction_config: Option<CompactionConfig>,
361}
362
363#[must_use]
365pub const fn builder<Ctx>() -> AgentLoopBuilder<Ctx, (), (), (), ()> {
366 AgentLoopBuilder::new()
367}
368
369impl<Ctx, P, H, M, S> AgentLoop<Ctx, P, H, M, S>
370where
371 Ctx: Send + Sync + 'static,
372 P: LlmProvider + 'static,
373 H: AgentHooks + 'static,
374 M: MessageStore + 'static,
375 S: StateStore + 'static,
376{
377 #[must_use]
379 pub fn new(
380 provider: P,
381 tools: ToolRegistry<Ctx>,
382 hooks: H,
383 message_store: M,
384 state_store: S,
385 config: AgentConfig,
386 ) -> Self {
387 Self {
388 provider: Arc::new(provider),
389 tools: Arc::new(tools),
390 hooks: Arc::new(hooks),
391 message_store: Arc::new(message_store),
392 state_store: Arc::new(state_store),
393 config,
394 compaction_config: None,
395 }
396 }
397
398 #[must_use]
400 pub fn with_compaction(
401 provider: P,
402 tools: ToolRegistry<Ctx>,
403 hooks: H,
404 message_store: M,
405 state_store: S,
406 config: AgentConfig,
407 compaction_config: CompactionConfig,
408 ) -> Self {
409 Self {
410 provider: Arc::new(provider),
411 tools: Arc::new(tools),
412 hooks: Arc::new(hooks),
413 message_store: Arc::new(message_store),
414 state_store: Arc::new(state_store),
415 config,
416 compaction_config: Some(compaction_config),
417 }
418 }
419
420 pub fn run(
423 &self,
424 thread_id: ThreadId,
425 user_message: String,
426 tool_context: ToolContext<Ctx>,
427 ) -> mpsc::Receiver<AgentEvent>
428 where
429 Ctx: Clone,
430 {
431 let (tx, rx) = mpsc::channel(100);
432
433 let provider = Arc::clone(&self.provider);
434 let tools = Arc::clone(&self.tools);
435 let hooks = Arc::clone(&self.hooks);
436 let message_store = Arc::clone(&self.message_store);
437 let state_store = Arc::clone(&self.state_store);
438 let config = self.config.clone();
439 let compaction_config = self.compaction_config.clone();
440
441 tokio::spawn(async move {
442 let result = run_loop(
443 tx.clone(),
444 thread_id,
445 user_message,
446 tool_context,
447 provider,
448 tools,
449 hooks,
450 message_store,
451 state_store,
452 config,
453 compaction_config,
454 )
455 .await;
456
457 if let Err(e) = result {
458 let _ = tx.send(AgentEvent::error(e.to_string(), false)).await;
459 }
460 });
461
462 rx
463 }
464}
465
466#[allow(clippy::too_many_arguments, clippy::too_many_lines)]
467async fn run_loop<Ctx, P, H, M, S>(
468 tx: mpsc::Sender<AgentEvent>,
469 thread_id: ThreadId,
470 user_message: String,
471 tool_context: ToolContext<Ctx>,
472 provider: Arc<P>,
473 tools: Arc<ToolRegistry<Ctx>>,
474 hooks: Arc<H>,
475 message_store: Arc<M>,
476 state_store: Arc<S>,
477 config: AgentConfig,
478 compaction_config: Option<CompactionConfig>,
479) -> Result<()>
480where
481 Ctx: Send + Sync + Clone + 'static,
482 P: LlmProvider,
483 H: AgentHooks,
484 M: MessageStore,
485 S: StateStore,
486{
487 let tool_context = tool_context.with_event_tx(tx.clone());
489
490 let start_time = Instant::now();
491 let mut turn = 0;
492 let mut total_usage = TokenUsage::default();
493
494 let mut state = state_store
496 .load(&thread_id)
497 .await?
498 .unwrap_or_else(|| AgentState::new(thread_id.clone()));
499
500 let user_msg = Message::user(&user_message);
502 message_store.append(&thread_id, user_msg).await?;
503
504 loop {
506 turn += 1;
507 state.turn_count = turn;
508
509 if turn > config.max_turns {
510 warn!(turn, max = config.max_turns, "Max turns reached");
511 tx.send(AgentEvent::error(
512 format!("Maximum turns ({}) reached", config.max_turns),
513 true,
514 ))
515 .await?;
516 break;
517 }
518
519 tx.send(AgentEvent::start(thread_id.clone(), turn)).await?;
521 hooks
522 .on_event(&AgentEvent::start(thread_id.clone(), turn))
523 .await;
524
525 let mut messages = message_store.get_history(&thread_id).await?;
527
528 if let Some(ref compact_config) = compaction_config {
530 let compactor = LlmContextCompactor::new(Arc::clone(&provider), compact_config.clone());
531 if compactor.needs_compaction(&messages) {
532 debug!(
533 turn,
534 message_count = messages.len(),
535 "Context compaction triggered"
536 );
537
538 match compactor.compact_history(messages).await {
539 Ok(result) => {
540 message_store
542 .replace_history(&thread_id, result.messages.clone())
543 .await?;
544
545 tx.send(AgentEvent::context_compacted(
547 result.original_count,
548 result.new_count,
549 result.original_tokens,
550 result.new_tokens,
551 ))
552 .await?;
553
554 info!(
555 original_count = result.original_count,
556 new_count = result.new_count,
557 original_tokens = result.original_tokens,
558 new_tokens = result.new_tokens,
559 "Context compacted successfully"
560 );
561
562 messages = result.messages;
564 }
565 Err(e) => {
566 warn!(error = %e, "Context compaction failed, continuing with full history");
567 messages = message_store.get_history(&thread_id).await?;
569 }
570 }
571 }
572 }
573
574 let llm_tools = if tools.is_empty() {
576 None
577 } else {
578 Some(tools.to_llm_tools())
579 };
580
581 let request = ChatRequest {
582 system: config.system_prompt.clone(),
583 messages,
584 tools: llm_tools,
585 max_tokens: config.max_tokens,
586 };
587
588 debug!(turn, "Calling LLM");
590 let max_retries = config.retry.max_retries;
591 let response = {
592 let mut attempt = 0u32;
593 loop {
594 let outcome = provider.chat(request.clone()).await?;
595 match outcome {
596 ChatOutcome::Success(response) => break Some(response),
597 ChatOutcome::RateLimited => {
598 attempt += 1;
599 if attempt > max_retries {
600 error!("Rate limited by LLM provider after {max_retries} retries");
601 tx.send(AgentEvent::error(
602 format!("Rate limited after {max_retries} retries"),
603 true,
604 ))
605 .await?;
606 break None;
607 }
608 let delay = calculate_backoff_delay(attempt, &config.retry);
609 warn!(
610 attempt,
611 delay_ms = delay.as_millis(),
612 "Rate limited, retrying after backoff"
613 );
614 tx.send(AgentEvent::text(format!(
615 "\n[Rate limited, retrying in {:.1}s... (attempt {attempt}/{max_retries})]\n",
616 delay.as_secs_f64()
617 )))
618 .await?;
619 sleep(delay).await;
620 }
621 ChatOutcome::InvalidRequest(msg) => {
622 error!(msg, "Invalid request to LLM");
623 tx.send(AgentEvent::error(format!("Invalid request: {msg}"), false))
624 .await?;
625 break None;
626 }
627 ChatOutcome::ServerError(msg) => {
628 attempt += 1;
629 if attempt > max_retries {
630 error!(msg, "LLM server error after {max_retries} retries");
631 tx.send(AgentEvent::error(
632 format!("Server error after {max_retries} retries: {msg}"),
633 true,
634 ))
635 .await?;
636 break None;
637 }
638 let delay = calculate_backoff_delay(attempt, &config.retry);
639 warn!(
640 attempt,
641 delay_ms = delay.as_millis(),
642 error = msg,
643 "Server error, retrying after backoff"
644 );
645 tx.send(AgentEvent::text(format!(
646 "\n[Server error: {msg}, retrying in {:.1}s... (attempt {attempt}/{max_retries})]\n",
647 delay.as_secs_f64()
648 )))
649 .await?;
650 sleep(delay).await;
651 }
652 }
653 }
654 };
655
656 let Some(response) = response else {
658 break;
659 };
660
661 let turn_usage = TokenUsage {
663 input_tokens: response.usage.input_tokens,
664 output_tokens: response.usage.output_tokens,
665 };
666 total_usage.add(&turn_usage);
667 state.total_usage = total_usage.clone();
668
669 let (text_content, tool_uses) = extract_content(&response);
671
672 if let Some(text) = &text_content {
674 tx.send(AgentEvent::text(text.clone())).await?;
675 hooks.on_event(&AgentEvent::text(text.clone())).await;
676 }
677
678 if tool_uses.is_empty() {
680 info!(turn, "Agent completed (no tool use)");
681 break;
682 }
683
684 let assistant_msg = build_assistant_message(&response);
686 message_store.append(&thread_id, assistant_msg).await?;
687
688 let mut tool_results = Vec::new();
690 for (tool_id, tool_name, tool_input) in &tool_uses {
691 let Some(tool) = tools.get(tool_name) else {
692 let result = ToolResult::error(format!("Unknown tool: {tool_name}"));
693 tool_results.push((tool_id.clone(), result));
694 continue;
695 };
696
697 let tier = tool.tier();
698
699 tx.send(AgentEvent::tool_call_start(
701 tool_id,
702 tool_name,
703 tool_input.clone(),
704 tier,
705 ))
706 .await?;
707
708 let decision = hooks.pre_tool_use(tool_name, tool_input, tier).await;
710
711 match decision {
712 ToolDecision::Allow => {
713 let tool_start = Instant::now();
715 let result = match tool.execute(&tool_context, tool_input.clone()).await {
716 Ok(mut r) => {
717 r.duration_ms = Some(millis_to_u64(tool_start.elapsed().as_millis()));
718 r
719 }
720 Err(e) => ToolResult::error(format!("Tool error: {e}"))
721 .with_duration(millis_to_u64(tool_start.elapsed().as_millis())),
722 };
723
724 hooks.post_tool_use(tool_name, &result).await;
726
727 tx.send(AgentEvent::tool_call_end(
729 tool_id,
730 tool_name,
731 result.clone(),
732 ))
733 .await?;
734
735 tool_results.push((tool_id.clone(), result));
736 }
737 ToolDecision::Block(reason) => {
738 let result = ToolResult::error(format!("Blocked: {reason}"));
739 tx.send(AgentEvent::tool_call_end(
740 tool_id,
741 tool_name,
742 result.clone(),
743 ))
744 .await?;
745 tool_results.push((tool_id.clone(), result));
746 }
747 ToolDecision::RequiresConfirmation(description) => {
748 tx.send(AgentEvent::ToolRequiresConfirmation {
749 id: tool_id.clone(),
750 name: tool_name.clone(),
751 input: tool_input.clone(),
752 description,
753 })
754 .await?;
755 let result = ToolResult::error("Awaiting user confirmation");
757 tool_results.push((tool_id.clone(), result));
758 }
759 ToolDecision::RequiresPin(description) => {
760 tx.send(AgentEvent::ToolRequiresPin {
761 id: tool_id.clone(),
762 name: tool_name.clone(),
763 input: tool_input.clone(),
764 description,
765 })
766 .await?;
767 let result = ToolResult::error("Awaiting PIN verification");
769 tool_results.push((tool_id.clone(), result));
770 }
771 }
772 }
773
774 for (tool_id, result) in &tool_results {
776 let tool_result_msg = Message::tool_result(tool_id, &result.output, !result.success);
777 message_store.append(&thread_id, tool_result_msg).await?;
778 }
779
780 tx.send(AgentEvent::TurnComplete {
782 turn,
783 usage: turn_usage,
784 })
785 .await?;
786
787 if response.stop_reason == Some(StopReason::EndTurn) {
789 info!(turn, "Agent completed (end_turn)");
790 break;
791 }
792
793 state_store.save(&state).await?;
795 }
796
797 state_store.save(&state).await?;
799
800 let duration = start_time.elapsed();
802 tx.send(AgentEvent::done(thread_id, turn, total_usage, duration))
803 .await?;
804
805 Ok(())
806}
807
808#[allow(clippy::cast_possible_truncation)]
810const fn millis_to_u64(millis: u128) -> u64 {
811 if millis > u64::MAX as u128 {
812 u64::MAX
813 } else {
814 millis as u64
815 }
816}
817
818fn calculate_backoff_delay(attempt: u32, config: &RetryConfig) -> Duration {
823 let base_delay = config
825 .base_delay_ms
826 .saturating_mul(1u64 << (attempt.saturating_sub(1)));
827
828 let max_jitter = config.base_delay_ms.min(1000);
830 let jitter = if max_jitter > 0 {
831 u64::from(
832 std::time::SystemTime::now()
833 .duration_since(std::time::UNIX_EPOCH)
834 .unwrap_or_default()
835 .subsec_nanos(),
836 ) % max_jitter
837 } else {
838 0
839 };
840
841 let delay_ms = base_delay.saturating_add(jitter).min(config.max_delay_ms);
842 Duration::from_millis(delay_ms)
843}
844
845fn extract_content(
846 response: &ChatResponse,
847) -> (Option<String>, Vec<(String, String, serde_json::Value)>) {
848 let mut text_parts = Vec::new();
849 let mut tool_uses = Vec::new();
850
851 for block in &response.content {
852 match block {
853 ContentBlock::Text { text } => {
854 text_parts.push(text.clone());
855 }
856 ContentBlock::ToolUse {
857 id, name, input, ..
858 } => {
859 tool_uses.push((id.clone(), name.clone(), input.clone()));
860 }
861 ContentBlock::ToolResult { .. } => {
862 }
864 }
865 }
866
867 let text = if text_parts.is_empty() {
868 None
869 } else {
870 Some(text_parts.join("\n"))
871 };
872
873 (text, tool_uses)
874}
875
876fn build_assistant_message(response: &ChatResponse) -> Message {
877 let mut blocks = Vec::new();
878
879 for block in &response.content {
880 match block {
881 ContentBlock::Text { text } => {
882 blocks.push(ContentBlock::Text { text: text.clone() });
883 }
884 ContentBlock::ToolUse {
885 id,
886 name,
887 input,
888 thought_signature,
889 } => {
890 blocks.push(ContentBlock::ToolUse {
891 id: id.clone(),
892 name: name.clone(),
893 input: input.clone(),
894 thought_signature: thought_signature.clone(),
895 });
896 }
897 ContentBlock::ToolResult { .. } => {}
898 }
899 }
900
901 Message {
902 role: Role::Assistant,
903 content: Content::Blocks(blocks),
904 }
905}
906
907#[cfg(test)]
908mod tests {
909 use super::*;
910 use crate::hooks::AllowAllHooks;
911 use crate::llm::{ChatOutcome, ChatRequest, ChatResponse, ContentBlock, StopReason, Usage};
912 use crate::stores::InMemoryStore;
913 use crate::tools::{Tool, ToolContext, ToolRegistry};
914 use crate::types::{AgentConfig, ToolResult, ToolTier};
915 use async_trait::async_trait;
916 use serde_json::json;
917 use std::sync::RwLock;
918 use std::sync::atomic::{AtomicUsize, Ordering};
919
920 struct MockProvider {
925 responses: RwLock<Vec<ChatOutcome>>,
926 call_count: AtomicUsize,
927 }
928
929 impl MockProvider {
930 fn new(responses: Vec<ChatOutcome>) -> Self {
931 Self {
932 responses: RwLock::new(responses),
933 call_count: AtomicUsize::new(0),
934 }
935 }
936
937 fn text_response(text: &str) -> ChatOutcome {
938 ChatOutcome::Success(ChatResponse {
939 id: "msg_1".to_string(),
940 content: vec![ContentBlock::Text {
941 text: text.to_string(),
942 }],
943 model: "mock-model".to_string(),
944 stop_reason: Some(StopReason::EndTurn),
945 usage: Usage {
946 input_tokens: 10,
947 output_tokens: 20,
948 },
949 })
950 }
951
952 fn tool_use_response(
953 tool_id: &str,
954 tool_name: &str,
955 input: serde_json::Value,
956 ) -> ChatOutcome {
957 ChatOutcome::Success(ChatResponse {
958 id: "msg_1".to_string(),
959 content: vec![ContentBlock::ToolUse {
960 id: tool_id.to_string(),
961 name: tool_name.to_string(),
962 input,
963 thought_signature: None,
964 }],
965 model: "mock-model".to_string(),
966 stop_reason: Some(StopReason::ToolUse),
967 usage: Usage {
968 input_tokens: 10,
969 output_tokens: 20,
970 },
971 })
972 }
973 }
974
975 #[async_trait]
976 impl LlmProvider for MockProvider {
977 async fn chat(&self, _request: ChatRequest) -> Result<ChatOutcome> {
978 let idx = self.call_count.fetch_add(1, Ordering::SeqCst);
979 let responses = self.responses.read().unwrap();
980 if idx < responses.len() {
981 Ok(responses[idx].clone())
982 } else {
983 Ok(Self::text_response("Done"))
985 }
986 }
987
988 fn model(&self) -> &'static str {
989 "mock-model"
990 }
991
992 fn provider(&self) -> &'static str {
993 "mock"
994 }
995 }
996
997 impl Clone for ChatOutcome {
999 fn clone(&self) -> Self {
1000 match self {
1001 Self::Success(r) => Self::Success(r.clone()),
1002 Self::RateLimited => Self::RateLimited,
1003 Self::InvalidRequest(s) => Self::InvalidRequest(s.clone()),
1004 Self::ServerError(s) => Self::ServerError(s.clone()),
1005 }
1006 }
1007 }
1008
1009 struct EchoTool;
1014
1015 #[async_trait]
1016 impl Tool<()> for EchoTool {
1017 fn name(&self) -> &'static str {
1018 "echo"
1019 }
1020
1021 fn description(&self) -> &'static str {
1022 "Echo the input message"
1023 }
1024
1025 fn input_schema(&self) -> serde_json::Value {
1026 json!({
1027 "type": "object",
1028 "properties": {
1029 "message": { "type": "string" }
1030 },
1031 "required": ["message"]
1032 })
1033 }
1034
1035 fn tier(&self) -> ToolTier {
1036 ToolTier::Observe
1037 }
1038
1039 async fn execute(
1040 &self,
1041 _ctx: &ToolContext<()>,
1042 input: serde_json::Value,
1043 ) -> Result<ToolResult> {
1044 let message = input
1045 .get("message")
1046 .and_then(|v| v.as_str())
1047 .unwrap_or("no message");
1048 Ok(ToolResult::success(format!("Echo: {message}")))
1049 }
1050 }
1051
1052 #[test]
1057 fn test_builder_creates_agent_loop() {
1058 let provider = MockProvider::new(vec![]);
1059 let agent = builder::<()>().provider(provider).build();
1060
1061 assert_eq!(agent.config.max_turns, 10);
1062 assert_eq!(agent.config.max_tokens, 4096);
1063 }
1064
1065 #[test]
1066 fn test_builder_with_custom_config() {
1067 let provider = MockProvider::new(vec![]);
1068 let config = AgentConfig {
1069 max_turns: 5,
1070 max_tokens: 2048,
1071 system_prompt: "Custom prompt".to_string(),
1072 model: "custom-model".to_string(),
1073 ..Default::default()
1074 };
1075
1076 let agent = builder::<()>().provider(provider).config(config).build();
1077
1078 assert_eq!(agent.config.max_turns, 5);
1079 assert_eq!(agent.config.max_tokens, 2048);
1080 assert_eq!(agent.config.system_prompt, "Custom prompt");
1081 }
1082
1083 #[test]
1084 fn test_builder_with_tools() {
1085 let provider = MockProvider::new(vec![]);
1086 let mut tools = ToolRegistry::new();
1087 tools.register(EchoTool);
1088
1089 let agent = builder::<()>().provider(provider).tools(tools).build();
1090
1091 assert_eq!(agent.tools.len(), 1);
1092 }
1093
1094 #[test]
1095 fn test_builder_with_custom_stores() {
1096 let provider = MockProvider::new(vec![]);
1097 let message_store = InMemoryStore::new();
1098 let state_store = InMemoryStore::new();
1099
1100 let agent = builder::<()>()
1101 .provider(provider)
1102 .hooks(AllowAllHooks)
1103 .message_store(message_store)
1104 .state_store(state_store)
1105 .build_with_stores();
1106
1107 assert_eq!(agent.config.max_turns, 10);
1109 }
1110
1111 #[tokio::test]
1116 async fn test_simple_text_response() -> anyhow::Result<()> {
1117 let provider = MockProvider::new(vec![MockProvider::text_response("Hello, user!")]);
1118
1119 let agent = builder::<()>().provider(provider).build();
1120
1121 let thread_id = ThreadId::new();
1122 let tool_ctx = ToolContext::new(());
1123 let mut rx = agent.run(thread_id, "Hi".to_string(), tool_ctx);
1124
1125 let mut events = Vec::new();
1126 while let Some(event) = rx.recv().await {
1127 events.push(event);
1128 }
1129
1130 assert!(events.iter().any(|e| matches!(e, AgentEvent::Text { .. })));
1132 assert!(events.iter().any(|e| matches!(e, AgentEvent::Done { .. })));
1133
1134 Ok(())
1135 }
1136
1137 #[tokio::test]
1138 async fn test_tool_execution() -> anyhow::Result<()> {
1139 let provider = MockProvider::new(vec![
1140 MockProvider::tool_use_response("tool_1", "echo", json!({"message": "test"})),
1142 MockProvider::text_response("Tool executed successfully"),
1144 ]);
1145
1146 let mut tools = ToolRegistry::new();
1147 tools.register(EchoTool);
1148
1149 let agent = builder::<()>().provider(provider).tools(tools).build();
1150
1151 let thread_id = ThreadId::new();
1152 let tool_ctx = ToolContext::new(());
1153 let mut rx = agent.run(thread_id, "Run echo".to_string(), tool_ctx);
1154
1155 let mut events = Vec::new();
1156 while let Some(event) = rx.recv().await {
1157 events.push(event);
1158 }
1159
1160 assert!(
1162 events
1163 .iter()
1164 .any(|e| matches!(e, AgentEvent::ToolCallStart { .. }))
1165 );
1166 assert!(
1167 events
1168 .iter()
1169 .any(|e| matches!(e, AgentEvent::ToolCallEnd { .. }))
1170 );
1171
1172 Ok(())
1173 }
1174
1175 #[tokio::test]
1176 async fn test_max_turns_limit() -> anyhow::Result<()> {
1177 let provider = MockProvider::new(vec![
1179 MockProvider::tool_use_response("tool_1", "echo", json!({"message": "1"})),
1180 MockProvider::tool_use_response("tool_2", "echo", json!({"message": "2"})),
1181 MockProvider::tool_use_response("tool_3", "echo", json!({"message": "3"})),
1182 MockProvider::tool_use_response("tool_4", "echo", json!({"message": "4"})),
1183 ]);
1184
1185 let mut tools = ToolRegistry::new();
1186 tools.register(EchoTool);
1187
1188 let config = AgentConfig {
1189 max_turns: 2,
1190 ..Default::default()
1191 };
1192
1193 let agent = builder::<()>()
1194 .provider(provider)
1195 .tools(tools)
1196 .config(config)
1197 .build();
1198
1199 let thread_id = ThreadId::new();
1200 let tool_ctx = ToolContext::new(());
1201 let mut rx = agent.run(thread_id, "Loop".to_string(), tool_ctx);
1202
1203 let mut events = Vec::new();
1204 while let Some(event) = rx.recv().await {
1205 events.push(event);
1206 }
1207
1208 assert!(events.iter().any(|e| {
1210 matches!(e, AgentEvent::Error { message, .. } if message.contains("Maximum turns"))
1211 }));
1212
1213 Ok(())
1214 }
1215
1216 #[tokio::test]
1217 async fn test_unknown_tool_handling() -> anyhow::Result<()> {
1218 let provider = MockProvider::new(vec![
1219 MockProvider::tool_use_response("tool_1", "nonexistent_tool", json!({})),
1221 MockProvider::text_response("I couldn't find that tool."),
1223 ]);
1224
1225 let tools = ToolRegistry::new();
1227
1228 let agent = builder::<()>().provider(provider).tools(tools).build();
1229
1230 let thread_id = ThreadId::new();
1231 let tool_ctx = ToolContext::new(());
1232 let mut rx = agent.run(thread_id, "Call unknown".to_string(), tool_ctx);
1233
1234 let mut events = Vec::new();
1235 while let Some(event) = rx.recv().await {
1236 events.push(event);
1237 }
1238
1239 assert!(events.iter().any(|e| matches!(e, AgentEvent::Done { .. })));
1242
1243 assert!(
1245 events.iter().any(|e| {
1246 matches!(e, AgentEvent::Text { text } if text.contains("couldn't find"))
1247 })
1248 );
1249
1250 Ok(())
1251 }
1252
1253 #[tokio::test]
1254 async fn test_rate_limit_handling() -> anyhow::Result<()> {
1255 let provider = MockProvider::new(vec![
1257 ChatOutcome::RateLimited,
1258 ChatOutcome::RateLimited,
1259 ChatOutcome::RateLimited,
1260 ChatOutcome::RateLimited,
1261 ChatOutcome::RateLimited,
1262 ChatOutcome::RateLimited, ]);
1264
1265 let config = AgentConfig {
1267 retry: crate::types::RetryConfig::fast(),
1268 ..Default::default()
1269 };
1270
1271 let agent = builder::<()>().provider(provider).config(config).build();
1272
1273 let thread_id = ThreadId::new();
1274 let tool_ctx = ToolContext::new(());
1275 let mut rx = agent.run(thread_id, "Hi".to_string(), tool_ctx);
1276
1277 let mut events = Vec::new();
1278 while let Some(event) = rx.recv().await {
1279 events.push(event);
1280 }
1281
1282 assert!(events.iter().any(|e| {
1284 matches!(e, AgentEvent::Error { message, recoverable: true } if message.contains("Rate limited"))
1285 }));
1286
1287 assert!(
1289 events
1290 .iter()
1291 .any(|e| { matches!(e, AgentEvent::Text { text } if text.contains("retrying")) })
1292 );
1293
1294 Ok(())
1295 }
1296
1297 #[tokio::test]
1298 async fn test_rate_limit_recovery() -> anyhow::Result<()> {
1299 let provider = MockProvider::new(vec![
1301 ChatOutcome::RateLimited,
1302 MockProvider::text_response("Recovered after rate limit"),
1303 ]);
1304
1305 let config = AgentConfig {
1307 retry: crate::types::RetryConfig::fast(),
1308 ..Default::default()
1309 };
1310
1311 let agent = builder::<()>().provider(provider).config(config).build();
1312
1313 let thread_id = ThreadId::new();
1314 let tool_ctx = ToolContext::new(());
1315 let mut rx = agent.run(thread_id, "Hi".to_string(), tool_ctx);
1316
1317 let mut events = Vec::new();
1318 while let Some(event) = rx.recv().await {
1319 events.push(event);
1320 }
1321
1322 assert!(events.iter().any(|e| matches!(e, AgentEvent::Done { .. })));
1324
1325 assert!(
1327 events
1328 .iter()
1329 .any(|e| { matches!(e, AgentEvent::Text { text } if text.contains("retrying")) })
1330 );
1331
1332 Ok(())
1333 }
1334
1335 #[tokio::test]
1336 async fn test_server_error_handling() -> anyhow::Result<()> {
1337 let provider = MockProvider::new(vec![
1339 ChatOutcome::ServerError("Internal error".to_string()),
1340 ChatOutcome::ServerError("Internal error".to_string()),
1341 ChatOutcome::ServerError("Internal error".to_string()),
1342 ChatOutcome::ServerError("Internal error".to_string()),
1343 ChatOutcome::ServerError("Internal error".to_string()),
1344 ChatOutcome::ServerError("Internal error".to_string()), ]);
1346
1347 let config = AgentConfig {
1349 retry: crate::types::RetryConfig::fast(),
1350 ..Default::default()
1351 };
1352
1353 let agent = builder::<()>().provider(provider).config(config).build();
1354
1355 let thread_id = ThreadId::new();
1356 let tool_ctx = ToolContext::new(());
1357 let mut rx = agent.run(thread_id, "Hi".to_string(), tool_ctx);
1358
1359 let mut events = Vec::new();
1360 while let Some(event) = rx.recv().await {
1361 events.push(event);
1362 }
1363
1364 assert!(events.iter().any(|e| {
1366 matches!(e, AgentEvent::Error { message, recoverable: true } if message.contains("Server error"))
1367 }));
1368
1369 assert!(
1371 events
1372 .iter()
1373 .any(|e| { matches!(e, AgentEvent::Text { text } if text.contains("retrying")) })
1374 );
1375
1376 Ok(())
1377 }
1378
1379 #[tokio::test]
1380 async fn test_server_error_recovery() -> anyhow::Result<()> {
1381 let provider = MockProvider::new(vec![
1383 ChatOutcome::ServerError("Temporary error".to_string()),
1384 MockProvider::text_response("Recovered after server error"),
1385 ]);
1386
1387 let config = AgentConfig {
1389 retry: crate::types::RetryConfig::fast(),
1390 ..Default::default()
1391 };
1392
1393 let agent = builder::<()>().provider(provider).config(config).build();
1394
1395 let thread_id = ThreadId::new();
1396 let tool_ctx = ToolContext::new(());
1397 let mut rx = agent.run(thread_id, "Hi".to_string(), tool_ctx);
1398
1399 let mut events = Vec::new();
1400 while let Some(event) = rx.recv().await {
1401 events.push(event);
1402 }
1403
1404 assert!(events.iter().any(|e| matches!(e, AgentEvent::Done { .. })));
1406
1407 assert!(
1409 events
1410 .iter()
1411 .any(|e| { matches!(e, AgentEvent::Text { text } if text.contains("retrying")) })
1412 );
1413
1414 Ok(())
1415 }
1416
1417 #[test]
1422 fn test_extract_content_text_only() {
1423 let response = ChatResponse {
1424 id: "msg_1".to_string(),
1425 content: vec![ContentBlock::Text {
1426 text: "Hello".to_string(),
1427 }],
1428 model: "test".to_string(),
1429 stop_reason: None,
1430 usage: Usage {
1431 input_tokens: 0,
1432 output_tokens: 0,
1433 },
1434 };
1435
1436 let (text, tool_uses) = extract_content(&response);
1437 assert_eq!(text, Some("Hello".to_string()));
1438 assert!(tool_uses.is_empty());
1439 }
1440
1441 #[test]
1442 fn test_extract_content_tool_use() {
1443 let response = ChatResponse {
1444 id: "msg_1".to_string(),
1445 content: vec![ContentBlock::ToolUse {
1446 id: "tool_1".to_string(),
1447 name: "test_tool".to_string(),
1448 input: json!({"key": "value"}),
1449 thought_signature: None,
1450 }],
1451 model: "test".to_string(),
1452 stop_reason: None,
1453 usage: Usage {
1454 input_tokens: 0,
1455 output_tokens: 0,
1456 },
1457 };
1458
1459 let (text, tool_uses) = extract_content(&response);
1460 assert!(text.is_none());
1461 assert_eq!(tool_uses.len(), 1);
1462 assert_eq!(tool_uses[0].1, "test_tool");
1463 }
1464
1465 #[test]
1466 fn test_extract_content_mixed() {
1467 let response = ChatResponse {
1468 id: "msg_1".to_string(),
1469 content: vec![
1470 ContentBlock::Text {
1471 text: "Let me help".to_string(),
1472 },
1473 ContentBlock::ToolUse {
1474 id: "tool_1".to_string(),
1475 name: "helper".to_string(),
1476 input: json!({}),
1477 thought_signature: None,
1478 },
1479 ],
1480 model: "test".to_string(),
1481 stop_reason: None,
1482 usage: Usage {
1483 input_tokens: 0,
1484 output_tokens: 0,
1485 },
1486 };
1487
1488 let (text, tool_uses) = extract_content(&response);
1489 assert_eq!(text, Some("Let me help".to_string()));
1490 assert_eq!(tool_uses.len(), 1);
1491 }
1492
1493 #[test]
1494 fn test_millis_to_u64() {
1495 assert_eq!(millis_to_u64(0), 0);
1496 assert_eq!(millis_to_u64(1000), 1000);
1497 assert_eq!(millis_to_u64(u128::from(u64::MAX)), u64::MAX);
1498 assert_eq!(millis_to_u64(u128::from(u64::MAX) + 1), u64::MAX);
1499 }
1500
1501 #[test]
1502 fn test_build_assistant_message() {
1503 let response = ChatResponse {
1504 id: "msg_1".to_string(),
1505 content: vec![
1506 ContentBlock::Text {
1507 text: "Response text".to_string(),
1508 },
1509 ContentBlock::ToolUse {
1510 id: "tool_1".to_string(),
1511 name: "echo".to_string(),
1512 input: json!({"message": "test"}),
1513 thought_signature: None,
1514 },
1515 ],
1516 model: "test".to_string(),
1517 stop_reason: None,
1518 usage: Usage {
1519 input_tokens: 0,
1520 output_tokens: 0,
1521 },
1522 };
1523
1524 let msg = build_assistant_message(&response);
1525 assert_eq!(msg.role, Role::Assistant);
1526
1527 if let Content::Blocks(blocks) = msg.content {
1528 assert_eq!(blocks.len(), 2);
1529 } else {
1530 panic!("Expected Content::Blocks");
1531 }
1532 }
1533}