1use crate::context::{CompactionConfig, ContextCompactor, LlmContextCompactor};
31use crate::events::AgentEvent;
32use crate::hooks::{AgentHooks, DefaultHooks, ToolDecision};
33use crate::llm::{
34 ChatOutcome, ChatRequest, ChatResponse, Content, ContentBlock, LlmProvider, Message, Role,
35 StopReason, StreamAccumulator, StreamDelta, Usage,
36};
37use crate::skills::Skill;
38use crate::stores::{InMemoryStore, MessageStore, StateStore, ToolExecutionStore};
39use crate::tools::{ErasedAsyncTool, ErasedToolStatus, ToolContext, ToolRegistry};
40use crate::types::{
41 AgentConfig, AgentContinuation, AgentError, AgentInput, AgentRunState, AgentState,
42 ExecutionStatus, PendingToolCallInfo, RetryConfig, ThreadId, TokenUsage, ToolExecution,
43 ToolOutcome, ToolResult, TurnOutcome,
44};
45use futures::StreamExt;
46use std::sync::Arc;
47use std::time::{Duration, Instant};
48use tokio::sync::mpsc;
49use tokio::time::sleep;
50use tracing::{debug, error, info, warn};
51
52enum InternalTurnResult {
56 Continue { turn_usage: TokenUsage },
58 Done,
60 AwaitingConfirmation {
62 tool_call_id: String,
63 tool_name: String,
64 display_name: String,
65 input: serde_json::Value,
66 description: String,
67 continuation: Box<AgentContinuation>,
68 },
69 Error(AgentError),
71}
72
73struct TurnContext {
77 thread_id: ThreadId,
78 turn: usize,
79 total_usage: TokenUsage,
80 state: AgentState,
81 start_time: Instant,
82}
83
84struct ResumeData {
86 continuation: Box<AgentContinuation>,
87 tool_call_id: String,
88 confirmed: bool,
89 rejection_reason: Option<String>,
90}
91
92struct InitializedState {
94 turn: usize,
95 total_usage: TokenUsage,
96 state: AgentState,
97 resume_data: Option<ResumeData>,
98}
99
100enum ToolExecutionOutcome {
102 Completed { tool_id: String, result: ToolResult },
104 RequiresConfirmation {
106 tool_id: String,
107 tool_name: String,
108 display_name: String,
109 input: serde_json::Value,
110 description: String,
111 },
112}
113
114pub struct AgentLoopBuilder<Ctx, P, H, M, S> {
126 provider: Option<P>,
127 tools: Option<ToolRegistry<Ctx>>,
128 hooks: Option<H>,
129 message_store: Option<M>,
130 state_store: Option<S>,
131 config: Option<AgentConfig>,
132 compaction_config: Option<CompactionConfig>,
133 execution_store: Option<Arc<dyn ToolExecutionStore>>,
134}
135
136impl<Ctx> AgentLoopBuilder<Ctx, (), (), (), ()> {
137 #[must_use]
139 pub fn new() -> Self {
140 Self {
141 provider: None,
142 tools: None,
143 hooks: None,
144 message_store: None,
145 state_store: None,
146 config: None,
147 compaction_config: None,
148 execution_store: None,
149 }
150 }
151}
152
153impl<Ctx> Default for AgentLoopBuilder<Ctx, (), (), (), ()> {
154 fn default() -> Self {
155 Self::new()
156 }
157}
158
159impl<Ctx, P, H, M, S> AgentLoopBuilder<Ctx, P, H, M, S> {
160 #[must_use]
162 pub fn provider<P2: LlmProvider>(self, provider: P2) -> AgentLoopBuilder<Ctx, P2, H, M, S> {
163 AgentLoopBuilder {
164 provider: Some(provider),
165 tools: self.tools,
166 hooks: self.hooks,
167 message_store: self.message_store,
168 state_store: self.state_store,
169 config: self.config,
170 compaction_config: self.compaction_config,
171 execution_store: self.execution_store,
172 }
173 }
174
175 #[must_use]
177 pub fn tools(mut self, tools: ToolRegistry<Ctx>) -> Self {
178 self.tools = Some(tools);
179 self
180 }
181
182 #[must_use]
184 pub fn hooks<H2: AgentHooks>(self, hooks: H2) -> AgentLoopBuilder<Ctx, P, H2, M, S> {
185 AgentLoopBuilder {
186 provider: self.provider,
187 tools: self.tools,
188 hooks: Some(hooks),
189 message_store: self.message_store,
190 state_store: self.state_store,
191 config: self.config,
192 compaction_config: self.compaction_config,
193 execution_store: self.execution_store,
194 }
195 }
196
197 #[must_use]
199 pub fn message_store<M2: MessageStore>(
200 self,
201 message_store: M2,
202 ) -> AgentLoopBuilder<Ctx, P, H, M2, S> {
203 AgentLoopBuilder {
204 provider: self.provider,
205 tools: self.tools,
206 hooks: self.hooks,
207 message_store: Some(message_store),
208 state_store: self.state_store,
209 config: self.config,
210 compaction_config: self.compaction_config,
211 execution_store: self.execution_store,
212 }
213 }
214
215 #[must_use]
217 pub fn state_store<S2: StateStore>(
218 self,
219 state_store: S2,
220 ) -> AgentLoopBuilder<Ctx, P, H, M, S2> {
221 AgentLoopBuilder {
222 provider: self.provider,
223 tools: self.tools,
224 hooks: self.hooks,
225 message_store: self.message_store,
226 state_store: Some(state_store),
227 config: self.config,
228 compaction_config: self.compaction_config,
229 execution_store: self.execution_store,
230 }
231 }
232
233 #[must_use]
251 pub fn execution_store(mut self, store: impl ToolExecutionStore + 'static) -> Self {
252 self.execution_store = Some(Arc::new(store));
253 self
254 }
255
256 #[must_use]
258 pub fn config(mut self, config: AgentConfig) -> Self {
259 self.config = Some(config);
260 self
261 }
262
263 #[must_use]
279 pub const fn with_compaction(mut self, config: CompactionConfig) -> Self {
280 self.compaction_config = Some(config);
281 self
282 }
283
284 #[must_use]
291 pub fn with_auto_compaction(self) -> Self {
292 self.with_compaction(CompactionConfig::default())
293 }
294
295 #[must_use]
313 pub fn with_skill(mut self, skill: Skill) -> Self
314 where
315 Ctx: Send + Sync + 'static,
316 {
317 if let Some(ref mut tools) = self.tools {
319 tools.filter(|name| skill.is_tool_allowed(name));
320 }
321
322 let mut config = self.config.take().unwrap_or_default();
324 if config.system_prompt.is_empty() {
325 config.system_prompt = skill.system_prompt;
326 } else {
327 config.system_prompt = format!("{}\n\n{}", config.system_prompt, skill.system_prompt);
328 }
329 self.config = Some(config);
330
331 self
332 }
333}
334
335impl<Ctx, P> AgentLoopBuilder<Ctx, P, (), (), ()>
336where
337 Ctx: Send + Sync + 'static,
338 P: LlmProvider + 'static,
339{
340 #[must_use]
352 pub fn build(self) -> AgentLoop<Ctx, P, DefaultHooks, InMemoryStore, InMemoryStore> {
353 let provider = self.provider.expect("provider is required");
354 let tools = self.tools.unwrap_or_default();
355 let config = self.config.unwrap_or_default();
356
357 AgentLoop {
358 provider: Arc::new(provider),
359 tools: Arc::new(tools),
360 hooks: Arc::new(DefaultHooks),
361 message_store: Arc::new(InMemoryStore::new()),
362 state_store: Arc::new(InMemoryStore::new()),
363 config,
364 compaction_config: self.compaction_config,
365 execution_store: self.execution_store,
366 }
367 }
368}
369
370impl<Ctx, P, H, M, S> AgentLoopBuilder<Ctx, P, H, M, S>
371where
372 Ctx: Send + Sync + 'static,
373 P: LlmProvider + 'static,
374 H: AgentHooks + 'static,
375 M: MessageStore + 'static,
376 S: StateStore + 'static,
377{
378 #[must_use]
388 pub fn build_with_stores(self) -> AgentLoop<Ctx, P, H, M, S> {
389 let provider = self.provider.expect("provider is required");
390 let tools = self.tools.unwrap_or_default();
391 let hooks = self
392 .hooks
393 .expect("hooks is required when using build_with_stores");
394 let message_store = self
395 .message_store
396 .expect("message_store is required when using build_with_stores");
397 let state_store = self
398 .state_store
399 .expect("state_store is required when using build_with_stores");
400 let config = self.config.unwrap_or_default();
401
402 AgentLoop {
403 provider: Arc::new(provider),
404 tools: Arc::new(tools),
405 hooks: Arc::new(hooks),
406 message_store: Arc::new(message_store),
407 state_store: Arc::new(state_store),
408 config,
409 compaction_config: self.compaction_config,
410 execution_store: self.execution_store,
411 }
412 }
413}
414
415pub struct AgentLoop<Ctx, P, H, M, S>
449where
450 P: LlmProvider,
451 H: AgentHooks,
452 M: MessageStore,
453 S: StateStore,
454{
455 provider: Arc<P>,
456 tools: Arc<ToolRegistry<Ctx>>,
457 hooks: Arc<H>,
458 message_store: Arc<M>,
459 state_store: Arc<S>,
460 config: AgentConfig,
461 compaction_config: Option<CompactionConfig>,
462 execution_store: Option<Arc<dyn ToolExecutionStore>>,
463}
464
465#[must_use]
467pub fn builder<Ctx>() -> AgentLoopBuilder<Ctx, (), (), (), ()> {
468 AgentLoopBuilder::new()
469}
470
471impl<Ctx, P, H, M, S> AgentLoop<Ctx, P, H, M, S>
472where
473 Ctx: Send + Sync + 'static,
474 P: LlmProvider + 'static,
475 H: AgentHooks + 'static,
476 M: MessageStore + 'static,
477 S: StateStore + 'static,
478{
479 #[must_use]
481 pub fn new(
482 provider: P,
483 tools: ToolRegistry<Ctx>,
484 hooks: H,
485 message_store: M,
486 state_store: S,
487 config: AgentConfig,
488 ) -> Self {
489 Self {
490 provider: Arc::new(provider),
491 tools: Arc::new(tools),
492 hooks: Arc::new(hooks),
493 message_store: Arc::new(message_store),
494 state_store: Arc::new(state_store),
495 config,
496 compaction_config: None,
497 execution_store: None,
498 }
499 }
500
501 #[must_use]
503 pub fn with_compaction(
504 provider: P,
505 tools: ToolRegistry<Ctx>,
506 hooks: H,
507 message_store: M,
508 state_store: S,
509 config: AgentConfig,
510 compaction_config: CompactionConfig,
511 ) -> Self {
512 Self {
513 provider: Arc::new(provider),
514 tools: Arc::new(tools),
515 hooks: Arc::new(hooks),
516 message_store: Arc::new(message_store),
517 state_store: Arc::new(state_store),
518 config,
519 compaction_config: Some(compaction_config),
520 execution_store: None,
521 }
522 }
523
524 pub fn run(
574 &self,
575 thread_id: ThreadId,
576 input: AgentInput,
577 tool_context: ToolContext<Ctx>,
578 ) -> (
579 mpsc::Receiver<AgentEvent>,
580 tokio::sync::oneshot::Receiver<AgentRunState>,
581 )
582 where
583 Ctx: Clone,
584 {
585 let (event_tx, event_rx) = mpsc::channel(100);
586 let (state_tx, state_rx) = tokio::sync::oneshot::channel();
587
588 let provider = Arc::clone(&self.provider);
589 let tools = Arc::clone(&self.tools);
590 let hooks = Arc::clone(&self.hooks);
591 let message_store = Arc::clone(&self.message_store);
592 let state_store = Arc::clone(&self.state_store);
593 let config = self.config.clone();
594 let compaction_config = self.compaction_config.clone();
595 let execution_store = self.execution_store.clone();
596
597 tokio::spawn(async move {
598 let result = run_loop(
599 event_tx,
600 thread_id,
601 input,
602 tool_context,
603 provider,
604 tools,
605 hooks,
606 message_store,
607 state_store,
608 config,
609 compaction_config,
610 execution_store,
611 )
612 .await;
613
614 let _ = state_tx.send(result);
615 });
616
617 (event_rx, state_rx)
618 }
619
620 pub async fn run_turn(
676 &self,
677 thread_id: ThreadId,
678 input: AgentInput,
679 tool_context: ToolContext<Ctx>,
680 ) -> (mpsc::Receiver<AgentEvent>, TurnOutcome)
681 where
682 Ctx: Clone,
683 {
684 let (event_tx, event_rx) = mpsc::channel(100);
685
686 let provider = Arc::clone(&self.provider);
687 let tools = Arc::clone(&self.tools);
688 let hooks = Arc::clone(&self.hooks);
689 let message_store = Arc::clone(&self.message_store);
690 let state_store = Arc::clone(&self.state_store);
691 let config = self.config.clone();
692 let compaction_config = self.compaction_config.clone();
693 let execution_store = self.execution_store.clone();
694
695 let result = run_single_turn(TurnParameters {
696 tx: event_tx,
697 thread_id,
698 input,
699 tool_context,
700 provider,
701 tools,
702 hooks,
703 message_store,
704 state_store,
705 config,
706 compaction_config,
707 execution_store,
708 })
709 .await;
710
711 (event_rx, result)
712 }
713}
714
715async fn initialize_from_input<M, S>(
726 input: AgentInput,
727 thread_id: &ThreadId,
728 message_store: &Arc<M>,
729 state_store: &Arc<S>,
730) -> Result<InitializedState, AgentError>
731where
732 M: MessageStore,
733 S: StateStore,
734{
735 match input {
736 AgentInput::Text(user_message) => {
737 let state = match state_store.load(thread_id).await {
739 Ok(Some(s)) => s,
740 Ok(None) => AgentState::new(thread_id.clone()),
741 Err(e) => {
742 return Err(AgentError::new(format!("Failed to load state: {e}"), false));
743 }
744 };
745
746 let user_msg = Message::user(&user_message);
748 if let Err(e) = message_store.append(thread_id, user_msg).await {
749 return Err(AgentError::new(
750 format!("Failed to append message: {e}"),
751 false,
752 ));
753 }
754
755 Ok(InitializedState {
756 turn: 0,
757 total_usage: TokenUsage::default(),
758 state,
759 resume_data: None,
760 })
761 }
762 AgentInput::Resume {
763 continuation,
764 tool_call_id,
765 confirmed,
766 rejection_reason,
767 } => {
768 if continuation.thread_id != *thread_id {
770 return Err(AgentError::new(
771 format!(
772 "Thread ID mismatch: continuation is for {}, but resuming on {}",
773 continuation.thread_id, thread_id
774 ),
775 false,
776 ));
777 }
778
779 Ok(InitializedState {
780 turn: continuation.turn,
781 total_usage: continuation.total_usage.clone(),
782 state: continuation.state.clone(),
783 resume_data: Some(ResumeData {
784 continuation,
785 tool_call_id,
786 confirmed,
787 rejection_reason,
788 }),
789 })
790 }
791 AgentInput::Continue => {
792 let state = match state_store.load(thread_id).await {
794 Ok(Some(s)) => s,
795 Ok(None) => {
796 return Err(AgentError::new(
797 "Cannot continue: no state found for thread",
798 false,
799 ));
800 }
801 Err(e) => {
802 return Err(AgentError::new(format!("Failed to load state: {e}"), false));
803 }
804 };
805
806 Ok(InitializedState {
808 turn: state.turn_count,
809 total_usage: state.total_usage.clone(),
810 state,
811 resume_data: None,
812 })
813 }
814 }
815}
816
817#[allow(clippy::too_many_lines)]
826async fn execute_tool_call<Ctx, H>(
827 pending: &PendingToolCallInfo,
828 tool_context: &ToolContext<Ctx>,
829 tools: &ToolRegistry<Ctx>,
830 hooks: &Arc<H>,
831 tx: &mpsc::Sender<AgentEvent>,
832) -> ToolExecutionOutcome
833where
834 Ctx: Send + Sync + Clone + 'static,
835 H: AgentHooks,
836{
837 if let Some(async_tool) = tools.get_async(&pending.name) {
839 let tier = async_tool.tier();
840
841 let _ = tx
843 .send(AgentEvent::tool_call_start(
844 &pending.id,
845 &pending.name,
846 &pending.display_name,
847 pending.input.clone(),
848 tier,
849 ))
850 .await;
851
852 let decision = hooks
854 .pre_tool_use(&pending.name, &pending.input, tier)
855 .await;
856
857 return match decision {
858 ToolDecision::Allow => {
859 let result = execute_async_tool(pending, async_tool, tool_context, tx).await;
860
861 hooks.post_tool_use(&pending.name, &result).await;
862
863 send_event(
864 tx,
865 hooks,
866 AgentEvent::tool_call_end(
867 &pending.id,
868 &pending.name,
869 &pending.display_name,
870 result.clone(),
871 ),
872 )
873 .await;
874
875 ToolExecutionOutcome::Completed {
876 tool_id: pending.id.clone(),
877 result,
878 }
879 }
880 ToolDecision::Block(reason) => {
881 let result = ToolResult::error(format!("Blocked: {reason}"));
882 send_event(
883 tx,
884 hooks,
885 AgentEvent::tool_call_end(
886 &pending.id,
887 &pending.name,
888 &pending.display_name,
889 result.clone(),
890 ),
891 )
892 .await;
893 ToolExecutionOutcome::Completed {
894 tool_id: pending.id.clone(),
895 result,
896 }
897 }
898 ToolDecision::RequiresConfirmation(description) => {
899 send_event(
900 tx,
901 hooks,
902 AgentEvent::ToolRequiresConfirmation {
903 id: pending.id.clone(),
904 name: pending.name.clone(),
905 input: pending.input.clone(),
906 description: description.clone(),
907 },
908 )
909 .await;
910
911 ToolExecutionOutcome::RequiresConfirmation {
912 tool_id: pending.id.clone(),
913 tool_name: pending.name.clone(),
914 display_name: pending.display_name.clone(),
915 input: pending.input.clone(),
916 description,
917 }
918 }
919 };
920 }
921
922 let Some(tool) = tools.get(&pending.name) else {
924 let result = ToolResult::error(format!("Unknown tool: {}", pending.name));
925 return ToolExecutionOutcome::Completed {
926 tool_id: pending.id.clone(),
927 result,
928 };
929 };
930
931 let tier = tool.tier();
932
933 let _ = tx
935 .send(AgentEvent::tool_call_start(
936 &pending.id,
937 &pending.name,
938 &pending.display_name,
939 pending.input.clone(),
940 tier,
941 ))
942 .await;
943
944 let decision = hooks
946 .pre_tool_use(&pending.name, &pending.input, tier)
947 .await;
948
949 match decision {
950 ToolDecision::Allow => {
951 let tool_start = Instant::now();
952 let result = match tool.execute(tool_context, pending.input.clone()).await {
953 Ok(mut r) => {
954 r.duration_ms = Some(millis_to_u64(tool_start.elapsed().as_millis()));
955 r
956 }
957 Err(e) => ToolResult::error(format!("Tool error: {e}"))
958 .with_duration(millis_to_u64(tool_start.elapsed().as_millis())),
959 };
960
961 hooks.post_tool_use(&pending.name, &result).await;
962
963 send_event(
964 tx,
965 hooks,
966 AgentEvent::tool_call_end(
967 &pending.id,
968 &pending.name,
969 &pending.display_name,
970 result.clone(),
971 ),
972 )
973 .await;
974
975 ToolExecutionOutcome::Completed {
976 tool_id: pending.id.clone(),
977 result,
978 }
979 }
980 ToolDecision::Block(reason) => {
981 let result = ToolResult::error(format!("Blocked: {reason}"));
982 send_event(
983 tx,
984 hooks,
985 AgentEvent::tool_call_end(
986 &pending.id,
987 &pending.name,
988 &pending.display_name,
989 result.clone(),
990 ),
991 )
992 .await;
993 ToolExecutionOutcome::Completed {
994 tool_id: pending.id.clone(),
995 result,
996 }
997 }
998 ToolDecision::RequiresConfirmation(description) => {
999 send_event(
1000 tx,
1001 hooks,
1002 AgentEvent::ToolRequiresConfirmation {
1003 id: pending.id.clone(),
1004 name: pending.name.clone(),
1005 input: pending.input.clone(),
1006 description: description.clone(),
1007 },
1008 )
1009 .await;
1010
1011 ToolExecutionOutcome::RequiresConfirmation {
1012 tool_id: pending.id.clone(),
1013 tool_name: pending.name.clone(),
1014 display_name: pending.display_name.clone(),
1015 input: pending.input.clone(),
1016 description,
1017 }
1018 }
1019 }
1020}
1021
1022async fn execute_async_tool<Ctx>(
1028 pending: &PendingToolCallInfo,
1029 tool: &Arc<dyn ErasedAsyncTool<Ctx>>,
1030 tool_context: &ToolContext<Ctx>,
1031 tx: &mpsc::Sender<AgentEvent>,
1032) -> ToolResult
1033where
1034 Ctx: Send + Sync + Clone,
1035{
1036 let tool_start = Instant::now();
1037
1038 let outcome = match tool.execute(tool_context, pending.input.clone()).await {
1040 Ok(o) => o,
1041 Err(e) => {
1042 return ToolResult::error(format!("Tool error: {e}"))
1043 .with_duration(millis_to_u64(tool_start.elapsed().as_millis()));
1044 }
1045 };
1046
1047 match outcome {
1048 ToolOutcome::Success(mut result) | ToolOutcome::Failed(mut result) => {
1050 result.duration_ms = Some(millis_to_u64(tool_start.elapsed().as_millis()));
1051 result
1052 }
1053
1054 ToolOutcome::InProgress {
1056 operation_id,
1057 message,
1058 } => {
1059 let _ = tx
1061 .send(AgentEvent::tool_progress(
1062 &pending.id,
1063 &pending.name,
1064 &pending.display_name,
1065 "started",
1066 &message,
1067 None,
1068 ))
1069 .await;
1070
1071 let mut stream = tool.check_status_stream(tool_context, &operation_id);
1073
1074 while let Some(status) = stream.next().await {
1075 match status {
1076 ErasedToolStatus::Progress {
1077 stage,
1078 message,
1079 data,
1080 } => {
1081 let _ = tx
1082 .send(AgentEvent::tool_progress(
1083 &pending.id,
1084 &pending.name,
1085 &pending.display_name,
1086 stage,
1087 message,
1088 data,
1089 ))
1090 .await;
1091 }
1092 ErasedToolStatus::Completed(mut result)
1093 | ErasedToolStatus::Failed(mut result) => {
1094 result.duration_ms = Some(millis_to_u64(tool_start.elapsed().as_millis()));
1095 return result;
1096 }
1097 }
1098 }
1099
1100 ToolResult::error("Async tool stream ended without completion")
1102 .with_duration(millis_to_u64(tool_start.elapsed().as_millis()))
1103 }
1104 }
1105}
1106
1107async fn execute_confirmed_tool<Ctx, H>(
1112 awaiting_tool: &PendingToolCallInfo,
1113 confirmed: bool,
1114 rejection_reason: Option<String>,
1115 tool_context: &ToolContext<Ctx>,
1116 tools: &ToolRegistry<Ctx>,
1117 hooks: &Arc<H>,
1118 tx: &mpsc::Sender<AgentEvent>,
1119) -> ToolResult
1120where
1121 Ctx: Send + Sync + Clone + 'static,
1122 H: AgentHooks,
1123{
1124 if confirmed {
1125 if let Some(async_tool) = tools.get_async(&awaiting_tool.name) {
1127 let result = execute_async_tool(awaiting_tool, async_tool, tool_context, tx).await;
1128
1129 hooks.post_tool_use(&awaiting_tool.name, &result).await;
1130
1131 let _ = tx
1132 .send(AgentEvent::tool_call_end(
1133 &awaiting_tool.id,
1134 &awaiting_tool.name,
1135 &awaiting_tool.display_name,
1136 result.clone(),
1137 ))
1138 .await;
1139
1140 return result;
1141 }
1142
1143 if let Some(tool) = tools.get(&awaiting_tool.name) {
1145 let tool_start = Instant::now();
1146 let result = match tool
1147 .execute(tool_context, awaiting_tool.input.clone())
1148 .await
1149 {
1150 Ok(mut r) => {
1151 r.duration_ms = Some(millis_to_u64(tool_start.elapsed().as_millis()));
1152 r
1153 }
1154 Err(e) => ToolResult::error(format!("Tool error: {e}"))
1155 .with_duration(millis_to_u64(tool_start.elapsed().as_millis())),
1156 };
1157
1158 hooks.post_tool_use(&awaiting_tool.name, &result).await;
1159
1160 let _ = tx
1161 .send(AgentEvent::tool_call_end(
1162 &awaiting_tool.id,
1163 &awaiting_tool.name,
1164 &awaiting_tool.display_name,
1165 result.clone(),
1166 ))
1167 .await;
1168
1169 result
1170 } else {
1171 ToolResult::error(format!("Unknown tool: {}", awaiting_tool.name))
1172 }
1173 } else {
1174 let reason = rejection_reason.unwrap_or_else(|| "User rejected".to_string());
1175 let result = ToolResult::error(format!("Rejected: {reason}"));
1176 send_event(
1177 tx,
1178 hooks,
1179 AgentEvent::tool_call_end(
1180 &awaiting_tool.id,
1181 &awaiting_tool.name,
1182 &awaiting_tool.display_name,
1183 result.clone(),
1184 ),
1185 )
1186 .await;
1187 result
1188 }
1189}
1190
1191async fn append_tool_results<M>(
1193 tool_results: &[(String, ToolResult)],
1194 thread_id: &ThreadId,
1195 message_store: &Arc<M>,
1196) -> Result<(), AgentError>
1197where
1198 M: MessageStore,
1199{
1200 for (tool_id, result) in tool_results {
1201 let tool_result_msg = Message::tool_result(tool_id, &result.output, !result.success);
1202 if let Err(e) = message_store.append(thread_id, tool_result_msg).await {
1203 return Err(AgentError::new(
1204 format!("Failed to append tool result: {e}"),
1205 false,
1206 ));
1207 }
1208 }
1209 Ok(())
1210}
1211
1212async fn call_llm_with_retry<P, H>(
1214 provider: &Arc<P>,
1215 request: ChatRequest,
1216 config: &AgentConfig,
1217 tx: &mpsc::Sender<AgentEvent>,
1218 hooks: &Arc<H>,
1219) -> Result<ChatResponse, AgentError>
1220where
1221 P: LlmProvider,
1222 H: AgentHooks,
1223{
1224 let max_retries = config.retry.max_retries;
1225 let mut attempt = 0u32;
1226
1227 loop {
1228 let outcome = match provider.chat(request.clone()).await {
1229 Ok(o) => o,
1230 Err(e) => {
1231 return Err(AgentError::new(format!("LLM error: {e}"), false));
1232 }
1233 };
1234
1235 match outcome {
1236 ChatOutcome::Success(response) => return Ok(response),
1237 ChatOutcome::RateLimited => {
1238 attempt += 1;
1239 if attempt > max_retries {
1240 error!("Rate limited by LLM provider after {max_retries} retries");
1241 let error_msg = format!("Rate limited after {max_retries} retries");
1242 send_event(tx, hooks, AgentEvent::error(&error_msg, true)).await;
1243 return Err(AgentError::new(error_msg, true));
1244 }
1245 let delay = calculate_backoff_delay(attempt, &config.retry);
1246 warn!(
1247 attempt,
1248 delay_ms = delay.as_millis(),
1249 "Rate limited, retrying after backoff"
1250 );
1251 let _ = tx
1252 .send(AgentEvent::text(format!(
1253 "\n[Rate limited, retrying in {:.1}s... (attempt {attempt}/{max_retries})]\n",
1254 delay.as_secs_f64()
1255 )))
1256 .await;
1257 sleep(delay).await;
1258 }
1259 ChatOutcome::InvalidRequest(msg) => {
1260 error!(msg, "Invalid request to LLM");
1261 return Err(AgentError::new(format!("Invalid request: {msg}"), false));
1262 }
1263 ChatOutcome::ServerError(msg) => {
1264 attempt += 1;
1265 if attempt > max_retries {
1266 error!(msg, "LLM server error after {max_retries} retries");
1267 let error_msg = format!("Server error after {max_retries} retries: {msg}");
1268 send_event(tx, hooks, AgentEvent::error(&error_msg, true)).await;
1269 return Err(AgentError::new(error_msg, true));
1270 }
1271 let delay = calculate_backoff_delay(attempt, &config.retry);
1272 warn!(
1273 attempt,
1274 delay_ms = delay.as_millis(),
1275 error = msg,
1276 "Server error, retrying after backoff"
1277 );
1278 send_event(
1279 tx,
1280 hooks,
1281 AgentEvent::text(format!(
1282 "\n[Server error: {msg}, retrying in {:.1}s... (attempt {attempt}/{max_retries})]\n",
1283 delay.as_secs_f64()
1284 )),
1285 )
1286 .await;
1287 sleep(delay).await;
1288 }
1289 }
1290 }
1291}
1292
1293async fn call_llm_streaming<P, H>(
1299 provider: &Arc<P>,
1300 request: ChatRequest,
1301 config: &AgentConfig,
1302 tx: &mpsc::Sender<AgentEvent>,
1303 hooks: &Arc<H>,
1304) -> Result<ChatResponse, AgentError>
1305where
1306 P: LlmProvider,
1307 H: AgentHooks,
1308{
1309 let max_retries = config.retry.max_retries;
1310 let mut attempt = 0u32;
1311
1312 loop {
1313 let result = process_stream(provider, &request, tx, hooks).await;
1314
1315 match result {
1316 Ok(response) => return Ok(response),
1317 Err(StreamError::Recoverable(msg)) => {
1318 attempt += 1;
1319 if attempt > max_retries {
1320 error!("Streaming error after {max_retries} retries: {msg}");
1321 let err_msg = format!("Streaming error after {max_retries} retries: {msg}");
1322 send_event(tx, hooks, AgentEvent::error(&err_msg, true)).await;
1323 return Err(AgentError::new(err_msg, true));
1324 }
1325 let delay = calculate_backoff_delay(attempt, &config.retry);
1326 warn!(
1327 attempt,
1328 delay_ms = delay.as_millis(),
1329 error = msg,
1330 "Streaming error, retrying"
1331 );
1332 send_event(
1333 tx,
1334 hooks,
1335 AgentEvent::text(format!(
1336 "\n[Streaming error: {msg}, retrying in {:.1}s... (attempt {attempt}/{max_retries})]\n",
1337 delay.as_secs_f64()
1338 )),
1339 )
1340 .await;
1341 sleep(delay).await;
1342 }
1343 Err(StreamError::Fatal(msg)) => {
1344 error!("Streaming error (non-recoverable): {msg}");
1345 return Err(AgentError::new(format!("Streaming error: {msg}"), false));
1346 }
1347 }
1348 }
1349}
1350
1351enum StreamError {
1353 Recoverable(String),
1354 Fatal(String),
1355}
1356
1357async fn process_stream<P, H>(
1359 provider: &Arc<P>,
1360 request: &ChatRequest,
1361 tx: &mpsc::Sender<AgentEvent>,
1362 hooks: &Arc<H>,
1363) -> Result<ChatResponse, StreamError>
1364where
1365 P: LlmProvider,
1366 H: AgentHooks,
1367{
1368 let mut stream = std::pin::pin!(provider.chat_stream(request.clone()));
1369 let mut accumulator = StreamAccumulator::new();
1370
1371 while let Some(result) = stream.next().await {
1372 match result {
1373 Ok(delta) => {
1374 accumulator.apply(&delta);
1375 match &delta {
1376 StreamDelta::TextDelta { delta, .. } => {
1377 send_event(tx, hooks, AgentEvent::text_delta(delta.clone())).await;
1378 }
1379 StreamDelta::ThinkingDelta { delta, .. } => {
1380 send_event(tx, hooks, AgentEvent::thinking(delta.clone())).await;
1381 }
1382 StreamDelta::Error {
1383 message,
1384 recoverable,
1385 } => {
1386 return if *recoverable {
1387 Err(StreamError::Recoverable(message.clone()))
1388 } else {
1389 Err(StreamError::Fatal(message.clone()))
1390 };
1391 }
1392 StreamDelta::Done { .. }
1394 | StreamDelta::Usage(_)
1395 | StreamDelta::ToolUseStart { .. }
1396 | StreamDelta::ToolInputDelta { .. } => {}
1397 }
1398 }
1399 Err(e) => return Err(StreamError::Recoverable(format!("Stream error: {e}"))),
1400 }
1401 }
1402
1403 let usage = accumulator.usage().cloned().unwrap_or(Usage {
1404 input_tokens: 0,
1405 output_tokens: 0,
1406 });
1407 let stop_reason = accumulator.stop_reason().copied();
1408 let content_blocks = accumulator.into_content_blocks();
1409
1410 Ok(ChatResponse {
1411 id: String::new(),
1412 content: content_blocks,
1413 model: provider.model().to_string(),
1414 stop_reason,
1415 usage,
1416 })
1417}
1418
1419#[allow(clippy::too_many_arguments, clippy::too_many_lines)]
1424async fn run_loop<Ctx, P, H, M, S>(
1425 tx: mpsc::Sender<AgentEvent>,
1426 thread_id: ThreadId,
1427 input: AgentInput,
1428 tool_context: ToolContext<Ctx>,
1429 provider: Arc<P>,
1430 tools: Arc<ToolRegistry<Ctx>>,
1431 hooks: Arc<H>,
1432 message_store: Arc<M>,
1433 state_store: Arc<S>,
1434 config: AgentConfig,
1435 compaction_config: Option<CompactionConfig>,
1436 execution_store: Option<Arc<dyn ToolExecutionStore>>,
1437) -> AgentRunState
1438where
1439 Ctx: Send + Sync + Clone + 'static,
1440 P: LlmProvider,
1441 H: AgentHooks,
1442 M: MessageStore,
1443 S: StateStore,
1444{
1445 let tool_context = tool_context.with_event_tx(tx.clone());
1447 let start_time = Instant::now();
1448
1449 let init_state =
1451 match initialize_from_input(input, &thread_id, &message_store, &state_store).await {
1452 Ok(s) => s,
1453 Err(e) => return AgentRunState::Error(e),
1454 };
1455
1456 let InitializedState {
1457 turn,
1458 total_usage,
1459 state,
1460 resume_data,
1461 } = init_state;
1462
1463 if let Some(resume) = resume_data {
1464 let ResumeData {
1465 continuation: cont,
1466 tool_call_id,
1467 confirmed,
1468 rejection_reason,
1469 } = resume;
1470 let mut tool_results = cont.completed_results.clone();
1471 let awaiting_tool = &cont.pending_tool_calls[cont.awaiting_index];
1472
1473 if awaiting_tool.id != tool_call_id {
1474 let message = format!(
1475 "Tool call ID mismatch: expected {}, got {}",
1476 awaiting_tool.id, tool_call_id
1477 );
1478 let recoverable = false;
1479 send_event(&tx, &hooks, AgentEvent::error(&message, recoverable)).await;
1480 return AgentRunState::Error(AgentError::new(&message, recoverable));
1481 }
1482
1483 let result = execute_confirmed_tool(
1484 awaiting_tool,
1485 confirmed,
1486 rejection_reason,
1487 &tool_context,
1488 &tools,
1489 &hooks,
1490 &tx,
1491 )
1492 .await;
1493 tool_results.push((awaiting_tool.id.clone(), result));
1494
1495 for pending in cont.pending_tool_calls.iter().skip(cont.awaiting_index + 1) {
1496 match execute_tool_call(pending, &tool_context, &tools, &hooks, &tx).await {
1497 ToolExecutionOutcome::Completed { tool_id, result } => {
1498 tool_results.push((tool_id, result));
1499 }
1500 ToolExecutionOutcome::RequiresConfirmation {
1501 tool_id,
1502 tool_name,
1503 display_name,
1504 input,
1505 description,
1506 } => {
1507 let pending_idx = cont
1508 .pending_tool_calls
1509 .iter()
1510 .position(|p| p.id == tool_id)
1511 .unwrap_or(0);
1512
1513 let new_continuation = AgentContinuation {
1514 thread_id: thread_id.clone(),
1515 turn,
1516 total_usage: total_usage.clone(),
1517 turn_usage: cont.turn_usage.clone(),
1518 pending_tool_calls: cont.pending_tool_calls.clone(),
1519 awaiting_index: pending_idx,
1520 completed_results: tool_results,
1521 state: state.clone(),
1522 };
1523
1524 return AgentRunState::AwaitingConfirmation {
1525 tool_call_id: tool_id,
1526 tool_name,
1527 display_name,
1528 input,
1529 description,
1530 continuation: Box::new(new_continuation),
1531 };
1532 }
1533 }
1534 }
1535
1536 if let Err(e) = append_tool_results(&tool_results, &thread_id, &message_store).await {
1537 send_event(
1538 &tx,
1539 &hooks,
1540 AgentEvent::Error {
1541 message: e.message.clone(),
1542 recoverable: e.recoverable,
1543 },
1544 )
1545 .await;
1546 return AgentRunState::Error(e);
1547 }
1548
1549 send_event(
1550 &tx,
1551 &hooks,
1552 AgentEvent::TurnComplete {
1553 turn,
1554 usage: cont.turn_usage.clone(),
1555 },
1556 )
1557 .await;
1558 }
1559
1560 let mut ctx = TurnContext {
1561 thread_id: thread_id.clone(),
1562 turn,
1563 total_usage,
1564 state,
1565 start_time,
1566 };
1567
1568 loop {
1569 let result = execute_turn(
1570 &tx,
1571 &mut ctx,
1572 &tool_context,
1573 &provider,
1574 &tools,
1575 &hooks,
1576 &message_store,
1577 &config,
1578 compaction_config.as_ref(),
1579 execution_store.as_ref(),
1580 )
1581 .await;
1582
1583 match result {
1584 InternalTurnResult::Continue { .. } => {
1585 if let Err(e) = state_store.save(&ctx.state).await {
1586 warn!(error = %e, "Failed to save state checkpoint");
1587 }
1588 }
1589 InternalTurnResult::Done => {
1590 break;
1591 }
1592 InternalTurnResult::AwaitingConfirmation {
1593 tool_call_id,
1594 tool_name,
1595 display_name,
1596 input,
1597 description,
1598 continuation,
1599 } => {
1600 return AgentRunState::AwaitingConfirmation {
1601 tool_call_id,
1602 tool_name,
1603 display_name,
1604 input,
1605 description,
1606 continuation,
1607 };
1608 }
1609 InternalTurnResult::Error(e) => {
1610 return AgentRunState::Error(e);
1611 }
1612 }
1613 }
1614
1615 if let Err(e) = state_store.save(&ctx.state).await {
1616 warn!(error = %e, "Failed to save final state");
1617 }
1618
1619 let duration = ctx.start_time.elapsed();
1620 send_event(
1621 &tx,
1622 &hooks,
1623 AgentEvent::done(thread_id, ctx.turn, ctx.total_usage.clone(), duration),
1624 )
1625 .await;
1626
1627 AgentRunState::Done {
1628 total_turns: u32::try_from(ctx.turn).unwrap_or(u32::MAX),
1629 input_tokens: u64::from(ctx.total_usage.input_tokens),
1630 output_tokens: u64::from(ctx.total_usage.output_tokens),
1631 }
1632}
1633
1634struct TurnParameters<Ctx, P, H, M, S> {
1635 tx: mpsc::Sender<AgentEvent>,
1636 thread_id: ThreadId,
1637 input: AgentInput,
1638 tool_context: ToolContext<Ctx>,
1639 provider: Arc<P>,
1640 tools: Arc<ToolRegistry<Ctx>>,
1641 hooks: Arc<H>,
1642 message_store: Arc<M>,
1643 state_store: Arc<S>,
1644 config: AgentConfig,
1645 compaction_config: Option<CompactionConfig>,
1646 execution_store: Option<Arc<dyn ToolExecutionStore>>,
1647}
1648
1649async fn run_single_turn<Ctx, P, H, M, S>(
1655 TurnParameters {
1656 tx,
1657 thread_id,
1658 input,
1659 tool_context,
1660 provider,
1661 tools,
1662 hooks,
1663 message_store,
1664 state_store,
1665 config,
1666 compaction_config,
1667 execution_store,
1668 }: TurnParameters<Ctx, P, H, M, S>,
1669) -> TurnOutcome
1670where
1671 Ctx: Send + Sync + Clone + 'static,
1672 P: LlmProvider,
1673 H: AgentHooks,
1674 M: MessageStore,
1675 S: StateStore,
1676{
1677 let tool_context = tool_context.with_event_tx(tx.clone());
1678 let start_time = Instant::now();
1679
1680 let init_state =
1681 match initialize_from_input(input, &thread_id, &message_store, &state_store).await {
1682 Ok(s) => s,
1683 Err(e) => {
1684 send_event(&tx, &hooks, AgentEvent::error(&e.message, e.recoverable)).await;
1685 return TurnOutcome::Error(e);
1686 }
1687 };
1688
1689 let InitializedState {
1690 turn,
1691 total_usage,
1692 state,
1693 resume_data,
1694 } = init_state;
1695
1696 if let Some(resume_data_val) = resume_data {
1697 return handle_resume_case(ResumeCaseParameters {
1698 resume_data: resume_data_val,
1699 turn,
1700 total_usage,
1701 state,
1702 thread_id,
1703 tool_context,
1704 tools,
1705 hooks,
1706 tx,
1707 message_store,
1708 state_store,
1709 })
1710 .await;
1711 }
1712
1713 let mut ctx = TurnContext {
1714 thread_id: thread_id.clone(),
1715 turn,
1716 total_usage,
1717 state,
1718 start_time,
1719 };
1720
1721 let result = execute_turn(
1722 &tx,
1723 &mut ctx,
1724 &tool_context,
1725 &provider,
1726 &tools,
1727 &hooks,
1728 &message_store,
1729 &config,
1730 compaction_config.as_ref(),
1731 execution_store.as_ref(),
1732 )
1733 .await;
1734
1735 match result {
1737 InternalTurnResult::Continue { turn_usage } => {
1738 if let Err(e) = state_store.save(&ctx.state).await {
1740 warn!(error = %e, "Failed to save state checkpoint");
1741 }
1742
1743 TurnOutcome::NeedsMoreTurns {
1744 turn: ctx.turn,
1745 turn_usage,
1746 total_usage: ctx.total_usage,
1747 }
1748 }
1749 InternalTurnResult::Done => {
1750 if let Err(e) = state_store.save(&ctx.state).await {
1752 warn!(error = %e, "Failed to save final state");
1753 }
1754
1755 let duration = ctx.start_time.elapsed();
1757 send_event(
1758 &tx,
1759 &hooks,
1760 AgentEvent::done(thread_id, ctx.turn, ctx.total_usage.clone(), duration),
1761 )
1762 .await;
1763
1764 TurnOutcome::Done {
1765 total_turns: u32::try_from(ctx.turn).unwrap_or(u32::MAX),
1766 input_tokens: u64::from(ctx.total_usage.input_tokens),
1767 output_tokens: u64::from(ctx.total_usage.output_tokens),
1768 }
1769 }
1770 InternalTurnResult::AwaitingConfirmation {
1771 tool_call_id,
1772 tool_name,
1773 display_name,
1774 input,
1775 description,
1776 continuation,
1777 } => TurnOutcome::AwaitingConfirmation {
1778 tool_call_id,
1779 tool_name,
1780 display_name,
1781 input,
1782 description,
1783 continuation,
1784 },
1785 InternalTurnResult::Error(e) => TurnOutcome::Error(e),
1786 }
1787}
1788
1789struct ResumeCaseParameters<Ctx, H, M, S> {
1790 resume_data: ResumeData,
1791 turn: usize,
1792 total_usage: TokenUsage,
1793 state: AgentState,
1794 thread_id: ThreadId,
1795 tool_context: ToolContext<Ctx>,
1796 tools: Arc<ToolRegistry<Ctx>>,
1797 hooks: Arc<H>,
1798 tx: mpsc::Sender<AgentEvent>,
1799 message_store: Arc<M>,
1800 state_store: Arc<S>,
1801}
1802
1803async fn handle_resume_case<Ctx, H, M, S>(
1804 ResumeCaseParameters {
1805 resume_data,
1806 turn,
1807 total_usage,
1808 state,
1809 thread_id,
1810 tool_context,
1811 tools,
1812 hooks,
1813 tx,
1814 message_store,
1815 state_store,
1816 }: ResumeCaseParameters<Ctx, H, M, S>,
1817) -> TurnOutcome
1818where
1819 Ctx: Send + Sync + Clone + 'static,
1820 H: AgentHooks,
1821 M: MessageStore,
1822 S: StateStore,
1823{
1824 let ResumeData {
1825 continuation: cont,
1826 tool_call_id,
1827 confirmed,
1828 rejection_reason,
1829 } = resume_data;
1830 let mut tool_results = cont.completed_results.clone();
1832 let awaiting_tool = &cont.pending_tool_calls[cont.awaiting_index];
1833
1834 if awaiting_tool.id != tool_call_id {
1836 let message = format!(
1837 "Tool call ID mismatch: expected {}, got {}",
1838 awaiting_tool.id, tool_call_id
1839 );
1840 let recoverable = false;
1841 send_event(&tx, &hooks, AgentEvent::error(&message, recoverable)).await;
1842 return TurnOutcome::Error(AgentError::new(&message, recoverable));
1843 }
1844
1845 let result = execute_confirmed_tool(
1846 awaiting_tool,
1847 confirmed,
1848 rejection_reason,
1849 &tool_context,
1850 &tools,
1851 &hooks,
1852 &tx,
1853 )
1854 .await;
1855 tool_results.push((awaiting_tool.id.clone(), result));
1856
1857 for pending in cont.pending_tool_calls.iter().skip(cont.awaiting_index + 1) {
1858 match execute_tool_call(pending, &tool_context, &tools, &hooks, &tx).await {
1859 ToolExecutionOutcome::Completed { tool_id, result } => {
1860 tool_results.push((tool_id, result));
1861 }
1862 ToolExecutionOutcome::RequiresConfirmation {
1863 tool_id,
1864 tool_name,
1865 display_name,
1866 input,
1867 description,
1868 } => {
1869 let pending_idx = cont
1870 .pending_tool_calls
1871 .iter()
1872 .position(|p| p.id == tool_id)
1873 .unwrap_or(0);
1874
1875 let new_continuation = AgentContinuation {
1876 thread_id: thread_id.clone(),
1877 turn,
1878 total_usage: total_usage.clone(),
1879 turn_usage: cont.turn_usage.clone(),
1880 pending_tool_calls: cont.pending_tool_calls.clone(),
1881 awaiting_index: pending_idx,
1882 completed_results: tool_results,
1883 state: state.clone(),
1884 };
1885
1886 return TurnOutcome::AwaitingConfirmation {
1887 tool_call_id: tool_id,
1888 tool_name,
1889 display_name,
1890 input,
1891 description,
1892 continuation: Box::new(new_continuation),
1893 };
1894 }
1895 }
1896 }
1897
1898 if let Err(e) = append_tool_results(&tool_results, &thread_id, &message_store).await {
1899 send_event(&tx, &hooks, AgentEvent::error(&e.message, e.recoverable)).await;
1900 return TurnOutcome::Error(e);
1901 }
1902
1903 send_event(
1904 &tx,
1905 &hooks,
1906 AgentEvent::TurnComplete {
1907 turn,
1908 usage: cont.turn_usage.clone(),
1909 },
1910 )
1911 .await;
1912
1913 let mut updated_state = state;
1914 updated_state.turn_count = turn;
1915 if let Err(e) = state_store.save(&updated_state).await {
1916 warn!(error = %e, "Failed to save state checkpoint");
1917 }
1918
1919 TurnOutcome::NeedsMoreTurns {
1920 turn,
1921 turn_usage: cont.turn_usage.clone(),
1922 total_usage,
1923 }
1924}
1925
1926async fn try_get_cached_result(
1935 execution_store: Option<&Arc<dyn ToolExecutionStore>>,
1936 tool_call_id: &str,
1937) -> Option<ToolResult> {
1938 let store = execution_store?;
1939 let execution = store.get_execution(tool_call_id).await.ok()??;
1940
1941 match execution.status {
1942 ExecutionStatus::Completed => execution.result,
1943 ExecutionStatus::InFlight => {
1944 warn!(
1947 tool_call_id = tool_call_id,
1948 tool_name = execution.tool_name,
1949 "Found in-flight execution from previous attempt, re-executing"
1950 );
1951 None
1952 }
1953 }
1954}
1955
1956async fn record_execution_start(
1958 execution_store: Option<&Arc<dyn ToolExecutionStore>>,
1959 pending: &PendingToolCallInfo,
1960 thread_id: &ThreadId,
1961 started_at: time::OffsetDateTime,
1962) {
1963 if let Some(store) = execution_store {
1964 let execution = ToolExecution::new_in_flight(
1965 &pending.id,
1966 thread_id.clone(),
1967 &pending.name,
1968 &pending.display_name,
1969 pending.input.clone(),
1970 started_at,
1971 );
1972 if let Err(e) = store.record_execution(execution).await {
1973 warn!(
1974 tool_call_id = pending.id,
1975 error = %e,
1976 "Failed to record execution start"
1977 );
1978 }
1979 }
1980}
1981
1982async fn record_execution_complete(
1984 execution_store: Option<&Arc<dyn ToolExecutionStore>>,
1985 pending: &PendingToolCallInfo,
1986 thread_id: &ThreadId,
1987 result: &ToolResult,
1988 started_at: time::OffsetDateTime,
1989) {
1990 if let Some(store) = execution_store {
1991 let mut execution = ToolExecution::new_in_flight(
1992 &pending.id,
1993 thread_id.clone(),
1994 &pending.name,
1995 &pending.display_name,
1996 pending.input.clone(),
1997 started_at,
1998 );
1999 execution.complete(result.clone());
2000 if let Err(e) = store.update_execution(execution).await {
2001 warn!(
2002 tool_call_id = pending.id,
2003 error = %e,
2004 "Failed to record execution completion"
2005 );
2006 }
2007 }
2008}
2009
2010#[allow(clippy::too_many_arguments, clippy::too_many_lines)]
2015async fn execute_turn<Ctx, P, H, M>(
2016 tx: &mpsc::Sender<AgentEvent>,
2017 ctx: &mut TurnContext,
2018 tool_context: &ToolContext<Ctx>,
2019 provider: &Arc<P>,
2020 tools: &Arc<ToolRegistry<Ctx>>,
2021 hooks: &Arc<H>,
2022 message_store: &Arc<M>,
2023 config: &AgentConfig,
2024 compaction_config: Option<&CompactionConfig>,
2025 execution_store: Option<&Arc<dyn ToolExecutionStore>>,
2026) -> InternalTurnResult
2027where
2028 Ctx: Send + Sync + Clone + 'static,
2029 P: LlmProvider,
2030 H: AgentHooks,
2031 M: MessageStore,
2032{
2033 ctx.turn += 1;
2034 ctx.state.turn_count = ctx.turn;
2035
2036 if ctx.turn > config.max_turns {
2037 warn!(turn = ctx.turn, max = config.max_turns, "Max turns reached");
2038 send_event(
2039 tx,
2040 hooks,
2041 AgentEvent::error(
2042 format!("Maximum turns ({}) reached", config.max_turns),
2043 true,
2044 ),
2045 )
2046 .await;
2047 return InternalTurnResult::Error(AgentError::new(
2048 format!("Maximum turns ({}) reached", config.max_turns),
2049 true,
2050 ));
2051 }
2052
2053 send_event(
2055 tx,
2056 hooks,
2057 AgentEvent::start(ctx.thread_id.clone(), ctx.turn),
2058 )
2059 .await;
2060
2061 let mut messages = match message_store.get_history(&ctx.thread_id).await {
2063 Ok(m) => m,
2064 Err(e) => {
2065 send_event(
2066 tx,
2067 hooks,
2068 AgentEvent::error(format!("Failed to get history: {e}"), false),
2069 )
2070 .await;
2071 return InternalTurnResult::Error(AgentError::new(
2072 format!("Failed to get history: {e}"),
2073 false,
2074 ));
2075 }
2076 };
2077
2078 if let Some(compact_config) = compaction_config {
2080 let compactor = LlmContextCompactor::new(Arc::clone(provider), compact_config.clone());
2081 if compactor.needs_compaction(&messages) {
2082 debug!(
2083 turn = ctx.turn,
2084 message_count = messages.len(),
2085 "Context compaction triggered"
2086 );
2087
2088 match compactor.compact_history(messages.clone()).await {
2089 Ok(result) => {
2090 if let Err(e) = message_store
2091 .replace_history(&ctx.thread_id, result.messages.clone())
2092 .await
2093 {
2094 warn!(error = %e, "Failed to replace history after compaction");
2095 } else {
2096 send_event(
2097 tx,
2098 hooks,
2099 AgentEvent::context_compacted(
2100 result.original_count,
2101 result.new_count,
2102 result.original_tokens,
2103 result.new_tokens,
2104 ),
2105 )
2106 .await;
2107
2108 info!(
2109 original_count = result.original_count,
2110 new_count = result.new_count,
2111 original_tokens = result.original_tokens,
2112 new_tokens = result.new_tokens,
2113 "Context compacted successfully"
2114 );
2115
2116 messages = result.messages;
2117 }
2118 }
2119 Err(e) => {
2120 warn!(error = %e, "Context compaction failed, continuing with full history");
2121 }
2122 }
2123 }
2124 }
2125
2126 let llm_tools = if tools.is_empty() {
2128 None
2129 } else {
2130 Some(tools.to_llm_tools())
2131 };
2132
2133 let request = ChatRequest {
2134 system: config.system_prompt.clone(),
2135 messages,
2136 tools: llm_tools,
2137 max_tokens: config.max_tokens,
2138 thinking: config.thinking.clone(),
2139 };
2140
2141 debug!(turn = ctx.turn, streaming = config.streaming, "Calling LLM");
2143 let response = if config.streaming {
2144 match call_llm_streaming(provider, request, config, tx, hooks).await {
2146 Ok(r) => r,
2147 Err(e) => {
2148 return InternalTurnResult::Error(e);
2149 }
2150 }
2151 } else {
2152 match call_llm_with_retry(provider, request, config, tx, hooks).await {
2154 Ok(r) => r,
2155 Err(e) => {
2156 return InternalTurnResult::Error(e);
2157 }
2158 }
2159 };
2160
2161 let turn_usage = TokenUsage {
2163 input_tokens: response.usage.input_tokens,
2164 output_tokens: response.usage.output_tokens,
2165 };
2166 ctx.total_usage.add(&turn_usage);
2167 ctx.state.total_usage = ctx.total_usage.clone();
2168
2169 let (thinking_content, text_content, tool_uses) = extract_content(&response);
2171
2172 if !config.streaming {
2174 if let Some(thinking) = &thinking_content {
2176 send_event(tx, hooks, AgentEvent::thinking(thinking.clone())).await;
2177 }
2178
2179 if let Some(text) = &text_content {
2181 send_event(tx, hooks, AgentEvent::text(text.clone())).await;
2182 }
2183 }
2184
2185 if tool_uses.is_empty() {
2187 info!(turn = ctx.turn, "Agent completed (no tool use)");
2188 return InternalTurnResult::Done;
2189 }
2190
2191 let assistant_msg = build_assistant_message(&response);
2193 if let Err(e) = message_store.append(&ctx.thread_id, assistant_msg).await {
2194 send_event(
2195 tx,
2196 hooks,
2197 AgentEvent::error(format!("Failed to append assistant message: {e}"), false),
2198 )
2199 .await;
2200 return InternalTurnResult::Error(AgentError::new(
2201 format!("Failed to append assistant message: {e}"),
2202 false,
2203 ));
2204 }
2205
2206 let pending_tool_calls: Vec<PendingToolCallInfo> = tool_uses
2208 .iter()
2209 .map(|(id, name, input)| {
2210 let display_name = tools
2211 .get(name)
2212 .map(|t| t.display_name().to_string())
2213 .or_else(|| tools.get_async(name).map(|t| t.display_name().to_string()))
2214 .unwrap_or_default();
2215 PendingToolCallInfo {
2216 id: id.clone(),
2217 name: name.clone(),
2218 display_name,
2219 input: input.clone(),
2220 }
2221 })
2222 .collect();
2223
2224 let mut tool_results = Vec::new();
2226 for (idx, pending) in pending_tool_calls.iter().enumerate() {
2227 if let Some(cached_result) = try_get_cached_result(execution_store, &pending.id).await {
2229 debug!(
2230 tool_call_id = pending.id,
2231 tool_name = pending.name,
2232 "Using cached result from previous execution"
2233 );
2234 tool_results.push((pending.id.clone(), cached_result));
2235 continue;
2236 }
2237
2238 if let Some(async_tool) = tools.get_async(&pending.name) {
2240 let tier = async_tool.tier();
2241
2242 send_event(
2244 tx,
2245 hooks,
2246 AgentEvent::tool_call_start(
2247 &pending.id,
2248 &pending.name,
2249 &pending.display_name,
2250 pending.input.clone(),
2251 tier,
2252 ),
2253 )
2254 .await;
2255
2256 let decision = hooks
2258 .pre_tool_use(&pending.name, &pending.input, tier)
2259 .await;
2260
2261 match decision {
2262 ToolDecision::Allow => {
2263 let started_at = time::OffsetDateTime::now_utc();
2265 record_execution_start(execution_store, pending, &ctx.thread_id, started_at)
2266 .await;
2267
2268 let result = execute_async_tool(pending, async_tool, tool_context, tx).await;
2269
2270 record_execution_complete(
2272 execution_store,
2273 pending,
2274 &ctx.thread_id,
2275 &result,
2276 started_at,
2277 )
2278 .await;
2279
2280 hooks.post_tool_use(&pending.name, &result).await;
2281
2282 send_event(
2283 tx,
2284 hooks,
2285 AgentEvent::tool_call_end(
2286 &pending.id,
2287 &pending.name,
2288 &pending.display_name,
2289 result.clone(),
2290 ),
2291 )
2292 .await;
2293
2294 tool_results.push((pending.id.clone(), result));
2295 }
2296 ToolDecision::Block(reason) => {
2297 let result = ToolResult::error(format!("Blocked: {reason}"));
2298 send_event(
2299 tx,
2300 hooks,
2301 AgentEvent::tool_call_end(
2302 &pending.id,
2303 &pending.name,
2304 &pending.display_name,
2305 result.clone(),
2306 ),
2307 )
2308 .await;
2309 tool_results.push((pending.id.clone(), result));
2310 }
2311 ToolDecision::RequiresConfirmation(description) => {
2312 send_event(
2314 tx,
2315 hooks,
2316 AgentEvent::ToolRequiresConfirmation {
2317 id: pending.id.clone(),
2318 name: pending.name.clone(),
2319 input: pending.input.clone(),
2320 description: description.clone(),
2321 },
2322 )
2323 .await;
2324
2325 let continuation = AgentContinuation {
2326 thread_id: ctx.thread_id.clone(),
2327 turn: ctx.turn,
2328 total_usage: ctx.total_usage.clone(),
2329 turn_usage: turn_usage.clone(),
2330 pending_tool_calls: pending_tool_calls.clone(),
2331 awaiting_index: idx,
2332 completed_results: tool_results,
2333 state: ctx.state.clone(),
2334 };
2335
2336 return InternalTurnResult::AwaitingConfirmation {
2337 tool_call_id: pending.id.clone(),
2338 tool_name: pending.name.clone(),
2339 display_name: pending.display_name.clone(),
2340 input: pending.input.clone(),
2341 description,
2342 continuation: Box::new(continuation),
2343 };
2344 }
2345 }
2346 continue;
2347 }
2348
2349 let Some(tool) = tools.get(&pending.name) else {
2351 let result = ToolResult::error(format!("Unknown tool: {}", pending.name));
2352 tool_results.push((pending.id.clone(), result));
2353 continue;
2354 };
2355
2356 let tier = tool.tier();
2357
2358 send_event(
2360 tx,
2361 hooks,
2362 AgentEvent::tool_call_start(
2363 &pending.id,
2364 &pending.name,
2365 &pending.display_name,
2366 pending.input.clone(),
2367 tier,
2368 ),
2369 )
2370 .await;
2371
2372 let decision = hooks
2374 .pre_tool_use(&pending.name, &pending.input, tier)
2375 .await;
2376
2377 match decision {
2378 ToolDecision::Allow => {
2379 let started_at = time::OffsetDateTime::now_utc();
2381 record_execution_start(execution_store, pending, &ctx.thread_id, started_at).await;
2382
2383 let tool_start = Instant::now();
2384 let result = match tool.execute(tool_context, pending.input.clone()).await {
2385 Ok(mut r) => {
2386 r.duration_ms = Some(millis_to_u64(tool_start.elapsed().as_millis()));
2387 r
2388 }
2389 Err(e) => ToolResult::error(format!("Tool error: {e}"))
2390 .with_duration(millis_to_u64(tool_start.elapsed().as_millis())),
2391 };
2392
2393 record_execution_complete(
2395 execution_store,
2396 pending,
2397 &ctx.thread_id,
2398 &result,
2399 started_at,
2400 )
2401 .await;
2402
2403 hooks.post_tool_use(&pending.name, &result).await;
2404
2405 send_event(
2406 tx,
2407 hooks,
2408 AgentEvent::tool_call_end(
2409 &pending.id,
2410 &pending.name,
2411 &pending.display_name,
2412 result.clone(),
2413 ),
2414 )
2415 .await;
2416
2417 tool_results.push((pending.id.clone(), result));
2418 }
2419 ToolDecision::Block(reason) => {
2420 let result = ToolResult::error(format!("Blocked: {reason}"));
2421 send_event(
2422 tx,
2423 hooks,
2424 AgentEvent::tool_call_end(
2425 &pending.id,
2426 &pending.name,
2427 &pending.display_name,
2428 result.clone(),
2429 ),
2430 )
2431 .await;
2432 tool_results.push((pending.id.clone(), result));
2433 }
2434 ToolDecision::RequiresConfirmation(description) => {
2435 send_event(
2437 tx,
2438 hooks,
2439 AgentEvent::ToolRequiresConfirmation {
2440 id: pending.id.clone(),
2441 name: pending.name.clone(),
2442 input: pending.input.clone(),
2443 description: description.clone(),
2444 },
2445 )
2446 .await;
2447
2448 let continuation = AgentContinuation {
2449 thread_id: ctx.thread_id.clone(),
2450 turn: ctx.turn,
2451 total_usage: ctx.total_usage.clone(),
2452 turn_usage: turn_usage.clone(),
2453 pending_tool_calls: pending_tool_calls.clone(),
2454 awaiting_index: idx,
2455 completed_results: tool_results,
2456 state: ctx.state.clone(),
2457 };
2458
2459 return InternalTurnResult::AwaitingConfirmation {
2460 tool_call_id: pending.id.clone(),
2461 tool_name: pending.name.clone(),
2462 display_name: pending.display_name.clone(),
2463 input: pending.input.clone(),
2464 description,
2465 continuation: Box::new(continuation),
2466 };
2467 }
2468 }
2469 }
2470
2471 if let Err(e) = append_tool_results(&tool_results, &ctx.thread_id, message_store).await {
2473 send_event(
2474 tx,
2475 hooks,
2476 AgentEvent::error(format!("Failed to append tool results: {e}"), false),
2477 )
2478 .await;
2479 return InternalTurnResult::Error(e);
2480 }
2481
2482 send_event(
2484 tx,
2485 hooks,
2486 AgentEvent::TurnComplete {
2487 turn: ctx.turn,
2488 usage: turn_usage.clone(),
2489 },
2490 )
2491 .await;
2492
2493 if response.stop_reason == Some(StopReason::EndTurn) {
2495 info!(turn = ctx.turn, "Agent completed (end_turn)");
2496 return InternalTurnResult::Done;
2497 }
2498
2499 InternalTurnResult::Continue { turn_usage }
2500}
2501
2502#[allow(clippy::cast_possible_truncation)]
2504const fn millis_to_u64(millis: u128) -> u64 {
2505 if millis > u64::MAX as u128 {
2506 u64::MAX
2507 } else {
2508 millis as u64
2509 }
2510}
2511
2512fn calculate_backoff_delay(attempt: u32, config: &RetryConfig) -> Duration {
2517 let base_delay = config
2519 .base_delay_ms
2520 .saturating_mul(1u64 << (attempt.saturating_sub(1)));
2521
2522 let max_jitter = config.base_delay_ms.min(1000);
2524 let jitter = if max_jitter > 0 {
2525 u64::from(
2526 std::time::SystemTime::now()
2527 .duration_since(std::time::UNIX_EPOCH)
2528 .unwrap_or_default()
2529 .subsec_nanos(),
2530 ) % max_jitter
2531 } else {
2532 0
2533 };
2534
2535 let delay_ms = base_delay.saturating_add(jitter).min(config.max_delay_ms);
2536 Duration::from_millis(delay_ms)
2537}
2538
2539type ExtractedContent = (
2541 Option<String>,
2542 Option<String>,
2543 Vec<(String, String, serde_json::Value)>,
2544);
2545
2546fn extract_content(response: &ChatResponse) -> ExtractedContent {
2548 let mut thinking_parts = Vec::new();
2549 let mut text_parts = Vec::new();
2550 let mut tool_uses = Vec::new();
2551
2552 for block in &response.content {
2553 match block {
2554 ContentBlock::Text { text } => {
2555 text_parts.push(text.clone());
2556 }
2557 ContentBlock::Thinking { thinking } => {
2558 thinking_parts.push(thinking.clone());
2559 }
2560 ContentBlock::ToolUse {
2561 id, name, input, ..
2562 } => {
2563 tool_uses.push((id.clone(), name.clone(), input.clone()));
2564 }
2565 ContentBlock::ToolResult { .. } => {
2566 }
2568 }
2569 }
2570
2571 let thinking = if thinking_parts.is_empty() {
2572 None
2573 } else {
2574 Some(thinking_parts.join("\n"))
2575 };
2576
2577 let text = if text_parts.is_empty() {
2578 None
2579 } else {
2580 Some(text_parts.join("\n"))
2581 };
2582
2583 (thinking, text, tool_uses)
2584}
2585
2586async fn send_event<H>(tx: &mpsc::Sender<AgentEvent>, hooks: &Arc<H>, event: AgentEvent)
2587where
2588 H: AgentHooks,
2589{
2590 hooks.on_event(&event).await;
2591 let _ = tx.send(event).await;
2592}
2593
2594fn build_assistant_message(response: &ChatResponse) -> Message {
2595 let mut blocks = Vec::new();
2596
2597 for block in &response.content {
2598 match block {
2599 ContentBlock::Text { text } => {
2600 blocks.push(ContentBlock::Text { text: text.clone() });
2601 }
2602 ContentBlock::Thinking { .. } | ContentBlock::ToolResult { .. } => {
2603 }
2606 ContentBlock::ToolUse {
2607 id,
2608 name,
2609 input,
2610 thought_signature,
2611 } => {
2612 blocks.push(ContentBlock::ToolUse {
2613 id: id.clone(),
2614 name: name.clone(),
2615 input: input.clone(),
2616 thought_signature: thought_signature.clone(),
2617 });
2618 }
2619 }
2620 }
2621
2622 Message {
2623 role: Role::Assistant,
2624 content: Content::Blocks(blocks),
2625 }
2626}
2627
2628#[cfg(test)]
2629mod tests {
2630 use super::*;
2631 use crate::hooks::AllowAllHooks;
2632 use crate::llm::{ChatOutcome, ChatRequest, ChatResponse, ContentBlock, StopReason, Usage};
2633 use crate::stores::InMemoryStore;
2634 use crate::tools::{Tool, ToolContext, ToolRegistry};
2635 use crate::types::{AgentConfig, AgentInput, ToolResult, ToolTier};
2636 use anyhow::Result;
2637 use async_trait::async_trait;
2638 use serde_json::json;
2639 use std::sync::RwLock;
2640 use std::sync::atomic::{AtomicUsize, Ordering};
2641
2642 struct MockProvider {
2647 responses: RwLock<Vec<ChatOutcome>>,
2648 call_count: AtomicUsize,
2649 }
2650
2651 impl MockProvider {
2652 fn new(responses: Vec<ChatOutcome>) -> Self {
2653 Self {
2654 responses: RwLock::new(responses),
2655 call_count: AtomicUsize::new(0),
2656 }
2657 }
2658
2659 fn text_response(text: &str) -> ChatOutcome {
2660 ChatOutcome::Success(ChatResponse {
2661 id: "msg_1".to_string(),
2662 content: vec![ContentBlock::Text {
2663 text: text.to_string(),
2664 }],
2665 model: "mock-model".to_string(),
2666 stop_reason: Some(StopReason::EndTurn),
2667 usage: Usage {
2668 input_tokens: 10,
2669 output_tokens: 20,
2670 },
2671 })
2672 }
2673
2674 fn tool_use_response(
2675 tool_id: &str,
2676 tool_name: &str,
2677 input: serde_json::Value,
2678 ) -> ChatOutcome {
2679 ChatOutcome::Success(ChatResponse {
2680 id: "msg_1".to_string(),
2681 content: vec![ContentBlock::ToolUse {
2682 id: tool_id.to_string(),
2683 name: tool_name.to_string(),
2684 input,
2685 thought_signature: None,
2686 }],
2687 model: "mock-model".to_string(),
2688 stop_reason: Some(StopReason::ToolUse),
2689 usage: Usage {
2690 input_tokens: 10,
2691 output_tokens: 20,
2692 },
2693 })
2694 }
2695 }
2696
2697 #[async_trait]
2698 impl LlmProvider for MockProvider {
2699 async fn chat(&self, _request: ChatRequest) -> Result<ChatOutcome> {
2700 let idx = self.call_count.fetch_add(1, Ordering::SeqCst);
2701 let responses = self.responses.read().unwrap();
2702 if idx < responses.len() {
2703 Ok(responses[idx].clone())
2704 } else {
2705 Ok(Self::text_response("Done"))
2707 }
2708 }
2709
2710 fn model(&self) -> &'static str {
2711 "mock-model"
2712 }
2713
2714 fn provider(&self) -> &'static str {
2715 "mock"
2716 }
2717 }
2718
2719 impl Clone for ChatOutcome {
2721 fn clone(&self) -> Self {
2722 match self {
2723 Self::Success(r) => Self::Success(r.clone()),
2724 Self::RateLimited => Self::RateLimited,
2725 Self::InvalidRequest(s) => Self::InvalidRequest(s.clone()),
2726 Self::ServerError(s) => Self::ServerError(s.clone()),
2727 }
2728 }
2729 }
2730
2731 struct EchoTool;
2736
2737 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
2739 #[serde(rename_all = "snake_case")]
2740 enum TestToolName {
2741 Echo,
2742 }
2743
2744 impl crate::tools::ToolName for TestToolName {}
2745
2746 impl Tool<()> for EchoTool {
2747 type Name = TestToolName;
2748
2749 fn name(&self) -> TestToolName {
2750 TestToolName::Echo
2751 }
2752
2753 fn display_name(&self) -> &'static str {
2754 "Echo"
2755 }
2756
2757 fn description(&self) -> &'static str {
2758 "Echo the input message"
2759 }
2760
2761 fn input_schema(&self) -> serde_json::Value {
2762 json!({
2763 "type": "object",
2764 "properties": {
2765 "message": { "type": "string" }
2766 },
2767 "required": ["message"]
2768 })
2769 }
2770
2771 fn tier(&self) -> ToolTier {
2772 ToolTier::Observe
2773 }
2774
2775 async fn execute(
2776 &self,
2777 _ctx: &ToolContext<()>,
2778 input: serde_json::Value,
2779 ) -> Result<ToolResult> {
2780 let message = input
2781 .get("message")
2782 .and_then(|v| v.as_str())
2783 .unwrap_or("no message");
2784 Ok(ToolResult::success(format!("Echo: {message}")))
2785 }
2786 }
2787
2788 #[test]
2793 fn test_builder_creates_agent_loop() {
2794 let provider = MockProvider::new(vec![]);
2795 let agent = builder::<()>().provider(provider).build();
2796
2797 assert_eq!(agent.config.max_turns, 10);
2798 assert_eq!(agent.config.max_tokens, 4096);
2799 }
2800
2801 #[test]
2802 fn test_builder_with_custom_config() {
2803 let provider = MockProvider::new(vec![]);
2804 let config = AgentConfig {
2805 max_turns: 5,
2806 max_tokens: 2048,
2807 system_prompt: "Custom prompt".to_string(),
2808 model: "custom-model".to_string(),
2809 ..Default::default()
2810 };
2811
2812 let agent = builder::<()>().provider(provider).config(config).build();
2813
2814 assert_eq!(agent.config.max_turns, 5);
2815 assert_eq!(agent.config.max_tokens, 2048);
2816 assert_eq!(agent.config.system_prompt, "Custom prompt");
2817 }
2818
2819 #[test]
2820 fn test_builder_with_tools() {
2821 let provider = MockProvider::new(vec![]);
2822 let mut tools = ToolRegistry::new();
2823 tools.register(EchoTool);
2824
2825 let agent = builder::<()>().provider(provider).tools(tools).build();
2826
2827 assert_eq!(agent.tools.len(), 1);
2828 }
2829
2830 #[test]
2831 fn test_builder_with_custom_stores() {
2832 let provider = MockProvider::new(vec![]);
2833 let message_store = InMemoryStore::new();
2834 let state_store = InMemoryStore::new();
2835
2836 let agent = builder::<()>()
2837 .provider(provider)
2838 .hooks(AllowAllHooks)
2839 .message_store(message_store)
2840 .state_store(state_store)
2841 .build_with_stores();
2842
2843 assert_eq!(agent.config.max_turns, 10);
2845 }
2846
2847 #[tokio::test]
2852 async fn test_simple_text_response() -> anyhow::Result<()> {
2853 let provider = MockProvider::new(vec![MockProvider::text_response("Hello, user!")]);
2854
2855 let agent = builder::<()>().provider(provider).build();
2856
2857 let thread_id = ThreadId::new();
2858 let tool_ctx = ToolContext::new(());
2859 let (mut rx, _final_state) =
2860 agent.run(thread_id, AgentInput::Text("Hi".to_string()), tool_ctx);
2861
2862 let mut events = Vec::new();
2863 while let Some(event) = rx.recv().await {
2864 events.push(event);
2865 }
2866
2867 assert!(events.iter().any(|e| matches!(e, AgentEvent::Text { .. })));
2869 assert!(events.iter().any(|e| matches!(e, AgentEvent::Done { .. })));
2870
2871 Ok(())
2872 }
2873
2874 #[tokio::test]
2875 async fn test_tool_execution() -> anyhow::Result<()> {
2876 let provider = MockProvider::new(vec![
2877 MockProvider::tool_use_response("tool_1", "echo", json!({"message": "test"})),
2879 MockProvider::text_response("Tool executed successfully"),
2881 ]);
2882
2883 let mut tools = ToolRegistry::new();
2884 tools.register(EchoTool);
2885
2886 let agent = builder::<()>().provider(provider).tools(tools).build();
2887
2888 let thread_id = ThreadId::new();
2889 let tool_ctx = ToolContext::new(());
2890 let (mut rx, _final_state) = agent.run(
2891 thread_id,
2892 AgentInput::Text("Run echo".to_string()),
2893 tool_ctx,
2894 );
2895
2896 let mut events = Vec::new();
2897 while let Some(event) = rx.recv().await {
2898 events.push(event);
2899 }
2900
2901 assert!(
2903 events
2904 .iter()
2905 .any(|e| matches!(e, AgentEvent::ToolCallStart { .. }))
2906 );
2907 assert!(
2908 events
2909 .iter()
2910 .any(|e| matches!(e, AgentEvent::ToolCallEnd { .. }))
2911 );
2912
2913 Ok(())
2914 }
2915
2916 #[tokio::test]
2917 async fn test_max_turns_limit() -> anyhow::Result<()> {
2918 let provider = MockProvider::new(vec![
2920 MockProvider::tool_use_response("tool_1", "echo", json!({"message": "1"})),
2921 MockProvider::tool_use_response("tool_2", "echo", json!({"message": "2"})),
2922 MockProvider::tool_use_response("tool_3", "echo", json!({"message": "3"})),
2923 MockProvider::tool_use_response("tool_4", "echo", json!({"message": "4"})),
2924 ]);
2925
2926 let mut tools = ToolRegistry::new();
2927 tools.register(EchoTool);
2928
2929 let config = AgentConfig {
2930 max_turns: 2,
2931 ..Default::default()
2932 };
2933
2934 let agent = builder::<()>()
2935 .provider(provider)
2936 .tools(tools)
2937 .config(config)
2938 .build();
2939
2940 let thread_id = ThreadId::new();
2941 let tool_ctx = ToolContext::new(());
2942 let (mut rx, _final_state) =
2943 agent.run(thread_id, AgentInput::Text("Loop".to_string()), tool_ctx);
2944
2945 let mut events = Vec::new();
2946 while let Some(event) = rx.recv().await {
2947 events.push(event);
2948 }
2949
2950 assert!(events.iter().any(|e| {
2952 matches!(e, AgentEvent::Error { message, .. } if message.contains("Maximum turns"))
2953 }));
2954
2955 Ok(())
2956 }
2957
2958 #[tokio::test]
2959 async fn test_unknown_tool_handling() -> anyhow::Result<()> {
2960 let provider = MockProvider::new(vec![
2961 MockProvider::tool_use_response("tool_1", "nonexistent_tool", json!({})),
2963 MockProvider::text_response("I couldn't find that tool."),
2965 ]);
2966
2967 let tools = ToolRegistry::new();
2969
2970 let agent = builder::<()>().provider(provider).tools(tools).build();
2971
2972 let thread_id = ThreadId::new();
2973 let tool_ctx = ToolContext::new(());
2974 let (mut rx, _final_state) = agent.run(
2975 thread_id,
2976 AgentInput::Text("Call unknown".to_string()),
2977 tool_ctx,
2978 );
2979
2980 let mut events = Vec::new();
2981 while let Some(event) = rx.recv().await {
2982 events.push(event);
2983 }
2984
2985 assert!(events.iter().any(|e| matches!(e, AgentEvent::Done { .. })));
2988
2989 assert!(
2991 events.iter().any(|e| {
2992 matches!(e, AgentEvent::Text { text } if text.contains("couldn't find"))
2993 })
2994 );
2995
2996 Ok(())
2997 }
2998
2999 #[tokio::test]
3000 async fn test_rate_limit_handling() -> anyhow::Result<()> {
3001 let provider = MockProvider::new(vec![
3003 ChatOutcome::RateLimited,
3004 ChatOutcome::RateLimited,
3005 ChatOutcome::RateLimited,
3006 ChatOutcome::RateLimited,
3007 ChatOutcome::RateLimited,
3008 ChatOutcome::RateLimited, ]);
3010
3011 let config = AgentConfig {
3013 retry: crate::types::RetryConfig::fast(),
3014 ..Default::default()
3015 };
3016
3017 let agent = builder::<()>().provider(provider).config(config).build();
3018
3019 let thread_id = ThreadId::new();
3020 let tool_ctx = ToolContext::new(());
3021 let (mut rx, _final_state) =
3022 agent.run(thread_id, AgentInput::Text("Hi".to_string()), tool_ctx);
3023
3024 let mut events = Vec::new();
3025 while let Some(event) = rx.recv().await {
3026 events.push(event);
3027 }
3028
3029 assert!(events.iter().any(|e| {
3031 matches!(e, AgentEvent::Error { message, recoverable: true } if message.contains("Rate limited"))
3032 }));
3033
3034 assert!(
3036 events
3037 .iter()
3038 .any(|e| { matches!(e, AgentEvent::Text { text } if text.contains("retrying")) })
3039 );
3040
3041 Ok(())
3042 }
3043
3044 #[tokio::test]
3045 async fn test_rate_limit_recovery() -> anyhow::Result<()> {
3046 let provider = MockProvider::new(vec![
3048 ChatOutcome::RateLimited,
3049 MockProvider::text_response("Recovered after rate limit"),
3050 ]);
3051
3052 let config = AgentConfig {
3054 retry: crate::types::RetryConfig::fast(),
3055 ..Default::default()
3056 };
3057
3058 let agent = builder::<()>().provider(provider).config(config).build();
3059
3060 let thread_id = ThreadId::new();
3061 let tool_ctx = ToolContext::new(());
3062 let (mut rx, _final_state) =
3063 agent.run(thread_id, AgentInput::Text("Hi".to_string()), tool_ctx);
3064
3065 let mut events = Vec::new();
3066 while let Some(event) = rx.recv().await {
3067 events.push(event);
3068 }
3069
3070 assert!(events.iter().any(|e| matches!(e, AgentEvent::Done { .. })));
3072
3073 assert!(
3075 events
3076 .iter()
3077 .any(|e| { matches!(e, AgentEvent::Text { text } if text.contains("retrying")) })
3078 );
3079
3080 Ok(())
3081 }
3082
3083 #[tokio::test]
3084 async fn test_server_error_handling() -> anyhow::Result<()> {
3085 let provider = MockProvider::new(vec![
3087 ChatOutcome::ServerError("Internal error".to_string()),
3088 ChatOutcome::ServerError("Internal error".to_string()),
3089 ChatOutcome::ServerError("Internal error".to_string()),
3090 ChatOutcome::ServerError("Internal error".to_string()),
3091 ChatOutcome::ServerError("Internal error".to_string()),
3092 ChatOutcome::ServerError("Internal error".to_string()), ]);
3094
3095 let config = AgentConfig {
3097 retry: crate::types::RetryConfig::fast(),
3098 ..Default::default()
3099 };
3100
3101 let agent = builder::<()>().provider(provider).config(config).build();
3102
3103 let thread_id = ThreadId::new();
3104 let tool_ctx = ToolContext::new(());
3105 let (mut rx, _final_state) =
3106 agent.run(thread_id, AgentInput::Text("Hi".to_string()), tool_ctx);
3107
3108 let mut events = Vec::new();
3109 while let Some(event) = rx.recv().await {
3110 events.push(event);
3111 }
3112
3113 assert!(events.iter().any(|e| {
3115 matches!(e, AgentEvent::Error { message, recoverable: true } if message.contains("Server error"))
3116 }));
3117
3118 assert!(
3120 events
3121 .iter()
3122 .any(|e| { matches!(e, AgentEvent::Text { text } if text.contains("retrying")) })
3123 );
3124
3125 Ok(())
3126 }
3127
3128 #[tokio::test]
3129 async fn test_server_error_recovery() -> anyhow::Result<()> {
3130 let provider = MockProvider::new(vec![
3132 ChatOutcome::ServerError("Temporary error".to_string()),
3133 MockProvider::text_response("Recovered after server error"),
3134 ]);
3135
3136 let config = AgentConfig {
3138 retry: crate::types::RetryConfig::fast(),
3139 ..Default::default()
3140 };
3141
3142 let agent = builder::<()>().provider(provider).config(config).build();
3143
3144 let thread_id = ThreadId::new();
3145 let tool_ctx = ToolContext::new(());
3146 let (mut rx, _final_state) =
3147 agent.run(thread_id, AgentInput::Text("Hi".to_string()), tool_ctx);
3148
3149 let mut events = Vec::new();
3150 while let Some(event) = rx.recv().await {
3151 events.push(event);
3152 }
3153
3154 assert!(events.iter().any(|e| matches!(e, AgentEvent::Done { .. })));
3156
3157 assert!(
3159 events
3160 .iter()
3161 .any(|e| { matches!(e, AgentEvent::Text { text } if text.contains("retrying")) })
3162 );
3163
3164 Ok(())
3165 }
3166
3167 #[test]
3172 fn test_extract_content_text_only() {
3173 let response = ChatResponse {
3174 id: "msg_1".to_string(),
3175 content: vec![ContentBlock::Text {
3176 text: "Hello".to_string(),
3177 }],
3178 model: "test".to_string(),
3179 stop_reason: None,
3180 usage: Usage {
3181 input_tokens: 0,
3182 output_tokens: 0,
3183 },
3184 };
3185
3186 let (thinking, text, tool_uses) = extract_content(&response);
3187 assert!(thinking.is_none());
3188 assert_eq!(text, Some("Hello".to_string()));
3189 assert!(tool_uses.is_empty());
3190 }
3191
3192 #[test]
3193 fn test_extract_content_tool_use() {
3194 let response = ChatResponse {
3195 id: "msg_1".to_string(),
3196 content: vec![ContentBlock::ToolUse {
3197 id: "tool_1".to_string(),
3198 name: "test_tool".to_string(),
3199 input: json!({"key": "value"}),
3200 thought_signature: None,
3201 }],
3202 model: "test".to_string(),
3203 stop_reason: None,
3204 usage: Usage {
3205 input_tokens: 0,
3206 output_tokens: 0,
3207 },
3208 };
3209
3210 let (thinking, text, tool_uses) = extract_content(&response);
3211 assert!(thinking.is_none());
3212 assert!(text.is_none());
3213 assert_eq!(tool_uses.len(), 1);
3214 assert_eq!(tool_uses[0].1, "test_tool");
3215 }
3216
3217 #[test]
3218 fn test_extract_content_mixed() {
3219 let response = ChatResponse {
3220 id: "msg_1".to_string(),
3221 content: vec![
3222 ContentBlock::Text {
3223 text: "Let me help".to_string(),
3224 },
3225 ContentBlock::ToolUse {
3226 id: "tool_1".to_string(),
3227 name: "helper".to_string(),
3228 input: json!({}),
3229 thought_signature: None,
3230 },
3231 ],
3232 model: "test".to_string(),
3233 stop_reason: None,
3234 usage: Usage {
3235 input_tokens: 0,
3236 output_tokens: 0,
3237 },
3238 };
3239
3240 let (thinking, text, tool_uses) = extract_content(&response);
3241 assert!(thinking.is_none());
3242 assert_eq!(text, Some("Let me help".to_string()));
3243 assert_eq!(tool_uses.len(), 1);
3244 }
3245
3246 #[test]
3247 fn test_millis_to_u64() {
3248 assert_eq!(millis_to_u64(0), 0);
3249 assert_eq!(millis_to_u64(1000), 1000);
3250 assert_eq!(millis_to_u64(u128::from(u64::MAX)), u64::MAX);
3251 assert_eq!(millis_to_u64(u128::from(u64::MAX) + 1), u64::MAX);
3252 }
3253
3254 #[test]
3255 fn test_build_assistant_message() {
3256 let response = ChatResponse {
3257 id: "msg_1".to_string(),
3258 content: vec![
3259 ContentBlock::Text {
3260 text: "Response text".to_string(),
3261 },
3262 ContentBlock::ToolUse {
3263 id: "tool_1".to_string(),
3264 name: "echo".to_string(),
3265 input: json!({"message": "test"}),
3266 thought_signature: None,
3267 },
3268 ],
3269 model: "test".to_string(),
3270 stop_reason: None,
3271 usage: Usage {
3272 input_tokens: 0,
3273 output_tokens: 0,
3274 },
3275 };
3276
3277 let msg = build_assistant_message(&response);
3278 assert_eq!(msg.role, Role::Assistant);
3279
3280 if let Content::Blocks(blocks) = msg.content {
3281 assert_eq!(blocks.len(), 2);
3282 } else {
3283 panic!("Expected Content::Blocks");
3284 }
3285 }
3286}