1use crate::context::{CompactionConfig, ContextCompactor, LlmContextCompactor};
31use crate::events::AgentEvent;
32use crate::hooks::{AgentHooks, DefaultHooks, ToolDecision};
33use crate::llm::{
34 ChatOutcome, ChatRequest, ChatResponse, Content, ContentBlock, LlmProvider, Message, Role,
35 StopReason,
36};
37use crate::skills::Skill;
38use crate::stores::{InMemoryStore, MessageStore, StateStore};
39use crate::tools::{ToolContext, ToolRegistry};
40use crate::types::{AgentConfig, AgentState, RetryConfig, ThreadId, TokenUsage, ToolResult};
41use anyhow::Result;
42use std::sync::Arc;
43use std::time::{Duration, Instant};
44use tokio::sync::mpsc;
45use tokio::time::sleep;
46use tracing::{debug, error, info, warn};
47
48pub struct AgentLoopBuilder<Ctx, P, H, M, S> {
60 provider: Option<P>,
61 tools: Option<ToolRegistry<Ctx>>,
62 hooks: Option<H>,
63 message_store: Option<M>,
64 state_store: Option<S>,
65 config: Option<AgentConfig>,
66 compaction_config: Option<CompactionConfig>,
67}
68
69impl<Ctx> AgentLoopBuilder<Ctx, (), (), (), ()> {
70 #[must_use]
72 pub const fn new() -> Self {
73 Self {
74 provider: None,
75 tools: None,
76 hooks: None,
77 message_store: None,
78 state_store: None,
79 config: None,
80 compaction_config: None,
81 }
82 }
83}
84
85impl<Ctx> Default for AgentLoopBuilder<Ctx, (), (), (), ()> {
86 fn default() -> Self {
87 Self::new()
88 }
89}
90
91impl<Ctx, P, H, M, S> AgentLoopBuilder<Ctx, P, H, M, S> {
92 #[must_use]
94 pub fn provider<P2: LlmProvider>(self, provider: P2) -> AgentLoopBuilder<Ctx, P2, H, M, S> {
95 AgentLoopBuilder {
96 provider: Some(provider),
97 tools: self.tools,
98 hooks: self.hooks,
99 message_store: self.message_store,
100 state_store: self.state_store,
101 config: self.config,
102 compaction_config: self.compaction_config,
103 }
104 }
105
106 #[must_use]
108 pub fn tools(mut self, tools: ToolRegistry<Ctx>) -> Self {
109 self.tools = Some(tools);
110 self
111 }
112
113 #[must_use]
115 pub fn hooks<H2: AgentHooks>(self, hooks: H2) -> AgentLoopBuilder<Ctx, P, H2, M, S> {
116 AgentLoopBuilder {
117 provider: self.provider,
118 tools: self.tools,
119 hooks: Some(hooks),
120 message_store: self.message_store,
121 state_store: self.state_store,
122 config: self.config,
123 compaction_config: self.compaction_config,
124 }
125 }
126
127 #[must_use]
129 pub fn message_store<M2: MessageStore>(
130 self,
131 message_store: M2,
132 ) -> AgentLoopBuilder<Ctx, P, H, M2, S> {
133 AgentLoopBuilder {
134 provider: self.provider,
135 tools: self.tools,
136 hooks: self.hooks,
137 message_store: Some(message_store),
138 state_store: self.state_store,
139 config: self.config,
140 compaction_config: self.compaction_config,
141 }
142 }
143
144 #[must_use]
146 pub fn state_store<S2: StateStore>(
147 self,
148 state_store: S2,
149 ) -> AgentLoopBuilder<Ctx, P, H, M, S2> {
150 AgentLoopBuilder {
151 provider: self.provider,
152 tools: self.tools,
153 hooks: self.hooks,
154 message_store: self.message_store,
155 state_store: Some(state_store),
156 config: self.config,
157 compaction_config: self.compaction_config,
158 }
159 }
160
161 #[must_use]
163 pub fn config(mut self, config: AgentConfig) -> Self {
164 self.config = Some(config);
165 self
166 }
167
168 #[must_use]
184 pub const fn with_compaction(mut self, config: CompactionConfig) -> Self {
185 self.compaction_config = Some(config);
186 self
187 }
188
189 #[must_use]
196 pub fn with_auto_compaction(self) -> Self {
197 self.with_compaction(CompactionConfig::default())
198 }
199
200 #[must_use]
218 pub fn with_skill(mut self, skill: Skill) -> Self
219 where
220 Ctx: Send + Sync + 'static,
221 {
222 if let Some(ref mut tools) = self.tools {
224 tools.filter(|name| skill.is_tool_allowed(name));
225 }
226
227 let mut config = self.config.take().unwrap_or_default();
229 if config.system_prompt.is_empty() {
230 config.system_prompt = skill.system_prompt;
231 } else {
232 config.system_prompt = format!("{}\n\n{}", config.system_prompt, skill.system_prompt);
233 }
234 self.config = Some(config);
235
236 self
237 }
238}
239
240impl<Ctx, P> AgentLoopBuilder<Ctx, P, (), (), ()>
241where
242 Ctx: Send + Sync + 'static,
243 P: LlmProvider + 'static,
244{
245 #[must_use]
257 pub fn build(self) -> AgentLoop<Ctx, P, DefaultHooks, InMemoryStore, InMemoryStore> {
258 let provider = self.provider.expect("provider is required");
259 let tools = self.tools.unwrap_or_default();
260 let config = self.config.unwrap_or_default();
261
262 AgentLoop {
263 provider: Arc::new(provider),
264 tools: Arc::new(tools),
265 hooks: Arc::new(DefaultHooks),
266 message_store: Arc::new(InMemoryStore::new()),
267 state_store: Arc::new(InMemoryStore::new()),
268 config,
269 compaction_config: self.compaction_config,
270 }
271 }
272}
273
274impl<Ctx, P, H, M, S> AgentLoopBuilder<Ctx, P, H, M, S>
275where
276 Ctx: Send + Sync + 'static,
277 P: LlmProvider + 'static,
278 H: AgentHooks + 'static,
279 M: MessageStore + 'static,
280 S: StateStore + 'static,
281{
282 #[must_use]
292 pub fn build_with_stores(self) -> AgentLoop<Ctx, P, H, M, S> {
293 let provider = self.provider.expect("provider is required");
294 let tools = self.tools.unwrap_or_default();
295 let hooks = self
296 .hooks
297 .expect("hooks is required when using build_with_stores");
298 let message_store = self
299 .message_store
300 .expect("message_store is required when using build_with_stores");
301 let state_store = self
302 .state_store
303 .expect("state_store is required when using build_with_stores");
304 let config = self.config.unwrap_or_default();
305
306 AgentLoop {
307 provider: Arc::new(provider),
308 tools: Arc::new(tools),
309 hooks: Arc::new(hooks),
310 message_store: Arc::new(message_store),
311 state_store: Arc::new(state_store),
312 config,
313 compaction_config: self.compaction_config,
314 }
315 }
316}
317
318pub struct AgentLoop<Ctx, P, H, M, S>
348where
349 P: LlmProvider,
350 H: AgentHooks,
351 M: MessageStore,
352 S: StateStore,
353{
354 provider: Arc<P>,
355 tools: Arc<ToolRegistry<Ctx>>,
356 hooks: Arc<H>,
357 message_store: Arc<M>,
358 state_store: Arc<S>,
359 config: AgentConfig,
360 compaction_config: Option<CompactionConfig>,
361}
362
363#[must_use]
365pub const fn builder<Ctx>() -> AgentLoopBuilder<Ctx, (), (), (), ()> {
366 AgentLoopBuilder::new()
367}
368
369impl<Ctx, P, H, M, S> AgentLoop<Ctx, P, H, M, S>
370where
371 Ctx: Send + Sync + 'static,
372 P: LlmProvider + 'static,
373 H: AgentHooks + 'static,
374 M: MessageStore + 'static,
375 S: StateStore + 'static,
376{
377 #[must_use]
379 pub fn new(
380 provider: P,
381 tools: ToolRegistry<Ctx>,
382 hooks: H,
383 message_store: M,
384 state_store: S,
385 config: AgentConfig,
386 ) -> Self {
387 Self {
388 provider: Arc::new(provider),
389 tools: Arc::new(tools),
390 hooks: Arc::new(hooks),
391 message_store: Arc::new(message_store),
392 state_store: Arc::new(state_store),
393 config,
394 compaction_config: None,
395 }
396 }
397
398 #[must_use]
400 pub fn with_compaction(
401 provider: P,
402 tools: ToolRegistry<Ctx>,
403 hooks: H,
404 message_store: M,
405 state_store: S,
406 config: AgentConfig,
407 compaction_config: CompactionConfig,
408 ) -> Self {
409 Self {
410 provider: Arc::new(provider),
411 tools: Arc::new(tools),
412 hooks: Arc::new(hooks),
413 message_store: Arc::new(message_store),
414 state_store: Arc::new(state_store),
415 config,
416 compaction_config: Some(compaction_config),
417 }
418 }
419
420 pub fn run(
423 &self,
424 thread_id: ThreadId,
425 user_message: String,
426 tool_context: ToolContext<Ctx>,
427 ) -> mpsc::Receiver<AgentEvent>
428 where
429 Ctx: Clone,
430 {
431 let (tx, rx) = mpsc::channel(100);
432
433 let provider = Arc::clone(&self.provider);
434 let tools = Arc::clone(&self.tools);
435 let hooks = Arc::clone(&self.hooks);
436 let message_store = Arc::clone(&self.message_store);
437 let state_store = Arc::clone(&self.state_store);
438 let config = self.config.clone();
439 let compaction_config = self.compaction_config.clone();
440
441 tokio::spawn(async move {
442 let result = run_loop(
443 tx.clone(),
444 thread_id,
445 user_message,
446 tool_context,
447 provider,
448 tools,
449 hooks,
450 message_store,
451 state_store,
452 config,
453 compaction_config,
454 )
455 .await;
456
457 if let Err(e) = result {
458 let _ = tx.send(AgentEvent::error(e.to_string(), false)).await;
459 }
460 });
461
462 rx
463 }
464}
465
466#[allow(clippy::too_many_arguments, clippy::too_many_lines)]
467async fn run_loop<Ctx, P, H, M, S>(
468 tx: mpsc::Sender<AgentEvent>,
469 thread_id: ThreadId,
470 user_message: String,
471 tool_context: ToolContext<Ctx>,
472 provider: Arc<P>,
473 tools: Arc<ToolRegistry<Ctx>>,
474 hooks: Arc<H>,
475 message_store: Arc<M>,
476 state_store: Arc<S>,
477 config: AgentConfig,
478 compaction_config: Option<CompactionConfig>,
479) -> Result<()>
480where
481 Ctx: Send + Sync + Clone + 'static,
482 P: LlmProvider,
483 H: AgentHooks,
484 M: MessageStore,
485 S: StateStore,
486{
487 let tool_context = tool_context.with_event_tx(tx.clone());
489
490 let start_time = Instant::now();
491 let mut turn = 0;
492 let mut total_usage = TokenUsage::default();
493
494 let mut state = state_store
496 .load(&thread_id)
497 .await?
498 .unwrap_or_else(|| AgentState::new(thread_id.clone()));
499
500 let user_msg = Message::user(&user_message);
502 message_store.append(&thread_id, user_msg).await?;
503
504 loop {
506 turn += 1;
507 state.turn_count = turn;
508
509 if turn > config.max_turns {
510 warn!(turn, max = config.max_turns, "Max turns reached");
511 tx.send(AgentEvent::error(
512 format!("Maximum turns ({}) reached", config.max_turns),
513 true,
514 ))
515 .await?;
516 break;
517 }
518
519 tx.send(AgentEvent::start(thread_id.clone(), turn)).await?;
521 hooks
522 .on_event(&AgentEvent::start(thread_id.clone(), turn))
523 .await;
524
525 let mut messages = message_store.get_history(&thread_id).await?;
527
528 if let Some(ref compact_config) = compaction_config {
530 let compactor = LlmContextCompactor::new(Arc::clone(&provider), compact_config.clone());
531 if compactor.needs_compaction(&messages) {
532 debug!(
533 turn,
534 message_count = messages.len(),
535 "Context compaction triggered"
536 );
537
538 match compactor.compact_history(messages).await {
539 Ok(result) => {
540 message_store
542 .replace_history(&thread_id, result.messages.clone())
543 .await?;
544
545 tx.send(AgentEvent::context_compacted(
547 result.original_count,
548 result.new_count,
549 result.original_tokens,
550 result.new_tokens,
551 ))
552 .await?;
553
554 info!(
555 original_count = result.original_count,
556 new_count = result.new_count,
557 original_tokens = result.original_tokens,
558 new_tokens = result.new_tokens,
559 "Context compacted successfully"
560 );
561
562 messages = result.messages;
564 }
565 Err(e) => {
566 warn!(error = %e, "Context compaction failed, continuing with full history");
567 messages = message_store.get_history(&thread_id).await?;
569 }
570 }
571 }
572 }
573
574 let llm_tools = if tools.is_empty() {
576 None
577 } else {
578 Some(tools.to_llm_tools())
579 };
580
581 let request = ChatRequest {
582 system: config.system_prompt.clone(),
583 messages,
584 tools: llm_tools,
585 max_tokens: config.max_tokens,
586 };
587
588 debug!(turn, "Calling LLM");
590 let max_retries = config.retry.max_retries;
591 let response = {
592 let mut attempt = 0u32;
593 loop {
594 let outcome = provider.chat(request.clone()).await?;
595 match outcome {
596 ChatOutcome::Success(response) => break Some(response),
597 ChatOutcome::RateLimited => {
598 attempt += 1;
599 if attempt > max_retries {
600 error!("Rate limited by LLM provider after {max_retries} retries");
601 tx.send(AgentEvent::error(
602 format!("Rate limited after {max_retries} retries"),
603 true,
604 ))
605 .await?;
606 break None;
607 }
608 let delay = calculate_backoff_delay(attempt, &config.retry);
609 warn!(
610 attempt,
611 delay_ms = delay.as_millis(),
612 "Rate limited, retrying after backoff"
613 );
614 tx.send(AgentEvent::text(format!(
615 "\n[Rate limited, retrying in {:.1}s... (attempt {attempt}/{max_retries})]\n",
616 delay.as_secs_f64()
617 )))
618 .await?;
619 sleep(delay).await;
620 }
621 ChatOutcome::InvalidRequest(msg) => {
622 error!(msg, "Invalid request to LLM");
623 tx.send(AgentEvent::error(format!("Invalid request: {msg}"), false))
624 .await?;
625 break None;
626 }
627 ChatOutcome::ServerError(msg) => {
628 attempt += 1;
629 if attempt > max_retries {
630 error!(msg, "LLM server error after {max_retries} retries");
631 tx.send(AgentEvent::error(
632 format!("Server error after {max_retries} retries: {msg}"),
633 true,
634 ))
635 .await?;
636 break None;
637 }
638 let delay = calculate_backoff_delay(attempt, &config.retry);
639 warn!(
640 attempt,
641 delay_ms = delay.as_millis(),
642 error = msg,
643 "Server error, retrying after backoff"
644 );
645 tx.send(AgentEvent::text(format!(
646 "\n[Server error: {msg}, retrying in {:.1}s... (attempt {attempt}/{max_retries})]\n",
647 delay.as_secs_f64()
648 )))
649 .await?;
650 sleep(delay).await;
651 }
652 }
653 }
654 };
655
656 let Some(response) = response else {
658 break;
659 };
660
661 let turn_usage = TokenUsage {
663 input_tokens: response.usage.input_tokens,
664 output_tokens: response.usage.output_tokens,
665 };
666 total_usage.add(&turn_usage);
667 state.total_usage = total_usage.clone();
668
669 let (text_content, tool_uses) = extract_content(&response);
671
672 if let Some(text) = &text_content {
674 tx.send(AgentEvent::text(text.clone())).await?;
675 hooks.on_event(&AgentEvent::text(text.clone())).await;
676 }
677
678 if tool_uses.is_empty() {
680 info!(turn, "Agent completed (no tool use)");
681 break;
682 }
683
684 let assistant_msg = build_assistant_message(&response);
686 message_store.append(&thread_id, assistant_msg).await?;
687
688 let mut tool_results = Vec::new();
690 for (tool_id, tool_name, tool_input) in &tool_uses {
691 let Some(tool) = tools.get(tool_name) else {
692 let result = ToolResult::error(format!("Unknown tool: {tool_name}"));
693 tool_results.push((tool_id.clone(), result));
694 continue;
695 };
696
697 let tier = tool.tier();
698
699 tx.send(AgentEvent::tool_call_start(
701 tool_id,
702 tool_name,
703 tool_input.clone(),
704 tier,
705 ))
706 .await?;
707
708 let decision = hooks.pre_tool_use(tool_name, tool_input, tier).await;
710
711 match decision {
712 ToolDecision::Allow => {
713 let tool_start = Instant::now();
715 let result = match tool.execute(&tool_context, tool_input.clone()).await {
716 Ok(mut r) => {
717 r.duration_ms = Some(millis_to_u64(tool_start.elapsed().as_millis()));
718 r
719 }
720 Err(e) => ToolResult::error(format!("Tool error: {e}"))
721 .with_duration(millis_to_u64(tool_start.elapsed().as_millis())),
722 };
723
724 hooks.post_tool_use(tool_name, &result).await;
726
727 tx.send(AgentEvent::tool_call_end(
729 tool_id,
730 tool_name,
731 result.clone(),
732 ))
733 .await?;
734
735 tool_results.push((tool_id.clone(), result));
736 }
737 ToolDecision::Block(reason) => {
738 let result = ToolResult::error(format!("Blocked: {reason}"));
739 tx.send(AgentEvent::tool_call_end(
740 tool_id,
741 tool_name,
742 result.clone(),
743 ))
744 .await?;
745 tool_results.push((tool_id.clone(), result));
746 }
747 ToolDecision::RequiresConfirmation(description) => {
748 tx.send(AgentEvent::ToolRequiresConfirmation {
749 id: tool_id.clone(),
750 name: tool_name.clone(),
751 input: tool_input.clone(),
752 description,
753 })
754 .await?;
755 let result = ToolResult::error("Awaiting user confirmation");
757 tool_results.push((tool_id.clone(), result));
758 }
759 ToolDecision::RequiresPin(description) => {
760 tx.send(AgentEvent::ToolRequiresPin {
761 id: tool_id.clone(),
762 name: tool_name.clone(),
763 input: tool_input.clone(),
764 description,
765 })
766 .await?;
767 let result = ToolResult::error("Awaiting PIN verification");
769 tool_results.push((tool_id.clone(), result));
770 }
771 }
772 }
773
774 for (tool_id, result) in &tool_results {
776 let tool_result_msg = Message::tool_result(tool_id, &result.output, !result.success);
777 message_store.append(&thread_id, tool_result_msg).await?;
778 }
779
780 tx.send(AgentEvent::TurnComplete {
782 turn,
783 usage: turn_usage,
784 })
785 .await?;
786
787 if response.stop_reason == Some(StopReason::EndTurn) {
789 info!(turn, "Agent completed (end_turn)");
790 break;
791 }
792
793 state_store.save(&state).await?;
795 }
796
797 state_store.save(&state).await?;
799
800 let duration = start_time.elapsed();
802 tx.send(AgentEvent::done(thread_id, turn, total_usage, duration))
803 .await?;
804
805 Ok(())
806}
807
808#[allow(clippy::cast_possible_truncation)]
810const fn millis_to_u64(millis: u128) -> u64 {
811 if millis > u64::MAX as u128 {
812 u64::MAX
813 } else {
814 millis as u64
815 }
816}
817
818fn calculate_backoff_delay(attempt: u32, config: &RetryConfig) -> Duration {
823 let base_delay = config
825 .base_delay_ms
826 .saturating_mul(1u64 << (attempt.saturating_sub(1)));
827
828 let max_jitter = config.base_delay_ms.min(1000);
830 let jitter = if max_jitter > 0 {
831 u64::from(
832 std::time::SystemTime::now()
833 .duration_since(std::time::UNIX_EPOCH)
834 .unwrap_or_default()
835 .subsec_nanos(),
836 ) % max_jitter
837 } else {
838 0
839 };
840
841 let delay_ms = base_delay.saturating_add(jitter).min(config.max_delay_ms);
842 Duration::from_millis(delay_ms)
843}
844
845fn extract_content(
846 response: &ChatResponse,
847) -> (Option<String>, Vec<(String, String, serde_json::Value)>) {
848 let mut text_parts = Vec::new();
849 let mut tool_uses = Vec::new();
850
851 for block in &response.content {
852 match block {
853 ContentBlock::Text { text } => {
854 text_parts.push(text.clone());
855 }
856 ContentBlock::ToolUse { id, name, input } => {
857 tool_uses.push((id.clone(), name.clone(), input.clone()));
858 }
859 ContentBlock::ToolResult { .. } => {
860 }
862 }
863 }
864
865 let text = if text_parts.is_empty() {
866 None
867 } else {
868 Some(text_parts.join("\n"))
869 };
870
871 (text, tool_uses)
872}
873
874fn build_assistant_message(response: &ChatResponse) -> Message {
875 let mut blocks = Vec::new();
876
877 for block in &response.content {
878 match block {
879 ContentBlock::Text { text } => {
880 blocks.push(ContentBlock::Text { text: text.clone() });
881 }
882 ContentBlock::ToolUse { id, name, input } => {
883 blocks.push(ContentBlock::ToolUse {
884 id: id.clone(),
885 name: name.clone(),
886 input: input.clone(),
887 });
888 }
889 ContentBlock::ToolResult { .. } => {}
890 }
891 }
892
893 Message {
894 role: Role::Assistant,
895 content: Content::Blocks(blocks),
896 }
897}
898
899#[cfg(test)]
900mod tests {
901 use super::*;
902 use crate::hooks::AllowAllHooks;
903 use crate::llm::{ChatOutcome, ChatRequest, ChatResponse, ContentBlock, StopReason, Usage};
904 use crate::stores::InMemoryStore;
905 use crate::tools::{Tool, ToolContext, ToolRegistry};
906 use crate::types::{AgentConfig, ToolResult, ToolTier};
907 use async_trait::async_trait;
908 use serde_json::json;
909 use std::sync::RwLock;
910 use std::sync::atomic::{AtomicUsize, Ordering};
911
912 struct MockProvider {
917 responses: RwLock<Vec<ChatOutcome>>,
918 call_count: AtomicUsize,
919 }
920
921 impl MockProvider {
922 fn new(responses: Vec<ChatOutcome>) -> Self {
923 Self {
924 responses: RwLock::new(responses),
925 call_count: AtomicUsize::new(0),
926 }
927 }
928
929 fn text_response(text: &str) -> ChatOutcome {
930 ChatOutcome::Success(ChatResponse {
931 id: "msg_1".to_string(),
932 content: vec![ContentBlock::Text {
933 text: text.to_string(),
934 }],
935 model: "mock-model".to_string(),
936 stop_reason: Some(StopReason::EndTurn),
937 usage: Usage {
938 input_tokens: 10,
939 output_tokens: 20,
940 },
941 })
942 }
943
944 fn tool_use_response(
945 tool_id: &str,
946 tool_name: &str,
947 input: serde_json::Value,
948 ) -> ChatOutcome {
949 ChatOutcome::Success(ChatResponse {
950 id: "msg_1".to_string(),
951 content: vec![ContentBlock::ToolUse {
952 id: tool_id.to_string(),
953 name: tool_name.to_string(),
954 input,
955 }],
956 model: "mock-model".to_string(),
957 stop_reason: Some(StopReason::ToolUse),
958 usage: Usage {
959 input_tokens: 10,
960 output_tokens: 20,
961 },
962 })
963 }
964 }
965
966 #[async_trait]
967 impl LlmProvider for MockProvider {
968 async fn chat(&self, _request: ChatRequest) -> Result<ChatOutcome> {
969 let idx = self.call_count.fetch_add(1, Ordering::SeqCst);
970 let responses = self.responses.read().unwrap();
971 if idx < responses.len() {
972 Ok(responses[idx].clone())
973 } else {
974 Ok(Self::text_response("Done"))
976 }
977 }
978
979 fn model(&self) -> &'static str {
980 "mock-model"
981 }
982
983 fn provider(&self) -> &'static str {
984 "mock"
985 }
986 }
987
988 impl Clone for ChatOutcome {
990 fn clone(&self) -> Self {
991 match self {
992 Self::Success(r) => Self::Success(r.clone()),
993 Self::RateLimited => Self::RateLimited,
994 Self::InvalidRequest(s) => Self::InvalidRequest(s.clone()),
995 Self::ServerError(s) => Self::ServerError(s.clone()),
996 }
997 }
998 }
999
1000 struct EchoTool;
1005
1006 #[async_trait]
1007 impl Tool<()> for EchoTool {
1008 fn name(&self) -> &'static str {
1009 "echo"
1010 }
1011
1012 fn description(&self) -> &'static str {
1013 "Echo the input message"
1014 }
1015
1016 fn input_schema(&self) -> serde_json::Value {
1017 json!({
1018 "type": "object",
1019 "properties": {
1020 "message": { "type": "string" }
1021 },
1022 "required": ["message"]
1023 })
1024 }
1025
1026 fn tier(&self) -> ToolTier {
1027 ToolTier::Observe
1028 }
1029
1030 async fn execute(
1031 &self,
1032 _ctx: &ToolContext<()>,
1033 input: serde_json::Value,
1034 ) -> Result<ToolResult> {
1035 let message = input
1036 .get("message")
1037 .and_then(|v| v.as_str())
1038 .unwrap_or("no message");
1039 Ok(ToolResult::success(format!("Echo: {message}")))
1040 }
1041 }
1042
1043 #[test]
1048 fn test_builder_creates_agent_loop() {
1049 let provider = MockProvider::new(vec![]);
1050 let agent = builder::<()>().provider(provider).build();
1051
1052 assert_eq!(agent.config.max_turns, 10);
1053 assert_eq!(agent.config.max_tokens, 4096);
1054 }
1055
1056 #[test]
1057 fn test_builder_with_custom_config() {
1058 let provider = MockProvider::new(vec![]);
1059 let config = AgentConfig {
1060 max_turns: 5,
1061 max_tokens: 2048,
1062 system_prompt: "Custom prompt".to_string(),
1063 model: "custom-model".to_string(),
1064 ..Default::default()
1065 };
1066
1067 let agent = builder::<()>().provider(provider).config(config).build();
1068
1069 assert_eq!(agent.config.max_turns, 5);
1070 assert_eq!(agent.config.max_tokens, 2048);
1071 assert_eq!(agent.config.system_prompt, "Custom prompt");
1072 }
1073
1074 #[test]
1075 fn test_builder_with_tools() {
1076 let provider = MockProvider::new(vec![]);
1077 let mut tools = ToolRegistry::new();
1078 tools.register(EchoTool);
1079
1080 let agent = builder::<()>().provider(provider).tools(tools).build();
1081
1082 assert_eq!(agent.tools.len(), 1);
1083 }
1084
1085 #[test]
1086 fn test_builder_with_custom_stores() {
1087 let provider = MockProvider::new(vec![]);
1088 let message_store = InMemoryStore::new();
1089 let state_store = InMemoryStore::new();
1090
1091 let agent = builder::<()>()
1092 .provider(provider)
1093 .hooks(AllowAllHooks)
1094 .message_store(message_store)
1095 .state_store(state_store)
1096 .build_with_stores();
1097
1098 assert_eq!(agent.config.max_turns, 10);
1100 }
1101
1102 #[tokio::test]
1107 async fn test_simple_text_response() -> anyhow::Result<()> {
1108 let provider = MockProvider::new(vec![MockProvider::text_response("Hello, user!")]);
1109
1110 let agent = builder::<()>().provider(provider).build();
1111
1112 let thread_id = ThreadId::new();
1113 let tool_ctx = ToolContext::new(());
1114 let mut rx = agent.run(thread_id, "Hi".to_string(), tool_ctx);
1115
1116 let mut events = Vec::new();
1117 while let Some(event) = rx.recv().await {
1118 events.push(event);
1119 }
1120
1121 assert!(events.iter().any(|e| matches!(e, AgentEvent::Text { .. })));
1123 assert!(events.iter().any(|e| matches!(e, AgentEvent::Done { .. })));
1124
1125 Ok(())
1126 }
1127
1128 #[tokio::test]
1129 async fn test_tool_execution() -> anyhow::Result<()> {
1130 let provider = MockProvider::new(vec![
1131 MockProvider::tool_use_response("tool_1", "echo", json!({"message": "test"})),
1133 MockProvider::text_response("Tool executed successfully"),
1135 ]);
1136
1137 let mut tools = ToolRegistry::new();
1138 tools.register(EchoTool);
1139
1140 let agent = builder::<()>().provider(provider).tools(tools).build();
1141
1142 let thread_id = ThreadId::new();
1143 let tool_ctx = ToolContext::new(());
1144 let mut rx = agent.run(thread_id, "Run echo".to_string(), tool_ctx);
1145
1146 let mut events = Vec::new();
1147 while let Some(event) = rx.recv().await {
1148 events.push(event);
1149 }
1150
1151 assert!(
1153 events
1154 .iter()
1155 .any(|e| matches!(e, AgentEvent::ToolCallStart { .. }))
1156 );
1157 assert!(
1158 events
1159 .iter()
1160 .any(|e| matches!(e, AgentEvent::ToolCallEnd { .. }))
1161 );
1162
1163 Ok(())
1164 }
1165
1166 #[tokio::test]
1167 async fn test_max_turns_limit() -> anyhow::Result<()> {
1168 let provider = MockProvider::new(vec![
1170 MockProvider::tool_use_response("tool_1", "echo", json!({"message": "1"})),
1171 MockProvider::tool_use_response("tool_2", "echo", json!({"message": "2"})),
1172 MockProvider::tool_use_response("tool_3", "echo", json!({"message": "3"})),
1173 MockProvider::tool_use_response("tool_4", "echo", json!({"message": "4"})),
1174 ]);
1175
1176 let mut tools = ToolRegistry::new();
1177 tools.register(EchoTool);
1178
1179 let config = AgentConfig {
1180 max_turns: 2,
1181 ..Default::default()
1182 };
1183
1184 let agent = builder::<()>()
1185 .provider(provider)
1186 .tools(tools)
1187 .config(config)
1188 .build();
1189
1190 let thread_id = ThreadId::new();
1191 let tool_ctx = ToolContext::new(());
1192 let mut rx = agent.run(thread_id, "Loop".to_string(), tool_ctx);
1193
1194 let mut events = Vec::new();
1195 while let Some(event) = rx.recv().await {
1196 events.push(event);
1197 }
1198
1199 assert!(events.iter().any(|e| {
1201 matches!(e, AgentEvent::Error { message, .. } if message.contains("Maximum turns"))
1202 }));
1203
1204 Ok(())
1205 }
1206
1207 #[tokio::test]
1208 async fn test_unknown_tool_handling() -> anyhow::Result<()> {
1209 let provider = MockProvider::new(vec![
1210 MockProvider::tool_use_response("tool_1", "nonexistent_tool", json!({})),
1212 MockProvider::text_response("I couldn't find that tool."),
1214 ]);
1215
1216 let tools = ToolRegistry::new();
1218
1219 let agent = builder::<()>().provider(provider).tools(tools).build();
1220
1221 let thread_id = ThreadId::new();
1222 let tool_ctx = ToolContext::new(());
1223 let mut rx = agent.run(thread_id, "Call unknown".to_string(), tool_ctx);
1224
1225 let mut events = Vec::new();
1226 while let Some(event) = rx.recv().await {
1227 events.push(event);
1228 }
1229
1230 assert!(events.iter().any(|e| matches!(e, AgentEvent::Done { .. })));
1233
1234 assert!(
1236 events.iter().any(|e| {
1237 matches!(e, AgentEvent::Text { text } if text.contains("couldn't find"))
1238 })
1239 );
1240
1241 Ok(())
1242 }
1243
1244 #[tokio::test]
1245 async fn test_rate_limit_handling() -> anyhow::Result<()> {
1246 let provider = MockProvider::new(vec![
1248 ChatOutcome::RateLimited,
1249 ChatOutcome::RateLimited,
1250 ChatOutcome::RateLimited,
1251 ChatOutcome::RateLimited,
1252 ChatOutcome::RateLimited,
1253 ChatOutcome::RateLimited, ]);
1255
1256 let config = AgentConfig {
1258 retry: crate::types::RetryConfig::fast(),
1259 ..Default::default()
1260 };
1261
1262 let agent = builder::<()>().provider(provider).config(config).build();
1263
1264 let thread_id = ThreadId::new();
1265 let tool_ctx = ToolContext::new(());
1266 let mut rx = agent.run(thread_id, "Hi".to_string(), tool_ctx);
1267
1268 let mut events = Vec::new();
1269 while let Some(event) = rx.recv().await {
1270 events.push(event);
1271 }
1272
1273 assert!(events.iter().any(|e| {
1275 matches!(e, AgentEvent::Error { message, recoverable: true } if message.contains("Rate limited"))
1276 }));
1277
1278 assert!(
1280 events
1281 .iter()
1282 .any(|e| { matches!(e, AgentEvent::Text { text } if text.contains("retrying")) })
1283 );
1284
1285 Ok(())
1286 }
1287
1288 #[tokio::test]
1289 async fn test_rate_limit_recovery() -> anyhow::Result<()> {
1290 let provider = MockProvider::new(vec![
1292 ChatOutcome::RateLimited,
1293 MockProvider::text_response("Recovered after rate limit"),
1294 ]);
1295
1296 let config = AgentConfig {
1298 retry: crate::types::RetryConfig::fast(),
1299 ..Default::default()
1300 };
1301
1302 let agent = builder::<()>().provider(provider).config(config).build();
1303
1304 let thread_id = ThreadId::new();
1305 let tool_ctx = ToolContext::new(());
1306 let mut rx = agent.run(thread_id, "Hi".to_string(), tool_ctx);
1307
1308 let mut events = Vec::new();
1309 while let Some(event) = rx.recv().await {
1310 events.push(event);
1311 }
1312
1313 assert!(events.iter().any(|e| matches!(e, AgentEvent::Done { .. })));
1315
1316 assert!(
1318 events
1319 .iter()
1320 .any(|e| { matches!(e, AgentEvent::Text { text } if text.contains("retrying")) })
1321 );
1322
1323 Ok(())
1324 }
1325
1326 #[tokio::test]
1327 async fn test_server_error_handling() -> anyhow::Result<()> {
1328 let provider = MockProvider::new(vec![
1330 ChatOutcome::ServerError("Internal error".to_string()),
1331 ChatOutcome::ServerError("Internal error".to_string()),
1332 ChatOutcome::ServerError("Internal error".to_string()),
1333 ChatOutcome::ServerError("Internal error".to_string()),
1334 ChatOutcome::ServerError("Internal error".to_string()),
1335 ChatOutcome::ServerError("Internal error".to_string()), ]);
1337
1338 let config = AgentConfig {
1340 retry: crate::types::RetryConfig::fast(),
1341 ..Default::default()
1342 };
1343
1344 let agent = builder::<()>().provider(provider).config(config).build();
1345
1346 let thread_id = ThreadId::new();
1347 let tool_ctx = ToolContext::new(());
1348 let mut rx = agent.run(thread_id, "Hi".to_string(), tool_ctx);
1349
1350 let mut events = Vec::new();
1351 while let Some(event) = rx.recv().await {
1352 events.push(event);
1353 }
1354
1355 assert!(events.iter().any(|e| {
1357 matches!(e, AgentEvent::Error { message, recoverable: true } if message.contains("Server error"))
1358 }));
1359
1360 assert!(
1362 events
1363 .iter()
1364 .any(|e| { matches!(e, AgentEvent::Text { text } if text.contains("retrying")) })
1365 );
1366
1367 Ok(())
1368 }
1369
1370 #[tokio::test]
1371 async fn test_server_error_recovery() -> anyhow::Result<()> {
1372 let provider = MockProvider::new(vec![
1374 ChatOutcome::ServerError("Temporary error".to_string()),
1375 MockProvider::text_response("Recovered after server error"),
1376 ]);
1377
1378 let config = AgentConfig {
1380 retry: crate::types::RetryConfig::fast(),
1381 ..Default::default()
1382 };
1383
1384 let agent = builder::<()>().provider(provider).config(config).build();
1385
1386 let thread_id = ThreadId::new();
1387 let tool_ctx = ToolContext::new(());
1388 let mut rx = agent.run(thread_id, "Hi".to_string(), tool_ctx);
1389
1390 let mut events = Vec::new();
1391 while let Some(event) = rx.recv().await {
1392 events.push(event);
1393 }
1394
1395 assert!(events.iter().any(|e| matches!(e, AgentEvent::Done { .. })));
1397
1398 assert!(
1400 events
1401 .iter()
1402 .any(|e| { matches!(e, AgentEvent::Text { text } if text.contains("retrying")) })
1403 );
1404
1405 Ok(())
1406 }
1407
1408 #[test]
1413 fn test_extract_content_text_only() {
1414 let response = ChatResponse {
1415 id: "msg_1".to_string(),
1416 content: vec![ContentBlock::Text {
1417 text: "Hello".to_string(),
1418 }],
1419 model: "test".to_string(),
1420 stop_reason: None,
1421 usage: Usage {
1422 input_tokens: 0,
1423 output_tokens: 0,
1424 },
1425 };
1426
1427 let (text, tool_uses) = extract_content(&response);
1428 assert_eq!(text, Some("Hello".to_string()));
1429 assert!(tool_uses.is_empty());
1430 }
1431
1432 #[test]
1433 fn test_extract_content_tool_use() {
1434 let response = ChatResponse {
1435 id: "msg_1".to_string(),
1436 content: vec![ContentBlock::ToolUse {
1437 id: "tool_1".to_string(),
1438 name: "test_tool".to_string(),
1439 input: json!({"key": "value"}),
1440 }],
1441 model: "test".to_string(),
1442 stop_reason: None,
1443 usage: Usage {
1444 input_tokens: 0,
1445 output_tokens: 0,
1446 },
1447 };
1448
1449 let (text, tool_uses) = extract_content(&response);
1450 assert!(text.is_none());
1451 assert_eq!(tool_uses.len(), 1);
1452 assert_eq!(tool_uses[0].1, "test_tool");
1453 }
1454
1455 #[test]
1456 fn test_extract_content_mixed() {
1457 let response = ChatResponse {
1458 id: "msg_1".to_string(),
1459 content: vec![
1460 ContentBlock::Text {
1461 text: "Let me help".to_string(),
1462 },
1463 ContentBlock::ToolUse {
1464 id: "tool_1".to_string(),
1465 name: "helper".to_string(),
1466 input: json!({}),
1467 },
1468 ],
1469 model: "test".to_string(),
1470 stop_reason: None,
1471 usage: Usage {
1472 input_tokens: 0,
1473 output_tokens: 0,
1474 },
1475 };
1476
1477 let (text, tool_uses) = extract_content(&response);
1478 assert_eq!(text, Some("Let me help".to_string()));
1479 assert_eq!(tool_uses.len(), 1);
1480 }
1481
1482 #[test]
1483 fn test_millis_to_u64() {
1484 assert_eq!(millis_to_u64(0), 0);
1485 assert_eq!(millis_to_u64(1000), 1000);
1486 assert_eq!(millis_to_u64(u128::from(u64::MAX)), u64::MAX);
1487 assert_eq!(millis_to_u64(u128::from(u64::MAX) + 1), u64::MAX);
1488 }
1489
1490 #[test]
1491 fn test_build_assistant_message() {
1492 let response = ChatResponse {
1493 id: "msg_1".to_string(),
1494 content: vec![
1495 ContentBlock::Text {
1496 text: "Response text".to_string(),
1497 },
1498 ContentBlock::ToolUse {
1499 id: "tool_1".to_string(),
1500 name: "echo".to_string(),
1501 input: json!({"message": "test"}),
1502 },
1503 ],
1504 model: "test".to_string(),
1505 stop_reason: None,
1506 usage: Usage {
1507 input_tokens: 0,
1508 output_tokens: 0,
1509 },
1510 };
1511
1512 let msg = build_assistant_message(&response);
1513 assert_eq!(msg.role, Role::Assistant);
1514
1515 if let Content::Blocks(blocks) = msg.content {
1516 assert_eq!(blocks.len(), 2);
1517 } else {
1518 panic!("Expected Content::Blocks");
1519 }
1520 }
1521}