1use crate::context::{ContextProvider, ContextQuery, ContextResult};
13use crate::hitl::ConfirmationProvider;
14use crate::hooks::{
15 GenerateEndEvent, GenerateStartEvent, HookEvent, HookExecutor, HookResult, PostToolUseEvent,
16 PreToolUseEvent, TokenUsageInfo, ToolCallInfo, ToolResultData,
17};
18use crate::llm::{LlmClient, LlmResponse, Message, TokenUsage, ToolDefinition};
19use crate::permissions::{PermissionChecker, PermissionDecision};
20use crate::planning::{AgentGoal, ExecutionPlan, TaskStatus};
21use crate::queue::SessionCommand;
22use crate::session_lane_queue::SessionLaneQueue;
23use crate::tools::{ToolContext, ToolExecutor, ToolStreamEvent};
24use anyhow::{Context, Result};
25use async_trait::async_trait;
26use futures::future::join_all;
27use serde::{Deserialize, Serialize};
28use serde_json::Value;
29use std::sync::Arc;
30use std::time::Duration;
31use tokio::sync::{mpsc, RwLock};
32
33const MAX_TOOL_ROUNDS: usize = 50;
35
36#[derive(Clone)]
38pub struct AgentConfig {
39 pub system_prompt: Option<String>,
40 pub tools: Vec<ToolDefinition>,
41 pub max_tool_rounds: usize,
42 pub permission_checker: Option<Arc<dyn PermissionChecker>>,
44 pub confirmation_manager: Option<Arc<dyn ConfirmationProvider>>,
46 pub context_providers: Vec<Arc<dyn ContextProvider>>,
48 pub planning_enabled: bool,
50 pub goal_tracking: bool,
52 pub hook_engine: Option<Arc<dyn HookExecutor>>,
54 pub skill_registry: Option<Arc<crate::skills::SkillRegistry>>,
56 pub max_parse_retries: u32,
62 pub tool_timeout_ms: Option<u64>,
68 pub circuit_breaker_threshold: u32,
75}
76
77impl std::fmt::Debug for AgentConfig {
78 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79 f.debug_struct("AgentConfig")
80 .field("system_prompt", &self.system_prompt)
81 .field("tools", &self.tools)
82 .field("max_tool_rounds", &self.max_tool_rounds)
83 .field("permission_checker", &self.permission_checker.is_some())
84 .field("confirmation_manager", &self.confirmation_manager.is_some())
85 .field("context_providers", &self.context_providers.len())
86 .field("planning_enabled", &self.planning_enabled)
87 .field("goal_tracking", &self.goal_tracking)
88 .field("hook_engine", &self.hook_engine.is_some())
89 .field(
90 "skill_registry",
91 &self.skill_registry.as_ref().map(|r| r.len()),
92 )
93 .field("max_parse_retries", &self.max_parse_retries)
94 .field("tool_timeout_ms", &self.tool_timeout_ms)
95 .field("circuit_breaker_threshold", &self.circuit_breaker_threshold)
96 .finish()
97 }
98}
99
100impl Default for AgentConfig {
101 fn default() -> Self {
102 Self {
103 system_prompt: None,
104 tools: Vec::new(), max_tool_rounds: MAX_TOOL_ROUNDS,
106 permission_checker: None,
107 confirmation_manager: None,
108 context_providers: Vec::new(),
109 planning_enabled: false,
110 goal_tracking: false,
111 hook_engine: None,
112 skill_registry: None,
113 max_parse_retries: 2,
114 tool_timeout_ms: None,
115 circuit_breaker_threshold: 3,
116 }
117 }
118}
119
120#[derive(Debug, Clone, Serialize, Deserialize)]
126#[serde(tag = "type")]
127#[non_exhaustive]
128pub enum AgentEvent {
129 #[serde(rename = "agent_start")]
131 Start { prompt: String },
132
133 #[serde(rename = "turn_start")]
135 TurnStart { turn: usize },
136
137 #[serde(rename = "text_delta")]
139 TextDelta { text: String },
140
141 #[serde(rename = "tool_start")]
143 ToolStart { id: String, name: String },
144
145 #[serde(rename = "tool_end")]
147 ToolEnd {
148 id: String,
149 name: String,
150 output: String,
151 exit_code: i32,
152 },
153
154 #[serde(rename = "tool_output_delta")]
156 ToolOutputDelta {
157 id: String,
158 name: String,
159 delta: String,
160 },
161
162 #[serde(rename = "turn_end")]
164 TurnEnd { turn: usize, usage: TokenUsage },
165
166 #[serde(rename = "agent_end")]
168 End { text: String, usage: TokenUsage },
169
170 #[serde(rename = "error")]
172 Error { message: String },
173
174 #[serde(rename = "confirmation_required")]
176 ConfirmationRequired {
177 tool_id: String,
178 tool_name: String,
179 args: serde_json::Value,
180 timeout_ms: u64,
181 },
182
183 #[serde(rename = "confirmation_received")]
185 ConfirmationReceived {
186 tool_id: String,
187 approved: bool,
188 reason: Option<String>,
189 },
190
191 #[serde(rename = "confirmation_timeout")]
193 ConfirmationTimeout {
194 tool_id: String,
195 action_taken: String, },
197
198 #[serde(rename = "external_task_pending")]
200 ExternalTaskPending {
201 task_id: String,
202 session_id: String,
203 lane: crate::hitl::SessionLane,
204 command_type: String,
205 payload: serde_json::Value,
206 timeout_ms: u64,
207 },
208
209 #[serde(rename = "external_task_completed")]
211 ExternalTaskCompleted {
212 task_id: String,
213 session_id: String,
214 success: bool,
215 },
216
217 #[serde(rename = "permission_denied")]
219 PermissionDenied {
220 tool_id: String,
221 tool_name: String,
222 args: serde_json::Value,
223 reason: String,
224 },
225
226 #[serde(rename = "context_resolving")]
228 ContextResolving { providers: Vec<String> },
229
230 #[serde(rename = "context_resolved")]
232 ContextResolved {
233 total_items: usize,
234 total_tokens: usize,
235 },
236
237 #[serde(rename = "command_dead_lettered")]
242 CommandDeadLettered {
243 command_id: String,
244 command_type: String,
245 lane: String,
246 error: String,
247 attempts: u32,
248 },
249
250 #[serde(rename = "command_retry")]
252 CommandRetry {
253 command_id: String,
254 command_type: String,
255 lane: String,
256 attempt: u32,
257 delay_ms: u64,
258 },
259
260 #[serde(rename = "queue_alert")]
262 QueueAlert {
263 level: String,
264 alert_type: String,
265 message: String,
266 },
267
268 #[serde(rename = "task_updated")]
273 TaskUpdated {
274 session_id: String,
275 tasks: Vec<crate::planning::Task>,
276 },
277
278 #[serde(rename = "memory_stored")]
283 MemoryStored {
284 memory_id: String,
285 memory_type: String,
286 importance: f32,
287 tags: Vec<String>,
288 },
289
290 #[serde(rename = "memory_recalled")]
292 MemoryRecalled {
293 memory_id: String,
294 content: String,
295 relevance: f32,
296 },
297
298 #[serde(rename = "memories_searched")]
300 MemoriesSearched {
301 query: Option<String>,
302 tags: Vec<String>,
303 result_count: usize,
304 },
305
306 #[serde(rename = "memory_cleared")]
308 MemoryCleared {
309 tier: String, count: u64,
311 },
312
313 #[serde(rename = "subagent_start")]
318 SubagentStart {
319 task_id: String,
321 session_id: String,
323 parent_session_id: String,
325 agent: String,
327 description: String,
329 },
330
331 #[serde(rename = "subagent_progress")]
333 SubagentProgress {
334 task_id: String,
336 session_id: String,
338 status: String,
340 metadata: serde_json::Value,
342 },
343
344 #[serde(rename = "subagent_end")]
346 SubagentEnd {
347 task_id: String,
349 session_id: String,
351 agent: String,
353 output: String,
355 success: bool,
357 },
358
359 #[serde(rename = "planning_start")]
364 PlanningStart { prompt: String },
365
366 #[serde(rename = "planning_end")]
368 PlanningEnd {
369 plan: ExecutionPlan,
370 estimated_steps: usize,
371 },
372
373 #[serde(rename = "step_start")]
375 StepStart {
376 step_id: String,
377 description: String,
378 step_number: usize,
379 total_steps: usize,
380 },
381
382 #[serde(rename = "step_end")]
384 StepEnd {
385 step_id: String,
386 status: TaskStatus,
387 step_number: usize,
388 total_steps: usize,
389 },
390
391 #[serde(rename = "goal_extracted")]
393 GoalExtracted { goal: AgentGoal },
394
395 #[serde(rename = "goal_progress")]
397 GoalProgress {
398 goal: String,
399 progress: f32,
400 completed_steps: usize,
401 total_steps: usize,
402 },
403
404 #[serde(rename = "goal_achieved")]
406 GoalAchieved {
407 goal: String,
408 total_steps: usize,
409 duration_ms: i64,
410 },
411
412 #[serde(rename = "context_compacted")]
417 ContextCompacted {
418 session_id: String,
419 before_messages: usize,
420 after_messages: usize,
421 percent_before: f32,
422 },
423
424 #[serde(rename = "persistence_failed")]
429 PersistenceFailed {
430 session_id: String,
431 operation: String,
432 error: String,
433 },
434}
435
436#[derive(Debug, Clone)]
438pub struct AgentResult {
439 pub text: String,
440 pub messages: Vec<Message>,
441 pub usage: TokenUsage,
442 pub tool_calls_count: usize,
443}
444
445pub struct ToolCommand {
453 tool_executor: Arc<ToolExecutor>,
454 tool_name: String,
455 tool_args: Value,
456 tool_context: ToolContext,
457 skill_registry: Option<Arc<crate::skills::SkillRegistry>>,
458}
459
460impl ToolCommand {
461 pub fn new(
463 tool_executor: Arc<ToolExecutor>,
464 tool_name: String,
465 tool_args: Value,
466 tool_context: ToolContext,
467 skill_registry: Option<Arc<crate::skills::SkillRegistry>>,
468 ) -> Self {
469 Self {
470 tool_executor,
471 tool_name,
472 tool_args,
473 tool_context,
474 skill_registry,
475 }
476 }
477}
478
479#[async_trait]
480impl SessionCommand for ToolCommand {
481 async fn execute(&self) -> Result<Value> {
482 if let Some(registry) = &self.skill_registry {
484 let instruction_skills = registry.by_kind(crate::skills::SkillKind::Instruction);
485
486 let has_restrictions = instruction_skills.iter().any(|s| s.allowed_tools.is_some());
488
489 if has_restrictions {
490 let mut allowed = false;
491
492 for skill in &instruction_skills {
493 if skill.is_tool_allowed(&self.tool_name) {
494 allowed = true;
495 break;
496 }
497 }
498
499 if !allowed {
500 return Err(anyhow::anyhow!(
501 "Tool '{}' is not allowed by any active skill. Active skills restrict tools to their allowed-tools lists.",
502 self.tool_name
503 ));
504 }
505 }
506 }
507
508 let result = self
510 .tool_executor
511 .execute_with_context(&self.tool_name, &self.tool_args, &self.tool_context)
512 .await?;
513 Ok(serde_json::json!({
514 "output": result.output,
515 "exit_code": result.exit_code,
516 "metadata": result.metadata,
517 }))
518 }
519
520 fn command_type(&self) -> &str {
521 &self.tool_name
522 }
523
524 fn payload(&self) -> Value {
525 self.tool_args.clone()
526 }
527}
528
529#[derive(Clone)]
535pub struct AgentLoop {
536 llm_client: Arc<dyn LlmClient>,
537 tool_executor: Arc<ToolExecutor>,
538 tool_context: ToolContext,
539 config: AgentConfig,
540 tool_metrics: Option<Arc<RwLock<crate::telemetry::ToolMetrics>>>,
542 command_queue: Option<Arc<SessionLaneQueue>>,
544}
545
546impl AgentLoop {
547 pub fn new(
548 llm_client: Arc<dyn LlmClient>,
549 tool_executor: Arc<ToolExecutor>,
550 tool_context: ToolContext,
551 config: AgentConfig,
552 ) -> Self {
553 Self {
554 llm_client,
555 tool_executor,
556 tool_context,
557 config,
558 tool_metrics: None,
559 command_queue: None,
560 }
561 }
562
563 pub fn with_tool_metrics(
565 mut self,
566 metrics: Arc<RwLock<crate::telemetry::ToolMetrics>>,
567 ) -> Self {
568 self.tool_metrics = Some(metrics);
569 self
570 }
571
572 pub fn with_queue(mut self, queue: Arc<SessionLaneQueue>) -> Self {
577 self.command_queue = Some(queue);
578 self
579 }
580
581 async fn execute_tool_timed(
587 &self,
588 name: &str,
589 args: &serde_json::Value,
590 ctx: &ToolContext,
591 ) -> anyhow::Result<crate::tools::ToolResult> {
592 let fut = self.tool_executor.execute_with_context(name, args, ctx);
593 if let Some(timeout_ms) = self.config.tool_timeout_ms {
594 match tokio::time::timeout(Duration::from_millis(timeout_ms), fut).await {
595 Ok(result) => result,
596 Err(_) => Err(anyhow::anyhow!(
597 "Tool '{}' timed out after {}ms",
598 name,
599 timeout_ms
600 )),
601 }
602 } else {
603 fut.await
604 }
605 }
606
607 async fn call_llm(
615 &self,
616 messages: &[Message],
617 system: Option<&str>,
618 event_tx: &Option<mpsc::Sender<AgentEvent>>,
619 ) -> anyhow::Result<LlmResponse> {
620 if event_tx.is_some() {
621 let mut stream_rx = self
622 .llm_client
623 .complete_streaming(messages, system, &self.config.tools)
624 .await
625 .context("LLM streaming call failed")?;
626
627 let mut final_response: Option<LlmResponse> = None;
628 while let Some(event) = stream_rx.recv().await {
629 match event {
630 crate::llm::StreamEvent::TextDelta(text) => {
631 if let Some(tx) = event_tx {
632 tx.send(AgentEvent::TextDelta { text }).await.ok();
633 }
634 }
635 crate::llm::StreamEvent::ToolUseStart { id, name } => {
636 if let Some(tx) = event_tx {
637 tx.send(AgentEvent::ToolStart { id, name }).await.ok();
638 }
639 }
640 crate::llm::StreamEvent::ToolUseInputDelta(_) => {}
641 crate::llm::StreamEvent::Done(resp) => {
642 final_response = Some(resp);
643 }
644 }
645 }
646 final_response.context("Stream ended without final response")
647 } else {
648 self.llm_client
649 .complete(messages, system, &self.config.tools)
650 .await
651 .context("LLM call failed")
652 }
653 }
654
655 fn streaming_tool_context(
664 &self,
665 event_tx: &Option<mpsc::Sender<AgentEvent>>,
666 tool_id: &str,
667 tool_name: &str,
668 ) -> ToolContext {
669 let mut ctx = self.tool_context.clone();
670 if let Some(agent_tx) = event_tx {
671 let (tool_tx, mut tool_rx) = mpsc::channel::<ToolStreamEvent>(64);
672 ctx.event_tx = Some(tool_tx);
673
674 let agent_tx = agent_tx.clone();
675 let tool_id = tool_id.to_string();
676 let tool_name = tool_name.to_string();
677 tokio::spawn(async move {
678 while let Some(event) = tool_rx.recv().await {
679 match event {
680 ToolStreamEvent::OutputDelta(delta) => {
681 agent_tx
682 .send(AgentEvent::ToolOutputDelta {
683 id: tool_id.clone(),
684 name: tool_name.clone(),
685 delta,
686 })
687 .await
688 .ok();
689 }
690 }
691 }
692 });
693 }
694 ctx
695 }
696
697 async fn resolve_context(&self, prompt: &str, session_id: Option<&str>) -> Vec<ContextResult> {
701 if self.config.context_providers.is_empty() {
702 return Vec::new();
703 }
704
705 let query = ContextQuery::new(prompt).with_session_id(session_id.unwrap_or(""));
706
707 let futures = self
708 .config
709 .context_providers
710 .iter()
711 .map(|p| p.query(&query));
712 let outcomes = join_all(futures).await;
713
714 outcomes
715 .into_iter()
716 .enumerate()
717 .filter_map(|(i, r)| match r {
718 Ok(result) if !result.is_empty() => Some(result),
719 Ok(_) => None,
720 Err(e) => {
721 tracing::warn!(
722 "Context provider '{}' failed: {}",
723 self.config.context_providers[i].name(),
724 e
725 );
726 None
727 }
728 })
729 .collect()
730 }
731
732 fn build_augmented_system_prompt(&self, context_results: &[ContextResult]) -> Option<String> {
734 if context_results.is_empty() {
735 return self.config.system_prompt.clone();
736 }
737
738 let context_xml: String = context_results
740 .iter()
741 .map(|r| r.to_xml())
742 .collect::<Vec<_>>()
743 .join("\n\n");
744
745 match &self.config.system_prompt {
747 Some(system) => Some(format!("{}\n\n{}", system, context_xml)),
748 None => Some(context_xml),
749 }
750 }
751
752 async fn notify_turn_complete(&self, session_id: &str, prompt: &str, response: &str) {
754 let futures = self
755 .config
756 .context_providers
757 .iter()
758 .map(|p| p.on_turn_complete(session_id, prompt, response));
759 let outcomes = join_all(futures).await;
760
761 for (i, result) in outcomes.into_iter().enumerate() {
762 if let Err(e) = result {
763 tracing::warn!(
764 "Context provider '{}' on_turn_complete failed: {}",
765 self.config.context_providers[i].name(),
766 e
767 );
768 }
769 }
770 }
771
772 async fn fire_pre_tool_use(
775 &self,
776 session_id: &str,
777 tool_name: &str,
778 args: &serde_json::Value,
779 ) -> Option<HookResult> {
780 if let Some(he) = &self.config.hook_engine {
781 let event = HookEvent::PreToolUse(PreToolUseEvent {
782 session_id: session_id.to_string(),
783 tool: tool_name.to_string(),
784 args: args.clone(),
785 working_directory: self.tool_context.workspace.to_string_lossy().to_string(),
786 recent_tools: Vec::new(),
787 });
788 let result = he.fire(&event).await;
789 if result.is_block() {
790 return Some(result);
791 }
792 }
793 None
794 }
795
796 async fn fire_post_tool_use(
798 &self,
799 session_id: &str,
800 tool_name: &str,
801 args: &serde_json::Value,
802 output: &str,
803 success: bool,
804 duration_ms: u64,
805 ) {
806 if let Some(he) = &self.config.hook_engine {
807 let event = HookEvent::PostToolUse(PostToolUseEvent {
808 session_id: session_id.to_string(),
809 tool: tool_name.to_string(),
810 args: args.clone(),
811 result: ToolResultData {
812 success,
813 output: output.to_string(),
814 exit_code: if success { Some(0) } else { Some(1) },
815 duration_ms,
816 },
817 });
818 let _ = he.fire(&event).await;
819 }
820 }
821
822 async fn fire_generate_start(
824 &self,
825 session_id: &str,
826 prompt: &str,
827 system_prompt: &Option<String>,
828 ) {
829 if let Some(he) = &self.config.hook_engine {
830 let event = HookEvent::GenerateStart(GenerateStartEvent {
831 session_id: session_id.to_string(),
832 prompt: prompt.to_string(),
833 system_prompt: system_prompt.clone(),
834 model_provider: String::new(),
835 model_name: String::new(),
836 available_tools: self.config.tools.iter().map(|t| t.name.clone()).collect(),
837 });
838 let _ = he.fire(&event).await;
839 }
840 }
841
842 async fn fire_generate_end(
844 &self,
845 session_id: &str,
846 prompt: &str,
847 response: &LlmResponse,
848 duration_ms: u64,
849 ) {
850 if let Some(he) = &self.config.hook_engine {
851 let tool_calls: Vec<ToolCallInfo> = response
852 .tool_calls()
853 .iter()
854 .map(|tc| ToolCallInfo {
855 name: tc.name.clone(),
856 args: tc.args.clone(),
857 })
858 .collect();
859
860 let event = HookEvent::GenerateEnd(GenerateEndEvent {
861 session_id: session_id.to_string(),
862 prompt: prompt.to_string(),
863 response_text: response.text().to_string(),
864 tool_calls,
865 usage: TokenUsageInfo {
866 prompt_tokens: response.usage.prompt_tokens as i32,
867 completion_tokens: response.usage.completion_tokens as i32,
868 total_tokens: response.usage.total_tokens as i32,
869 },
870 duration_ms,
871 });
872 let _ = he.fire(&event).await;
873 }
874 }
875
876 pub async fn execute(
882 &self,
883 history: &[Message],
884 prompt: &str,
885 event_tx: Option<mpsc::Sender<AgentEvent>>,
886 ) -> Result<AgentResult> {
887 self.execute_with_session(history, prompt, None, event_tx)
888 .await
889 }
890
891 pub async fn execute_from_messages(
897 &self,
898 messages: Vec<Message>,
899 session_id: Option<&str>,
900 event_tx: Option<mpsc::Sender<AgentEvent>>,
901 ) -> Result<AgentResult> {
902 tracing::info!(
903 a3s.session.id = session_id.unwrap_or("none"),
904 a3s.agent.max_turns = self.config.max_tool_rounds,
905 "a3s.agent.execute_from_messages started"
906 );
907
908 let result = self.execute_loop(&messages, "", session_id, event_tx).await;
910
911 match &result {
912 Ok(r) => tracing::info!(
913 a3s.agent.tool_calls_count = r.tool_calls_count,
914 a3s.llm.total_tokens = r.usage.total_tokens,
915 "a3s.agent.execute_from_messages completed"
916 ),
917 Err(e) => tracing::warn!(
918 error = %e,
919 "a3s.agent.execute_from_messages failed"
920 ),
921 }
922
923 result
924 }
925
926 pub async fn execute_with_session(
931 &self,
932 history: &[Message],
933 prompt: &str,
934 session_id: Option<&str>,
935 event_tx: Option<mpsc::Sender<AgentEvent>>,
936 ) -> Result<AgentResult> {
937 tracing::info!(
938 a3s.session.id = session_id.unwrap_or("none"),
939 a3s.agent.max_turns = self.config.max_tool_rounds,
940 "a3s.agent.execute started"
941 );
942
943 let result = if self.config.planning_enabled {
945 self.execute_with_planning(history, prompt, event_tx).await
946 } else {
947 self.execute_loop(history, prompt, session_id, event_tx)
948 .await
949 };
950
951 match &result {
952 Ok(r) => tracing::info!(
953 a3s.agent.tool_calls_count = r.tool_calls_count,
954 a3s.llm.total_tokens = r.usage.total_tokens,
955 "a3s.agent.execute completed"
956 ),
957 Err(e) => tracing::warn!(
958 error = %e,
959 "a3s.agent.execute failed"
960 ),
961 }
962
963 result
964 }
965
966 async fn execute_loop(
972 &self,
973 history: &[Message],
974 prompt: &str,
975 session_id: Option<&str>,
976 event_tx: Option<mpsc::Sender<AgentEvent>>,
977 ) -> Result<AgentResult> {
978 let mut messages = history.to_vec();
979 let mut total_usage = TokenUsage::default();
980 let mut tool_calls_count = 0;
981 let mut turn = 0;
982 let mut parse_error_count: u32 = 0;
984
985 if let Some(tx) = &event_tx {
987 tx.send(AgentEvent::Start {
988 prompt: prompt.to_string(),
989 })
990 .await
991 .ok();
992 }
993
994 let augmented_system = if !self.config.context_providers.is_empty() {
996 if let Some(tx) = &event_tx {
998 let provider_names: Vec<String> = self
999 .config
1000 .context_providers
1001 .iter()
1002 .map(|p| p.name().to_string())
1003 .collect();
1004 tx.send(AgentEvent::ContextResolving {
1005 providers: provider_names,
1006 })
1007 .await
1008 .ok();
1009 }
1010
1011 tracing::info!(
1012 a3s.context.providers = self.config.context_providers.len() as i64,
1013 "Context resolution started"
1014 );
1015 let context_results = self.resolve_context(prompt, session_id).await;
1016
1017 if let Some(tx) = &event_tx {
1019 let total_items: usize = context_results.iter().map(|r| r.items.len()).sum();
1020 let total_tokens: usize = context_results.iter().map(|r| r.total_tokens).sum();
1021
1022 tracing::info!(
1023 context_items = total_items,
1024 context_tokens = total_tokens,
1025 "Context resolution completed"
1026 );
1027
1028 tx.send(AgentEvent::ContextResolved {
1029 total_items,
1030 total_tokens,
1031 })
1032 .await
1033 .ok();
1034 }
1035
1036 self.build_augmented_system_prompt(&context_results)
1037 } else {
1038 self.config.system_prompt.clone()
1039 };
1040
1041 if !prompt.is_empty() {
1043 messages.push(Message::user(prompt));
1044 }
1045
1046 loop {
1047 turn += 1;
1048
1049 if turn > self.config.max_tool_rounds {
1050 let error = format!("Max tool rounds ({}) exceeded", self.config.max_tool_rounds);
1051 if let Some(tx) = &event_tx {
1052 tx.send(AgentEvent::Error {
1053 message: error.clone(),
1054 })
1055 .await
1056 .ok();
1057 }
1058 anyhow::bail!(error);
1059 }
1060
1061 if let Some(tx) = &event_tx {
1063 tx.send(AgentEvent::TurnStart { turn }).await.ok();
1064 }
1065
1066 tracing::info!(
1067 turn = turn,
1068 max_turns = self.config.max_tool_rounds,
1069 "Agent turn started"
1070 );
1071
1072 tracing::info!(
1074 a3s.llm.streaming = event_tx.is_some(),
1075 "LLM completion started"
1076 );
1077
1078 self.fire_generate_start(session_id.unwrap_or(""), prompt, &augmented_system)
1080 .await;
1081
1082 let llm_start = std::time::Instant::now();
1083 let response = {
1087 let threshold = self.config.circuit_breaker_threshold.max(1);
1088 let mut attempt = 0u32;
1089 loop {
1090 attempt += 1;
1091 let result = self
1092 .call_llm(&messages, augmented_system.as_deref(), &event_tx)
1093 .await;
1094 match result {
1095 Ok(r) => {
1096 break r;
1097 }
1098 Err(e) if event_tx.is_none() && attempt < threshold => {
1100 tracing::warn!(
1101 turn = turn,
1102 attempt = attempt,
1103 threshold = threshold,
1104 error = %e,
1105 "LLM call failed, will retry"
1106 );
1107 tokio::time::sleep(Duration::from_millis(100 * attempt as u64)).await;
1108 }
1109 Err(e) => {
1111 let msg = if event_tx.is_none() && attempt > 1 {
1112 format!(
1113 "LLM circuit breaker triggered: failed after {} attempt(s): {}",
1114 attempt, e
1115 )
1116 } else {
1117 format!("LLM call failed: {}", e)
1118 };
1119 tracing::error!(turn = turn, attempt = attempt, "{}", msg);
1120 if let Some(tx) = &event_tx {
1121 tx.send(AgentEvent::Error {
1122 message: msg.clone(),
1123 })
1124 .await
1125 .ok();
1126 }
1127 anyhow::bail!(msg);
1128 }
1129 }
1130 }
1131 };
1132
1133 total_usage.prompt_tokens += response.usage.prompt_tokens;
1135 total_usage.completion_tokens += response.usage.completion_tokens;
1136 total_usage.total_tokens += response.usage.total_tokens;
1137
1138 let llm_duration = llm_start.elapsed();
1140 tracing::info!(
1141 turn = turn,
1142 streaming = event_tx.is_some(),
1143 prompt_tokens = response.usage.prompt_tokens,
1144 completion_tokens = response.usage.completion_tokens,
1145 total_tokens = response.usage.total_tokens,
1146 stop_reason = response.stop_reason.as_deref().unwrap_or("unknown"),
1147 duration_ms = llm_duration.as_millis() as u64,
1148 "LLM completion finished"
1149 );
1150
1151 self.fire_generate_end(
1153 session_id.unwrap_or(""),
1154 prompt,
1155 &response,
1156 llm_duration.as_millis() as u64,
1157 )
1158 .await;
1159
1160 crate::telemetry::record_llm_usage(
1162 response.usage.prompt_tokens,
1163 response.usage.completion_tokens,
1164 response.usage.total_tokens,
1165 response.stop_reason.as_deref(),
1166 );
1167 tracing::info!(
1169 turn = turn,
1170 a3s.llm.total_tokens = response.usage.total_tokens,
1171 "Turn token usage"
1172 );
1173
1174 messages.push(response.message.clone());
1176
1177 let tool_calls = response.tool_calls();
1179
1180 if let Some(tx) = &event_tx {
1182 tx.send(AgentEvent::TurnEnd {
1183 turn,
1184 usage: response.usage.clone(),
1185 })
1186 .await
1187 .ok();
1188 }
1189
1190 if tool_calls.is_empty() {
1191 let final_text = response.text();
1193
1194 tracing::info!(
1196 tool_calls_count = tool_calls_count,
1197 total_prompt_tokens = total_usage.prompt_tokens,
1198 total_completion_tokens = total_usage.completion_tokens,
1199 total_tokens = total_usage.total_tokens,
1200 turns = turn,
1201 "Agent execution completed"
1202 );
1203
1204 if let Some(tx) = &event_tx {
1205 tx.send(AgentEvent::End {
1206 text: final_text.clone(),
1207 usage: total_usage.clone(),
1208 })
1209 .await
1210 .ok();
1211 }
1212
1213 if let Some(sid) = session_id {
1215 self.notify_turn_complete(sid, prompt, &final_text).await;
1216 }
1217
1218 return Ok(AgentResult {
1219 text: final_text,
1220 messages,
1221 usage: total_usage,
1222 tool_calls_count,
1223 });
1224 }
1225
1226 for tool_call in tool_calls {
1228 tool_calls_count += 1;
1229
1230 let tool_start = std::time::Instant::now();
1231
1232 tracing::info!(
1233 tool_name = tool_call.name.as_str(),
1234 tool_id = tool_call.id.as_str(),
1235 "Tool execution started"
1236 );
1237
1238 if let Some(parse_error) =
1244 tool_call.args.get("__parse_error").and_then(|v| v.as_str())
1245 {
1246 parse_error_count += 1;
1247 let error_msg = format!("Error: {}", parse_error);
1248 tracing::warn!(
1249 tool = tool_call.name.as_str(),
1250 parse_error_count = parse_error_count,
1251 max_parse_retries = self.config.max_parse_retries,
1252 "Malformed tool arguments from LLM"
1253 );
1254
1255 if let Some(tx) = &event_tx {
1256 tx.send(AgentEvent::ToolEnd {
1257 id: tool_call.id.clone(),
1258 name: tool_call.name.clone(),
1259 output: error_msg.clone(),
1260 exit_code: 1,
1261 })
1262 .await
1263 .ok();
1264 }
1265
1266 messages.push(Message::tool_result(&tool_call.id, &error_msg, true));
1267
1268 if parse_error_count > self.config.max_parse_retries {
1269 let msg = format!(
1270 "LLM produced malformed tool arguments {} time(s) in a row \
1271 (max_parse_retries={}); giving up",
1272 parse_error_count, self.config.max_parse_retries
1273 );
1274 tracing::error!("{}", msg);
1275 if let Some(tx) = &event_tx {
1276 tx.send(AgentEvent::Error {
1277 message: msg.clone(),
1278 })
1279 .await
1280 .ok();
1281 }
1282 anyhow::bail!(msg);
1283 }
1284 continue;
1285 }
1286
1287 parse_error_count = 0;
1289
1290 if let Some(HookResult::Block(reason)) = self
1292 .fire_pre_tool_use(session_id.unwrap_or(""), &tool_call.name, &tool_call.args)
1293 .await
1294 {
1295 let msg = format!("Tool '{}' blocked by hook: {}", tool_call.name, reason);
1296 tracing::info!(
1297 tool_name = tool_call.name.as_str(),
1298 "Tool blocked by PreToolUse hook"
1299 );
1300
1301 if let Some(tx) = &event_tx {
1302 tx.send(AgentEvent::PermissionDenied {
1303 tool_id: tool_call.id.clone(),
1304 tool_name: tool_call.name.clone(),
1305 args: tool_call.args.clone(),
1306 reason: reason.clone(),
1307 })
1308 .await
1309 .ok();
1310 }
1311
1312 messages.push(Message::tool_result(&tool_call.id, &msg, true));
1313 continue;
1314 }
1315
1316 let permission_decision = if let Some(checker) = &self.config.permission_checker {
1318 checker.check(&tool_call.name, &tool_call.args)
1319 } else {
1320 PermissionDecision::Ask
1322 };
1323
1324 let (output, exit_code, is_error, _metadata, images) = match permission_decision {
1325 PermissionDecision::Deny => {
1326 tracing::info!(
1327 tool_name = tool_call.name.as_str(),
1328 permission = "deny",
1329 "Tool permission denied"
1330 );
1331 let denial_msg = format!(
1333 "Permission denied: Tool '{}' is blocked by permission policy.",
1334 tool_call.name
1335 );
1336
1337 if let Some(tx) = &event_tx {
1339 tx.send(AgentEvent::PermissionDenied {
1340 tool_id: tool_call.id.clone(),
1341 tool_name: tool_call.name.clone(),
1342 args: tool_call.args.clone(),
1343 reason: "Blocked by deny rule in permission policy".to_string(),
1344 })
1345 .await
1346 .ok();
1347 }
1348
1349 (denial_msg, 1, true, None, Vec::new())
1350 }
1351 PermissionDecision::Allow => {
1352 tracing::info!(
1353 tool_name = tool_call.name.as_str(),
1354 permission = "allow",
1355 "Tool permission: allow"
1356 );
1357 let stream_ctx =
1359 self.streaming_tool_context(&event_tx, &tool_call.id, &tool_call.name);
1360 let result = self
1361 .execute_tool_timed(&tool_call.name, &tool_call.args, &stream_ctx)
1362 .await;
1363
1364 match result {
1365 Ok(r) => (
1366 r.output,
1367 r.exit_code,
1368 r.exit_code != 0,
1369 r.metadata,
1370 r.images,
1371 ),
1372 Err(e) => (
1373 format!("Tool execution error: {}", e),
1374 1,
1375 true,
1376 None,
1377 Vec::new(),
1378 ),
1379 }
1380 }
1381 PermissionDecision::Ask => {
1382 tracing::info!(
1383 tool_name = tool_call.name.as_str(),
1384 permission = "ask",
1385 "Tool permission: ask"
1386 );
1387 if let Some(cm) = &self.config.confirmation_manager {
1389 if !cm.requires_confirmation(&tool_call.name).await {
1391 let stream_ctx = self.streaming_tool_context(
1392 &event_tx,
1393 &tool_call.id,
1394 &tool_call.name,
1395 );
1396 let result = self
1397 .execute_tool_timed(
1398 &tool_call.name,
1399 &tool_call.args,
1400 &stream_ctx,
1401 )
1402 .await;
1403
1404 let (output, exit_code, is_error, _metadata, images) = match result
1405 {
1406 Ok(r) => (
1407 r.output,
1408 r.exit_code,
1409 r.exit_code != 0,
1410 r.metadata,
1411 r.images,
1412 ),
1413 Err(e) => (
1414 format!("Tool execution error: {}", e),
1415 1,
1416 true,
1417 None,
1418 Vec::new(),
1419 ),
1420 };
1421
1422 if images.is_empty() {
1424 messages.push(Message::tool_result(
1425 &tool_call.id,
1426 &output,
1427 is_error,
1428 ));
1429 } else {
1430 messages.push(Message::tool_result_with_images(
1431 &tool_call.id,
1432 &output,
1433 &images,
1434 is_error,
1435 ));
1436 }
1437
1438 let tool_duration = tool_start.elapsed();
1440 crate::telemetry::record_tool_result(exit_code, tool_duration);
1441
1442 self.fire_post_tool_use(
1444 session_id.unwrap_or(""),
1445 &tool_call.name,
1446 &tool_call.args,
1447 &output,
1448 exit_code == 0,
1449 tool_duration.as_millis() as u64,
1450 )
1451 .await;
1452
1453 continue; }
1455
1456 let policy = cm.policy().await;
1458 let timeout_ms = policy.default_timeout_ms;
1459 let timeout_action = policy.timeout_action;
1460
1461 let rx = cm
1463 .request_confirmation(
1464 &tool_call.id,
1465 &tool_call.name,
1466 &tool_call.args,
1467 )
1468 .await;
1469
1470 let confirmation_result =
1472 tokio::time::timeout(Duration::from_millis(timeout_ms), rx).await;
1473
1474 match confirmation_result {
1475 Ok(Ok(response)) => {
1476 if response.approved {
1477 let stream_ctx = self.streaming_tool_context(
1478 &event_tx,
1479 &tool_call.id,
1480 &tool_call.name,
1481 );
1482 let result = self
1483 .execute_tool_timed(
1484 &tool_call.name,
1485 &tool_call.args,
1486 &stream_ctx,
1487 )
1488 .await;
1489
1490 match result {
1491 Ok(r) => (
1492 r.output,
1493 r.exit_code,
1494 r.exit_code != 0,
1495 r.metadata,
1496 r.images,
1497 ),
1498 Err(e) => (
1499 format!("Tool execution error: {}", e),
1500 1,
1501 true,
1502 None,
1503 Vec::new(),
1504 ),
1505 }
1506 } else {
1507 let rejection_msg = format!(
1508 "Tool '{}' execution was rejected by user. Reason: {}",
1509 tool_call.name,
1510 response.reason.unwrap_or_else(|| "No reason provided".to_string())
1511 );
1512 (rejection_msg, 1, true, None, Vec::new())
1513 }
1514 }
1515 Ok(Err(_)) => {
1516 let msg = format!(
1517 "Tool '{}' confirmation failed: confirmation channel closed",
1518 tool_call.name
1519 );
1520 (msg, 1, true, None, Vec::new())
1521 }
1522 Err(_) => {
1523 cm.check_timeouts().await;
1524
1525 match timeout_action {
1526 crate::hitl::TimeoutAction::Reject => {
1527 let msg = format!(
1528 "Tool '{}' execution timed out waiting for confirmation ({}ms). Execution rejected.",
1529 tool_call.name, timeout_ms
1530 );
1531 (msg, 1, true, None, Vec::new())
1532 }
1533 crate::hitl::TimeoutAction::AutoApprove => {
1534 let stream_ctx = self.streaming_tool_context(
1535 &event_tx,
1536 &tool_call.id,
1537 &tool_call.name,
1538 );
1539 let result = self
1540 .execute_tool_timed(
1541 &tool_call.name,
1542 &tool_call.args,
1543 &stream_ctx,
1544 )
1545 .await;
1546
1547 match result {
1548 Ok(r) => (
1549 r.output,
1550 r.exit_code,
1551 r.exit_code != 0,
1552 r.metadata,
1553 r.images,
1554 ),
1555 Err(e) => (
1556 format!("Tool execution error: {}", e),
1557 1,
1558 true,
1559 None,
1560 Vec::new(),
1561 ),
1562 }
1563 }
1564 }
1565 }
1566 }
1567 } else {
1568 let msg = format!(
1570 "Tool '{}' requires confirmation but no HITL confirmation manager is configured. \
1571 Configure a confirmation policy to enable tool execution.",
1572 tool_call.name
1573 );
1574 tracing::warn!(
1575 tool_name = tool_call.name.as_str(),
1576 "Tool requires confirmation but no HITL manager configured"
1577 );
1578 (msg, 1, true, None, Vec::new())
1579 }
1580 }
1581 };
1582
1583 let tool_duration = tool_start.elapsed();
1584 crate::telemetry::record_tool_result(exit_code, tool_duration);
1585
1586 self.fire_post_tool_use(
1588 session_id.unwrap_or(""),
1589 &tool_call.name,
1590 &tool_call.args,
1591 &output,
1592 exit_code == 0,
1593 tool_duration.as_millis() as u64,
1594 )
1595 .await;
1596
1597 if let Some(tx) = &event_tx {
1599 tx.send(AgentEvent::ToolEnd {
1600 id: tool_call.id.clone(),
1601 name: tool_call.name.clone(),
1602 output: output.clone(),
1603 exit_code,
1604 })
1605 .await
1606 .ok();
1607 }
1608
1609 if images.is_empty() {
1611 messages.push(Message::tool_result(&tool_call.id, &output, is_error));
1612 } else {
1613 messages.push(Message::tool_result_with_images(
1614 &tool_call.id,
1615 &output,
1616 &images,
1617 is_error,
1618 ));
1619 }
1620 }
1621 }
1622 }
1623
1624 pub async fn execute_streaming(
1626 &self,
1627 history: &[Message],
1628 prompt: &str,
1629 ) -> Result<(
1630 mpsc::Receiver<AgentEvent>,
1631 tokio::task::JoinHandle<Result<AgentResult>>,
1632 )> {
1633 let (tx, rx) = mpsc::channel(100);
1634
1635 let llm_client = self.llm_client.clone();
1636 let tool_executor = self.tool_executor.clone();
1637 let tool_context = self.tool_context.clone();
1638 let config = self.config.clone();
1639 let tool_metrics = self.tool_metrics.clone();
1640 let command_queue = self.command_queue.clone();
1641 let history = history.to_vec();
1642 let prompt = prompt.to_string();
1643
1644 let handle = tokio::spawn(async move {
1645 let mut agent = AgentLoop::new(llm_client, tool_executor, tool_context, config);
1646 if let Some(metrics) = tool_metrics {
1647 agent = agent.with_tool_metrics(metrics);
1648 }
1649 if let Some(queue) = command_queue {
1650 agent = agent.with_queue(queue);
1651 }
1652 agent.execute(&history, &prompt, Some(tx)).await
1653 });
1654
1655 Ok((rx, handle))
1656 }
1657
1658 pub async fn plan(&self, prompt: &str, _context: Option<&str>) -> Result<ExecutionPlan> {
1663 use crate::planning::LlmPlanner;
1664
1665 match LlmPlanner::create_plan(&self.llm_client, prompt).await {
1666 Ok(plan) => Ok(plan),
1667 Err(e) => {
1668 tracing::warn!("LLM plan creation failed, using fallback: {}", e);
1669 Ok(LlmPlanner::fallback_plan(prompt))
1670 }
1671 }
1672 }
1673
1674 pub async fn execute_with_planning(
1676 &self,
1677 history: &[Message],
1678 prompt: &str,
1679 event_tx: Option<mpsc::Sender<AgentEvent>>,
1680 ) -> Result<AgentResult> {
1681 if let Some(tx) = &event_tx {
1683 tx.send(AgentEvent::PlanningStart {
1684 prompt: prompt.to_string(),
1685 })
1686 .await
1687 .ok();
1688 }
1689
1690 let plan = self.plan(prompt, None).await?;
1692
1693 if let Some(tx) = &event_tx {
1695 tx.send(AgentEvent::PlanningEnd {
1696 estimated_steps: plan.steps.len(),
1697 plan: plan.clone(),
1698 })
1699 .await
1700 .ok();
1701 }
1702
1703 self.execute_plan(history, &plan, event_tx).await
1705 }
1706
1707 async fn execute_plan(
1714 &self,
1715 history: &[Message],
1716 plan: &ExecutionPlan,
1717 event_tx: Option<mpsc::Sender<AgentEvent>>,
1718 ) -> Result<AgentResult> {
1719 let mut plan = plan.clone();
1720 let mut current_history = history.to_vec();
1721 let mut total_usage = TokenUsage::default();
1722 let mut tool_calls_count = 0;
1723 let total_steps = plan.steps.len();
1724
1725 let steps_text = plan
1727 .steps
1728 .iter()
1729 .enumerate()
1730 .map(|(i, step)| format!("{}. {}", i + 1, step.content))
1731 .collect::<Vec<_>>()
1732 .join("\n");
1733 current_history.push(Message::user(&crate::prompts::render(
1734 crate::prompts::PLAN_EXECUTE_GOAL,
1735 &[("goal", &plan.goal), ("steps", &steps_text)],
1736 )));
1737
1738 loop {
1739 let ready: Vec<String> = plan
1740 .get_ready_steps()
1741 .iter()
1742 .map(|s| s.id.clone())
1743 .collect();
1744
1745 if ready.is_empty() {
1746 if plan.has_deadlock() {
1748 tracing::warn!(
1749 "Plan deadlock detected: {} pending steps with unresolvable dependencies",
1750 plan.pending_count()
1751 );
1752 }
1753 break;
1754 }
1755
1756 if ready.len() == 1 {
1757 let step_id = &ready[0];
1759 let step = plan
1760 .steps
1761 .iter()
1762 .find(|s| s.id == *step_id)
1763 .unwrap()
1764 .clone();
1765 let step_number = plan.steps.iter().position(|s| s.id == *step_id).unwrap() + 1;
1766
1767 if let Some(tx) = &event_tx {
1769 tx.send(AgentEvent::StepStart {
1770 step_id: step.id.clone(),
1771 description: step.content.clone(),
1772 step_number,
1773 total_steps,
1774 })
1775 .await
1776 .ok();
1777 }
1778
1779 plan.mark_status(&step.id, TaskStatus::InProgress);
1780
1781 let step_prompt = crate::prompts::render(
1782 crate::prompts::PLAN_EXECUTE_STEP,
1783 &[
1784 ("step_num", &step_number.to_string()),
1785 ("description", &step.content),
1786 ],
1787 );
1788
1789 match self
1790 .execute_loop(¤t_history, &step_prompt, None, event_tx.clone())
1791 .await
1792 {
1793 Ok(result) => {
1794 current_history = result.messages.clone();
1795 total_usage.prompt_tokens += result.usage.prompt_tokens;
1796 total_usage.completion_tokens += result.usage.completion_tokens;
1797 total_usage.total_tokens += result.usage.total_tokens;
1798 tool_calls_count += result.tool_calls_count;
1799 plan.mark_status(&step.id, TaskStatus::Completed);
1800
1801 if let Some(tx) = &event_tx {
1802 tx.send(AgentEvent::StepEnd {
1803 step_id: step.id.clone(),
1804 status: TaskStatus::Completed,
1805 step_number,
1806 total_steps,
1807 })
1808 .await
1809 .ok();
1810 }
1811 }
1812 Err(e) => {
1813 tracing::error!("Plan step '{}' failed: {}", step.id, e);
1814 plan.mark_status(&step.id, TaskStatus::Failed);
1815
1816 if let Some(tx) = &event_tx {
1817 tx.send(AgentEvent::StepEnd {
1818 step_id: step.id.clone(),
1819 status: TaskStatus::Failed,
1820 step_number,
1821 total_steps,
1822 })
1823 .await
1824 .ok();
1825 }
1826 }
1827 }
1828 } else {
1829 let ready_steps: Vec<_> = ready
1831 .iter()
1832 .map(|id| {
1833 let step = plan.steps.iter().find(|s| s.id == *id).unwrap().clone();
1834 let step_number = plan.steps.iter().position(|s| s.id == *id).unwrap() + 1;
1835 (step, step_number)
1836 })
1837 .collect();
1838
1839 for (step, step_number) in &ready_steps {
1841 plan.mark_status(&step.id, TaskStatus::InProgress);
1842 if let Some(tx) = &event_tx {
1843 tx.send(AgentEvent::StepStart {
1844 step_id: step.id.clone(),
1845 description: step.content.clone(),
1846 step_number: *step_number,
1847 total_steps,
1848 })
1849 .await
1850 .ok();
1851 }
1852 }
1853
1854 let mut join_set = tokio::task::JoinSet::new();
1856 for (step, step_number) in &ready_steps {
1857 let base_history = current_history.clone();
1858 let agent_clone = self.clone();
1859 let tx = event_tx.clone();
1860 let step_clone = step.clone();
1861 let sn = *step_number;
1862
1863 join_set.spawn(async move {
1864 let prompt = crate::prompts::render(
1865 crate::prompts::PLAN_EXECUTE_STEP,
1866 &[
1867 ("step_num", &sn.to_string()),
1868 ("description", &step_clone.content),
1869 ],
1870 );
1871 let result = agent_clone
1872 .execute_loop(&base_history, &prompt, None, tx)
1873 .await;
1874 (step_clone.id, sn, result)
1875 });
1876 }
1877
1878 let mut parallel_summaries = Vec::new();
1880 while let Some(join_result) = join_set.join_next().await {
1881 match join_result {
1882 Ok((step_id, step_number, step_result)) => match step_result {
1883 Ok(result) => {
1884 total_usage.prompt_tokens += result.usage.prompt_tokens;
1885 total_usage.completion_tokens += result.usage.completion_tokens;
1886 total_usage.total_tokens += result.usage.total_tokens;
1887 tool_calls_count += result.tool_calls_count;
1888 plan.mark_status(&step_id, TaskStatus::Completed);
1889
1890 parallel_summaries.push(format!(
1892 "- Step {} ({}): {}",
1893 step_number, step_id, result.text
1894 ));
1895
1896 if let Some(tx) = &event_tx {
1897 tx.send(AgentEvent::StepEnd {
1898 step_id,
1899 status: TaskStatus::Completed,
1900 step_number,
1901 total_steps,
1902 })
1903 .await
1904 .ok();
1905 }
1906 }
1907 Err(e) => {
1908 tracing::error!("Plan step '{}' failed: {}", step_id, e);
1909 plan.mark_status(&step_id, TaskStatus::Failed);
1910
1911 if let Some(tx) = &event_tx {
1912 tx.send(AgentEvent::StepEnd {
1913 step_id,
1914 status: TaskStatus::Failed,
1915 step_number,
1916 total_steps,
1917 })
1918 .await
1919 .ok();
1920 }
1921 }
1922 },
1923 Err(e) => {
1924 tracing::error!("JoinSet task panicked: {}", e);
1925 }
1926 }
1927 }
1928
1929 if !parallel_summaries.is_empty() {
1931 parallel_summaries.sort(); let results_text = parallel_summaries.join("\n");
1933 current_history.push(Message::user(&crate::prompts::render(
1934 crate::prompts::PLAN_PARALLEL_RESULTS,
1935 &[("results", &results_text)],
1936 )));
1937 }
1938 }
1939
1940 if self.config.goal_tracking {
1942 let completed = plan
1943 .steps
1944 .iter()
1945 .filter(|s| s.status == TaskStatus::Completed)
1946 .count();
1947 if let Some(tx) = &event_tx {
1948 tx.send(AgentEvent::GoalProgress {
1949 goal: plan.goal.clone(),
1950 progress: plan.progress(),
1951 completed_steps: completed,
1952 total_steps,
1953 })
1954 .await
1955 .ok();
1956 }
1957 }
1958 }
1959
1960 let final_text = current_history
1962 .last()
1963 .map(|m| {
1964 m.content
1965 .iter()
1966 .filter_map(|block| {
1967 if let crate::llm::ContentBlock::Text { text } = block {
1968 Some(text.as_str())
1969 } else {
1970 None
1971 }
1972 })
1973 .collect::<Vec<_>>()
1974 .join("\n")
1975 })
1976 .unwrap_or_default();
1977
1978 Ok(AgentResult {
1979 text: final_text,
1980 messages: current_history,
1981 usage: total_usage,
1982 tool_calls_count,
1983 })
1984 }
1985
1986 pub async fn extract_goal(&self, prompt: &str) -> Result<AgentGoal> {
1991 use crate::planning::LlmPlanner;
1992
1993 match LlmPlanner::extract_goal(&self.llm_client, prompt).await {
1994 Ok(goal) => Ok(goal),
1995 Err(e) => {
1996 tracing::warn!("LLM goal extraction failed, using fallback: {}", e);
1997 Ok(LlmPlanner::fallback_goal(prompt))
1998 }
1999 }
2000 }
2001
2002 pub async fn check_goal_achievement(
2007 &self,
2008 goal: &AgentGoal,
2009 current_state: &str,
2010 ) -> Result<bool> {
2011 use crate::planning::LlmPlanner;
2012
2013 match LlmPlanner::check_achievement(&self.llm_client, goal, current_state).await {
2014 Ok(result) => Ok(result.achieved),
2015 Err(e) => {
2016 tracing::warn!("LLM achievement check failed, using fallback: {}", e);
2017 let result = LlmPlanner::fallback_check_achievement(goal, current_state);
2018 Ok(result.achieved)
2019 }
2020 }
2021 }
2022}
2023
2024#[cfg(test)]
2025mod tests {
2026 use super::*;
2027 use crate::llm::{ContentBlock, StreamEvent};
2028 use crate::permissions::PermissionPolicy;
2029 use crate::tools::ToolExecutor;
2030 use std::path::PathBuf;
2031 use std::sync::atomic::{AtomicUsize, Ordering};
2032
2033 fn test_tool_context() -> ToolContext {
2035 ToolContext::new(PathBuf::from("/tmp"))
2036 }
2037
2038 #[test]
2039 fn test_agent_config_default() {
2040 let config = AgentConfig::default();
2041 assert!(config.system_prompt.is_none());
2042 assert!(config.tools.is_empty()); assert_eq!(config.max_tool_rounds, MAX_TOOL_ROUNDS);
2044 assert!(config.permission_checker.is_none());
2045 assert!(config.context_providers.is_empty());
2046 }
2047
2048 pub(crate) struct MockLlmClient {
2054 responses: std::sync::Mutex<Vec<LlmResponse>>,
2056 pub(crate) call_count: AtomicUsize,
2058 }
2059
2060 impl MockLlmClient {
2061 pub(crate) fn new(responses: Vec<LlmResponse>) -> Self {
2062 Self {
2063 responses: std::sync::Mutex::new(responses),
2064 call_count: AtomicUsize::new(0),
2065 }
2066 }
2067
2068 pub(crate) fn text_response(text: &str) -> LlmResponse {
2070 LlmResponse {
2071 message: Message {
2072 role: "assistant".to_string(),
2073 content: vec![ContentBlock::Text {
2074 text: text.to_string(),
2075 }],
2076 reasoning_content: None,
2077 },
2078 usage: TokenUsage {
2079 prompt_tokens: 10,
2080 completion_tokens: 5,
2081 total_tokens: 15,
2082 cache_read_tokens: None,
2083 cache_write_tokens: None,
2084 },
2085 stop_reason: Some("end_turn".to_string()),
2086 }
2087 }
2088
2089 pub(crate) fn tool_call_response(
2091 tool_id: &str,
2092 tool_name: &str,
2093 args: serde_json::Value,
2094 ) -> LlmResponse {
2095 LlmResponse {
2096 message: Message {
2097 role: "assistant".to_string(),
2098 content: vec![ContentBlock::ToolUse {
2099 id: tool_id.to_string(),
2100 name: tool_name.to_string(),
2101 input: args,
2102 }],
2103 reasoning_content: None,
2104 },
2105 usage: TokenUsage {
2106 prompt_tokens: 10,
2107 completion_tokens: 5,
2108 total_tokens: 15,
2109 cache_read_tokens: None,
2110 cache_write_tokens: None,
2111 },
2112 stop_reason: Some("tool_use".to_string()),
2113 }
2114 }
2115 }
2116
2117 #[async_trait::async_trait]
2118 impl LlmClient for MockLlmClient {
2119 async fn complete(
2120 &self,
2121 _messages: &[Message],
2122 _system: Option<&str>,
2123 _tools: &[ToolDefinition],
2124 ) -> Result<LlmResponse> {
2125 self.call_count.fetch_add(1, Ordering::SeqCst);
2126 let mut responses = self.responses.lock().unwrap();
2127 if responses.is_empty() {
2128 anyhow::bail!("No more mock responses available");
2129 }
2130 Ok(responses.remove(0))
2131 }
2132
2133 async fn complete_streaming(
2134 &self,
2135 _messages: &[Message],
2136 _system: Option<&str>,
2137 _tools: &[ToolDefinition],
2138 ) -> Result<mpsc::Receiver<StreamEvent>> {
2139 self.call_count.fetch_add(1, Ordering::SeqCst);
2140 let mut responses = self.responses.lock().unwrap();
2141 if responses.is_empty() {
2142 anyhow::bail!("No more mock responses available");
2143 }
2144 let response = responses.remove(0);
2145
2146 let (tx, rx) = mpsc::channel(10);
2147 tokio::spawn(async move {
2148 for block in &response.message.content {
2150 if let ContentBlock::Text { text } = block {
2151 tx.send(StreamEvent::TextDelta(text.clone())).await.ok();
2152 }
2153 }
2154 tx.send(StreamEvent::Done(response)).await.ok();
2155 });
2156
2157 Ok(rx)
2158 }
2159 }
2160
2161 #[tokio::test]
2166 async fn test_agent_simple_response() {
2167 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
2168 "Hello, I'm an AI assistant.",
2169 )]));
2170
2171 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2172 let config = AgentConfig::default();
2173
2174 let agent = AgentLoop::new(
2175 mock_client.clone(),
2176 tool_executor,
2177 test_tool_context(),
2178 config,
2179 );
2180 let result = agent.execute(&[], "Hello", None).await.unwrap();
2181
2182 assert_eq!(result.text, "Hello, I'm an AI assistant.");
2183 assert_eq!(result.tool_calls_count, 0);
2184 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 1);
2185 }
2186
2187 #[tokio::test]
2188 async fn test_agent_with_tool_call() {
2189 let mock_client = Arc::new(MockLlmClient::new(vec![
2190 MockLlmClient::tool_call_response(
2192 "tool-1",
2193 "bash",
2194 serde_json::json!({"command": "echo hello"}),
2195 ),
2196 MockLlmClient::text_response("The command output was: hello"),
2198 ]));
2199
2200 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2201 let config = AgentConfig::default();
2202
2203 let agent = AgentLoop::new(
2204 mock_client.clone(),
2205 tool_executor,
2206 test_tool_context(),
2207 config,
2208 );
2209 let result = agent.execute(&[], "Run echo hello", None).await.unwrap();
2210
2211 assert_eq!(result.text, "The command output was: hello");
2212 assert_eq!(result.tool_calls_count, 1);
2213 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 2);
2214 }
2215
2216 #[tokio::test]
2217 async fn test_agent_permission_deny() {
2218 let mock_client = Arc::new(MockLlmClient::new(vec![
2219 MockLlmClient::tool_call_response(
2221 "tool-1",
2222 "bash",
2223 serde_json::json!({"command": "rm -rf /tmp/test"}),
2224 ),
2225 MockLlmClient::text_response(
2227 "I cannot execute that command due to permission restrictions.",
2228 ),
2229 ]));
2230
2231 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2232
2233 let permission_policy = PermissionPolicy::new().deny("bash(rm:*)");
2235
2236 let config = AgentConfig {
2237 permission_checker: Some(Arc::new(permission_policy)),
2238 ..Default::default()
2239 };
2240
2241 let (tx, mut rx) = mpsc::channel(100);
2242 let agent = AgentLoop::new(
2243 mock_client.clone(),
2244 tool_executor,
2245 test_tool_context(),
2246 config,
2247 );
2248 let result = agent.execute(&[], "Delete files", Some(tx)).await.unwrap();
2249
2250 let mut found_permission_denied = false;
2252 while let Ok(event) = rx.try_recv() {
2253 if let AgentEvent::PermissionDenied { tool_name, .. } = event {
2254 assert_eq!(tool_name, "bash");
2255 found_permission_denied = true;
2256 }
2257 }
2258 assert!(
2259 found_permission_denied,
2260 "Should have received PermissionDenied event"
2261 );
2262
2263 assert_eq!(result.tool_calls_count, 1);
2264 }
2265
2266 #[tokio::test]
2267 async fn test_agent_permission_allow() {
2268 let mock_client = Arc::new(MockLlmClient::new(vec![
2269 MockLlmClient::tool_call_response(
2271 "tool-1",
2272 "bash",
2273 serde_json::json!({"command": "echo hello"}),
2274 ),
2275 MockLlmClient::text_response("Done!"),
2277 ]));
2278
2279 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2280
2281 let permission_policy = PermissionPolicy::new()
2283 .allow("bash(echo:*)")
2284 .deny("bash(rm:*)");
2285
2286 let config = AgentConfig {
2287 permission_checker: Some(Arc::new(permission_policy)),
2288 ..Default::default()
2289 };
2290
2291 let agent = AgentLoop::new(
2292 mock_client.clone(),
2293 tool_executor,
2294 test_tool_context(),
2295 config,
2296 );
2297 let result = agent.execute(&[], "Echo hello", None).await.unwrap();
2298
2299 assert_eq!(result.text, "Done!");
2300 assert_eq!(result.tool_calls_count, 1);
2301 }
2302
2303 #[tokio::test]
2304 async fn test_agent_streaming_events() {
2305 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
2306 "Hello!",
2307 )]));
2308
2309 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2310 let config = AgentConfig::default();
2311
2312 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2313 let (mut rx, handle) = agent.execute_streaming(&[], "Hi").await.unwrap();
2314
2315 let mut events = Vec::new();
2317 while let Some(event) = rx.recv().await {
2318 events.push(event);
2319 }
2320
2321 let result = handle.await.unwrap().unwrap();
2322 assert_eq!(result.text, "Hello!");
2323
2324 assert!(events.iter().any(|e| matches!(e, AgentEvent::Start { .. })));
2326 assert!(events.iter().any(|e| matches!(e, AgentEvent::End { .. })));
2327 }
2328
2329 #[tokio::test]
2330 async fn test_agent_max_tool_rounds() {
2331 let responses: Vec<LlmResponse> = (0..100)
2333 .map(|i| {
2334 MockLlmClient::tool_call_response(
2335 &format!("tool-{}", i),
2336 "bash",
2337 serde_json::json!({"command": "echo loop"}),
2338 )
2339 })
2340 .collect();
2341
2342 let mock_client = Arc::new(MockLlmClient::new(responses));
2343 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2344
2345 let config = AgentConfig {
2346 max_tool_rounds: 3,
2347 ..Default::default()
2348 };
2349
2350 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2351 let result = agent.execute(&[], "Loop forever", None).await;
2352
2353 assert!(result.is_err());
2355 assert!(result.unwrap_err().to_string().contains("Max tool rounds"));
2356 }
2357
2358 #[tokio::test]
2359 async fn test_agent_no_permission_policy_defaults_to_ask() {
2360 let mock_client = Arc::new(MockLlmClient::new(vec![
2363 MockLlmClient::tool_call_response(
2364 "tool-1",
2365 "bash",
2366 serde_json::json!({"command": "rm -rf /tmp/test"}),
2367 ),
2368 MockLlmClient::text_response("Denied!"),
2369 ]));
2370
2371 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2372 let config = AgentConfig {
2373 permission_checker: None, ..Default::default()
2376 };
2377
2378 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2379 let result = agent.execute(&[], "Delete", None).await.unwrap();
2380
2381 assert_eq!(result.text, "Denied!");
2383 assert_eq!(result.tool_calls_count, 1);
2384 }
2385
2386 #[tokio::test]
2387 async fn test_agent_permission_ask_without_cm_denies() {
2388 let mock_client = Arc::new(MockLlmClient::new(vec![
2391 MockLlmClient::tool_call_response(
2392 "tool-1",
2393 "bash",
2394 serde_json::json!({"command": "echo test"}),
2395 ),
2396 MockLlmClient::text_response("Denied!"),
2397 ]));
2398
2399 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2400
2401 let permission_policy = PermissionPolicy::new(); let config = AgentConfig {
2405 permission_checker: Some(Arc::new(permission_policy)),
2406 ..Default::default()
2408 };
2409
2410 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2411 let result = agent.execute(&[], "Echo", None).await.unwrap();
2412
2413 assert_eq!(result.text, "Denied!");
2415 assert!(result.tool_calls_count >= 1);
2417 }
2418
2419 #[tokio::test]
2424 async fn test_agent_hitl_approved() {
2425 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
2426 use tokio::sync::broadcast;
2427
2428 let mock_client = Arc::new(MockLlmClient::new(vec![
2429 MockLlmClient::tool_call_response(
2430 "tool-1",
2431 "bash",
2432 serde_json::json!({"command": "echo hello"}),
2433 ),
2434 MockLlmClient::text_response("Command executed!"),
2435 ]));
2436
2437 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2438
2439 let (event_tx, _event_rx) = broadcast::channel(100);
2441 let hitl_policy = ConfirmationPolicy {
2442 enabled: true,
2443 ..Default::default()
2444 };
2445 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
2446
2447 let permission_policy = PermissionPolicy::new(); let config = AgentConfig {
2451 permission_checker: Some(Arc::new(permission_policy)),
2452 confirmation_manager: Some(confirmation_manager.clone()),
2453 ..Default::default()
2454 };
2455
2456 let cm_clone = confirmation_manager.clone();
2458 tokio::spawn(async move {
2459 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
2461 cm_clone.confirm("tool-1", true, None).await.ok();
2463 });
2464
2465 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2466 let result = agent.execute(&[], "Run echo", None).await.unwrap();
2467
2468 assert_eq!(result.text, "Command executed!");
2469 assert_eq!(result.tool_calls_count, 1);
2470 }
2471
2472 #[tokio::test]
2473 async fn test_agent_hitl_rejected() {
2474 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
2475 use tokio::sync::broadcast;
2476
2477 let mock_client = Arc::new(MockLlmClient::new(vec![
2478 MockLlmClient::tool_call_response(
2479 "tool-1",
2480 "bash",
2481 serde_json::json!({"command": "rm -rf /"}),
2482 ),
2483 MockLlmClient::text_response("Understood, I won't do that."),
2484 ]));
2485
2486 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2487
2488 let (event_tx, _event_rx) = broadcast::channel(100);
2490 let hitl_policy = ConfirmationPolicy {
2491 enabled: true,
2492 ..Default::default()
2493 };
2494 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
2495
2496 let permission_policy = PermissionPolicy::new();
2498
2499 let config = AgentConfig {
2500 permission_checker: Some(Arc::new(permission_policy)),
2501 confirmation_manager: Some(confirmation_manager.clone()),
2502 ..Default::default()
2503 };
2504
2505 let cm_clone = confirmation_manager.clone();
2507 tokio::spawn(async move {
2508 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
2509 cm_clone
2510 .confirm("tool-1", false, Some("Too dangerous".to_string()))
2511 .await
2512 .ok();
2513 });
2514
2515 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2516 let result = agent.execute(&[], "Delete everything", None).await.unwrap();
2517
2518 assert_eq!(result.text, "Understood, I won't do that.");
2520 }
2521
2522 #[tokio::test]
2523 async fn test_agent_hitl_timeout_reject() {
2524 use crate::hitl::{ConfirmationManager, ConfirmationPolicy, TimeoutAction};
2525 use tokio::sync::broadcast;
2526
2527 let mock_client = Arc::new(MockLlmClient::new(vec![
2528 MockLlmClient::tool_call_response(
2529 "tool-1",
2530 "bash",
2531 serde_json::json!({"command": "echo test"}),
2532 ),
2533 MockLlmClient::text_response("Timed out, I understand."),
2534 ]));
2535
2536 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2537
2538 let (event_tx, _event_rx) = broadcast::channel(100);
2540 let hitl_policy = ConfirmationPolicy {
2541 enabled: true,
2542 default_timeout_ms: 50, timeout_action: TimeoutAction::Reject,
2544 ..Default::default()
2545 };
2546 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
2547
2548 let permission_policy = PermissionPolicy::new();
2549
2550 let config = AgentConfig {
2551 permission_checker: Some(Arc::new(permission_policy)),
2552 confirmation_manager: Some(confirmation_manager),
2553 ..Default::default()
2554 };
2555
2556 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2558 let result = agent.execute(&[], "Echo", None).await.unwrap();
2559
2560 assert_eq!(result.text, "Timed out, I understand.");
2562 }
2563
2564 #[tokio::test]
2565 async fn test_agent_hitl_timeout_auto_approve() {
2566 use crate::hitl::{ConfirmationManager, ConfirmationPolicy, TimeoutAction};
2567 use tokio::sync::broadcast;
2568
2569 let mock_client = Arc::new(MockLlmClient::new(vec![
2570 MockLlmClient::tool_call_response(
2571 "tool-1",
2572 "bash",
2573 serde_json::json!({"command": "echo hello"}),
2574 ),
2575 MockLlmClient::text_response("Auto-approved and executed!"),
2576 ]));
2577
2578 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2579
2580 let (event_tx, _event_rx) = broadcast::channel(100);
2582 let hitl_policy = ConfirmationPolicy {
2583 enabled: true,
2584 default_timeout_ms: 50, timeout_action: TimeoutAction::AutoApprove,
2586 ..Default::default()
2587 };
2588 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
2589
2590 let permission_policy = PermissionPolicy::new();
2591
2592 let config = AgentConfig {
2593 permission_checker: Some(Arc::new(permission_policy)),
2594 confirmation_manager: Some(confirmation_manager),
2595 ..Default::default()
2596 };
2597
2598 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2600 let result = agent.execute(&[], "Echo", None).await.unwrap();
2601
2602 assert_eq!(result.text, "Auto-approved and executed!");
2604 assert_eq!(result.tool_calls_count, 1);
2605 }
2606
2607 #[tokio::test]
2608 async fn test_agent_hitl_confirmation_events() {
2609 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
2610 use tokio::sync::broadcast;
2611
2612 let mock_client = Arc::new(MockLlmClient::new(vec![
2613 MockLlmClient::tool_call_response(
2614 "tool-1",
2615 "bash",
2616 serde_json::json!({"command": "echo test"}),
2617 ),
2618 MockLlmClient::text_response("Done!"),
2619 ]));
2620
2621 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2622
2623 let (event_tx, mut event_rx) = broadcast::channel(100);
2625 let hitl_policy = ConfirmationPolicy {
2626 enabled: true,
2627 default_timeout_ms: 5000, ..Default::default()
2629 };
2630 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
2631
2632 let permission_policy = PermissionPolicy::new();
2633
2634 let config = AgentConfig {
2635 permission_checker: Some(Arc::new(permission_policy)),
2636 confirmation_manager: Some(confirmation_manager.clone()),
2637 ..Default::default()
2638 };
2639
2640 let cm_clone = confirmation_manager.clone();
2642 let event_handle = tokio::spawn(async move {
2643 let mut events = Vec::new();
2644 while let Ok(event) = event_rx.recv().await {
2646 events.push(event.clone());
2647 if let AgentEvent::ConfirmationRequired { tool_id, .. } = event {
2648 cm_clone.confirm(&tool_id, true, None).await.ok();
2650 if let Ok(recv_event) = event_rx.recv().await {
2652 events.push(recv_event);
2653 }
2654 break;
2655 }
2656 }
2657 events
2658 });
2659
2660 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2661 let _result = agent.execute(&[], "Echo", None).await.unwrap();
2662
2663 let events = event_handle.await.unwrap();
2665 assert!(
2666 events
2667 .iter()
2668 .any(|e| matches!(e, AgentEvent::ConfirmationRequired { .. })),
2669 "Should have ConfirmationRequired event"
2670 );
2671 assert!(
2672 events
2673 .iter()
2674 .any(|e| matches!(e, AgentEvent::ConfirmationReceived { approved: true, .. })),
2675 "Should have ConfirmationReceived event with approved=true"
2676 );
2677 }
2678
2679 #[tokio::test]
2680 async fn test_agent_hitl_disabled_auto_executes() {
2681 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
2683 use tokio::sync::broadcast;
2684
2685 let mock_client = Arc::new(MockLlmClient::new(vec![
2686 MockLlmClient::tool_call_response(
2687 "tool-1",
2688 "bash",
2689 serde_json::json!({"command": "echo auto"}),
2690 ),
2691 MockLlmClient::text_response("Auto executed!"),
2692 ]));
2693
2694 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2695
2696 let (event_tx, _event_rx) = broadcast::channel(100);
2698 let hitl_policy = ConfirmationPolicy {
2699 enabled: false, ..Default::default()
2701 };
2702 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
2703
2704 let permission_policy = PermissionPolicy::new(); let config = AgentConfig {
2707 permission_checker: Some(Arc::new(permission_policy)),
2708 confirmation_manager: Some(confirmation_manager),
2709 ..Default::default()
2710 };
2711
2712 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2713 let result = agent.execute(&[], "Echo", None).await.unwrap();
2714
2715 assert_eq!(result.text, "Auto executed!");
2717 assert_eq!(result.tool_calls_count, 1);
2718 }
2719
2720 #[tokio::test]
2721 async fn test_agent_hitl_with_permission_deny_skips_hitl() {
2722 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
2724 use tokio::sync::broadcast;
2725
2726 let mock_client = Arc::new(MockLlmClient::new(vec![
2727 MockLlmClient::tool_call_response(
2728 "tool-1",
2729 "bash",
2730 serde_json::json!({"command": "rm -rf /"}),
2731 ),
2732 MockLlmClient::text_response("Blocked by permission."),
2733 ]));
2734
2735 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2736
2737 let (event_tx, mut event_rx) = broadcast::channel(100);
2739 let hitl_policy = ConfirmationPolicy {
2740 enabled: true,
2741 ..Default::default()
2742 };
2743 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
2744
2745 let permission_policy = PermissionPolicy::new().deny("bash(rm:*)");
2747
2748 let config = AgentConfig {
2749 permission_checker: Some(Arc::new(permission_policy)),
2750 confirmation_manager: Some(confirmation_manager),
2751 ..Default::default()
2752 };
2753
2754 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2755 let result = agent.execute(&[], "Delete", None).await.unwrap();
2756
2757 assert_eq!(result.text, "Blocked by permission.");
2759
2760 let mut found_confirmation = false;
2762 while let Ok(event) = event_rx.try_recv() {
2763 if matches!(event, AgentEvent::ConfirmationRequired { .. }) {
2764 found_confirmation = true;
2765 }
2766 }
2767 assert!(
2768 !found_confirmation,
2769 "HITL should not be triggered when permission is Deny"
2770 );
2771 }
2772
2773 #[tokio::test]
2774 async fn test_agent_hitl_with_permission_allow_skips_hitl() {
2775 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
2778 use tokio::sync::broadcast;
2779
2780 let mock_client = Arc::new(MockLlmClient::new(vec![
2781 MockLlmClient::tool_call_response(
2782 "tool-1",
2783 "bash",
2784 serde_json::json!({"command": "echo hello"}),
2785 ),
2786 MockLlmClient::text_response("Allowed!"),
2787 ]));
2788
2789 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2790
2791 let (event_tx, mut event_rx) = broadcast::channel(100);
2793 let hitl_policy = ConfirmationPolicy {
2794 enabled: true,
2795 ..Default::default()
2796 };
2797 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
2798
2799 let permission_policy = PermissionPolicy::new().allow("bash(echo:*)");
2801
2802 let config = AgentConfig {
2803 permission_checker: Some(Arc::new(permission_policy)),
2804 confirmation_manager: Some(confirmation_manager.clone()),
2805 ..Default::default()
2806 };
2807
2808 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2809 let result = agent.execute(&[], "Echo", None).await.unwrap();
2810
2811 assert_eq!(result.text, "Allowed!");
2813
2814 let mut found_confirmation = false;
2816 while let Ok(event) = event_rx.try_recv() {
2817 if matches!(event, AgentEvent::ConfirmationRequired { .. }) {
2818 found_confirmation = true;
2819 }
2820 }
2821 assert!(
2822 !found_confirmation,
2823 "Permission Allow should skip HITL confirmation"
2824 );
2825 }
2826
2827 #[tokio::test]
2828 async fn test_agent_hitl_multiple_tool_calls() {
2829 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
2831 use tokio::sync::broadcast;
2832
2833 let mock_client = Arc::new(MockLlmClient::new(vec![
2834 LlmResponse {
2836 message: Message {
2837 role: "assistant".to_string(),
2838 content: vec![
2839 ContentBlock::ToolUse {
2840 id: "tool-1".to_string(),
2841 name: "bash".to_string(),
2842 input: serde_json::json!({"command": "echo first"}),
2843 },
2844 ContentBlock::ToolUse {
2845 id: "tool-2".to_string(),
2846 name: "bash".to_string(),
2847 input: serde_json::json!({"command": "echo second"}),
2848 },
2849 ],
2850 reasoning_content: None,
2851 },
2852 usage: TokenUsage {
2853 prompt_tokens: 10,
2854 completion_tokens: 5,
2855 total_tokens: 15,
2856 cache_read_tokens: None,
2857 cache_write_tokens: None,
2858 },
2859 stop_reason: Some("tool_use".to_string()),
2860 },
2861 MockLlmClient::text_response("Both executed!"),
2862 ]));
2863
2864 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2865
2866 let (event_tx, _event_rx) = broadcast::channel(100);
2868 let hitl_policy = ConfirmationPolicy {
2869 enabled: true,
2870 default_timeout_ms: 5000,
2871 ..Default::default()
2872 };
2873 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
2874
2875 let permission_policy = PermissionPolicy::new(); let config = AgentConfig {
2878 permission_checker: Some(Arc::new(permission_policy)),
2879 confirmation_manager: Some(confirmation_manager.clone()),
2880 ..Default::default()
2881 };
2882
2883 let cm_clone = confirmation_manager.clone();
2885 tokio::spawn(async move {
2886 tokio::time::sleep(std::time::Duration::from_millis(30)).await;
2887 cm_clone.confirm("tool-1", true, None).await.ok();
2888 tokio::time::sleep(std::time::Duration::from_millis(30)).await;
2889 cm_clone.confirm("tool-2", true, None).await.ok();
2890 });
2891
2892 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2893 let result = agent.execute(&[], "Run both", None).await.unwrap();
2894
2895 assert_eq!(result.text, "Both executed!");
2896 assert_eq!(result.tool_calls_count, 2);
2897 }
2898
2899 #[tokio::test]
2900 async fn test_agent_hitl_partial_approval() {
2901 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
2903 use tokio::sync::broadcast;
2904
2905 let mock_client = Arc::new(MockLlmClient::new(vec![
2906 LlmResponse {
2908 message: Message {
2909 role: "assistant".to_string(),
2910 content: vec![
2911 ContentBlock::ToolUse {
2912 id: "tool-1".to_string(),
2913 name: "bash".to_string(),
2914 input: serde_json::json!({"command": "echo safe"}),
2915 },
2916 ContentBlock::ToolUse {
2917 id: "tool-2".to_string(),
2918 name: "bash".to_string(),
2919 input: serde_json::json!({"command": "rm -rf /"}),
2920 },
2921 ],
2922 reasoning_content: None,
2923 },
2924 usage: TokenUsage {
2925 prompt_tokens: 10,
2926 completion_tokens: 5,
2927 total_tokens: 15,
2928 cache_read_tokens: None,
2929 cache_write_tokens: None,
2930 },
2931 stop_reason: Some("tool_use".to_string()),
2932 },
2933 MockLlmClient::text_response("First worked, second rejected."),
2934 ]));
2935
2936 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2937
2938 let (event_tx, _event_rx) = broadcast::channel(100);
2939 let hitl_policy = ConfirmationPolicy {
2940 enabled: true,
2941 default_timeout_ms: 5000,
2942 ..Default::default()
2943 };
2944 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
2945
2946 let permission_policy = PermissionPolicy::new();
2947
2948 let config = AgentConfig {
2949 permission_checker: Some(Arc::new(permission_policy)),
2950 confirmation_manager: Some(confirmation_manager.clone()),
2951 ..Default::default()
2952 };
2953
2954 let cm_clone = confirmation_manager.clone();
2956 tokio::spawn(async move {
2957 tokio::time::sleep(std::time::Duration::from_millis(30)).await;
2958 cm_clone.confirm("tool-1", true, None).await.ok();
2959 tokio::time::sleep(std::time::Duration::from_millis(30)).await;
2960 cm_clone
2961 .confirm("tool-2", false, Some("Dangerous".to_string()))
2962 .await
2963 .ok();
2964 });
2965
2966 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2967 let result = agent.execute(&[], "Run both", None).await.unwrap();
2968
2969 assert_eq!(result.text, "First worked, second rejected.");
2970 assert_eq!(result.tool_calls_count, 2);
2971 }
2972
2973 #[tokio::test]
2974 async fn test_agent_hitl_yolo_mode_auto_approves() {
2975 use crate::hitl::{ConfirmationManager, ConfirmationPolicy, SessionLane};
2977 use tokio::sync::broadcast;
2978
2979 let mock_client = Arc::new(MockLlmClient::new(vec![
2980 MockLlmClient::tool_call_response(
2981 "tool-1",
2982 "read", serde_json::json!({"path": "/tmp/test.txt"}),
2984 ),
2985 MockLlmClient::text_response("File read!"),
2986 ]));
2987
2988 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2989
2990 let (event_tx, mut event_rx) = broadcast::channel(100);
2992 let mut yolo_lanes = std::collections::HashSet::new();
2993 yolo_lanes.insert(SessionLane::Query);
2994 let hitl_policy = ConfirmationPolicy {
2995 enabled: true,
2996 yolo_lanes, ..Default::default()
2998 };
2999 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3000
3001 let permission_policy = PermissionPolicy::new();
3002
3003 let config = AgentConfig {
3004 permission_checker: Some(Arc::new(permission_policy)),
3005 confirmation_manager: Some(confirmation_manager),
3006 ..Default::default()
3007 };
3008
3009 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3010 let result = agent.execute(&[], "Read file", None).await.unwrap();
3011
3012 assert_eq!(result.text, "File read!");
3014
3015 let mut found_confirmation = false;
3017 while let Ok(event) = event_rx.try_recv() {
3018 if matches!(event, AgentEvent::ConfirmationRequired { .. }) {
3019 found_confirmation = true;
3020 }
3021 }
3022 assert!(
3023 !found_confirmation,
3024 "YOLO mode should not trigger confirmation"
3025 );
3026 }
3027
3028 #[tokio::test]
3029 async fn test_agent_config_with_all_options() {
3030 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
3031 use tokio::sync::broadcast;
3032
3033 let (event_tx, _) = broadcast::channel(100);
3034 let hitl_policy = ConfirmationPolicy::default();
3035 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3036
3037 let permission_policy = PermissionPolicy::new().allow("bash(*)");
3038
3039 let config = AgentConfig {
3040 system_prompt: Some("Test system prompt".to_string()),
3041 tools: vec![],
3042 max_tool_rounds: 10,
3043 permission_checker: Some(Arc::new(permission_policy)),
3044 confirmation_manager: Some(confirmation_manager),
3045 context_providers: vec![],
3046 planning_enabled: false,
3047 goal_tracking: false,
3048 hook_engine: None,
3049 skill_registry: None,
3050 ..AgentConfig::default()
3051 };
3052
3053 assert_eq!(config.system_prompt, Some("Test system prompt".to_string()));
3054 assert_eq!(config.max_tool_rounds, 10);
3055 assert!(config.permission_checker.is_some());
3056 assert!(config.confirmation_manager.is_some());
3057 assert!(config.context_providers.is_empty());
3058
3059 let debug_str = format!("{:?}", config);
3061 assert!(debug_str.contains("AgentConfig"));
3062 assert!(debug_str.contains("permission_checker: true"));
3063 assert!(debug_str.contains("confirmation_manager: true"));
3064 assert!(debug_str.contains("context_providers: 0"));
3065 }
3066
3067 use crate::context::{ContextItem, ContextType};
3072
3073 struct MockContextProvider {
3075 name: String,
3076 items: Vec<ContextItem>,
3077 on_turn_calls: std::sync::Arc<tokio::sync::RwLock<Vec<(String, String, String)>>>,
3078 }
3079
3080 impl MockContextProvider {
3081 fn new(name: &str) -> Self {
3082 Self {
3083 name: name.to_string(),
3084 items: Vec::new(),
3085 on_turn_calls: std::sync::Arc::new(tokio::sync::RwLock::new(Vec::new())),
3086 }
3087 }
3088
3089 fn with_items(mut self, items: Vec<ContextItem>) -> Self {
3090 self.items = items;
3091 self
3092 }
3093 }
3094
3095 #[async_trait::async_trait]
3096 impl ContextProvider for MockContextProvider {
3097 fn name(&self) -> &str {
3098 &self.name
3099 }
3100
3101 async fn query(&self, _query: &ContextQuery) -> anyhow::Result<ContextResult> {
3102 let mut result = ContextResult::new(&self.name);
3103 for item in &self.items {
3104 result.add_item(item.clone());
3105 }
3106 Ok(result)
3107 }
3108
3109 async fn on_turn_complete(
3110 &self,
3111 session_id: &str,
3112 prompt: &str,
3113 response: &str,
3114 ) -> anyhow::Result<()> {
3115 let mut calls = self.on_turn_calls.write().await;
3116 calls.push((
3117 session_id.to_string(),
3118 prompt.to_string(),
3119 response.to_string(),
3120 ));
3121 Ok(())
3122 }
3123 }
3124
3125 #[tokio::test]
3126 async fn test_agent_with_context_provider() {
3127 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
3128 "Response using context",
3129 )]));
3130
3131 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3132
3133 let provider =
3134 MockContextProvider::new("test-provider").with_items(vec![ContextItem::new(
3135 "ctx-1",
3136 ContextType::Resource,
3137 "Relevant context here",
3138 )
3139 .with_source("test://docs/example")]);
3140
3141 let config = AgentConfig {
3142 system_prompt: Some("You are helpful.".to_string()),
3143 context_providers: vec![Arc::new(provider)],
3144 ..Default::default()
3145 };
3146
3147 let agent = AgentLoop::new(
3148 mock_client.clone(),
3149 tool_executor,
3150 test_tool_context(),
3151 config,
3152 );
3153 let result = agent.execute(&[], "What is X?", None).await.unwrap();
3154
3155 assert_eq!(result.text, "Response using context");
3156 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 1);
3157 }
3158
3159 #[tokio::test]
3160 async fn test_agent_context_provider_events() {
3161 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
3162 "Answer",
3163 )]));
3164
3165 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3166
3167 let provider =
3168 MockContextProvider::new("event-provider").with_items(vec![ContextItem::new(
3169 "item-1",
3170 ContextType::Memory,
3171 "Memory content",
3172 )
3173 .with_token_count(50)]);
3174
3175 let config = AgentConfig {
3176 context_providers: vec![Arc::new(provider)],
3177 ..Default::default()
3178 };
3179
3180 let (tx, mut rx) = mpsc::channel(100);
3181 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3182 let _result = agent.execute(&[], "Test prompt", Some(tx)).await.unwrap();
3183
3184 let mut events = Vec::new();
3186 while let Ok(event) = rx.try_recv() {
3187 events.push(event);
3188 }
3189
3190 assert!(
3192 events
3193 .iter()
3194 .any(|e| matches!(e, AgentEvent::ContextResolving { .. })),
3195 "Should have ContextResolving event"
3196 );
3197 assert!(
3198 events
3199 .iter()
3200 .any(|e| matches!(e, AgentEvent::ContextResolved { .. })),
3201 "Should have ContextResolved event"
3202 );
3203
3204 for event in &events {
3206 if let AgentEvent::ContextResolved {
3207 total_items,
3208 total_tokens,
3209 } = event
3210 {
3211 assert_eq!(*total_items, 1);
3212 assert_eq!(*total_tokens, 50);
3213 }
3214 }
3215 }
3216
3217 #[tokio::test]
3218 async fn test_agent_multiple_context_providers() {
3219 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
3220 "Combined response",
3221 )]));
3222
3223 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3224
3225 let provider1 = MockContextProvider::new("provider-1").with_items(vec![ContextItem::new(
3226 "p1-1",
3227 ContextType::Resource,
3228 "Resource from P1",
3229 )
3230 .with_token_count(100)]);
3231
3232 let provider2 = MockContextProvider::new("provider-2").with_items(vec![
3233 ContextItem::new("p2-1", ContextType::Memory, "Memory from P2").with_token_count(50),
3234 ContextItem::new("p2-2", ContextType::Skill, "Skill from P2").with_token_count(75),
3235 ]);
3236
3237 let config = AgentConfig {
3238 system_prompt: Some("Base system prompt.".to_string()),
3239 context_providers: vec![Arc::new(provider1), Arc::new(provider2)],
3240 ..Default::default()
3241 };
3242
3243 let (tx, mut rx) = mpsc::channel(100);
3244 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3245 let result = agent.execute(&[], "Query", Some(tx)).await.unwrap();
3246
3247 assert_eq!(result.text, "Combined response");
3248
3249 while let Ok(event) = rx.try_recv() {
3251 if let AgentEvent::ContextResolved {
3252 total_items,
3253 total_tokens,
3254 } = event
3255 {
3256 assert_eq!(total_items, 3); assert_eq!(total_tokens, 225); }
3259 }
3260 }
3261
3262 #[tokio::test]
3263 async fn test_agent_no_context_providers() {
3264 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
3265 "No context",
3266 )]));
3267
3268 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3269
3270 let config = AgentConfig::default();
3272
3273 let (tx, mut rx) = mpsc::channel(100);
3274 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3275 let result = agent.execute(&[], "Simple prompt", Some(tx)).await.unwrap();
3276
3277 assert_eq!(result.text, "No context");
3278
3279 let mut events = Vec::new();
3281 while let Ok(event) = rx.try_recv() {
3282 events.push(event);
3283 }
3284
3285 assert!(
3286 !events
3287 .iter()
3288 .any(|e| matches!(e, AgentEvent::ContextResolving { .. })),
3289 "Should NOT have ContextResolving event"
3290 );
3291 }
3292
3293 #[tokio::test]
3294 async fn test_agent_context_on_turn_complete() {
3295 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
3296 "Final response",
3297 )]));
3298
3299 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3300
3301 let provider = Arc::new(MockContextProvider::new("memory-provider"));
3302 let on_turn_calls = provider.on_turn_calls.clone();
3303
3304 let config = AgentConfig {
3305 context_providers: vec![provider],
3306 ..Default::default()
3307 };
3308
3309 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3310
3311 let result = agent
3313 .execute_with_session(&[], "User prompt", Some("sess-123"), None)
3314 .await
3315 .unwrap();
3316
3317 assert_eq!(result.text, "Final response");
3318
3319 let calls = on_turn_calls.read().await;
3321 assert_eq!(calls.len(), 1);
3322 assert_eq!(calls[0].0, "sess-123");
3323 assert_eq!(calls[0].1, "User prompt");
3324 assert_eq!(calls[0].2, "Final response");
3325 }
3326
3327 #[tokio::test]
3328 async fn test_agent_context_on_turn_complete_no_session() {
3329 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
3330 "Response",
3331 )]));
3332
3333 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3334
3335 let provider = Arc::new(MockContextProvider::new("memory-provider"));
3336 let on_turn_calls = provider.on_turn_calls.clone();
3337
3338 let config = AgentConfig {
3339 context_providers: vec![provider],
3340 ..Default::default()
3341 };
3342
3343 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3344
3345 let _result = agent.execute(&[], "Prompt", None).await.unwrap();
3347
3348 let calls = on_turn_calls.read().await;
3350 assert!(calls.is_empty());
3351 }
3352
3353 #[tokio::test]
3354 async fn test_agent_build_augmented_system_prompt() {
3355 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response("OK")]));
3356
3357 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3358
3359 let provider = MockContextProvider::new("test").with_items(vec![ContextItem::new(
3360 "doc-1",
3361 ContextType::Resource,
3362 "Auth uses JWT tokens.",
3363 )
3364 .with_source("viking://docs/auth")]);
3365
3366 let config = AgentConfig {
3367 system_prompt: Some("You are helpful.".to_string()),
3368 context_providers: vec![Arc::new(provider)],
3369 ..Default::default()
3370 };
3371
3372 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3373
3374 let context_results = agent.resolve_context("test", None).await;
3376 let augmented = agent.build_augmented_system_prompt(&context_results);
3377
3378 let augmented_str = augmented.unwrap();
3379 assert!(augmented_str.contains("You are helpful."));
3380 assert!(augmented_str.contains("<context source=\"viking://docs/auth\" type=\"Resource\">"));
3381 assert!(augmented_str.contains("Auth uses JWT tokens."));
3382 }
3383
3384 async fn collect_events(mut rx: mpsc::Receiver<AgentEvent>) -> Vec<AgentEvent> {
3390 let mut events = Vec::new();
3391 while let Ok(event) = rx.try_recv() {
3392 events.push(event);
3393 }
3394 while let Some(event) = rx.recv().await {
3396 events.push(event);
3397 }
3398 events
3399 }
3400
3401 #[tokio::test]
3402 async fn test_agent_multi_turn_tool_chain() {
3403 let mock_client = Arc::new(MockLlmClient::new(vec![
3405 MockLlmClient::tool_call_response(
3407 "t1",
3408 "bash",
3409 serde_json::json!({"command": "echo step1"}),
3410 ),
3411 MockLlmClient::tool_call_response(
3413 "t2",
3414 "bash",
3415 serde_json::json!({"command": "echo step2"}),
3416 ),
3417 MockLlmClient::text_response("Completed both steps: step1 then step2"),
3419 ]));
3420
3421 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3422 let config = AgentConfig::default();
3423
3424 let agent = AgentLoop::new(
3425 mock_client.clone(),
3426 tool_executor,
3427 test_tool_context(),
3428 config,
3429 );
3430 let result = agent.execute(&[], "Run two steps", None).await.unwrap();
3431
3432 assert_eq!(result.text, "Completed both steps: step1 then step2");
3433 assert_eq!(result.tool_calls_count, 2);
3434 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 3);
3435
3436 assert_eq!(result.messages[0].role, "user");
3438 assert_eq!(result.messages[1].role, "assistant"); assert_eq!(result.messages[2].role, "user"); assert_eq!(result.messages[3].role, "assistant"); assert_eq!(result.messages[4].role, "user"); assert_eq!(result.messages[5].role, "assistant"); assert_eq!(result.messages.len(), 6);
3444 }
3445
3446 #[tokio::test]
3447 async fn test_agent_conversation_history_preserved() {
3448 let existing_history = vec![
3450 Message::user("What is Rust?"),
3451 Message {
3452 role: "assistant".to_string(),
3453 content: vec![ContentBlock::Text {
3454 text: "Rust is a systems programming language.".to_string(),
3455 }],
3456 reasoning_content: None,
3457 },
3458 ];
3459
3460 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
3461 "Rust was created by Graydon Hoare at Mozilla.",
3462 )]));
3463
3464 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3465 let agent = AgentLoop::new(
3466 mock_client.clone(),
3467 tool_executor,
3468 test_tool_context(),
3469 AgentConfig::default(),
3470 );
3471
3472 let result = agent
3473 .execute(&existing_history, "Who created it?", None)
3474 .await
3475 .unwrap();
3476
3477 assert_eq!(result.messages.len(), 4);
3479 assert_eq!(result.messages[0].text(), "What is Rust?");
3480 assert_eq!(
3481 result.messages[1].text(),
3482 "Rust is a systems programming language."
3483 );
3484 assert_eq!(result.messages[2].text(), "Who created it?");
3485 assert_eq!(
3486 result.messages[3].text(),
3487 "Rust was created by Graydon Hoare at Mozilla."
3488 );
3489 }
3490
3491 #[tokio::test]
3492 async fn test_agent_event_stream_completeness() {
3493 let mock_client = Arc::new(MockLlmClient::new(vec![
3495 MockLlmClient::tool_call_response(
3496 "t1",
3497 "bash",
3498 serde_json::json!({"command": "echo hi"}),
3499 ),
3500 MockLlmClient::text_response("Done"),
3501 ]));
3502
3503 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3504 let agent = AgentLoop::new(
3505 mock_client,
3506 tool_executor,
3507 test_tool_context(),
3508 AgentConfig::default(),
3509 );
3510
3511 let (tx, rx) = mpsc::channel(100);
3512 let result = agent.execute(&[], "Say hi", Some(tx)).await.unwrap();
3513 assert_eq!(result.text, "Done");
3514
3515 let events = collect_events(rx).await;
3516
3517 let event_types: Vec<&str> = events
3519 .iter()
3520 .map(|e| match e {
3521 AgentEvent::Start { .. } => "Start",
3522 AgentEvent::TurnStart { .. } => "TurnStart",
3523 AgentEvent::TurnEnd { .. } => "TurnEnd",
3524 AgentEvent::ToolEnd { .. } => "ToolEnd",
3525 AgentEvent::End { .. } => "End",
3526 _ => "Other",
3527 })
3528 .collect();
3529
3530 assert_eq!(event_types.first(), Some(&"Start"));
3532 assert_eq!(event_types.last(), Some(&"End"));
3533
3534 let turn_starts = event_types.iter().filter(|&&t| t == "TurnStart").count();
3536 assert_eq!(turn_starts, 2);
3537
3538 let tool_ends = event_types.iter().filter(|&&t| t == "ToolEnd").count();
3540 assert_eq!(tool_ends, 1);
3541 }
3542
3543 #[tokio::test]
3544 async fn test_agent_multiple_tools_single_turn() {
3545 let mock_client = Arc::new(MockLlmClient::new(vec![
3547 LlmResponse {
3548 message: Message {
3549 role: "assistant".to_string(),
3550 content: vec![
3551 ContentBlock::ToolUse {
3552 id: "t1".to_string(),
3553 name: "bash".to_string(),
3554 input: serde_json::json!({"command": "echo first"}),
3555 },
3556 ContentBlock::ToolUse {
3557 id: "t2".to_string(),
3558 name: "bash".to_string(),
3559 input: serde_json::json!({"command": "echo second"}),
3560 },
3561 ],
3562 reasoning_content: None,
3563 },
3564 usage: TokenUsage {
3565 prompt_tokens: 10,
3566 completion_tokens: 5,
3567 total_tokens: 15,
3568 cache_read_tokens: None,
3569 cache_write_tokens: None,
3570 },
3571 stop_reason: Some("tool_use".to_string()),
3572 },
3573 MockLlmClient::text_response("Both commands ran"),
3574 ]));
3575
3576 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3577 let agent = AgentLoop::new(
3578 mock_client.clone(),
3579 tool_executor,
3580 test_tool_context(),
3581 AgentConfig::default(),
3582 );
3583
3584 let result = agent.execute(&[], "Run both", None).await.unwrap();
3585
3586 assert_eq!(result.text, "Both commands ran");
3587 assert_eq!(result.tool_calls_count, 2);
3588 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 2); assert_eq!(result.messages[0].role, "user");
3592 assert_eq!(result.messages[1].role, "assistant");
3593 assert_eq!(result.messages[2].role, "user"); assert_eq!(result.messages[3].role, "user"); assert_eq!(result.messages[4].role, "assistant");
3596 }
3597
3598 #[tokio::test]
3599 async fn test_agent_token_usage_accumulation() {
3600 let mock_client = Arc::new(MockLlmClient::new(vec![
3602 MockLlmClient::tool_call_response(
3603 "t1",
3604 "bash",
3605 serde_json::json!({"command": "echo x"}),
3606 ),
3607 MockLlmClient::text_response("Done"),
3608 ]));
3609
3610 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3611 let agent = AgentLoop::new(
3612 mock_client,
3613 tool_executor,
3614 test_tool_context(),
3615 AgentConfig::default(),
3616 );
3617
3618 let result = agent.execute(&[], "test", None).await.unwrap();
3619
3620 assert_eq!(result.usage.prompt_tokens, 20);
3623 assert_eq!(result.usage.completion_tokens, 10);
3624 assert_eq!(result.usage.total_tokens, 30);
3625 }
3626
3627 #[tokio::test]
3628 async fn test_agent_system_prompt_passed() {
3629 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
3631 "I am a coding assistant.",
3632 )]));
3633
3634 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3635 let config = AgentConfig {
3636 system_prompt: Some("You are a coding assistant.".to_string()),
3637 ..Default::default()
3638 };
3639
3640 let agent = AgentLoop::new(
3641 mock_client.clone(),
3642 tool_executor,
3643 test_tool_context(),
3644 config,
3645 );
3646 let result = agent.execute(&[], "What are you?", None).await.unwrap();
3647
3648 assert_eq!(result.text, "I am a coding assistant.");
3649 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 1);
3650 }
3651
3652 #[tokio::test]
3653 async fn test_agent_max_rounds_with_persistent_tool_calls() {
3654 let mut responses = Vec::new();
3656 for i in 0..15 {
3657 responses.push(MockLlmClient::tool_call_response(
3658 &format!("t{}", i),
3659 "bash",
3660 serde_json::json!({"command": format!("echo round{}", i)}),
3661 ));
3662 }
3663
3664 let mock_client = Arc::new(MockLlmClient::new(responses));
3665 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3666 let config = AgentConfig {
3667 max_tool_rounds: 5,
3668 ..Default::default()
3669 };
3670
3671 let agent = AgentLoop::new(
3672 mock_client.clone(),
3673 tool_executor,
3674 test_tool_context(),
3675 config,
3676 );
3677 let result = agent.execute(&[], "Loop forever", None).await;
3678
3679 assert!(result.is_err());
3680 let err = result.unwrap_err().to_string();
3681 assert!(err.contains("Max tool rounds (5) exceeded"));
3682 }
3683
3684 #[tokio::test]
3685 async fn test_agent_end_event_contains_final_text() {
3686 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
3687 "Final answer here",
3688 )]));
3689
3690 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3691 let agent = AgentLoop::new(
3692 mock_client,
3693 tool_executor,
3694 test_tool_context(),
3695 AgentConfig::default(),
3696 );
3697
3698 let (tx, rx) = mpsc::channel(100);
3699 agent.execute(&[], "test", Some(tx)).await.unwrap();
3700
3701 let events = collect_events(rx).await;
3702 let end_event = events.iter().find(|e| matches!(e, AgentEvent::End { .. }));
3703 assert!(end_event.is_some());
3704
3705 if let AgentEvent::End { text, usage } = end_event.unwrap() {
3706 assert_eq!(text, "Final answer here");
3707 assert_eq!(usage.total_tokens, 15);
3708 }
3709 }
3710}
3711
3712#[cfg(test)]
3713mod extra_agent_tests {
3714 use super::*;
3715 use crate::agent::tests::MockLlmClient;
3716 use crate::llm::{ContentBlock, StreamEvent};
3717 use crate::queue::SessionQueueConfig;
3718 use crate::tools::ToolExecutor;
3719 use std::path::PathBuf;
3720 use std::sync::atomic::{AtomicUsize, Ordering};
3721
3722 fn test_tool_context() -> ToolContext {
3723 ToolContext::new(PathBuf::from("/tmp"))
3724 }
3725
3726 #[test]
3731 fn test_agent_config_debug() {
3732 let config = AgentConfig {
3733 system_prompt: Some("You are helpful".to_string()),
3734 tools: vec![],
3735 max_tool_rounds: 10,
3736 permission_checker: None,
3737 confirmation_manager: None,
3738 context_providers: vec![],
3739 planning_enabled: true,
3740 goal_tracking: false,
3741 hook_engine: None,
3742 skill_registry: None,
3743 ..AgentConfig::default()
3744 };
3745 let debug = format!("{:?}", config);
3746 assert!(debug.contains("AgentConfig"));
3747 assert!(debug.contains("planning_enabled"));
3748 }
3749
3750 #[test]
3751 fn test_agent_config_default_values() {
3752 let config = AgentConfig::default();
3753 assert_eq!(config.max_tool_rounds, MAX_TOOL_ROUNDS);
3754 assert!(!config.planning_enabled);
3755 assert!(!config.goal_tracking);
3756 assert!(config.context_providers.is_empty());
3757 }
3758
3759 #[test]
3764 fn test_agent_event_serialize_start() {
3765 let event = AgentEvent::Start {
3766 prompt: "Hello".to_string(),
3767 };
3768 let json = serde_json::to_string(&event).unwrap();
3769 assert!(json.contains("agent_start"));
3770 assert!(json.contains("Hello"));
3771 }
3772
3773 #[test]
3774 fn test_agent_event_serialize_text_delta() {
3775 let event = AgentEvent::TextDelta {
3776 text: "chunk".to_string(),
3777 };
3778 let json = serde_json::to_string(&event).unwrap();
3779 assert!(json.contains("text_delta"));
3780 }
3781
3782 #[test]
3783 fn test_agent_event_serialize_tool_start() {
3784 let event = AgentEvent::ToolStart {
3785 id: "t1".to_string(),
3786 name: "bash".to_string(),
3787 };
3788 let json = serde_json::to_string(&event).unwrap();
3789 assert!(json.contains("tool_start"));
3790 assert!(json.contains("bash"));
3791 }
3792
3793 #[test]
3794 fn test_agent_event_serialize_tool_end() {
3795 let event = AgentEvent::ToolEnd {
3796 id: "t1".to_string(),
3797 name: "bash".to_string(),
3798 output: "hello".to_string(),
3799 exit_code: 0,
3800 };
3801 let json = serde_json::to_string(&event).unwrap();
3802 assert!(json.contains("tool_end"));
3803 }
3804
3805 #[test]
3806 fn test_agent_event_serialize_error() {
3807 let event = AgentEvent::Error {
3808 message: "oops".to_string(),
3809 };
3810 let json = serde_json::to_string(&event).unwrap();
3811 assert!(json.contains("error"));
3812 assert!(json.contains("oops"));
3813 }
3814
3815 #[test]
3816 fn test_agent_event_serialize_confirmation_required() {
3817 let event = AgentEvent::ConfirmationRequired {
3818 tool_id: "t1".to_string(),
3819 tool_name: "bash".to_string(),
3820 args: serde_json::json!({"cmd": "rm"}),
3821 timeout_ms: 30000,
3822 };
3823 let json = serde_json::to_string(&event).unwrap();
3824 assert!(json.contains("confirmation_required"));
3825 }
3826
3827 #[test]
3828 fn test_agent_event_serialize_confirmation_received() {
3829 let event = AgentEvent::ConfirmationReceived {
3830 tool_id: "t1".to_string(),
3831 approved: true,
3832 reason: Some("safe".to_string()),
3833 };
3834 let json = serde_json::to_string(&event).unwrap();
3835 assert!(json.contains("confirmation_received"));
3836 }
3837
3838 #[test]
3839 fn test_agent_event_serialize_confirmation_timeout() {
3840 let event = AgentEvent::ConfirmationTimeout {
3841 tool_id: "t1".to_string(),
3842 action_taken: "rejected".to_string(),
3843 };
3844 let json = serde_json::to_string(&event).unwrap();
3845 assert!(json.contains("confirmation_timeout"));
3846 }
3847
3848 #[test]
3849 fn test_agent_event_serialize_external_task_pending() {
3850 let event = AgentEvent::ExternalTaskPending {
3851 task_id: "task-1".to_string(),
3852 session_id: "sess-1".to_string(),
3853 lane: crate::hitl::SessionLane::Execute,
3854 command_type: "bash".to_string(),
3855 payload: serde_json::json!({}),
3856 timeout_ms: 60000,
3857 };
3858 let json = serde_json::to_string(&event).unwrap();
3859 assert!(json.contains("external_task_pending"));
3860 }
3861
3862 #[test]
3863 fn test_agent_event_serialize_external_task_completed() {
3864 let event = AgentEvent::ExternalTaskCompleted {
3865 task_id: "task-1".to_string(),
3866 session_id: "sess-1".to_string(),
3867 success: false,
3868 };
3869 let json = serde_json::to_string(&event).unwrap();
3870 assert!(json.contains("external_task_completed"));
3871 }
3872
3873 #[test]
3874 fn test_agent_event_serialize_permission_denied() {
3875 let event = AgentEvent::PermissionDenied {
3876 tool_id: "t1".to_string(),
3877 tool_name: "bash".to_string(),
3878 args: serde_json::json!({}),
3879 reason: "denied".to_string(),
3880 };
3881 let json = serde_json::to_string(&event).unwrap();
3882 assert!(json.contains("permission_denied"));
3883 }
3884
3885 #[test]
3886 fn test_agent_event_serialize_context_compacted() {
3887 let event = AgentEvent::ContextCompacted {
3888 session_id: "sess-1".to_string(),
3889 before_messages: 100,
3890 after_messages: 20,
3891 percent_before: 0.85,
3892 };
3893 let json = serde_json::to_string(&event).unwrap();
3894 assert!(json.contains("context_compacted"));
3895 }
3896
3897 #[test]
3898 fn test_agent_event_serialize_turn_start() {
3899 let event = AgentEvent::TurnStart { turn: 3 };
3900 let json = serde_json::to_string(&event).unwrap();
3901 assert!(json.contains("turn_start"));
3902 }
3903
3904 #[test]
3905 fn test_agent_event_serialize_turn_end() {
3906 let event = AgentEvent::TurnEnd {
3907 turn: 3,
3908 usage: TokenUsage::default(),
3909 };
3910 let json = serde_json::to_string(&event).unwrap();
3911 assert!(json.contains("turn_end"));
3912 }
3913
3914 #[test]
3915 fn test_agent_event_serialize_end() {
3916 let event = AgentEvent::End {
3917 text: "Done".to_string(),
3918 usage: TokenUsage {
3919 prompt_tokens: 100,
3920 completion_tokens: 50,
3921 total_tokens: 150,
3922 cache_read_tokens: None,
3923 cache_write_tokens: None,
3924 },
3925 };
3926 let json = serde_json::to_string(&event).unwrap();
3927 assert!(json.contains("agent_end"));
3928 }
3929
3930 #[test]
3935 fn test_agent_result_fields() {
3936 let result = AgentResult {
3937 text: "output".to_string(),
3938 messages: vec![Message::user("hello")],
3939 usage: TokenUsage::default(),
3940 tool_calls_count: 3,
3941 };
3942 assert_eq!(result.text, "output");
3943 assert_eq!(result.messages.len(), 1);
3944 assert_eq!(result.tool_calls_count, 3);
3945 }
3946
3947 #[test]
3952 fn test_agent_event_serialize_context_resolving() {
3953 let event = AgentEvent::ContextResolving {
3954 providers: vec!["provider1".to_string(), "provider2".to_string()],
3955 };
3956 let json = serde_json::to_string(&event).unwrap();
3957 assert!(json.contains("context_resolving"));
3958 assert!(json.contains("provider1"));
3959 }
3960
3961 #[test]
3962 fn test_agent_event_serialize_context_resolved() {
3963 let event = AgentEvent::ContextResolved {
3964 total_items: 5,
3965 total_tokens: 1000,
3966 };
3967 let json = serde_json::to_string(&event).unwrap();
3968 assert!(json.contains("context_resolved"));
3969 assert!(json.contains("1000"));
3970 }
3971
3972 #[test]
3973 fn test_agent_event_serialize_command_dead_lettered() {
3974 let event = AgentEvent::CommandDeadLettered {
3975 command_id: "cmd-1".to_string(),
3976 command_type: "bash".to_string(),
3977 lane: "execute".to_string(),
3978 error: "timeout".to_string(),
3979 attempts: 3,
3980 };
3981 let json = serde_json::to_string(&event).unwrap();
3982 assert!(json.contains("command_dead_lettered"));
3983 assert!(json.contains("cmd-1"));
3984 }
3985
3986 #[test]
3987 fn test_agent_event_serialize_command_retry() {
3988 let event = AgentEvent::CommandRetry {
3989 command_id: "cmd-2".to_string(),
3990 command_type: "read".to_string(),
3991 lane: "query".to_string(),
3992 attempt: 2,
3993 delay_ms: 1000,
3994 };
3995 let json = serde_json::to_string(&event).unwrap();
3996 assert!(json.contains("command_retry"));
3997 assert!(json.contains("cmd-2"));
3998 }
3999
4000 #[test]
4001 fn test_agent_event_serialize_queue_alert() {
4002 let event = AgentEvent::QueueAlert {
4003 level: "warning".to_string(),
4004 alert_type: "depth".to_string(),
4005 message: "Queue depth exceeded".to_string(),
4006 };
4007 let json = serde_json::to_string(&event).unwrap();
4008 assert!(json.contains("queue_alert"));
4009 assert!(json.contains("warning"));
4010 }
4011
4012 #[test]
4013 fn test_agent_event_serialize_task_updated() {
4014 let event = AgentEvent::TaskUpdated {
4015 session_id: "sess-1".to_string(),
4016 tasks: vec![],
4017 };
4018 let json = serde_json::to_string(&event).unwrap();
4019 assert!(json.contains("task_updated"));
4020 assert!(json.contains("sess-1"));
4021 }
4022
4023 #[test]
4024 fn test_agent_event_serialize_memory_stored() {
4025 let event = AgentEvent::MemoryStored {
4026 memory_id: "mem-1".to_string(),
4027 memory_type: "conversation".to_string(),
4028 importance: 0.8,
4029 tags: vec!["important".to_string()],
4030 };
4031 let json = serde_json::to_string(&event).unwrap();
4032 assert!(json.contains("memory_stored"));
4033 assert!(json.contains("mem-1"));
4034 }
4035
4036 #[test]
4037 fn test_agent_event_serialize_memory_recalled() {
4038 let event = AgentEvent::MemoryRecalled {
4039 memory_id: "mem-2".to_string(),
4040 content: "Previous conversation".to_string(),
4041 relevance: 0.9,
4042 };
4043 let json = serde_json::to_string(&event).unwrap();
4044 assert!(json.contains("memory_recalled"));
4045 assert!(json.contains("mem-2"));
4046 }
4047
4048 #[test]
4049 fn test_agent_event_serialize_memories_searched() {
4050 let event = AgentEvent::MemoriesSearched {
4051 query: Some("search term".to_string()),
4052 tags: vec!["tag1".to_string()],
4053 result_count: 5,
4054 };
4055 let json = serde_json::to_string(&event).unwrap();
4056 assert!(json.contains("memories_searched"));
4057 assert!(json.contains("search term"));
4058 }
4059
4060 #[test]
4061 fn test_agent_event_serialize_memory_cleared() {
4062 let event = AgentEvent::MemoryCleared {
4063 tier: "short_term".to_string(),
4064 count: 10,
4065 };
4066 let json = serde_json::to_string(&event).unwrap();
4067 assert!(json.contains("memory_cleared"));
4068 assert!(json.contains("short_term"));
4069 }
4070
4071 #[test]
4072 fn test_agent_event_serialize_subagent_start() {
4073 let event = AgentEvent::SubagentStart {
4074 task_id: "task-1".to_string(),
4075 session_id: "child-sess".to_string(),
4076 parent_session_id: "parent-sess".to_string(),
4077 agent: "explore".to_string(),
4078 description: "Explore codebase".to_string(),
4079 };
4080 let json = serde_json::to_string(&event).unwrap();
4081 assert!(json.contains("subagent_start"));
4082 assert!(json.contains("explore"));
4083 }
4084
4085 #[test]
4086 fn test_agent_event_serialize_subagent_progress() {
4087 let event = AgentEvent::SubagentProgress {
4088 task_id: "task-1".to_string(),
4089 session_id: "child-sess".to_string(),
4090 status: "processing".to_string(),
4091 metadata: serde_json::json!({"progress": 50}),
4092 };
4093 let json = serde_json::to_string(&event).unwrap();
4094 assert!(json.contains("subagent_progress"));
4095 assert!(json.contains("processing"));
4096 }
4097
4098 #[test]
4099 fn test_agent_event_serialize_subagent_end() {
4100 let event = AgentEvent::SubagentEnd {
4101 task_id: "task-1".to_string(),
4102 session_id: "child-sess".to_string(),
4103 agent: "explore".to_string(),
4104 output: "Found 10 files".to_string(),
4105 success: true,
4106 };
4107 let json = serde_json::to_string(&event).unwrap();
4108 assert!(json.contains("subagent_end"));
4109 assert!(json.contains("Found 10 files"));
4110 }
4111
4112 #[test]
4113 fn test_agent_event_serialize_planning_start() {
4114 let event = AgentEvent::PlanningStart {
4115 prompt: "Build a web app".to_string(),
4116 };
4117 let json = serde_json::to_string(&event).unwrap();
4118 assert!(json.contains("planning_start"));
4119 assert!(json.contains("Build a web app"));
4120 }
4121
4122 #[test]
4123 fn test_agent_event_serialize_planning_end() {
4124 use crate::planning::{Complexity, ExecutionPlan};
4125 let plan = ExecutionPlan::new("Test goal".to_string(), Complexity::Simple);
4126 let event = AgentEvent::PlanningEnd {
4127 plan,
4128 estimated_steps: 3,
4129 };
4130 let json = serde_json::to_string(&event).unwrap();
4131 assert!(json.contains("planning_end"));
4132 assert!(json.contains("estimated_steps"));
4133 }
4134
4135 #[test]
4136 fn test_agent_event_serialize_step_start() {
4137 let event = AgentEvent::StepStart {
4138 step_id: "step-1".to_string(),
4139 description: "Initialize project".to_string(),
4140 step_number: 1,
4141 total_steps: 5,
4142 };
4143 let json = serde_json::to_string(&event).unwrap();
4144 assert!(json.contains("step_start"));
4145 assert!(json.contains("Initialize project"));
4146 }
4147
4148 #[test]
4149 fn test_agent_event_serialize_step_end() {
4150 let event = AgentEvent::StepEnd {
4151 step_id: "step-1".to_string(),
4152 status: TaskStatus::Completed,
4153 step_number: 1,
4154 total_steps: 5,
4155 };
4156 let json = serde_json::to_string(&event).unwrap();
4157 assert!(json.contains("step_end"));
4158 assert!(json.contains("step-1"));
4159 }
4160
4161 #[test]
4162 fn test_agent_event_serialize_goal_extracted() {
4163 use crate::planning::AgentGoal;
4164 let goal = AgentGoal::new("Complete the task".to_string());
4165 let event = AgentEvent::GoalExtracted { goal };
4166 let json = serde_json::to_string(&event).unwrap();
4167 assert!(json.contains("goal_extracted"));
4168 }
4169
4170 #[test]
4171 fn test_agent_event_serialize_goal_progress() {
4172 let event = AgentEvent::GoalProgress {
4173 goal: "Build app".to_string(),
4174 progress: 0.5,
4175 completed_steps: 2,
4176 total_steps: 4,
4177 };
4178 let json = serde_json::to_string(&event).unwrap();
4179 assert!(json.contains("goal_progress"));
4180 assert!(json.contains("0.5"));
4181 }
4182
4183 #[test]
4184 fn test_agent_event_serialize_goal_achieved() {
4185 let event = AgentEvent::GoalAchieved {
4186 goal: "Build app".to_string(),
4187 total_steps: 4,
4188 duration_ms: 5000,
4189 };
4190 let json = serde_json::to_string(&event).unwrap();
4191 assert!(json.contains("goal_achieved"));
4192 assert!(json.contains("5000"));
4193 }
4194
4195 #[tokio::test]
4196 async fn test_extract_goal_with_json_response() {
4197 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4199 r#"{"description": "Build web app", "success_criteria": ["App runs on port 3000", "Has login page"]}"#,
4200 )]));
4201 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4202 let agent = AgentLoop::new(
4203 mock_client,
4204 tool_executor,
4205 test_tool_context(),
4206 AgentConfig::default(),
4207 );
4208
4209 let goal = agent.extract_goal("Build a web app").await.unwrap();
4210 assert_eq!(goal.description, "Build web app");
4211 assert_eq!(goal.success_criteria.len(), 2);
4212 assert_eq!(goal.success_criteria[0], "App runs on port 3000");
4213 }
4214
4215 #[tokio::test]
4216 async fn test_extract_goal_fallback_on_non_json() {
4217 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4219 "Some non-JSON response",
4220 )]));
4221 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4222 let agent = AgentLoop::new(
4223 mock_client,
4224 tool_executor,
4225 test_tool_context(),
4226 AgentConfig::default(),
4227 );
4228
4229 let goal = agent.extract_goal("Do something").await.unwrap();
4230 assert_eq!(goal.description, "Do something");
4232 assert_eq!(goal.success_criteria.len(), 2);
4234 }
4235
4236 #[tokio::test]
4237 async fn test_check_goal_achievement_json_yes() {
4238 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4239 r#"{"achieved": true, "progress": 1.0, "remaining_criteria": []}"#,
4240 )]));
4241 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4242 let agent = AgentLoop::new(
4243 mock_client,
4244 tool_executor,
4245 test_tool_context(),
4246 AgentConfig::default(),
4247 );
4248
4249 let goal = crate::planning::AgentGoal::new("Test goal".to_string());
4250 let achieved = agent
4251 .check_goal_achievement(&goal, "All done")
4252 .await
4253 .unwrap();
4254 assert!(achieved);
4255 }
4256
4257 #[tokio::test]
4258 async fn test_check_goal_achievement_fallback_not_done() {
4259 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4261 "invalid json",
4262 )]));
4263 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4264 let agent = AgentLoop::new(
4265 mock_client,
4266 tool_executor,
4267 test_tool_context(),
4268 AgentConfig::default(),
4269 );
4270
4271 let goal = crate::planning::AgentGoal::new("Test goal".to_string());
4272 let achieved = agent
4274 .check_goal_achievement(&goal, "still working")
4275 .await
4276 .unwrap();
4277 assert!(!achieved);
4278 }
4279
4280 #[test]
4285 fn test_build_augmented_system_prompt_empty_context() {
4286 let mock_client = Arc::new(MockLlmClient::new(vec![]));
4287 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4288 let config = AgentConfig {
4289 system_prompt: Some("Base prompt".to_string()),
4290 ..Default::default()
4291 };
4292 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
4293
4294 let result = agent.build_augmented_system_prompt(&[]);
4295 assert_eq!(result, Some("Base prompt".to_string()));
4296 }
4297
4298 #[test]
4299 fn test_build_augmented_system_prompt_no_system_prompt() {
4300 let mock_client = Arc::new(MockLlmClient::new(vec![]));
4301 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4302 let agent = AgentLoop::new(
4303 mock_client,
4304 tool_executor,
4305 test_tool_context(),
4306 AgentConfig::default(),
4307 );
4308
4309 let result = agent.build_augmented_system_prompt(&[]);
4310 assert_eq!(result, None);
4311 }
4312
4313 #[test]
4314 fn test_build_augmented_system_prompt_with_context_no_base() {
4315 use crate::context::{ContextItem, ContextResult, ContextType};
4316
4317 let mock_client = Arc::new(MockLlmClient::new(vec![]));
4318 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4319 let agent = AgentLoop::new(
4320 mock_client,
4321 tool_executor,
4322 test_tool_context(),
4323 AgentConfig::default(),
4324 );
4325
4326 let context = vec![ContextResult {
4327 provider: "test".to_string(),
4328 items: vec![ContextItem::new("id1", ContextType::Resource, "Content")],
4329 total_tokens: 10,
4330 truncated: false,
4331 }];
4332
4333 let result = agent.build_augmented_system_prompt(&context);
4334 assert!(result.is_some());
4335 let text = result.unwrap();
4336 assert!(text.contains("<context"));
4337 assert!(text.contains("Content"));
4338 }
4339
4340 #[test]
4345 fn test_agent_result_clone() {
4346 let result = AgentResult {
4347 text: "output".to_string(),
4348 messages: vec![Message::user("hello")],
4349 usage: TokenUsage::default(),
4350 tool_calls_count: 3,
4351 };
4352 let cloned = result.clone();
4353 assert_eq!(cloned.text, result.text);
4354 assert_eq!(cloned.tool_calls_count, result.tool_calls_count);
4355 }
4356
4357 #[test]
4358 fn test_agent_result_debug() {
4359 let result = AgentResult {
4360 text: "output".to_string(),
4361 messages: vec![Message::user("hello")],
4362 usage: TokenUsage::default(),
4363 tool_calls_count: 3,
4364 };
4365 let debug = format!("{:?}", result);
4366 assert!(debug.contains("AgentResult"));
4367 assert!(debug.contains("output"));
4368 }
4369
4370 #[tokio::test]
4379 async fn test_tool_command_command_type() {
4380 let executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4381 let cmd = ToolCommand {
4382 tool_executor: executor,
4383 tool_name: "read".to_string(),
4384 tool_args: serde_json::json!({"file": "test.rs"}),
4385 skill_registry: None,
4386 tool_context: test_tool_context(),
4387 };
4388 assert_eq!(cmd.command_type(), "read");
4389 }
4390
4391 #[tokio::test]
4392 async fn test_tool_command_payload() {
4393 let executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4394 let args = serde_json::json!({"file": "test.rs", "offset": 10});
4395 let cmd = ToolCommand {
4396 tool_executor: executor,
4397 tool_name: "read".to_string(),
4398 tool_args: args.clone(),
4399 skill_registry: None,
4400 tool_context: test_tool_context(),
4401 };
4402 assert_eq!(cmd.payload(), args);
4403 }
4404
4405 #[tokio::test(flavor = "multi_thread")]
4410 async fn test_agent_loop_with_queue() {
4411 use tokio::sync::broadcast;
4412
4413 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4414 "Hello",
4415 )]));
4416 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4417 let config = AgentConfig::default();
4418
4419 let (event_tx, _) = broadcast::channel(100);
4420 let queue = SessionLaneQueue::new("test-session", SessionQueueConfig::default(), event_tx)
4421 .await
4422 .unwrap();
4423
4424 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config)
4425 .with_queue(Arc::new(queue));
4426
4427 assert!(agent.command_queue.is_some());
4428 }
4429
4430 #[tokio::test]
4431 async fn test_agent_loop_without_queue() {
4432 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4433 "Hello",
4434 )]));
4435 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4436 let config = AgentConfig::default();
4437
4438 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
4439
4440 assert!(agent.command_queue.is_none());
4441 }
4442
4443 #[tokio::test]
4448 async fn test_execute_plan_parallel_independent() {
4449 use crate::planning::{Complexity, ExecutionPlan, Task};
4450
4451 let mock_client = Arc::new(MockLlmClient::new(vec![
4454 MockLlmClient::text_response("Step 1 done"),
4455 MockLlmClient::text_response("Step 2 done"),
4456 MockLlmClient::text_response("Step 3 done"),
4457 ]));
4458
4459 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4460 let config = AgentConfig::default();
4461 let agent = AgentLoop::new(
4462 mock_client.clone(),
4463 tool_executor,
4464 test_tool_context(),
4465 config,
4466 );
4467
4468 let mut plan = ExecutionPlan::new("Test parallel", Complexity::Simple);
4469 plan.add_step(Task::new("s1", "First step"));
4470 plan.add_step(Task::new("s2", "Second step"));
4471 plan.add_step(Task::new("s3", "Third step"));
4472
4473 let (tx, mut rx) = mpsc::channel(100);
4474 let result = agent.execute_plan(&[], &plan, Some(tx)).await.unwrap();
4475
4476 assert_eq!(result.usage.total_tokens, 45);
4478
4479 let mut step_starts = Vec::new();
4481 let mut step_ends = Vec::new();
4482 rx.close();
4483 while let Some(event) = rx.recv().await {
4484 match event {
4485 AgentEvent::StepStart { step_id, .. } => step_starts.push(step_id),
4486 AgentEvent::StepEnd {
4487 step_id, status, ..
4488 } => {
4489 assert_eq!(status, TaskStatus::Completed);
4490 step_ends.push(step_id);
4491 }
4492 _ => {}
4493 }
4494 }
4495 assert_eq!(step_starts.len(), 3);
4496 assert_eq!(step_ends.len(), 3);
4497 }
4498
4499 #[tokio::test]
4500 async fn test_execute_plan_respects_dependencies() {
4501 use crate::planning::{Complexity, ExecutionPlan, Task};
4502
4503 let mock_client = Arc::new(MockLlmClient::new(vec![
4506 MockLlmClient::text_response("Step 1 done"),
4507 MockLlmClient::text_response("Step 2 done"),
4508 MockLlmClient::text_response("Step 3 done"),
4509 ]));
4510
4511 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4512 let config = AgentConfig::default();
4513 let agent = AgentLoop::new(
4514 mock_client.clone(),
4515 tool_executor,
4516 test_tool_context(),
4517 config,
4518 );
4519
4520 let mut plan = ExecutionPlan::new("Test deps", Complexity::Medium);
4521 plan.add_step(Task::new("s1", "Independent A"));
4522 plan.add_step(Task::new("s2", "Independent B"));
4523 plan.add_step(
4524 Task::new("s3", "Depends on A+B")
4525 .with_dependencies(vec!["s1".to_string(), "s2".to_string()]),
4526 );
4527
4528 let (tx, mut rx) = mpsc::channel(100);
4529 let result = agent.execute_plan(&[], &plan, Some(tx)).await.unwrap();
4530
4531 assert_eq!(result.usage.total_tokens, 45);
4533
4534 let mut events = Vec::new();
4536 rx.close();
4537 while let Some(event) = rx.recv().await {
4538 match &event {
4539 AgentEvent::StepStart { step_id, .. } => {
4540 events.push(format!("start:{}", step_id));
4541 }
4542 AgentEvent::StepEnd { step_id, .. } => {
4543 events.push(format!("end:{}", step_id));
4544 }
4545 _ => {}
4546 }
4547 }
4548
4549 let s1_end = events.iter().position(|e| e == "end:s1").unwrap();
4551 let s2_end = events.iter().position(|e| e == "end:s2").unwrap();
4552 let s3_start = events.iter().position(|e| e == "start:s3").unwrap();
4553 assert!(
4554 s3_start > s1_end,
4555 "s3 started before s1 ended: {:?}",
4556 events
4557 );
4558 assert!(
4559 s3_start > s2_end,
4560 "s3 started before s2 ended: {:?}",
4561 events
4562 );
4563
4564 assert!(result.text.contains("Step 3 done") || !result.text.is_empty());
4566 }
4567
4568 #[tokio::test]
4569 async fn test_execute_plan_handles_step_failure() {
4570 use crate::planning::{Complexity, ExecutionPlan, Task};
4571
4572 let mock_client = Arc::new(MockLlmClient::new(vec![
4582 MockLlmClient::text_response("s1 done"),
4584 MockLlmClient::text_response("s3 done"),
4585 ]));
4588
4589 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4590 let config = AgentConfig::default();
4591 let agent = AgentLoop::new(
4592 mock_client.clone(),
4593 tool_executor,
4594 test_tool_context(),
4595 config,
4596 );
4597
4598 let mut plan = ExecutionPlan::new("Test failure", Complexity::Medium);
4599 plan.add_step(Task::new("s1", "Independent step"));
4600 plan.add_step(Task::new("s2", "Depends on s1").with_dependencies(vec!["s1".to_string()]));
4601 plan.add_step(Task::new("s3", "Another independent"));
4602 plan.add_step(Task::new("s4", "Depends on s2").with_dependencies(vec!["s2".to_string()]));
4603
4604 let (tx, mut rx) = mpsc::channel(100);
4605 let _result = agent.execute_plan(&[], &plan, Some(tx)).await.unwrap();
4606
4607 let mut completed_steps = Vec::new();
4610 let mut failed_steps = Vec::new();
4611 rx.close();
4612 while let Some(event) = rx.recv().await {
4613 if let AgentEvent::StepEnd {
4614 step_id, status, ..
4615 } = event
4616 {
4617 match status {
4618 TaskStatus::Completed => completed_steps.push(step_id),
4619 TaskStatus::Failed => failed_steps.push(step_id),
4620 _ => {}
4621 }
4622 }
4623 }
4624
4625 assert!(
4626 completed_steps.contains(&"s1".to_string()),
4627 "s1 should complete"
4628 );
4629 assert!(
4630 completed_steps.contains(&"s3".to_string()),
4631 "s3 should complete"
4632 );
4633 assert!(failed_steps.contains(&"s2".to_string()), "s2 should fail");
4634 assert!(
4636 !completed_steps.contains(&"s4".to_string()),
4637 "s4 should not complete"
4638 );
4639 assert!(
4640 !failed_steps.contains(&"s4".to_string()),
4641 "s4 should not fail (never started)"
4642 );
4643 }
4644
4645 #[test]
4650 fn test_agent_config_resilience_defaults() {
4651 let config = AgentConfig::default();
4652 assert_eq!(config.max_parse_retries, 2);
4653 assert_eq!(config.tool_timeout_ms, None);
4654 assert_eq!(config.circuit_breaker_threshold, 3);
4655 }
4656
4657 #[tokio::test]
4659 async fn test_parse_error_recovery_bails_after_threshold() {
4660 let mock_client = Arc::new(MockLlmClient::new(vec![
4662 MockLlmClient::tool_call_response(
4663 "c1",
4664 "bash",
4665 serde_json::json!({"__parse_error": "unexpected token at position 5"}),
4666 ),
4667 MockLlmClient::tool_call_response(
4668 "c2",
4669 "bash",
4670 serde_json::json!({"__parse_error": "missing closing brace"}),
4671 ),
4672 MockLlmClient::tool_call_response(
4673 "c3",
4674 "bash",
4675 serde_json::json!({"__parse_error": "still broken"}),
4676 ),
4677 MockLlmClient::text_response("Done"), ]));
4679
4680 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4681 let config = AgentConfig {
4682 max_parse_retries: 2,
4683 ..AgentConfig::default()
4684 };
4685 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
4686 let result = agent.execute(&[], "Do something", None).await;
4687 assert!(result.is_err(), "should bail after parse error threshold");
4688 let err = result.unwrap_err().to_string();
4689 assert!(
4690 err.contains("malformed tool arguments"),
4691 "error should mention malformed tool arguments, got: {}",
4692 err
4693 );
4694 }
4695
4696 #[tokio::test]
4698 async fn test_parse_error_counter_resets_on_success() {
4699 let mock_client = Arc::new(MockLlmClient::new(vec![
4703 MockLlmClient::tool_call_response(
4704 "c1",
4705 "bash",
4706 serde_json::json!({"__parse_error": "bad args"}),
4707 ),
4708 MockLlmClient::tool_call_response(
4709 "c2",
4710 "bash",
4711 serde_json::json!({"__parse_error": "bad args again"}),
4712 ),
4713 MockLlmClient::tool_call_response(
4715 "c3",
4716 "bash",
4717 serde_json::json!({"command": "echo ok"}),
4718 ),
4719 MockLlmClient::text_response("All done"),
4720 ]));
4721
4722 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4723 let config = AgentConfig {
4724 max_parse_retries: 2,
4725 ..AgentConfig::default()
4726 };
4727 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
4728 let result = agent.execute(&[], "Do something", None).await;
4729 assert!(
4730 result.is_ok(),
4731 "should not bail — counter reset after successful tool, got: {:?}",
4732 result.err()
4733 );
4734 assert_eq!(result.unwrap().text, "All done");
4735 }
4736
4737 #[tokio::test]
4739 async fn test_tool_timeout_produces_error_result() {
4740 let mock_client = Arc::new(MockLlmClient::new(vec![
4741 MockLlmClient::tool_call_response(
4742 "t1",
4743 "bash",
4744 serde_json::json!({"command": "sleep 10"}),
4745 ),
4746 MockLlmClient::text_response("The command timed out."),
4747 ]));
4748
4749 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4750 let config = AgentConfig {
4751 tool_timeout_ms: Some(50),
4753 ..AgentConfig::default()
4754 };
4755 let agent = AgentLoop::new(
4756 mock_client.clone(),
4757 tool_executor,
4758 test_tool_context(),
4759 config,
4760 );
4761 let result = agent.execute(&[], "Run sleep", None).await;
4762 assert!(
4763 result.is_ok(),
4764 "session should continue after tool timeout: {:?}",
4765 result.err()
4766 );
4767 assert_eq!(result.unwrap().text, "The command timed out.");
4768 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 2);
4770 }
4771
4772 #[tokio::test]
4774 async fn test_tool_within_timeout_succeeds() {
4775 let mock_client = Arc::new(MockLlmClient::new(vec![
4776 MockLlmClient::tool_call_response(
4777 "t1",
4778 "bash",
4779 serde_json::json!({"command": "echo fast"}),
4780 ),
4781 MockLlmClient::text_response("Command succeeded."),
4782 ]));
4783
4784 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4785 let config = AgentConfig {
4786 tool_timeout_ms: Some(5_000), ..AgentConfig::default()
4788 };
4789 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
4790 let result = agent.execute(&[], "Run something fast", None).await;
4791 assert!(
4792 result.is_ok(),
4793 "fast tool should succeed: {:?}",
4794 result.err()
4795 );
4796 assert_eq!(result.unwrap().text, "Command succeeded.");
4797 }
4798
4799 #[tokio::test]
4801 async fn test_circuit_breaker_retries_non_streaming() {
4802 let mock_client = Arc::new(MockLlmClient::new(vec![]));
4805
4806 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4807 let config = AgentConfig {
4808 circuit_breaker_threshold: 2,
4809 ..AgentConfig::default()
4810 };
4811 let agent = AgentLoop::new(
4812 mock_client.clone(),
4813 tool_executor,
4814 test_tool_context(),
4815 config,
4816 );
4817 let result = agent.execute(&[], "Hello", None).await;
4818 assert!(result.is_err(), "should fail when LLM always errors");
4819 let err = result.unwrap_err().to_string();
4820 assert!(
4821 err.contains("circuit breaker"),
4822 "error should mention circuit breaker, got: {}",
4823 err
4824 );
4825 assert_eq!(
4826 mock_client.call_count.load(Ordering::SeqCst),
4827 2,
4828 "should make exactly threshold=2 LLM calls"
4829 );
4830 }
4831
4832 #[tokio::test]
4834 async fn test_circuit_breaker_threshold_one_no_retry() {
4835 let mock_client = Arc::new(MockLlmClient::new(vec![]));
4836
4837 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4838 let config = AgentConfig {
4839 circuit_breaker_threshold: 1,
4840 ..AgentConfig::default()
4841 };
4842 let agent = AgentLoop::new(
4843 mock_client.clone(),
4844 tool_executor,
4845 test_tool_context(),
4846 config,
4847 );
4848 let result = agent.execute(&[], "Hello", None).await;
4849 assert!(result.is_err());
4850 assert_eq!(
4851 mock_client.call_count.load(Ordering::SeqCst),
4852 1,
4853 "with threshold=1 exactly one attempt should be made"
4854 );
4855 }
4856
4857 #[tokio::test]
4859 async fn test_circuit_breaker_succeeds_if_llm_recovers() {
4860 struct FailOnceThenSucceed {
4862 inner: MockLlmClient,
4863 failed_once: std::sync::atomic::AtomicBool,
4864 call_count: AtomicUsize,
4865 }
4866
4867 #[async_trait::async_trait]
4868 impl LlmClient for FailOnceThenSucceed {
4869 async fn complete(
4870 &self,
4871 messages: &[Message],
4872 system: Option<&str>,
4873 tools: &[ToolDefinition],
4874 ) -> Result<LlmResponse> {
4875 self.call_count.fetch_add(1, Ordering::SeqCst);
4876 let already_failed = self
4877 .failed_once
4878 .swap(true, std::sync::atomic::Ordering::SeqCst);
4879 if !already_failed {
4880 anyhow::bail!("transient network error");
4881 }
4882 self.inner.complete(messages, system, tools).await
4883 }
4884
4885 async fn complete_streaming(
4886 &self,
4887 messages: &[Message],
4888 system: Option<&str>,
4889 tools: &[ToolDefinition],
4890 ) -> Result<tokio::sync::mpsc::Receiver<crate::llm::StreamEvent>> {
4891 self.inner.complete_streaming(messages, system, tools).await
4892 }
4893 }
4894
4895 let mock = Arc::new(FailOnceThenSucceed {
4896 inner: MockLlmClient::new(vec![MockLlmClient::text_response("Recovered!")]),
4897 failed_once: std::sync::atomic::AtomicBool::new(false),
4898 call_count: AtomicUsize::new(0),
4899 });
4900
4901 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4902 let config = AgentConfig {
4903 circuit_breaker_threshold: 3,
4904 ..AgentConfig::default()
4905 };
4906 let agent = AgentLoop::new(mock.clone(), tool_executor, test_tool_context(), config);
4907 let result = agent.execute(&[], "Hello", None).await;
4908 assert!(
4909 result.is_ok(),
4910 "should succeed when LLM recovers within threshold: {:?}",
4911 result.err()
4912 );
4913 assert_eq!(result.unwrap().text, "Recovered!");
4914 assert_eq!(
4915 mock.call_count.load(Ordering::SeqCst),
4916 2,
4917 "should have made exactly 2 calls (1 fail + 1 success)"
4918 );
4919 }
4920}