1use crate::context::{ContextProvider, ContextQuery, ContextResult};
13use crate::hitl::ConfirmationManager;
14use crate::hooks::{
15 GenerateEndEvent, GenerateStartEvent, HookEngine, HookEvent, HookResult, PostToolUseEvent,
16 PreToolUseEvent, TokenUsageInfo, ToolCallInfo, ToolResultData,
17};
18use crate::llm::{LlmClient, LlmResponse, Message, TokenUsage, ToolDefinition};
19use crate::permissions::{PermissionDecision, PermissionPolicy};
20use crate::planning::{AgentGoal, ExecutionPlan, TaskStatus};
21use crate::tools::skill::Skill;
22use crate::tools::{ToolContext, ToolExecutor, ToolStreamEvent};
23use anyhow::{Context, Result};
24use futures::future::join_all;
25use serde::{Deserialize, Serialize};
26use std::sync::Arc;
27use std::time::Duration;
28use tokio::sync::{mpsc, RwLock};
29use tracing::Instrument;
30
31const MAX_TOOL_ROUNDS: usize = 50;
33
34#[derive(Clone)]
36pub struct AgentConfig {
37 pub system_prompt: Option<String>,
38 pub tools: Vec<ToolDefinition>,
39 pub max_tool_rounds: usize,
40 pub permission_policy: Option<Arc<RwLock<PermissionPolicy>>>,
42 pub confirmation_manager: Option<Arc<ConfirmationManager>>,
44 pub context_providers: Vec<Arc<dyn ContextProvider>>,
46 pub planning_enabled: bool,
48 pub goal_tracking: bool,
50 pub skill_tool_filters: Vec<Skill>,
55 pub hook_engine: Option<Arc<HookEngine>>,
57}
58
59impl std::fmt::Debug for AgentConfig {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 f.debug_struct("AgentConfig")
62 .field("system_prompt", &self.system_prompt)
63 .field("tools", &self.tools)
64 .field("max_tool_rounds", &self.max_tool_rounds)
65 .field("permission_policy", &self.permission_policy.is_some())
66 .field("confirmation_manager", &self.confirmation_manager.is_some())
67 .field("context_providers", &self.context_providers.len())
68 .field("planning_enabled", &self.planning_enabled)
69 .field("goal_tracking", &self.goal_tracking)
70 .field("skill_tool_filters", &self.skill_tool_filters.len())
71 .field("hook_engine", &self.hook_engine.is_some())
72 .finish()
73 }
74}
75
76impl Default for AgentConfig {
77 fn default() -> Self {
78 Self {
79 system_prompt: None,
80 tools: Vec::new(), max_tool_rounds: MAX_TOOL_ROUNDS,
82 permission_policy: None,
83 confirmation_manager: None,
84 context_providers: Vec::new(),
85 planning_enabled: false,
86 goal_tracking: false,
87 skill_tool_filters: Vec::new(),
88 hook_engine: None,
89 }
90 }
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize)]
99#[serde(tag = "type")]
100#[non_exhaustive]
101pub enum AgentEvent {
102 #[serde(rename = "agent_start")]
104 Start { prompt: String },
105
106 #[serde(rename = "turn_start")]
108 TurnStart { turn: usize },
109
110 #[serde(rename = "text_delta")]
112 TextDelta { text: String },
113
114 #[serde(rename = "tool_start")]
116 ToolStart { id: String, name: String },
117
118 #[serde(rename = "tool_end")]
120 ToolEnd {
121 id: String,
122 name: String,
123 output: String,
124 exit_code: i32,
125 },
126
127 #[serde(rename = "tool_output_delta")]
129 ToolOutputDelta {
130 id: String,
131 name: String,
132 delta: String,
133 },
134
135 #[serde(rename = "turn_end")]
137 TurnEnd { turn: usize, usage: TokenUsage },
138
139 #[serde(rename = "agent_end")]
141 End { text: String, usage: TokenUsage },
142
143 #[serde(rename = "error")]
145 Error { message: String },
146
147 #[serde(rename = "confirmation_required")]
149 ConfirmationRequired {
150 tool_id: String,
151 tool_name: String,
152 args: serde_json::Value,
153 timeout_ms: u64,
154 },
155
156 #[serde(rename = "confirmation_received")]
158 ConfirmationReceived {
159 tool_id: String,
160 approved: bool,
161 reason: Option<String>,
162 },
163
164 #[serde(rename = "confirmation_timeout")]
166 ConfirmationTimeout {
167 tool_id: String,
168 action_taken: String, },
170
171 #[serde(rename = "external_task_pending")]
173 ExternalTaskPending {
174 task_id: String,
175 session_id: String,
176 lane: crate::hitl::SessionLane,
177 command_type: String,
178 payload: serde_json::Value,
179 timeout_ms: u64,
180 },
181
182 #[serde(rename = "external_task_completed")]
184 ExternalTaskCompleted {
185 task_id: String,
186 session_id: String,
187 success: bool,
188 },
189
190 #[serde(rename = "permission_denied")]
192 PermissionDenied {
193 tool_id: String,
194 tool_name: String,
195 args: serde_json::Value,
196 reason: String,
197 },
198
199 #[serde(rename = "context_resolving")]
201 ContextResolving { providers: Vec<String> },
202
203 #[serde(rename = "context_resolved")]
205 ContextResolved {
206 total_items: usize,
207 total_tokens: usize,
208 },
209
210 #[serde(rename = "command_dead_lettered")]
215 CommandDeadLettered {
216 command_id: String,
217 command_type: String,
218 lane: String,
219 error: String,
220 attempts: u32,
221 },
222
223 #[serde(rename = "command_retry")]
225 CommandRetry {
226 command_id: String,
227 command_type: String,
228 lane: String,
229 attempt: u32,
230 delay_ms: u64,
231 },
232
233 #[serde(rename = "queue_alert")]
235 QueueAlert {
236 level: String,
237 alert_type: String,
238 message: String,
239 },
240
241 #[serde(rename = "task_updated")]
246 TaskUpdated {
247 session_id: String,
248 tasks: Vec<crate::planning::Task>,
249 },
250
251 #[serde(rename = "memory_stored")]
256 MemoryStored {
257 memory_id: String,
258 memory_type: String,
259 importance: f32,
260 tags: Vec<String>,
261 },
262
263 #[serde(rename = "memory_recalled")]
265 MemoryRecalled {
266 memory_id: String,
267 content: String,
268 relevance: f32,
269 },
270
271 #[serde(rename = "memories_searched")]
273 MemoriesSearched {
274 query: Option<String>,
275 tags: Vec<String>,
276 result_count: usize,
277 },
278
279 #[serde(rename = "memory_cleared")]
281 MemoryCleared {
282 tier: String, count: u64,
284 },
285
286 #[serde(rename = "subagent_start")]
291 SubagentStart {
292 task_id: String,
294 session_id: String,
296 parent_session_id: String,
298 agent: String,
300 description: String,
302 },
303
304 #[serde(rename = "subagent_progress")]
306 SubagentProgress {
307 task_id: String,
309 session_id: String,
311 status: String,
313 metadata: serde_json::Value,
315 },
316
317 #[serde(rename = "subagent_end")]
319 SubagentEnd {
320 task_id: String,
322 session_id: String,
324 agent: String,
326 output: String,
328 success: bool,
330 },
331
332 #[serde(rename = "planning_start")]
337 PlanningStart { prompt: String },
338
339 #[serde(rename = "planning_end")]
341 PlanningEnd {
342 plan: ExecutionPlan,
343 estimated_steps: usize,
344 },
345
346 #[serde(rename = "step_start")]
348 StepStart {
349 step_id: String,
350 description: String,
351 step_number: usize,
352 total_steps: usize,
353 },
354
355 #[serde(rename = "step_end")]
357 StepEnd {
358 step_id: String,
359 status: TaskStatus,
360 step_number: usize,
361 total_steps: usize,
362 },
363
364 #[serde(rename = "goal_extracted")]
366 GoalExtracted { goal: AgentGoal },
367
368 #[serde(rename = "goal_progress")]
370 GoalProgress {
371 goal: String,
372 progress: f32,
373 completed_steps: usize,
374 total_steps: usize,
375 },
376
377 #[serde(rename = "goal_achieved")]
379 GoalAchieved {
380 goal: String,
381 total_steps: usize,
382 duration_ms: i64,
383 },
384
385 #[serde(rename = "context_compacted")]
390 ContextCompacted {
391 session_id: String,
392 before_messages: usize,
393 after_messages: usize,
394 percent_before: f32,
395 },
396
397 #[serde(rename = "persistence_failed")]
402 PersistenceFailed {
403 session_id: String,
404 operation: String,
405 error: String,
406 },
407}
408
409#[derive(Debug, Clone)]
411pub struct AgentResult {
412 pub text: String,
413 pub messages: Vec<Message>,
414 pub usage: TokenUsage,
415 pub tool_calls_count: usize,
416}
417
418pub struct AgentLoop {
420 llm_client: Arc<dyn LlmClient>,
421 tool_executor: Arc<ToolExecutor>,
422 tool_context: ToolContext,
423 config: AgentConfig,
424 tool_metrics: Option<Arc<RwLock<crate::telemetry::ToolMetrics>>>,
426}
427
428impl AgentLoop {
429 pub fn new(
430 llm_client: Arc<dyn LlmClient>,
431 tool_executor: Arc<ToolExecutor>,
432 tool_context: ToolContext,
433 config: AgentConfig,
434 ) -> Self {
435 Self {
436 llm_client,
437 tool_executor,
438 tool_context,
439 config,
440 tool_metrics: None,
441 }
442 }
443
444 pub fn with_tool_metrics(
446 mut self,
447 metrics: Arc<RwLock<crate::telemetry::ToolMetrics>>,
448 ) -> Self {
449 self.tool_metrics = Some(metrics);
450 self
451 }
452
453 fn streaming_tool_context(
462 &self,
463 event_tx: &Option<mpsc::Sender<AgentEvent>>,
464 tool_id: &str,
465 tool_name: &str,
466 ) -> ToolContext {
467 let mut ctx = self.tool_context.clone();
468 if let Some(agent_tx) = event_tx {
469 let (tool_tx, mut tool_rx) = mpsc::channel::<ToolStreamEvent>(64);
470 ctx.event_tx = Some(tool_tx);
471
472 let agent_tx = agent_tx.clone();
473 let tool_id = tool_id.to_string();
474 let tool_name = tool_name.to_string();
475 tokio::spawn(async move {
476 while let Some(event) = tool_rx.recv().await {
477 match event {
478 ToolStreamEvent::OutputDelta(delta) => {
479 agent_tx
480 .send(AgentEvent::ToolOutputDelta {
481 id: tool_id.clone(),
482 name: tool_name.clone(),
483 delta,
484 })
485 .await
486 .ok();
487 }
488 }
489 }
490 });
491 }
492 ctx
493 }
494
495 async fn resolve_context(&self, prompt: &str, session_id: Option<&str>) -> Vec<ContextResult> {
499 if self.config.context_providers.is_empty() {
500 return Vec::new();
501 }
502
503 let query = ContextQuery::new(prompt).with_session_id(session_id.unwrap_or(""));
504
505 let futures = self
506 .config
507 .context_providers
508 .iter()
509 .map(|p| p.query(&query));
510 let outcomes = join_all(futures).await;
511
512 outcomes
513 .into_iter()
514 .enumerate()
515 .filter_map(|(i, r)| match r {
516 Ok(result) if !result.is_empty() => Some(result),
517 Ok(_) => None,
518 Err(e) => {
519 tracing::warn!(
520 "Context provider '{}' failed: {}",
521 self.config.context_providers[i].name(),
522 e
523 );
524 None
525 }
526 })
527 .collect()
528 }
529
530 fn build_augmented_system_prompt(&self, context_results: &[ContextResult]) -> Option<String> {
532 if context_results.is_empty() {
533 return self.config.system_prompt.clone();
534 }
535
536 let context_xml: String = context_results
538 .iter()
539 .map(|r| r.to_xml())
540 .collect::<Vec<_>>()
541 .join("\n\n");
542
543 match &self.config.system_prompt {
545 Some(system) => Some(format!("{}\n\n{}", system, context_xml)),
546 None => Some(context_xml),
547 }
548 }
549
550 async fn notify_turn_complete(&self, session_id: &str, prompt: &str, response: &str) {
552 let futures = self
553 .config
554 .context_providers
555 .iter()
556 .map(|p| p.on_turn_complete(session_id, prompt, response));
557 let outcomes = join_all(futures).await;
558
559 for (i, result) in outcomes.into_iter().enumerate() {
560 if let Err(e) = result {
561 tracing::warn!(
562 "Context provider '{}' on_turn_complete failed: {}",
563 self.config.context_providers[i].name(),
564 e
565 );
566 }
567 }
568 }
569
570 async fn fire_pre_tool_use(
573 &self,
574 session_id: &str,
575 tool_name: &str,
576 args: &serde_json::Value,
577 ) -> Option<HookResult> {
578 if let Some(he) = &self.config.hook_engine {
579 let event = HookEvent::PreToolUse(PreToolUseEvent {
580 session_id: session_id.to_string(),
581 tool: tool_name.to_string(),
582 args: args.clone(),
583 working_directory: self.tool_context.workspace.to_string_lossy().to_string(),
584 recent_tools: Vec::new(),
585 });
586 let result = he.fire(&event).await;
587 if result.is_block() {
588 return Some(result);
589 }
590 }
591 None
592 }
593
594 async fn fire_post_tool_use(
596 &self,
597 session_id: &str,
598 tool_name: &str,
599 args: &serde_json::Value,
600 output: &str,
601 success: bool,
602 duration_ms: u64,
603 ) {
604 if let Some(he) = &self.config.hook_engine {
605 let event = HookEvent::PostToolUse(PostToolUseEvent {
606 session_id: session_id.to_string(),
607 tool: tool_name.to_string(),
608 args: args.clone(),
609 result: ToolResultData {
610 success,
611 output: output.to_string(),
612 exit_code: if success { Some(0) } else { Some(1) },
613 duration_ms,
614 },
615 });
616 let _ = he.fire(&event).await;
617 }
618 }
619
620 async fn fire_generate_start(
622 &self,
623 session_id: &str,
624 prompt: &str,
625 system_prompt: &Option<String>,
626 ) {
627 if let Some(he) = &self.config.hook_engine {
628 let event = HookEvent::GenerateStart(GenerateStartEvent {
629 session_id: session_id.to_string(),
630 prompt: prompt.to_string(),
631 system_prompt: system_prompt.clone(),
632 model_provider: String::new(),
633 model_name: String::new(),
634 available_tools: self.config.tools.iter().map(|t| t.name.clone()).collect(),
635 });
636 let _ = he.fire(&event).await;
637 }
638 }
639
640 async fn fire_generate_end(
642 &self,
643 session_id: &str,
644 prompt: &str,
645 response: &LlmResponse,
646 duration_ms: u64,
647 ) {
648 if let Some(he) = &self.config.hook_engine {
649 let tool_calls: Vec<ToolCallInfo> = response
650 .tool_calls()
651 .iter()
652 .map(|tc| ToolCallInfo {
653 name: tc.name.clone(),
654 args: tc.args.clone(),
655 })
656 .collect();
657
658 let event = HookEvent::GenerateEnd(GenerateEndEvent {
659 session_id: session_id.to_string(),
660 prompt: prompt.to_string(),
661 response_text: response.text().to_string(),
662 tool_calls,
663 usage: TokenUsageInfo {
664 prompt_tokens: response.usage.prompt_tokens as i32,
665 completion_tokens: response.usage.completion_tokens as i32,
666 total_tokens: response.usage.total_tokens as i32,
667 },
668 duration_ms,
669 });
670 let _ = he.fire(&event).await;
671 }
672 }
673
674 fn handle_post_execution_metadata(
684 metadata: &Option<serde_json::Value>,
685 augmented_system: &mut Option<String>,
686 tool_executor: Option<&ToolExecutor>,
687 ) -> Option<String> {
688 let meta = metadata.as_ref()?;
689 if meta.get("_load_skill")?.as_bool() != Some(true) {
690 return None;
691 }
692
693 let skill_content = meta.get("skill_content")?.as_str()?;
694 let skill_name = meta
695 .get("skill_name")
696 .and_then(|v| v.as_str())
697 .unwrap_or("unknown");
698
699 let skill = Skill::parse(skill_content)?;
701
702 match skill.kind {
703 crate::tools::SkillKind::Instruction => {
704 let xml_fragment = format!(
706 "\n\n<skills>\n<skill name=\"{}\">\n{}\n</skill>\n</skills>",
707 skill.name, skill.content
708 );
709
710 match augmented_system {
711 Some(existing) => existing.push_str(&xml_fragment),
712 None => *augmented_system = Some(xml_fragment.clone()),
713 }
714
715 tracing::info!(
716 skill_name = skill_name,
717 kind = "instruction",
718 "Auto-loaded instruction skill into session"
719 );
720
721 Some(xml_fragment)
722 }
723 crate::tools::SkillKind::Tool => {
724 let xml_fragment = format!(
726 "\n\n<skills>\n<skill name=\"{}\">\n{}\n</skill>\n</skills>",
727 skill.name, skill.content
728 );
729
730 match augmented_system {
731 Some(existing) => existing.push_str(&xml_fragment),
732 None => *augmented_system = Some(xml_fragment.clone()),
733 }
734
735 if let Some(executor) = tool_executor {
737 let tools = crate::tools::parse_skill_tools(skill_content);
738 for tool in tools {
739 tracing::info!(
740 skill_name = skill_name,
741 tool_name = tool.name(),
742 "Registered tool from Tool-kind skill"
743 );
744 executor.registry().register(tool);
745 }
746 }
747
748 tracing::info!(
749 skill_name = skill_name,
750 kind = "tool",
751 "Auto-loaded tool skill into session"
752 );
753
754 Some(xml_fragment)
755 }
756 crate::tools::SkillKind::Agent => {
757 tracing::info!(
758 skill_name = skill_name,
759 kind = "agent",
760 "Loaded agent skill (agent registration not yet implemented)"
761 );
762 None
763 }
764 }
765 }
766
767 pub async fn execute(
773 &self,
774 history: &[Message],
775 prompt: &str,
776 event_tx: Option<mpsc::Sender<AgentEvent>>,
777 ) -> Result<AgentResult> {
778 self.execute_with_session(history, prompt, None, event_tx)
779 .await
780 }
781
782 #[tracing::instrument(
787 name = "a3s.agent.execute",
788 skip(self, history, prompt, event_tx),
789 fields(
790 a3s.session.id = session_id.unwrap_or("none"),
791 a3s.agent.max_turns = self.config.max_tool_rounds,
792 a3s.agent.tool_calls_count = tracing::field::Empty,
793 a3s.llm.total_tokens = tracing::field::Empty,
794 )
795 )]
796 pub async fn execute_with_session(
797 &self,
798 history: &[Message],
799 prompt: &str,
800 session_id: Option<&str>,
801 event_tx: Option<mpsc::Sender<AgentEvent>>,
802 ) -> Result<AgentResult> {
803 if self.config.planning_enabled {
805 return self.execute_with_planning(history, prompt, event_tx).await;
806 }
807
808 self.execute_loop(history, prompt, session_id, event_tx)
809 .await
810 }
811
812 async fn execute_loop(
818 &self,
819 history: &[Message],
820 prompt: &str,
821 session_id: Option<&str>,
822 event_tx: Option<mpsc::Sender<AgentEvent>>,
823 ) -> Result<AgentResult> {
824 let mut messages = history.to_vec();
825 let mut total_usage = TokenUsage::default();
826 let mut tool_calls_count = 0;
827 let mut turn = 0;
828
829 if let Some(tx) = &event_tx {
831 tx.send(AgentEvent::Start {
832 prompt: prompt.to_string(),
833 })
834 .await
835 .ok();
836 }
837
838 let mut augmented_system = if !self.config.context_providers.is_empty() {
840 if let Some(tx) = &event_tx {
842 let provider_names: Vec<String> = self
843 .config
844 .context_providers
845 .iter()
846 .map(|p| p.name().to_string())
847 .collect();
848 tx.send(AgentEvent::ContextResolving {
849 providers: provider_names,
850 })
851 .await
852 .ok();
853 }
854
855 let context_results = {
856 let context_span = tracing::info_span!(
857 "a3s.agent.context_resolve",
858 a3s.context.providers = self.config.context_providers.len() as i64,
859 a3s.context.items = tracing::field::Empty,
860 a3s.context.tokens = tracing::field::Empty,
861 );
862
863 self.resolve_context(prompt, session_id)
864 .instrument(context_span)
865 .await
866 };
867
868 if let Some(tx) = &event_tx {
870 let total_items: usize = context_results.iter().map(|r| r.items.len()).sum();
871 let total_tokens: usize = context_results.iter().map(|r| r.total_tokens).sum();
872
873 tracing::info!(
874 context_items = total_items,
875 context_tokens = total_tokens,
876 "Context resolution completed"
877 );
878
879 tx.send(AgentEvent::ContextResolved {
880 total_items,
881 total_tokens,
882 })
883 .await
884 .ok();
885 }
886
887 self.build_augmented_system_prompt(&context_results)
888 } else {
889 self.config.system_prompt.clone()
890 };
891
892 messages.push(Message::user(prompt));
894
895 loop {
896 turn += 1;
897
898 let turn_span = tracing::info_span!(
899 "a3s.agent.turn",
900 a3s.agent.turn_number = turn as i64,
901 a3s.llm.total_tokens = tracing::field::Empty,
902 );
903 let _turn_guard = turn_span.enter();
904
905 if turn > self.config.max_tool_rounds {
906 let error = format!("Max tool rounds ({}) exceeded", self.config.max_tool_rounds);
907 if let Some(tx) = &event_tx {
908 tx.send(AgentEvent::Error {
909 message: error.clone(),
910 })
911 .await
912 .ok();
913 }
914 anyhow::bail!(error);
915 }
916
917 if let Some(tx) = &event_tx {
919 tx.send(AgentEvent::TurnStart { turn }).await.ok();
920 }
921
922 tracing::info!(
923 turn = turn,
924 max_turns = self.config.max_tool_rounds,
925 "Agent turn started"
926 );
927
928 let llm_span = tracing::info_span!(
930 "a3s.llm.completion",
931 a3s.llm.streaming = event_tx.is_some(),
932 a3s.llm.prompt_tokens = tracing::field::Empty,
933 a3s.llm.completion_tokens = tracing::field::Empty,
934 a3s.llm.total_tokens = tracing::field::Empty,
935 a3s.llm.stop_reason = tracing::field::Empty,
936 );
937 let _llm_guard = llm_span.enter();
938
939 self.fire_generate_start(session_id.unwrap_or(""), prompt, &augmented_system)
941 .await;
942
943 let llm_start = std::time::Instant::now();
944 let response = if event_tx.is_some() {
945 let mut stream_rx = self
947 .llm_client
948 .complete_streaming(&messages, augmented_system.as_deref(), &self.config.tools)
949 .await
950 .context("LLM streaming call failed")?;
951
952 let mut final_response: Option<LlmResponse> = None;
953
954 while let Some(event) = stream_rx.recv().await {
955 match event {
956 crate::llm::StreamEvent::TextDelta(text) => {
957 if let Some(tx) = &event_tx {
958 tx.send(AgentEvent::TextDelta { text }).await.ok();
959 }
960 }
961 crate::llm::StreamEvent::ToolUseStart { id, name } => {
962 if let Some(tx) = &event_tx {
963 tx.send(AgentEvent::ToolStart { id, name }).await.ok();
964 }
965 }
966 crate::llm::StreamEvent::ToolUseInputDelta(_) => {
967 }
969 crate::llm::StreamEvent::Done(resp) => {
970 final_response = Some(resp);
971 }
972 }
973 }
974
975 final_response.context("Stream ended without final response")?
976 } else {
977 self.llm_client
979 .complete(&messages, augmented_system.as_deref(), &self.config.tools)
980 .await
981 .context("LLM call failed")?
982 };
983
984 total_usage.prompt_tokens += response.usage.prompt_tokens;
986 total_usage.completion_tokens += response.usage.completion_tokens;
987 total_usage.total_tokens += response.usage.total_tokens;
988
989 let llm_duration = llm_start.elapsed();
991 tracing::info!(
992 turn = turn,
993 streaming = event_tx.is_some(),
994 prompt_tokens = response.usage.prompt_tokens,
995 completion_tokens = response.usage.completion_tokens,
996 total_tokens = response.usage.total_tokens,
997 stop_reason = response.stop_reason.as_deref().unwrap_or("unknown"),
998 duration_ms = llm_duration.as_millis() as u64,
999 "LLM completion finished"
1000 );
1001
1002 self.fire_generate_end(
1004 session_id.unwrap_or(""),
1005 prompt,
1006 &response,
1007 llm_duration.as_millis() as u64,
1008 )
1009 .await;
1010
1011 crate::telemetry::record_llm_usage(
1013 response.usage.prompt_tokens,
1014 response.usage.completion_tokens,
1015 response.usage.total_tokens,
1016 response.stop_reason.as_deref(),
1017 );
1018 drop(_llm_guard);
1019
1020 turn_span.record("a3s.llm.total_tokens", response.usage.total_tokens as i64);
1022
1023 messages.push(response.message.clone());
1025
1026 let tool_calls = response.tool_calls();
1028
1029 if let Some(tx) = &event_tx {
1031 tx.send(AgentEvent::TurnEnd {
1032 turn,
1033 usage: response.usage.clone(),
1034 })
1035 .await
1036 .ok();
1037 }
1038
1039 if tool_calls.is_empty() {
1040 let final_text = response.text();
1042
1043 tracing::info!(
1045 tool_calls_count = tool_calls_count,
1046 total_prompt_tokens = total_usage.prompt_tokens,
1047 total_completion_tokens = total_usage.completion_tokens,
1048 total_tokens = total_usage.total_tokens,
1049 turns = turn,
1050 "Agent execution completed"
1051 );
1052
1053 if let Some(tx) = &event_tx {
1054 tx.send(AgentEvent::End {
1055 text: final_text.clone(),
1056 usage: total_usage.clone(),
1057 })
1058 .await
1059 .ok();
1060 }
1061
1062 if let Some(sid) = session_id {
1064 self.notify_turn_complete(sid, prompt, &final_text).await;
1065 }
1066
1067 return Ok(AgentResult {
1068 text: final_text,
1069 messages,
1070 usage: total_usage,
1071 tool_calls_count,
1072 });
1073 }
1074
1075 for tool_call in tool_calls {
1077 tool_calls_count += 1;
1078
1079 let tool_span = tracing::info_span!(
1080 "a3s.tool.execute",
1081 a3s.tool.name = tool_call.name.as_str(),
1082 a3s.tool.id = tool_call.id.as_str(),
1083 a3s.tool.exit_code = tracing::field::Empty,
1084 a3s.tool.success = tracing::field::Empty,
1085 a3s.tool.duration_ms = tracing::field::Empty,
1086 a3s.tool.permission = tracing::field::Empty,
1087 );
1088 let _tool_guard = tool_span.enter();
1089
1090 let tool_start = std::time::Instant::now();
1091
1092 tracing::info!(
1093 tool_name = tool_call.name.as_str(),
1094 tool_id = tool_call.id.as_str(),
1095 "Tool execution started"
1096 );
1097
1098 if let Some(parse_error) =
1104 tool_call.args.get("__parse_error").and_then(|v| v.as_str())
1105 {
1106 let error_msg = format!("Error: {}", parse_error);
1107 tracing::warn!(
1108 tool = tool_call.name.as_str(),
1109 "Malformed tool arguments from LLM"
1110 );
1111
1112 if let Some(tx) = &event_tx {
1113 tx.send(AgentEvent::ToolEnd {
1114 id: tool_call.id.clone(),
1115 name: tool_call.name.clone(),
1116 output: error_msg.clone(),
1117 exit_code: 1,
1118 })
1119 .await
1120 .ok();
1121 }
1122
1123 messages.push(Message::tool_result(&tool_call.id, &error_msg, true));
1124 continue;
1125 }
1126
1127 if let Some(HookResult::Block(reason)) = self
1129 .fire_pre_tool_use(session_id.unwrap_or(""), &tool_call.name, &tool_call.args)
1130 .await
1131 {
1132 let msg = format!("Tool '{}' blocked by hook: {}", tool_call.name, reason);
1133 tracing::info!(
1134 tool_name = tool_call.name.as_str(),
1135 "Tool blocked by PreToolUse hook"
1136 );
1137
1138 if let Some(tx) = &event_tx {
1139 tx.send(AgentEvent::PermissionDenied {
1140 tool_id: tool_call.id.clone(),
1141 tool_name: tool_call.name.clone(),
1142 args: tool_call.args.clone(),
1143 reason: reason.clone(),
1144 })
1145 .await
1146 .ok();
1147 }
1148
1149 messages.push(Message::tool_result(&tool_call.id, &msg, true));
1150 continue;
1151 }
1152
1153 if !self.config.skill_tool_filters.is_empty() {
1158 let has_restrictions = self
1159 .config
1160 .skill_tool_filters
1161 .iter()
1162 .any(|s| s.allowed_tools.is_some());
1163
1164 if has_restrictions {
1165 let args_str = serde_json::to_string(&tool_call.args).unwrap_or_default();
1166 let tool_allowed = self
1167 .config
1168 .skill_tool_filters
1169 .iter()
1170 .filter(|s| s.allowed_tools.is_some())
1171 .any(|s| s.is_tool_allowed(&tool_call.name, &args_str));
1172
1173 if !tool_allowed {
1174 tracing::info!(
1175 tool_name = tool_call.name.as_str(),
1176 "Tool blocked by skill allowed_tools restriction"
1177 );
1178 let msg = format!(
1179 "Tool '{}' is not permitted by any loaded skill's allowed_tools policy.",
1180 tool_call.name
1181 );
1182
1183 if let Some(tx) = &event_tx {
1184 tx.send(AgentEvent::PermissionDenied {
1185 tool_id: tool_call.id.clone(),
1186 tool_name: tool_call.name.clone(),
1187 args: tool_call.args.clone(),
1188 reason: "Blocked by skill allowed_tools restriction"
1189 .to_string(),
1190 })
1191 .await
1192 .ok();
1193 }
1194
1195 messages.push(Message::tool_result(&tool_call.id, &msg, true));
1196 continue;
1197 }
1198 }
1199 }
1200
1201 let permission_decision = if let Some(policy_lock) = &self.config.permission_policy
1203 {
1204 let policy = policy_lock.read().await;
1205 policy.check(&tool_call.name, &tool_call.args)
1206 } else {
1207 PermissionDecision::Ask
1209 };
1210
1211 let (output, exit_code, is_error, metadata) = match permission_decision {
1212 PermissionDecision::Deny => {
1213 tracing::info!(
1214 tool_name = tool_call.name.as_str(),
1215 permission = "deny",
1216 "Tool permission denied"
1217 );
1218 tool_span.record("a3s.tool.permission", "deny");
1219 let denial_msg = format!(
1221 "Permission denied: Tool '{}' is blocked by permission policy.",
1222 tool_call.name
1223 );
1224
1225 if let Some(tx) = &event_tx {
1227 tx.send(AgentEvent::PermissionDenied {
1228 tool_id: tool_call.id.clone(),
1229 tool_name: tool_call.name.clone(),
1230 args: tool_call.args.clone(),
1231 reason: "Blocked by deny rule in permission policy".to_string(),
1232 })
1233 .await
1234 .ok();
1235 }
1236
1237 (denial_msg, 1, true, None)
1238 }
1239 PermissionDecision::Allow => {
1240 tracing::info!(
1241 tool_name = tool_call.name.as_str(),
1242 permission = "allow",
1243 "Tool permission: allow"
1244 );
1245 tool_span.record("a3s.tool.permission", "allow");
1246
1247 let stream_ctx =
1249 self.streaming_tool_context(&event_tx, &tool_call.id, &tool_call.name);
1250 let result = self
1251 .tool_executor
1252 .execute_with_context(&tool_call.name, &tool_call.args, &stream_ctx)
1253 .await;
1254
1255 match result {
1256 Ok(r) => (r.output, r.exit_code, r.exit_code != 0, r.metadata),
1257 Err(e) => (format!("Tool execution error: {}", e), 1, true, None),
1258 }
1259 }
1260 PermissionDecision::Ask => {
1261 tracing::info!(
1262 tool_name = tool_call.name.as_str(),
1263 permission = "ask",
1264 "Tool permission: ask"
1265 );
1266 tool_span.record("a3s.tool.permission", "ask");
1267
1268 if let Some(cm) = &self.config.confirmation_manager {
1270 if !cm.requires_confirmation(&tool_call.name).await {
1272 let stream_ctx = self.streaming_tool_context(
1273 &event_tx,
1274 &tool_call.id,
1275 &tool_call.name,
1276 );
1277 let result = self
1278 .tool_executor
1279 .execute_with_context(
1280 &tool_call.name,
1281 &tool_call.args,
1282 &stream_ctx,
1283 )
1284 .await;
1285
1286 let (output, exit_code, is_error, metadata) = match result {
1287 Ok(r) => (r.output, r.exit_code, r.exit_code != 0, r.metadata),
1288 Err(e) => {
1289 (format!("Tool execution error: {}", e), 1, true, None)
1290 }
1291 };
1292
1293 Self::handle_post_execution_metadata(
1295 &metadata,
1296 &mut augmented_system,
1297 Some(&self.tool_executor),
1298 );
1299
1300 if let Some(tx) = &event_tx {
1302 tx.send(AgentEvent::ToolEnd {
1303 id: tool_call.id.clone(),
1304 name: tool_call.name.clone(),
1305 output: output.clone(),
1306 exit_code,
1307 })
1308 .await
1309 .ok();
1310 }
1311
1312 messages.push(Message::tool_result(
1314 &tool_call.id,
1315 &output,
1316 is_error,
1317 ));
1318
1319 let tool_duration = tool_start.elapsed();
1321 crate::telemetry::record_tool_result(exit_code, tool_duration);
1322
1323 self.fire_post_tool_use(
1325 session_id.unwrap_or(""),
1326 &tool_call.name,
1327 &tool_call.args,
1328 &output,
1329 exit_code == 0,
1330 tool_duration.as_millis() as u64,
1331 )
1332 .await;
1333
1334 continue; }
1336
1337 let policy = cm.policy().await;
1339 let timeout_ms = policy.default_timeout_ms;
1340 let timeout_action = policy.timeout_action;
1341
1342 let rx = cm
1344 .request_confirmation(
1345 &tool_call.id,
1346 &tool_call.name,
1347 &tool_call.args,
1348 )
1349 .await;
1350
1351 let confirmation_result =
1353 tokio::time::timeout(Duration::from_millis(timeout_ms), rx).await;
1354
1355 match confirmation_result {
1356 Ok(Ok(response)) => {
1357 if response.approved {
1358 let stream_ctx = self.streaming_tool_context(
1359 &event_tx,
1360 &tool_call.id,
1361 &tool_call.name,
1362 );
1363 let result = self
1364 .tool_executor
1365 .execute_with_context(
1366 &tool_call.name,
1367 &tool_call.args,
1368 &stream_ctx,
1369 )
1370 .await;
1371
1372 match result {
1373 Ok(r) => (
1374 r.output,
1375 r.exit_code,
1376 r.exit_code != 0,
1377 r.metadata,
1378 ),
1379 Err(e) => (
1380 format!("Tool execution error: {}", e),
1381 1,
1382 true,
1383 None,
1384 ),
1385 }
1386 } else {
1387 let rejection_msg = format!(
1388 "Tool '{}' execution was rejected by user. Reason: {}",
1389 tool_call.name,
1390 response.reason.unwrap_or_else(|| "No reason provided".to_string())
1391 );
1392 (rejection_msg, 1, true, None)
1393 }
1394 }
1395 Ok(Err(_)) => {
1396 let msg = format!(
1397 "Tool '{}' confirmation failed: confirmation channel closed",
1398 tool_call.name
1399 );
1400 (msg, 1, true, None)
1401 }
1402 Err(_) => {
1403 cm.check_timeouts().await;
1404
1405 match timeout_action {
1406 crate::hitl::TimeoutAction::Reject => {
1407 let msg = format!(
1408 "Tool '{}' execution timed out waiting for confirmation ({}ms). Execution rejected.",
1409 tool_call.name, timeout_ms
1410 );
1411 (msg, 1, true, None)
1412 }
1413 crate::hitl::TimeoutAction::AutoApprove => {
1414 let stream_ctx = self.streaming_tool_context(
1415 &event_tx,
1416 &tool_call.id,
1417 &tool_call.name,
1418 );
1419 let result = self
1420 .tool_executor
1421 .execute_with_context(
1422 &tool_call.name,
1423 &tool_call.args,
1424 &stream_ctx,
1425 )
1426 .await;
1427
1428 match result {
1429 Ok(r) => (
1430 r.output,
1431 r.exit_code,
1432 r.exit_code != 0,
1433 r.metadata,
1434 ),
1435 Err(e) => (
1436 format!("Tool execution error: {}", e),
1437 1,
1438 true,
1439 None,
1440 ),
1441 }
1442 }
1443 }
1444 }
1445 }
1446 } else {
1447 let msg = format!(
1449 "Tool '{}' requires confirmation but no HITL confirmation manager is configured. \
1450 Configure a confirmation policy to enable tool execution.",
1451 tool_call.name
1452 );
1453 tracing::warn!(
1454 tool_name = tool_call.name.as_str(),
1455 "Tool requires confirmation but no HITL manager configured"
1456 );
1457 (msg, 1, true, None)
1458 }
1459 }
1460 };
1461
1462 Self::handle_post_execution_metadata(
1464 &metadata,
1465 &mut augmented_system,
1466 Some(&self.tool_executor),
1467 );
1468
1469 let tool_duration = tool_start.elapsed();
1471 tracing::info!(
1472 tool_name = tool_call.name.as_str(),
1473 tool_id = tool_call.id.as_str(),
1474 exit_code = exit_code,
1475 success = (exit_code == 0),
1476 duration_ms = tool_duration.as_millis() as u64,
1477 "Tool execution finished"
1478 );
1479
1480 crate::telemetry::record_tool_result(exit_code, tool_duration);
1482
1483 if let Some(ref metrics) = self.tool_metrics {
1485 metrics.write().await.record(
1486 &tool_call.name,
1487 exit_code == 0,
1488 tool_duration.as_millis() as u64,
1489 );
1490 }
1491
1492 self.fire_post_tool_use(
1494 session_id.unwrap_or(""),
1495 &tool_call.name,
1496 &tool_call.args,
1497 &output,
1498 exit_code == 0,
1499 tool_duration.as_millis() as u64,
1500 )
1501 .await;
1502
1503 if let Some(tx) = &event_tx {
1505 tx.send(AgentEvent::ToolEnd {
1506 id: tool_call.id.clone(),
1507 name: tool_call.name.clone(),
1508 output: output.clone(),
1509 exit_code,
1510 })
1511 .await
1512 .ok();
1513 }
1514
1515 messages.push(Message::tool_result(&tool_call.id, &output, is_error));
1517 }
1518 }
1519 }
1520
1521 pub async fn execute_streaming(
1523 &self,
1524 history: &[Message],
1525 prompt: &str,
1526 ) -> Result<(
1527 mpsc::Receiver<AgentEvent>,
1528 tokio::task::JoinHandle<Result<AgentResult>>,
1529 )> {
1530 let (tx, rx) = mpsc::channel(100);
1531
1532 let llm_client = self.llm_client.clone();
1533 let tool_executor = self.tool_executor.clone();
1534 let tool_context = self.tool_context.clone();
1535 let config = self.config.clone();
1536 let tool_metrics = self.tool_metrics.clone();
1537 let history = history.to_vec();
1538 let prompt = prompt.to_string();
1539
1540 let handle = tokio::spawn(async move {
1541 let mut agent = AgentLoop::new(llm_client, tool_executor, tool_context, config);
1542 if let Some(metrics) = tool_metrics {
1543 agent = agent.with_tool_metrics(metrics);
1544 }
1545 agent.execute(&history, &prompt, Some(tx)).await
1546 });
1547
1548 Ok((rx, handle))
1549 }
1550
1551 pub async fn plan(&self, prompt: &str, _context: Option<&str>) -> Result<ExecutionPlan> {
1556 use crate::planning::LlmPlanner;
1557
1558 match LlmPlanner::create_plan(&self.llm_client, prompt).await {
1559 Ok(plan) => Ok(plan),
1560 Err(e) => {
1561 tracing::warn!("LLM plan creation failed, using fallback: {}", e);
1562 Ok(LlmPlanner::fallback_plan(prompt))
1563 }
1564 }
1565 }
1566
1567 pub async fn execute_with_planning(
1569 &self,
1570 history: &[Message],
1571 prompt: &str,
1572 event_tx: Option<mpsc::Sender<AgentEvent>>,
1573 ) -> Result<AgentResult> {
1574 if let Some(tx) = &event_tx {
1576 tx.send(AgentEvent::PlanningStart {
1577 prompt: prompt.to_string(),
1578 })
1579 .await
1580 .ok();
1581 }
1582
1583 let plan = self.plan(prompt, None).await?;
1585
1586 if let Some(tx) = &event_tx {
1588 tx.send(AgentEvent::PlanningEnd {
1589 estimated_steps: plan.steps.len(),
1590 plan: plan.clone(),
1591 })
1592 .await
1593 .ok();
1594 }
1595
1596 self.execute_plan(history, &plan, event_tx).await
1598 }
1599
1600 async fn execute_plan(
1602 &self,
1603 history: &[Message],
1604 plan: &ExecutionPlan,
1605 event_tx: Option<mpsc::Sender<AgentEvent>>,
1606 ) -> Result<AgentResult> {
1607 let mut current_history = history.to_vec();
1608 let mut total_usage = TokenUsage::default();
1609 let mut tool_calls_count = 0;
1610
1611 let steps_text = plan
1613 .steps
1614 .iter()
1615 .enumerate()
1616 .map(|(i, step)| format!("{}. {}", i + 1, step.content))
1617 .collect::<Vec<_>>()
1618 .join("\n");
1619 current_history.push(Message::user(&crate::prompts::render(
1620 crate::prompts::PLAN_EXECUTE_GOAL,
1621 &[("goal", &plan.goal), ("steps", &steps_text)],
1622 )));
1623
1624 for (step_idx, step) in plan.steps.iter().enumerate() {
1626 if let Some(tx) = &event_tx {
1628 tx.send(AgentEvent::StepStart {
1629 step_id: step.id.clone(),
1630 description: step.content.clone(),
1631 step_number: step_idx + 1,
1632 total_steps: plan.steps.len(),
1633 })
1634 .await
1635 .ok();
1636 }
1637
1638 let step_prompt = crate::prompts::render(
1640 crate::prompts::PLAN_EXECUTE_STEP,
1641 &[
1642 ("step_num", &(step_idx + 1).to_string()),
1643 ("description", &step.content),
1644 ],
1645 );
1646 let step_result = self
1647 .execute_loop(¤t_history, &step_prompt, None, event_tx.clone())
1648 .await?;
1649
1650 current_history = step_result.messages.clone();
1652 total_usage.prompt_tokens += step_result.usage.prompt_tokens;
1653 total_usage.completion_tokens += step_result.usage.completion_tokens;
1654 total_usage.total_tokens += step_result.usage.total_tokens;
1655 tool_calls_count += step_result.tool_calls_count;
1656
1657 if let Some(tx) = &event_tx {
1659 tx.send(AgentEvent::StepEnd {
1660 step_id: step.id.clone(),
1661 status: TaskStatus::Completed,
1662 step_number: step_idx + 1,
1663 total_steps: plan.steps.len(),
1664 })
1665 .await
1666 .ok();
1667 }
1668
1669 if self.config.goal_tracking {
1671 if let Some(tx) = &event_tx {
1672 tx.send(AgentEvent::GoalProgress {
1673 goal: plan.goal.clone(),
1674 progress: (step_idx + 1) as f32 / plan.steps.len() as f32,
1675 completed_steps: step_idx + 1,
1676 total_steps: plan.steps.len(),
1677 })
1678 .await
1679 .ok();
1680 }
1681 }
1682 }
1683
1684 let final_text = current_history
1686 .last()
1687 .map(|m| {
1688 m.content
1689 .iter()
1690 .filter_map(|block| {
1691 if let crate::llm::ContentBlock::Text { text } = block {
1692 Some(text.as_str())
1693 } else {
1694 None
1695 }
1696 })
1697 .collect::<Vec<_>>()
1698 .join("\n")
1699 })
1700 .unwrap_or_default();
1701
1702 Ok(AgentResult {
1703 text: final_text,
1704 messages: current_history,
1705 usage: total_usage,
1706 tool_calls_count,
1707 })
1708 }
1709
1710 pub async fn extract_goal(&self, prompt: &str) -> Result<AgentGoal> {
1715 use crate::planning::LlmPlanner;
1716
1717 match LlmPlanner::extract_goal(&self.llm_client, prompt).await {
1718 Ok(goal) => Ok(goal),
1719 Err(e) => {
1720 tracing::warn!("LLM goal extraction failed, using fallback: {}", e);
1721 Ok(LlmPlanner::fallback_goal(prompt))
1722 }
1723 }
1724 }
1725
1726 pub async fn check_goal_achievement(
1731 &self,
1732 goal: &AgentGoal,
1733 current_state: &str,
1734 ) -> Result<bool> {
1735 use crate::planning::LlmPlanner;
1736
1737 match LlmPlanner::check_achievement(&self.llm_client, goal, current_state).await {
1738 Ok(result) => Ok(result.achieved),
1739 Err(e) => {
1740 tracing::warn!("LLM achievement check failed, using fallback: {}", e);
1741 let result = LlmPlanner::fallback_check_achievement(goal, current_state);
1742 Ok(result.achieved)
1743 }
1744 }
1745 }
1746}
1747
1748#[cfg(test)]
1749mod tests {
1750 use super::*;
1751 use crate::llm::{ContentBlock, StreamEvent};
1752 use crate::permissions::PermissionPolicy;
1753 use crate::tools::ToolExecutor;
1754 use std::path::PathBuf;
1755 use std::sync::atomic::{AtomicUsize, Ordering};
1756
1757 fn test_tool_context() -> ToolContext {
1759 ToolContext::new(PathBuf::from("/tmp"))
1760 }
1761
1762 #[test]
1763 fn test_agent_config_default() {
1764 let config = AgentConfig::default();
1765 assert!(config.system_prompt.is_none());
1766 assert!(config.tools.is_empty()); assert_eq!(config.max_tool_rounds, MAX_TOOL_ROUNDS);
1768 assert!(config.permission_policy.is_none());
1769 assert!(config.context_providers.is_empty());
1770 }
1771
1772 pub(crate) struct MockLlmClient {
1778 responses: std::sync::Mutex<Vec<LlmResponse>>,
1780 call_count: AtomicUsize,
1782 }
1783
1784 impl MockLlmClient {
1785 pub(crate) fn new(responses: Vec<LlmResponse>) -> Self {
1786 Self {
1787 responses: std::sync::Mutex::new(responses),
1788 call_count: AtomicUsize::new(0),
1789 }
1790 }
1791
1792 pub(crate) fn text_response(text: &str) -> LlmResponse {
1794 LlmResponse {
1795 message: Message {
1796 role: "assistant".to_string(),
1797 content: vec![ContentBlock::Text {
1798 text: text.to_string(),
1799 }],
1800 reasoning_content: None,
1801 },
1802 usage: TokenUsage {
1803 prompt_tokens: 10,
1804 completion_tokens: 5,
1805 total_tokens: 15,
1806 cache_read_tokens: None,
1807 cache_write_tokens: None,
1808 },
1809 stop_reason: Some("end_turn".to_string()),
1810 }
1811 }
1812
1813 pub(crate) fn tool_call_response(
1815 tool_id: &str,
1816 tool_name: &str,
1817 args: serde_json::Value,
1818 ) -> LlmResponse {
1819 LlmResponse {
1820 message: Message {
1821 role: "assistant".to_string(),
1822 content: vec![ContentBlock::ToolUse {
1823 id: tool_id.to_string(),
1824 name: tool_name.to_string(),
1825 input: args,
1826 }],
1827 reasoning_content: None,
1828 },
1829 usage: TokenUsage {
1830 prompt_tokens: 10,
1831 completion_tokens: 5,
1832 total_tokens: 15,
1833 cache_read_tokens: None,
1834 cache_write_tokens: None,
1835 },
1836 stop_reason: Some("tool_use".to_string()),
1837 }
1838 }
1839 }
1840
1841 #[async_trait::async_trait]
1842 impl LlmClient for MockLlmClient {
1843 async fn complete(
1844 &self,
1845 _messages: &[Message],
1846 _system: Option<&str>,
1847 _tools: &[ToolDefinition],
1848 ) -> Result<LlmResponse> {
1849 self.call_count.fetch_add(1, Ordering::SeqCst);
1850 let mut responses = self.responses.lock().unwrap();
1851 if responses.is_empty() {
1852 anyhow::bail!("No more mock responses available");
1853 }
1854 Ok(responses.remove(0))
1855 }
1856
1857 async fn complete_streaming(
1858 &self,
1859 _messages: &[Message],
1860 _system: Option<&str>,
1861 _tools: &[ToolDefinition],
1862 ) -> Result<mpsc::Receiver<StreamEvent>> {
1863 self.call_count.fetch_add(1, Ordering::SeqCst);
1864 let mut responses = self.responses.lock().unwrap();
1865 if responses.is_empty() {
1866 anyhow::bail!("No more mock responses available");
1867 }
1868 let response = responses.remove(0);
1869
1870 let (tx, rx) = mpsc::channel(10);
1871 tokio::spawn(async move {
1872 for block in &response.message.content {
1874 if let ContentBlock::Text { text } = block {
1875 tx.send(StreamEvent::TextDelta(text.clone())).await.ok();
1876 }
1877 }
1878 tx.send(StreamEvent::Done(response)).await.ok();
1879 });
1880
1881 Ok(rx)
1882 }
1883 }
1884
1885 #[tokio::test]
1890 async fn test_agent_simple_response() {
1891 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
1892 "Hello, I'm an AI assistant.",
1893 )]));
1894
1895 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
1896 let config = AgentConfig::default();
1897
1898 let agent = AgentLoop::new(
1899 mock_client.clone(),
1900 tool_executor,
1901 test_tool_context(),
1902 config,
1903 );
1904 let result = agent.execute(&[], "Hello", None).await.unwrap();
1905
1906 assert_eq!(result.text, "Hello, I'm an AI assistant.");
1907 assert_eq!(result.tool_calls_count, 0);
1908 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 1);
1909 }
1910
1911 #[tokio::test]
1912 async fn test_agent_with_tool_call() {
1913 let mock_client = Arc::new(MockLlmClient::new(vec![
1914 MockLlmClient::tool_call_response(
1916 "tool-1",
1917 "bash",
1918 serde_json::json!({"command": "echo hello"}),
1919 ),
1920 MockLlmClient::text_response("The command output was: hello"),
1922 ]));
1923
1924 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
1925 let config = AgentConfig::default();
1926
1927 let agent = AgentLoop::new(
1928 mock_client.clone(),
1929 tool_executor,
1930 test_tool_context(),
1931 config,
1932 );
1933 let result = agent.execute(&[], "Run echo hello", None).await.unwrap();
1934
1935 assert_eq!(result.text, "The command output was: hello");
1936 assert_eq!(result.tool_calls_count, 1);
1937 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 2);
1938 }
1939
1940 #[tokio::test]
1941 async fn test_agent_permission_deny() {
1942 let mock_client = Arc::new(MockLlmClient::new(vec![
1943 MockLlmClient::tool_call_response(
1945 "tool-1",
1946 "bash",
1947 serde_json::json!({"command": "rm -rf /tmp/test"}),
1948 ),
1949 MockLlmClient::text_response(
1951 "I cannot execute that command due to permission restrictions.",
1952 ),
1953 ]));
1954
1955 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
1956
1957 let permission_policy = PermissionPolicy::new().deny("bash(rm:*)");
1959 let policy_lock = Arc::new(RwLock::new(permission_policy));
1960
1961 let config = AgentConfig {
1962 permission_policy: Some(policy_lock),
1963 ..Default::default()
1964 };
1965
1966 let (tx, mut rx) = mpsc::channel(100);
1967 let agent = AgentLoop::new(
1968 mock_client.clone(),
1969 tool_executor,
1970 test_tool_context(),
1971 config,
1972 );
1973 let result = agent.execute(&[], "Delete files", Some(tx)).await.unwrap();
1974
1975 let mut found_permission_denied = false;
1977 while let Ok(event) = rx.try_recv() {
1978 if let AgentEvent::PermissionDenied { tool_name, .. } = event {
1979 assert_eq!(tool_name, "bash");
1980 found_permission_denied = true;
1981 }
1982 }
1983 assert!(
1984 found_permission_denied,
1985 "Should have received PermissionDenied event"
1986 );
1987
1988 assert_eq!(result.tool_calls_count, 1);
1989 }
1990
1991 #[tokio::test]
1992 async fn test_agent_permission_allow() {
1993 let mock_client = Arc::new(MockLlmClient::new(vec![
1994 MockLlmClient::tool_call_response(
1996 "tool-1",
1997 "bash",
1998 serde_json::json!({"command": "echo hello"}),
1999 ),
2000 MockLlmClient::text_response("Done!"),
2002 ]));
2003
2004 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2005
2006 let permission_policy = PermissionPolicy::new()
2008 .allow("bash(echo:*)")
2009 .deny("bash(rm:*)");
2010 let policy_lock = Arc::new(RwLock::new(permission_policy));
2011
2012 let config = AgentConfig {
2013 permission_policy: Some(policy_lock),
2014 ..Default::default()
2015 };
2016
2017 let agent = AgentLoop::new(
2018 mock_client.clone(),
2019 tool_executor,
2020 test_tool_context(),
2021 config,
2022 );
2023 let result = agent.execute(&[], "Echo hello", None).await.unwrap();
2024
2025 assert_eq!(result.text, "Done!");
2026 assert_eq!(result.tool_calls_count, 1);
2027 }
2028
2029 #[tokio::test]
2030 async fn test_agent_streaming_events() {
2031 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
2032 "Hello!",
2033 )]));
2034
2035 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2036 let config = AgentConfig::default();
2037
2038 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2039 let (mut rx, handle) = agent.execute_streaming(&[], "Hi").await.unwrap();
2040
2041 let mut events = Vec::new();
2043 while let Some(event) = rx.recv().await {
2044 events.push(event);
2045 }
2046
2047 let result = handle.await.unwrap().unwrap();
2048 assert_eq!(result.text, "Hello!");
2049
2050 assert!(events.iter().any(|e| matches!(e, AgentEvent::Start { .. })));
2052 assert!(events.iter().any(|e| matches!(e, AgentEvent::End { .. })));
2053 }
2054
2055 #[tokio::test]
2056 async fn test_agent_max_tool_rounds() {
2057 let responses: Vec<LlmResponse> = (0..100)
2059 .map(|i| {
2060 MockLlmClient::tool_call_response(
2061 &format!("tool-{}", i),
2062 "bash",
2063 serde_json::json!({"command": "echo loop"}),
2064 )
2065 })
2066 .collect();
2067
2068 let mock_client = Arc::new(MockLlmClient::new(responses));
2069 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2070
2071 let config = AgentConfig {
2072 max_tool_rounds: 3,
2073 ..Default::default()
2074 };
2075
2076 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2077 let result = agent.execute(&[], "Loop forever", None).await;
2078
2079 assert!(result.is_err());
2081 assert!(result.unwrap_err().to_string().contains("Max tool rounds"));
2082 }
2083
2084 #[tokio::test]
2085 async fn test_agent_no_permission_policy_defaults_to_ask() {
2086 let mock_client = Arc::new(MockLlmClient::new(vec![
2089 MockLlmClient::tool_call_response(
2090 "tool-1",
2091 "bash",
2092 serde_json::json!({"command": "rm -rf /tmp/test"}),
2093 ),
2094 MockLlmClient::text_response("Denied!"),
2095 ]));
2096
2097 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2098 let config = AgentConfig {
2099 permission_policy: None, ..Default::default()
2102 };
2103
2104 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2105 let result = agent.execute(&[], "Delete", None).await.unwrap();
2106
2107 assert_eq!(result.text, "Denied!");
2109 assert_eq!(result.tool_calls_count, 1);
2110 }
2111
2112 #[tokio::test]
2113 async fn test_agent_permission_ask_without_cm_denies() {
2114 let mock_client = Arc::new(MockLlmClient::new(vec![
2117 MockLlmClient::tool_call_response(
2118 "tool-1",
2119 "bash",
2120 serde_json::json!({"command": "echo test"}),
2121 ),
2122 MockLlmClient::text_response("Denied!"),
2123 ]));
2124
2125 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2126
2127 let permission_policy = PermissionPolicy::new(); let policy_lock = Arc::new(RwLock::new(permission_policy));
2130
2131 let config = AgentConfig {
2132 permission_policy: Some(policy_lock),
2133 ..Default::default()
2135 };
2136
2137 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2138 let result = agent.execute(&[], "Echo", None).await.unwrap();
2139
2140 assert_eq!(result.text, "Denied!");
2142 assert!(result.tool_calls_count >= 1);
2144 }
2145
2146 #[tokio::test]
2151 async fn test_agent_hitl_approved() {
2152 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
2153 use tokio::sync::broadcast;
2154
2155 let mock_client = Arc::new(MockLlmClient::new(vec![
2156 MockLlmClient::tool_call_response(
2157 "tool-1",
2158 "bash",
2159 serde_json::json!({"command": "echo hello"}),
2160 ),
2161 MockLlmClient::text_response("Command executed!"),
2162 ]));
2163
2164 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2165
2166 let (event_tx, _event_rx) = broadcast::channel(100);
2168 let hitl_policy = ConfirmationPolicy {
2169 enabled: true,
2170 ..Default::default()
2171 };
2172 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
2173
2174 let permission_policy = PermissionPolicy::new(); let policy_lock = Arc::new(RwLock::new(permission_policy));
2177
2178 let config = AgentConfig {
2179 permission_policy: Some(policy_lock),
2180 confirmation_manager: Some(confirmation_manager.clone()),
2181 ..Default::default()
2182 };
2183
2184 let cm_clone = confirmation_manager.clone();
2186 tokio::spawn(async move {
2187 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
2189 cm_clone.confirm("tool-1", true, None).await.ok();
2191 });
2192
2193 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2194 let result = agent.execute(&[], "Run echo", None).await.unwrap();
2195
2196 assert_eq!(result.text, "Command executed!");
2197 assert_eq!(result.tool_calls_count, 1);
2198 }
2199
2200 #[tokio::test]
2201 async fn test_agent_hitl_rejected() {
2202 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
2203 use tokio::sync::broadcast;
2204
2205 let mock_client = Arc::new(MockLlmClient::new(vec![
2206 MockLlmClient::tool_call_response(
2207 "tool-1",
2208 "bash",
2209 serde_json::json!({"command": "rm -rf /"}),
2210 ),
2211 MockLlmClient::text_response("Understood, I won't do that."),
2212 ]));
2213
2214 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2215
2216 let (event_tx, _event_rx) = broadcast::channel(100);
2218 let hitl_policy = ConfirmationPolicy {
2219 enabled: true,
2220 ..Default::default()
2221 };
2222 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
2223
2224 let permission_policy = PermissionPolicy::new();
2226 let policy_lock = Arc::new(RwLock::new(permission_policy));
2227
2228 let config = AgentConfig {
2229 permission_policy: Some(policy_lock),
2230 confirmation_manager: Some(confirmation_manager.clone()),
2231 ..Default::default()
2232 };
2233
2234 let cm_clone = confirmation_manager.clone();
2236 tokio::spawn(async move {
2237 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
2238 cm_clone
2239 .confirm("tool-1", false, Some("Too dangerous".to_string()))
2240 .await
2241 .ok();
2242 });
2243
2244 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2245 let result = agent.execute(&[], "Delete everything", None).await.unwrap();
2246
2247 assert_eq!(result.text, "Understood, I won't do that.");
2249 }
2250
2251 #[tokio::test]
2252 async fn test_agent_hitl_timeout_reject() {
2253 use crate::hitl::{ConfirmationManager, ConfirmationPolicy, TimeoutAction};
2254 use tokio::sync::broadcast;
2255
2256 let mock_client = Arc::new(MockLlmClient::new(vec![
2257 MockLlmClient::tool_call_response(
2258 "tool-1",
2259 "bash",
2260 serde_json::json!({"command": "echo test"}),
2261 ),
2262 MockLlmClient::text_response("Timed out, I understand."),
2263 ]));
2264
2265 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2266
2267 let (event_tx, _event_rx) = broadcast::channel(100);
2269 let hitl_policy = ConfirmationPolicy {
2270 enabled: true,
2271 default_timeout_ms: 50, timeout_action: TimeoutAction::Reject,
2273 ..Default::default()
2274 };
2275 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
2276
2277 let permission_policy = PermissionPolicy::new();
2278 let policy_lock = Arc::new(RwLock::new(permission_policy));
2279
2280 let config = AgentConfig {
2281 permission_policy: Some(policy_lock),
2282 confirmation_manager: Some(confirmation_manager),
2283 ..Default::default()
2284 };
2285
2286 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2288 let result = agent.execute(&[], "Echo", None).await.unwrap();
2289
2290 assert_eq!(result.text, "Timed out, I understand.");
2292 }
2293
2294 #[tokio::test]
2295 async fn test_agent_hitl_timeout_auto_approve() {
2296 use crate::hitl::{ConfirmationManager, ConfirmationPolicy, TimeoutAction};
2297 use tokio::sync::broadcast;
2298
2299 let mock_client = Arc::new(MockLlmClient::new(vec![
2300 MockLlmClient::tool_call_response(
2301 "tool-1",
2302 "bash",
2303 serde_json::json!({"command": "echo hello"}),
2304 ),
2305 MockLlmClient::text_response("Auto-approved and executed!"),
2306 ]));
2307
2308 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2309
2310 let (event_tx, _event_rx) = broadcast::channel(100);
2312 let hitl_policy = ConfirmationPolicy {
2313 enabled: true,
2314 default_timeout_ms: 50, timeout_action: TimeoutAction::AutoApprove,
2316 ..Default::default()
2317 };
2318 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
2319
2320 let permission_policy = PermissionPolicy::new();
2321 let policy_lock = Arc::new(RwLock::new(permission_policy));
2322
2323 let config = AgentConfig {
2324 permission_policy: Some(policy_lock),
2325 confirmation_manager: Some(confirmation_manager),
2326 ..Default::default()
2327 };
2328
2329 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2331 let result = agent.execute(&[], "Echo", None).await.unwrap();
2332
2333 assert_eq!(result.text, "Auto-approved and executed!");
2335 assert_eq!(result.tool_calls_count, 1);
2336 }
2337
2338 #[tokio::test]
2339 async fn test_agent_hitl_confirmation_events() {
2340 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
2341 use tokio::sync::broadcast;
2342
2343 let mock_client = Arc::new(MockLlmClient::new(vec![
2344 MockLlmClient::tool_call_response(
2345 "tool-1",
2346 "bash",
2347 serde_json::json!({"command": "echo test"}),
2348 ),
2349 MockLlmClient::text_response("Done!"),
2350 ]));
2351
2352 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2353
2354 let (event_tx, mut event_rx) = broadcast::channel(100);
2356 let hitl_policy = ConfirmationPolicy {
2357 enabled: true,
2358 default_timeout_ms: 5000, ..Default::default()
2360 };
2361 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
2362
2363 let permission_policy = PermissionPolicy::new();
2364 let policy_lock = Arc::new(RwLock::new(permission_policy));
2365
2366 let config = AgentConfig {
2367 permission_policy: Some(policy_lock),
2368 confirmation_manager: Some(confirmation_manager.clone()),
2369 ..Default::default()
2370 };
2371
2372 let cm_clone = confirmation_manager.clone();
2374 let event_handle = tokio::spawn(async move {
2375 let mut events = Vec::new();
2376 while let Ok(event) = event_rx.recv().await {
2378 events.push(event.clone());
2379 if let AgentEvent::ConfirmationRequired { tool_id, .. } = event {
2380 cm_clone.confirm(&tool_id, true, None).await.ok();
2382 if let Ok(recv_event) = event_rx.recv().await {
2384 events.push(recv_event);
2385 }
2386 break;
2387 }
2388 }
2389 events
2390 });
2391
2392 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2393 let _result = agent.execute(&[], "Echo", None).await.unwrap();
2394
2395 let events = event_handle.await.unwrap();
2397 assert!(
2398 events
2399 .iter()
2400 .any(|e| matches!(e, AgentEvent::ConfirmationRequired { .. })),
2401 "Should have ConfirmationRequired event"
2402 );
2403 assert!(
2404 events
2405 .iter()
2406 .any(|e| matches!(e, AgentEvent::ConfirmationReceived { approved: true, .. })),
2407 "Should have ConfirmationReceived event with approved=true"
2408 );
2409 }
2410
2411 #[tokio::test]
2412 async fn test_agent_hitl_disabled_auto_executes() {
2413 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
2415 use tokio::sync::broadcast;
2416
2417 let mock_client = Arc::new(MockLlmClient::new(vec![
2418 MockLlmClient::tool_call_response(
2419 "tool-1",
2420 "bash",
2421 serde_json::json!({"command": "echo auto"}),
2422 ),
2423 MockLlmClient::text_response("Auto executed!"),
2424 ]));
2425
2426 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2427
2428 let (event_tx, _event_rx) = broadcast::channel(100);
2430 let hitl_policy = ConfirmationPolicy {
2431 enabled: false, ..Default::default()
2433 };
2434 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
2435
2436 let permission_policy = PermissionPolicy::new(); let policy_lock = Arc::new(RwLock::new(permission_policy));
2438
2439 let config = AgentConfig {
2440 permission_policy: Some(policy_lock),
2441 confirmation_manager: Some(confirmation_manager),
2442 ..Default::default()
2443 };
2444
2445 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2446 let result = agent.execute(&[], "Echo", None).await.unwrap();
2447
2448 assert_eq!(result.text, "Auto executed!");
2450 assert_eq!(result.tool_calls_count, 1);
2451 }
2452
2453 #[tokio::test]
2454 async fn test_agent_hitl_with_permission_deny_skips_hitl() {
2455 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
2457 use tokio::sync::broadcast;
2458
2459 let mock_client = Arc::new(MockLlmClient::new(vec![
2460 MockLlmClient::tool_call_response(
2461 "tool-1",
2462 "bash",
2463 serde_json::json!({"command": "rm -rf /"}),
2464 ),
2465 MockLlmClient::text_response("Blocked by permission."),
2466 ]));
2467
2468 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2469
2470 let (event_tx, mut event_rx) = broadcast::channel(100);
2472 let hitl_policy = ConfirmationPolicy {
2473 enabled: true,
2474 ..Default::default()
2475 };
2476 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
2477
2478 let permission_policy = PermissionPolicy::new().deny("bash(rm:*)");
2480 let policy_lock = Arc::new(RwLock::new(permission_policy));
2481
2482 let config = AgentConfig {
2483 permission_policy: Some(policy_lock),
2484 confirmation_manager: Some(confirmation_manager),
2485 ..Default::default()
2486 };
2487
2488 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2489 let result = agent.execute(&[], "Delete", None).await.unwrap();
2490
2491 assert_eq!(result.text, "Blocked by permission.");
2493
2494 let mut found_confirmation = false;
2496 while let Ok(event) = event_rx.try_recv() {
2497 if matches!(event, AgentEvent::ConfirmationRequired { .. }) {
2498 found_confirmation = true;
2499 }
2500 }
2501 assert!(
2502 !found_confirmation,
2503 "HITL should not be triggered when permission is Deny"
2504 );
2505 }
2506
2507 #[tokio::test]
2508 async fn test_agent_hitl_with_permission_allow_skips_hitl() {
2509 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
2512 use tokio::sync::broadcast;
2513
2514 let mock_client = Arc::new(MockLlmClient::new(vec![
2515 MockLlmClient::tool_call_response(
2516 "tool-1",
2517 "bash",
2518 serde_json::json!({"command": "echo hello"}),
2519 ),
2520 MockLlmClient::text_response("Allowed!"),
2521 ]));
2522
2523 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2524
2525 let (event_tx, mut event_rx) = broadcast::channel(100);
2527 let hitl_policy = ConfirmationPolicy {
2528 enabled: true,
2529 ..Default::default()
2530 };
2531 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
2532
2533 let permission_policy = PermissionPolicy::new().allow("bash(echo:*)");
2535 let policy_lock = Arc::new(RwLock::new(permission_policy));
2536
2537 let config = AgentConfig {
2538 permission_policy: Some(policy_lock),
2539 confirmation_manager: Some(confirmation_manager.clone()),
2540 ..Default::default()
2541 };
2542
2543 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2544 let result = agent.execute(&[], "Echo", None).await.unwrap();
2545
2546 assert_eq!(result.text, "Allowed!");
2548
2549 let mut found_confirmation = false;
2551 while let Ok(event) = event_rx.try_recv() {
2552 if matches!(event, AgentEvent::ConfirmationRequired { .. }) {
2553 found_confirmation = true;
2554 }
2555 }
2556 assert!(
2557 !found_confirmation,
2558 "Permission Allow should skip HITL confirmation"
2559 );
2560 }
2561
2562 #[tokio::test]
2563 async fn test_agent_hitl_multiple_tool_calls() {
2564 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
2566 use tokio::sync::broadcast;
2567
2568 let mock_client = Arc::new(MockLlmClient::new(vec![
2569 LlmResponse {
2571 message: Message {
2572 role: "assistant".to_string(),
2573 content: vec![
2574 ContentBlock::ToolUse {
2575 id: "tool-1".to_string(),
2576 name: "bash".to_string(),
2577 input: serde_json::json!({"command": "echo first"}),
2578 },
2579 ContentBlock::ToolUse {
2580 id: "tool-2".to_string(),
2581 name: "bash".to_string(),
2582 input: serde_json::json!({"command": "echo second"}),
2583 },
2584 ],
2585 reasoning_content: None,
2586 },
2587 usage: TokenUsage {
2588 prompt_tokens: 10,
2589 completion_tokens: 5,
2590 total_tokens: 15,
2591 cache_read_tokens: None,
2592 cache_write_tokens: None,
2593 },
2594 stop_reason: Some("tool_use".to_string()),
2595 },
2596 MockLlmClient::text_response("Both executed!"),
2597 ]));
2598
2599 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2600
2601 let (event_tx, _event_rx) = broadcast::channel(100);
2603 let hitl_policy = ConfirmationPolicy {
2604 enabled: true,
2605 default_timeout_ms: 5000,
2606 ..Default::default()
2607 };
2608 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
2609
2610 let permission_policy = PermissionPolicy::new(); let policy_lock = Arc::new(RwLock::new(permission_policy));
2612
2613 let config = AgentConfig {
2614 permission_policy: Some(policy_lock),
2615 confirmation_manager: Some(confirmation_manager.clone()),
2616 ..Default::default()
2617 };
2618
2619 let cm_clone = confirmation_manager.clone();
2621 tokio::spawn(async move {
2622 tokio::time::sleep(std::time::Duration::from_millis(30)).await;
2623 cm_clone.confirm("tool-1", true, None).await.ok();
2624 tokio::time::sleep(std::time::Duration::from_millis(30)).await;
2625 cm_clone.confirm("tool-2", true, None).await.ok();
2626 });
2627
2628 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2629 let result = agent.execute(&[], "Run both", None).await.unwrap();
2630
2631 assert_eq!(result.text, "Both executed!");
2632 assert_eq!(result.tool_calls_count, 2);
2633 }
2634
2635 #[tokio::test]
2636 async fn test_agent_hitl_partial_approval() {
2637 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
2639 use tokio::sync::broadcast;
2640
2641 let mock_client = Arc::new(MockLlmClient::new(vec![
2642 LlmResponse {
2644 message: Message {
2645 role: "assistant".to_string(),
2646 content: vec![
2647 ContentBlock::ToolUse {
2648 id: "tool-1".to_string(),
2649 name: "bash".to_string(),
2650 input: serde_json::json!({"command": "echo safe"}),
2651 },
2652 ContentBlock::ToolUse {
2653 id: "tool-2".to_string(),
2654 name: "bash".to_string(),
2655 input: serde_json::json!({"command": "rm -rf /"}),
2656 },
2657 ],
2658 reasoning_content: None,
2659 },
2660 usage: TokenUsage {
2661 prompt_tokens: 10,
2662 completion_tokens: 5,
2663 total_tokens: 15,
2664 cache_read_tokens: None,
2665 cache_write_tokens: None,
2666 },
2667 stop_reason: Some("tool_use".to_string()),
2668 },
2669 MockLlmClient::text_response("First worked, second rejected."),
2670 ]));
2671
2672 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2673
2674 let (event_tx, _event_rx) = broadcast::channel(100);
2675 let hitl_policy = ConfirmationPolicy {
2676 enabled: true,
2677 default_timeout_ms: 5000,
2678 ..Default::default()
2679 };
2680 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
2681
2682 let permission_policy = PermissionPolicy::new();
2683 let policy_lock = Arc::new(RwLock::new(permission_policy));
2684
2685 let config = AgentConfig {
2686 permission_policy: Some(policy_lock),
2687 confirmation_manager: Some(confirmation_manager.clone()),
2688 ..Default::default()
2689 };
2690
2691 let cm_clone = confirmation_manager.clone();
2693 tokio::spawn(async move {
2694 tokio::time::sleep(std::time::Duration::from_millis(30)).await;
2695 cm_clone.confirm("tool-1", true, None).await.ok();
2696 tokio::time::sleep(std::time::Duration::from_millis(30)).await;
2697 cm_clone
2698 .confirm("tool-2", false, Some("Dangerous".to_string()))
2699 .await
2700 .ok();
2701 });
2702
2703 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2704 let result = agent.execute(&[], "Run both", None).await.unwrap();
2705
2706 assert_eq!(result.text, "First worked, second rejected.");
2707 assert_eq!(result.tool_calls_count, 2);
2708 }
2709
2710 #[tokio::test]
2711 async fn test_agent_hitl_yolo_mode_auto_approves() {
2712 use crate::hitl::{ConfirmationManager, ConfirmationPolicy, SessionLane};
2714 use tokio::sync::broadcast;
2715
2716 let mock_client = Arc::new(MockLlmClient::new(vec![
2717 MockLlmClient::tool_call_response(
2718 "tool-1",
2719 "read", serde_json::json!({"path": "/tmp/test.txt"}),
2721 ),
2722 MockLlmClient::text_response("File read!"),
2723 ]));
2724
2725 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2726
2727 let (event_tx, mut event_rx) = broadcast::channel(100);
2729 let mut yolo_lanes = std::collections::HashSet::new();
2730 yolo_lanes.insert(SessionLane::Query);
2731 let hitl_policy = ConfirmationPolicy {
2732 enabled: true,
2733 yolo_lanes, ..Default::default()
2735 };
2736 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
2737
2738 let permission_policy = PermissionPolicy::new();
2739 let policy_lock = Arc::new(RwLock::new(permission_policy));
2740
2741 let config = AgentConfig {
2742 permission_policy: Some(policy_lock),
2743 confirmation_manager: Some(confirmation_manager),
2744 ..Default::default()
2745 };
2746
2747 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2748 let result = agent.execute(&[], "Read file", None).await.unwrap();
2749
2750 assert_eq!(result.text, "File read!");
2752
2753 let mut found_confirmation = false;
2755 while let Ok(event) = event_rx.try_recv() {
2756 if matches!(event, AgentEvent::ConfirmationRequired { .. }) {
2757 found_confirmation = true;
2758 }
2759 }
2760 assert!(
2761 !found_confirmation,
2762 "YOLO mode should not trigger confirmation"
2763 );
2764 }
2765
2766 #[tokio::test]
2767 async fn test_agent_config_with_all_options() {
2768 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
2769 use tokio::sync::broadcast;
2770
2771 let (event_tx, _) = broadcast::channel(100);
2772 let hitl_policy = ConfirmationPolicy::default();
2773 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
2774
2775 let permission_policy = PermissionPolicy::new().allow("bash(*)");
2776 let policy_lock = Arc::new(RwLock::new(permission_policy));
2777
2778 let config = AgentConfig {
2779 system_prompt: Some("Test system prompt".to_string()),
2780 tools: vec![],
2781 max_tool_rounds: 10,
2782 permission_policy: Some(policy_lock),
2783 confirmation_manager: Some(confirmation_manager),
2784 context_providers: vec![],
2785 planning_enabled: false,
2786 goal_tracking: false,
2787 skill_tool_filters: vec![],
2788 hook_engine: None,
2789 };
2790
2791 assert_eq!(config.system_prompt, Some("Test system prompt".to_string()));
2792 assert_eq!(config.max_tool_rounds, 10);
2793 assert!(config.permission_policy.is_some());
2794 assert!(config.confirmation_manager.is_some());
2795 assert!(config.context_providers.is_empty());
2796
2797 let debug_str = format!("{:?}", config);
2799 assert!(debug_str.contains("AgentConfig"));
2800 assert!(debug_str.contains("permission_policy: true"));
2801 assert!(debug_str.contains("confirmation_manager: true"));
2802 assert!(debug_str.contains("context_providers: 0"));
2803 }
2804
2805 use crate::context::{ContextItem, ContextType};
2810
2811 struct MockContextProvider {
2813 name: String,
2814 items: Vec<ContextItem>,
2815 on_turn_calls: std::sync::Arc<tokio::sync::RwLock<Vec<(String, String, String)>>>,
2816 }
2817
2818 impl MockContextProvider {
2819 fn new(name: &str) -> Self {
2820 Self {
2821 name: name.to_string(),
2822 items: Vec::new(),
2823 on_turn_calls: std::sync::Arc::new(tokio::sync::RwLock::new(Vec::new())),
2824 }
2825 }
2826
2827 fn with_items(mut self, items: Vec<ContextItem>) -> Self {
2828 self.items = items;
2829 self
2830 }
2831 }
2832
2833 #[async_trait::async_trait]
2834 impl ContextProvider for MockContextProvider {
2835 fn name(&self) -> &str {
2836 &self.name
2837 }
2838
2839 async fn query(&self, _query: &ContextQuery) -> anyhow::Result<ContextResult> {
2840 let mut result = ContextResult::new(&self.name);
2841 for item in &self.items {
2842 result.add_item(item.clone());
2843 }
2844 Ok(result)
2845 }
2846
2847 async fn on_turn_complete(
2848 &self,
2849 session_id: &str,
2850 prompt: &str,
2851 response: &str,
2852 ) -> anyhow::Result<()> {
2853 let mut calls = self.on_turn_calls.write().await;
2854 calls.push((
2855 session_id.to_string(),
2856 prompt.to_string(),
2857 response.to_string(),
2858 ));
2859 Ok(())
2860 }
2861 }
2862
2863 #[tokio::test]
2864 async fn test_agent_with_context_provider() {
2865 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
2866 "Response using context",
2867 )]));
2868
2869 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2870
2871 let provider =
2872 MockContextProvider::new("test-provider").with_items(vec![ContextItem::new(
2873 "ctx-1",
2874 ContextType::Resource,
2875 "Relevant context here",
2876 )
2877 .with_source("test://docs/example")]);
2878
2879 let config = AgentConfig {
2880 system_prompt: Some("You are helpful.".to_string()),
2881 context_providers: vec![Arc::new(provider)],
2882 ..Default::default()
2883 };
2884
2885 let agent = AgentLoop::new(
2886 mock_client.clone(),
2887 tool_executor,
2888 test_tool_context(),
2889 config,
2890 );
2891 let result = agent.execute(&[], "What is X?", None).await.unwrap();
2892
2893 assert_eq!(result.text, "Response using context");
2894 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 1);
2895 }
2896
2897 #[tokio::test]
2898 async fn test_agent_context_provider_events() {
2899 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
2900 "Answer",
2901 )]));
2902
2903 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2904
2905 let provider =
2906 MockContextProvider::new("event-provider").with_items(vec![ContextItem::new(
2907 "item-1",
2908 ContextType::Memory,
2909 "Memory content",
2910 )
2911 .with_token_count(50)]);
2912
2913 let config = AgentConfig {
2914 context_providers: vec![Arc::new(provider)],
2915 ..Default::default()
2916 };
2917
2918 let (tx, mut rx) = mpsc::channel(100);
2919 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2920 let _result = agent.execute(&[], "Test prompt", Some(tx)).await.unwrap();
2921
2922 let mut events = Vec::new();
2924 while let Ok(event) = rx.try_recv() {
2925 events.push(event);
2926 }
2927
2928 assert!(
2930 events
2931 .iter()
2932 .any(|e| matches!(e, AgentEvent::ContextResolving { .. })),
2933 "Should have ContextResolving event"
2934 );
2935 assert!(
2936 events
2937 .iter()
2938 .any(|e| matches!(e, AgentEvent::ContextResolved { .. })),
2939 "Should have ContextResolved event"
2940 );
2941
2942 for event in &events {
2944 if let AgentEvent::ContextResolved {
2945 total_items,
2946 total_tokens,
2947 } = event
2948 {
2949 assert_eq!(*total_items, 1);
2950 assert_eq!(*total_tokens, 50);
2951 }
2952 }
2953 }
2954
2955 #[tokio::test]
2956 async fn test_agent_multiple_context_providers() {
2957 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
2958 "Combined response",
2959 )]));
2960
2961 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2962
2963 let provider1 = MockContextProvider::new("provider-1").with_items(vec![ContextItem::new(
2964 "p1-1",
2965 ContextType::Resource,
2966 "Resource from P1",
2967 )
2968 .with_token_count(100)]);
2969
2970 let provider2 = MockContextProvider::new("provider-2").with_items(vec![
2971 ContextItem::new("p2-1", ContextType::Memory, "Memory from P2").with_token_count(50),
2972 ContextItem::new("p2-2", ContextType::Skill, "Skill from P2").with_token_count(75),
2973 ]);
2974
2975 let config = AgentConfig {
2976 system_prompt: Some("Base system prompt.".to_string()),
2977 context_providers: vec![Arc::new(provider1), Arc::new(provider2)],
2978 ..Default::default()
2979 };
2980
2981 let (tx, mut rx) = mpsc::channel(100);
2982 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2983 let result = agent.execute(&[], "Query", Some(tx)).await.unwrap();
2984
2985 assert_eq!(result.text, "Combined response");
2986
2987 while let Ok(event) = rx.try_recv() {
2989 if let AgentEvent::ContextResolved {
2990 total_items,
2991 total_tokens,
2992 } = event
2993 {
2994 assert_eq!(total_items, 3); assert_eq!(total_tokens, 225); }
2997 }
2998 }
2999
3000 #[tokio::test]
3001 async fn test_agent_no_context_providers() {
3002 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
3003 "No context",
3004 )]));
3005
3006 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3007
3008 let config = AgentConfig::default();
3010
3011 let (tx, mut rx) = mpsc::channel(100);
3012 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3013 let result = agent.execute(&[], "Simple prompt", Some(tx)).await.unwrap();
3014
3015 assert_eq!(result.text, "No context");
3016
3017 let mut events = Vec::new();
3019 while let Ok(event) = rx.try_recv() {
3020 events.push(event);
3021 }
3022
3023 assert!(
3024 !events
3025 .iter()
3026 .any(|e| matches!(e, AgentEvent::ContextResolving { .. })),
3027 "Should NOT have ContextResolving event"
3028 );
3029 }
3030
3031 #[tokio::test]
3032 async fn test_agent_context_on_turn_complete() {
3033 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
3034 "Final response",
3035 )]));
3036
3037 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3038
3039 let provider = Arc::new(MockContextProvider::new("memory-provider"));
3040 let on_turn_calls = provider.on_turn_calls.clone();
3041
3042 let config = AgentConfig {
3043 context_providers: vec![provider],
3044 ..Default::default()
3045 };
3046
3047 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3048
3049 let result = agent
3051 .execute_with_session(&[], "User prompt", Some("sess-123"), None)
3052 .await
3053 .unwrap();
3054
3055 assert_eq!(result.text, "Final response");
3056
3057 let calls = on_turn_calls.read().await;
3059 assert_eq!(calls.len(), 1);
3060 assert_eq!(calls[0].0, "sess-123");
3061 assert_eq!(calls[0].1, "User prompt");
3062 assert_eq!(calls[0].2, "Final response");
3063 }
3064
3065 #[tokio::test]
3066 async fn test_agent_context_on_turn_complete_no_session() {
3067 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
3068 "Response",
3069 )]));
3070
3071 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3072
3073 let provider = Arc::new(MockContextProvider::new("memory-provider"));
3074 let on_turn_calls = provider.on_turn_calls.clone();
3075
3076 let config = AgentConfig {
3077 context_providers: vec![provider],
3078 ..Default::default()
3079 };
3080
3081 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3082
3083 let _result = agent.execute(&[], "Prompt", None).await.unwrap();
3085
3086 let calls = on_turn_calls.read().await;
3088 assert!(calls.is_empty());
3089 }
3090
3091 #[tokio::test]
3092 async fn test_agent_build_augmented_system_prompt() {
3093 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response("OK")]));
3094
3095 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3096
3097 let provider = MockContextProvider::new("test").with_items(vec![ContextItem::new(
3098 "doc-1",
3099 ContextType::Resource,
3100 "Auth uses JWT tokens.",
3101 )
3102 .with_source("viking://docs/auth")]);
3103
3104 let config = AgentConfig {
3105 system_prompt: Some("You are helpful.".to_string()),
3106 context_providers: vec![Arc::new(provider)],
3107 ..Default::default()
3108 };
3109
3110 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3111
3112 let context_results = agent.resolve_context("test", None).await;
3114 let augmented = agent.build_augmented_system_prompt(&context_results);
3115
3116 let augmented_str = augmented.unwrap();
3117 assert!(augmented_str.contains("You are helpful."));
3118 assert!(augmented_str.contains("<context source=\"viking://docs/auth\" type=\"Resource\">"));
3119 assert!(augmented_str.contains("Auth uses JWT tokens."));
3120 }
3121
3122 async fn collect_events(mut rx: mpsc::Receiver<AgentEvent>) -> Vec<AgentEvent> {
3128 let mut events = Vec::new();
3129 while let Ok(event) = rx.try_recv() {
3130 events.push(event);
3131 }
3132 while let Some(event) = rx.recv().await {
3134 events.push(event);
3135 }
3136 events
3137 }
3138
3139 #[tokio::test]
3140 async fn test_agent_multi_turn_tool_chain() {
3141 let mock_client = Arc::new(MockLlmClient::new(vec![
3143 MockLlmClient::tool_call_response(
3145 "t1",
3146 "bash",
3147 serde_json::json!({"command": "echo step1"}),
3148 ),
3149 MockLlmClient::tool_call_response(
3151 "t2",
3152 "bash",
3153 serde_json::json!({"command": "echo step2"}),
3154 ),
3155 MockLlmClient::text_response("Completed both steps: step1 then step2"),
3157 ]));
3158
3159 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3160 let config = AgentConfig::default();
3161
3162 let agent = AgentLoop::new(
3163 mock_client.clone(),
3164 tool_executor,
3165 test_tool_context(),
3166 config,
3167 );
3168 let result = agent.execute(&[], "Run two steps", None).await.unwrap();
3169
3170 assert_eq!(result.text, "Completed both steps: step1 then step2");
3171 assert_eq!(result.tool_calls_count, 2);
3172 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 3);
3173
3174 assert_eq!(result.messages[0].role, "user");
3176 assert_eq!(result.messages[1].role, "assistant"); assert_eq!(result.messages[2].role, "user"); assert_eq!(result.messages[3].role, "assistant"); assert_eq!(result.messages[4].role, "user"); assert_eq!(result.messages[5].role, "assistant"); assert_eq!(result.messages.len(), 6);
3182 }
3183
3184 #[tokio::test]
3185 async fn test_agent_conversation_history_preserved() {
3186 let existing_history = vec![
3188 Message::user("What is Rust?"),
3189 Message {
3190 role: "assistant".to_string(),
3191 content: vec![ContentBlock::Text {
3192 text: "Rust is a systems programming language.".to_string(),
3193 }],
3194 reasoning_content: None,
3195 },
3196 ];
3197
3198 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
3199 "Rust was created by Graydon Hoare at Mozilla.",
3200 )]));
3201
3202 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3203 let agent = AgentLoop::new(
3204 mock_client.clone(),
3205 tool_executor,
3206 test_tool_context(),
3207 AgentConfig::default(),
3208 );
3209
3210 let result = agent
3211 .execute(&existing_history, "Who created it?", None)
3212 .await
3213 .unwrap();
3214
3215 assert_eq!(result.messages.len(), 4);
3217 assert_eq!(result.messages[0].text(), "What is Rust?");
3218 assert_eq!(
3219 result.messages[1].text(),
3220 "Rust is a systems programming language."
3221 );
3222 assert_eq!(result.messages[2].text(), "Who created it?");
3223 assert_eq!(
3224 result.messages[3].text(),
3225 "Rust was created by Graydon Hoare at Mozilla."
3226 );
3227 }
3228
3229 #[tokio::test]
3230 async fn test_agent_event_stream_completeness() {
3231 let mock_client = Arc::new(MockLlmClient::new(vec![
3233 MockLlmClient::tool_call_response(
3234 "t1",
3235 "bash",
3236 serde_json::json!({"command": "echo hi"}),
3237 ),
3238 MockLlmClient::text_response("Done"),
3239 ]));
3240
3241 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3242 let agent = AgentLoop::new(
3243 mock_client,
3244 tool_executor,
3245 test_tool_context(),
3246 AgentConfig::default(),
3247 );
3248
3249 let (tx, rx) = mpsc::channel(100);
3250 let result = agent.execute(&[], "Say hi", Some(tx)).await.unwrap();
3251 assert_eq!(result.text, "Done");
3252
3253 let events = collect_events(rx).await;
3254
3255 let event_types: Vec<&str> = events
3257 .iter()
3258 .map(|e| match e {
3259 AgentEvent::Start { .. } => "Start",
3260 AgentEvent::TurnStart { .. } => "TurnStart",
3261 AgentEvent::TurnEnd { .. } => "TurnEnd",
3262 AgentEvent::ToolEnd { .. } => "ToolEnd",
3263 AgentEvent::End { .. } => "End",
3264 _ => "Other",
3265 })
3266 .collect();
3267
3268 assert_eq!(event_types.first(), Some(&"Start"));
3270 assert_eq!(event_types.last(), Some(&"End"));
3271
3272 let turn_starts = event_types.iter().filter(|&&t| t == "TurnStart").count();
3274 assert_eq!(turn_starts, 2);
3275
3276 let tool_ends = event_types.iter().filter(|&&t| t == "ToolEnd").count();
3278 assert_eq!(tool_ends, 1);
3279 }
3280
3281 #[tokio::test]
3282 async fn test_agent_multiple_tools_single_turn() {
3283 let mock_client = Arc::new(MockLlmClient::new(vec![
3285 LlmResponse {
3286 message: Message {
3287 role: "assistant".to_string(),
3288 content: vec![
3289 ContentBlock::ToolUse {
3290 id: "t1".to_string(),
3291 name: "bash".to_string(),
3292 input: serde_json::json!({"command": "echo first"}),
3293 },
3294 ContentBlock::ToolUse {
3295 id: "t2".to_string(),
3296 name: "bash".to_string(),
3297 input: serde_json::json!({"command": "echo second"}),
3298 },
3299 ],
3300 reasoning_content: None,
3301 },
3302 usage: TokenUsage {
3303 prompt_tokens: 10,
3304 completion_tokens: 5,
3305 total_tokens: 15,
3306 cache_read_tokens: None,
3307 cache_write_tokens: None,
3308 },
3309 stop_reason: Some("tool_use".to_string()),
3310 },
3311 MockLlmClient::text_response("Both commands ran"),
3312 ]));
3313
3314 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3315 let agent = AgentLoop::new(
3316 mock_client.clone(),
3317 tool_executor,
3318 test_tool_context(),
3319 AgentConfig::default(),
3320 );
3321
3322 let result = agent.execute(&[], "Run both", None).await.unwrap();
3323
3324 assert_eq!(result.text, "Both commands ran");
3325 assert_eq!(result.tool_calls_count, 2);
3326 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 2); assert_eq!(result.messages[0].role, "user");
3330 assert_eq!(result.messages[1].role, "assistant");
3331 assert_eq!(result.messages[2].role, "user"); assert_eq!(result.messages[3].role, "user"); assert_eq!(result.messages[4].role, "assistant");
3334 }
3335
3336 #[tokio::test]
3337 async fn test_agent_token_usage_accumulation() {
3338 let mock_client = Arc::new(MockLlmClient::new(vec![
3340 MockLlmClient::tool_call_response(
3341 "t1",
3342 "bash",
3343 serde_json::json!({"command": "echo x"}),
3344 ),
3345 MockLlmClient::text_response("Done"),
3346 ]));
3347
3348 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3349 let agent = AgentLoop::new(
3350 mock_client,
3351 tool_executor,
3352 test_tool_context(),
3353 AgentConfig::default(),
3354 );
3355
3356 let result = agent.execute(&[], "test", None).await.unwrap();
3357
3358 assert_eq!(result.usage.prompt_tokens, 20);
3361 assert_eq!(result.usage.completion_tokens, 10);
3362 assert_eq!(result.usage.total_tokens, 30);
3363 }
3364
3365 #[tokio::test]
3366 async fn test_agent_system_prompt_passed() {
3367 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
3369 "I am a coding assistant.",
3370 )]));
3371
3372 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3373 let config = AgentConfig {
3374 system_prompt: Some("You are a coding assistant.".to_string()),
3375 ..Default::default()
3376 };
3377
3378 let agent = AgentLoop::new(
3379 mock_client.clone(),
3380 tool_executor,
3381 test_tool_context(),
3382 config,
3383 );
3384 let result = agent.execute(&[], "What are you?", None).await.unwrap();
3385
3386 assert_eq!(result.text, "I am a coding assistant.");
3387 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 1);
3388 }
3389
3390 #[tokio::test]
3391 async fn test_agent_max_rounds_with_persistent_tool_calls() {
3392 let mut responses = Vec::new();
3394 for i in 0..15 {
3395 responses.push(MockLlmClient::tool_call_response(
3396 &format!("t{}", i),
3397 "bash",
3398 serde_json::json!({"command": format!("echo round{}", i)}),
3399 ));
3400 }
3401
3402 let mock_client = Arc::new(MockLlmClient::new(responses));
3403 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3404 let config = AgentConfig {
3405 max_tool_rounds: 5,
3406 ..Default::default()
3407 };
3408
3409 let agent = AgentLoop::new(
3410 mock_client.clone(),
3411 tool_executor,
3412 test_tool_context(),
3413 config,
3414 );
3415 let result = agent.execute(&[], "Loop forever", None).await;
3416
3417 assert!(result.is_err());
3418 let err = result.unwrap_err().to_string();
3419 assert!(err.contains("Max tool rounds (5) exceeded"));
3420 }
3421
3422 #[tokio::test]
3423 async fn test_agent_end_event_contains_final_text() {
3424 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
3425 "Final answer here",
3426 )]));
3427
3428 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3429 let agent = AgentLoop::new(
3430 mock_client,
3431 tool_executor,
3432 test_tool_context(),
3433 AgentConfig::default(),
3434 );
3435
3436 let (tx, rx) = mpsc::channel(100);
3437 agent.execute(&[], "test", Some(tx)).await.unwrap();
3438
3439 let events = collect_events(rx).await;
3440 let end_event = events.iter().find(|e| matches!(e, AgentEvent::End { .. }));
3441 assert!(end_event.is_some());
3442
3443 if let AgentEvent::End { text, usage } = end_event.unwrap() {
3444 assert_eq!(text, "Final answer here");
3445 assert_eq!(usage.total_tokens, 15);
3446 }
3447 }
3448}
3449
3450#[cfg(test)]
3451mod extra_agent_tests {
3452 use super::*;
3453 use crate::agent::tests::MockLlmClient;
3454 use crate::llm::{ContentBlock, StreamEvent};
3455 use crate::tools::ToolExecutor;
3456 use std::path::PathBuf;
3457 use std::sync::atomic::{AtomicUsize, Ordering};
3458
3459 fn test_tool_context() -> ToolContext {
3460 ToolContext::new(PathBuf::from("/tmp"))
3461 }
3462
3463 #[test]
3468 fn test_agent_config_debug() {
3469 let config = AgentConfig {
3470 system_prompt: Some("You are helpful".to_string()),
3471 tools: vec![],
3472 max_tool_rounds: 10,
3473 permission_policy: None,
3474 confirmation_manager: None,
3475 context_providers: vec![],
3476 planning_enabled: true,
3477 goal_tracking: false,
3478 skill_tool_filters: vec![],
3479 hook_engine: None,
3480 };
3481 let debug = format!("{:?}", config);
3482 assert!(debug.contains("AgentConfig"));
3483 assert!(debug.contains("planning_enabled"));
3484 }
3485
3486 #[test]
3487 fn test_agent_config_default_values() {
3488 let config = AgentConfig::default();
3489 assert_eq!(config.max_tool_rounds, MAX_TOOL_ROUNDS);
3490 assert!(!config.planning_enabled);
3491 assert!(!config.goal_tracking);
3492 assert!(config.context_providers.is_empty());
3493 assert!(config.skill_tool_filters.is_empty());
3494 }
3495
3496 #[tokio::test]
3497 async fn test_agent_skill_tool_filters_blocks_unauthorized() {
3498 use crate::tools::skill::Skill;
3501
3502 let mock_client = Arc::new(MockLlmClient::new(vec![
3503 MockLlmClient::tool_call_response(
3504 "tool-1",
3505 "bash",
3506 serde_json::json!({"command": "rm -rf /"}),
3507 ),
3508 MockLlmClient::text_response("Blocked!"),
3509 ]));
3510
3511 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3512
3513 let skill = Skill {
3515 name: "read-only".to_string(),
3516 description: "Read-only skill".to_string(),
3517 content: "Read files only".to_string(),
3518 allowed_tools: Some("read(*)".to_string()),
3519 disable_model_invocation: false,
3520 kind: crate::tools::SkillKind::Instruction,
3521 };
3522
3523 let policy = PermissionPolicy::new().allow("bash(*)");
3525 let policy_lock = Arc::new(RwLock::new(policy));
3526
3527 let (event_tx, _) = tokio::sync::broadcast::channel(10);
3529 let cm = Arc::new(crate::hitl::ConfirmationManager::new(
3530 crate::hitl::ConfirmationPolicy::default(), event_tx,
3532 ));
3533
3534 let config = AgentConfig {
3535 permission_policy: Some(policy_lock),
3536 confirmation_manager: Some(cm),
3537 skill_tool_filters: vec![skill],
3538 ..Default::default()
3539 };
3540
3541 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3542 let result = agent.execute(&[], "Delete", None).await.unwrap();
3543
3544 assert_eq!(result.text, "Blocked!");
3546 }
3547
3548 #[tokio::test]
3549 async fn test_agent_skill_tool_filters_allows_authorized() {
3550 use crate::tools::skill::Skill;
3552
3553 let mock_client = Arc::new(MockLlmClient::new(vec![
3554 MockLlmClient::tool_call_response(
3555 "tool-1",
3556 "bash",
3557 serde_json::json!({"command": "echo hello"}),
3558 ),
3559 MockLlmClient::text_response("Allowed!"),
3560 ]));
3561
3562 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3563
3564 let skill = Skill {
3566 name: "bash-skill".to_string(),
3567 description: "Bash skill".to_string(),
3568 content: "Run bash".to_string(),
3569 allowed_tools: Some("bash(*)".to_string()),
3570 disable_model_invocation: false,
3571 kind: crate::tools::SkillKind::Instruction,
3572 };
3573
3574 let policy = PermissionPolicy::new().allow("bash(*)");
3576 let policy_lock = Arc::new(RwLock::new(policy));
3577
3578 let (event_tx, _) = tokio::sync::broadcast::channel(10);
3580 let cm = Arc::new(crate::hitl::ConfirmationManager::new(
3581 crate::hitl::ConfirmationPolicy::default(), event_tx,
3583 ));
3584
3585 let config = AgentConfig {
3586 permission_policy: Some(policy_lock),
3587 confirmation_manager: Some(cm),
3588 skill_tool_filters: vec![skill],
3589 ..Default::default()
3590 };
3591
3592 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3593 let result = agent.execute(&[], "Echo", None).await.unwrap();
3594
3595 assert_eq!(result.text, "Allowed!");
3597 }
3598
3599 #[test]
3604 fn test_agent_event_serialize_start() {
3605 let event = AgentEvent::Start {
3606 prompt: "Hello".to_string(),
3607 };
3608 let json = serde_json::to_string(&event).unwrap();
3609 assert!(json.contains("agent_start"));
3610 assert!(json.contains("Hello"));
3611 }
3612
3613 #[test]
3614 fn test_agent_event_serialize_text_delta() {
3615 let event = AgentEvent::TextDelta {
3616 text: "chunk".to_string(),
3617 };
3618 let json = serde_json::to_string(&event).unwrap();
3619 assert!(json.contains("text_delta"));
3620 }
3621
3622 #[test]
3623 fn test_agent_event_serialize_tool_start() {
3624 let event = AgentEvent::ToolStart {
3625 id: "t1".to_string(),
3626 name: "bash".to_string(),
3627 };
3628 let json = serde_json::to_string(&event).unwrap();
3629 assert!(json.contains("tool_start"));
3630 assert!(json.contains("bash"));
3631 }
3632
3633 #[test]
3634 fn test_agent_event_serialize_tool_end() {
3635 let event = AgentEvent::ToolEnd {
3636 id: "t1".to_string(),
3637 name: "bash".to_string(),
3638 output: "hello".to_string(),
3639 exit_code: 0,
3640 };
3641 let json = serde_json::to_string(&event).unwrap();
3642 assert!(json.contains("tool_end"));
3643 }
3644
3645 #[test]
3646 fn test_agent_event_serialize_error() {
3647 let event = AgentEvent::Error {
3648 message: "oops".to_string(),
3649 };
3650 let json = serde_json::to_string(&event).unwrap();
3651 assert!(json.contains("error"));
3652 assert!(json.contains("oops"));
3653 }
3654
3655 #[test]
3656 fn test_agent_event_serialize_confirmation_required() {
3657 let event = AgentEvent::ConfirmationRequired {
3658 tool_id: "t1".to_string(),
3659 tool_name: "bash".to_string(),
3660 args: serde_json::json!({"cmd": "rm"}),
3661 timeout_ms: 30000,
3662 };
3663 let json = serde_json::to_string(&event).unwrap();
3664 assert!(json.contains("confirmation_required"));
3665 }
3666
3667 #[test]
3668 fn test_agent_event_serialize_confirmation_received() {
3669 let event = AgentEvent::ConfirmationReceived {
3670 tool_id: "t1".to_string(),
3671 approved: true,
3672 reason: Some("safe".to_string()),
3673 };
3674 let json = serde_json::to_string(&event).unwrap();
3675 assert!(json.contains("confirmation_received"));
3676 }
3677
3678 #[test]
3679 fn test_agent_event_serialize_confirmation_timeout() {
3680 let event = AgentEvent::ConfirmationTimeout {
3681 tool_id: "t1".to_string(),
3682 action_taken: "rejected".to_string(),
3683 };
3684 let json = serde_json::to_string(&event).unwrap();
3685 assert!(json.contains("confirmation_timeout"));
3686 }
3687
3688 #[test]
3689 fn test_agent_event_serialize_external_task_pending() {
3690 let event = AgentEvent::ExternalTaskPending {
3691 task_id: "task-1".to_string(),
3692 session_id: "sess-1".to_string(),
3693 lane: crate::hitl::SessionLane::Execute,
3694 command_type: "bash".to_string(),
3695 payload: serde_json::json!({}),
3696 timeout_ms: 60000,
3697 };
3698 let json = serde_json::to_string(&event).unwrap();
3699 assert!(json.contains("external_task_pending"));
3700 }
3701
3702 #[test]
3703 fn test_agent_event_serialize_external_task_completed() {
3704 let event = AgentEvent::ExternalTaskCompleted {
3705 task_id: "task-1".to_string(),
3706 session_id: "sess-1".to_string(),
3707 success: false,
3708 };
3709 let json = serde_json::to_string(&event).unwrap();
3710 assert!(json.contains("external_task_completed"));
3711 }
3712
3713 #[test]
3714 fn test_agent_event_serialize_permission_denied() {
3715 let event = AgentEvent::PermissionDenied {
3716 tool_id: "t1".to_string(),
3717 tool_name: "bash".to_string(),
3718 args: serde_json::json!({}),
3719 reason: "denied".to_string(),
3720 };
3721 let json = serde_json::to_string(&event).unwrap();
3722 assert!(json.contains("permission_denied"));
3723 }
3724
3725 #[test]
3726 fn test_agent_event_serialize_context_compacted() {
3727 let event = AgentEvent::ContextCompacted {
3728 session_id: "sess-1".to_string(),
3729 before_messages: 100,
3730 after_messages: 20,
3731 percent_before: 0.85,
3732 };
3733 let json = serde_json::to_string(&event).unwrap();
3734 assert!(json.contains("context_compacted"));
3735 }
3736
3737 #[test]
3738 fn test_agent_event_serialize_turn_start() {
3739 let event = AgentEvent::TurnStart { turn: 3 };
3740 let json = serde_json::to_string(&event).unwrap();
3741 assert!(json.contains("turn_start"));
3742 }
3743
3744 #[test]
3745 fn test_agent_event_serialize_turn_end() {
3746 let event = AgentEvent::TurnEnd {
3747 turn: 3,
3748 usage: TokenUsage::default(),
3749 };
3750 let json = serde_json::to_string(&event).unwrap();
3751 assert!(json.contains("turn_end"));
3752 }
3753
3754 #[test]
3755 fn test_agent_event_serialize_end() {
3756 let event = AgentEvent::End {
3757 text: "Done".to_string(),
3758 usage: TokenUsage {
3759 prompt_tokens: 100,
3760 completion_tokens: 50,
3761 total_tokens: 150,
3762 cache_read_tokens: None,
3763 cache_write_tokens: None,
3764 },
3765 };
3766 let json = serde_json::to_string(&event).unwrap();
3767 assert!(json.contains("agent_end"));
3768 }
3769
3770 #[test]
3775 fn test_agent_result_fields() {
3776 let result = AgentResult {
3777 text: "output".to_string(),
3778 messages: vec![Message::user("hello")],
3779 usage: TokenUsage::default(),
3780 tool_calls_count: 3,
3781 };
3782 assert_eq!(result.text, "output");
3783 assert_eq!(result.messages.len(), 1);
3784 assert_eq!(result.tool_calls_count, 3);
3785 }
3786
3787 #[test]
3792 fn test_agent_event_serialize_context_resolving() {
3793 let event = AgentEvent::ContextResolving {
3794 providers: vec!["provider1".to_string(), "provider2".to_string()],
3795 };
3796 let json = serde_json::to_string(&event).unwrap();
3797 assert!(json.contains("context_resolving"));
3798 assert!(json.contains("provider1"));
3799 }
3800
3801 #[test]
3802 fn test_agent_event_serialize_context_resolved() {
3803 let event = AgentEvent::ContextResolved {
3804 total_items: 5,
3805 total_tokens: 1000,
3806 };
3807 let json = serde_json::to_string(&event).unwrap();
3808 assert!(json.contains("context_resolved"));
3809 assert!(json.contains("1000"));
3810 }
3811
3812 #[test]
3813 fn test_agent_event_serialize_command_dead_lettered() {
3814 let event = AgentEvent::CommandDeadLettered {
3815 command_id: "cmd-1".to_string(),
3816 command_type: "bash".to_string(),
3817 lane: "execute".to_string(),
3818 error: "timeout".to_string(),
3819 attempts: 3,
3820 };
3821 let json = serde_json::to_string(&event).unwrap();
3822 assert!(json.contains("command_dead_lettered"));
3823 assert!(json.contains("cmd-1"));
3824 }
3825
3826 #[test]
3827 fn test_agent_event_serialize_command_retry() {
3828 let event = AgentEvent::CommandRetry {
3829 command_id: "cmd-2".to_string(),
3830 command_type: "read".to_string(),
3831 lane: "query".to_string(),
3832 attempt: 2,
3833 delay_ms: 1000,
3834 };
3835 let json = serde_json::to_string(&event).unwrap();
3836 assert!(json.contains("command_retry"));
3837 assert!(json.contains("cmd-2"));
3838 }
3839
3840 #[test]
3841 fn test_agent_event_serialize_queue_alert() {
3842 let event = AgentEvent::QueueAlert {
3843 level: "warning".to_string(),
3844 alert_type: "depth".to_string(),
3845 message: "Queue depth exceeded".to_string(),
3846 };
3847 let json = serde_json::to_string(&event).unwrap();
3848 assert!(json.contains("queue_alert"));
3849 assert!(json.contains("warning"));
3850 }
3851
3852 #[test]
3853 fn test_agent_event_serialize_task_updated() {
3854 let event = AgentEvent::TaskUpdated {
3855 session_id: "sess-1".to_string(),
3856 tasks: vec![],
3857 };
3858 let json = serde_json::to_string(&event).unwrap();
3859 assert!(json.contains("task_updated"));
3860 assert!(json.contains("sess-1"));
3861 }
3862
3863 #[test]
3864 fn test_agent_event_serialize_memory_stored() {
3865 let event = AgentEvent::MemoryStored {
3866 memory_id: "mem-1".to_string(),
3867 memory_type: "conversation".to_string(),
3868 importance: 0.8,
3869 tags: vec!["important".to_string()],
3870 };
3871 let json = serde_json::to_string(&event).unwrap();
3872 assert!(json.contains("memory_stored"));
3873 assert!(json.contains("mem-1"));
3874 }
3875
3876 #[test]
3877 fn test_agent_event_serialize_memory_recalled() {
3878 let event = AgentEvent::MemoryRecalled {
3879 memory_id: "mem-2".to_string(),
3880 content: "Previous conversation".to_string(),
3881 relevance: 0.9,
3882 };
3883 let json = serde_json::to_string(&event).unwrap();
3884 assert!(json.contains("memory_recalled"));
3885 assert!(json.contains("mem-2"));
3886 }
3887
3888 #[test]
3889 fn test_agent_event_serialize_memories_searched() {
3890 let event = AgentEvent::MemoriesSearched {
3891 query: Some("search term".to_string()),
3892 tags: vec!["tag1".to_string()],
3893 result_count: 5,
3894 };
3895 let json = serde_json::to_string(&event).unwrap();
3896 assert!(json.contains("memories_searched"));
3897 assert!(json.contains("search term"));
3898 }
3899
3900 #[test]
3901 fn test_agent_event_serialize_memory_cleared() {
3902 let event = AgentEvent::MemoryCleared {
3903 tier: "short_term".to_string(),
3904 count: 10,
3905 };
3906 let json = serde_json::to_string(&event).unwrap();
3907 assert!(json.contains("memory_cleared"));
3908 assert!(json.contains("short_term"));
3909 }
3910
3911 #[test]
3912 fn test_agent_event_serialize_subagent_start() {
3913 let event = AgentEvent::SubagentStart {
3914 task_id: "task-1".to_string(),
3915 session_id: "child-sess".to_string(),
3916 parent_session_id: "parent-sess".to_string(),
3917 agent: "explore".to_string(),
3918 description: "Explore codebase".to_string(),
3919 };
3920 let json = serde_json::to_string(&event).unwrap();
3921 assert!(json.contains("subagent_start"));
3922 assert!(json.contains("explore"));
3923 }
3924
3925 #[test]
3926 fn test_agent_event_serialize_subagent_progress() {
3927 let event = AgentEvent::SubagentProgress {
3928 task_id: "task-1".to_string(),
3929 session_id: "child-sess".to_string(),
3930 status: "processing".to_string(),
3931 metadata: serde_json::json!({"progress": 50}),
3932 };
3933 let json = serde_json::to_string(&event).unwrap();
3934 assert!(json.contains("subagent_progress"));
3935 assert!(json.contains("processing"));
3936 }
3937
3938 #[test]
3939 fn test_agent_event_serialize_subagent_end() {
3940 let event = AgentEvent::SubagentEnd {
3941 task_id: "task-1".to_string(),
3942 session_id: "child-sess".to_string(),
3943 agent: "explore".to_string(),
3944 output: "Found 10 files".to_string(),
3945 success: true,
3946 };
3947 let json = serde_json::to_string(&event).unwrap();
3948 assert!(json.contains("subagent_end"));
3949 assert!(json.contains("Found 10 files"));
3950 }
3951
3952 #[test]
3953 fn test_agent_event_serialize_planning_start() {
3954 let event = AgentEvent::PlanningStart {
3955 prompt: "Build a web app".to_string(),
3956 };
3957 let json = serde_json::to_string(&event).unwrap();
3958 assert!(json.contains("planning_start"));
3959 assert!(json.contains("Build a web app"));
3960 }
3961
3962 #[test]
3963 fn test_agent_event_serialize_planning_end() {
3964 use crate::planning::{Complexity, ExecutionPlan};
3965 let plan = ExecutionPlan::new("Test goal".to_string(), Complexity::Simple);
3966 let event = AgentEvent::PlanningEnd {
3967 plan,
3968 estimated_steps: 3,
3969 };
3970 let json = serde_json::to_string(&event).unwrap();
3971 assert!(json.contains("planning_end"));
3972 assert!(json.contains("estimated_steps"));
3973 }
3974
3975 #[test]
3976 fn test_agent_event_serialize_step_start() {
3977 let event = AgentEvent::StepStart {
3978 step_id: "step-1".to_string(),
3979 description: "Initialize project".to_string(),
3980 step_number: 1,
3981 total_steps: 5,
3982 };
3983 let json = serde_json::to_string(&event).unwrap();
3984 assert!(json.contains("step_start"));
3985 assert!(json.contains("Initialize project"));
3986 }
3987
3988 #[test]
3989 fn test_agent_event_serialize_step_end() {
3990 let event = AgentEvent::StepEnd {
3991 step_id: "step-1".to_string(),
3992 status: TaskStatus::Completed,
3993 step_number: 1,
3994 total_steps: 5,
3995 };
3996 let json = serde_json::to_string(&event).unwrap();
3997 assert!(json.contains("step_end"));
3998 assert!(json.contains("step-1"));
3999 }
4000
4001 #[test]
4002 fn test_agent_event_serialize_goal_extracted() {
4003 use crate::planning::AgentGoal;
4004 let goal = AgentGoal::new("Complete the task".to_string());
4005 let event = AgentEvent::GoalExtracted { goal };
4006 let json = serde_json::to_string(&event).unwrap();
4007 assert!(json.contains("goal_extracted"));
4008 }
4009
4010 #[test]
4011 fn test_agent_event_serialize_goal_progress() {
4012 let event = AgentEvent::GoalProgress {
4013 goal: "Build app".to_string(),
4014 progress: 0.5,
4015 completed_steps: 2,
4016 total_steps: 4,
4017 };
4018 let json = serde_json::to_string(&event).unwrap();
4019 assert!(json.contains("goal_progress"));
4020 assert!(json.contains("0.5"));
4021 }
4022
4023 #[test]
4024 fn test_agent_event_serialize_goal_achieved() {
4025 let event = AgentEvent::GoalAchieved {
4026 goal: "Build app".to_string(),
4027 total_steps: 4,
4028 duration_ms: 5000,
4029 };
4030 let json = serde_json::to_string(&event).unwrap();
4031 assert!(json.contains("goal_achieved"));
4032 assert!(json.contains("5000"));
4033 }
4034
4035 #[tokio::test]
4036 async fn test_extract_goal_with_json_response() {
4037 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4039 r#"{"description": "Build web app", "success_criteria": ["App runs on port 3000", "Has login page"]}"#,
4040 )]));
4041 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4042 let agent = AgentLoop::new(
4043 mock_client,
4044 tool_executor,
4045 test_tool_context(),
4046 AgentConfig::default(),
4047 );
4048
4049 let goal = agent.extract_goal("Build a web app").await.unwrap();
4050 assert_eq!(goal.description, "Build web app");
4051 assert_eq!(goal.success_criteria.len(), 2);
4052 assert_eq!(goal.success_criteria[0], "App runs on port 3000");
4053 }
4054
4055 #[tokio::test]
4056 async fn test_extract_goal_fallback_on_non_json() {
4057 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4059 "Some non-JSON response",
4060 )]));
4061 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4062 let agent = AgentLoop::new(
4063 mock_client,
4064 tool_executor,
4065 test_tool_context(),
4066 AgentConfig::default(),
4067 );
4068
4069 let goal = agent.extract_goal("Do something").await.unwrap();
4070 assert_eq!(goal.description, "Do something");
4072 assert_eq!(goal.success_criteria.len(), 2);
4074 }
4075
4076 #[tokio::test]
4077 async fn test_check_goal_achievement_json_yes() {
4078 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4079 r#"{"achieved": true, "progress": 1.0, "remaining_criteria": []}"#,
4080 )]));
4081 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4082 let agent = AgentLoop::new(
4083 mock_client,
4084 tool_executor,
4085 test_tool_context(),
4086 AgentConfig::default(),
4087 );
4088
4089 let goal = crate::planning::AgentGoal::new("Test goal".to_string());
4090 let achieved = agent
4091 .check_goal_achievement(&goal, "All done")
4092 .await
4093 .unwrap();
4094 assert!(achieved);
4095 }
4096
4097 #[tokio::test]
4098 async fn test_check_goal_achievement_fallback_not_done() {
4099 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4101 "invalid json",
4102 )]));
4103 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4104 let agent = AgentLoop::new(
4105 mock_client,
4106 tool_executor,
4107 test_tool_context(),
4108 AgentConfig::default(),
4109 );
4110
4111 let goal = crate::planning::AgentGoal::new("Test goal".to_string());
4112 let achieved = agent
4114 .check_goal_achievement(&goal, "still working")
4115 .await
4116 .unwrap();
4117 assert!(!achieved);
4118 }
4119
4120 #[test]
4125 fn test_build_augmented_system_prompt_empty_context() {
4126 let mock_client = Arc::new(MockLlmClient::new(vec![]));
4127 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4128 let config = AgentConfig {
4129 system_prompt: Some("Base prompt".to_string()),
4130 ..Default::default()
4131 };
4132 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
4133
4134 let result = agent.build_augmented_system_prompt(&[]);
4135 assert_eq!(result, Some("Base prompt".to_string()));
4136 }
4137
4138 #[test]
4139 fn test_build_augmented_system_prompt_no_system_prompt() {
4140 let mock_client = Arc::new(MockLlmClient::new(vec![]));
4141 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4142 let agent = AgentLoop::new(
4143 mock_client,
4144 tool_executor,
4145 test_tool_context(),
4146 AgentConfig::default(),
4147 );
4148
4149 let result = agent.build_augmented_system_prompt(&[]);
4150 assert_eq!(result, None);
4151 }
4152
4153 #[test]
4154 fn test_build_augmented_system_prompt_with_context_no_base() {
4155 use crate::context::{ContextItem, ContextResult, ContextType};
4156
4157 let mock_client = Arc::new(MockLlmClient::new(vec![]));
4158 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4159 let agent = AgentLoop::new(
4160 mock_client,
4161 tool_executor,
4162 test_tool_context(),
4163 AgentConfig::default(),
4164 );
4165
4166 let context = vec![ContextResult {
4167 provider: "test".to_string(),
4168 items: vec![ContextItem::new("id1", ContextType::Resource, "Content")],
4169 total_tokens: 10,
4170 truncated: false,
4171 }];
4172
4173 let result = agent.build_augmented_system_prompt(&context);
4174 assert!(result.is_some());
4175 let text = result.unwrap();
4176 assert!(text.contains("<context"));
4177 assert!(text.contains("Content"));
4178 }
4179
4180 #[test]
4185 fn test_agent_result_clone() {
4186 let result = AgentResult {
4187 text: "output".to_string(),
4188 messages: vec![Message::user("hello")],
4189 usage: TokenUsage::default(),
4190 tool_calls_count: 3,
4191 };
4192 let cloned = result.clone();
4193 assert_eq!(cloned.text, result.text);
4194 assert_eq!(cloned.tool_calls_count, result.tool_calls_count);
4195 }
4196
4197 #[test]
4198 fn test_agent_result_debug() {
4199 let result = AgentResult {
4200 text: "output".to_string(),
4201 messages: vec![Message::user("hello")],
4202 usage: TokenUsage::default(),
4203 tool_calls_count: 3,
4204 };
4205 let debug = format!("{:?}", result);
4206 assert!(debug.contains("AgentResult"));
4207 assert!(debug.contains("output"));
4208 }
4209
4210 #[test]
4215 fn test_handle_post_execution_metadata_no_metadata() {
4216 let mut system = Some("base prompt".to_string());
4217 let result = AgentLoop::handle_post_execution_metadata(&None, &mut system, None);
4218 assert!(result.is_none());
4219 assert_eq!(system.as_deref(), Some("base prompt"));
4220 }
4221
4222 #[test]
4223 fn test_handle_post_execution_metadata_no_load_skill_key() {
4224 let mut system = Some("base prompt".to_string());
4225 let meta = Some(serde_json::json!({"other": "value"}));
4226 let result = AgentLoop::handle_post_execution_metadata(&meta, &mut system, None);
4227 assert!(result.is_none());
4228 assert_eq!(system.as_deref(), Some("base prompt"));
4229 }
4230
4231 #[test]
4232 fn test_handle_post_execution_metadata_load_skill_false() {
4233 let mut system = Some("base prompt".to_string());
4234 let meta = Some(serde_json::json!({"_load_skill": false}));
4235 let result = AgentLoop::handle_post_execution_metadata(&meta, &mut system, None);
4236 assert!(result.is_none());
4237 }
4238
4239 #[test]
4240 fn test_handle_post_execution_metadata_invalid_skill_content() {
4241 let mut system = Some("base prompt".to_string());
4242 let meta = Some(serde_json::json!({
4243 "_load_skill": true,
4244 "skill_name": "bad.md",
4245 "skill_content": "not a valid skill",
4246 }));
4247 let result = AgentLoop::handle_post_execution_metadata(&meta, &mut system, None);
4248 assert!(result.is_none());
4249 assert_eq!(system.as_deref(), Some("base prompt"));
4250 }
4251
4252 #[test]
4253 fn test_handle_post_execution_metadata_valid_skill() {
4254 let mut system = Some("base prompt".to_string());
4255 let skill_content =
4256 "---\nname: test-skill\ndescription: A test\n---\n# Instructions\nDo things.";
4257 let meta = Some(serde_json::json!({
4258 "_load_skill": true,
4259 "skill_name": "test-skill.md",
4260 "skill_content": skill_content,
4261 }));
4262 let result = AgentLoop::handle_post_execution_metadata(&meta, &mut system, None);
4263 assert!(result.is_some());
4264 let xml = result.unwrap();
4265 assert!(xml.contains("<skill name=\"test-skill\">"));
4266 assert!(xml.contains("# Instructions\nDo things."));
4267
4268 let sys = system.unwrap();
4270 assert!(sys.starts_with("base prompt"));
4271 assert!(sys.contains("<skills>"));
4272 assert!(sys.contains("</skills>"));
4273 }
4274
4275 #[test]
4276 fn test_handle_post_execution_metadata_none_system_prompt() {
4277 let mut system: Option<String> = None;
4278 let skill_content = "---\nname: my-skill\ndescription: desc\n---\nContent here";
4279 let meta = Some(serde_json::json!({
4280 "_load_skill": true,
4281 "skill_name": "my-skill.md",
4282 "skill_content": skill_content,
4283 }));
4284 let result = AgentLoop::handle_post_execution_metadata(&meta, &mut system, None);
4285 assert!(result.is_some());
4286
4287 let sys = system.unwrap();
4289 assert!(sys.contains("<skill name=\"my-skill\">"));
4290 assert!(sys.contains("Content here"));
4291 }
4292
4293 #[test]
4294 fn test_handle_post_execution_metadata_tool_kind_injects_xml() {
4295 let mut system = Some("base".to_string());
4296 let skill_content =
4297 "---\nname: tool-skill\nkind: tool\ndescription: A tool\n---\nTool instructions.";
4298 let meta = Some(serde_json::json!({
4299 "_load_skill": true,
4300 "skill_name": "tool-skill",
4301 "skill_content": skill_content,
4302 }));
4303 let result = AgentLoop::handle_post_execution_metadata(&meta, &mut system, None);
4304 assert!(result.is_some());
4305 let xml = result.unwrap();
4306 assert!(xml.contains("<skill name=\"tool-skill\">"));
4307 assert!(xml.contains("Tool instructions."));
4308 }
4309
4310 #[test]
4311 fn test_handle_post_execution_metadata_agent_kind_returns_none() {
4312 let mut system = Some("base".to_string());
4313 let skill_content =
4314 "---\nname: agent-skill\nkind: agent\ndescription: An agent\n---\nAgent def.";
4315 let meta = Some(serde_json::json!({
4316 "_load_skill": true,
4317 "skill_name": "agent-skill",
4318 "skill_content": skill_content,
4319 }));
4320 let result = AgentLoop::handle_post_execution_metadata(&meta, &mut system, None);
4321 assert!(result.is_none());
4323 assert_eq!(system.as_deref(), Some("base"));
4325 }
4326}