1use crate::context::{CompactionConfig, ContextCompactor, LlmContextCompactor};
31use crate::events::AgentEvent;
32use crate::hooks::{AgentHooks, DefaultHooks, ToolDecision};
33use crate::llm::{
34 ChatOutcome, ChatRequest, ChatResponse, Content, ContentBlock, LlmProvider, Message, Role,
35 StopReason, StreamAccumulator, StreamDelta, Usage,
36};
37use crate::skills::Skill;
38use crate::stores::{InMemoryStore, MessageStore, StateStore, ToolExecutionStore};
39use crate::tools::{ErasedAsyncTool, ErasedToolStatus, ToolContext, ToolRegistry};
40use crate::types::{
41 AgentConfig, AgentContinuation, AgentError, AgentInput, AgentRunState, AgentState,
42 ExecutionStatus, PendingToolCallInfo, RetryConfig, ThreadId, TokenUsage, ToolExecution,
43 ToolOutcome, ToolResult, TurnOutcome,
44};
45use futures::StreamExt;
46use log::{debug, error, info, warn};
47use std::sync::Arc;
48use std::time::{Duration, Instant};
49use tokio::sync::{mpsc, oneshot};
50use tokio::time::sleep;
51
52enum InternalTurnResult {
56 Continue { turn_usage: TokenUsage },
58 Done,
60 AwaitingConfirmation {
62 tool_call_id: String,
63 tool_name: String,
64 display_name: String,
65 input: serde_json::Value,
66 description: String,
67 continuation: Box<AgentContinuation>,
68 },
69 Error(AgentError),
71}
72
73struct TurnContext {
77 thread_id: ThreadId,
78 turn: usize,
79 total_usage: TokenUsage,
80 state: AgentState,
81 start_time: Instant,
82}
83
84struct ResumeData {
86 continuation: Box<AgentContinuation>,
87 tool_call_id: String,
88 confirmed: bool,
89 rejection_reason: Option<String>,
90}
91
92struct InitializedState {
94 turn: usize,
95 total_usage: TokenUsage,
96 state: AgentState,
97 resume_data: Option<ResumeData>,
98}
99
100enum ToolExecutionOutcome {
102 Completed { tool_id: String, result: ToolResult },
104 RequiresConfirmation {
106 tool_id: String,
107 tool_name: String,
108 display_name: String,
109 input: serde_json::Value,
110 description: String,
111 },
112}
113
114pub struct AgentLoopBuilder<Ctx, P, H, M, S> {
126 provider: Option<P>,
127 tools: Option<ToolRegistry<Ctx>>,
128 hooks: Option<H>,
129 message_store: Option<M>,
130 state_store: Option<S>,
131 config: Option<AgentConfig>,
132 compaction_config: Option<CompactionConfig>,
133 execution_store: Option<Arc<dyn ToolExecutionStore>>,
134}
135
136impl<Ctx> AgentLoopBuilder<Ctx, (), (), (), ()> {
137 #[must_use]
139 pub fn new() -> Self {
140 Self {
141 provider: None,
142 tools: None,
143 hooks: None,
144 message_store: None,
145 state_store: None,
146 config: None,
147 compaction_config: None,
148 execution_store: None,
149 }
150 }
151}
152
153impl<Ctx> Default for AgentLoopBuilder<Ctx, (), (), (), ()> {
154 fn default() -> Self {
155 Self::new()
156 }
157}
158
159impl<Ctx, P, H, M, S> AgentLoopBuilder<Ctx, P, H, M, S> {
160 #[must_use]
162 pub fn provider<P2: LlmProvider>(self, provider: P2) -> AgentLoopBuilder<Ctx, P2, H, M, S> {
163 AgentLoopBuilder {
164 provider: Some(provider),
165 tools: self.tools,
166 hooks: self.hooks,
167 message_store: self.message_store,
168 state_store: self.state_store,
169 config: self.config,
170 compaction_config: self.compaction_config,
171 execution_store: self.execution_store,
172 }
173 }
174
175 #[must_use]
177 pub fn tools(mut self, tools: ToolRegistry<Ctx>) -> Self {
178 self.tools = Some(tools);
179 self
180 }
181
182 #[must_use]
184 pub fn hooks<H2: AgentHooks>(self, hooks: H2) -> AgentLoopBuilder<Ctx, P, H2, M, S> {
185 AgentLoopBuilder {
186 provider: self.provider,
187 tools: self.tools,
188 hooks: Some(hooks),
189 message_store: self.message_store,
190 state_store: self.state_store,
191 config: self.config,
192 compaction_config: self.compaction_config,
193 execution_store: self.execution_store,
194 }
195 }
196
197 #[must_use]
199 pub fn message_store<M2: MessageStore>(
200 self,
201 message_store: M2,
202 ) -> AgentLoopBuilder<Ctx, P, H, M2, S> {
203 AgentLoopBuilder {
204 provider: self.provider,
205 tools: self.tools,
206 hooks: self.hooks,
207 message_store: Some(message_store),
208 state_store: self.state_store,
209 config: self.config,
210 compaction_config: self.compaction_config,
211 execution_store: self.execution_store,
212 }
213 }
214
215 #[must_use]
217 pub fn state_store<S2: StateStore>(
218 self,
219 state_store: S2,
220 ) -> AgentLoopBuilder<Ctx, P, H, M, S2> {
221 AgentLoopBuilder {
222 provider: self.provider,
223 tools: self.tools,
224 hooks: self.hooks,
225 message_store: self.message_store,
226 state_store: Some(state_store),
227 config: self.config,
228 compaction_config: self.compaction_config,
229 execution_store: self.execution_store,
230 }
231 }
232
233 #[must_use]
251 pub fn execution_store(mut self, store: impl ToolExecutionStore + 'static) -> Self {
252 self.execution_store = Some(Arc::new(store));
253 self
254 }
255
256 #[must_use]
258 pub fn config(mut self, config: AgentConfig) -> Self {
259 self.config = Some(config);
260 self
261 }
262
263 #[must_use]
279 pub const fn with_compaction(mut self, config: CompactionConfig) -> Self {
280 self.compaction_config = Some(config);
281 self
282 }
283
284 #[must_use]
291 pub fn with_auto_compaction(self) -> Self {
292 self.with_compaction(CompactionConfig::default())
293 }
294
295 #[must_use]
313 pub fn with_skill(mut self, skill: Skill) -> Self
314 where
315 Ctx: Send + Sync + 'static,
316 {
317 if let Some(ref mut tools) = self.tools {
319 tools.filter(|name| skill.is_tool_allowed(name));
320 }
321
322 let mut config = self.config.take().unwrap_or_default();
324 if config.system_prompt.is_empty() {
325 config.system_prompt = skill.system_prompt;
326 } else {
327 config.system_prompt = format!("{}\n\n{}", config.system_prompt, skill.system_prompt);
328 }
329 self.config = Some(config);
330
331 self
332 }
333}
334
335impl<Ctx, P> AgentLoopBuilder<Ctx, P, (), (), ()>
336where
337 Ctx: Send + Sync + 'static,
338 P: LlmProvider + 'static,
339{
340 #[must_use]
352 pub fn build(self) -> AgentLoop<Ctx, P, DefaultHooks, InMemoryStore, InMemoryStore> {
353 let provider = self.provider.expect("provider is required");
354 let tools = self.tools.unwrap_or_default();
355 let config = self.config.unwrap_or_default();
356
357 AgentLoop {
358 provider: Arc::new(provider),
359 tools: Arc::new(tools),
360 hooks: Arc::new(DefaultHooks),
361 message_store: Arc::new(InMemoryStore::new()),
362 state_store: Arc::new(InMemoryStore::new()),
363 config,
364 compaction_config: self.compaction_config,
365 execution_store: self.execution_store,
366 }
367 }
368}
369
370impl<Ctx, P, H, M, S> AgentLoopBuilder<Ctx, P, H, M, S>
371where
372 Ctx: Send + Sync + 'static,
373 P: LlmProvider + 'static,
374 H: AgentHooks + 'static,
375 M: MessageStore + 'static,
376 S: StateStore + 'static,
377{
378 #[must_use]
388 pub fn build_with_stores(self) -> AgentLoop<Ctx, P, H, M, S> {
389 let provider = self.provider.expect("provider is required");
390 let tools = self.tools.unwrap_or_default();
391 let hooks = self
392 .hooks
393 .expect("hooks is required when using build_with_stores");
394 let message_store = self
395 .message_store
396 .expect("message_store is required when using build_with_stores");
397 let state_store = self
398 .state_store
399 .expect("state_store is required when using build_with_stores");
400 let config = self.config.unwrap_or_default();
401
402 AgentLoop {
403 provider: Arc::new(provider),
404 tools: Arc::new(tools),
405 hooks: Arc::new(hooks),
406 message_store: Arc::new(message_store),
407 state_store: Arc::new(state_store),
408 config,
409 compaction_config: self.compaction_config,
410 execution_store: self.execution_store,
411 }
412 }
413}
414
415pub struct AgentLoop<Ctx, P, H, M, S>
461where
462 P: LlmProvider,
463 H: AgentHooks,
464 M: MessageStore,
465 S: StateStore,
466{
467 provider: Arc<P>,
468 tools: Arc<ToolRegistry<Ctx>>,
469 hooks: Arc<H>,
470 message_store: Arc<M>,
471 state_store: Arc<S>,
472 config: AgentConfig,
473 compaction_config: Option<CompactionConfig>,
474 execution_store: Option<Arc<dyn ToolExecutionStore>>,
475}
476
477#[must_use]
479pub fn builder<Ctx>() -> AgentLoopBuilder<Ctx, (), (), (), ()> {
480 AgentLoopBuilder::new()
481}
482
483impl<Ctx, P, H, M, S> AgentLoop<Ctx, P, H, M, S>
484where
485 Ctx: Send + Sync + 'static,
486 P: LlmProvider + 'static,
487 H: AgentHooks + 'static,
488 M: MessageStore + 'static,
489 S: StateStore + 'static,
490{
491 #[must_use]
493 pub fn new(
494 provider: P,
495 tools: ToolRegistry<Ctx>,
496 hooks: H,
497 message_store: M,
498 state_store: S,
499 config: AgentConfig,
500 ) -> Self {
501 Self {
502 provider: Arc::new(provider),
503 tools: Arc::new(tools),
504 hooks: Arc::new(hooks),
505 message_store: Arc::new(message_store),
506 state_store: Arc::new(state_store),
507 config,
508 compaction_config: None,
509 execution_store: None,
510 }
511 }
512
513 #[must_use]
515 pub fn with_compaction(
516 provider: P,
517 tools: ToolRegistry<Ctx>,
518 hooks: H,
519 message_store: M,
520 state_store: S,
521 config: AgentConfig,
522 compaction_config: CompactionConfig,
523 ) -> Self {
524 Self {
525 provider: Arc::new(provider),
526 tools: Arc::new(tools),
527 hooks: Arc::new(hooks),
528 message_store: Arc::new(message_store),
529 state_store: Arc::new(state_store),
530 config,
531 compaction_config: Some(compaction_config),
532 execution_store: None,
533 }
534 }
535
536 pub fn run(
586 &self,
587 thread_id: ThreadId,
588 input: AgentInput,
589 tool_context: ToolContext<Ctx>,
590 ) -> (mpsc::Receiver<AgentEvent>, oneshot::Receiver<AgentRunState>)
591 where
592 Ctx: Clone,
593 {
594 let (event_tx, event_rx) = mpsc::channel(100);
595 let (state_tx, state_rx) = oneshot::channel();
596
597 let provider = Arc::clone(&self.provider);
598 let tools = Arc::clone(&self.tools);
599 let hooks = Arc::clone(&self.hooks);
600 let message_store = Arc::clone(&self.message_store);
601 let state_store = Arc::clone(&self.state_store);
602 let config = self.config.clone();
603 let compaction_config = self.compaction_config.clone();
604 let execution_store = self.execution_store.clone();
605
606 tokio::spawn(async move {
607 let result = run_loop(
608 event_tx,
609 thread_id,
610 input,
611 tool_context,
612 provider,
613 tools,
614 hooks,
615 message_store,
616 state_store,
617 config,
618 compaction_config,
619 execution_store,
620 )
621 .await;
622
623 let _ = state_tx.send(result);
624 });
625
626 (event_rx, state_rx)
627 }
628
629 pub fn run_turn(
685 &self,
686 thread_id: ThreadId,
687 input: AgentInput,
688 tool_context: ToolContext<Ctx>,
689 ) -> (mpsc::Receiver<AgentEvent>, oneshot::Receiver<TurnOutcome>)
690 where
691 Ctx: Clone,
692 {
693 let (event_tx, event_rx) = mpsc::channel(100);
694 let (outcome_tx, outcome_rx) = oneshot::channel();
695
696 let provider = Arc::clone(&self.provider);
697 let tools = Arc::clone(&self.tools);
698 let hooks = Arc::clone(&self.hooks);
699 let message_store = Arc::clone(&self.message_store);
700 let state_store = Arc::clone(&self.state_store);
701 let config = self.config.clone();
702 let compaction_config = self.compaction_config.clone();
703 let execution_store = self.execution_store.clone();
704
705 tokio::spawn(async move {
706 let result = run_single_turn(TurnParameters {
707 tx: event_tx,
708 thread_id,
709 input,
710 tool_context,
711 provider,
712 tools,
713 hooks,
714 message_store,
715 state_store,
716 config,
717 compaction_config,
718 execution_store,
719 })
720 .await;
721
722 let _ = outcome_tx.send(result);
723 });
724
725 (event_rx, outcome_rx)
726 }
727}
728
729async fn initialize_from_input<M, S>(
740 input: AgentInput,
741 thread_id: &ThreadId,
742 message_store: &Arc<M>,
743 state_store: &Arc<S>,
744) -> Result<InitializedState, AgentError>
745where
746 M: MessageStore,
747 S: StateStore,
748{
749 match input {
750 AgentInput::Text(user_message) => {
751 let state = match state_store.load(thread_id).await {
753 Ok(Some(s)) => s,
754 Ok(None) => AgentState::new(thread_id.clone()),
755 Err(e) => {
756 return Err(AgentError::new(format!("Failed to load state: {e}"), false));
757 }
758 };
759
760 let user_msg = Message::user(&user_message);
762 if let Err(e) = message_store.append(thread_id, user_msg).await {
763 return Err(AgentError::new(
764 format!("Failed to append message: {e}"),
765 false,
766 ));
767 }
768
769 Ok(InitializedState {
770 turn: 0,
771 total_usage: TokenUsage::default(),
772 state,
773 resume_data: None,
774 })
775 }
776 AgentInput::Resume {
777 continuation,
778 tool_call_id,
779 confirmed,
780 rejection_reason,
781 } => {
782 if continuation.thread_id != *thread_id {
784 return Err(AgentError::new(
785 format!(
786 "Thread ID mismatch: continuation is for {}, but resuming on {}",
787 continuation.thread_id, thread_id
788 ),
789 false,
790 ));
791 }
792
793 Ok(InitializedState {
794 turn: continuation.turn,
795 total_usage: continuation.total_usage.clone(),
796 state: continuation.state.clone(),
797 resume_data: Some(ResumeData {
798 continuation,
799 tool_call_id,
800 confirmed,
801 rejection_reason,
802 }),
803 })
804 }
805 AgentInput::Continue => {
806 let state = match state_store.load(thread_id).await {
808 Ok(Some(s)) => s,
809 Ok(None) => {
810 return Err(AgentError::new(
811 "Cannot continue: no state found for thread",
812 false,
813 ));
814 }
815 Err(e) => {
816 return Err(AgentError::new(format!("Failed to load state: {e}"), false));
817 }
818 };
819
820 Ok(InitializedState {
822 turn: state.turn_count,
823 total_usage: state.total_usage.clone(),
824 state,
825 resume_data: None,
826 })
827 }
828 }
829}
830
831#[allow(clippy::too_many_lines)]
840async fn execute_tool_call<Ctx, H>(
841 pending: &PendingToolCallInfo,
842 tool_context: &ToolContext<Ctx>,
843 tools: &ToolRegistry<Ctx>,
844 hooks: &Arc<H>,
845 tx: &mpsc::Sender<AgentEvent>,
846) -> ToolExecutionOutcome
847where
848 Ctx: Send + Sync + Clone + 'static,
849 H: AgentHooks,
850{
851 if let Some(async_tool) = tools.get_async(&pending.name) {
853 let tier = async_tool.tier();
854
855 let _ = tx
857 .send(AgentEvent::tool_call_start(
858 &pending.id,
859 &pending.name,
860 &pending.display_name,
861 pending.input.clone(),
862 tier,
863 ))
864 .await;
865
866 let decision = hooks
868 .pre_tool_use(&pending.name, &pending.input, tier)
869 .await;
870
871 return match decision {
872 ToolDecision::Allow => {
873 let result = execute_async_tool(pending, async_tool, tool_context, tx).await;
874
875 hooks.post_tool_use(&pending.name, &result).await;
876
877 send_event(
878 tx,
879 hooks,
880 AgentEvent::tool_call_end(
881 &pending.id,
882 &pending.name,
883 &pending.display_name,
884 result.clone(),
885 ),
886 )
887 .await;
888
889 ToolExecutionOutcome::Completed {
890 tool_id: pending.id.clone(),
891 result,
892 }
893 }
894 ToolDecision::Block(reason) => {
895 let result = ToolResult::error(format!("Blocked: {reason}"));
896 send_event(
897 tx,
898 hooks,
899 AgentEvent::tool_call_end(
900 &pending.id,
901 &pending.name,
902 &pending.display_name,
903 result.clone(),
904 ),
905 )
906 .await;
907 ToolExecutionOutcome::Completed {
908 tool_id: pending.id.clone(),
909 result,
910 }
911 }
912 ToolDecision::RequiresConfirmation(description) => {
913 send_event(
914 tx,
915 hooks,
916 AgentEvent::ToolRequiresConfirmation {
917 id: pending.id.clone(),
918 name: pending.name.clone(),
919 input: pending.input.clone(),
920 description: description.clone(),
921 },
922 )
923 .await;
924
925 ToolExecutionOutcome::RequiresConfirmation {
926 tool_id: pending.id.clone(),
927 tool_name: pending.name.clone(),
928 display_name: pending.display_name.clone(),
929 input: pending.input.clone(),
930 description,
931 }
932 }
933 };
934 }
935
936 let Some(tool) = tools.get(&pending.name) else {
938 let result = ToolResult::error(format!("Unknown tool: {}", pending.name));
939 return ToolExecutionOutcome::Completed {
940 tool_id: pending.id.clone(),
941 result,
942 };
943 };
944
945 let tier = tool.tier();
946
947 let _ = tx
949 .send(AgentEvent::tool_call_start(
950 &pending.id,
951 &pending.name,
952 &pending.display_name,
953 pending.input.clone(),
954 tier,
955 ))
956 .await;
957
958 let decision = hooks
960 .pre_tool_use(&pending.name, &pending.input, tier)
961 .await;
962
963 match decision {
964 ToolDecision::Allow => {
965 let tool_start = Instant::now();
966 let result = match tool.execute(tool_context, pending.input.clone()).await {
967 Ok(mut r) => {
968 r.duration_ms = Some(millis_to_u64(tool_start.elapsed().as_millis()));
969 r
970 }
971 Err(e) => ToolResult::error(format!("Tool error: {e}"))
972 .with_duration(millis_to_u64(tool_start.elapsed().as_millis())),
973 };
974
975 hooks.post_tool_use(&pending.name, &result).await;
976
977 send_event(
978 tx,
979 hooks,
980 AgentEvent::tool_call_end(
981 &pending.id,
982 &pending.name,
983 &pending.display_name,
984 result.clone(),
985 ),
986 )
987 .await;
988
989 ToolExecutionOutcome::Completed {
990 tool_id: pending.id.clone(),
991 result,
992 }
993 }
994 ToolDecision::Block(reason) => {
995 let result = ToolResult::error(format!("Blocked: {reason}"));
996 send_event(
997 tx,
998 hooks,
999 AgentEvent::tool_call_end(
1000 &pending.id,
1001 &pending.name,
1002 &pending.display_name,
1003 result.clone(),
1004 ),
1005 )
1006 .await;
1007 ToolExecutionOutcome::Completed {
1008 tool_id: pending.id.clone(),
1009 result,
1010 }
1011 }
1012 ToolDecision::RequiresConfirmation(description) => {
1013 send_event(
1014 tx,
1015 hooks,
1016 AgentEvent::ToolRequiresConfirmation {
1017 id: pending.id.clone(),
1018 name: pending.name.clone(),
1019 input: pending.input.clone(),
1020 description: description.clone(),
1021 },
1022 )
1023 .await;
1024
1025 ToolExecutionOutcome::RequiresConfirmation {
1026 tool_id: pending.id.clone(),
1027 tool_name: pending.name.clone(),
1028 display_name: pending.display_name.clone(),
1029 input: pending.input.clone(),
1030 description,
1031 }
1032 }
1033 }
1034}
1035
1036async fn execute_async_tool<Ctx>(
1042 pending: &PendingToolCallInfo,
1043 tool: &Arc<dyn ErasedAsyncTool<Ctx>>,
1044 tool_context: &ToolContext<Ctx>,
1045 tx: &mpsc::Sender<AgentEvent>,
1046) -> ToolResult
1047where
1048 Ctx: Send + Sync + Clone,
1049{
1050 let tool_start = Instant::now();
1051
1052 let outcome = match tool.execute(tool_context, pending.input.clone()).await {
1054 Ok(o) => o,
1055 Err(e) => {
1056 return ToolResult::error(format!("Tool error: {e}"))
1057 .with_duration(millis_to_u64(tool_start.elapsed().as_millis()));
1058 }
1059 };
1060
1061 match outcome {
1062 ToolOutcome::Success(mut result) | ToolOutcome::Failed(mut result) => {
1064 result.duration_ms = Some(millis_to_u64(tool_start.elapsed().as_millis()));
1065 result
1066 }
1067
1068 ToolOutcome::InProgress {
1070 operation_id,
1071 message,
1072 } => {
1073 let _ = tx
1075 .send(AgentEvent::tool_progress(
1076 &pending.id,
1077 &pending.name,
1078 &pending.display_name,
1079 "started",
1080 &message,
1081 None,
1082 ))
1083 .await;
1084
1085 let mut stream = tool.check_status_stream(tool_context, &operation_id);
1087
1088 while let Some(status) = stream.next().await {
1089 match status {
1090 ErasedToolStatus::Progress {
1091 stage,
1092 message,
1093 data,
1094 } => {
1095 let _ = tx
1096 .send(AgentEvent::tool_progress(
1097 &pending.id,
1098 &pending.name,
1099 &pending.display_name,
1100 stage,
1101 message,
1102 data,
1103 ))
1104 .await;
1105 }
1106 ErasedToolStatus::Completed(mut result)
1107 | ErasedToolStatus::Failed(mut result) => {
1108 result.duration_ms = Some(millis_to_u64(tool_start.elapsed().as_millis()));
1109 return result;
1110 }
1111 }
1112 }
1113
1114 ToolResult::error("Async tool stream ended without completion")
1116 .with_duration(millis_to_u64(tool_start.elapsed().as_millis()))
1117 }
1118 }
1119}
1120
1121async fn execute_confirmed_tool<Ctx, H>(
1126 awaiting_tool: &PendingToolCallInfo,
1127 confirmed: bool,
1128 rejection_reason: Option<String>,
1129 tool_context: &ToolContext<Ctx>,
1130 tools: &ToolRegistry<Ctx>,
1131 hooks: &Arc<H>,
1132 tx: &mpsc::Sender<AgentEvent>,
1133) -> ToolResult
1134where
1135 Ctx: Send + Sync + Clone + 'static,
1136 H: AgentHooks,
1137{
1138 if confirmed {
1139 if let Some(async_tool) = tools.get_async(&awaiting_tool.name) {
1141 let result = execute_async_tool(awaiting_tool, async_tool, tool_context, tx).await;
1142
1143 hooks.post_tool_use(&awaiting_tool.name, &result).await;
1144
1145 let _ = tx
1146 .send(AgentEvent::tool_call_end(
1147 &awaiting_tool.id,
1148 &awaiting_tool.name,
1149 &awaiting_tool.display_name,
1150 result.clone(),
1151 ))
1152 .await;
1153
1154 return result;
1155 }
1156
1157 if let Some(tool) = tools.get(&awaiting_tool.name) {
1159 let tool_start = Instant::now();
1160 let result = match tool
1161 .execute(tool_context, awaiting_tool.input.clone())
1162 .await
1163 {
1164 Ok(mut r) => {
1165 r.duration_ms = Some(millis_to_u64(tool_start.elapsed().as_millis()));
1166 r
1167 }
1168 Err(e) => ToolResult::error(format!("Tool error: {e}"))
1169 .with_duration(millis_to_u64(tool_start.elapsed().as_millis())),
1170 };
1171
1172 hooks.post_tool_use(&awaiting_tool.name, &result).await;
1173
1174 let _ = tx
1175 .send(AgentEvent::tool_call_end(
1176 &awaiting_tool.id,
1177 &awaiting_tool.name,
1178 &awaiting_tool.display_name,
1179 result.clone(),
1180 ))
1181 .await;
1182
1183 result
1184 } else {
1185 ToolResult::error(format!("Unknown tool: {}", awaiting_tool.name))
1186 }
1187 } else {
1188 let reason = rejection_reason.unwrap_or_else(|| "User rejected".to_string());
1189 let result = ToolResult::error(format!("Rejected: {reason}"));
1190 send_event(
1191 tx,
1192 hooks,
1193 AgentEvent::tool_call_end(
1194 &awaiting_tool.id,
1195 &awaiting_tool.name,
1196 &awaiting_tool.display_name,
1197 result.clone(),
1198 ),
1199 )
1200 .await;
1201 result
1202 }
1203}
1204
1205async fn append_tool_results<M>(
1207 tool_results: &[(String, ToolResult)],
1208 thread_id: &ThreadId,
1209 message_store: &Arc<M>,
1210) -> Result<(), AgentError>
1211where
1212 M: MessageStore,
1213{
1214 for (tool_id, result) in tool_results {
1215 let tool_result_msg = Message::tool_result(tool_id, &result.output, !result.success);
1216 if let Err(e) = message_store.append(thread_id, tool_result_msg).await {
1217 return Err(AgentError::new(
1218 format!("Failed to append tool result: {e}"),
1219 false,
1220 ));
1221 }
1222 }
1223 Ok(())
1224}
1225
1226async fn call_llm_with_retry<P, H>(
1228 provider: &Arc<P>,
1229 request: ChatRequest,
1230 config: &AgentConfig,
1231 tx: &mpsc::Sender<AgentEvent>,
1232 hooks: &Arc<H>,
1233) -> Result<ChatResponse, AgentError>
1234where
1235 P: LlmProvider,
1236 H: AgentHooks,
1237{
1238 let max_retries = config.retry.max_retries;
1239 let mut attempt = 0u32;
1240
1241 loop {
1242 let outcome = match provider.chat(request.clone()).await {
1243 Ok(o) => o,
1244 Err(e) => {
1245 return Err(AgentError::new(format!("LLM error: {e}"), false));
1246 }
1247 };
1248
1249 match outcome {
1250 ChatOutcome::Success(response) => return Ok(response),
1251 ChatOutcome::RateLimited => {
1252 attempt += 1;
1253 if attempt > max_retries {
1254 error!("Rate limited by LLM provider after {max_retries} retries");
1255 let error_msg = format!("Rate limited after {max_retries} retries");
1256 send_event(tx, hooks, AgentEvent::error(&error_msg, true)).await;
1257 return Err(AgentError::new(error_msg, true));
1258 }
1259 let delay = calculate_backoff_delay(attempt, &config.retry);
1260 warn!(
1261 "Rate limited, retrying after backoff (attempt={}, delay_ms={})",
1262 attempt,
1263 delay.as_millis()
1264 );
1265 let _ = tx
1266 .send(AgentEvent::text(format!(
1267 "\n[Rate limited, retrying in {:.1}s... (attempt {attempt}/{max_retries})]\n",
1268 delay.as_secs_f64()
1269 )))
1270 .await;
1271 sleep(delay).await;
1272 }
1273 ChatOutcome::InvalidRequest(msg) => {
1274 error!("Invalid request to LLM: {msg}");
1275 return Err(AgentError::new(format!("Invalid request: {msg}"), false));
1276 }
1277 ChatOutcome::ServerError(msg) => {
1278 attempt += 1;
1279 if attempt > max_retries {
1280 error!("LLM server error after {max_retries} retries: {msg}");
1281 let error_msg = format!("Server error after {max_retries} retries: {msg}");
1282 send_event(tx, hooks, AgentEvent::error(&error_msg, true)).await;
1283 return Err(AgentError::new(error_msg, true));
1284 }
1285 let delay = calculate_backoff_delay(attempt, &config.retry);
1286 warn!(
1287 "Server error, retrying after backoff (attempt={attempt}, delay_ms={}, error={msg})",
1288 delay.as_millis()
1289 );
1290 send_event(
1291 tx,
1292 hooks,
1293 AgentEvent::text(format!(
1294 "\n[Server error: {msg}, retrying in {:.1}s... (attempt {attempt}/{max_retries})]\n",
1295 delay.as_secs_f64()
1296 )),
1297 )
1298 .await;
1299 sleep(delay).await;
1300 }
1301 }
1302 }
1303}
1304
1305async fn call_llm_streaming<P, H>(
1311 provider: &Arc<P>,
1312 request: ChatRequest,
1313 config: &AgentConfig,
1314 tx: &mpsc::Sender<AgentEvent>,
1315 hooks: &Arc<H>,
1316) -> Result<ChatResponse, AgentError>
1317where
1318 P: LlmProvider,
1319 H: AgentHooks,
1320{
1321 let max_retries = config.retry.max_retries;
1322 let mut attempt = 0u32;
1323
1324 loop {
1325 let result = process_stream(provider, &request, tx, hooks).await;
1326
1327 match result {
1328 Ok(response) => return Ok(response),
1329 Err(StreamError::Recoverable(msg)) => {
1330 attempt += 1;
1331 if attempt > max_retries {
1332 error!("Streaming error after {max_retries} retries: {msg}");
1333 let err_msg = format!("Streaming error after {max_retries} retries: {msg}");
1334 send_event(tx, hooks, AgentEvent::error(&err_msg, true)).await;
1335 return Err(AgentError::new(err_msg, true));
1336 }
1337 let delay = calculate_backoff_delay(attempt, &config.retry);
1338 warn!(
1339 "Streaming error, retrying (attempt={attempt}, delay_ms={}, error={msg})",
1340 delay.as_millis()
1341 );
1342 send_event(
1343 tx,
1344 hooks,
1345 AgentEvent::text(format!(
1346 "\n[Streaming error: {msg}, retrying in {:.1}s... (attempt {attempt}/{max_retries})]\n",
1347 delay.as_secs_f64()
1348 )),
1349 )
1350 .await;
1351 sleep(delay).await;
1352 }
1353 Err(StreamError::Fatal(msg)) => {
1354 error!("Streaming error (non-recoverable): {msg}");
1355 return Err(AgentError::new(format!("Streaming error: {msg}"), false));
1356 }
1357 }
1358 }
1359}
1360
1361enum StreamError {
1363 Recoverable(String),
1364 Fatal(String),
1365}
1366
1367async fn process_stream<P, H>(
1369 provider: &Arc<P>,
1370 request: &ChatRequest,
1371 tx: &mpsc::Sender<AgentEvent>,
1372 hooks: &Arc<H>,
1373) -> Result<ChatResponse, StreamError>
1374where
1375 P: LlmProvider,
1376 H: AgentHooks,
1377{
1378 let mut stream = std::pin::pin!(provider.chat_stream(request.clone()));
1379 let mut accumulator = StreamAccumulator::new();
1380 let mut delta_count: u64 = 0;
1381
1382 log::debug!("Starting to consume LLM stream");
1383
1384 let mut channel_closed = false;
1386
1387 while let Some(result) = stream.next().await {
1388 if delta_count > 0 && delta_count.is_multiple_of(50) {
1390 log::debug!("Stream progress: delta_count={delta_count}");
1391 }
1392
1393 match result {
1394 Ok(delta) => {
1395 delta_count += 1;
1396 accumulator.apply(&delta);
1397 match &delta {
1398 StreamDelta::TextDelta { delta, .. } => {
1399 if !channel_closed {
1401 if tx.is_closed() {
1402 log::warn!(
1403 "Event channel closed by receiver at delta_count={delta_count} - consumer may have disconnected"
1404 );
1405 channel_closed = true;
1406 } else {
1407 send_event(tx, hooks, AgentEvent::text_delta(delta.clone())).await;
1408 }
1409 }
1410 }
1411 StreamDelta::ThinkingDelta { delta, .. } => {
1412 if !channel_closed {
1413 if tx.is_closed() {
1414 log::warn!(
1415 "Event channel closed by receiver at delta_count={delta_count}"
1416 );
1417 channel_closed = true;
1418 } else {
1419 send_event(tx, hooks, AgentEvent::thinking(delta.clone())).await;
1420 }
1421 }
1422 }
1423 StreamDelta::Error {
1424 message,
1425 recoverable,
1426 } => {
1427 log::warn!(
1428 "Stream error received delta_count={delta_count} message={message} recoverable={recoverable}"
1429 );
1430 return if *recoverable {
1431 Err(StreamError::Recoverable(message.clone()))
1432 } else {
1433 Err(StreamError::Fatal(message.clone()))
1434 };
1435 }
1436 StreamDelta::Done { .. }
1438 | StreamDelta::Usage(_)
1439 | StreamDelta::ToolUseStart { .. }
1440 | StreamDelta::ToolInputDelta { .. } => {}
1441 }
1442 }
1443 Err(e) => {
1444 log::error!("Stream iteration error delta_count={delta_count} error={e}");
1445 return Err(StreamError::Recoverable(format!("Stream error: {e}")));
1446 }
1447 }
1448 }
1449
1450 log::debug!("Stream while loop exited normally at delta_count={delta_count}");
1451
1452 let usage = accumulator.usage().cloned().unwrap_or(Usage {
1453 input_tokens: 0,
1454 output_tokens: 0,
1455 });
1456 let stop_reason = accumulator.stop_reason().copied();
1457 let content_blocks = accumulator.into_content_blocks();
1458
1459 log::debug!(
1460 "LLM stream completed successfully delta_count={delta_count} stop_reason={stop_reason:?} content_block_count={} input_tokens={} output_tokens={}",
1461 content_blocks.len(),
1462 usage.input_tokens,
1463 usage.output_tokens
1464 );
1465
1466 Ok(ChatResponse {
1467 id: String::new(),
1468 content: content_blocks,
1469 model: provider.model().to_string(),
1470 stop_reason,
1471 usage,
1472 })
1473}
1474
1475#[allow(clippy::too_many_arguments, clippy::too_many_lines)]
1480async fn run_loop<Ctx, P, H, M, S>(
1481 tx: mpsc::Sender<AgentEvent>,
1482 thread_id: ThreadId,
1483 input: AgentInput,
1484 tool_context: ToolContext<Ctx>,
1485 provider: Arc<P>,
1486 tools: Arc<ToolRegistry<Ctx>>,
1487 hooks: Arc<H>,
1488 message_store: Arc<M>,
1489 state_store: Arc<S>,
1490 config: AgentConfig,
1491 compaction_config: Option<CompactionConfig>,
1492 execution_store: Option<Arc<dyn ToolExecutionStore>>,
1493) -> AgentRunState
1494where
1495 Ctx: Send + Sync + Clone + 'static,
1496 P: LlmProvider,
1497 H: AgentHooks,
1498 M: MessageStore,
1499 S: StateStore,
1500{
1501 let tool_context = tool_context.with_event_tx(tx.clone());
1503 let start_time = Instant::now();
1504
1505 let init_state =
1507 match initialize_from_input(input, &thread_id, &message_store, &state_store).await {
1508 Ok(s) => s,
1509 Err(e) => return AgentRunState::Error(e),
1510 };
1511
1512 let InitializedState {
1513 turn,
1514 total_usage,
1515 state,
1516 resume_data,
1517 } = init_state;
1518
1519 if let Some(resume) = resume_data {
1520 let ResumeData {
1521 continuation: cont,
1522 tool_call_id,
1523 confirmed,
1524 rejection_reason,
1525 } = resume;
1526 let mut tool_results = cont.completed_results.clone();
1527 let awaiting_tool = &cont.pending_tool_calls[cont.awaiting_index];
1528
1529 if awaiting_tool.id != tool_call_id {
1530 let message = format!(
1531 "Tool call ID mismatch: expected {}, got {}",
1532 awaiting_tool.id, tool_call_id
1533 );
1534 let recoverable = false;
1535 send_event(&tx, &hooks, AgentEvent::error(&message, recoverable)).await;
1536 return AgentRunState::Error(AgentError::new(&message, recoverable));
1537 }
1538
1539 let result = execute_confirmed_tool(
1540 awaiting_tool,
1541 confirmed,
1542 rejection_reason,
1543 &tool_context,
1544 &tools,
1545 &hooks,
1546 &tx,
1547 )
1548 .await;
1549 tool_results.push((awaiting_tool.id.clone(), result));
1550
1551 for pending in cont.pending_tool_calls.iter().skip(cont.awaiting_index + 1) {
1552 match execute_tool_call(pending, &tool_context, &tools, &hooks, &tx).await {
1553 ToolExecutionOutcome::Completed { tool_id, result } => {
1554 tool_results.push((tool_id, result));
1555 }
1556 ToolExecutionOutcome::RequiresConfirmation {
1557 tool_id,
1558 tool_name,
1559 display_name,
1560 input,
1561 description,
1562 } => {
1563 let pending_idx = cont
1564 .pending_tool_calls
1565 .iter()
1566 .position(|p| p.id == tool_id)
1567 .unwrap_or(0);
1568
1569 let new_continuation = AgentContinuation {
1570 thread_id: thread_id.clone(),
1571 turn,
1572 total_usage: total_usage.clone(),
1573 turn_usage: cont.turn_usage.clone(),
1574 pending_tool_calls: cont.pending_tool_calls.clone(),
1575 awaiting_index: pending_idx,
1576 completed_results: tool_results,
1577 state: state.clone(),
1578 };
1579
1580 return AgentRunState::AwaitingConfirmation {
1581 tool_call_id: tool_id,
1582 tool_name,
1583 display_name,
1584 input,
1585 description,
1586 continuation: Box::new(new_continuation),
1587 };
1588 }
1589 }
1590 }
1591
1592 if let Err(e) = append_tool_results(&tool_results, &thread_id, &message_store).await {
1593 send_event(
1594 &tx,
1595 &hooks,
1596 AgentEvent::Error {
1597 message: e.message.clone(),
1598 recoverable: e.recoverable,
1599 },
1600 )
1601 .await;
1602 return AgentRunState::Error(e);
1603 }
1604
1605 send_event(
1606 &tx,
1607 &hooks,
1608 AgentEvent::TurnComplete {
1609 turn,
1610 usage: cont.turn_usage.clone(),
1611 },
1612 )
1613 .await;
1614 }
1615
1616 let mut ctx = TurnContext {
1617 thread_id: thread_id.clone(),
1618 turn,
1619 total_usage,
1620 state,
1621 start_time,
1622 };
1623
1624 loop {
1625 let result = execute_turn(
1626 &tx,
1627 &mut ctx,
1628 &tool_context,
1629 &provider,
1630 &tools,
1631 &hooks,
1632 &message_store,
1633 &config,
1634 compaction_config.as_ref(),
1635 execution_store.as_ref(),
1636 )
1637 .await;
1638
1639 match result {
1640 InternalTurnResult::Continue { .. } => {
1641 if let Err(e) = state_store.save(&ctx.state).await {
1642 warn!("Failed to save state checkpoint: {e}");
1643 }
1644 }
1645 InternalTurnResult::Done => {
1646 break;
1647 }
1648 InternalTurnResult::AwaitingConfirmation {
1649 tool_call_id,
1650 tool_name,
1651 display_name,
1652 input,
1653 description,
1654 continuation,
1655 } => {
1656 return AgentRunState::AwaitingConfirmation {
1657 tool_call_id,
1658 tool_name,
1659 display_name,
1660 input,
1661 description,
1662 continuation,
1663 };
1664 }
1665 InternalTurnResult::Error(e) => {
1666 return AgentRunState::Error(e);
1667 }
1668 }
1669 }
1670
1671 if let Err(e) = state_store.save(&ctx.state).await {
1672 warn!("Failed to save final state: {e}");
1673 }
1674
1675 let duration = ctx.start_time.elapsed();
1676 send_event(
1677 &tx,
1678 &hooks,
1679 AgentEvent::done(thread_id, ctx.turn, ctx.total_usage.clone(), duration),
1680 )
1681 .await;
1682
1683 AgentRunState::Done {
1684 total_turns: u32::try_from(ctx.turn).unwrap_or(u32::MAX),
1685 input_tokens: u64::from(ctx.total_usage.input_tokens),
1686 output_tokens: u64::from(ctx.total_usage.output_tokens),
1687 }
1688}
1689
1690struct TurnParameters<Ctx, P, H, M, S> {
1691 tx: mpsc::Sender<AgentEvent>,
1692 thread_id: ThreadId,
1693 input: AgentInput,
1694 tool_context: ToolContext<Ctx>,
1695 provider: Arc<P>,
1696 tools: Arc<ToolRegistry<Ctx>>,
1697 hooks: Arc<H>,
1698 message_store: Arc<M>,
1699 state_store: Arc<S>,
1700 config: AgentConfig,
1701 compaction_config: Option<CompactionConfig>,
1702 execution_store: Option<Arc<dyn ToolExecutionStore>>,
1703}
1704
1705async fn run_single_turn<Ctx, P, H, M, S>(
1711 TurnParameters {
1712 tx,
1713 thread_id,
1714 input,
1715 tool_context,
1716 provider,
1717 tools,
1718 hooks,
1719 message_store,
1720 state_store,
1721 config,
1722 compaction_config,
1723 execution_store,
1724 }: TurnParameters<Ctx, P, H, M, S>,
1725) -> TurnOutcome
1726where
1727 Ctx: Send + Sync + Clone + 'static,
1728 P: LlmProvider,
1729 H: AgentHooks,
1730 M: MessageStore,
1731 S: StateStore,
1732{
1733 let tool_context = tool_context.with_event_tx(tx.clone());
1734 let start_time = Instant::now();
1735
1736 let init_state =
1737 match initialize_from_input(input, &thread_id, &message_store, &state_store).await {
1738 Ok(s) => s,
1739 Err(e) => {
1740 send_event(&tx, &hooks, AgentEvent::error(&e.message, e.recoverable)).await;
1741 return TurnOutcome::Error(e);
1742 }
1743 };
1744
1745 let InitializedState {
1746 turn,
1747 total_usage,
1748 state,
1749 resume_data,
1750 } = init_state;
1751
1752 if let Some(resume_data_val) = resume_data {
1753 return handle_resume_case(ResumeCaseParameters {
1754 resume_data: resume_data_val,
1755 turn,
1756 total_usage,
1757 state,
1758 thread_id,
1759 tool_context,
1760 tools,
1761 hooks,
1762 tx,
1763 message_store,
1764 state_store,
1765 })
1766 .await;
1767 }
1768
1769 let mut ctx = TurnContext {
1770 thread_id: thread_id.clone(),
1771 turn,
1772 total_usage,
1773 state,
1774 start_time,
1775 };
1776
1777 let result = execute_turn(
1778 &tx,
1779 &mut ctx,
1780 &tool_context,
1781 &provider,
1782 &tools,
1783 &hooks,
1784 &message_store,
1785 &config,
1786 compaction_config.as_ref(),
1787 execution_store.as_ref(),
1788 )
1789 .await;
1790
1791 match result {
1793 InternalTurnResult::Continue { turn_usage } => {
1794 if let Err(e) = state_store.save(&ctx.state).await {
1796 warn!("Failed to save state checkpoint: {e}");
1797 }
1798
1799 TurnOutcome::NeedsMoreTurns {
1800 turn: ctx.turn,
1801 turn_usage,
1802 total_usage: ctx.total_usage,
1803 }
1804 }
1805 InternalTurnResult::Done => {
1806 if let Err(e) = state_store.save(&ctx.state).await {
1808 warn!("Failed to save final state: {e}");
1809 }
1810
1811 let duration = ctx.start_time.elapsed();
1813 send_event(
1814 &tx,
1815 &hooks,
1816 AgentEvent::done(thread_id, ctx.turn, ctx.total_usage.clone(), duration),
1817 )
1818 .await;
1819
1820 TurnOutcome::Done {
1821 total_turns: u32::try_from(ctx.turn).unwrap_or(u32::MAX),
1822 input_tokens: u64::from(ctx.total_usage.input_tokens),
1823 output_tokens: u64::from(ctx.total_usage.output_tokens),
1824 }
1825 }
1826 InternalTurnResult::AwaitingConfirmation {
1827 tool_call_id,
1828 tool_name,
1829 display_name,
1830 input,
1831 description,
1832 continuation,
1833 } => TurnOutcome::AwaitingConfirmation {
1834 tool_call_id,
1835 tool_name,
1836 display_name,
1837 input,
1838 description,
1839 continuation,
1840 },
1841 InternalTurnResult::Error(e) => TurnOutcome::Error(e),
1842 }
1843}
1844
1845struct ResumeCaseParameters<Ctx, H, M, S> {
1846 resume_data: ResumeData,
1847 turn: usize,
1848 total_usage: TokenUsage,
1849 state: AgentState,
1850 thread_id: ThreadId,
1851 tool_context: ToolContext<Ctx>,
1852 tools: Arc<ToolRegistry<Ctx>>,
1853 hooks: Arc<H>,
1854 tx: mpsc::Sender<AgentEvent>,
1855 message_store: Arc<M>,
1856 state_store: Arc<S>,
1857}
1858
1859async fn handle_resume_case<Ctx, H, M, S>(
1860 ResumeCaseParameters {
1861 resume_data,
1862 turn,
1863 total_usage,
1864 state,
1865 thread_id,
1866 tool_context,
1867 tools,
1868 hooks,
1869 tx,
1870 message_store,
1871 state_store,
1872 }: ResumeCaseParameters<Ctx, H, M, S>,
1873) -> TurnOutcome
1874where
1875 Ctx: Send + Sync + Clone + 'static,
1876 H: AgentHooks,
1877 M: MessageStore,
1878 S: StateStore,
1879{
1880 let ResumeData {
1881 continuation: cont,
1882 tool_call_id,
1883 confirmed,
1884 rejection_reason,
1885 } = resume_data;
1886 let mut tool_results = cont.completed_results.clone();
1888 let awaiting_tool = &cont.pending_tool_calls[cont.awaiting_index];
1889
1890 if awaiting_tool.id != tool_call_id {
1892 let message = format!(
1893 "Tool call ID mismatch: expected {}, got {}",
1894 awaiting_tool.id, tool_call_id
1895 );
1896 let recoverable = false;
1897 send_event(&tx, &hooks, AgentEvent::error(&message, recoverable)).await;
1898 return TurnOutcome::Error(AgentError::new(&message, recoverable));
1899 }
1900
1901 let result = execute_confirmed_tool(
1902 awaiting_tool,
1903 confirmed,
1904 rejection_reason,
1905 &tool_context,
1906 &tools,
1907 &hooks,
1908 &tx,
1909 )
1910 .await;
1911 tool_results.push((awaiting_tool.id.clone(), result));
1912
1913 for pending in cont.pending_tool_calls.iter().skip(cont.awaiting_index + 1) {
1914 match execute_tool_call(pending, &tool_context, &tools, &hooks, &tx).await {
1915 ToolExecutionOutcome::Completed { tool_id, result } => {
1916 tool_results.push((tool_id, result));
1917 }
1918 ToolExecutionOutcome::RequiresConfirmation {
1919 tool_id,
1920 tool_name,
1921 display_name,
1922 input,
1923 description,
1924 } => {
1925 let pending_idx = cont
1926 .pending_tool_calls
1927 .iter()
1928 .position(|p| p.id == tool_id)
1929 .unwrap_or(0);
1930
1931 let new_continuation = AgentContinuation {
1932 thread_id: thread_id.clone(),
1933 turn,
1934 total_usage: total_usage.clone(),
1935 turn_usage: cont.turn_usage.clone(),
1936 pending_tool_calls: cont.pending_tool_calls.clone(),
1937 awaiting_index: pending_idx,
1938 completed_results: tool_results,
1939 state: state.clone(),
1940 };
1941
1942 return TurnOutcome::AwaitingConfirmation {
1943 tool_call_id: tool_id,
1944 tool_name,
1945 display_name,
1946 input,
1947 description,
1948 continuation: Box::new(new_continuation),
1949 };
1950 }
1951 }
1952 }
1953
1954 if let Err(e) = append_tool_results(&tool_results, &thread_id, &message_store).await {
1955 send_event(&tx, &hooks, AgentEvent::error(&e.message, e.recoverable)).await;
1956 return TurnOutcome::Error(e);
1957 }
1958
1959 send_event(
1960 &tx,
1961 &hooks,
1962 AgentEvent::TurnComplete {
1963 turn,
1964 usage: cont.turn_usage.clone(),
1965 },
1966 )
1967 .await;
1968
1969 let mut updated_state = state;
1970 updated_state.turn_count = turn;
1971 if let Err(e) = state_store.save(&updated_state).await {
1972 warn!("Failed to save state checkpoint: {e}");
1973 }
1974
1975 TurnOutcome::NeedsMoreTurns {
1976 turn,
1977 turn_usage: cont.turn_usage.clone(),
1978 total_usage,
1979 }
1980}
1981
1982async fn try_get_cached_result(
1991 execution_store: Option<&Arc<dyn ToolExecutionStore>>,
1992 tool_call_id: &str,
1993) -> Option<ToolResult> {
1994 let store = execution_store?;
1995 let execution = store.get_execution(tool_call_id).await.ok()??;
1996
1997 match execution.status {
1998 ExecutionStatus::Completed => execution.result,
1999 ExecutionStatus::InFlight => {
2000 warn!(
2003 "Found in-flight execution from previous attempt, re-executing (tool_call_id={}, tool_name={})",
2004 tool_call_id, execution.tool_name
2005 );
2006 None
2007 }
2008 }
2009}
2010
2011async fn record_execution_start(
2013 execution_store: Option<&Arc<dyn ToolExecutionStore>>,
2014 pending: &PendingToolCallInfo,
2015 thread_id: &ThreadId,
2016 started_at: time::OffsetDateTime,
2017) {
2018 if let Some(store) = execution_store {
2019 let execution = ToolExecution::new_in_flight(
2020 &pending.id,
2021 thread_id.clone(),
2022 &pending.name,
2023 &pending.display_name,
2024 pending.input.clone(),
2025 started_at,
2026 );
2027 if let Err(e) = store.record_execution(execution).await {
2028 warn!(
2029 "Failed to record execution start (tool_call_id={}, error={})",
2030 pending.id, e
2031 );
2032 }
2033 }
2034}
2035
2036async fn record_execution_complete(
2038 execution_store: Option<&Arc<dyn ToolExecutionStore>>,
2039 pending: &PendingToolCallInfo,
2040 thread_id: &ThreadId,
2041 result: &ToolResult,
2042 started_at: time::OffsetDateTime,
2043) {
2044 if let Some(store) = execution_store {
2045 let mut execution = ToolExecution::new_in_flight(
2046 &pending.id,
2047 thread_id.clone(),
2048 &pending.name,
2049 &pending.display_name,
2050 pending.input.clone(),
2051 started_at,
2052 );
2053 execution.complete(result.clone());
2054 if let Err(e) = store.update_execution(execution).await {
2055 warn!(
2056 "Failed to record execution completion (tool_call_id={}, error={})",
2057 pending.id, e
2058 );
2059 }
2060 }
2061}
2062
2063#[allow(clippy::too_many_arguments, clippy::too_many_lines)]
2068async fn execute_turn<Ctx, P, H, M>(
2069 tx: &mpsc::Sender<AgentEvent>,
2070 ctx: &mut TurnContext,
2071 tool_context: &ToolContext<Ctx>,
2072 provider: &Arc<P>,
2073 tools: &Arc<ToolRegistry<Ctx>>,
2074 hooks: &Arc<H>,
2075 message_store: &Arc<M>,
2076 config: &AgentConfig,
2077 compaction_config: Option<&CompactionConfig>,
2078 execution_store: Option<&Arc<dyn ToolExecutionStore>>,
2079) -> InternalTurnResult
2080where
2081 Ctx: Send + Sync + Clone + 'static,
2082 P: LlmProvider,
2083 H: AgentHooks,
2084 M: MessageStore,
2085{
2086 ctx.turn += 1;
2087 ctx.state.turn_count = ctx.turn;
2088
2089 if ctx.turn > config.max_turns {
2090 warn!(
2091 "Max turns reached (turn={}, max={})",
2092 ctx.turn, config.max_turns
2093 );
2094 send_event(
2095 tx,
2096 hooks,
2097 AgentEvent::error(
2098 format!("Maximum turns ({}) reached", config.max_turns),
2099 true,
2100 ),
2101 )
2102 .await;
2103 return InternalTurnResult::Error(AgentError::new(
2104 format!("Maximum turns ({}) reached", config.max_turns),
2105 true,
2106 ));
2107 }
2108
2109 send_event(
2111 tx,
2112 hooks,
2113 AgentEvent::start(ctx.thread_id.clone(), ctx.turn),
2114 )
2115 .await;
2116
2117 let mut messages = match message_store.get_history(&ctx.thread_id).await {
2119 Ok(m) => m,
2120 Err(e) => {
2121 send_event(
2122 tx,
2123 hooks,
2124 AgentEvent::error(format!("Failed to get history: {e}"), false),
2125 )
2126 .await;
2127 return InternalTurnResult::Error(AgentError::new(
2128 format!("Failed to get history: {e}"),
2129 false,
2130 ));
2131 }
2132 };
2133
2134 if let Some(compact_config) = compaction_config {
2136 let compactor = LlmContextCompactor::new(Arc::clone(provider), compact_config.clone());
2137 if compactor.needs_compaction(&messages) {
2138 debug!(
2139 "Context compaction triggered (turn={}, message_count={})",
2140 ctx.turn,
2141 messages.len()
2142 );
2143
2144 match compactor.compact_history(messages.clone()).await {
2145 Ok(result) => {
2146 if let Err(e) = message_store
2147 .replace_history(&ctx.thread_id, result.messages.clone())
2148 .await
2149 {
2150 warn!("Failed to replace history after compaction: {e}");
2151 } else {
2152 send_event(
2153 tx,
2154 hooks,
2155 AgentEvent::context_compacted(
2156 result.original_count,
2157 result.new_count,
2158 result.original_tokens,
2159 result.new_tokens,
2160 ),
2161 )
2162 .await;
2163
2164 info!(
2165 "Context compacted successfully (original_count={}, new_count={}, original_tokens={}, new_tokens={})",
2166 result.original_count,
2167 result.new_count,
2168 result.original_tokens,
2169 result.new_tokens
2170 );
2171
2172 messages = result.messages;
2173 }
2174 }
2175 Err(e) => {
2176 warn!("Context compaction failed, continuing with full history: {e}");
2177 }
2178 }
2179 }
2180 }
2181
2182 let llm_tools = if tools.is_empty() {
2184 None
2185 } else {
2186 Some(tools.to_llm_tools())
2187 };
2188
2189 let request = ChatRequest {
2190 system: config.system_prompt.clone(),
2191 messages,
2192 tools: llm_tools,
2193 max_tokens: config.max_tokens,
2194 thinking: config.thinking.clone(),
2195 };
2196
2197 debug!(
2199 "Calling LLM (turn={}, streaming={})",
2200 ctx.turn, config.streaming
2201 );
2202 let response = if config.streaming {
2203 match call_llm_streaming(provider, request, config, tx, hooks).await {
2205 Ok(r) => r,
2206 Err(e) => {
2207 return InternalTurnResult::Error(e);
2208 }
2209 }
2210 } else {
2211 match call_llm_with_retry(provider, request, config, tx, hooks).await {
2213 Ok(r) => r,
2214 Err(e) => {
2215 return InternalTurnResult::Error(e);
2216 }
2217 }
2218 };
2219
2220 let turn_usage = TokenUsage {
2222 input_tokens: response.usage.input_tokens,
2223 output_tokens: response.usage.output_tokens,
2224 };
2225 ctx.total_usage.add(&turn_usage);
2226 ctx.state.total_usage = ctx.total_usage.clone();
2227
2228 let (thinking_content, text_content, tool_uses) = extract_content(&response);
2230
2231 if !config.streaming
2233 && let Some(thinking) = &thinking_content
2234 {
2235 send_event(tx, hooks, AgentEvent::thinking(thinking.clone())).await;
2236 }
2237
2238 if let Some(text) = &text_content {
2241 send_event(tx, hooks, AgentEvent::text(text.clone())).await;
2242 }
2243
2244 if tool_uses.is_empty() {
2246 info!("Agent completed (no tool use) (turn={})", ctx.turn);
2247 return InternalTurnResult::Done;
2248 }
2249
2250 let assistant_msg = build_assistant_message(&response);
2252 if let Err(e) = message_store.append(&ctx.thread_id, assistant_msg).await {
2253 send_event(
2254 tx,
2255 hooks,
2256 AgentEvent::error(format!("Failed to append assistant message: {e}"), false),
2257 )
2258 .await;
2259 return InternalTurnResult::Error(AgentError::new(
2260 format!("Failed to append assistant message: {e}"),
2261 false,
2262 ));
2263 }
2264
2265 let pending_tool_calls: Vec<PendingToolCallInfo> = tool_uses
2267 .iter()
2268 .map(|(id, name, input)| {
2269 let display_name = tools
2270 .get(name)
2271 .map(|t| t.display_name().to_string())
2272 .or_else(|| tools.get_async(name).map(|t| t.display_name().to_string()))
2273 .unwrap_or_default();
2274 PendingToolCallInfo {
2275 id: id.clone(),
2276 name: name.clone(),
2277 display_name,
2278 input: input.clone(),
2279 }
2280 })
2281 .collect();
2282
2283 let mut tool_results = Vec::new();
2285 for (idx, pending) in pending_tool_calls.iter().enumerate() {
2286 if let Some(cached_result) = try_get_cached_result(execution_store, &pending.id).await {
2288 debug!(
2289 "Using cached result from previous execution (tool_call_id={}, tool_name={})",
2290 pending.id, pending.name
2291 );
2292 tool_results.push((pending.id.clone(), cached_result));
2293 continue;
2294 }
2295
2296 if let Some(async_tool) = tools.get_async(&pending.name) {
2298 let tier = async_tool.tier();
2299
2300 send_event(
2302 tx,
2303 hooks,
2304 AgentEvent::tool_call_start(
2305 &pending.id,
2306 &pending.name,
2307 &pending.display_name,
2308 pending.input.clone(),
2309 tier,
2310 ),
2311 )
2312 .await;
2313
2314 let decision = hooks
2316 .pre_tool_use(&pending.name, &pending.input, tier)
2317 .await;
2318
2319 match decision {
2320 ToolDecision::Allow => {
2321 let started_at = time::OffsetDateTime::now_utc();
2323 record_execution_start(execution_store, pending, &ctx.thread_id, started_at)
2324 .await;
2325
2326 let result = execute_async_tool(pending, async_tool, tool_context, tx).await;
2327
2328 record_execution_complete(
2330 execution_store,
2331 pending,
2332 &ctx.thread_id,
2333 &result,
2334 started_at,
2335 )
2336 .await;
2337
2338 hooks.post_tool_use(&pending.name, &result).await;
2339
2340 send_event(
2341 tx,
2342 hooks,
2343 AgentEvent::tool_call_end(
2344 &pending.id,
2345 &pending.name,
2346 &pending.display_name,
2347 result.clone(),
2348 ),
2349 )
2350 .await;
2351
2352 tool_results.push((pending.id.clone(), result));
2353 }
2354 ToolDecision::Block(reason) => {
2355 let result = ToolResult::error(format!("Blocked: {reason}"));
2356 send_event(
2357 tx,
2358 hooks,
2359 AgentEvent::tool_call_end(
2360 &pending.id,
2361 &pending.name,
2362 &pending.display_name,
2363 result.clone(),
2364 ),
2365 )
2366 .await;
2367 tool_results.push((pending.id.clone(), result));
2368 }
2369 ToolDecision::RequiresConfirmation(description) => {
2370 send_event(
2372 tx,
2373 hooks,
2374 AgentEvent::ToolRequiresConfirmation {
2375 id: pending.id.clone(),
2376 name: pending.name.clone(),
2377 input: pending.input.clone(),
2378 description: description.clone(),
2379 },
2380 )
2381 .await;
2382
2383 let continuation = AgentContinuation {
2384 thread_id: ctx.thread_id.clone(),
2385 turn: ctx.turn,
2386 total_usage: ctx.total_usage.clone(),
2387 turn_usage: turn_usage.clone(),
2388 pending_tool_calls: pending_tool_calls.clone(),
2389 awaiting_index: idx,
2390 completed_results: tool_results,
2391 state: ctx.state.clone(),
2392 };
2393
2394 return InternalTurnResult::AwaitingConfirmation {
2395 tool_call_id: pending.id.clone(),
2396 tool_name: pending.name.clone(),
2397 display_name: pending.display_name.clone(),
2398 input: pending.input.clone(),
2399 description,
2400 continuation: Box::new(continuation),
2401 };
2402 }
2403 }
2404 continue;
2405 }
2406
2407 let Some(tool) = tools.get(&pending.name) else {
2409 let result = ToolResult::error(format!("Unknown tool: {}", pending.name));
2410 tool_results.push((pending.id.clone(), result));
2411 continue;
2412 };
2413
2414 let tier = tool.tier();
2415
2416 send_event(
2418 tx,
2419 hooks,
2420 AgentEvent::tool_call_start(
2421 &pending.id,
2422 &pending.name,
2423 &pending.display_name,
2424 pending.input.clone(),
2425 tier,
2426 ),
2427 )
2428 .await;
2429
2430 let decision = hooks
2432 .pre_tool_use(&pending.name, &pending.input, tier)
2433 .await;
2434
2435 match decision {
2436 ToolDecision::Allow => {
2437 let started_at = time::OffsetDateTime::now_utc();
2439 record_execution_start(execution_store, pending, &ctx.thread_id, started_at).await;
2440
2441 let tool_start = Instant::now();
2442 let result = match tool.execute(tool_context, pending.input.clone()).await {
2443 Ok(mut r) => {
2444 r.duration_ms = Some(millis_to_u64(tool_start.elapsed().as_millis()));
2445 r
2446 }
2447 Err(e) => ToolResult::error(format!("Tool error: {e}"))
2448 .with_duration(millis_to_u64(tool_start.elapsed().as_millis())),
2449 };
2450
2451 record_execution_complete(
2453 execution_store,
2454 pending,
2455 &ctx.thread_id,
2456 &result,
2457 started_at,
2458 )
2459 .await;
2460
2461 hooks.post_tool_use(&pending.name, &result).await;
2462
2463 send_event(
2464 tx,
2465 hooks,
2466 AgentEvent::tool_call_end(
2467 &pending.id,
2468 &pending.name,
2469 &pending.display_name,
2470 result.clone(),
2471 ),
2472 )
2473 .await;
2474
2475 tool_results.push((pending.id.clone(), result));
2476 }
2477 ToolDecision::Block(reason) => {
2478 let result = ToolResult::error(format!("Blocked: {reason}"));
2479 send_event(
2480 tx,
2481 hooks,
2482 AgentEvent::tool_call_end(
2483 &pending.id,
2484 &pending.name,
2485 &pending.display_name,
2486 result.clone(),
2487 ),
2488 )
2489 .await;
2490 tool_results.push((pending.id.clone(), result));
2491 }
2492 ToolDecision::RequiresConfirmation(description) => {
2493 send_event(
2495 tx,
2496 hooks,
2497 AgentEvent::ToolRequiresConfirmation {
2498 id: pending.id.clone(),
2499 name: pending.name.clone(),
2500 input: pending.input.clone(),
2501 description: description.clone(),
2502 },
2503 )
2504 .await;
2505
2506 let continuation = AgentContinuation {
2507 thread_id: ctx.thread_id.clone(),
2508 turn: ctx.turn,
2509 total_usage: ctx.total_usage.clone(),
2510 turn_usage: turn_usage.clone(),
2511 pending_tool_calls: pending_tool_calls.clone(),
2512 awaiting_index: idx,
2513 completed_results: tool_results,
2514 state: ctx.state.clone(),
2515 };
2516
2517 return InternalTurnResult::AwaitingConfirmation {
2518 tool_call_id: pending.id.clone(),
2519 tool_name: pending.name.clone(),
2520 display_name: pending.display_name.clone(),
2521 input: pending.input.clone(),
2522 description,
2523 continuation: Box::new(continuation),
2524 };
2525 }
2526 }
2527 }
2528
2529 if let Err(e) = append_tool_results(&tool_results, &ctx.thread_id, message_store).await {
2531 send_event(
2532 tx,
2533 hooks,
2534 AgentEvent::error(format!("Failed to append tool results: {e}"), false),
2535 )
2536 .await;
2537 return InternalTurnResult::Error(e);
2538 }
2539
2540 send_event(
2542 tx,
2543 hooks,
2544 AgentEvent::TurnComplete {
2545 turn: ctx.turn,
2546 usage: turn_usage.clone(),
2547 },
2548 )
2549 .await;
2550
2551 if response.stop_reason == Some(StopReason::EndTurn) {
2553 info!("Agent completed (end_turn) (turn={})", ctx.turn);
2554 return InternalTurnResult::Done;
2555 }
2556
2557 InternalTurnResult::Continue { turn_usage }
2558}
2559
2560#[allow(clippy::cast_possible_truncation)]
2562const fn millis_to_u64(millis: u128) -> u64 {
2563 if millis > u64::MAX as u128 {
2564 u64::MAX
2565 } else {
2566 millis as u64
2567 }
2568}
2569
2570fn calculate_backoff_delay(attempt: u32, config: &RetryConfig) -> Duration {
2575 let base_delay = config
2577 .base_delay_ms
2578 .saturating_mul(1u64 << (attempt.saturating_sub(1)));
2579
2580 let max_jitter = config.base_delay_ms.min(1000);
2582 let jitter = if max_jitter > 0 {
2583 u64::from(
2584 std::time::SystemTime::now()
2585 .duration_since(std::time::UNIX_EPOCH)
2586 .unwrap_or_default()
2587 .subsec_nanos(),
2588 ) % max_jitter
2589 } else {
2590 0
2591 };
2592
2593 let delay_ms = base_delay.saturating_add(jitter).min(config.max_delay_ms);
2594 Duration::from_millis(delay_ms)
2595}
2596
2597type ExtractedContent = (
2599 Option<String>,
2600 Option<String>,
2601 Vec<(String, String, serde_json::Value)>,
2602);
2603
2604fn extract_content(response: &ChatResponse) -> ExtractedContent {
2606 let mut thinking_parts = Vec::new();
2607 let mut text_parts = Vec::new();
2608 let mut tool_uses = Vec::new();
2609
2610 for block in &response.content {
2611 match block {
2612 ContentBlock::Text { text } => {
2613 text_parts.push(text.clone());
2614 }
2615 ContentBlock::Thinking { thinking } => {
2616 thinking_parts.push(thinking.clone());
2617 }
2618 ContentBlock::ToolUse {
2619 id, name, input, ..
2620 } => {
2621 let input = if input.is_null() {
2622 serde_json::json!({})
2623 } else {
2624 input.clone()
2625 };
2626 tool_uses.push((id.clone(), name.clone(), input.clone()));
2627 }
2628 ContentBlock::ToolResult { .. } => {
2629 }
2631 }
2632 }
2633
2634 let thinking = if thinking_parts.is_empty() {
2635 None
2636 } else {
2637 Some(thinking_parts.join("\n"))
2638 };
2639
2640 let text = if text_parts.is_empty() {
2641 None
2642 } else {
2643 Some(text_parts.join("\n"))
2644 };
2645
2646 (thinking, text, tool_uses)
2647}
2648
2649async fn send_event<H>(tx: &mpsc::Sender<AgentEvent>, hooks: &Arc<H>, event: AgentEvent)
2663where
2664 H: AgentHooks,
2665{
2666 hooks.on_event(&event).await;
2667
2668 match tx.try_send(event) {
2670 Ok(()) => {}
2671 Err(mpsc::error::TrySendError::Full(event)) => {
2672 log::debug!("Event channel full, waiting for consumer...");
2674 match tokio::time::timeout(std::time::Duration::from_secs(30), tx.send(event)).await {
2676 Ok(Ok(())) => {}
2677 Ok(Err(_)) => {
2678 log::warn!("Event channel closed while sending - consumer disconnected");
2679 }
2680 Err(_) => {
2681 log::error!("Timeout waiting to send event - consumer may be deadlocked");
2682 }
2683 }
2684 }
2685 Err(mpsc::error::TrySendError::Closed(_)) => {
2686 log::debug!("Event channel closed - consumer disconnected");
2687 }
2688 }
2689}
2690
2691fn build_assistant_message(response: &ChatResponse) -> Message {
2692 let mut blocks = Vec::new();
2693
2694 for block in &response.content {
2695 match block {
2696 ContentBlock::Text { text } => {
2697 blocks.push(ContentBlock::Text { text: text.clone() });
2698 }
2699 ContentBlock::Thinking { .. } | ContentBlock::ToolResult { .. } => {
2700 }
2703 ContentBlock::ToolUse {
2704 id,
2705 name,
2706 input,
2707 thought_signature,
2708 } => {
2709 blocks.push(ContentBlock::ToolUse {
2710 id: id.clone(),
2711 name: name.clone(),
2712 input: input.clone(),
2713 thought_signature: thought_signature.clone(),
2714 });
2715 }
2716 }
2717 }
2718
2719 Message {
2720 role: Role::Assistant,
2721 content: Content::Blocks(blocks),
2722 }
2723}
2724
2725#[cfg(test)]
2726mod tests {
2727 use super::*;
2728 use crate::hooks::AllowAllHooks;
2729 use crate::llm::{ChatOutcome, ChatRequest, ChatResponse, ContentBlock, StopReason, Usage};
2730 use crate::stores::InMemoryStore;
2731 use crate::tools::{Tool, ToolContext, ToolRegistry};
2732 use crate::types::{AgentConfig, AgentInput, ToolResult, ToolTier};
2733 use anyhow::Result;
2734 use async_trait::async_trait;
2735 use serde_json::json;
2736 use std::sync::RwLock;
2737 use std::sync::atomic::{AtomicUsize, Ordering};
2738
2739 struct MockProvider {
2744 responses: RwLock<Vec<ChatOutcome>>,
2745 call_count: AtomicUsize,
2746 }
2747
2748 impl MockProvider {
2749 fn new(responses: Vec<ChatOutcome>) -> Self {
2750 Self {
2751 responses: RwLock::new(responses),
2752 call_count: AtomicUsize::new(0),
2753 }
2754 }
2755
2756 fn text_response(text: &str) -> ChatOutcome {
2757 ChatOutcome::Success(ChatResponse {
2758 id: "msg_1".to_string(),
2759 content: vec![ContentBlock::Text {
2760 text: text.to_string(),
2761 }],
2762 model: "mock-model".to_string(),
2763 stop_reason: Some(StopReason::EndTurn),
2764 usage: Usage {
2765 input_tokens: 10,
2766 output_tokens: 20,
2767 },
2768 })
2769 }
2770
2771 fn tool_use_response(
2772 tool_id: &str,
2773 tool_name: &str,
2774 input: serde_json::Value,
2775 ) -> ChatOutcome {
2776 ChatOutcome::Success(ChatResponse {
2777 id: "msg_1".to_string(),
2778 content: vec![ContentBlock::ToolUse {
2779 id: tool_id.to_string(),
2780 name: tool_name.to_string(),
2781 input,
2782 thought_signature: None,
2783 }],
2784 model: "mock-model".to_string(),
2785 stop_reason: Some(StopReason::ToolUse),
2786 usage: Usage {
2787 input_tokens: 10,
2788 output_tokens: 20,
2789 },
2790 })
2791 }
2792 }
2793
2794 #[async_trait]
2795 impl LlmProvider for MockProvider {
2796 async fn chat(&self, _request: ChatRequest) -> Result<ChatOutcome> {
2797 let idx = self.call_count.fetch_add(1, Ordering::SeqCst);
2798 let responses = self.responses.read().unwrap();
2799 if idx < responses.len() {
2800 Ok(responses[idx].clone())
2801 } else {
2802 Ok(Self::text_response("Done"))
2804 }
2805 }
2806
2807 fn model(&self) -> &'static str {
2808 "mock-model"
2809 }
2810
2811 fn provider(&self) -> &'static str {
2812 "mock"
2813 }
2814 }
2815
2816 impl Clone for ChatOutcome {
2818 fn clone(&self) -> Self {
2819 match self {
2820 Self::Success(r) => Self::Success(r.clone()),
2821 Self::RateLimited => Self::RateLimited,
2822 Self::InvalidRequest(s) => Self::InvalidRequest(s.clone()),
2823 Self::ServerError(s) => Self::ServerError(s.clone()),
2824 }
2825 }
2826 }
2827
2828 struct EchoTool;
2833
2834 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
2836 #[serde(rename_all = "snake_case")]
2837 enum TestToolName {
2838 Echo,
2839 }
2840
2841 impl crate::tools::ToolName for TestToolName {}
2842
2843 impl Tool<()> for EchoTool {
2844 type Name = TestToolName;
2845
2846 fn name(&self) -> TestToolName {
2847 TestToolName::Echo
2848 }
2849
2850 fn display_name(&self) -> &'static str {
2851 "Echo"
2852 }
2853
2854 fn description(&self) -> &'static str {
2855 "Echo the input message"
2856 }
2857
2858 fn input_schema(&self) -> serde_json::Value {
2859 json!({
2860 "type": "object",
2861 "properties": {
2862 "message": { "type": "string" }
2863 },
2864 "required": ["message"]
2865 })
2866 }
2867
2868 fn tier(&self) -> ToolTier {
2869 ToolTier::Observe
2870 }
2871
2872 async fn execute(
2873 &self,
2874 _ctx: &ToolContext<()>,
2875 input: serde_json::Value,
2876 ) -> Result<ToolResult> {
2877 let message = input
2878 .get("message")
2879 .and_then(|v| v.as_str())
2880 .unwrap_or("no message");
2881 Ok(ToolResult::success(format!("Echo: {message}")))
2882 }
2883 }
2884
2885 #[test]
2890 fn test_builder_creates_agent_loop() {
2891 let provider = MockProvider::new(vec![]);
2892 let agent = builder::<()>().provider(provider).build();
2893
2894 assert_eq!(agent.config.max_turns, 10);
2895 assert_eq!(agent.config.max_tokens, 4096);
2896 }
2897
2898 #[test]
2899 fn test_builder_with_custom_config() {
2900 let provider = MockProvider::new(vec![]);
2901 let config = AgentConfig {
2902 max_turns: 5,
2903 max_tokens: 2048,
2904 system_prompt: "Custom prompt".to_string(),
2905 model: "custom-model".to_string(),
2906 ..Default::default()
2907 };
2908
2909 let agent = builder::<()>().provider(provider).config(config).build();
2910
2911 assert_eq!(agent.config.max_turns, 5);
2912 assert_eq!(agent.config.max_tokens, 2048);
2913 assert_eq!(agent.config.system_prompt, "Custom prompt");
2914 }
2915
2916 #[test]
2917 fn test_builder_with_tools() {
2918 let provider = MockProvider::new(vec![]);
2919 let mut tools = ToolRegistry::new();
2920 tools.register(EchoTool);
2921
2922 let agent = builder::<()>().provider(provider).tools(tools).build();
2923
2924 assert_eq!(agent.tools.len(), 1);
2925 }
2926
2927 #[test]
2928 fn test_builder_with_custom_stores() {
2929 let provider = MockProvider::new(vec![]);
2930 let message_store = InMemoryStore::new();
2931 let state_store = InMemoryStore::new();
2932
2933 let agent = builder::<()>()
2934 .provider(provider)
2935 .hooks(AllowAllHooks)
2936 .message_store(message_store)
2937 .state_store(state_store)
2938 .build_with_stores();
2939
2940 assert_eq!(agent.config.max_turns, 10);
2942 }
2943
2944 #[tokio::test]
2949 async fn test_simple_text_response() -> anyhow::Result<()> {
2950 let provider = MockProvider::new(vec![MockProvider::text_response("Hello, user!")]);
2951
2952 let agent = builder::<()>().provider(provider).build();
2953
2954 let thread_id = ThreadId::new();
2955 let tool_ctx = ToolContext::new(());
2956 let (mut rx, _final_state) =
2957 agent.run(thread_id, AgentInput::Text("Hi".to_string()), tool_ctx);
2958
2959 let mut events = Vec::new();
2960 while let Some(event) = rx.recv().await {
2961 events.push(event);
2962 }
2963
2964 assert!(events.iter().any(|e| matches!(e, AgentEvent::Text { .. })));
2966 assert!(events.iter().any(|e| matches!(e, AgentEvent::Done { .. })));
2967
2968 Ok(())
2969 }
2970
2971 #[tokio::test]
2972 async fn test_tool_execution() -> anyhow::Result<()> {
2973 let provider = MockProvider::new(vec![
2974 MockProvider::tool_use_response("tool_1", "echo", json!({"message": "test"})),
2976 MockProvider::text_response("Tool executed successfully"),
2978 ]);
2979
2980 let mut tools = ToolRegistry::new();
2981 tools.register(EchoTool);
2982
2983 let agent = builder::<()>().provider(provider).tools(tools).build();
2984
2985 let thread_id = ThreadId::new();
2986 let tool_ctx = ToolContext::new(());
2987 let (mut rx, _final_state) = agent.run(
2988 thread_id,
2989 AgentInput::Text("Run echo".to_string()),
2990 tool_ctx,
2991 );
2992
2993 let mut events = Vec::new();
2994 while let Some(event) = rx.recv().await {
2995 events.push(event);
2996 }
2997
2998 assert!(
3000 events
3001 .iter()
3002 .any(|e| matches!(e, AgentEvent::ToolCallStart { .. }))
3003 );
3004 assert!(
3005 events
3006 .iter()
3007 .any(|e| matches!(e, AgentEvent::ToolCallEnd { .. }))
3008 );
3009
3010 Ok(())
3011 }
3012
3013 #[tokio::test]
3014 async fn test_max_turns_limit() -> anyhow::Result<()> {
3015 let provider = MockProvider::new(vec![
3017 MockProvider::tool_use_response("tool_1", "echo", json!({"message": "1"})),
3018 MockProvider::tool_use_response("tool_2", "echo", json!({"message": "2"})),
3019 MockProvider::tool_use_response("tool_3", "echo", json!({"message": "3"})),
3020 MockProvider::tool_use_response("tool_4", "echo", json!({"message": "4"})),
3021 ]);
3022
3023 let mut tools = ToolRegistry::new();
3024 tools.register(EchoTool);
3025
3026 let config = AgentConfig {
3027 max_turns: 2,
3028 ..Default::default()
3029 };
3030
3031 let agent = builder::<()>()
3032 .provider(provider)
3033 .tools(tools)
3034 .config(config)
3035 .build();
3036
3037 let thread_id = ThreadId::new();
3038 let tool_ctx = ToolContext::new(());
3039 let (mut rx, _final_state) =
3040 agent.run(thread_id, AgentInput::Text("Loop".to_string()), tool_ctx);
3041
3042 let mut events = Vec::new();
3043 while let Some(event) = rx.recv().await {
3044 events.push(event);
3045 }
3046
3047 assert!(events.iter().any(|e| {
3049 matches!(e, AgentEvent::Error { message, .. } if message.contains("Maximum turns"))
3050 }));
3051
3052 Ok(())
3053 }
3054
3055 #[tokio::test]
3056 async fn test_unknown_tool_handling() -> anyhow::Result<()> {
3057 let provider = MockProvider::new(vec![
3058 MockProvider::tool_use_response("tool_1", "nonexistent_tool", json!({})),
3060 MockProvider::text_response("I couldn't find that tool."),
3062 ]);
3063
3064 let tools = ToolRegistry::new();
3066
3067 let agent = builder::<()>().provider(provider).tools(tools).build();
3068
3069 let thread_id = ThreadId::new();
3070 let tool_ctx = ToolContext::new(());
3071 let (mut rx, _final_state) = agent.run(
3072 thread_id,
3073 AgentInput::Text("Call unknown".to_string()),
3074 tool_ctx,
3075 );
3076
3077 let mut events = Vec::new();
3078 while let Some(event) = rx.recv().await {
3079 events.push(event);
3080 }
3081
3082 assert!(events.iter().any(|e| matches!(e, AgentEvent::Done { .. })));
3085
3086 assert!(
3088 events.iter().any(|e| {
3089 matches!(e, AgentEvent::Text { text } if text.contains("couldn't find"))
3090 })
3091 );
3092
3093 Ok(())
3094 }
3095
3096 #[tokio::test]
3097 async fn test_rate_limit_handling() -> anyhow::Result<()> {
3098 let provider = MockProvider::new(vec![
3100 ChatOutcome::RateLimited,
3101 ChatOutcome::RateLimited,
3102 ChatOutcome::RateLimited,
3103 ChatOutcome::RateLimited,
3104 ChatOutcome::RateLimited,
3105 ChatOutcome::RateLimited, ]);
3107
3108 let config = AgentConfig {
3110 retry: crate::types::RetryConfig::fast(),
3111 ..Default::default()
3112 };
3113
3114 let agent = builder::<()>().provider(provider).config(config).build();
3115
3116 let thread_id = ThreadId::new();
3117 let tool_ctx = ToolContext::new(());
3118 let (mut rx, _final_state) =
3119 agent.run(thread_id, AgentInput::Text("Hi".to_string()), tool_ctx);
3120
3121 let mut events = Vec::new();
3122 while let Some(event) = rx.recv().await {
3123 events.push(event);
3124 }
3125
3126 assert!(events.iter().any(|e| {
3128 matches!(e, AgentEvent::Error { message, recoverable: true } if message.contains("Rate limited"))
3129 }));
3130
3131 assert!(
3133 events
3134 .iter()
3135 .any(|e| { matches!(e, AgentEvent::Text { text } if text.contains("retrying")) })
3136 );
3137
3138 Ok(())
3139 }
3140
3141 #[tokio::test]
3142 async fn test_rate_limit_recovery() -> anyhow::Result<()> {
3143 let provider = MockProvider::new(vec![
3145 ChatOutcome::RateLimited,
3146 MockProvider::text_response("Recovered after rate limit"),
3147 ]);
3148
3149 let config = AgentConfig {
3151 retry: crate::types::RetryConfig::fast(),
3152 ..Default::default()
3153 };
3154
3155 let agent = builder::<()>().provider(provider).config(config).build();
3156
3157 let thread_id = ThreadId::new();
3158 let tool_ctx = ToolContext::new(());
3159 let (mut rx, _final_state) =
3160 agent.run(thread_id, AgentInput::Text("Hi".to_string()), tool_ctx);
3161
3162 let mut events = Vec::new();
3163 while let Some(event) = rx.recv().await {
3164 events.push(event);
3165 }
3166
3167 assert!(events.iter().any(|e| matches!(e, AgentEvent::Done { .. })));
3169
3170 assert!(
3172 events
3173 .iter()
3174 .any(|e| { matches!(e, AgentEvent::Text { text } if text.contains("retrying")) })
3175 );
3176
3177 Ok(())
3178 }
3179
3180 #[tokio::test]
3181 async fn test_server_error_handling() -> anyhow::Result<()> {
3182 let provider = MockProvider::new(vec![
3184 ChatOutcome::ServerError("Internal error".to_string()),
3185 ChatOutcome::ServerError("Internal error".to_string()),
3186 ChatOutcome::ServerError("Internal error".to_string()),
3187 ChatOutcome::ServerError("Internal error".to_string()),
3188 ChatOutcome::ServerError("Internal error".to_string()),
3189 ChatOutcome::ServerError("Internal error".to_string()), ]);
3191
3192 let config = AgentConfig {
3194 retry: crate::types::RetryConfig::fast(),
3195 ..Default::default()
3196 };
3197
3198 let agent = builder::<()>().provider(provider).config(config).build();
3199
3200 let thread_id = ThreadId::new();
3201 let tool_ctx = ToolContext::new(());
3202 let (mut rx, _final_state) =
3203 agent.run(thread_id, AgentInput::Text("Hi".to_string()), tool_ctx);
3204
3205 let mut events = Vec::new();
3206 while let Some(event) = rx.recv().await {
3207 events.push(event);
3208 }
3209
3210 assert!(events.iter().any(|e| {
3212 matches!(e, AgentEvent::Error { message, recoverable: true } if message.contains("Server error"))
3213 }));
3214
3215 assert!(
3217 events
3218 .iter()
3219 .any(|e| { matches!(e, AgentEvent::Text { text } if text.contains("retrying")) })
3220 );
3221
3222 Ok(())
3223 }
3224
3225 #[tokio::test]
3226 async fn test_server_error_recovery() -> anyhow::Result<()> {
3227 let provider = MockProvider::new(vec![
3229 ChatOutcome::ServerError("Temporary error".to_string()),
3230 MockProvider::text_response("Recovered after server error"),
3231 ]);
3232
3233 let config = AgentConfig {
3235 retry: crate::types::RetryConfig::fast(),
3236 ..Default::default()
3237 };
3238
3239 let agent = builder::<()>().provider(provider).config(config).build();
3240
3241 let thread_id = ThreadId::new();
3242 let tool_ctx = ToolContext::new(());
3243 let (mut rx, _final_state) =
3244 agent.run(thread_id, AgentInput::Text("Hi".to_string()), tool_ctx);
3245
3246 let mut events = Vec::new();
3247 while let Some(event) = rx.recv().await {
3248 events.push(event);
3249 }
3250
3251 assert!(events.iter().any(|e| matches!(e, AgentEvent::Done { .. })));
3253
3254 assert!(
3256 events
3257 .iter()
3258 .any(|e| { matches!(e, AgentEvent::Text { text } if text.contains("retrying")) })
3259 );
3260
3261 Ok(())
3262 }
3263
3264 #[test]
3269 fn test_extract_content_text_only() {
3270 let response = ChatResponse {
3271 id: "msg_1".to_string(),
3272 content: vec![ContentBlock::Text {
3273 text: "Hello".to_string(),
3274 }],
3275 model: "test".to_string(),
3276 stop_reason: None,
3277 usage: Usage {
3278 input_tokens: 0,
3279 output_tokens: 0,
3280 },
3281 };
3282
3283 let (thinking, text, tool_uses) = extract_content(&response);
3284 assert!(thinking.is_none());
3285 assert_eq!(text, Some("Hello".to_string()));
3286 assert!(tool_uses.is_empty());
3287 }
3288
3289 #[test]
3290 fn test_extract_content_tool_use() {
3291 let response = ChatResponse {
3292 id: "msg_1".to_string(),
3293 content: vec![ContentBlock::ToolUse {
3294 id: "tool_1".to_string(),
3295 name: "test_tool".to_string(),
3296 input: json!({"key": "value"}),
3297 thought_signature: None,
3298 }],
3299 model: "test".to_string(),
3300 stop_reason: None,
3301 usage: Usage {
3302 input_tokens: 0,
3303 output_tokens: 0,
3304 },
3305 };
3306
3307 let (thinking, text, tool_uses) = extract_content(&response);
3308 assert!(thinking.is_none());
3309 assert!(text.is_none());
3310 assert_eq!(tool_uses.len(), 1);
3311 assert_eq!(tool_uses[0].1, "test_tool");
3312 }
3313
3314 #[test]
3315 fn test_extract_content_mixed() {
3316 let response = ChatResponse {
3317 id: "msg_1".to_string(),
3318 content: vec![
3319 ContentBlock::Text {
3320 text: "Let me help".to_string(),
3321 },
3322 ContentBlock::ToolUse {
3323 id: "tool_1".to_string(),
3324 name: "helper".to_string(),
3325 input: json!({}),
3326 thought_signature: None,
3327 },
3328 ],
3329 model: "test".to_string(),
3330 stop_reason: None,
3331 usage: Usage {
3332 input_tokens: 0,
3333 output_tokens: 0,
3334 },
3335 };
3336
3337 let (thinking, text, tool_uses) = extract_content(&response);
3338 assert!(thinking.is_none());
3339 assert_eq!(text, Some("Let me help".to_string()));
3340 assert_eq!(tool_uses.len(), 1);
3341 }
3342
3343 #[test]
3344 fn test_millis_to_u64() {
3345 assert_eq!(millis_to_u64(0), 0);
3346 assert_eq!(millis_to_u64(1000), 1000);
3347 assert_eq!(millis_to_u64(u128::from(u64::MAX)), u64::MAX);
3348 assert_eq!(millis_to_u64(u128::from(u64::MAX) + 1), u64::MAX);
3349 }
3350
3351 #[test]
3352 fn test_build_assistant_message() {
3353 let response = ChatResponse {
3354 id: "msg_1".to_string(),
3355 content: vec![
3356 ContentBlock::Text {
3357 text: "Response text".to_string(),
3358 },
3359 ContentBlock::ToolUse {
3360 id: "tool_1".to_string(),
3361 name: "echo".to_string(),
3362 input: json!({"message": "test"}),
3363 thought_signature: None,
3364 },
3365 ],
3366 model: "test".to_string(),
3367 stop_reason: None,
3368 usage: Usage {
3369 input_tokens: 0,
3370 output_tokens: 0,
3371 },
3372 };
3373
3374 let msg = build_assistant_message(&response);
3375 assert_eq!(msg.role, Role::Assistant);
3376
3377 if let Content::Blocks(blocks) = msg.content {
3378 assert_eq!(blocks.len(), 2);
3379 } else {
3380 panic!("Expected Content::Blocks");
3381 }
3382 }
3383}