1use crate::context::{CompactionConfig, ContextCompactor, LlmContextCompactor};
31use crate::events::AgentEvent;
32use crate::hooks::{AgentHooks, DefaultHooks, ToolDecision};
33use crate::llm::{
34 ChatOutcome, ChatRequest, ChatResponse, Content, ContentBlock, LlmProvider, Message, Role,
35 StopReason, StreamAccumulator, StreamDelta, Usage,
36};
37use crate::skills::Skill;
38use crate::stores::{InMemoryStore, MessageStore, StateStore, ToolExecutionStore};
39use crate::tools::{ErasedAsyncTool, ErasedToolStatus, ToolContext, ToolRegistry};
40use crate::types::{
41 AgentConfig, AgentContinuation, AgentError, AgentInput, AgentRunState, AgentState,
42 ExecutionStatus, PendingToolCallInfo, RetryConfig, ThreadId, TokenUsage, ToolExecution,
43 ToolOutcome, ToolResult, TurnOutcome,
44};
45use futures::StreamExt;
46use log::{debug, error, info, warn};
47use std::sync::Arc;
48use std::time::{Duration, Instant};
49use tokio::sync::{mpsc, oneshot};
50use tokio::time::sleep;
51
52enum InternalTurnResult {
56 Continue { turn_usage: TokenUsage },
58 Done,
60 AwaitingConfirmation {
62 tool_call_id: String,
63 tool_name: String,
64 display_name: String,
65 input: serde_json::Value,
66 description: String,
67 continuation: Box<AgentContinuation>,
68 },
69 Error(AgentError),
71}
72
73struct TurnContext {
77 thread_id: ThreadId,
78 turn: usize,
79 total_usage: TokenUsage,
80 state: AgentState,
81 start_time: Instant,
82}
83
84struct ResumeData {
86 continuation: Box<AgentContinuation>,
87 tool_call_id: String,
88 confirmed: bool,
89 rejection_reason: Option<String>,
90}
91
92struct InitializedState {
94 turn: usize,
95 total_usage: TokenUsage,
96 state: AgentState,
97 resume_data: Option<ResumeData>,
98}
99
100enum ToolExecutionOutcome {
102 Completed { tool_id: String, result: ToolResult },
104 RequiresConfirmation {
106 tool_id: String,
107 tool_name: String,
108 display_name: String,
109 input: serde_json::Value,
110 description: String,
111 },
112}
113
114pub struct AgentLoopBuilder<Ctx, P, H, M, S> {
126 provider: Option<P>,
127 tools: Option<ToolRegistry<Ctx>>,
128 hooks: Option<H>,
129 message_store: Option<M>,
130 state_store: Option<S>,
131 config: Option<AgentConfig>,
132 compaction_config: Option<CompactionConfig>,
133 execution_store: Option<Arc<dyn ToolExecutionStore>>,
134}
135
136impl<Ctx> AgentLoopBuilder<Ctx, (), (), (), ()> {
137 #[must_use]
139 pub fn new() -> Self {
140 Self {
141 provider: None,
142 tools: None,
143 hooks: None,
144 message_store: None,
145 state_store: None,
146 config: None,
147 compaction_config: None,
148 execution_store: None,
149 }
150 }
151}
152
153impl<Ctx> Default for AgentLoopBuilder<Ctx, (), (), (), ()> {
154 fn default() -> Self {
155 Self::new()
156 }
157}
158
159impl<Ctx, P, H, M, S> AgentLoopBuilder<Ctx, P, H, M, S> {
160 #[must_use]
162 pub fn provider<P2: LlmProvider>(self, provider: P2) -> AgentLoopBuilder<Ctx, P2, H, M, S> {
163 AgentLoopBuilder {
164 provider: Some(provider),
165 tools: self.tools,
166 hooks: self.hooks,
167 message_store: self.message_store,
168 state_store: self.state_store,
169 config: self.config,
170 compaction_config: self.compaction_config,
171 execution_store: self.execution_store,
172 }
173 }
174
175 #[must_use]
177 pub fn tools(mut self, tools: ToolRegistry<Ctx>) -> Self {
178 self.tools = Some(tools);
179 self
180 }
181
182 #[must_use]
184 pub fn hooks<H2: AgentHooks>(self, hooks: H2) -> AgentLoopBuilder<Ctx, P, H2, M, S> {
185 AgentLoopBuilder {
186 provider: self.provider,
187 tools: self.tools,
188 hooks: Some(hooks),
189 message_store: self.message_store,
190 state_store: self.state_store,
191 config: self.config,
192 compaction_config: self.compaction_config,
193 execution_store: self.execution_store,
194 }
195 }
196
197 #[must_use]
199 pub fn message_store<M2: MessageStore>(
200 self,
201 message_store: M2,
202 ) -> AgentLoopBuilder<Ctx, P, H, M2, S> {
203 AgentLoopBuilder {
204 provider: self.provider,
205 tools: self.tools,
206 hooks: self.hooks,
207 message_store: Some(message_store),
208 state_store: self.state_store,
209 config: self.config,
210 compaction_config: self.compaction_config,
211 execution_store: self.execution_store,
212 }
213 }
214
215 #[must_use]
217 pub fn state_store<S2: StateStore>(
218 self,
219 state_store: S2,
220 ) -> AgentLoopBuilder<Ctx, P, H, M, S2> {
221 AgentLoopBuilder {
222 provider: self.provider,
223 tools: self.tools,
224 hooks: self.hooks,
225 message_store: self.message_store,
226 state_store: Some(state_store),
227 config: self.config,
228 compaction_config: self.compaction_config,
229 execution_store: self.execution_store,
230 }
231 }
232
233 #[must_use]
251 pub fn execution_store(mut self, store: impl ToolExecutionStore + 'static) -> Self {
252 self.execution_store = Some(Arc::new(store));
253 self
254 }
255
256 #[must_use]
258 pub fn config(mut self, config: AgentConfig) -> Self {
259 self.config = Some(config);
260 self
261 }
262
263 #[must_use]
279 pub const fn with_compaction(mut self, config: CompactionConfig) -> Self {
280 self.compaction_config = Some(config);
281 self
282 }
283
284 #[must_use]
291 pub fn with_auto_compaction(self) -> Self {
292 self.with_compaction(CompactionConfig::default())
293 }
294
295 #[must_use]
313 pub fn with_skill(mut self, skill: Skill) -> Self
314 where
315 Ctx: Send + Sync + 'static,
316 {
317 if let Some(ref mut tools) = self.tools {
319 tools.filter(|name| skill.is_tool_allowed(name));
320 }
321
322 let mut config = self.config.take().unwrap_or_default();
324 if config.system_prompt.is_empty() {
325 config.system_prompt = skill.system_prompt;
326 } else {
327 config.system_prompt = format!("{}\n\n{}", config.system_prompt, skill.system_prompt);
328 }
329 self.config = Some(config);
330
331 self
332 }
333}
334
335impl<Ctx, P> AgentLoopBuilder<Ctx, P, (), (), ()>
336where
337 Ctx: Send + Sync + 'static,
338 P: LlmProvider + 'static,
339{
340 #[must_use]
352 pub fn build(self) -> AgentLoop<Ctx, P, DefaultHooks, InMemoryStore, InMemoryStore> {
353 let provider = self.provider.expect("provider is required");
354 let tools = self.tools.unwrap_or_default();
355 let config = self.config.unwrap_or_default();
356
357 AgentLoop {
358 provider: Arc::new(provider),
359 tools: Arc::new(tools),
360 hooks: Arc::new(DefaultHooks),
361 message_store: Arc::new(InMemoryStore::new()),
362 state_store: Arc::new(InMemoryStore::new()),
363 config,
364 compaction_config: self.compaction_config,
365 execution_store: self.execution_store,
366 }
367 }
368}
369
370impl<Ctx, P, H, M, S> AgentLoopBuilder<Ctx, P, H, M, S>
371where
372 Ctx: Send + Sync + 'static,
373 P: LlmProvider + 'static,
374 H: AgentHooks + 'static,
375 M: MessageStore + 'static,
376 S: StateStore + 'static,
377{
378 #[must_use]
388 pub fn build_with_stores(self) -> AgentLoop<Ctx, P, H, M, S> {
389 let provider = self.provider.expect("provider is required");
390 let tools = self.tools.unwrap_or_default();
391 let hooks = self
392 .hooks
393 .expect("hooks is required when using build_with_stores");
394 let message_store = self
395 .message_store
396 .expect("message_store is required when using build_with_stores");
397 let state_store = self
398 .state_store
399 .expect("state_store is required when using build_with_stores");
400 let config = self.config.unwrap_or_default();
401
402 AgentLoop {
403 provider: Arc::new(provider),
404 tools: Arc::new(tools),
405 hooks: Arc::new(hooks),
406 message_store: Arc::new(message_store),
407 state_store: Arc::new(state_store),
408 config,
409 compaction_config: self.compaction_config,
410 execution_store: self.execution_store,
411 }
412 }
413}
414
415pub struct AgentLoop<Ctx, P, H, M, S>
461where
462 P: LlmProvider,
463 H: AgentHooks,
464 M: MessageStore,
465 S: StateStore,
466{
467 provider: Arc<P>,
468 tools: Arc<ToolRegistry<Ctx>>,
469 hooks: Arc<H>,
470 message_store: Arc<M>,
471 state_store: Arc<S>,
472 config: AgentConfig,
473 compaction_config: Option<CompactionConfig>,
474 execution_store: Option<Arc<dyn ToolExecutionStore>>,
475}
476
477#[must_use]
479pub fn builder<Ctx>() -> AgentLoopBuilder<Ctx, (), (), (), ()> {
480 AgentLoopBuilder::new()
481}
482
483impl<Ctx, P, H, M, S> AgentLoop<Ctx, P, H, M, S>
484where
485 Ctx: Send + Sync + 'static,
486 P: LlmProvider + 'static,
487 H: AgentHooks + 'static,
488 M: MessageStore + 'static,
489 S: StateStore + 'static,
490{
491 #[must_use]
493 pub fn new(
494 provider: P,
495 tools: ToolRegistry<Ctx>,
496 hooks: H,
497 message_store: M,
498 state_store: S,
499 config: AgentConfig,
500 ) -> Self {
501 Self {
502 provider: Arc::new(provider),
503 tools: Arc::new(tools),
504 hooks: Arc::new(hooks),
505 message_store: Arc::new(message_store),
506 state_store: Arc::new(state_store),
507 config,
508 compaction_config: None,
509 execution_store: None,
510 }
511 }
512
513 #[must_use]
515 pub fn with_compaction(
516 provider: P,
517 tools: ToolRegistry<Ctx>,
518 hooks: H,
519 message_store: M,
520 state_store: S,
521 config: AgentConfig,
522 compaction_config: CompactionConfig,
523 ) -> Self {
524 Self {
525 provider: Arc::new(provider),
526 tools: Arc::new(tools),
527 hooks: Arc::new(hooks),
528 message_store: Arc::new(message_store),
529 state_store: Arc::new(state_store),
530 config,
531 compaction_config: Some(compaction_config),
532 execution_store: None,
533 }
534 }
535
536 pub fn run(
586 &self,
587 thread_id: ThreadId,
588 input: AgentInput,
589 tool_context: ToolContext<Ctx>,
590 ) -> (mpsc::Receiver<AgentEvent>, oneshot::Receiver<AgentRunState>)
591 where
592 Ctx: Clone,
593 {
594 let (event_tx, event_rx) = mpsc::channel(100);
595 let (state_tx, state_rx) = oneshot::channel();
596
597 let provider = Arc::clone(&self.provider);
598 let tools = Arc::clone(&self.tools);
599 let hooks = Arc::clone(&self.hooks);
600 let message_store = Arc::clone(&self.message_store);
601 let state_store = Arc::clone(&self.state_store);
602 let config = self.config.clone();
603 let compaction_config = self.compaction_config.clone();
604 let execution_store = self.execution_store.clone();
605
606 tokio::spawn(async move {
607 let result = run_loop(
608 event_tx,
609 thread_id,
610 input,
611 tool_context,
612 provider,
613 tools,
614 hooks,
615 message_store,
616 state_store,
617 config,
618 compaction_config,
619 execution_store,
620 )
621 .await;
622
623 let _ = state_tx.send(result);
624 });
625
626 (event_rx, state_rx)
627 }
628
629 pub fn run_turn(
685 &self,
686 thread_id: ThreadId,
687 input: AgentInput,
688 tool_context: ToolContext<Ctx>,
689 ) -> (mpsc::Receiver<AgentEvent>, oneshot::Receiver<TurnOutcome>)
690 where
691 Ctx: Clone,
692 {
693 let (event_tx, event_rx) = mpsc::channel(100);
694 let (outcome_tx, outcome_rx) = oneshot::channel();
695
696 let provider = Arc::clone(&self.provider);
697 let tools = Arc::clone(&self.tools);
698 let hooks = Arc::clone(&self.hooks);
699 let message_store = Arc::clone(&self.message_store);
700 let state_store = Arc::clone(&self.state_store);
701 let config = self.config.clone();
702 let compaction_config = self.compaction_config.clone();
703 let execution_store = self.execution_store.clone();
704
705 tokio::spawn(async move {
706 let result = run_single_turn(TurnParameters {
707 tx: event_tx,
708 thread_id,
709 input,
710 tool_context,
711 provider,
712 tools,
713 hooks,
714 message_store,
715 state_store,
716 config,
717 compaction_config,
718 execution_store,
719 })
720 .await;
721
722 let _ = outcome_tx.send(result);
723 });
724
725 (event_rx, outcome_rx)
726 }
727}
728
729async fn initialize_from_input<M, S>(
740 input: AgentInput,
741 thread_id: &ThreadId,
742 message_store: &Arc<M>,
743 state_store: &Arc<S>,
744) -> Result<InitializedState, AgentError>
745where
746 M: MessageStore,
747 S: StateStore,
748{
749 match input {
750 AgentInput::Text(user_message) => {
751 let state = match state_store.load(thread_id).await {
753 Ok(Some(s)) => s,
754 Ok(None) => AgentState::new(thread_id.clone()),
755 Err(e) => {
756 return Err(AgentError::new(format!("Failed to load state: {e}"), false));
757 }
758 };
759
760 let user_msg = Message::user(&user_message);
762 if let Err(e) = message_store.append(thread_id, user_msg).await {
763 return Err(AgentError::new(
764 format!("Failed to append message: {e}"),
765 false,
766 ));
767 }
768
769 Ok(InitializedState {
770 turn: 0,
771 total_usage: TokenUsage::default(),
772 state,
773 resume_data: None,
774 })
775 }
776 AgentInput::Resume {
777 continuation,
778 tool_call_id,
779 confirmed,
780 rejection_reason,
781 } => {
782 if continuation.thread_id != *thread_id {
784 return Err(AgentError::new(
785 format!(
786 "Thread ID mismatch: continuation is for {}, but resuming on {}",
787 continuation.thread_id, thread_id
788 ),
789 false,
790 ));
791 }
792
793 Ok(InitializedState {
794 turn: continuation.turn,
795 total_usage: continuation.total_usage.clone(),
796 state: continuation.state.clone(),
797 resume_data: Some(ResumeData {
798 continuation,
799 tool_call_id,
800 confirmed,
801 rejection_reason,
802 }),
803 })
804 }
805 AgentInput::Continue => {
806 let state = match state_store.load(thread_id).await {
808 Ok(Some(s)) => s,
809 Ok(None) => {
810 return Err(AgentError::new(
811 "Cannot continue: no state found for thread",
812 false,
813 ));
814 }
815 Err(e) => {
816 return Err(AgentError::new(format!("Failed to load state: {e}"), false));
817 }
818 };
819
820 Ok(InitializedState {
822 turn: state.turn_count,
823 total_usage: state.total_usage.clone(),
824 state,
825 resume_data: None,
826 })
827 }
828 }
829}
830
831#[allow(clippy::too_many_lines)]
840async fn execute_tool_call<Ctx, H>(
841 pending: &PendingToolCallInfo,
842 tool_context: &ToolContext<Ctx>,
843 tools: &ToolRegistry<Ctx>,
844 hooks: &Arc<H>,
845 tx: &mpsc::Sender<AgentEvent>,
846) -> ToolExecutionOutcome
847where
848 Ctx: Send + Sync + Clone + 'static,
849 H: AgentHooks,
850{
851 if let Some(async_tool) = tools.get_async(&pending.name) {
853 let tier = async_tool.tier();
854
855 let _ = tx
857 .send(AgentEvent::tool_call_start(
858 &pending.id,
859 &pending.name,
860 &pending.display_name,
861 pending.input.clone(),
862 tier,
863 ))
864 .await;
865
866 let decision = hooks
868 .pre_tool_use(&pending.name, &pending.input, tier)
869 .await;
870
871 return match decision {
872 ToolDecision::Allow => {
873 let result = execute_async_tool(pending, async_tool, tool_context, tx).await;
874
875 hooks.post_tool_use(&pending.name, &result).await;
876
877 send_event(
878 tx,
879 hooks,
880 AgentEvent::tool_call_end(
881 &pending.id,
882 &pending.name,
883 &pending.display_name,
884 result.clone(),
885 ),
886 )
887 .await;
888
889 ToolExecutionOutcome::Completed {
890 tool_id: pending.id.clone(),
891 result,
892 }
893 }
894 ToolDecision::Block(reason) => {
895 let result = ToolResult::error(format!("Blocked: {reason}"));
896 send_event(
897 tx,
898 hooks,
899 AgentEvent::tool_call_end(
900 &pending.id,
901 &pending.name,
902 &pending.display_name,
903 result.clone(),
904 ),
905 )
906 .await;
907 ToolExecutionOutcome::Completed {
908 tool_id: pending.id.clone(),
909 result,
910 }
911 }
912 ToolDecision::RequiresConfirmation(description) => {
913 send_event(
914 tx,
915 hooks,
916 AgentEvent::ToolRequiresConfirmation {
917 id: pending.id.clone(),
918 name: pending.name.clone(),
919 input: pending.input.clone(),
920 description: description.clone(),
921 },
922 )
923 .await;
924
925 ToolExecutionOutcome::RequiresConfirmation {
926 tool_id: pending.id.clone(),
927 tool_name: pending.name.clone(),
928 display_name: pending.display_name.clone(),
929 input: pending.input.clone(),
930 description,
931 }
932 }
933 };
934 }
935
936 let Some(tool) = tools.get(&pending.name) else {
938 let result = ToolResult::error(format!("Unknown tool: {}", pending.name));
939 return ToolExecutionOutcome::Completed {
940 tool_id: pending.id.clone(),
941 result,
942 };
943 };
944
945 let tier = tool.tier();
946
947 let _ = tx
949 .send(AgentEvent::tool_call_start(
950 &pending.id,
951 &pending.name,
952 &pending.display_name,
953 pending.input.clone(),
954 tier,
955 ))
956 .await;
957
958 let decision = hooks
960 .pre_tool_use(&pending.name, &pending.input, tier)
961 .await;
962
963 match decision {
964 ToolDecision::Allow => {
965 let tool_start = Instant::now();
966 let result = match tool.execute(tool_context, pending.input.clone()).await {
967 Ok(mut r) => {
968 r.duration_ms = Some(millis_to_u64(tool_start.elapsed().as_millis()));
969 r
970 }
971 Err(e) => ToolResult::error(format!("Tool error: {e}"))
972 .with_duration(millis_to_u64(tool_start.elapsed().as_millis())),
973 };
974
975 hooks.post_tool_use(&pending.name, &result).await;
976
977 send_event(
978 tx,
979 hooks,
980 AgentEvent::tool_call_end(
981 &pending.id,
982 &pending.name,
983 &pending.display_name,
984 result.clone(),
985 ),
986 )
987 .await;
988
989 ToolExecutionOutcome::Completed {
990 tool_id: pending.id.clone(),
991 result,
992 }
993 }
994 ToolDecision::Block(reason) => {
995 let result = ToolResult::error(format!("Blocked: {reason}"));
996 send_event(
997 tx,
998 hooks,
999 AgentEvent::tool_call_end(
1000 &pending.id,
1001 &pending.name,
1002 &pending.display_name,
1003 result.clone(),
1004 ),
1005 )
1006 .await;
1007 ToolExecutionOutcome::Completed {
1008 tool_id: pending.id.clone(),
1009 result,
1010 }
1011 }
1012 ToolDecision::RequiresConfirmation(description) => {
1013 send_event(
1014 tx,
1015 hooks,
1016 AgentEvent::ToolRequiresConfirmation {
1017 id: pending.id.clone(),
1018 name: pending.name.clone(),
1019 input: pending.input.clone(),
1020 description: description.clone(),
1021 },
1022 )
1023 .await;
1024
1025 ToolExecutionOutcome::RequiresConfirmation {
1026 tool_id: pending.id.clone(),
1027 tool_name: pending.name.clone(),
1028 display_name: pending.display_name.clone(),
1029 input: pending.input.clone(),
1030 description,
1031 }
1032 }
1033 }
1034}
1035
1036async fn execute_async_tool<Ctx>(
1042 pending: &PendingToolCallInfo,
1043 tool: &Arc<dyn ErasedAsyncTool<Ctx>>,
1044 tool_context: &ToolContext<Ctx>,
1045 tx: &mpsc::Sender<AgentEvent>,
1046) -> ToolResult
1047where
1048 Ctx: Send + Sync + Clone,
1049{
1050 let tool_start = Instant::now();
1051
1052 let outcome = match tool.execute(tool_context, pending.input.clone()).await {
1054 Ok(o) => o,
1055 Err(e) => {
1056 return ToolResult::error(format!("Tool error: {e}"))
1057 .with_duration(millis_to_u64(tool_start.elapsed().as_millis()));
1058 }
1059 };
1060
1061 match outcome {
1062 ToolOutcome::Success(mut result) | ToolOutcome::Failed(mut result) => {
1064 result.duration_ms = Some(millis_to_u64(tool_start.elapsed().as_millis()));
1065 result
1066 }
1067
1068 ToolOutcome::InProgress {
1070 operation_id,
1071 message,
1072 } => {
1073 let _ = tx
1075 .send(AgentEvent::tool_progress(
1076 &pending.id,
1077 &pending.name,
1078 &pending.display_name,
1079 "started",
1080 &message,
1081 None,
1082 ))
1083 .await;
1084
1085 let mut stream = tool.check_status_stream(tool_context, &operation_id);
1087
1088 while let Some(status) = stream.next().await {
1089 match status {
1090 ErasedToolStatus::Progress {
1091 stage,
1092 message,
1093 data,
1094 } => {
1095 let _ = tx
1096 .send(AgentEvent::tool_progress(
1097 &pending.id,
1098 &pending.name,
1099 &pending.display_name,
1100 stage,
1101 message,
1102 data,
1103 ))
1104 .await;
1105 }
1106 ErasedToolStatus::Completed(mut result)
1107 | ErasedToolStatus::Failed(mut result) => {
1108 result.duration_ms = Some(millis_to_u64(tool_start.elapsed().as_millis()));
1109 return result;
1110 }
1111 }
1112 }
1113
1114 ToolResult::error("Async tool stream ended without completion")
1116 .with_duration(millis_to_u64(tool_start.elapsed().as_millis()))
1117 }
1118 }
1119}
1120
1121async fn execute_confirmed_tool<Ctx, H>(
1126 awaiting_tool: &PendingToolCallInfo,
1127 confirmed: bool,
1128 rejection_reason: Option<String>,
1129 tool_context: &ToolContext<Ctx>,
1130 tools: &ToolRegistry<Ctx>,
1131 hooks: &Arc<H>,
1132 tx: &mpsc::Sender<AgentEvent>,
1133) -> ToolResult
1134where
1135 Ctx: Send + Sync + Clone + 'static,
1136 H: AgentHooks,
1137{
1138 if confirmed {
1139 if let Some(async_tool) = tools.get_async(&awaiting_tool.name) {
1141 let result = execute_async_tool(awaiting_tool, async_tool, tool_context, tx).await;
1142
1143 hooks.post_tool_use(&awaiting_tool.name, &result).await;
1144
1145 let _ = tx
1146 .send(AgentEvent::tool_call_end(
1147 &awaiting_tool.id,
1148 &awaiting_tool.name,
1149 &awaiting_tool.display_name,
1150 result.clone(),
1151 ))
1152 .await;
1153
1154 return result;
1155 }
1156
1157 if let Some(tool) = tools.get(&awaiting_tool.name) {
1159 let tool_start = Instant::now();
1160 let result = match tool
1161 .execute(tool_context, awaiting_tool.input.clone())
1162 .await
1163 {
1164 Ok(mut r) => {
1165 r.duration_ms = Some(millis_to_u64(tool_start.elapsed().as_millis()));
1166 r
1167 }
1168 Err(e) => ToolResult::error(format!("Tool error: {e}"))
1169 .with_duration(millis_to_u64(tool_start.elapsed().as_millis())),
1170 };
1171
1172 hooks.post_tool_use(&awaiting_tool.name, &result).await;
1173
1174 let _ = tx
1175 .send(AgentEvent::tool_call_end(
1176 &awaiting_tool.id,
1177 &awaiting_tool.name,
1178 &awaiting_tool.display_name,
1179 result.clone(),
1180 ))
1181 .await;
1182
1183 result
1184 } else {
1185 ToolResult::error(format!("Unknown tool: {}", awaiting_tool.name))
1186 }
1187 } else {
1188 let reason = rejection_reason.unwrap_or_else(|| "User rejected".to_string());
1189 let result = ToolResult::error(format!("Rejected: {reason}"));
1190 send_event(
1191 tx,
1192 hooks,
1193 AgentEvent::tool_call_end(
1194 &awaiting_tool.id,
1195 &awaiting_tool.name,
1196 &awaiting_tool.display_name,
1197 result.clone(),
1198 ),
1199 )
1200 .await;
1201 result
1202 }
1203}
1204
1205async fn append_tool_results<M>(
1207 tool_results: &[(String, ToolResult)],
1208 thread_id: &ThreadId,
1209 message_store: &Arc<M>,
1210) -> Result<(), AgentError>
1211where
1212 M: MessageStore,
1213{
1214 for (tool_id, result) in tool_results {
1215 let tool_result_msg = Message::tool_result(tool_id, &result.output, !result.success);
1216 if let Err(e) = message_store.append(thread_id, tool_result_msg).await {
1217 return Err(AgentError::new(
1218 format!("Failed to append tool result: {e}"),
1219 false,
1220 ));
1221 }
1222 }
1223 Ok(())
1224}
1225
1226async fn call_llm_with_retry<P, H>(
1228 provider: &Arc<P>,
1229 request: ChatRequest,
1230 config: &AgentConfig,
1231 tx: &mpsc::Sender<AgentEvent>,
1232 hooks: &Arc<H>,
1233) -> Result<ChatResponse, AgentError>
1234where
1235 P: LlmProvider,
1236 H: AgentHooks,
1237{
1238 let max_retries = config.retry.max_retries;
1239 let mut attempt = 0u32;
1240
1241 loop {
1242 let outcome = match provider.chat(request.clone()).await {
1243 Ok(o) => o,
1244 Err(e) => {
1245 return Err(AgentError::new(format!("LLM error: {e}"), false));
1246 }
1247 };
1248
1249 match outcome {
1250 ChatOutcome::Success(response) => return Ok(response),
1251 ChatOutcome::RateLimited => {
1252 attempt += 1;
1253 if attempt > max_retries {
1254 error!("Rate limited by LLM provider after {max_retries} retries");
1255 let error_msg = format!("Rate limited after {max_retries} retries");
1256 send_event(tx, hooks, AgentEvent::error(&error_msg, true)).await;
1257 return Err(AgentError::new(error_msg, true));
1258 }
1259 let delay = calculate_backoff_delay(attempt, &config.retry);
1260 warn!(
1261 "Rate limited, retrying after backoff (attempt={}, delay_ms={})",
1262 attempt,
1263 delay.as_millis()
1264 );
1265
1266 sleep(delay).await;
1267 }
1268 ChatOutcome::InvalidRequest(msg) => {
1269 error!("Invalid request to LLM: {msg}");
1270 return Err(AgentError::new(format!("Invalid request: {msg}"), false));
1271 }
1272 ChatOutcome::ServerError(msg) => {
1273 attempt += 1;
1274 if attempt > max_retries {
1275 error!("LLM server error after {max_retries} retries: {msg}");
1276 let error_msg = format!("Server error after {max_retries} retries: {msg}");
1277 send_event(tx, hooks, AgentEvent::error(&error_msg, true)).await;
1278 return Err(AgentError::new(error_msg, true));
1279 }
1280 let delay = calculate_backoff_delay(attempt, &config.retry);
1281 warn!(
1282 "Server error, retrying after backoff (attempt={attempt}, delay_ms={}, error={msg})",
1283 delay.as_millis()
1284 );
1285
1286 sleep(delay).await;
1287 }
1288 }
1289 }
1290}
1291
1292async fn call_llm_streaming<P, H>(
1298 provider: &Arc<P>,
1299 request: ChatRequest,
1300 config: &AgentConfig,
1301 tx: &mpsc::Sender<AgentEvent>,
1302 hooks: &Arc<H>,
1303 message_id: &str,
1304 thinking_id: &str,
1305) -> Result<ChatResponse, AgentError>
1306where
1307 P: LlmProvider,
1308 H: AgentHooks,
1309{
1310 let max_retries = config.retry.max_retries;
1311 let mut attempt = 0u32;
1312
1313 loop {
1314 let result = process_stream(provider, &request, tx, hooks, message_id, thinking_id).await;
1315
1316 match result {
1317 Ok(response) => return Ok(response),
1318 Err(StreamError::Recoverable(msg)) => {
1319 attempt += 1;
1320 if attempt > max_retries {
1321 error!("Streaming error after {max_retries} retries: {msg}");
1322 let err_msg = format!("Streaming error after {max_retries} retries: {msg}");
1323 send_event(tx, hooks, AgentEvent::error(&err_msg, true)).await;
1324 return Err(AgentError::new(err_msg, true));
1325 }
1326 let delay = calculate_backoff_delay(attempt, &config.retry);
1327 warn!(
1328 "Streaming error, retrying (attempt={attempt}, delay_ms={}, error={msg})",
1329 delay.as_millis()
1330 );
1331
1332 sleep(delay).await;
1333 }
1334 Err(StreamError::Fatal(msg)) => {
1335 error!("Streaming error (non-recoverable): {msg}");
1336 return Err(AgentError::new(format!("Streaming error: {msg}"), false));
1337 }
1338 }
1339 }
1340}
1341
1342enum StreamError {
1344 Recoverable(String),
1345 Fatal(String),
1346}
1347
1348async fn process_stream<P, H>(
1350 provider: &Arc<P>,
1351 request: &ChatRequest,
1352 tx: &mpsc::Sender<AgentEvent>,
1353 hooks: &Arc<H>,
1354 message_id: &str,
1355 thinking_id: &str,
1356) -> Result<ChatResponse, StreamError>
1357where
1358 P: LlmProvider,
1359 H: AgentHooks,
1360{
1361 let mut stream = std::pin::pin!(provider.chat_stream(request.clone()));
1362 let mut accumulator = StreamAccumulator::new();
1363 let mut delta_count: u64 = 0;
1364
1365 log::debug!("Starting to consume LLM stream");
1366
1367 let mut channel_closed = false;
1369
1370 while let Some(result) = stream.next().await {
1371 if delta_count > 0 && delta_count.is_multiple_of(50) {
1373 log::debug!("Stream progress: delta_count={delta_count}");
1374 }
1375
1376 match result {
1377 Ok(delta) => {
1378 delta_count += 1;
1379 accumulator.apply(&delta);
1380 match &delta {
1381 StreamDelta::TextDelta { delta, .. } => {
1382 if !channel_closed {
1384 if tx.is_closed() {
1385 log::warn!(
1386 "Event channel closed by receiver at delta_count={delta_count} - consumer may have disconnected"
1387 );
1388 channel_closed = true;
1389 } else {
1390 send_event(
1391 tx,
1392 hooks,
1393 AgentEvent::text_delta(message_id, delta.clone()),
1394 )
1395 .await;
1396 }
1397 }
1398 }
1399 StreamDelta::ThinkingDelta { delta, .. } => {
1400 if !channel_closed {
1401 if tx.is_closed() {
1402 log::warn!(
1403 "Event channel closed by receiver at delta_count={delta_count}"
1404 );
1405 channel_closed = true;
1406 } else {
1407 send_event(
1408 tx,
1409 hooks,
1410 AgentEvent::thinking(thinking_id, delta.clone()),
1411 )
1412 .await;
1413 }
1414 }
1415 }
1416 StreamDelta::Error {
1417 message,
1418 recoverable,
1419 } => {
1420 log::warn!(
1421 "Stream error received delta_count={delta_count} message={message} recoverable={recoverable}"
1422 );
1423 return if *recoverable {
1424 Err(StreamError::Recoverable(message.clone()))
1425 } else {
1426 Err(StreamError::Fatal(message.clone()))
1427 };
1428 }
1429 StreamDelta::Done { .. }
1431 | StreamDelta::Usage(_)
1432 | StreamDelta::ToolUseStart { .. }
1433 | StreamDelta::ToolInputDelta { .. } => {}
1434 }
1435 }
1436 Err(e) => {
1437 log::error!("Stream iteration error delta_count={delta_count} error={e}");
1438 return Err(StreamError::Recoverable(format!("Stream error: {e}")));
1439 }
1440 }
1441 }
1442
1443 log::debug!("Stream while loop exited normally at delta_count={delta_count}");
1444
1445 let usage = accumulator.usage().cloned().unwrap_or(Usage {
1446 input_tokens: 0,
1447 output_tokens: 0,
1448 });
1449 let stop_reason = accumulator.stop_reason().copied();
1450 let content_blocks = accumulator.into_content_blocks();
1451
1452 log::debug!(
1453 "LLM stream completed successfully delta_count={delta_count} stop_reason={stop_reason:?} content_block_count={} input_tokens={} output_tokens={}",
1454 content_blocks.len(),
1455 usage.input_tokens,
1456 usage.output_tokens
1457 );
1458
1459 Ok(ChatResponse {
1460 id: String::new(),
1461 content: content_blocks,
1462 model: provider.model().to_string(),
1463 stop_reason,
1464 usage,
1465 })
1466}
1467
1468#[allow(clippy::too_many_arguments, clippy::too_many_lines)]
1473async fn run_loop<Ctx, P, H, M, S>(
1474 tx: mpsc::Sender<AgentEvent>,
1475 thread_id: ThreadId,
1476 input: AgentInput,
1477 tool_context: ToolContext<Ctx>,
1478 provider: Arc<P>,
1479 tools: Arc<ToolRegistry<Ctx>>,
1480 hooks: Arc<H>,
1481 message_store: Arc<M>,
1482 state_store: Arc<S>,
1483 config: AgentConfig,
1484 compaction_config: Option<CompactionConfig>,
1485 execution_store: Option<Arc<dyn ToolExecutionStore>>,
1486) -> AgentRunState
1487where
1488 Ctx: Send + Sync + Clone + 'static,
1489 P: LlmProvider,
1490 H: AgentHooks,
1491 M: MessageStore,
1492 S: StateStore,
1493{
1494 let tool_context = tool_context.with_event_tx(tx.clone());
1496 let start_time = Instant::now();
1497
1498 let init_state =
1500 match initialize_from_input(input, &thread_id, &message_store, &state_store).await {
1501 Ok(s) => s,
1502 Err(e) => return AgentRunState::Error(e),
1503 };
1504
1505 let InitializedState {
1506 turn,
1507 total_usage,
1508 state,
1509 resume_data,
1510 } = init_state;
1511
1512 if let Some(resume) = resume_data {
1513 let ResumeData {
1514 continuation: cont,
1515 tool_call_id,
1516 confirmed,
1517 rejection_reason,
1518 } = resume;
1519 let mut tool_results = cont.completed_results.clone();
1520 let awaiting_tool = &cont.pending_tool_calls[cont.awaiting_index];
1521
1522 if awaiting_tool.id != tool_call_id {
1523 let message = format!(
1524 "Tool call ID mismatch: expected {}, got {}",
1525 awaiting_tool.id, tool_call_id
1526 );
1527 let recoverable = false;
1528 send_event(&tx, &hooks, AgentEvent::error(&message, recoverable)).await;
1529 return AgentRunState::Error(AgentError::new(&message, recoverable));
1530 }
1531
1532 let result = execute_confirmed_tool(
1533 awaiting_tool,
1534 confirmed,
1535 rejection_reason,
1536 &tool_context,
1537 &tools,
1538 &hooks,
1539 &tx,
1540 )
1541 .await;
1542 tool_results.push((awaiting_tool.id.clone(), result));
1543
1544 for pending in cont.pending_tool_calls.iter().skip(cont.awaiting_index + 1) {
1545 match execute_tool_call(pending, &tool_context, &tools, &hooks, &tx).await {
1546 ToolExecutionOutcome::Completed { tool_id, result } => {
1547 tool_results.push((tool_id, result));
1548 }
1549 ToolExecutionOutcome::RequiresConfirmation {
1550 tool_id,
1551 tool_name,
1552 display_name,
1553 input,
1554 description,
1555 } => {
1556 let pending_idx = cont
1557 .pending_tool_calls
1558 .iter()
1559 .position(|p| p.id == tool_id)
1560 .unwrap_or(0);
1561
1562 let new_continuation = AgentContinuation {
1563 thread_id: thread_id.clone(),
1564 turn,
1565 total_usage: total_usage.clone(),
1566 turn_usage: cont.turn_usage.clone(),
1567 pending_tool_calls: cont.pending_tool_calls.clone(),
1568 awaiting_index: pending_idx,
1569 completed_results: tool_results,
1570 state: state.clone(),
1571 };
1572
1573 return AgentRunState::AwaitingConfirmation {
1574 tool_call_id: tool_id,
1575 tool_name,
1576 display_name,
1577 input,
1578 description,
1579 continuation: Box::new(new_continuation),
1580 };
1581 }
1582 }
1583 }
1584
1585 if let Err(e) = append_tool_results(&tool_results, &thread_id, &message_store).await {
1586 send_event(
1587 &tx,
1588 &hooks,
1589 AgentEvent::Error {
1590 message: e.message.clone(),
1591 recoverable: e.recoverable,
1592 },
1593 )
1594 .await;
1595 return AgentRunState::Error(e);
1596 }
1597
1598 send_event(
1599 &tx,
1600 &hooks,
1601 AgentEvent::TurnComplete {
1602 turn,
1603 usage: cont.turn_usage.clone(),
1604 },
1605 )
1606 .await;
1607 }
1608
1609 let mut ctx = TurnContext {
1610 thread_id: thread_id.clone(),
1611 turn,
1612 total_usage,
1613 state,
1614 start_time,
1615 };
1616
1617 loop {
1618 let result = execute_turn(
1619 &tx,
1620 &mut ctx,
1621 &tool_context,
1622 &provider,
1623 &tools,
1624 &hooks,
1625 &message_store,
1626 &config,
1627 compaction_config.as_ref(),
1628 execution_store.as_ref(),
1629 )
1630 .await;
1631
1632 match result {
1633 InternalTurnResult::Continue { .. } => {
1634 if let Err(e) = state_store.save(&ctx.state).await {
1635 warn!("Failed to save state checkpoint: {e}");
1636 }
1637 }
1638 InternalTurnResult::Done => {
1639 break;
1640 }
1641 InternalTurnResult::AwaitingConfirmation {
1642 tool_call_id,
1643 tool_name,
1644 display_name,
1645 input,
1646 description,
1647 continuation,
1648 } => {
1649 return AgentRunState::AwaitingConfirmation {
1650 tool_call_id,
1651 tool_name,
1652 display_name,
1653 input,
1654 description,
1655 continuation,
1656 };
1657 }
1658 InternalTurnResult::Error(e) => {
1659 return AgentRunState::Error(e);
1660 }
1661 }
1662 }
1663
1664 if let Err(e) = state_store.save(&ctx.state).await {
1665 warn!("Failed to save final state: {e}");
1666 }
1667
1668 let duration = ctx.start_time.elapsed();
1669 send_event(
1670 &tx,
1671 &hooks,
1672 AgentEvent::done(thread_id, ctx.turn, ctx.total_usage.clone(), duration),
1673 )
1674 .await;
1675
1676 AgentRunState::Done {
1677 total_turns: u32::try_from(ctx.turn).unwrap_or(u32::MAX),
1678 input_tokens: u64::from(ctx.total_usage.input_tokens),
1679 output_tokens: u64::from(ctx.total_usage.output_tokens),
1680 }
1681}
1682
1683struct TurnParameters<Ctx, P, H, M, S> {
1684 tx: mpsc::Sender<AgentEvent>,
1685 thread_id: ThreadId,
1686 input: AgentInput,
1687 tool_context: ToolContext<Ctx>,
1688 provider: Arc<P>,
1689 tools: Arc<ToolRegistry<Ctx>>,
1690 hooks: Arc<H>,
1691 message_store: Arc<M>,
1692 state_store: Arc<S>,
1693 config: AgentConfig,
1694 compaction_config: Option<CompactionConfig>,
1695 execution_store: Option<Arc<dyn ToolExecutionStore>>,
1696}
1697
1698async fn run_single_turn<Ctx, P, H, M, S>(
1704 TurnParameters {
1705 tx,
1706 thread_id,
1707 input,
1708 tool_context,
1709 provider,
1710 tools,
1711 hooks,
1712 message_store,
1713 state_store,
1714 config,
1715 compaction_config,
1716 execution_store,
1717 }: TurnParameters<Ctx, P, H, M, S>,
1718) -> TurnOutcome
1719where
1720 Ctx: Send + Sync + Clone + 'static,
1721 P: LlmProvider,
1722 H: AgentHooks,
1723 M: MessageStore,
1724 S: StateStore,
1725{
1726 let tool_context = tool_context.with_event_tx(tx.clone());
1727 let start_time = Instant::now();
1728
1729 let init_state =
1730 match initialize_from_input(input, &thread_id, &message_store, &state_store).await {
1731 Ok(s) => s,
1732 Err(e) => {
1733 send_event(&tx, &hooks, AgentEvent::error(&e.message, e.recoverable)).await;
1734 return TurnOutcome::Error(e);
1735 }
1736 };
1737
1738 let InitializedState {
1739 turn,
1740 total_usage,
1741 state,
1742 resume_data,
1743 } = init_state;
1744
1745 if let Some(resume_data_val) = resume_data {
1746 return handle_resume_case(ResumeCaseParameters {
1747 resume_data: resume_data_val,
1748 turn,
1749 total_usage,
1750 state,
1751 thread_id,
1752 tool_context,
1753 tools,
1754 hooks,
1755 tx,
1756 message_store,
1757 state_store,
1758 })
1759 .await;
1760 }
1761
1762 let mut ctx = TurnContext {
1763 thread_id: thread_id.clone(),
1764 turn,
1765 total_usage,
1766 state,
1767 start_time,
1768 };
1769
1770 let result = execute_turn(
1771 &tx,
1772 &mut ctx,
1773 &tool_context,
1774 &provider,
1775 &tools,
1776 &hooks,
1777 &message_store,
1778 &config,
1779 compaction_config.as_ref(),
1780 execution_store.as_ref(),
1781 )
1782 .await;
1783
1784 match result {
1786 InternalTurnResult::Continue { turn_usage } => {
1787 if let Err(e) = state_store.save(&ctx.state).await {
1789 warn!("Failed to save state checkpoint: {e}");
1790 }
1791
1792 TurnOutcome::NeedsMoreTurns {
1793 turn: ctx.turn,
1794 turn_usage,
1795 total_usage: ctx.total_usage,
1796 }
1797 }
1798 InternalTurnResult::Done => {
1799 if let Err(e) = state_store.save(&ctx.state).await {
1801 warn!("Failed to save final state: {e}");
1802 }
1803
1804 let duration = ctx.start_time.elapsed();
1806 send_event(
1807 &tx,
1808 &hooks,
1809 AgentEvent::done(thread_id, ctx.turn, ctx.total_usage.clone(), duration),
1810 )
1811 .await;
1812
1813 TurnOutcome::Done {
1814 total_turns: u32::try_from(ctx.turn).unwrap_or(u32::MAX),
1815 input_tokens: u64::from(ctx.total_usage.input_tokens),
1816 output_tokens: u64::from(ctx.total_usage.output_tokens),
1817 }
1818 }
1819 InternalTurnResult::AwaitingConfirmation {
1820 tool_call_id,
1821 tool_name,
1822 display_name,
1823 input,
1824 description,
1825 continuation,
1826 } => TurnOutcome::AwaitingConfirmation {
1827 tool_call_id,
1828 tool_name,
1829 display_name,
1830 input,
1831 description,
1832 continuation,
1833 },
1834 InternalTurnResult::Error(e) => TurnOutcome::Error(e),
1835 }
1836}
1837
1838struct ResumeCaseParameters<Ctx, H, M, S> {
1839 resume_data: ResumeData,
1840 turn: usize,
1841 total_usage: TokenUsage,
1842 state: AgentState,
1843 thread_id: ThreadId,
1844 tool_context: ToolContext<Ctx>,
1845 tools: Arc<ToolRegistry<Ctx>>,
1846 hooks: Arc<H>,
1847 tx: mpsc::Sender<AgentEvent>,
1848 message_store: Arc<M>,
1849 state_store: Arc<S>,
1850}
1851
1852async fn handle_resume_case<Ctx, H, M, S>(
1853 ResumeCaseParameters {
1854 resume_data,
1855 turn,
1856 total_usage,
1857 state,
1858 thread_id,
1859 tool_context,
1860 tools,
1861 hooks,
1862 tx,
1863 message_store,
1864 state_store,
1865 }: ResumeCaseParameters<Ctx, H, M, S>,
1866) -> TurnOutcome
1867where
1868 Ctx: Send + Sync + Clone + 'static,
1869 H: AgentHooks,
1870 M: MessageStore,
1871 S: StateStore,
1872{
1873 let ResumeData {
1874 continuation: cont,
1875 tool_call_id,
1876 confirmed,
1877 rejection_reason,
1878 } = resume_data;
1879 let mut tool_results = cont.completed_results.clone();
1881 let awaiting_tool = &cont.pending_tool_calls[cont.awaiting_index];
1882
1883 if awaiting_tool.id != tool_call_id {
1885 let message = format!(
1886 "Tool call ID mismatch: expected {}, got {}",
1887 awaiting_tool.id, tool_call_id
1888 );
1889 let recoverable = false;
1890 send_event(&tx, &hooks, AgentEvent::error(&message, recoverable)).await;
1891 return TurnOutcome::Error(AgentError::new(&message, recoverable));
1892 }
1893
1894 let result = execute_confirmed_tool(
1895 awaiting_tool,
1896 confirmed,
1897 rejection_reason,
1898 &tool_context,
1899 &tools,
1900 &hooks,
1901 &tx,
1902 )
1903 .await;
1904 tool_results.push((awaiting_tool.id.clone(), result));
1905
1906 for pending in cont.pending_tool_calls.iter().skip(cont.awaiting_index + 1) {
1907 match execute_tool_call(pending, &tool_context, &tools, &hooks, &tx).await {
1908 ToolExecutionOutcome::Completed { tool_id, result } => {
1909 tool_results.push((tool_id, result));
1910 }
1911 ToolExecutionOutcome::RequiresConfirmation {
1912 tool_id,
1913 tool_name,
1914 display_name,
1915 input,
1916 description,
1917 } => {
1918 let pending_idx = cont
1919 .pending_tool_calls
1920 .iter()
1921 .position(|p| p.id == tool_id)
1922 .unwrap_or(0);
1923
1924 let new_continuation = AgentContinuation {
1925 thread_id: thread_id.clone(),
1926 turn,
1927 total_usage: total_usage.clone(),
1928 turn_usage: cont.turn_usage.clone(),
1929 pending_tool_calls: cont.pending_tool_calls.clone(),
1930 awaiting_index: pending_idx,
1931 completed_results: tool_results,
1932 state: state.clone(),
1933 };
1934
1935 return TurnOutcome::AwaitingConfirmation {
1936 tool_call_id: tool_id,
1937 tool_name,
1938 display_name,
1939 input,
1940 description,
1941 continuation: Box::new(new_continuation),
1942 };
1943 }
1944 }
1945 }
1946
1947 if let Err(e) = append_tool_results(&tool_results, &thread_id, &message_store).await {
1948 send_event(&tx, &hooks, AgentEvent::error(&e.message, e.recoverable)).await;
1949 return TurnOutcome::Error(e);
1950 }
1951
1952 send_event(
1953 &tx,
1954 &hooks,
1955 AgentEvent::TurnComplete {
1956 turn,
1957 usage: cont.turn_usage.clone(),
1958 },
1959 )
1960 .await;
1961
1962 let mut updated_state = state;
1963 updated_state.turn_count = turn;
1964 if let Err(e) = state_store.save(&updated_state).await {
1965 warn!("Failed to save state checkpoint: {e}");
1966 }
1967
1968 TurnOutcome::NeedsMoreTurns {
1969 turn,
1970 turn_usage: cont.turn_usage.clone(),
1971 total_usage,
1972 }
1973}
1974
1975async fn try_get_cached_result(
1984 execution_store: Option<&Arc<dyn ToolExecutionStore>>,
1985 tool_call_id: &str,
1986) -> Option<ToolResult> {
1987 let store = execution_store?;
1988 let execution = store.get_execution(tool_call_id).await.ok()??;
1989
1990 match execution.status {
1991 ExecutionStatus::Completed => execution.result,
1992 ExecutionStatus::InFlight => {
1993 warn!(
1996 "Found in-flight execution from previous attempt, re-executing (tool_call_id={}, tool_name={})",
1997 tool_call_id, execution.tool_name
1998 );
1999 None
2000 }
2001 }
2002}
2003
2004async fn record_execution_start(
2006 execution_store: Option<&Arc<dyn ToolExecutionStore>>,
2007 pending: &PendingToolCallInfo,
2008 thread_id: &ThreadId,
2009 started_at: time::OffsetDateTime,
2010) {
2011 if let Some(store) = execution_store {
2012 let execution = ToolExecution::new_in_flight(
2013 &pending.id,
2014 thread_id.clone(),
2015 &pending.name,
2016 &pending.display_name,
2017 pending.input.clone(),
2018 started_at,
2019 );
2020 if let Err(e) = store.record_execution(execution).await {
2021 warn!(
2022 "Failed to record execution start (tool_call_id={}, error={})",
2023 pending.id, e
2024 );
2025 }
2026 }
2027}
2028
2029async fn record_execution_complete(
2031 execution_store: Option<&Arc<dyn ToolExecutionStore>>,
2032 pending: &PendingToolCallInfo,
2033 thread_id: &ThreadId,
2034 result: &ToolResult,
2035 started_at: time::OffsetDateTime,
2036) {
2037 if let Some(store) = execution_store {
2038 let mut execution = ToolExecution::new_in_flight(
2039 &pending.id,
2040 thread_id.clone(),
2041 &pending.name,
2042 &pending.display_name,
2043 pending.input.clone(),
2044 started_at,
2045 );
2046 execution.complete(result.clone());
2047 if let Err(e) = store.update_execution(execution).await {
2048 warn!(
2049 "Failed to record execution completion (tool_call_id={}, error={})",
2050 pending.id, e
2051 );
2052 }
2053 }
2054}
2055
2056#[allow(clippy::too_many_arguments, clippy::too_many_lines)]
2061async fn execute_turn<Ctx, P, H, M>(
2062 tx: &mpsc::Sender<AgentEvent>,
2063 ctx: &mut TurnContext,
2064 tool_context: &ToolContext<Ctx>,
2065 provider: &Arc<P>,
2066 tools: &Arc<ToolRegistry<Ctx>>,
2067 hooks: &Arc<H>,
2068 message_store: &Arc<M>,
2069 config: &AgentConfig,
2070 compaction_config: Option<&CompactionConfig>,
2071 execution_store: Option<&Arc<dyn ToolExecutionStore>>,
2072) -> InternalTurnResult
2073where
2074 Ctx: Send + Sync + Clone + 'static,
2075 P: LlmProvider,
2076 H: AgentHooks,
2077 M: MessageStore,
2078{
2079 ctx.turn += 1;
2080 ctx.state.turn_count = ctx.turn;
2081
2082 if ctx.turn > config.max_turns {
2083 warn!(
2084 "Max turns reached (turn={}, max={})",
2085 ctx.turn, config.max_turns
2086 );
2087 send_event(
2088 tx,
2089 hooks,
2090 AgentEvent::error(
2091 format!("Maximum turns ({}) reached", config.max_turns),
2092 true,
2093 ),
2094 )
2095 .await;
2096 return InternalTurnResult::Error(AgentError::new(
2097 format!("Maximum turns ({}) reached", config.max_turns),
2098 true,
2099 ));
2100 }
2101
2102 send_event(
2104 tx,
2105 hooks,
2106 AgentEvent::start(ctx.thread_id.clone(), ctx.turn),
2107 )
2108 .await;
2109
2110 let mut messages = match message_store.get_history(&ctx.thread_id).await {
2112 Ok(m) => m,
2113 Err(e) => {
2114 send_event(
2115 tx,
2116 hooks,
2117 AgentEvent::error(format!("Failed to get history: {e}"), false),
2118 )
2119 .await;
2120 return InternalTurnResult::Error(AgentError::new(
2121 format!("Failed to get history: {e}"),
2122 false,
2123 ));
2124 }
2125 };
2126
2127 if let Some(compact_config) = compaction_config {
2129 let compactor = LlmContextCompactor::new(Arc::clone(provider), compact_config.clone());
2130 if compactor.needs_compaction(&messages) {
2131 debug!(
2132 "Context compaction triggered (turn={}, message_count={})",
2133 ctx.turn,
2134 messages.len()
2135 );
2136
2137 match compactor.compact_history(messages.clone()).await {
2138 Ok(result) => {
2139 if let Err(e) = message_store
2140 .replace_history(&ctx.thread_id, result.messages.clone())
2141 .await
2142 {
2143 warn!("Failed to replace history after compaction: {e}");
2144 } else {
2145 send_event(
2146 tx,
2147 hooks,
2148 AgentEvent::context_compacted(
2149 result.original_count,
2150 result.new_count,
2151 result.original_tokens,
2152 result.new_tokens,
2153 ),
2154 )
2155 .await;
2156
2157 info!(
2158 "Context compacted successfully (original_count={}, new_count={}, original_tokens={}, new_tokens={})",
2159 result.original_count,
2160 result.new_count,
2161 result.original_tokens,
2162 result.new_tokens
2163 );
2164
2165 messages = result.messages;
2166 }
2167 }
2168 Err(e) => {
2169 warn!("Context compaction failed, continuing with full history: {e}");
2170 }
2171 }
2172 }
2173 }
2174
2175 let llm_tools = if tools.is_empty() {
2177 None
2178 } else {
2179 Some(tools.to_llm_tools())
2180 };
2181
2182 let request = ChatRequest {
2183 system: config.system_prompt.clone(),
2184 messages,
2185 tools: llm_tools,
2186 max_tokens: config.max_tokens,
2187 thinking: config.thinking.clone(),
2188 };
2189
2190 debug!(
2192 "Calling LLM (turn={}, streaming={})",
2193 ctx.turn, config.streaming
2194 );
2195 let message_id = uuid::Uuid::new_v4().to_string();
2196 let thinking_id = uuid::Uuid::new_v4().to_string();
2197 let response = if config.streaming {
2198 match call_llm_streaming(
2200 provider,
2201 request,
2202 config,
2203 tx,
2204 hooks,
2205 &message_id,
2206 &thinking_id,
2207 )
2208 .await
2209 {
2210 Ok(r) => r,
2211 Err(e) => {
2212 return InternalTurnResult::Error(e);
2213 }
2214 }
2215 } else {
2216 match call_llm_with_retry(provider, request, config, tx, hooks).await {
2218 Ok(r) => r,
2219 Err(e) => {
2220 return InternalTurnResult::Error(e);
2221 }
2222 }
2223 };
2224
2225 let turn_usage = TokenUsage {
2227 input_tokens: response.usage.input_tokens,
2228 output_tokens: response.usage.output_tokens,
2229 };
2230 ctx.total_usage.add(&turn_usage);
2231 ctx.state.total_usage = ctx.total_usage.clone();
2232
2233 let (thinking_content, text_content, tool_uses) = extract_content(&response);
2235
2236 if !config.streaming
2238 && let Some(thinking) = &thinking_content
2239 {
2240 send_event(
2241 tx,
2242 hooks,
2243 AgentEvent::thinking(thinking_id, thinking.clone()),
2244 )
2245 .await;
2246 }
2247
2248 if let Some(text) = &text_content {
2251 send_event(tx, hooks, AgentEvent::text(message_id, text.clone())).await;
2252 }
2253
2254 if tool_uses.is_empty() {
2256 info!("Agent completed (no tool use) (turn={})", ctx.turn);
2257 return InternalTurnResult::Done;
2258 }
2259
2260 let assistant_msg = build_assistant_message(&response);
2262 if let Err(e) = message_store.append(&ctx.thread_id, assistant_msg).await {
2263 send_event(
2264 tx,
2265 hooks,
2266 AgentEvent::error(format!("Failed to append assistant message: {e}"), false),
2267 )
2268 .await;
2269 return InternalTurnResult::Error(AgentError::new(
2270 format!("Failed to append assistant message: {e}"),
2271 false,
2272 ));
2273 }
2274
2275 let pending_tool_calls: Vec<PendingToolCallInfo> = tool_uses
2277 .iter()
2278 .map(|(id, name, input)| {
2279 let display_name = tools
2280 .get(name)
2281 .map(|t| t.display_name().to_string())
2282 .or_else(|| tools.get_async(name).map(|t| t.display_name().to_string()))
2283 .unwrap_or_default();
2284 PendingToolCallInfo {
2285 id: id.clone(),
2286 name: name.clone(),
2287 display_name,
2288 input: input.clone(),
2289 }
2290 })
2291 .collect();
2292
2293 let mut tool_results = Vec::new();
2295 for (idx, pending) in pending_tool_calls.iter().enumerate() {
2296 if let Some(cached_result) = try_get_cached_result(execution_store, &pending.id).await {
2298 debug!(
2299 "Using cached result from previous execution (tool_call_id={}, tool_name={})",
2300 pending.id, pending.name
2301 );
2302 tool_results.push((pending.id.clone(), cached_result));
2303 continue;
2304 }
2305
2306 if let Some(async_tool) = tools.get_async(&pending.name) {
2308 let tier = async_tool.tier();
2309
2310 send_event(
2312 tx,
2313 hooks,
2314 AgentEvent::tool_call_start(
2315 &pending.id,
2316 &pending.name,
2317 &pending.display_name,
2318 pending.input.clone(),
2319 tier,
2320 ),
2321 )
2322 .await;
2323
2324 let decision = hooks
2326 .pre_tool_use(&pending.name, &pending.input, tier)
2327 .await;
2328
2329 match decision {
2330 ToolDecision::Allow => {
2331 let started_at = time::OffsetDateTime::now_utc();
2333 record_execution_start(execution_store, pending, &ctx.thread_id, started_at)
2334 .await;
2335
2336 let result = execute_async_tool(pending, async_tool, tool_context, tx).await;
2337
2338 record_execution_complete(
2340 execution_store,
2341 pending,
2342 &ctx.thread_id,
2343 &result,
2344 started_at,
2345 )
2346 .await;
2347
2348 hooks.post_tool_use(&pending.name, &result).await;
2349
2350 send_event(
2351 tx,
2352 hooks,
2353 AgentEvent::tool_call_end(
2354 &pending.id,
2355 &pending.name,
2356 &pending.display_name,
2357 result.clone(),
2358 ),
2359 )
2360 .await;
2361
2362 tool_results.push((pending.id.clone(), result));
2363 }
2364 ToolDecision::Block(reason) => {
2365 let result = ToolResult::error(format!("Blocked: {reason}"));
2366 send_event(
2367 tx,
2368 hooks,
2369 AgentEvent::tool_call_end(
2370 &pending.id,
2371 &pending.name,
2372 &pending.display_name,
2373 result.clone(),
2374 ),
2375 )
2376 .await;
2377 tool_results.push((pending.id.clone(), result));
2378 }
2379 ToolDecision::RequiresConfirmation(description) => {
2380 send_event(
2382 tx,
2383 hooks,
2384 AgentEvent::ToolRequiresConfirmation {
2385 id: pending.id.clone(),
2386 name: pending.name.clone(),
2387 input: pending.input.clone(),
2388 description: description.clone(),
2389 },
2390 )
2391 .await;
2392
2393 let continuation = AgentContinuation {
2394 thread_id: ctx.thread_id.clone(),
2395 turn: ctx.turn,
2396 total_usage: ctx.total_usage.clone(),
2397 turn_usage: turn_usage.clone(),
2398 pending_tool_calls: pending_tool_calls.clone(),
2399 awaiting_index: idx,
2400 completed_results: tool_results,
2401 state: ctx.state.clone(),
2402 };
2403
2404 return InternalTurnResult::AwaitingConfirmation {
2405 tool_call_id: pending.id.clone(),
2406 tool_name: pending.name.clone(),
2407 display_name: pending.display_name.clone(),
2408 input: pending.input.clone(),
2409 description,
2410 continuation: Box::new(continuation),
2411 };
2412 }
2413 }
2414 continue;
2415 }
2416
2417 let Some(tool) = tools.get(&pending.name) else {
2419 let result = ToolResult::error(format!("Unknown tool: {}", pending.name));
2420 tool_results.push((pending.id.clone(), result));
2421 continue;
2422 };
2423
2424 let tier = tool.tier();
2425
2426 send_event(
2428 tx,
2429 hooks,
2430 AgentEvent::tool_call_start(
2431 &pending.id,
2432 &pending.name,
2433 &pending.display_name,
2434 pending.input.clone(),
2435 tier,
2436 ),
2437 )
2438 .await;
2439
2440 let decision = hooks
2442 .pre_tool_use(&pending.name, &pending.input, tier)
2443 .await;
2444
2445 match decision {
2446 ToolDecision::Allow => {
2447 let started_at = time::OffsetDateTime::now_utc();
2449 record_execution_start(execution_store, pending, &ctx.thread_id, started_at).await;
2450
2451 let tool_start = Instant::now();
2452 let result = match tool.execute(tool_context, pending.input.clone()).await {
2453 Ok(mut r) => {
2454 r.duration_ms = Some(millis_to_u64(tool_start.elapsed().as_millis()));
2455 r
2456 }
2457 Err(e) => ToolResult::error(format!("Tool error: {e}"))
2458 .with_duration(millis_to_u64(tool_start.elapsed().as_millis())),
2459 };
2460
2461 record_execution_complete(
2463 execution_store,
2464 pending,
2465 &ctx.thread_id,
2466 &result,
2467 started_at,
2468 )
2469 .await;
2470
2471 hooks.post_tool_use(&pending.name, &result).await;
2472
2473 send_event(
2474 tx,
2475 hooks,
2476 AgentEvent::tool_call_end(
2477 &pending.id,
2478 &pending.name,
2479 &pending.display_name,
2480 result.clone(),
2481 ),
2482 )
2483 .await;
2484
2485 tool_results.push((pending.id.clone(), result));
2486 }
2487 ToolDecision::Block(reason) => {
2488 let result = ToolResult::error(format!("Blocked: {reason}"));
2489 send_event(
2490 tx,
2491 hooks,
2492 AgentEvent::tool_call_end(
2493 &pending.id,
2494 &pending.name,
2495 &pending.display_name,
2496 result.clone(),
2497 ),
2498 )
2499 .await;
2500 tool_results.push((pending.id.clone(), result));
2501 }
2502 ToolDecision::RequiresConfirmation(description) => {
2503 send_event(
2505 tx,
2506 hooks,
2507 AgentEvent::ToolRequiresConfirmation {
2508 id: pending.id.clone(),
2509 name: pending.name.clone(),
2510 input: pending.input.clone(),
2511 description: description.clone(),
2512 },
2513 )
2514 .await;
2515
2516 let continuation = AgentContinuation {
2517 thread_id: ctx.thread_id.clone(),
2518 turn: ctx.turn,
2519 total_usage: ctx.total_usage.clone(),
2520 turn_usage: turn_usage.clone(),
2521 pending_tool_calls: pending_tool_calls.clone(),
2522 awaiting_index: idx,
2523 completed_results: tool_results,
2524 state: ctx.state.clone(),
2525 };
2526
2527 return InternalTurnResult::AwaitingConfirmation {
2528 tool_call_id: pending.id.clone(),
2529 tool_name: pending.name.clone(),
2530 display_name: pending.display_name.clone(),
2531 input: pending.input.clone(),
2532 description,
2533 continuation: Box::new(continuation),
2534 };
2535 }
2536 }
2537 }
2538
2539 if let Err(e) = append_tool_results(&tool_results, &ctx.thread_id, message_store).await {
2541 send_event(
2542 tx,
2543 hooks,
2544 AgentEvent::error(format!("Failed to append tool results: {e}"), false),
2545 )
2546 .await;
2547 return InternalTurnResult::Error(e);
2548 }
2549
2550 send_event(
2552 tx,
2553 hooks,
2554 AgentEvent::TurnComplete {
2555 turn: ctx.turn,
2556 usage: turn_usage.clone(),
2557 },
2558 )
2559 .await;
2560
2561 if response.stop_reason == Some(StopReason::EndTurn) {
2563 info!("Agent completed (end_turn) (turn={})", ctx.turn);
2564 return InternalTurnResult::Done;
2565 }
2566
2567 InternalTurnResult::Continue { turn_usage }
2568}
2569
2570#[allow(clippy::cast_possible_truncation)]
2572const fn millis_to_u64(millis: u128) -> u64 {
2573 if millis > u64::MAX as u128 {
2574 u64::MAX
2575 } else {
2576 millis as u64
2577 }
2578}
2579
2580fn calculate_backoff_delay(attempt: u32, config: &RetryConfig) -> Duration {
2585 let base_delay = config
2587 .base_delay_ms
2588 .saturating_mul(1u64 << (attempt.saturating_sub(1)));
2589
2590 let max_jitter = config.base_delay_ms.min(1000);
2592 let jitter = if max_jitter > 0 {
2593 u64::from(
2594 std::time::SystemTime::now()
2595 .duration_since(std::time::UNIX_EPOCH)
2596 .unwrap_or_default()
2597 .subsec_nanos(),
2598 ) % max_jitter
2599 } else {
2600 0
2601 };
2602
2603 let delay_ms = base_delay.saturating_add(jitter).min(config.max_delay_ms);
2604 Duration::from_millis(delay_ms)
2605}
2606
2607type ExtractedContent = (
2609 Option<String>,
2610 Option<String>,
2611 Vec<(String, String, serde_json::Value)>,
2612);
2613
2614fn extract_content(response: &ChatResponse) -> ExtractedContent {
2616 let mut thinking_parts = Vec::new();
2617 let mut text_parts = Vec::new();
2618 let mut tool_uses = Vec::new();
2619
2620 for block in &response.content {
2621 match block {
2622 ContentBlock::Text { text } => {
2623 text_parts.push(text.clone());
2624 }
2625 ContentBlock::Thinking { thinking } => {
2626 thinking_parts.push(thinking.clone());
2627 }
2628 ContentBlock::ToolUse {
2629 id, name, input, ..
2630 } => {
2631 let input = if input.is_null() {
2632 serde_json::json!({})
2633 } else {
2634 input.clone()
2635 };
2636 tool_uses.push((id.clone(), name.clone(), input.clone()));
2637 }
2638 ContentBlock::ToolResult { .. } => {
2639 }
2641 }
2642 }
2643
2644 let thinking = if thinking_parts.is_empty() {
2645 None
2646 } else {
2647 Some(thinking_parts.join("\n"))
2648 };
2649
2650 let text = if text_parts.is_empty() {
2651 None
2652 } else {
2653 Some(text_parts.join("\n"))
2654 };
2655
2656 (thinking, text, tool_uses)
2657}
2658
2659async fn send_event<H>(tx: &mpsc::Sender<AgentEvent>, hooks: &Arc<H>, event: AgentEvent)
2673where
2674 H: AgentHooks,
2675{
2676 hooks.on_event(&event).await;
2677
2678 match tx.try_send(event) {
2680 Ok(()) => {}
2681 Err(mpsc::error::TrySendError::Full(event)) => {
2682 log::debug!("Event channel full, waiting for consumer...");
2684 match tokio::time::timeout(std::time::Duration::from_secs(30), tx.send(event)).await {
2686 Ok(Ok(())) => {}
2687 Ok(Err(_)) => {
2688 log::warn!("Event channel closed while sending - consumer disconnected");
2689 }
2690 Err(_) => {
2691 log::error!("Timeout waiting to send event - consumer may be deadlocked");
2692 }
2693 }
2694 }
2695 Err(mpsc::error::TrySendError::Closed(_)) => {
2696 log::debug!("Event channel closed - consumer disconnected");
2697 }
2698 }
2699}
2700
2701fn build_assistant_message(response: &ChatResponse) -> Message {
2702 let mut blocks = Vec::new();
2703
2704 for block in &response.content {
2705 match block {
2706 ContentBlock::Text { text } => {
2707 blocks.push(ContentBlock::Text { text: text.clone() });
2708 }
2709 ContentBlock::Thinking { .. } | ContentBlock::ToolResult { .. } => {
2710 }
2713 ContentBlock::ToolUse {
2714 id,
2715 name,
2716 input,
2717 thought_signature,
2718 } => {
2719 blocks.push(ContentBlock::ToolUse {
2720 id: id.clone(),
2721 name: name.clone(),
2722 input: input.clone(),
2723 thought_signature: thought_signature.clone(),
2724 });
2725 }
2726 }
2727 }
2728
2729 Message {
2730 role: Role::Assistant,
2731 content: Content::Blocks(blocks),
2732 }
2733}
2734
2735#[cfg(test)]
2736mod tests {
2737 use super::*;
2738 use crate::hooks::AllowAllHooks;
2739 use crate::llm::{ChatOutcome, ChatRequest, ChatResponse, ContentBlock, StopReason, Usage};
2740 use crate::stores::InMemoryStore;
2741 use crate::tools::{Tool, ToolContext, ToolRegistry};
2742 use crate::types::{AgentConfig, AgentInput, ToolResult, ToolTier};
2743 use anyhow::Result;
2744 use async_trait::async_trait;
2745 use serde_json::json;
2746 use std::sync::RwLock;
2747 use std::sync::atomic::{AtomicUsize, Ordering};
2748
2749 struct MockProvider {
2754 responses: RwLock<Vec<ChatOutcome>>,
2755 call_count: AtomicUsize,
2756 }
2757
2758 impl MockProvider {
2759 fn new(responses: Vec<ChatOutcome>) -> Self {
2760 Self {
2761 responses: RwLock::new(responses),
2762 call_count: AtomicUsize::new(0),
2763 }
2764 }
2765
2766 fn text_response(text: &str) -> ChatOutcome {
2767 ChatOutcome::Success(ChatResponse {
2768 id: "msg_1".to_string(),
2769 content: vec![ContentBlock::Text {
2770 text: text.to_string(),
2771 }],
2772 model: "mock-model".to_string(),
2773 stop_reason: Some(StopReason::EndTurn),
2774 usage: Usage {
2775 input_tokens: 10,
2776 output_tokens: 20,
2777 },
2778 })
2779 }
2780
2781 fn tool_use_response(
2782 tool_id: &str,
2783 tool_name: &str,
2784 input: serde_json::Value,
2785 ) -> ChatOutcome {
2786 ChatOutcome::Success(ChatResponse {
2787 id: "msg_1".to_string(),
2788 content: vec![ContentBlock::ToolUse {
2789 id: tool_id.to_string(),
2790 name: tool_name.to_string(),
2791 input,
2792 thought_signature: None,
2793 }],
2794 model: "mock-model".to_string(),
2795 stop_reason: Some(StopReason::ToolUse),
2796 usage: Usage {
2797 input_tokens: 10,
2798 output_tokens: 20,
2799 },
2800 })
2801 }
2802 }
2803
2804 #[async_trait]
2805 impl LlmProvider for MockProvider {
2806 async fn chat(&self, _request: ChatRequest) -> Result<ChatOutcome> {
2807 let idx = self.call_count.fetch_add(1, Ordering::SeqCst);
2808 let responses = self.responses.read().unwrap();
2809 if idx < responses.len() {
2810 Ok(responses[idx].clone())
2811 } else {
2812 Ok(Self::text_response("Done"))
2814 }
2815 }
2816
2817 fn model(&self) -> &'static str {
2818 "mock-model"
2819 }
2820
2821 fn provider(&self) -> &'static str {
2822 "mock"
2823 }
2824 }
2825
2826 impl Clone for ChatOutcome {
2828 fn clone(&self) -> Self {
2829 match self {
2830 Self::Success(r) => Self::Success(r.clone()),
2831 Self::RateLimited => Self::RateLimited,
2832 Self::InvalidRequest(s) => Self::InvalidRequest(s.clone()),
2833 Self::ServerError(s) => Self::ServerError(s.clone()),
2834 }
2835 }
2836 }
2837
2838 struct EchoTool;
2843
2844 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
2846 #[serde(rename_all = "snake_case")]
2847 enum TestToolName {
2848 Echo,
2849 }
2850
2851 impl crate::tools::ToolName for TestToolName {}
2852
2853 impl Tool<()> for EchoTool {
2854 type Name = TestToolName;
2855
2856 fn name(&self) -> TestToolName {
2857 TestToolName::Echo
2858 }
2859
2860 fn display_name(&self) -> &'static str {
2861 "Echo"
2862 }
2863
2864 fn description(&self) -> &'static str {
2865 "Echo the input message"
2866 }
2867
2868 fn input_schema(&self) -> serde_json::Value {
2869 json!({
2870 "type": "object",
2871 "properties": {
2872 "message": { "type": "string" }
2873 },
2874 "required": ["message"]
2875 })
2876 }
2877
2878 fn tier(&self) -> ToolTier {
2879 ToolTier::Observe
2880 }
2881
2882 async fn execute(
2883 &self,
2884 _ctx: &ToolContext<()>,
2885 input: serde_json::Value,
2886 ) -> Result<ToolResult> {
2887 let message = input
2888 .get("message")
2889 .and_then(|v| v.as_str())
2890 .unwrap_or("no message");
2891 Ok(ToolResult::success(format!("Echo: {message}")))
2892 }
2893 }
2894
2895 #[test]
2900 fn test_builder_creates_agent_loop() {
2901 let provider = MockProvider::new(vec![]);
2902 let agent = builder::<()>().provider(provider).build();
2903
2904 assert_eq!(agent.config.max_turns, 10);
2905 assert_eq!(agent.config.max_tokens, 4096);
2906 }
2907
2908 #[test]
2909 fn test_builder_with_custom_config() {
2910 let provider = MockProvider::new(vec![]);
2911 let config = AgentConfig {
2912 max_turns: 5,
2913 max_tokens: 2048,
2914 system_prompt: "Custom prompt".to_string(),
2915 model: "custom-model".to_string(),
2916 ..Default::default()
2917 };
2918
2919 let agent = builder::<()>().provider(provider).config(config).build();
2920
2921 assert_eq!(agent.config.max_turns, 5);
2922 assert_eq!(agent.config.max_tokens, 2048);
2923 assert_eq!(agent.config.system_prompt, "Custom prompt");
2924 }
2925
2926 #[test]
2927 fn test_builder_with_tools() {
2928 let provider = MockProvider::new(vec![]);
2929 let mut tools = ToolRegistry::new();
2930 tools.register(EchoTool);
2931
2932 let agent = builder::<()>().provider(provider).tools(tools).build();
2933
2934 assert_eq!(agent.tools.len(), 1);
2935 }
2936
2937 #[test]
2938 fn test_builder_with_custom_stores() {
2939 let provider = MockProvider::new(vec![]);
2940 let message_store = InMemoryStore::new();
2941 let state_store = InMemoryStore::new();
2942
2943 let agent = builder::<()>()
2944 .provider(provider)
2945 .hooks(AllowAllHooks)
2946 .message_store(message_store)
2947 .state_store(state_store)
2948 .build_with_stores();
2949
2950 assert_eq!(agent.config.max_turns, 10);
2952 }
2953
2954 #[tokio::test]
2959 async fn test_simple_text_response() -> anyhow::Result<()> {
2960 let provider = MockProvider::new(vec![MockProvider::text_response("Hello, user!")]);
2961
2962 let agent = builder::<()>().provider(provider).build();
2963
2964 let thread_id = ThreadId::new();
2965 let tool_ctx = ToolContext::new(());
2966 let (mut rx, _final_state) =
2967 agent.run(thread_id, AgentInput::Text("Hi".to_string()), tool_ctx);
2968
2969 let mut events = Vec::new();
2970 while let Some(event) = rx.recv().await {
2971 events.push(event);
2972 }
2973
2974 assert!(events.iter().any(|e| matches!(e, AgentEvent::Text { .. })));
2976 assert!(events.iter().any(|e| matches!(e, AgentEvent::Done { .. })));
2977
2978 Ok(())
2979 }
2980
2981 #[tokio::test]
2982 async fn test_tool_execution() -> anyhow::Result<()> {
2983 let provider = MockProvider::new(vec![
2984 MockProvider::tool_use_response("tool_1", "echo", json!({"message": "test"})),
2986 MockProvider::text_response("Tool executed successfully"),
2988 ]);
2989
2990 let mut tools = ToolRegistry::new();
2991 tools.register(EchoTool);
2992
2993 let agent = builder::<()>().provider(provider).tools(tools).build();
2994
2995 let thread_id = ThreadId::new();
2996 let tool_ctx = ToolContext::new(());
2997 let (mut rx, _final_state) = agent.run(
2998 thread_id,
2999 AgentInput::Text("Run echo".to_string()),
3000 tool_ctx,
3001 );
3002
3003 let mut events = Vec::new();
3004 while let Some(event) = rx.recv().await {
3005 events.push(event);
3006 }
3007
3008 assert!(
3010 events
3011 .iter()
3012 .any(|e| matches!(e, AgentEvent::ToolCallStart { .. }))
3013 );
3014 assert!(
3015 events
3016 .iter()
3017 .any(|e| matches!(e, AgentEvent::ToolCallEnd { .. }))
3018 );
3019
3020 Ok(())
3021 }
3022
3023 #[tokio::test]
3024 async fn test_max_turns_limit() -> anyhow::Result<()> {
3025 let provider = MockProvider::new(vec![
3027 MockProvider::tool_use_response("tool_1", "echo", json!({"message": "1"})),
3028 MockProvider::tool_use_response("tool_2", "echo", json!({"message": "2"})),
3029 MockProvider::tool_use_response("tool_3", "echo", json!({"message": "3"})),
3030 MockProvider::tool_use_response("tool_4", "echo", json!({"message": "4"})),
3031 ]);
3032
3033 let mut tools = ToolRegistry::new();
3034 tools.register(EchoTool);
3035
3036 let config = AgentConfig {
3037 max_turns: 2,
3038 ..Default::default()
3039 };
3040
3041 let agent = builder::<()>()
3042 .provider(provider)
3043 .tools(tools)
3044 .config(config)
3045 .build();
3046
3047 let thread_id = ThreadId::new();
3048 let tool_ctx = ToolContext::new(());
3049 let (mut rx, _final_state) =
3050 agent.run(thread_id, AgentInput::Text("Loop".to_string()), tool_ctx);
3051
3052 let mut events = Vec::new();
3053 while let Some(event) = rx.recv().await {
3054 events.push(event);
3055 }
3056
3057 assert!(events.iter().any(|e| {
3059 matches!(e, AgentEvent::Error { message, .. } if message.contains("Maximum turns"))
3060 }));
3061
3062 Ok(())
3063 }
3064
3065 #[tokio::test]
3066 async fn test_unknown_tool_handling() -> anyhow::Result<()> {
3067 let provider = MockProvider::new(vec![
3068 MockProvider::tool_use_response("tool_1", "nonexistent_tool", json!({})),
3070 MockProvider::text_response("I couldn't find that tool."),
3072 ]);
3073
3074 let tools = ToolRegistry::new();
3076
3077 let agent = builder::<()>().provider(provider).tools(tools).build();
3078
3079 let thread_id = ThreadId::new();
3080 let tool_ctx = ToolContext::new(());
3081 let (mut rx, _final_state) = agent.run(
3082 thread_id,
3083 AgentInput::Text("Call unknown".to_string()),
3084 tool_ctx,
3085 );
3086
3087 let mut events = Vec::new();
3088 while let Some(event) = rx.recv().await {
3089 events.push(event);
3090 }
3091
3092 assert!(events.iter().any(|e| matches!(e, AgentEvent::Done { .. })));
3095
3096 assert!(events.iter().any(|e| {
3098 matches!(e, AgentEvent::Text { message_id, text } if text.contains("couldn't find"))
3099 }));
3100
3101 Ok(())
3102 }
3103
3104 #[tokio::test]
3105 async fn test_rate_limit_handling() -> anyhow::Result<()> {
3106 let provider = MockProvider::new(vec![
3108 ChatOutcome::RateLimited,
3109 ChatOutcome::RateLimited,
3110 ChatOutcome::RateLimited,
3111 ChatOutcome::RateLimited,
3112 ChatOutcome::RateLimited,
3113 ChatOutcome::RateLimited, ]);
3115
3116 let config = AgentConfig {
3118 retry: crate::types::RetryConfig::fast(),
3119 ..Default::default()
3120 };
3121
3122 let agent = builder::<()>().provider(provider).config(config).build();
3123
3124 let thread_id = ThreadId::new();
3125 let tool_ctx = ToolContext::new(());
3126 let (mut rx, _final_state) =
3127 agent.run(thread_id, AgentInput::Text("Hi".to_string()), tool_ctx);
3128
3129 let mut events = Vec::new();
3130 while let Some(event) = rx.recv().await {
3131 events.push(event);
3132 }
3133
3134 assert!(events.iter().any(|e| {
3136 matches!(e, AgentEvent::Error { message, recoverable: true } if message.contains("Rate limited"))
3137 }));
3138
3139 Ok(())
3140 }
3141
3142 #[tokio::test]
3143 async fn test_rate_limit_recovery() -> anyhow::Result<()> {
3144 let provider = MockProvider::new(vec![
3146 ChatOutcome::RateLimited,
3147 MockProvider::text_response("Recovered after rate limit"),
3148 ]);
3149
3150 let config = AgentConfig {
3152 retry: crate::types::RetryConfig::fast(),
3153 ..Default::default()
3154 };
3155
3156 let agent = builder::<()>().provider(provider).config(config).build();
3157
3158 let thread_id = ThreadId::new();
3159 let tool_ctx = ToolContext::new(());
3160 let (mut rx, _final_state) =
3161 agent.run(thread_id, AgentInput::Text("Hi".to_string()), tool_ctx);
3162
3163 let mut events = Vec::new();
3164 while let Some(event) = rx.recv().await {
3165 events.push(event);
3166 }
3167
3168 assert!(events.iter().any(|e| matches!(e, AgentEvent::Done { .. })));
3170
3171 Ok(())
3172 }
3173
3174 #[tokio::test]
3175 async fn test_server_error_handling() -> anyhow::Result<()> {
3176 let provider = MockProvider::new(vec![
3178 ChatOutcome::ServerError("Internal error".to_string()),
3179 ChatOutcome::ServerError("Internal error".to_string()),
3180 ChatOutcome::ServerError("Internal error".to_string()),
3181 ChatOutcome::ServerError("Internal error".to_string()),
3182 ChatOutcome::ServerError("Internal error".to_string()),
3183 ChatOutcome::ServerError("Internal error".to_string()), ]);
3185
3186 let config = AgentConfig {
3188 retry: crate::types::RetryConfig::fast(),
3189 ..Default::default()
3190 };
3191
3192 let agent = builder::<()>().provider(provider).config(config).build();
3193
3194 let thread_id = ThreadId::new();
3195 let tool_ctx = ToolContext::new(());
3196 let (mut rx, _final_state) =
3197 agent.run(thread_id, AgentInput::Text("Hi".to_string()), tool_ctx);
3198
3199 let mut events = Vec::new();
3200 while let Some(event) = rx.recv().await {
3201 events.push(event);
3202 }
3203
3204 assert!(events.iter().any(|e| {
3206 matches!(e, AgentEvent::Error { message, recoverable: true } if message.contains("Server error"))
3207 }));
3208
3209 Ok(())
3210 }
3211
3212 #[tokio::test]
3213 async fn test_server_error_recovery() -> anyhow::Result<()> {
3214 let provider = MockProvider::new(vec![
3216 ChatOutcome::ServerError("Temporary error".to_string()),
3217 MockProvider::text_response("Recovered after server error"),
3218 ]);
3219
3220 let config = AgentConfig {
3222 retry: crate::types::RetryConfig::fast(),
3223 ..Default::default()
3224 };
3225
3226 let agent = builder::<()>().provider(provider).config(config).build();
3227
3228 let thread_id = ThreadId::new();
3229 let tool_ctx = ToolContext::new(());
3230 let (mut rx, _final_state) =
3231 agent.run(thread_id, AgentInput::Text("Hi".to_string()), tool_ctx);
3232
3233 let mut events = Vec::new();
3234 while let Some(event) = rx.recv().await {
3235 events.push(event);
3236 }
3237
3238 assert!(events.iter().any(|e| matches!(e, AgentEvent::Done { .. })));
3240
3241 Ok(())
3242 }
3243
3244 #[test]
3249 fn test_extract_content_text_only() {
3250 let response = ChatResponse {
3251 id: "msg_1".to_string(),
3252 content: vec![ContentBlock::Text {
3253 text: "Hello".to_string(),
3254 }],
3255 model: "test".to_string(),
3256 stop_reason: None,
3257 usage: Usage {
3258 input_tokens: 0,
3259 output_tokens: 0,
3260 },
3261 };
3262
3263 let (thinking, text, tool_uses) = extract_content(&response);
3264 assert!(thinking.is_none());
3265 assert_eq!(text, Some("Hello".to_string()));
3266 assert!(tool_uses.is_empty());
3267 }
3268
3269 #[test]
3270 fn test_extract_content_tool_use() {
3271 let response = ChatResponse {
3272 id: "msg_1".to_string(),
3273 content: vec![ContentBlock::ToolUse {
3274 id: "tool_1".to_string(),
3275 name: "test_tool".to_string(),
3276 input: json!({"key": "value"}),
3277 thought_signature: None,
3278 }],
3279 model: "test".to_string(),
3280 stop_reason: None,
3281 usage: Usage {
3282 input_tokens: 0,
3283 output_tokens: 0,
3284 },
3285 };
3286
3287 let (thinking, text, tool_uses) = extract_content(&response);
3288 assert!(thinking.is_none());
3289 assert!(text.is_none());
3290 assert_eq!(tool_uses.len(), 1);
3291 assert_eq!(tool_uses[0].1, "test_tool");
3292 }
3293
3294 #[test]
3295 fn test_extract_content_mixed() {
3296 let response = ChatResponse {
3297 id: "msg_1".to_string(),
3298 content: vec![
3299 ContentBlock::Text {
3300 text: "Let me help".to_string(),
3301 },
3302 ContentBlock::ToolUse {
3303 id: "tool_1".to_string(),
3304 name: "helper".to_string(),
3305 input: json!({}),
3306 thought_signature: None,
3307 },
3308 ],
3309 model: "test".to_string(),
3310 stop_reason: None,
3311 usage: Usage {
3312 input_tokens: 0,
3313 output_tokens: 0,
3314 },
3315 };
3316
3317 let (thinking, text, tool_uses) = extract_content(&response);
3318 assert!(thinking.is_none());
3319 assert_eq!(text, Some("Let me help".to_string()));
3320 assert_eq!(tool_uses.len(), 1);
3321 }
3322
3323 #[test]
3324 fn test_millis_to_u64() {
3325 assert_eq!(millis_to_u64(0), 0);
3326 assert_eq!(millis_to_u64(1000), 1000);
3327 assert_eq!(millis_to_u64(u128::from(u64::MAX)), u64::MAX);
3328 assert_eq!(millis_to_u64(u128::from(u64::MAX) + 1), u64::MAX);
3329 }
3330
3331 #[test]
3332 fn test_build_assistant_message() {
3333 let response = ChatResponse {
3334 id: "msg_1".to_string(),
3335 content: vec![
3336 ContentBlock::Text {
3337 text: "Response text".to_string(),
3338 },
3339 ContentBlock::ToolUse {
3340 id: "tool_1".to_string(),
3341 name: "echo".to_string(),
3342 input: json!({"message": "test"}),
3343 thought_signature: None,
3344 },
3345 ],
3346 model: "test".to_string(),
3347 stop_reason: None,
3348 usage: Usage {
3349 input_tokens: 0,
3350 output_tokens: 0,
3351 },
3352 };
3353
3354 let msg = build_assistant_message(&response);
3355 assert_eq!(msg.role, Role::Assistant);
3356
3357 if let Content::Blocks(blocks) = msg.content {
3358 assert_eq!(blocks.len(), 2);
3359 } else {
3360 panic!("Expected Content::Blocks");
3361 }
3362 }
3363}