1use crate::context::{CompactionConfig, ContextCompactor, LlmContextCompactor};
31use crate::events::AgentEvent;
32use crate::hooks::{AgentHooks, DefaultHooks, ToolDecision};
33use crate::llm::{
34 ChatOutcome, ChatRequest, ChatResponse, Content, ContentBlock, LlmProvider, Message, Role,
35 StopReason,
36};
37use crate::skills::Skill;
38use crate::stores::{InMemoryStore, MessageStore, StateStore, ToolExecutionStore};
39use crate::tools::{ErasedAsyncTool, ErasedToolStatus, ToolContext, ToolRegistry};
40use crate::types::{
41 AgentConfig, AgentContinuation, AgentError, AgentInput, AgentRunState, AgentState,
42 ExecutionStatus, PendingToolCallInfo, RetryConfig, ThreadId, TokenUsage, ToolExecution,
43 ToolOutcome, ToolResult, TurnOutcome,
44};
45use futures::StreamExt;
46use std::sync::Arc;
47use std::time::{Duration, Instant};
48use tokio::sync::mpsc;
49use tokio::time::sleep;
50use tracing::{debug, error, info, warn};
51
52enum InternalTurnResult {
56 Continue { turn_usage: TokenUsage },
58 Done,
60 AwaitingConfirmation {
62 tool_call_id: String,
63 tool_name: String,
64 display_name: String,
65 input: serde_json::Value,
66 description: String,
67 continuation: Box<AgentContinuation>,
68 },
69 Error(AgentError),
71}
72
73struct TurnContext {
77 thread_id: ThreadId,
78 turn: usize,
79 total_usage: TokenUsage,
80 state: AgentState,
81 start_time: Instant,
82}
83
84struct ResumeData {
86 continuation: Box<AgentContinuation>,
87 tool_call_id: String,
88 confirmed: bool,
89 rejection_reason: Option<String>,
90}
91
92struct InitializedState {
94 turn: usize,
95 total_usage: TokenUsage,
96 state: AgentState,
97 resume_data: Option<ResumeData>,
98}
99
100enum ToolExecutionOutcome {
102 Completed { tool_id: String, result: ToolResult },
104 RequiresConfirmation {
106 tool_id: String,
107 tool_name: String,
108 display_name: String,
109 input: serde_json::Value,
110 description: String,
111 },
112}
113
114pub struct AgentLoopBuilder<Ctx, P, H, M, S> {
126 provider: Option<P>,
127 tools: Option<ToolRegistry<Ctx>>,
128 hooks: Option<H>,
129 message_store: Option<M>,
130 state_store: Option<S>,
131 config: Option<AgentConfig>,
132 compaction_config: Option<CompactionConfig>,
133 execution_store: Option<Arc<dyn ToolExecutionStore>>,
134}
135
136impl<Ctx> AgentLoopBuilder<Ctx, (), (), (), ()> {
137 #[must_use]
139 pub fn new() -> Self {
140 Self {
141 provider: None,
142 tools: None,
143 hooks: None,
144 message_store: None,
145 state_store: None,
146 config: None,
147 compaction_config: None,
148 execution_store: None,
149 }
150 }
151}
152
153impl<Ctx> Default for AgentLoopBuilder<Ctx, (), (), (), ()> {
154 fn default() -> Self {
155 Self::new()
156 }
157}
158
159impl<Ctx, P, H, M, S> AgentLoopBuilder<Ctx, P, H, M, S> {
160 #[must_use]
162 pub fn provider<P2: LlmProvider>(self, provider: P2) -> AgentLoopBuilder<Ctx, P2, H, M, S> {
163 AgentLoopBuilder {
164 provider: Some(provider),
165 tools: self.tools,
166 hooks: self.hooks,
167 message_store: self.message_store,
168 state_store: self.state_store,
169 config: self.config,
170 compaction_config: self.compaction_config,
171 execution_store: self.execution_store,
172 }
173 }
174
175 #[must_use]
177 pub fn tools(mut self, tools: ToolRegistry<Ctx>) -> Self {
178 self.tools = Some(tools);
179 self
180 }
181
182 #[must_use]
184 pub fn hooks<H2: AgentHooks>(self, hooks: H2) -> AgentLoopBuilder<Ctx, P, H2, M, S> {
185 AgentLoopBuilder {
186 provider: self.provider,
187 tools: self.tools,
188 hooks: Some(hooks),
189 message_store: self.message_store,
190 state_store: self.state_store,
191 config: self.config,
192 compaction_config: self.compaction_config,
193 execution_store: self.execution_store,
194 }
195 }
196
197 #[must_use]
199 pub fn message_store<M2: MessageStore>(
200 self,
201 message_store: M2,
202 ) -> AgentLoopBuilder<Ctx, P, H, M2, S> {
203 AgentLoopBuilder {
204 provider: self.provider,
205 tools: self.tools,
206 hooks: self.hooks,
207 message_store: Some(message_store),
208 state_store: self.state_store,
209 config: self.config,
210 compaction_config: self.compaction_config,
211 execution_store: self.execution_store,
212 }
213 }
214
215 #[must_use]
217 pub fn state_store<S2: StateStore>(
218 self,
219 state_store: S2,
220 ) -> AgentLoopBuilder<Ctx, P, H, M, S2> {
221 AgentLoopBuilder {
222 provider: self.provider,
223 tools: self.tools,
224 hooks: self.hooks,
225 message_store: self.message_store,
226 state_store: Some(state_store),
227 config: self.config,
228 compaction_config: self.compaction_config,
229 execution_store: self.execution_store,
230 }
231 }
232
233 #[must_use]
251 pub fn execution_store(mut self, store: impl ToolExecutionStore + 'static) -> Self {
252 self.execution_store = Some(Arc::new(store));
253 self
254 }
255
256 #[must_use]
258 pub fn config(mut self, config: AgentConfig) -> Self {
259 self.config = Some(config);
260 self
261 }
262
263 #[must_use]
279 pub const fn with_compaction(mut self, config: CompactionConfig) -> Self {
280 self.compaction_config = Some(config);
281 self
282 }
283
284 #[must_use]
291 pub fn with_auto_compaction(self) -> Self {
292 self.with_compaction(CompactionConfig::default())
293 }
294
295 #[must_use]
313 pub fn with_skill(mut self, skill: Skill) -> Self
314 where
315 Ctx: Send + Sync + 'static,
316 {
317 if let Some(ref mut tools) = self.tools {
319 tools.filter(|name| skill.is_tool_allowed(name));
320 }
321
322 let mut config = self.config.take().unwrap_or_default();
324 if config.system_prompt.is_empty() {
325 config.system_prompt = skill.system_prompt;
326 } else {
327 config.system_prompt = format!("{}\n\n{}", config.system_prompt, skill.system_prompt);
328 }
329 self.config = Some(config);
330
331 self
332 }
333}
334
335impl<Ctx, P> AgentLoopBuilder<Ctx, P, (), (), ()>
336where
337 Ctx: Send + Sync + 'static,
338 P: LlmProvider + 'static,
339{
340 #[must_use]
352 pub fn build(self) -> AgentLoop<Ctx, P, DefaultHooks, InMemoryStore, InMemoryStore> {
353 let provider = self.provider.expect("provider is required");
354 let tools = self.tools.unwrap_or_default();
355 let config = self.config.unwrap_or_default();
356
357 AgentLoop {
358 provider: Arc::new(provider),
359 tools: Arc::new(tools),
360 hooks: Arc::new(DefaultHooks),
361 message_store: Arc::new(InMemoryStore::new()),
362 state_store: Arc::new(InMemoryStore::new()),
363 config,
364 compaction_config: self.compaction_config,
365 execution_store: self.execution_store,
366 }
367 }
368}
369
370impl<Ctx, P, H, M, S> AgentLoopBuilder<Ctx, P, H, M, S>
371where
372 Ctx: Send + Sync + 'static,
373 P: LlmProvider + 'static,
374 H: AgentHooks + 'static,
375 M: MessageStore + 'static,
376 S: StateStore + 'static,
377{
378 #[must_use]
388 pub fn build_with_stores(self) -> AgentLoop<Ctx, P, H, M, S> {
389 let provider = self.provider.expect("provider is required");
390 let tools = self.tools.unwrap_or_default();
391 let hooks = self
392 .hooks
393 .expect("hooks is required when using build_with_stores");
394 let message_store = self
395 .message_store
396 .expect("message_store is required when using build_with_stores");
397 let state_store = self
398 .state_store
399 .expect("state_store is required when using build_with_stores");
400 let config = self.config.unwrap_or_default();
401
402 AgentLoop {
403 provider: Arc::new(provider),
404 tools: Arc::new(tools),
405 hooks: Arc::new(hooks),
406 message_store: Arc::new(message_store),
407 state_store: Arc::new(state_store),
408 config,
409 compaction_config: self.compaction_config,
410 execution_store: self.execution_store,
411 }
412 }
413}
414
415pub struct AgentLoop<Ctx, P, H, M, S>
449where
450 P: LlmProvider,
451 H: AgentHooks,
452 M: MessageStore,
453 S: StateStore,
454{
455 provider: Arc<P>,
456 tools: Arc<ToolRegistry<Ctx>>,
457 hooks: Arc<H>,
458 message_store: Arc<M>,
459 state_store: Arc<S>,
460 config: AgentConfig,
461 compaction_config: Option<CompactionConfig>,
462 execution_store: Option<Arc<dyn ToolExecutionStore>>,
463}
464
465#[must_use]
467pub fn builder<Ctx>() -> AgentLoopBuilder<Ctx, (), (), (), ()> {
468 AgentLoopBuilder::new()
469}
470
471impl<Ctx, P, H, M, S> AgentLoop<Ctx, P, H, M, S>
472where
473 Ctx: Send + Sync + 'static,
474 P: LlmProvider + 'static,
475 H: AgentHooks + 'static,
476 M: MessageStore + 'static,
477 S: StateStore + 'static,
478{
479 #[must_use]
481 pub fn new(
482 provider: P,
483 tools: ToolRegistry<Ctx>,
484 hooks: H,
485 message_store: M,
486 state_store: S,
487 config: AgentConfig,
488 ) -> Self {
489 Self {
490 provider: Arc::new(provider),
491 tools: Arc::new(tools),
492 hooks: Arc::new(hooks),
493 message_store: Arc::new(message_store),
494 state_store: Arc::new(state_store),
495 config,
496 compaction_config: None,
497 execution_store: None,
498 }
499 }
500
501 #[must_use]
503 pub fn with_compaction(
504 provider: P,
505 tools: ToolRegistry<Ctx>,
506 hooks: H,
507 message_store: M,
508 state_store: S,
509 config: AgentConfig,
510 compaction_config: CompactionConfig,
511 ) -> Self {
512 Self {
513 provider: Arc::new(provider),
514 tools: Arc::new(tools),
515 hooks: Arc::new(hooks),
516 message_store: Arc::new(message_store),
517 state_store: Arc::new(state_store),
518 config,
519 compaction_config: Some(compaction_config),
520 execution_store: None,
521 }
522 }
523
524 pub fn run(
574 &self,
575 thread_id: ThreadId,
576 input: AgentInput,
577 tool_context: ToolContext<Ctx>,
578 ) -> (
579 mpsc::Receiver<AgentEvent>,
580 tokio::sync::oneshot::Receiver<AgentRunState>,
581 )
582 where
583 Ctx: Clone,
584 {
585 let (event_tx, event_rx) = mpsc::channel(100);
586 let (state_tx, state_rx) = tokio::sync::oneshot::channel();
587
588 let provider = Arc::clone(&self.provider);
589 let tools = Arc::clone(&self.tools);
590 let hooks = Arc::clone(&self.hooks);
591 let message_store = Arc::clone(&self.message_store);
592 let state_store = Arc::clone(&self.state_store);
593 let config = self.config.clone();
594 let compaction_config = self.compaction_config.clone();
595 let execution_store = self.execution_store.clone();
596
597 tokio::spawn(async move {
598 let result = run_loop(
599 event_tx,
600 thread_id,
601 input,
602 tool_context,
603 provider,
604 tools,
605 hooks,
606 message_store,
607 state_store,
608 config,
609 compaction_config,
610 execution_store,
611 )
612 .await;
613
614 let _ = state_tx.send(result);
615 });
616
617 (event_rx, state_rx)
618 }
619
620 pub async fn run_turn(
676 &self,
677 thread_id: ThreadId,
678 input: AgentInput,
679 tool_context: ToolContext<Ctx>,
680 ) -> (mpsc::Receiver<AgentEvent>, TurnOutcome)
681 where
682 Ctx: Clone,
683 {
684 let (event_tx, event_rx) = mpsc::channel(100);
685
686 let provider = Arc::clone(&self.provider);
687 let tools = Arc::clone(&self.tools);
688 let hooks = Arc::clone(&self.hooks);
689 let message_store = Arc::clone(&self.message_store);
690 let state_store = Arc::clone(&self.state_store);
691 let config = self.config.clone();
692 let compaction_config = self.compaction_config.clone();
693 let execution_store = self.execution_store.clone();
694
695 let result = run_single_turn(TurnParameters {
696 tx: event_tx,
697 thread_id,
698 input,
699 tool_context,
700 provider,
701 tools,
702 hooks,
703 message_store,
704 state_store,
705 config,
706 compaction_config,
707 execution_store,
708 })
709 .await;
710
711 (event_rx, result)
712 }
713}
714
715async fn initialize_from_input<M, S>(
726 input: AgentInput,
727 thread_id: &ThreadId,
728 message_store: &Arc<M>,
729 state_store: &Arc<S>,
730) -> Result<InitializedState, AgentError>
731where
732 M: MessageStore,
733 S: StateStore,
734{
735 match input {
736 AgentInput::Text(user_message) => {
737 let state = match state_store.load(thread_id).await {
739 Ok(Some(s)) => s,
740 Ok(None) => AgentState::new(thread_id.clone()),
741 Err(e) => {
742 return Err(AgentError::new(format!("Failed to load state: {e}"), false));
743 }
744 };
745
746 let user_msg = Message::user(&user_message);
748 if let Err(e) = message_store.append(thread_id, user_msg).await {
749 return Err(AgentError::new(
750 format!("Failed to append message: {e}"),
751 false,
752 ));
753 }
754
755 Ok(InitializedState {
756 turn: 0,
757 total_usage: TokenUsage::default(),
758 state,
759 resume_data: None,
760 })
761 }
762 AgentInput::Resume {
763 continuation,
764 tool_call_id,
765 confirmed,
766 rejection_reason,
767 } => {
768 if continuation.thread_id != *thread_id {
770 return Err(AgentError::new(
771 format!(
772 "Thread ID mismatch: continuation is for {}, but resuming on {}",
773 continuation.thread_id, thread_id
774 ),
775 false,
776 ));
777 }
778
779 Ok(InitializedState {
780 turn: continuation.turn,
781 total_usage: continuation.total_usage.clone(),
782 state: continuation.state.clone(),
783 resume_data: Some(ResumeData {
784 continuation,
785 tool_call_id,
786 confirmed,
787 rejection_reason,
788 }),
789 })
790 }
791 AgentInput::Continue => {
792 let state = match state_store.load(thread_id).await {
794 Ok(Some(s)) => s,
795 Ok(None) => {
796 return Err(AgentError::new(
797 "Cannot continue: no state found for thread",
798 false,
799 ));
800 }
801 Err(e) => {
802 return Err(AgentError::new(format!("Failed to load state: {e}"), false));
803 }
804 };
805
806 Ok(InitializedState {
808 turn: state.turn_count,
809 total_usage: state.total_usage.clone(),
810 state,
811 resume_data: None,
812 })
813 }
814 }
815}
816
817#[allow(clippy::too_many_lines)]
826async fn execute_tool_call<Ctx, H>(
827 pending: &PendingToolCallInfo,
828 tool_context: &ToolContext<Ctx>,
829 tools: &ToolRegistry<Ctx>,
830 hooks: &Arc<H>,
831 tx: &mpsc::Sender<AgentEvent>,
832) -> ToolExecutionOutcome
833where
834 Ctx: Send + Sync + Clone + 'static,
835 H: AgentHooks,
836{
837 if let Some(async_tool) = tools.get_async(&pending.name) {
839 let tier = async_tool.tier();
840
841 let _ = tx
843 .send(AgentEvent::tool_call_start(
844 &pending.id,
845 &pending.name,
846 &pending.display_name,
847 pending.input.clone(),
848 tier,
849 ))
850 .await;
851
852 let decision = hooks
854 .pre_tool_use(&pending.name, &pending.input, tier)
855 .await;
856
857 return match decision {
858 ToolDecision::Allow => {
859 let result = execute_async_tool(pending, async_tool, tool_context, tx).await;
860
861 hooks.post_tool_use(&pending.name, &result).await;
862
863 send_event(
864 tx,
865 hooks,
866 AgentEvent::tool_call_end(
867 &pending.id,
868 &pending.name,
869 &pending.display_name,
870 result.clone(),
871 ),
872 )
873 .await;
874
875 ToolExecutionOutcome::Completed {
876 tool_id: pending.id.clone(),
877 result,
878 }
879 }
880 ToolDecision::Block(reason) => {
881 let result = ToolResult::error(format!("Blocked: {reason}"));
882 send_event(
883 tx,
884 hooks,
885 AgentEvent::tool_call_end(
886 &pending.id,
887 &pending.name,
888 &pending.display_name,
889 result.clone(),
890 ),
891 )
892 .await;
893 ToolExecutionOutcome::Completed {
894 tool_id: pending.id.clone(),
895 result,
896 }
897 }
898 ToolDecision::RequiresConfirmation(description) => {
899 send_event(
900 tx,
901 hooks,
902 AgentEvent::ToolRequiresConfirmation {
903 id: pending.id.clone(),
904 name: pending.name.clone(),
905 input: pending.input.clone(),
906 description: description.clone(),
907 },
908 )
909 .await;
910
911 ToolExecutionOutcome::RequiresConfirmation {
912 tool_id: pending.id.clone(),
913 tool_name: pending.name.clone(),
914 display_name: pending.display_name.clone(),
915 input: pending.input.clone(),
916 description,
917 }
918 }
919 };
920 }
921
922 let Some(tool) = tools.get(&pending.name) else {
924 let result = ToolResult::error(format!("Unknown tool: {}", pending.name));
925 return ToolExecutionOutcome::Completed {
926 tool_id: pending.id.clone(),
927 result,
928 };
929 };
930
931 let tier = tool.tier();
932
933 let _ = tx
935 .send(AgentEvent::tool_call_start(
936 &pending.id,
937 &pending.name,
938 &pending.display_name,
939 pending.input.clone(),
940 tier,
941 ))
942 .await;
943
944 let decision = hooks
946 .pre_tool_use(&pending.name, &pending.input, tier)
947 .await;
948
949 match decision {
950 ToolDecision::Allow => {
951 let tool_start = Instant::now();
952 let result = match tool.execute(tool_context, pending.input.clone()).await {
953 Ok(mut r) => {
954 r.duration_ms = Some(millis_to_u64(tool_start.elapsed().as_millis()));
955 r
956 }
957 Err(e) => ToolResult::error(format!("Tool error: {e}"))
958 .with_duration(millis_to_u64(tool_start.elapsed().as_millis())),
959 };
960
961 hooks.post_tool_use(&pending.name, &result).await;
962
963 send_event(
964 tx,
965 hooks,
966 AgentEvent::tool_call_end(
967 &pending.id,
968 &pending.name,
969 &pending.display_name,
970 result.clone(),
971 ),
972 )
973 .await;
974
975 ToolExecutionOutcome::Completed {
976 tool_id: pending.id.clone(),
977 result,
978 }
979 }
980 ToolDecision::Block(reason) => {
981 let result = ToolResult::error(format!("Blocked: {reason}"));
982 send_event(
983 tx,
984 hooks,
985 AgentEvent::tool_call_end(
986 &pending.id,
987 &pending.name,
988 &pending.display_name,
989 result.clone(),
990 ),
991 )
992 .await;
993 ToolExecutionOutcome::Completed {
994 tool_id: pending.id.clone(),
995 result,
996 }
997 }
998 ToolDecision::RequiresConfirmation(description) => {
999 send_event(
1000 tx,
1001 hooks,
1002 AgentEvent::ToolRequiresConfirmation {
1003 id: pending.id.clone(),
1004 name: pending.name.clone(),
1005 input: pending.input.clone(),
1006 description: description.clone(),
1007 },
1008 )
1009 .await;
1010
1011 ToolExecutionOutcome::RequiresConfirmation {
1012 tool_id: pending.id.clone(),
1013 tool_name: pending.name.clone(),
1014 display_name: pending.display_name.clone(),
1015 input: pending.input.clone(),
1016 description,
1017 }
1018 }
1019 }
1020}
1021
1022async fn execute_async_tool<Ctx>(
1028 pending: &PendingToolCallInfo,
1029 tool: &Arc<dyn ErasedAsyncTool<Ctx>>,
1030 tool_context: &ToolContext<Ctx>,
1031 tx: &mpsc::Sender<AgentEvent>,
1032) -> ToolResult
1033where
1034 Ctx: Send + Sync + Clone,
1035{
1036 let tool_start = Instant::now();
1037
1038 let outcome = match tool.execute(tool_context, pending.input.clone()).await {
1040 Ok(o) => o,
1041 Err(e) => {
1042 return ToolResult::error(format!("Tool error: {e}"))
1043 .with_duration(millis_to_u64(tool_start.elapsed().as_millis()));
1044 }
1045 };
1046
1047 match outcome {
1048 ToolOutcome::Success(mut result) | ToolOutcome::Failed(mut result) => {
1050 result.duration_ms = Some(millis_to_u64(tool_start.elapsed().as_millis()));
1051 result
1052 }
1053
1054 ToolOutcome::InProgress {
1056 operation_id,
1057 message,
1058 } => {
1059 let _ = tx
1061 .send(AgentEvent::tool_progress(
1062 &pending.id,
1063 &pending.name,
1064 &pending.display_name,
1065 "started",
1066 &message,
1067 None,
1068 ))
1069 .await;
1070
1071 let mut stream = tool.check_status_stream(tool_context, &operation_id);
1073
1074 while let Some(status) = stream.next().await {
1075 match status {
1076 ErasedToolStatus::Progress {
1077 stage,
1078 message,
1079 data,
1080 } => {
1081 let _ = tx
1082 .send(AgentEvent::tool_progress(
1083 &pending.id,
1084 &pending.name,
1085 &pending.display_name,
1086 stage,
1087 message,
1088 data,
1089 ))
1090 .await;
1091 }
1092 ErasedToolStatus::Completed(mut result)
1093 | ErasedToolStatus::Failed(mut result) => {
1094 result.duration_ms = Some(millis_to_u64(tool_start.elapsed().as_millis()));
1095 return result;
1096 }
1097 }
1098 }
1099
1100 ToolResult::error("Async tool stream ended without completion")
1102 .with_duration(millis_to_u64(tool_start.elapsed().as_millis()))
1103 }
1104 }
1105}
1106
1107async fn execute_confirmed_tool<Ctx, H>(
1112 awaiting_tool: &PendingToolCallInfo,
1113 confirmed: bool,
1114 rejection_reason: Option<String>,
1115 tool_context: &ToolContext<Ctx>,
1116 tools: &ToolRegistry<Ctx>,
1117 hooks: &Arc<H>,
1118 tx: &mpsc::Sender<AgentEvent>,
1119) -> ToolResult
1120where
1121 Ctx: Send + Sync + Clone + 'static,
1122 H: AgentHooks,
1123{
1124 if confirmed {
1125 if let Some(async_tool) = tools.get_async(&awaiting_tool.name) {
1127 let result = execute_async_tool(awaiting_tool, async_tool, tool_context, tx).await;
1128
1129 hooks.post_tool_use(&awaiting_tool.name, &result).await;
1130
1131 let _ = tx
1132 .send(AgentEvent::tool_call_end(
1133 &awaiting_tool.id,
1134 &awaiting_tool.name,
1135 &awaiting_tool.display_name,
1136 result.clone(),
1137 ))
1138 .await;
1139
1140 return result;
1141 }
1142
1143 if let Some(tool) = tools.get(&awaiting_tool.name) {
1145 let tool_start = Instant::now();
1146 let result = match tool
1147 .execute(tool_context, awaiting_tool.input.clone())
1148 .await
1149 {
1150 Ok(mut r) => {
1151 r.duration_ms = Some(millis_to_u64(tool_start.elapsed().as_millis()));
1152 r
1153 }
1154 Err(e) => ToolResult::error(format!("Tool error: {e}"))
1155 .with_duration(millis_to_u64(tool_start.elapsed().as_millis())),
1156 };
1157
1158 hooks.post_tool_use(&awaiting_tool.name, &result).await;
1159
1160 let _ = tx
1161 .send(AgentEvent::tool_call_end(
1162 &awaiting_tool.id,
1163 &awaiting_tool.name,
1164 &awaiting_tool.display_name,
1165 result.clone(),
1166 ))
1167 .await;
1168
1169 result
1170 } else {
1171 ToolResult::error(format!("Unknown tool: {}", awaiting_tool.name))
1172 }
1173 } else {
1174 let reason = rejection_reason.unwrap_or_else(|| "User rejected".to_string());
1175 let result = ToolResult::error(format!("Rejected: {reason}"));
1176 send_event(
1177 tx,
1178 hooks,
1179 AgentEvent::tool_call_end(
1180 &awaiting_tool.id,
1181 &awaiting_tool.name,
1182 &awaiting_tool.display_name,
1183 result.clone(),
1184 ),
1185 )
1186 .await;
1187 result
1188 }
1189}
1190
1191async fn append_tool_results<M>(
1193 tool_results: &[(String, ToolResult)],
1194 thread_id: &ThreadId,
1195 message_store: &Arc<M>,
1196) -> Result<(), AgentError>
1197where
1198 M: MessageStore,
1199{
1200 for (tool_id, result) in tool_results {
1201 let tool_result_msg = Message::tool_result(tool_id, &result.output, !result.success);
1202 if let Err(e) = message_store.append(thread_id, tool_result_msg).await {
1203 return Err(AgentError::new(
1204 format!("Failed to append tool result: {e}"),
1205 false,
1206 ));
1207 }
1208 }
1209 Ok(())
1210}
1211
1212async fn call_llm_with_retry<P, H>(
1214 provider: &Arc<P>,
1215 request: ChatRequest,
1216 config: &AgentConfig,
1217 tx: &mpsc::Sender<AgentEvent>,
1218 hooks: &Arc<H>,
1219) -> Result<ChatResponse, AgentError>
1220where
1221 P: LlmProvider,
1222 H: AgentHooks,
1223{
1224 let max_retries = config.retry.max_retries;
1225 let mut attempt = 0u32;
1226
1227 loop {
1228 let outcome = match provider.chat(request.clone()).await {
1229 Ok(o) => o,
1230 Err(e) => {
1231 return Err(AgentError::new(format!("LLM error: {e}"), false));
1232 }
1233 };
1234
1235 match outcome {
1236 ChatOutcome::Success(response) => return Ok(response),
1237 ChatOutcome::RateLimited => {
1238 attempt += 1;
1239 if attempt > max_retries {
1240 error!("Rate limited by LLM provider after {max_retries} retries");
1241 let error_msg = format!("Rate limited after {max_retries} retries");
1242 send_event(tx, hooks, AgentEvent::error(&error_msg, true)).await;
1243 return Err(AgentError::new(error_msg, true));
1244 }
1245 let delay = calculate_backoff_delay(attempt, &config.retry);
1246 warn!(
1247 attempt,
1248 delay_ms = delay.as_millis(),
1249 "Rate limited, retrying after backoff"
1250 );
1251 let _ = tx
1252 .send(AgentEvent::text(format!(
1253 "\n[Rate limited, retrying in {:.1}s... (attempt {attempt}/{max_retries})]\n",
1254 delay.as_secs_f64()
1255 )))
1256 .await;
1257 sleep(delay).await;
1258 }
1259 ChatOutcome::InvalidRequest(msg) => {
1260 error!(msg, "Invalid request to LLM");
1261 return Err(AgentError::new(format!("Invalid request: {msg}"), false));
1262 }
1263 ChatOutcome::ServerError(msg) => {
1264 attempt += 1;
1265 if attempt > max_retries {
1266 error!(msg, "LLM server error after {max_retries} retries");
1267 let error_msg = format!("Server error after {max_retries} retries: {msg}");
1268 send_event(tx, hooks, AgentEvent::error(&error_msg, true)).await;
1269 return Err(AgentError::new(error_msg, true));
1270 }
1271 let delay = calculate_backoff_delay(attempt, &config.retry);
1272 warn!(
1273 attempt,
1274 delay_ms = delay.as_millis(),
1275 error = msg,
1276 "Server error, retrying after backoff"
1277 );
1278 send_event(
1279 tx,
1280 hooks,
1281 AgentEvent::text(format!(
1282 "\n[Server error: {msg}, retrying in {:.1}s... (attempt {attempt}/{max_retries})]\n",
1283 delay.as_secs_f64()
1284 )),
1285 )
1286 .await;
1287 sleep(delay).await;
1288 }
1289 }
1290 }
1291}
1292
1293#[allow(clippy::too_many_arguments, clippy::too_many_lines)]
1298async fn run_loop<Ctx, P, H, M, S>(
1299 tx: mpsc::Sender<AgentEvent>,
1300 thread_id: ThreadId,
1301 input: AgentInput,
1302 tool_context: ToolContext<Ctx>,
1303 provider: Arc<P>,
1304 tools: Arc<ToolRegistry<Ctx>>,
1305 hooks: Arc<H>,
1306 message_store: Arc<M>,
1307 state_store: Arc<S>,
1308 config: AgentConfig,
1309 compaction_config: Option<CompactionConfig>,
1310 execution_store: Option<Arc<dyn ToolExecutionStore>>,
1311) -> AgentRunState
1312where
1313 Ctx: Send + Sync + Clone + 'static,
1314 P: LlmProvider,
1315 H: AgentHooks,
1316 M: MessageStore,
1317 S: StateStore,
1318{
1319 let tool_context = tool_context.with_event_tx(tx.clone());
1321 let start_time = Instant::now();
1322
1323 let init_state =
1325 match initialize_from_input(input, &thread_id, &message_store, &state_store).await {
1326 Ok(s) => s,
1327 Err(e) => return AgentRunState::Error(e),
1328 };
1329
1330 let InitializedState {
1331 turn,
1332 total_usage,
1333 state,
1334 resume_data,
1335 } = init_state;
1336
1337 if let Some(resume) = resume_data {
1338 let ResumeData {
1339 continuation: cont,
1340 tool_call_id,
1341 confirmed,
1342 rejection_reason,
1343 } = resume;
1344 let mut tool_results = cont.completed_results.clone();
1345 let awaiting_tool = &cont.pending_tool_calls[cont.awaiting_index];
1346
1347 if awaiting_tool.id != tool_call_id {
1348 let message = format!(
1349 "Tool call ID mismatch: expected {}, got {}",
1350 awaiting_tool.id, tool_call_id
1351 );
1352 let recoverable = false;
1353 send_event(&tx, &hooks, AgentEvent::error(&message, recoverable)).await;
1354 return AgentRunState::Error(AgentError::new(&message, recoverable));
1355 }
1356
1357 let result = execute_confirmed_tool(
1358 awaiting_tool,
1359 confirmed,
1360 rejection_reason,
1361 &tool_context,
1362 &tools,
1363 &hooks,
1364 &tx,
1365 )
1366 .await;
1367 tool_results.push((awaiting_tool.id.clone(), result));
1368
1369 for pending in cont.pending_tool_calls.iter().skip(cont.awaiting_index + 1) {
1370 match execute_tool_call(pending, &tool_context, &tools, &hooks, &tx).await {
1371 ToolExecutionOutcome::Completed { tool_id, result } => {
1372 tool_results.push((tool_id, result));
1373 }
1374 ToolExecutionOutcome::RequiresConfirmation {
1375 tool_id,
1376 tool_name,
1377 display_name,
1378 input,
1379 description,
1380 } => {
1381 let pending_idx = cont
1382 .pending_tool_calls
1383 .iter()
1384 .position(|p| p.id == tool_id)
1385 .unwrap_or(0);
1386
1387 let new_continuation = AgentContinuation {
1388 thread_id: thread_id.clone(),
1389 turn,
1390 total_usage: total_usage.clone(),
1391 turn_usage: cont.turn_usage.clone(),
1392 pending_tool_calls: cont.pending_tool_calls.clone(),
1393 awaiting_index: pending_idx,
1394 completed_results: tool_results,
1395 state: state.clone(),
1396 };
1397
1398 return AgentRunState::AwaitingConfirmation {
1399 tool_call_id: tool_id,
1400 tool_name,
1401 display_name,
1402 input,
1403 description,
1404 continuation: Box::new(new_continuation),
1405 };
1406 }
1407 }
1408 }
1409
1410 if let Err(e) = append_tool_results(&tool_results, &thread_id, &message_store).await {
1411 send_event(
1412 &tx,
1413 &hooks,
1414 AgentEvent::Error {
1415 message: e.message.clone(),
1416 recoverable: e.recoverable,
1417 },
1418 )
1419 .await;
1420 return AgentRunState::Error(e);
1421 }
1422
1423 send_event(
1424 &tx,
1425 &hooks,
1426 AgentEvent::TurnComplete {
1427 turn,
1428 usage: cont.turn_usage.clone(),
1429 },
1430 )
1431 .await;
1432 }
1433
1434 let mut ctx = TurnContext {
1435 thread_id: thread_id.clone(),
1436 turn,
1437 total_usage,
1438 state,
1439 start_time,
1440 };
1441
1442 loop {
1443 let result = execute_turn(
1444 &tx,
1445 &mut ctx,
1446 &tool_context,
1447 &provider,
1448 &tools,
1449 &hooks,
1450 &message_store,
1451 &config,
1452 compaction_config.as_ref(),
1453 execution_store.as_ref(),
1454 )
1455 .await;
1456
1457 match result {
1458 InternalTurnResult::Continue { .. } => {
1459 if let Err(e) = state_store.save(&ctx.state).await {
1460 warn!(error = %e, "Failed to save state checkpoint");
1461 }
1462 }
1463 InternalTurnResult::Done => {
1464 break;
1465 }
1466 InternalTurnResult::AwaitingConfirmation {
1467 tool_call_id,
1468 tool_name,
1469 display_name,
1470 input,
1471 description,
1472 continuation,
1473 } => {
1474 return AgentRunState::AwaitingConfirmation {
1475 tool_call_id,
1476 tool_name,
1477 display_name,
1478 input,
1479 description,
1480 continuation,
1481 };
1482 }
1483 InternalTurnResult::Error(e) => {
1484 return AgentRunState::Error(e);
1485 }
1486 }
1487 }
1488
1489 if let Err(e) = state_store.save(&ctx.state).await {
1490 warn!(error = %e, "Failed to save final state");
1491 }
1492
1493 let duration = ctx.start_time.elapsed();
1494 send_event(
1495 &tx,
1496 &hooks,
1497 AgentEvent::done(thread_id, ctx.turn, ctx.total_usage.clone(), duration),
1498 )
1499 .await;
1500
1501 AgentRunState::Done {
1502 total_turns: u32::try_from(ctx.turn).unwrap_or(u32::MAX),
1503 input_tokens: u64::from(ctx.total_usage.input_tokens),
1504 output_tokens: u64::from(ctx.total_usage.output_tokens),
1505 }
1506}
1507
1508struct TurnParameters<Ctx, P, H, M, S> {
1509 tx: mpsc::Sender<AgentEvent>,
1510 thread_id: ThreadId,
1511 input: AgentInput,
1512 tool_context: ToolContext<Ctx>,
1513 provider: Arc<P>,
1514 tools: Arc<ToolRegistry<Ctx>>,
1515 hooks: Arc<H>,
1516 message_store: Arc<M>,
1517 state_store: Arc<S>,
1518 config: AgentConfig,
1519 compaction_config: Option<CompactionConfig>,
1520 execution_store: Option<Arc<dyn ToolExecutionStore>>,
1521}
1522
1523async fn run_single_turn<Ctx, P, H, M, S>(
1529 TurnParameters {
1530 tx,
1531 thread_id,
1532 input,
1533 tool_context,
1534 provider,
1535 tools,
1536 hooks,
1537 message_store,
1538 state_store,
1539 config,
1540 compaction_config,
1541 execution_store,
1542 }: TurnParameters<Ctx, P, H, M, S>,
1543) -> TurnOutcome
1544where
1545 Ctx: Send + Sync + Clone + 'static,
1546 P: LlmProvider,
1547 H: AgentHooks,
1548 M: MessageStore,
1549 S: StateStore,
1550{
1551 let tool_context = tool_context.with_event_tx(tx.clone());
1552 let start_time = Instant::now();
1553
1554 let init_state =
1555 match initialize_from_input(input, &thread_id, &message_store, &state_store).await {
1556 Ok(s) => s,
1557 Err(e) => {
1558 send_event(&tx, &hooks, AgentEvent::error(&e.message, e.recoverable)).await;
1559 return TurnOutcome::Error(e);
1560 }
1561 };
1562
1563 let InitializedState {
1564 turn,
1565 total_usage,
1566 state,
1567 resume_data,
1568 } = init_state;
1569
1570 if let Some(resume_data_val) = resume_data {
1571 return handle_resume_case(ResumeCaseParameters {
1572 resume_data: resume_data_val,
1573 turn,
1574 total_usage,
1575 state,
1576 thread_id,
1577 tool_context,
1578 tools,
1579 hooks,
1580 tx,
1581 message_store,
1582 state_store,
1583 })
1584 .await;
1585 }
1586
1587 let mut ctx = TurnContext {
1588 thread_id: thread_id.clone(),
1589 turn,
1590 total_usage,
1591 state,
1592 start_time,
1593 };
1594
1595 let result = execute_turn(
1596 &tx,
1597 &mut ctx,
1598 &tool_context,
1599 &provider,
1600 &tools,
1601 &hooks,
1602 &message_store,
1603 &config,
1604 compaction_config.as_ref(),
1605 execution_store.as_ref(),
1606 )
1607 .await;
1608
1609 match result {
1611 InternalTurnResult::Continue { turn_usage } => {
1612 if let Err(e) = state_store.save(&ctx.state).await {
1614 warn!(error = %e, "Failed to save state checkpoint");
1615 }
1616
1617 TurnOutcome::NeedsMoreTurns {
1618 turn: ctx.turn,
1619 turn_usage,
1620 total_usage: ctx.total_usage,
1621 }
1622 }
1623 InternalTurnResult::Done => {
1624 if let Err(e) = state_store.save(&ctx.state).await {
1626 warn!(error = %e, "Failed to save final state");
1627 }
1628
1629 let duration = ctx.start_time.elapsed();
1631 send_event(
1632 &tx,
1633 &hooks,
1634 AgentEvent::done(thread_id, ctx.turn, ctx.total_usage.clone(), duration),
1635 )
1636 .await;
1637
1638 TurnOutcome::Done {
1639 total_turns: u32::try_from(ctx.turn).unwrap_or(u32::MAX),
1640 input_tokens: u64::from(ctx.total_usage.input_tokens),
1641 output_tokens: u64::from(ctx.total_usage.output_tokens),
1642 }
1643 }
1644 InternalTurnResult::AwaitingConfirmation {
1645 tool_call_id,
1646 tool_name,
1647 display_name,
1648 input,
1649 description,
1650 continuation,
1651 } => TurnOutcome::AwaitingConfirmation {
1652 tool_call_id,
1653 tool_name,
1654 display_name,
1655 input,
1656 description,
1657 continuation,
1658 },
1659 InternalTurnResult::Error(e) => TurnOutcome::Error(e),
1660 }
1661}
1662
1663struct ResumeCaseParameters<Ctx, H, M, S> {
1664 resume_data: ResumeData,
1665 turn: usize,
1666 total_usage: TokenUsage,
1667 state: AgentState,
1668 thread_id: ThreadId,
1669 tool_context: ToolContext<Ctx>,
1670 tools: Arc<ToolRegistry<Ctx>>,
1671 hooks: Arc<H>,
1672 tx: mpsc::Sender<AgentEvent>,
1673 message_store: Arc<M>,
1674 state_store: Arc<S>,
1675}
1676
1677async fn handle_resume_case<Ctx, H, M, S>(
1678 ResumeCaseParameters {
1679 resume_data,
1680 turn,
1681 total_usage,
1682 state,
1683 thread_id,
1684 tool_context,
1685 tools,
1686 hooks,
1687 tx,
1688 message_store,
1689 state_store,
1690 }: ResumeCaseParameters<Ctx, H, M, S>,
1691) -> TurnOutcome
1692where
1693 Ctx: Send + Sync + Clone + 'static,
1694 H: AgentHooks,
1695 M: MessageStore,
1696 S: StateStore,
1697{
1698 let ResumeData {
1699 continuation: cont,
1700 tool_call_id,
1701 confirmed,
1702 rejection_reason,
1703 } = resume_data;
1704 let mut tool_results = cont.completed_results.clone();
1706 let awaiting_tool = &cont.pending_tool_calls[cont.awaiting_index];
1707
1708 if awaiting_tool.id != tool_call_id {
1710 let message = format!(
1711 "Tool call ID mismatch: expected {}, got {}",
1712 awaiting_tool.id, tool_call_id
1713 );
1714 let recoverable = false;
1715 send_event(&tx, &hooks, AgentEvent::error(&message, recoverable)).await;
1716 return TurnOutcome::Error(AgentError::new(&message, recoverable));
1717 }
1718
1719 let result = execute_confirmed_tool(
1720 awaiting_tool,
1721 confirmed,
1722 rejection_reason,
1723 &tool_context,
1724 &tools,
1725 &hooks,
1726 &tx,
1727 )
1728 .await;
1729 tool_results.push((awaiting_tool.id.clone(), result));
1730
1731 for pending in cont.pending_tool_calls.iter().skip(cont.awaiting_index + 1) {
1732 match execute_tool_call(pending, &tool_context, &tools, &hooks, &tx).await {
1733 ToolExecutionOutcome::Completed { tool_id, result } => {
1734 tool_results.push((tool_id, result));
1735 }
1736 ToolExecutionOutcome::RequiresConfirmation {
1737 tool_id,
1738 tool_name,
1739 display_name,
1740 input,
1741 description,
1742 } => {
1743 let pending_idx = cont
1744 .pending_tool_calls
1745 .iter()
1746 .position(|p| p.id == tool_id)
1747 .unwrap_or(0);
1748
1749 let new_continuation = AgentContinuation {
1750 thread_id: thread_id.clone(),
1751 turn,
1752 total_usage: total_usage.clone(),
1753 turn_usage: cont.turn_usage.clone(),
1754 pending_tool_calls: cont.pending_tool_calls.clone(),
1755 awaiting_index: pending_idx,
1756 completed_results: tool_results,
1757 state: state.clone(),
1758 };
1759
1760 return TurnOutcome::AwaitingConfirmation {
1761 tool_call_id: tool_id,
1762 tool_name,
1763 display_name,
1764 input,
1765 description,
1766 continuation: Box::new(new_continuation),
1767 };
1768 }
1769 }
1770 }
1771
1772 if let Err(e) = append_tool_results(&tool_results, &thread_id, &message_store).await {
1773 send_event(&tx, &hooks, AgentEvent::error(&e.message, e.recoverable)).await;
1774 return TurnOutcome::Error(e);
1775 }
1776
1777 send_event(
1778 &tx,
1779 &hooks,
1780 AgentEvent::TurnComplete {
1781 turn,
1782 usage: cont.turn_usage.clone(),
1783 },
1784 )
1785 .await;
1786
1787 let mut updated_state = state;
1788 updated_state.turn_count = turn;
1789 if let Err(e) = state_store.save(&updated_state).await {
1790 warn!(error = %e, "Failed to save state checkpoint");
1791 }
1792
1793 TurnOutcome::NeedsMoreTurns {
1794 turn,
1795 turn_usage: cont.turn_usage.clone(),
1796 total_usage,
1797 }
1798}
1799
1800async fn try_get_cached_result(
1809 execution_store: Option<&Arc<dyn ToolExecutionStore>>,
1810 tool_call_id: &str,
1811) -> Option<ToolResult> {
1812 let store = execution_store?;
1813 let execution = store.get_execution(tool_call_id).await.ok()??;
1814
1815 match execution.status {
1816 ExecutionStatus::Completed => execution.result,
1817 ExecutionStatus::InFlight => {
1818 warn!(
1821 tool_call_id = tool_call_id,
1822 tool_name = execution.tool_name,
1823 "Found in-flight execution from previous attempt, re-executing"
1824 );
1825 None
1826 }
1827 }
1828}
1829
1830async fn record_execution_start(
1832 execution_store: Option<&Arc<dyn ToolExecutionStore>>,
1833 pending: &PendingToolCallInfo,
1834 thread_id: &ThreadId,
1835 started_at: time::OffsetDateTime,
1836) {
1837 if let Some(store) = execution_store {
1838 let execution = ToolExecution::new_in_flight(
1839 &pending.id,
1840 thread_id.clone(),
1841 &pending.name,
1842 &pending.display_name,
1843 pending.input.clone(),
1844 started_at,
1845 );
1846 if let Err(e) = store.record_execution(execution).await {
1847 warn!(
1848 tool_call_id = pending.id,
1849 error = %e,
1850 "Failed to record execution start"
1851 );
1852 }
1853 }
1854}
1855
1856async fn record_execution_complete(
1858 execution_store: Option<&Arc<dyn ToolExecutionStore>>,
1859 pending: &PendingToolCallInfo,
1860 thread_id: &ThreadId,
1861 result: &ToolResult,
1862 started_at: time::OffsetDateTime,
1863) {
1864 if let Some(store) = execution_store {
1865 let mut execution = ToolExecution::new_in_flight(
1866 &pending.id,
1867 thread_id.clone(),
1868 &pending.name,
1869 &pending.display_name,
1870 pending.input.clone(),
1871 started_at,
1872 );
1873 execution.complete(result.clone());
1874 if let Err(e) = store.update_execution(execution).await {
1875 warn!(
1876 tool_call_id = pending.id,
1877 error = %e,
1878 "Failed to record execution completion"
1879 );
1880 }
1881 }
1882}
1883
1884#[allow(clippy::too_many_arguments, clippy::too_many_lines)]
1889async fn execute_turn<Ctx, P, H, M>(
1890 tx: &mpsc::Sender<AgentEvent>,
1891 ctx: &mut TurnContext,
1892 tool_context: &ToolContext<Ctx>,
1893 provider: &Arc<P>,
1894 tools: &Arc<ToolRegistry<Ctx>>,
1895 hooks: &Arc<H>,
1896 message_store: &Arc<M>,
1897 config: &AgentConfig,
1898 compaction_config: Option<&CompactionConfig>,
1899 execution_store: Option<&Arc<dyn ToolExecutionStore>>,
1900) -> InternalTurnResult
1901where
1902 Ctx: Send + Sync + Clone + 'static,
1903 P: LlmProvider,
1904 H: AgentHooks,
1905 M: MessageStore,
1906{
1907 ctx.turn += 1;
1908 ctx.state.turn_count = ctx.turn;
1909
1910 if ctx.turn > config.max_turns {
1911 warn!(turn = ctx.turn, max = config.max_turns, "Max turns reached");
1912 send_event(
1913 tx,
1914 hooks,
1915 AgentEvent::error(
1916 format!("Maximum turns ({}) reached", config.max_turns),
1917 true,
1918 ),
1919 )
1920 .await;
1921 return InternalTurnResult::Error(AgentError::new(
1922 format!("Maximum turns ({}) reached", config.max_turns),
1923 true,
1924 ));
1925 }
1926
1927 send_event(
1929 tx,
1930 hooks,
1931 AgentEvent::start(ctx.thread_id.clone(), ctx.turn),
1932 )
1933 .await;
1934
1935 let mut messages = match message_store.get_history(&ctx.thread_id).await {
1937 Ok(m) => m,
1938 Err(e) => {
1939 send_event(
1940 tx,
1941 hooks,
1942 AgentEvent::error(format!("Failed to get history: {e}"), false),
1943 )
1944 .await;
1945 return InternalTurnResult::Error(AgentError::new(
1946 format!("Failed to get history: {e}"),
1947 false,
1948 ));
1949 }
1950 };
1951
1952 if let Some(compact_config) = compaction_config {
1954 let compactor = LlmContextCompactor::new(Arc::clone(provider), compact_config.clone());
1955 if compactor.needs_compaction(&messages) {
1956 debug!(
1957 turn = ctx.turn,
1958 message_count = messages.len(),
1959 "Context compaction triggered"
1960 );
1961
1962 match compactor.compact_history(messages.clone()).await {
1963 Ok(result) => {
1964 if let Err(e) = message_store
1965 .replace_history(&ctx.thread_id, result.messages.clone())
1966 .await
1967 {
1968 warn!(error = %e, "Failed to replace history after compaction");
1969 } else {
1970 send_event(
1971 tx,
1972 hooks,
1973 AgentEvent::context_compacted(
1974 result.original_count,
1975 result.new_count,
1976 result.original_tokens,
1977 result.new_tokens,
1978 ),
1979 )
1980 .await;
1981
1982 info!(
1983 original_count = result.original_count,
1984 new_count = result.new_count,
1985 original_tokens = result.original_tokens,
1986 new_tokens = result.new_tokens,
1987 "Context compacted successfully"
1988 );
1989
1990 messages = result.messages;
1991 }
1992 }
1993 Err(e) => {
1994 warn!(error = %e, "Context compaction failed, continuing with full history");
1995 }
1996 }
1997 }
1998 }
1999
2000 let llm_tools = if tools.is_empty() {
2002 None
2003 } else {
2004 Some(tools.to_llm_tools())
2005 };
2006
2007 let request = ChatRequest {
2008 system: config.system_prompt.clone(),
2009 messages,
2010 tools: llm_tools,
2011 max_tokens: config.max_tokens,
2012 thinking: config.thinking.clone(),
2013 };
2014
2015 debug!(turn = ctx.turn, "Calling LLM");
2017 let response = match call_llm_with_retry(provider, request, config, tx, hooks).await {
2018 Ok(r) => r,
2019 Err(e) => {
2020 return InternalTurnResult::Error(e);
2021 }
2022 };
2023
2024 let turn_usage = TokenUsage {
2026 input_tokens: response.usage.input_tokens,
2027 output_tokens: response.usage.output_tokens,
2028 };
2029 ctx.total_usage.add(&turn_usage);
2030 ctx.state.total_usage = ctx.total_usage.clone();
2031
2032 let (thinking_content, text_content, tool_uses) = extract_content(&response);
2034
2035 if let Some(thinking) = &thinking_content {
2037 send_event(tx, hooks, AgentEvent::thinking(thinking.clone())).await;
2038 }
2039
2040 if let Some(text) = &text_content {
2042 send_event(tx, hooks, AgentEvent::text(text.clone())).await;
2043 }
2044
2045 if tool_uses.is_empty() {
2047 info!(turn = ctx.turn, "Agent completed (no tool use)");
2048 return InternalTurnResult::Done;
2049 }
2050
2051 let assistant_msg = build_assistant_message(&response);
2053 if let Err(e) = message_store.append(&ctx.thread_id, assistant_msg).await {
2054 send_event(
2055 tx,
2056 hooks,
2057 AgentEvent::error(format!("Failed to append assistant message: {e}"), false),
2058 )
2059 .await;
2060 return InternalTurnResult::Error(AgentError::new(
2061 format!("Failed to append assistant message: {e}"),
2062 false,
2063 ));
2064 }
2065
2066 let pending_tool_calls: Vec<PendingToolCallInfo> = tool_uses
2068 .iter()
2069 .map(|(id, name, input)| {
2070 let display_name = tools
2071 .get(name)
2072 .map(|t| t.display_name().to_string())
2073 .or_else(|| tools.get_async(name).map(|t| t.display_name().to_string()))
2074 .unwrap_or_default();
2075 PendingToolCallInfo {
2076 id: id.clone(),
2077 name: name.clone(),
2078 display_name,
2079 input: input.clone(),
2080 }
2081 })
2082 .collect();
2083
2084 let mut tool_results = Vec::new();
2086 for (idx, pending) in pending_tool_calls.iter().enumerate() {
2087 if let Some(cached_result) = try_get_cached_result(execution_store, &pending.id).await {
2089 debug!(
2090 tool_call_id = pending.id,
2091 tool_name = pending.name,
2092 "Using cached result from previous execution"
2093 );
2094 tool_results.push((pending.id.clone(), cached_result));
2095 continue;
2096 }
2097
2098 if let Some(async_tool) = tools.get_async(&pending.name) {
2100 let tier = async_tool.tier();
2101
2102 send_event(
2104 tx,
2105 hooks,
2106 AgentEvent::tool_call_start(
2107 &pending.id,
2108 &pending.name,
2109 &pending.display_name,
2110 pending.input.clone(),
2111 tier,
2112 ),
2113 )
2114 .await;
2115
2116 let decision = hooks
2118 .pre_tool_use(&pending.name, &pending.input, tier)
2119 .await;
2120
2121 match decision {
2122 ToolDecision::Allow => {
2123 let started_at = time::OffsetDateTime::now_utc();
2125 record_execution_start(execution_store, pending, &ctx.thread_id, started_at)
2126 .await;
2127
2128 let result = execute_async_tool(pending, async_tool, tool_context, tx).await;
2129
2130 record_execution_complete(
2132 execution_store,
2133 pending,
2134 &ctx.thread_id,
2135 &result,
2136 started_at,
2137 )
2138 .await;
2139
2140 hooks.post_tool_use(&pending.name, &result).await;
2141
2142 send_event(
2143 tx,
2144 hooks,
2145 AgentEvent::tool_call_end(
2146 &pending.id,
2147 &pending.name,
2148 &pending.display_name,
2149 result.clone(),
2150 ),
2151 )
2152 .await;
2153
2154 tool_results.push((pending.id.clone(), result));
2155 }
2156 ToolDecision::Block(reason) => {
2157 let result = ToolResult::error(format!("Blocked: {reason}"));
2158 send_event(
2159 tx,
2160 hooks,
2161 AgentEvent::tool_call_end(
2162 &pending.id,
2163 &pending.name,
2164 &pending.display_name,
2165 result.clone(),
2166 ),
2167 )
2168 .await;
2169 tool_results.push((pending.id.clone(), result));
2170 }
2171 ToolDecision::RequiresConfirmation(description) => {
2172 send_event(
2174 tx,
2175 hooks,
2176 AgentEvent::ToolRequiresConfirmation {
2177 id: pending.id.clone(),
2178 name: pending.name.clone(),
2179 input: pending.input.clone(),
2180 description: description.clone(),
2181 },
2182 )
2183 .await;
2184
2185 let continuation = AgentContinuation {
2186 thread_id: ctx.thread_id.clone(),
2187 turn: ctx.turn,
2188 total_usage: ctx.total_usage.clone(),
2189 turn_usage: turn_usage.clone(),
2190 pending_tool_calls: pending_tool_calls.clone(),
2191 awaiting_index: idx,
2192 completed_results: tool_results,
2193 state: ctx.state.clone(),
2194 };
2195
2196 return InternalTurnResult::AwaitingConfirmation {
2197 tool_call_id: pending.id.clone(),
2198 tool_name: pending.name.clone(),
2199 display_name: pending.display_name.clone(),
2200 input: pending.input.clone(),
2201 description,
2202 continuation: Box::new(continuation),
2203 };
2204 }
2205 }
2206 continue;
2207 }
2208
2209 let Some(tool) = tools.get(&pending.name) else {
2211 let result = ToolResult::error(format!("Unknown tool: {}", pending.name));
2212 tool_results.push((pending.id.clone(), result));
2213 continue;
2214 };
2215
2216 let tier = tool.tier();
2217
2218 send_event(
2220 tx,
2221 hooks,
2222 AgentEvent::tool_call_start(
2223 &pending.id,
2224 &pending.name,
2225 &pending.display_name,
2226 pending.input.clone(),
2227 tier,
2228 ),
2229 )
2230 .await;
2231
2232 let decision = hooks
2234 .pre_tool_use(&pending.name, &pending.input, tier)
2235 .await;
2236
2237 match decision {
2238 ToolDecision::Allow => {
2239 let started_at = time::OffsetDateTime::now_utc();
2241 record_execution_start(execution_store, pending, &ctx.thread_id, started_at).await;
2242
2243 let tool_start = Instant::now();
2244 let result = match tool.execute(tool_context, pending.input.clone()).await {
2245 Ok(mut r) => {
2246 r.duration_ms = Some(millis_to_u64(tool_start.elapsed().as_millis()));
2247 r
2248 }
2249 Err(e) => ToolResult::error(format!("Tool error: {e}"))
2250 .with_duration(millis_to_u64(tool_start.elapsed().as_millis())),
2251 };
2252
2253 record_execution_complete(
2255 execution_store,
2256 pending,
2257 &ctx.thread_id,
2258 &result,
2259 started_at,
2260 )
2261 .await;
2262
2263 hooks.post_tool_use(&pending.name, &result).await;
2264
2265 send_event(
2266 tx,
2267 hooks,
2268 AgentEvent::tool_call_end(
2269 &pending.id,
2270 &pending.name,
2271 &pending.display_name,
2272 result.clone(),
2273 ),
2274 )
2275 .await;
2276
2277 tool_results.push((pending.id.clone(), result));
2278 }
2279 ToolDecision::Block(reason) => {
2280 let result = ToolResult::error(format!("Blocked: {reason}"));
2281 send_event(
2282 tx,
2283 hooks,
2284 AgentEvent::tool_call_end(
2285 &pending.id,
2286 &pending.name,
2287 &pending.display_name,
2288 result.clone(),
2289 ),
2290 )
2291 .await;
2292 tool_results.push((pending.id.clone(), result));
2293 }
2294 ToolDecision::RequiresConfirmation(description) => {
2295 send_event(
2297 tx,
2298 hooks,
2299 AgentEvent::ToolRequiresConfirmation {
2300 id: pending.id.clone(),
2301 name: pending.name.clone(),
2302 input: pending.input.clone(),
2303 description: description.clone(),
2304 },
2305 )
2306 .await;
2307
2308 let continuation = AgentContinuation {
2309 thread_id: ctx.thread_id.clone(),
2310 turn: ctx.turn,
2311 total_usage: ctx.total_usage.clone(),
2312 turn_usage: turn_usage.clone(),
2313 pending_tool_calls: pending_tool_calls.clone(),
2314 awaiting_index: idx,
2315 completed_results: tool_results,
2316 state: ctx.state.clone(),
2317 };
2318
2319 return InternalTurnResult::AwaitingConfirmation {
2320 tool_call_id: pending.id.clone(),
2321 tool_name: pending.name.clone(),
2322 display_name: pending.display_name.clone(),
2323 input: pending.input.clone(),
2324 description,
2325 continuation: Box::new(continuation),
2326 };
2327 }
2328 }
2329 }
2330
2331 if let Err(e) = append_tool_results(&tool_results, &ctx.thread_id, message_store).await {
2333 send_event(
2334 tx,
2335 hooks,
2336 AgentEvent::error(format!("Failed to append tool results: {e}"), false),
2337 )
2338 .await;
2339 return InternalTurnResult::Error(e);
2340 }
2341
2342 send_event(
2344 tx,
2345 hooks,
2346 AgentEvent::TurnComplete {
2347 turn: ctx.turn,
2348 usage: turn_usage.clone(),
2349 },
2350 )
2351 .await;
2352
2353 if response.stop_reason == Some(StopReason::EndTurn) {
2355 info!(turn = ctx.turn, "Agent completed (end_turn)");
2356 return InternalTurnResult::Done;
2357 }
2358
2359 InternalTurnResult::Continue { turn_usage }
2360}
2361
2362#[allow(clippy::cast_possible_truncation)]
2364const fn millis_to_u64(millis: u128) -> u64 {
2365 if millis > u64::MAX as u128 {
2366 u64::MAX
2367 } else {
2368 millis as u64
2369 }
2370}
2371
2372fn calculate_backoff_delay(attempt: u32, config: &RetryConfig) -> Duration {
2377 let base_delay = config
2379 .base_delay_ms
2380 .saturating_mul(1u64 << (attempt.saturating_sub(1)));
2381
2382 let max_jitter = config.base_delay_ms.min(1000);
2384 let jitter = if max_jitter > 0 {
2385 u64::from(
2386 std::time::SystemTime::now()
2387 .duration_since(std::time::UNIX_EPOCH)
2388 .unwrap_or_default()
2389 .subsec_nanos(),
2390 ) % max_jitter
2391 } else {
2392 0
2393 };
2394
2395 let delay_ms = base_delay.saturating_add(jitter).min(config.max_delay_ms);
2396 Duration::from_millis(delay_ms)
2397}
2398
2399type ExtractedContent = (
2401 Option<String>,
2402 Option<String>,
2403 Vec<(String, String, serde_json::Value)>,
2404);
2405
2406fn extract_content(response: &ChatResponse) -> ExtractedContent {
2408 let mut thinking_parts = Vec::new();
2409 let mut text_parts = Vec::new();
2410 let mut tool_uses = Vec::new();
2411
2412 for block in &response.content {
2413 match block {
2414 ContentBlock::Text { text } => {
2415 text_parts.push(text.clone());
2416 }
2417 ContentBlock::Thinking { thinking } => {
2418 thinking_parts.push(thinking.clone());
2419 }
2420 ContentBlock::ToolUse {
2421 id, name, input, ..
2422 } => {
2423 tool_uses.push((id.clone(), name.clone(), input.clone()));
2424 }
2425 ContentBlock::ToolResult { .. } => {
2426 }
2428 }
2429 }
2430
2431 let thinking = if thinking_parts.is_empty() {
2432 None
2433 } else {
2434 Some(thinking_parts.join("\n"))
2435 };
2436
2437 let text = if text_parts.is_empty() {
2438 None
2439 } else {
2440 Some(text_parts.join("\n"))
2441 };
2442
2443 (thinking, text, tool_uses)
2444}
2445
2446async fn send_event<H>(tx: &mpsc::Sender<AgentEvent>, hooks: &Arc<H>, event: AgentEvent)
2447where
2448 H: AgentHooks,
2449{
2450 hooks.on_event(&event).await;
2451 let _ = tx.send(event).await;
2452}
2453
2454fn build_assistant_message(response: &ChatResponse) -> Message {
2455 let mut blocks = Vec::new();
2456
2457 for block in &response.content {
2458 match block {
2459 ContentBlock::Text { text } => {
2460 blocks.push(ContentBlock::Text { text: text.clone() });
2461 }
2462 ContentBlock::Thinking { .. } | ContentBlock::ToolResult { .. } => {
2463 }
2466 ContentBlock::ToolUse {
2467 id,
2468 name,
2469 input,
2470 thought_signature,
2471 } => {
2472 blocks.push(ContentBlock::ToolUse {
2473 id: id.clone(),
2474 name: name.clone(),
2475 input: input.clone(),
2476 thought_signature: thought_signature.clone(),
2477 });
2478 }
2479 }
2480 }
2481
2482 Message {
2483 role: Role::Assistant,
2484 content: Content::Blocks(blocks),
2485 }
2486}
2487
2488#[cfg(test)]
2489mod tests {
2490 use super::*;
2491 use crate::hooks::AllowAllHooks;
2492 use crate::llm::{ChatOutcome, ChatRequest, ChatResponse, ContentBlock, StopReason, Usage};
2493 use crate::stores::InMemoryStore;
2494 use crate::tools::{Tool, ToolContext, ToolRegistry};
2495 use crate::types::{AgentConfig, AgentInput, ToolResult, ToolTier};
2496 use anyhow::Result;
2497 use async_trait::async_trait;
2498 use serde_json::json;
2499 use std::sync::RwLock;
2500 use std::sync::atomic::{AtomicUsize, Ordering};
2501
2502 struct MockProvider {
2507 responses: RwLock<Vec<ChatOutcome>>,
2508 call_count: AtomicUsize,
2509 }
2510
2511 impl MockProvider {
2512 fn new(responses: Vec<ChatOutcome>) -> Self {
2513 Self {
2514 responses: RwLock::new(responses),
2515 call_count: AtomicUsize::new(0),
2516 }
2517 }
2518
2519 fn text_response(text: &str) -> ChatOutcome {
2520 ChatOutcome::Success(ChatResponse {
2521 id: "msg_1".to_string(),
2522 content: vec![ContentBlock::Text {
2523 text: text.to_string(),
2524 }],
2525 model: "mock-model".to_string(),
2526 stop_reason: Some(StopReason::EndTurn),
2527 usage: Usage {
2528 input_tokens: 10,
2529 output_tokens: 20,
2530 },
2531 })
2532 }
2533
2534 fn tool_use_response(
2535 tool_id: &str,
2536 tool_name: &str,
2537 input: serde_json::Value,
2538 ) -> ChatOutcome {
2539 ChatOutcome::Success(ChatResponse {
2540 id: "msg_1".to_string(),
2541 content: vec![ContentBlock::ToolUse {
2542 id: tool_id.to_string(),
2543 name: tool_name.to_string(),
2544 input,
2545 thought_signature: None,
2546 }],
2547 model: "mock-model".to_string(),
2548 stop_reason: Some(StopReason::ToolUse),
2549 usage: Usage {
2550 input_tokens: 10,
2551 output_tokens: 20,
2552 },
2553 })
2554 }
2555 }
2556
2557 #[async_trait]
2558 impl LlmProvider for MockProvider {
2559 async fn chat(&self, _request: ChatRequest) -> Result<ChatOutcome> {
2560 let idx = self.call_count.fetch_add(1, Ordering::SeqCst);
2561 let responses = self.responses.read().unwrap();
2562 if idx < responses.len() {
2563 Ok(responses[idx].clone())
2564 } else {
2565 Ok(Self::text_response("Done"))
2567 }
2568 }
2569
2570 fn model(&self) -> &'static str {
2571 "mock-model"
2572 }
2573
2574 fn provider(&self) -> &'static str {
2575 "mock"
2576 }
2577 }
2578
2579 impl Clone for ChatOutcome {
2581 fn clone(&self) -> Self {
2582 match self {
2583 Self::Success(r) => Self::Success(r.clone()),
2584 Self::RateLimited => Self::RateLimited,
2585 Self::InvalidRequest(s) => Self::InvalidRequest(s.clone()),
2586 Self::ServerError(s) => Self::ServerError(s.clone()),
2587 }
2588 }
2589 }
2590
2591 struct EchoTool;
2596
2597 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
2599 #[serde(rename_all = "snake_case")]
2600 enum TestToolName {
2601 Echo,
2602 }
2603
2604 impl crate::tools::ToolName for TestToolName {}
2605
2606 impl Tool<()> for EchoTool {
2607 type Name = TestToolName;
2608
2609 fn name(&self) -> TestToolName {
2610 TestToolName::Echo
2611 }
2612
2613 fn display_name(&self) -> &'static str {
2614 "Echo"
2615 }
2616
2617 fn description(&self) -> &'static str {
2618 "Echo the input message"
2619 }
2620
2621 fn input_schema(&self) -> serde_json::Value {
2622 json!({
2623 "type": "object",
2624 "properties": {
2625 "message": { "type": "string" }
2626 },
2627 "required": ["message"]
2628 })
2629 }
2630
2631 fn tier(&self) -> ToolTier {
2632 ToolTier::Observe
2633 }
2634
2635 async fn execute(
2636 &self,
2637 _ctx: &ToolContext<()>,
2638 input: serde_json::Value,
2639 ) -> Result<ToolResult> {
2640 let message = input
2641 .get("message")
2642 .and_then(|v| v.as_str())
2643 .unwrap_or("no message");
2644 Ok(ToolResult::success(format!("Echo: {message}")))
2645 }
2646 }
2647
2648 #[test]
2653 fn test_builder_creates_agent_loop() {
2654 let provider = MockProvider::new(vec![]);
2655 let agent = builder::<()>().provider(provider).build();
2656
2657 assert_eq!(agent.config.max_turns, 10);
2658 assert_eq!(agent.config.max_tokens, 4096);
2659 }
2660
2661 #[test]
2662 fn test_builder_with_custom_config() {
2663 let provider = MockProvider::new(vec![]);
2664 let config = AgentConfig {
2665 max_turns: 5,
2666 max_tokens: 2048,
2667 system_prompt: "Custom prompt".to_string(),
2668 model: "custom-model".to_string(),
2669 ..Default::default()
2670 };
2671
2672 let agent = builder::<()>().provider(provider).config(config).build();
2673
2674 assert_eq!(agent.config.max_turns, 5);
2675 assert_eq!(agent.config.max_tokens, 2048);
2676 assert_eq!(agent.config.system_prompt, "Custom prompt");
2677 }
2678
2679 #[test]
2680 fn test_builder_with_tools() {
2681 let provider = MockProvider::new(vec![]);
2682 let mut tools = ToolRegistry::new();
2683 tools.register(EchoTool);
2684
2685 let agent = builder::<()>().provider(provider).tools(tools).build();
2686
2687 assert_eq!(agent.tools.len(), 1);
2688 }
2689
2690 #[test]
2691 fn test_builder_with_custom_stores() {
2692 let provider = MockProvider::new(vec![]);
2693 let message_store = InMemoryStore::new();
2694 let state_store = InMemoryStore::new();
2695
2696 let agent = builder::<()>()
2697 .provider(provider)
2698 .hooks(AllowAllHooks)
2699 .message_store(message_store)
2700 .state_store(state_store)
2701 .build_with_stores();
2702
2703 assert_eq!(agent.config.max_turns, 10);
2705 }
2706
2707 #[tokio::test]
2712 async fn test_simple_text_response() -> anyhow::Result<()> {
2713 let provider = MockProvider::new(vec![MockProvider::text_response("Hello, user!")]);
2714
2715 let agent = builder::<()>().provider(provider).build();
2716
2717 let thread_id = ThreadId::new();
2718 let tool_ctx = ToolContext::new(());
2719 let (mut rx, _final_state) =
2720 agent.run(thread_id, AgentInput::Text("Hi".to_string()), tool_ctx);
2721
2722 let mut events = Vec::new();
2723 while let Some(event) = rx.recv().await {
2724 events.push(event);
2725 }
2726
2727 assert!(events.iter().any(|e| matches!(e, AgentEvent::Text { .. })));
2729 assert!(events.iter().any(|e| matches!(e, AgentEvent::Done { .. })));
2730
2731 Ok(())
2732 }
2733
2734 #[tokio::test]
2735 async fn test_tool_execution() -> anyhow::Result<()> {
2736 let provider = MockProvider::new(vec![
2737 MockProvider::tool_use_response("tool_1", "echo", json!({"message": "test"})),
2739 MockProvider::text_response("Tool executed successfully"),
2741 ]);
2742
2743 let mut tools = ToolRegistry::new();
2744 tools.register(EchoTool);
2745
2746 let agent = builder::<()>().provider(provider).tools(tools).build();
2747
2748 let thread_id = ThreadId::new();
2749 let tool_ctx = ToolContext::new(());
2750 let (mut rx, _final_state) = agent.run(
2751 thread_id,
2752 AgentInput::Text("Run echo".to_string()),
2753 tool_ctx,
2754 );
2755
2756 let mut events = Vec::new();
2757 while let Some(event) = rx.recv().await {
2758 events.push(event);
2759 }
2760
2761 assert!(
2763 events
2764 .iter()
2765 .any(|e| matches!(e, AgentEvent::ToolCallStart { .. }))
2766 );
2767 assert!(
2768 events
2769 .iter()
2770 .any(|e| matches!(e, AgentEvent::ToolCallEnd { .. }))
2771 );
2772
2773 Ok(())
2774 }
2775
2776 #[tokio::test]
2777 async fn test_max_turns_limit() -> anyhow::Result<()> {
2778 let provider = MockProvider::new(vec![
2780 MockProvider::tool_use_response("tool_1", "echo", json!({"message": "1"})),
2781 MockProvider::tool_use_response("tool_2", "echo", json!({"message": "2"})),
2782 MockProvider::tool_use_response("tool_3", "echo", json!({"message": "3"})),
2783 MockProvider::tool_use_response("tool_4", "echo", json!({"message": "4"})),
2784 ]);
2785
2786 let mut tools = ToolRegistry::new();
2787 tools.register(EchoTool);
2788
2789 let config = AgentConfig {
2790 max_turns: 2,
2791 ..Default::default()
2792 };
2793
2794 let agent = builder::<()>()
2795 .provider(provider)
2796 .tools(tools)
2797 .config(config)
2798 .build();
2799
2800 let thread_id = ThreadId::new();
2801 let tool_ctx = ToolContext::new(());
2802 let (mut rx, _final_state) =
2803 agent.run(thread_id, AgentInput::Text("Loop".to_string()), tool_ctx);
2804
2805 let mut events = Vec::new();
2806 while let Some(event) = rx.recv().await {
2807 events.push(event);
2808 }
2809
2810 assert!(events.iter().any(|e| {
2812 matches!(e, AgentEvent::Error { message, .. } if message.contains("Maximum turns"))
2813 }));
2814
2815 Ok(())
2816 }
2817
2818 #[tokio::test]
2819 async fn test_unknown_tool_handling() -> anyhow::Result<()> {
2820 let provider = MockProvider::new(vec![
2821 MockProvider::tool_use_response("tool_1", "nonexistent_tool", json!({})),
2823 MockProvider::text_response("I couldn't find that tool."),
2825 ]);
2826
2827 let tools = ToolRegistry::new();
2829
2830 let agent = builder::<()>().provider(provider).tools(tools).build();
2831
2832 let thread_id = ThreadId::new();
2833 let tool_ctx = ToolContext::new(());
2834 let (mut rx, _final_state) = agent.run(
2835 thread_id,
2836 AgentInput::Text("Call unknown".to_string()),
2837 tool_ctx,
2838 );
2839
2840 let mut events = Vec::new();
2841 while let Some(event) = rx.recv().await {
2842 events.push(event);
2843 }
2844
2845 assert!(events.iter().any(|e| matches!(e, AgentEvent::Done { .. })));
2848
2849 assert!(
2851 events.iter().any(|e| {
2852 matches!(e, AgentEvent::Text { text } if text.contains("couldn't find"))
2853 })
2854 );
2855
2856 Ok(())
2857 }
2858
2859 #[tokio::test]
2860 async fn test_rate_limit_handling() -> anyhow::Result<()> {
2861 let provider = MockProvider::new(vec![
2863 ChatOutcome::RateLimited,
2864 ChatOutcome::RateLimited,
2865 ChatOutcome::RateLimited,
2866 ChatOutcome::RateLimited,
2867 ChatOutcome::RateLimited,
2868 ChatOutcome::RateLimited, ]);
2870
2871 let config = AgentConfig {
2873 retry: crate::types::RetryConfig::fast(),
2874 ..Default::default()
2875 };
2876
2877 let agent = builder::<()>().provider(provider).config(config).build();
2878
2879 let thread_id = ThreadId::new();
2880 let tool_ctx = ToolContext::new(());
2881 let (mut rx, _final_state) =
2882 agent.run(thread_id, AgentInput::Text("Hi".to_string()), tool_ctx);
2883
2884 let mut events = Vec::new();
2885 while let Some(event) = rx.recv().await {
2886 events.push(event);
2887 }
2888
2889 assert!(events.iter().any(|e| {
2891 matches!(e, AgentEvent::Error { message, recoverable: true } if message.contains("Rate limited"))
2892 }));
2893
2894 assert!(
2896 events
2897 .iter()
2898 .any(|e| { matches!(e, AgentEvent::Text { text } if text.contains("retrying")) })
2899 );
2900
2901 Ok(())
2902 }
2903
2904 #[tokio::test]
2905 async fn test_rate_limit_recovery() -> anyhow::Result<()> {
2906 let provider = MockProvider::new(vec![
2908 ChatOutcome::RateLimited,
2909 MockProvider::text_response("Recovered after rate limit"),
2910 ]);
2911
2912 let config = AgentConfig {
2914 retry: crate::types::RetryConfig::fast(),
2915 ..Default::default()
2916 };
2917
2918 let agent = builder::<()>().provider(provider).config(config).build();
2919
2920 let thread_id = ThreadId::new();
2921 let tool_ctx = ToolContext::new(());
2922 let (mut rx, _final_state) =
2923 agent.run(thread_id, AgentInput::Text("Hi".to_string()), tool_ctx);
2924
2925 let mut events = Vec::new();
2926 while let Some(event) = rx.recv().await {
2927 events.push(event);
2928 }
2929
2930 assert!(events.iter().any(|e| matches!(e, AgentEvent::Done { .. })));
2932
2933 assert!(
2935 events
2936 .iter()
2937 .any(|e| { matches!(e, AgentEvent::Text { text } if text.contains("retrying")) })
2938 );
2939
2940 Ok(())
2941 }
2942
2943 #[tokio::test]
2944 async fn test_server_error_handling() -> anyhow::Result<()> {
2945 let provider = MockProvider::new(vec![
2947 ChatOutcome::ServerError("Internal error".to_string()),
2948 ChatOutcome::ServerError("Internal error".to_string()),
2949 ChatOutcome::ServerError("Internal error".to_string()),
2950 ChatOutcome::ServerError("Internal error".to_string()),
2951 ChatOutcome::ServerError("Internal error".to_string()),
2952 ChatOutcome::ServerError("Internal error".to_string()), ]);
2954
2955 let config = AgentConfig {
2957 retry: crate::types::RetryConfig::fast(),
2958 ..Default::default()
2959 };
2960
2961 let agent = builder::<()>().provider(provider).config(config).build();
2962
2963 let thread_id = ThreadId::new();
2964 let tool_ctx = ToolContext::new(());
2965 let (mut rx, _final_state) =
2966 agent.run(thread_id, AgentInput::Text("Hi".to_string()), tool_ctx);
2967
2968 let mut events = Vec::new();
2969 while let Some(event) = rx.recv().await {
2970 events.push(event);
2971 }
2972
2973 assert!(events.iter().any(|e| {
2975 matches!(e, AgentEvent::Error { message, recoverable: true } if message.contains("Server error"))
2976 }));
2977
2978 assert!(
2980 events
2981 .iter()
2982 .any(|e| { matches!(e, AgentEvent::Text { text } if text.contains("retrying")) })
2983 );
2984
2985 Ok(())
2986 }
2987
2988 #[tokio::test]
2989 async fn test_server_error_recovery() -> anyhow::Result<()> {
2990 let provider = MockProvider::new(vec![
2992 ChatOutcome::ServerError("Temporary error".to_string()),
2993 MockProvider::text_response("Recovered after server error"),
2994 ]);
2995
2996 let config = AgentConfig {
2998 retry: crate::types::RetryConfig::fast(),
2999 ..Default::default()
3000 };
3001
3002 let agent = builder::<()>().provider(provider).config(config).build();
3003
3004 let thread_id = ThreadId::new();
3005 let tool_ctx = ToolContext::new(());
3006 let (mut rx, _final_state) =
3007 agent.run(thread_id, AgentInput::Text("Hi".to_string()), tool_ctx);
3008
3009 let mut events = Vec::new();
3010 while let Some(event) = rx.recv().await {
3011 events.push(event);
3012 }
3013
3014 assert!(events.iter().any(|e| matches!(e, AgentEvent::Done { .. })));
3016
3017 assert!(
3019 events
3020 .iter()
3021 .any(|e| { matches!(e, AgentEvent::Text { text } if text.contains("retrying")) })
3022 );
3023
3024 Ok(())
3025 }
3026
3027 #[test]
3032 fn test_extract_content_text_only() {
3033 let response = ChatResponse {
3034 id: "msg_1".to_string(),
3035 content: vec![ContentBlock::Text {
3036 text: "Hello".to_string(),
3037 }],
3038 model: "test".to_string(),
3039 stop_reason: None,
3040 usage: Usage {
3041 input_tokens: 0,
3042 output_tokens: 0,
3043 },
3044 };
3045
3046 let (thinking, text, tool_uses) = extract_content(&response);
3047 assert!(thinking.is_none());
3048 assert_eq!(text, Some("Hello".to_string()));
3049 assert!(tool_uses.is_empty());
3050 }
3051
3052 #[test]
3053 fn test_extract_content_tool_use() {
3054 let response = ChatResponse {
3055 id: "msg_1".to_string(),
3056 content: vec![ContentBlock::ToolUse {
3057 id: "tool_1".to_string(),
3058 name: "test_tool".to_string(),
3059 input: json!({"key": "value"}),
3060 thought_signature: None,
3061 }],
3062 model: "test".to_string(),
3063 stop_reason: None,
3064 usage: Usage {
3065 input_tokens: 0,
3066 output_tokens: 0,
3067 },
3068 };
3069
3070 let (thinking, text, tool_uses) = extract_content(&response);
3071 assert!(thinking.is_none());
3072 assert!(text.is_none());
3073 assert_eq!(tool_uses.len(), 1);
3074 assert_eq!(tool_uses[0].1, "test_tool");
3075 }
3076
3077 #[test]
3078 fn test_extract_content_mixed() {
3079 let response = ChatResponse {
3080 id: "msg_1".to_string(),
3081 content: vec![
3082 ContentBlock::Text {
3083 text: "Let me help".to_string(),
3084 },
3085 ContentBlock::ToolUse {
3086 id: "tool_1".to_string(),
3087 name: "helper".to_string(),
3088 input: json!({}),
3089 thought_signature: None,
3090 },
3091 ],
3092 model: "test".to_string(),
3093 stop_reason: None,
3094 usage: Usage {
3095 input_tokens: 0,
3096 output_tokens: 0,
3097 },
3098 };
3099
3100 let (thinking, text, tool_uses) = extract_content(&response);
3101 assert!(thinking.is_none());
3102 assert_eq!(text, Some("Let me help".to_string()));
3103 assert_eq!(tool_uses.len(), 1);
3104 }
3105
3106 #[test]
3107 fn test_millis_to_u64() {
3108 assert_eq!(millis_to_u64(0), 0);
3109 assert_eq!(millis_to_u64(1000), 1000);
3110 assert_eq!(millis_to_u64(u128::from(u64::MAX)), u64::MAX);
3111 assert_eq!(millis_to_u64(u128::from(u64::MAX) + 1), u64::MAX);
3112 }
3113
3114 #[test]
3115 fn test_build_assistant_message() {
3116 let response = ChatResponse {
3117 id: "msg_1".to_string(),
3118 content: vec![
3119 ContentBlock::Text {
3120 text: "Response text".to_string(),
3121 },
3122 ContentBlock::ToolUse {
3123 id: "tool_1".to_string(),
3124 name: "echo".to_string(),
3125 input: json!({"message": "test"}),
3126 thought_signature: None,
3127 },
3128 ],
3129 model: "test".to_string(),
3130 stop_reason: None,
3131 usage: Usage {
3132 input_tokens: 0,
3133 output_tokens: 0,
3134 },
3135 };
3136
3137 let msg = build_assistant_message(&response);
3138 assert_eq!(msg.role, Role::Assistant);
3139
3140 if let Content::Blocks(blocks) = msg.content {
3141 assert_eq!(blocks.len(), 2);
3142 } else {
3143 panic!("Expected Content::Blocks");
3144 }
3145 }
3146}