1use crate::context::{ContextProvider, ContextQuery, ContextResult};
13use crate::hitl::ConfirmationManager;
14use crate::hooks::{
15 GenerateEndEvent, GenerateStartEvent, HookEngine, HookEvent, HookResult, PostToolUseEvent,
16 PreToolUseEvent, TokenUsageInfo, ToolCallInfo, ToolResultData,
17};
18use crate::llm::{LlmClient, LlmResponse, Message, TokenUsage, ToolCall, ToolDefinition};
19use crate::permissions::{PermissionDecision, PermissionPolicy};
20use crate::planning::{AgentGoal, ExecutionPlan, TaskStatus};
21use crate::queue::{SessionCommand, SessionLane};
22use crate::session_lane_queue::SessionLaneQueue;
23use crate::tools::skill::Skill;
24use crate::tools::{ToolContext, ToolExecutor, ToolStreamEvent};
25use anyhow::{Context, Result};
26use async_trait::async_trait;
27use futures::future::join_all;
28use serde::{Deserialize, Serialize};
29use serde_json::Value;
30use std::sync::Arc;
31use std::time::Duration;
32use tokio::sync::{mpsc, RwLock};
33
34const MAX_TOOL_ROUNDS: usize = 50;
36
37#[derive(Clone)]
39pub struct AgentConfig {
40 pub system_prompt: Option<String>,
41 pub tools: Vec<ToolDefinition>,
42 pub max_tool_rounds: usize,
43 pub permission_policy: Option<Arc<RwLock<PermissionPolicy>>>,
45 pub confirmation_manager: Option<Arc<ConfirmationManager>>,
47 pub context_providers: Vec<Arc<dyn ContextProvider>>,
49 pub planning_enabled: bool,
51 pub goal_tracking: bool,
53 pub skill_tool_filters: Vec<Skill>,
58 pub hook_engine: Option<Arc<HookEngine>>,
60}
61
62impl std::fmt::Debug for AgentConfig {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 f.debug_struct("AgentConfig")
65 .field("system_prompt", &self.system_prompt)
66 .field("tools", &self.tools)
67 .field("max_tool_rounds", &self.max_tool_rounds)
68 .field("permission_policy", &self.permission_policy.is_some())
69 .field("confirmation_manager", &self.confirmation_manager.is_some())
70 .field("context_providers", &self.context_providers.len())
71 .field("planning_enabled", &self.planning_enabled)
72 .field("goal_tracking", &self.goal_tracking)
73 .field("skill_tool_filters", &self.skill_tool_filters.len())
74 .field("hook_engine", &self.hook_engine.is_some())
75 .finish()
76 }
77}
78
79impl Default for AgentConfig {
80 fn default() -> Self {
81 Self {
82 system_prompt: None,
83 tools: Vec::new(), max_tool_rounds: MAX_TOOL_ROUNDS,
85 permission_policy: None,
86 confirmation_manager: None,
87 context_providers: Vec::new(),
88 planning_enabled: false,
89 goal_tracking: false,
90 skill_tool_filters: Vec::new(),
91 hook_engine: None,
92 }
93 }
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize)]
102#[serde(tag = "type")]
103#[non_exhaustive]
104pub enum AgentEvent {
105 #[serde(rename = "agent_start")]
107 Start { prompt: String },
108
109 #[serde(rename = "turn_start")]
111 TurnStart { turn: usize },
112
113 #[serde(rename = "text_delta")]
115 TextDelta { text: String },
116
117 #[serde(rename = "tool_start")]
119 ToolStart { id: String, name: String },
120
121 #[serde(rename = "tool_end")]
123 ToolEnd {
124 id: String,
125 name: String,
126 output: String,
127 exit_code: i32,
128 },
129
130 #[serde(rename = "tool_output_delta")]
132 ToolOutputDelta {
133 id: String,
134 name: String,
135 delta: String,
136 },
137
138 #[serde(rename = "turn_end")]
140 TurnEnd { turn: usize, usage: TokenUsage },
141
142 #[serde(rename = "agent_end")]
144 End { text: String, usage: TokenUsage },
145
146 #[serde(rename = "error")]
148 Error { message: String },
149
150 #[serde(rename = "confirmation_required")]
152 ConfirmationRequired {
153 tool_id: String,
154 tool_name: String,
155 args: serde_json::Value,
156 timeout_ms: u64,
157 },
158
159 #[serde(rename = "confirmation_received")]
161 ConfirmationReceived {
162 tool_id: String,
163 approved: bool,
164 reason: Option<String>,
165 },
166
167 #[serde(rename = "confirmation_timeout")]
169 ConfirmationTimeout {
170 tool_id: String,
171 action_taken: String, },
173
174 #[serde(rename = "external_task_pending")]
176 ExternalTaskPending {
177 task_id: String,
178 session_id: String,
179 lane: crate::hitl::SessionLane,
180 command_type: String,
181 payload: serde_json::Value,
182 timeout_ms: u64,
183 },
184
185 #[serde(rename = "external_task_completed")]
187 ExternalTaskCompleted {
188 task_id: String,
189 session_id: String,
190 success: bool,
191 },
192
193 #[serde(rename = "permission_denied")]
195 PermissionDenied {
196 tool_id: String,
197 tool_name: String,
198 args: serde_json::Value,
199 reason: String,
200 },
201
202 #[serde(rename = "context_resolving")]
204 ContextResolving { providers: Vec<String> },
205
206 #[serde(rename = "context_resolved")]
208 ContextResolved {
209 total_items: usize,
210 total_tokens: usize,
211 },
212
213 #[serde(rename = "command_dead_lettered")]
218 CommandDeadLettered {
219 command_id: String,
220 command_type: String,
221 lane: String,
222 error: String,
223 attempts: u32,
224 },
225
226 #[serde(rename = "command_retry")]
228 CommandRetry {
229 command_id: String,
230 command_type: String,
231 lane: String,
232 attempt: u32,
233 delay_ms: u64,
234 },
235
236 #[serde(rename = "queue_alert")]
238 QueueAlert {
239 level: String,
240 alert_type: String,
241 message: String,
242 },
243
244 #[serde(rename = "task_updated")]
249 TaskUpdated {
250 session_id: String,
251 tasks: Vec<crate::planning::Task>,
252 },
253
254 #[serde(rename = "memory_stored")]
259 MemoryStored {
260 memory_id: String,
261 memory_type: String,
262 importance: f32,
263 tags: Vec<String>,
264 },
265
266 #[serde(rename = "memory_recalled")]
268 MemoryRecalled {
269 memory_id: String,
270 content: String,
271 relevance: f32,
272 },
273
274 #[serde(rename = "memories_searched")]
276 MemoriesSearched {
277 query: Option<String>,
278 tags: Vec<String>,
279 result_count: usize,
280 },
281
282 #[serde(rename = "memory_cleared")]
284 MemoryCleared {
285 tier: String, count: u64,
287 },
288
289 #[serde(rename = "subagent_start")]
294 SubagentStart {
295 task_id: String,
297 session_id: String,
299 parent_session_id: String,
301 agent: String,
303 description: String,
305 },
306
307 #[serde(rename = "subagent_progress")]
309 SubagentProgress {
310 task_id: String,
312 session_id: String,
314 status: String,
316 metadata: serde_json::Value,
318 },
319
320 #[serde(rename = "subagent_end")]
322 SubagentEnd {
323 task_id: String,
325 session_id: String,
327 agent: String,
329 output: String,
331 success: bool,
333 },
334
335 #[serde(rename = "planning_start")]
340 PlanningStart { prompt: String },
341
342 #[serde(rename = "planning_end")]
344 PlanningEnd {
345 plan: ExecutionPlan,
346 estimated_steps: usize,
347 },
348
349 #[serde(rename = "step_start")]
351 StepStart {
352 step_id: String,
353 description: String,
354 step_number: usize,
355 total_steps: usize,
356 },
357
358 #[serde(rename = "step_end")]
360 StepEnd {
361 step_id: String,
362 status: TaskStatus,
363 step_number: usize,
364 total_steps: usize,
365 },
366
367 #[serde(rename = "goal_extracted")]
369 GoalExtracted { goal: AgentGoal },
370
371 #[serde(rename = "goal_progress")]
373 GoalProgress {
374 goal: String,
375 progress: f32,
376 completed_steps: usize,
377 total_steps: usize,
378 },
379
380 #[serde(rename = "goal_achieved")]
382 GoalAchieved {
383 goal: String,
384 total_steps: usize,
385 duration_ms: i64,
386 },
387
388 #[serde(rename = "context_compacted")]
393 ContextCompacted {
394 session_id: String,
395 before_messages: usize,
396 after_messages: usize,
397 percent_before: f32,
398 },
399
400 #[serde(rename = "persistence_failed")]
405 PersistenceFailed {
406 session_id: String,
407 operation: String,
408 error: String,
409 },
410}
411
412#[derive(Debug, Clone)]
414pub struct AgentResult {
415 pub text: String,
416 pub messages: Vec<Message>,
417 pub usage: TokenUsage,
418 pub tool_calls_count: usize,
419}
420
421pub struct ToolCommand {
429 tool_executor: Arc<ToolExecutor>,
430 tool_name: String,
431 tool_args: Value,
432 tool_context: ToolContext,
433}
434
435#[async_trait]
436impl SessionCommand for ToolCommand {
437 async fn execute(&self) -> Result<Value> {
438 let result = self
439 .tool_executor
440 .execute_with_context(&self.tool_name, &self.tool_args, &self.tool_context)
441 .await?;
442 Ok(serde_json::json!({
443 "output": result.output,
444 "exit_code": result.exit_code,
445 "metadata": result.metadata,
446 }))
447 }
448
449 fn command_type(&self) -> &str {
450 &self.tool_name
451 }
452
453 fn payload(&self) -> Value {
454 self.tool_args.clone()
455 }
456}
457
458pub fn partition_by_lane(tool_calls: &[ToolCall]) -> (Vec<ToolCall>, Vec<ToolCall>) {
468 let mut query_tools = Vec::new();
469 let mut sequential_tools = Vec::new();
470
471 for tc in tool_calls {
472 match SessionLane::from_tool_name(&tc.name) {
473 SessionLane::Query => query_tools.push(tc.clone()),
474 _ => sequential_tools.push(tc.clone()),
475 }
476 }
477
478 (query_tools, sequential_tools)
479}
480
481#[derive(Clone)]
483pub struct AgentLoop {
484 llm_client: Arc<dyn LlmClient>,
485 tool_executor: Arc<ToolExecutor>,
486 tool_context: ToolContext,
487 config: AgentConfig,
488 tool_metrics: Option<Arc<RwLock<crate::telemetry::ToolMetrics>>>,
490 command_queue: Option<Arc<SessionLaneQueue>>,
492}
493
494impl AgentLoop {
495 pub fn new(
496 llm_client: Arc<dyn LlmClient>,
497 tool_executor: Arc<ToolExecutor>,
498 tool_context: ToolContext,
499 config: AgentConfig,
500 ) -> Self {
501 Self {
502 llm_client,
503 tool_executor,
504 tool_context,
505 config,
506 tool_metrics: None,
507 command_queue: None,
508 }
509 }
510
511 pub fn with_tool_metrics(
513 mut self,
514 metrics: Arc<RwLock<crate::telemetry::ToolMetrics>>,
515 ) -> Self {
516 self.tool_metrics = Some(metrics);
517 self
518 }
519
520 pub fn with_queue(mut self, queue: Arc<SessionLaneQueue>) -> Self {
526 self.command_queue = Some(queue);
527 self
528 }
529
530 fn streaming_tool_context(
539 &self,
540 event_tx: &Option<mpsc::Sender<AgentEvent>>,
541 tool_id: &str,
542 tool_name: &str,
543 ) -> ToolContext {
544 let mut ctx = self.tool_context.clone();
545 if let Some(agent_tx) = event_tx {
546 let (tool_tx, mut tool_rx) = mpsc::channel::<ToolStreamEvent>(64);
547 ctx.event_tx = Some(tool_tx);
548
549 let agent_tx = agent_tx.clone();
550 let tool_id = tool_id.to_string();
551 let tool_name = tool_name.to_string();
552 tokio::spawn(async move {
553 while let Some(event) = tool_rx.recv().await {
554 match event {
555 ToolStreamEvent::OutputDelta(delta) => {
556 agent_tx
557 .send(AgentEvent::ToolOutputDelta {
558 id: tool_id.clone(),
559 name: tool_name.clone(),
560 delta,
561 })
562 .await
563 .ok();
564 }
565 }
566 }
567 });
568 }
569 ctx
570 }
571
572 async fn resolve_context(&self, prompt: &str, session_id: Option<&str>) -> Vec<ContextResult> {
576 if self.config.context_providers.is_empty() {
577 return Vec::new();
578 }
579
580 let query = ContextQuery::new(prompt).with_session_id(session_id.unwrap_or(""));
581
582 let futures = self
583 .config
584 .context_providers
585 .iter()
586 .map(|p| p.query(&query));
587 let outcomes = join_all(futures).await;
588
589 outcomes
590 .into_iter()
591 .enumerate()
592 .filter_map(|(i, r)| match r {
593 Ok(result) if !result.is_empty() => Some(result),
594 Ok(_) => None,
595 Err(e) => {
596 tracing::warn!(
597 "Context provider '{}' failed: {}",
598 self.config.context_providers[i].name(),
599 e
600 );
601 None
602 }
603 })
604 .collect()
605 }
606
607 fn build_augmented_system_prompt(&self, context_results: &[ContextResult]) -> Option<String> {
609 if context_results.is_empty() {
610 return self.config.system_prompt.clone();
611 }
612
613 let context_xml: String = context_results
615 .iter()
616 .map(|r| r.to_xml())
617 .collect::<Vec<_>>()
618 .join("\n\n");
619
620 match &self.config.system_prompt {
622 Some(system) => Some(format!("{}\n\n{}", system, context_xml)),
623 None => Some(context_xml),
624 }
625 }
626
627 async fn notify_turn_complete(&self, session_id: &str, prompt: &str, response: &str) {
629 let futures = self
630 .config
631 .context_providers
632 .iter()
633 .map(|p| p.on_turn_complete(session_id, prompt, response));
634 let outcomes = join_all(futures).await;
635
636 for (i, result) in outcomes.into_iter().enumerate() {
637 if let Err(e) = result {
638 tracing::warn!(
639 "Context provider '{}' on_turn_complete failed: {}",
640 self.config.context_providers[i].name(),
641 e
642 );
643 }
644 }
645 }
646
647 async fn fire_pre_tool_use(
650 &self,
651 session_id: &str,
652 tool_name: &str,
653 args: &serde_json::Value,
654 ) -> Option<HookResult> {
655 if let Some(he) = &self.config.hook_engine {
656 let event = HookEvent::PreToolUse(PreToolUseEvent {
657 session_id: session_id.to_string(),
658 tool: tool_name.to_string(),
659 args: args.clone(),
660 working_directory: self.tool_context.workspace.to_string_lossy().to_string(),
661 recent_tools: Vec::new(),
662 });
663 let result = he.fire(&event).await;
664 if result.is_block() {
665 return Some(result);
666 }
667 }
668 None
669 }
670
671 async fn fire_post_tool_use(
673 &self,
674 session_id: &str,
675 tool_name: &str,
676 args: &serde_json::Value,
677 output: &str,
678 success: bool,
679 duration_ms: u64,
680 ) {
681 if let Some(he) = &self.config.hook_engine {
682 let event = HookEvent::PostToolUse(PostToolUseEvent {
683 session_id: session_id.to_string(),
684 tool: tool_name.to_string(),
685 args: args.clone(),
686 result: ToolResultData {
687 success,
688 output: output.to_string(),
689 exit_code: if success { Some(0) } else { Some(1) },
690 duration_ms,
691 },
692 });
693 let _ = he.fire(&event).await;
694 }
695 }
696
697 async fn fire_generate_start(
699 &self,
700 session_id: &str,
701 prompt: &str,
702 system_prompt: &Option<String>,
703 ) {
704 if let Some(he) = &self.config.hook_engine {
705 let event = HookEvent::GenerateStart(GenerateStartEvent {
706 session_id: session_id.to_string(),
707 prompt: prompt.to_string(),
708 system_prompt: system_prompt.clone(),
709 model_provider: String::new(),
710 model_name: String::new(),
711 available_tools: self.config.tools.iter().map(|t| t.name.clone()).collect(),
712 });
713 let _ = he.fire(&event).await;
714 }
715 }
716
717 async fn fire_generate_end(
719 &self,
720 session_id: &str,
721 prompt: &str,
722 response: &LlmResponse,
723 duration_ms: u64,
724 ) {
725 if let Some(he) = &self.config.hook_engine {
726 let tool_calls: Vec<ToolCallInfo> = response
727 .tool_calls()
728 .iter()
729 .map(|tc| ToolCallInfo {
730 name: tc.name.clone(),
731 args: tc.args.clone(),
732 })
733 .collect();
734
735 let event = HookEvent::GenerateEnd(GenerateEndEvent {
736 session_id: session_id.to_string(),
737 prompt: prompt.to_string(),
738 response_text: response.text().to_string(),
739 tool_calls,
740 usage: TokenUsageInfo {
741 prompt_tokens: response.usage.prompt_tokens as i32,
742 completion_tokens: response.usage.completion_tokens as i32,
743 total_tokens: response.usage.total_tokens as i32,
744 },
745 duration_ms,
746 });
747 let _ = he.fire(&event).await;
748 }
749 }
750
751 fn handle_post_execution_metadata(
761 metadata: &Option<serde_json::Value>,
762 augmented_system: &mut Option<String>,
763 tool_executor: Option<&ToolExecutor>,
764 ) -> Option<String> {
765 let meta = metadata.as_ref()?;
766 if meta.get("_load_skill")?.as_bool() != Some(true) {
767 return None;
768 }
769
770 let skill_content = meta.get("skill_content")?.as_str()?;
771 let skill_name = meta
772 .get("skill_name")
773 .and_then(|v| v.as_str())
774 .unwrap_or("unknown");
775
776 let skill = Skill::parse(skill_content)?;
778
779 match skill.kind {
780 crate::tools::SkillKind::Instruction => {
781 let xml_fragment = format!(
783 "\n\n<skills>\n<skill name=\"{}\">\n{}\n</skill>\n</skills>",
784 skill.name, skill.content
785 );
786
787 match augmented_system {
788 Some(existing) => existing.push_str(&xml_fragment),
789 None => *augmented_system = Some(xml_fragment.clone()),
790 }
791
792 tracing::info!(
793 skill_name = skill_name,
794 kind = "instruction",
795 "Auto-loaded instruction skill into session"
796 );
797
798 Some(xml_fragment)
799 }
800 crate::tools::SkillKind::Tool => {
801 let xml_fragment = format!(
803 "\n\n<skills>\n<skill name=\"{}\">\n{}\n</skill>\n</skills>",
804 skill.name, skill.content
805 );
806
807 match augmented_system {
808 Some(existing) => existing.push_str(&xml_fragment),
809 None => *augmented_system = Some(xml_fragment.clone()),
810 }
811
812 if let Some(executor) = tool_executor {
814 let tools = crate::tools::parse_skill_tools(skill_content);
815 for tool in tools {
816 tracing::info!(
817 skill_name = skill_name,
818 tool_name = tool.name(),
819 "Registered tool from Tool-kind skill"
820 );
821 executor.registry().register(tool);
822 }
823 }
824
825 tracing::info!(
826 skill_name = skill_name,
827 kind = "tool",
828 "Auto-loaded tool skill into session"
829 );
830
831 Some(xml_fragment)
832 }
833 crate::tools::SkillKind::Agent => {
834 tracing::info!(
835 skill_name = skill_name,
836 kind = "agent",
837 "Loaded agent skill (agent registration not yet implemented)"
838 );
839 None
840 }
841 }
842 }
843
844 pub async fn execute(
850 &self,
851 history: &[Message],
852 prompt: &str,
853 event_tx: Option<mpsc::Sender<AgentEvent>>,
854 ) -> Result<AgentResult> {
855 self.execute_with_session(history, prompt, None, event_tx)
856 .await
857 }
858
859 pub async fn execute_with_session(
864 &self,
865 history: &[Message],
866 prompt: &str,
867 session_id: Option<&str>,
868 event_tx: Option<mpsc::Sender<AgentEvent>>,
869 ) -> Result<AgentResult> {
870 tracing::info!(
871 a3s.session.id = session_id.unwrap_or("none"),
872 a3s.agent.max_turns = self.config.max_tool_rounds,
873 "a3s.agent.execute started"
874 );
875
876 let result = if self.config.planning_enabled {
878 self.execute_with_planning(history, prompt, event_tx).await
879 } else {
880 self.execute_loop(history, prompt, session_id, event_tx)
881 .await
882 };
883
884 match &result {
885 Ok(r) => tracing::info!(
886 a3s.agent.tool_calls_count = r.tool_calls_count,
887 a3s.llm.total_tokens = r.usage.total_tokens,
888 "a3s.agent.execute completed"
889 ),
890 Err(e) => tracing::warn!(
891 error = %e,
892 "a3s.agent.execute failed"
893 ),
894 }
895
896 result
897 }
898
899 async fn execute_loop(
905 &self,
906 history: &[Message],
907 prompt: &str,
908 session_id: Option<&str>,
909 event_tx: Option<mpsc::Sender<AgentEvent>>,
910 ) -> Result<AgentResult> {
911 let mut messages = history.to_vec();
912 let mut total_usage = TokenUsage::default();
913 let mut tool_calls_count = 0;
914 let mut turn = 0;
915
916 if let Some(tx) = &event_tx {
918 tx.send(AgentEvent::Start {
919 prompt: prompt.to_string(),
920 })
921 .await
922 .ok();
923 }
924
925 let mut augmented_system = if !self.config.context_providers.is_empty() {
927 if let Some(tx) = &event_tx {
929 let provider_names: Vec<String> = self
930 .config
931 .context_providers
932 .iter()
933 .map(|p| p.name().to_string())
934 .collect();
935 tx.send(AgentEvent::ContextResolving {
936 providers: provider_names,
937 })
938 .await
939 .ok();
940 }
941
942 tracing::info!(
943 a3s.context.providers = self.config.context_providers.len() as i64,
944 "Context resolution started"
945 );
946 let context_results = self.resolve_context(prompt, session_id).await;
947
948 if let Some(tx) = &event_tx {
950 let total_items: usize = context_results.iter().map(|r| r.items.len()).sum();
951 let total_tokens: usize = context_results.iter().map(|r| r.total_tokens).sum();
952
953 tracing::info!(
954 context_items = total_items,
955 context_tokens = total_tokens,
956 "Context resolution completed"
957 );
958
959 tx.send(AgentEvent::ContextResolved {
960 total_items,
961 total_tokens,
962 })
963 .await
964 .ok();
965 }
966
967 self.build_augmented_system_prompt(&context_results)
968 } else {
969 self.config.system_prompt.clone()
970 };
971
972 messages.push(Message::user(prompt));
974
975 loop {
976 turn += 1;
977
978 if turn > self.config.max_tool_rounds {
979 let error = format!("Max tool rounds ({}) exceeded", self.config.max_tool_rounds);
980 if let Some(tx) = &event_tx {
981 tx.send(AgentEvent::Error {
982 message: error.clone(),
983 })
984 .await
985 .ok();
986 }
987 anyhow::bail!(error);
988 }
989
990 if let Some(tx) = &event_tx {
992 tx.send(AgentEvent::TurnStart { turn }).await.ok();
993 }
994
995 tracing::info!(
996 turn = turn,
997 max_turns = self.config.max_tool_rounds,
998 "Agent turn started"
999 );
1000
1001 tracing::info!(
1003 a3s.llm.streaming = event_tx.is_some(),
1004 "LLM completion started"
1005 );
1006
1007 self.fire_generate_start(session_id.unwrap_or(""), prompt, &augmented_system)
1009 .await;
1010
1011 let llm_start = std::time::Instant::now();
1012 let response = if event_tx.is_some() {
1013 let mut stream_rx = self
1015 .llm_client
1016 .complete_streaming(&messages, augmented_system.as_deref(), &self.config.tools)
1017 .await
1018 .context("LLM streaming call failed")?;
1019
1020 let mut final_response: Option<LlmResponse> = None;
1021
1022 while let Some(event) = stream_rx.recv().await {
1023 match event {
1024 crate::llm::StreamEvent::TextDelta(text) => {
1025 if let Some(tx) = &event_tx {
1026 tx.send(AgentEvent::TextDelta { text }).await.ok();
1027 }
1028 }
1029 crate::llm::StreamEvent::ToolUseStart { id, name } => {
1030 if let Some(tx) = &event_tx {
1031 tx.send(AgentEvent::ToolStart { id, name }).await.ok();
1032 }
1033 }
1034 crate::llm::StreamEvent::ToolUseInputDelta(_) => {
1035 }
1037 crate::llm::StreamEvent::Done(resp) => {
1038 final_response = Some(resp);
1039 }
1040 }
1041 }
1042
1043 final_response.context("Stream ended without final response")?
1044 } else {
1045 self.llm_client
1047 .complete(&messages, augmented_system.as_deref(), &self.config.tools)
1048 .await
1049 .context("LLM call failed")?
1050 };
1051
1052 total_usage.prompt_tokens += response.usage.prompt_tokens;
1054 total_usage.completion_tokens += response.usage.completion_tokens;
1055 total_usage.total_tokens += response.usage.total_tokens;
1056
1057 let llm_duration = llm_start.elapsed();
1059 tracing::info!(
1060 turn = turn,
1061 streaming = event_tx.is_some(),
1062 prompt_tokens = response.usage.prompt_tokens,
1063 completion_tokens = response.usage.completion_tokens,
1064 total_tokens = response.usage.total_tokens,
1065 stop_reason = response.stop_reason.as_deref().unwrap_or("unknown"),
1066 duration_ms = llm_duration.as_millis() as u64,
1067 "LLM completion finished"
1068 );
1069
1070 self.fire_generate_end(
1072 session_id.unwrap_or(""),
1073 prompt,
1074 &response,
1075 llm_duration.as_millis() as u64,
1076 )
1077 .await;
1078
1079 crate::telemetry::record_llm_usage(
1081 response.usage.prompt_tokens,
1082 response.usage.completion_tokens,
1083 response.usage.total_tokens,
1084 response.stop_reason.as_deref(),
1085 );
1086 tracing::info!(
1088 turn = turn,
1089 a3s.llm.total_tokens = response.usage.total_tokens,
1090 "Turn token usage"
1091 );
1092
1093 messages.push(response.message.clone());
1095
1096 let tool_calls = response.tool_calls();
1098
1099 if let Some(tx) = &event_tx {
1101 tx.send(AgentEvent::TurnEnd {
1102 turn,
1103 usage: response.usage.clone(),
1104 })
1105 .await
1106 .ok();
1107 }
1108
1109 if tool_calls.is_empty() {
1110 let final_text = response.text();
1112
1113 tracing::info!(
1115 tool_calls_count = tool_calls_count,
1116 total_prompt_tokens = total_usage.prompt_tokens,
1117 total_completion_tokens = total_usage.completion_tokens,
1118 total_tokens = total_usage.total_tokens,
1119 turns = turn,
1120 "Agent execution completed"
1121 );
1122
1123 if let Some(tx) = &event_tx {
1124 tx.send(AgentEvent::End {
1125 text: final_text.clone(),
1126 usage: total_usage.clone(),
1127 })
1128 .await
1129 .ok();
1130 }
1131
1132 if let Some(sid) = session_id {
1134 self.notify_turn_complete(sid, prompt, &final_text).await;
1135 }
1136
1137 return Ok(AgentResult {
1138 text: final_text,
1139 messages,
1140 usage: total_usage,
1141 tool_calls_count,
1142 });
1143 }
1144
1145 let (query_tools, sequential_tools) = if self.command_queue.is_some() {
1147 partition_by_lane(&tool_calls)
1148 } else {
1149 (Vec::new(), tool_calls.clone())
1150 };
1151
1152 if !query_tools.is_empty() {
1154 if let Some(queue) = &self.command_queue {
1155 let parallel_count = self
1156 .execute_query_tools_parallel(
1157 &query_tools,
1158 queue,
1159 &mut messages,
1160 &event_tx,
1161 &mut augmented_system,
1162 session_id,
1163 )
1164 .await;
1165 tool_calls_count += parallel_count;
1166 }
1167 }
1168
1169 for tool_call in sequential_tools {
1171 tool_calls_count += 1;
1172
1173 let tool_start = std::time::Instant::now();
1174
1175 tracing::info!(
1176 tool_name = tool_call.name.as_str(),
1177 tool_id = tool_call.id.as_str(),
1178 "Tool execution started"
1179 );
1180
1181 if let Some(parse_error) =
1187 tool_call.args.get("__parse_error").and_then(|v| v.as_str())
1188 {
1189 let error_msg = format!("Error: {}", parse_error);
1190 tracing::warn!(
1191 tool = tool_call.name.as_str(),
1192 "Malformed tool arguments from LLM"
1193 );
1194
1195 if let Some(tx) = &event_tx {
1196 tx.send(AgentEvent::ToolEnd {
1197 id: tool_call.id.clone(),
1198 name: tool_call.name.clone(),
1199 output: error_msg.clone(),
1200 exit_code: 1,
1201 })
1202 .await
1203 .ok();
1204 }
1205
1206 messages.push(Message::tool_result(&tool_call.id, &error_msg, true));
1207 continue;
1208 }
1209
1210 if let Some(HookResult::Block(reason)) = self
1212 .fire_pre_tool_use(session_id.unwrap_or(""), &tool_call.name, &tool_call.args)
1213 .await
1214 {
1215 let msg = format!("Tool '{}' blocked by hook: {}", tool_call.name, reason);
1216 tracing::info!(
1217 tool_name = tool_call.name.as_str(),
1218 "Tool blocked by PreToolUse hook"
1219 );
1220
1221 if let Some(tx) = &event_tx {
1222 tx.send(AgentEvent::PermissionDenied {
1223 tool_id: tool_call.id.clone(),
1224 tool_name: tool_call.name.clone(),
1225 args: tool_call.args.clone(),
1226 reason: reason.clone(),
1227 })
1228 .await
1229 .ok();
1230 }
1231
1232 messages.push(Message::tool_result(&tool_call.id, &msg, true));
1233 continue;
1234 }
1235
1236 if !self.config.skill_tool_filters.is_empty() {
1241 let has_restrictions = self
1242 .config
1243 .skill_tool_filters
1244 .iter()
1245 .any(|s| s.allowed_tools.is_some());
1246
1247 if has_restrictions {
1248 let args_str = serde_json::to_string(&tool_call.args).unwrap_or_default();
1249 let tool_allowed = self
1250 .config
1251 .skill_tool_filters
1252 .iter()
1253 .filter(|s| s.allowed_tools.is_some())
1254 .any(|s| s.is_tool_allowed(&tool_call.name, &args_str));
1255
1256 if !tool_allowed {
1257 tracing::info!(
1258 tool_name = tool_call.name.as_str(),
1259 "Tool blocked by skill allowed_tools restriction"
1260 );
1261 let msg = format!(
1262 "Tool '{}' is not permitted by any loaded skill's allowed_tools policy.",
1263 tool_call.name
1264 );
1265
1266 if let Some(tx) = &event_tx {
1267 tx.send(AgentEvent::PermissionDenied {
1268 tool_id: tool_call.id.clone(),
1269 tool_name: tool_call.name.clone(),
1270 args: tool_call.args.clone(),
1271 reason: "Blocked by skill allowed_tools restriction"
1272 .to_string(),
1273 })
1274 .await
1275 .ok();
1276 }
1277
1278 messages.push(Message::tool_result(&tool_call.id, &msg, true));
1279 continue;
1280 }
1281 }
1282 }
1283
1284 let permission_decision = if let Some(policy_lock) = &self.config.permission_policy
1286 {
1287 let policy = policy_lock.read().await;
1288 policy.check(&tool_call.name, &tool_call.args)
1289 } else {
1290 PermissionDecision::Ask
1292 };
1293
1294 let (output, exit_code, is_error, metadata) = match permission_decision {
1295 PermissionDecision::Deny => {
1296 tracing::info!(
1297 tool_name = tool_call.name.as_str(),
1298 permission = "deny",
1299 "Tool permission denied"
1300 );
1301 let denial_msg = format!(
1303 "Permission denied: Tool '{}' is blocked by permission policy.",
1304 tool_call.name
1305 );
1306
1307 if let Some(tx) = &event_tx {
1309 tx.send(AgentEvent::PermissionDenied {
1310 tool_id: tool_call.id.clone(),
1311 tool_name: tool_call.name.clone(),
1312 args: tool_call.args.clone(),
1313 reason: "Blocked by deny rule in permission policy".to_string(),
1314 })
1315 .await
1316 .ok();
1317 }
1318
1319 (denial_msg, 1, true, None)
1320 }
1321 PermissionDecision::Allow => {
1322 tracing::info!(
1323 tool_name = tool_call.name.as_str(),
1324 permission = "allow",
1325 "Tool permission: allow"
1326 );
1327 let stream_ctx =
1329 self.streaming_tool_context(&event_tx, &tool_call.id, &tool_call.name);
1330 let result = self
1331 .tool_executor
1332 .execute_with_context(&tool_call.name, &tool_call.args, &stream_ctx)
1333 .await;
1334
1335 match result {
1336 Ok(r) => (r.output, r.exit_code, r.exit_code != 0, r.metadata),
1337 Err(e) => (format!("Tool execution error: {}", e), 1, true, None),
1338 }
1339 }
1340 PermissionDecision::Ask => {
1341 tracing::info!(
1342 tool_name = tool_call.name.as_str(),
1343 permission = "ask",
1344 "Tool permission: ask"
1345 );
1346 if let Some(cm) = &self.config.confirmation_manager {
1348 if !cm.requires_confirmation(&tool_call.name).await {
1350 let stream_ctx = self.streaming_tool_context(
1351 &event_tx,
1352 &tool_call.id,
1353 &tool_call.name,
1354 );
1355 let result = self
1356 .tool_executor
1357 .execute_with_context(
1358 &tool_call.name,
1359 &tool_call.args,
1360 &stream_ctx,
1361 )
1362 .await;
1363
1364 let (output, exit_code, is_error, metadata) = match result {
1365 Ok(r) => (r.output, r.exit_code, r.exit_code != 0, r.metadata),
1366 Err(e) => {
1367 (format!("Tool execution error: {}", e), 1, true, None)
1368 }
1369 };
1370
1371 Self::handle_post_execution_metadata(
1373 &metadata,
1374 &mut augmented_system,
1375 Some(&self.tool_executor),
1376 );
1377
1378 if let Some(tx) = &event_tx {
1380 tx.send(AgentEvent::ToolEnd {
1381 id: tool_call.id.clone(),
1382 name: tool_call.name.clone(),
1383 output: output.clone(),
1384 exit_code,
1385 })
1386 .await
1387 .ok();
1388 }
1389
1390 messages.push(Message::tool_result(
1392 &tool_call.id,
1393 &output,
1394 is_error,
1395 ));
1396
1397 let tool_duration = tool_start.elapsed();
1399 crate::telemetry::record_tool_result(exit_code, tool_duration);
1400
1401 self.fire_post_tool_use(
1403 session_id.unwrap_or(""),
1404 &tool_call.name,
1405 &tool_call.args,
1406 &output,
1407 exit_code == 0,
1408 tool_duration.as_millis() as u64,
1409 )
1410 .await;
1411
1412 continue; }
1414
1415 let policy = cm.policy().await;
1417 let timeout_ms = policy.default_timeout_ms;
1418 let timeout_action = policy.timeout_action;
1419
1420 let rx = cm
1422 .request_confirmation(
1423 &tool_call.id,
1424 &tool_call.name,
1425 &tool_call.args,
1426 )
1427 .await;
1428
1429 let confirmation_result =
1431 tokio::time::timeout(Duration::from_millis(timeout_ms), rx).await;
1432
1433 match confirmation_result {
1434 Ok(Ok(response)) => {
1435 if response.approved {
1436 let stream_ctx = self.streaming_tool_context(
1437 &event_tx,
1438 &tool_call.id,
1439 &tool_call.name,
1440 );
1441 let result = self
1442 .tool_executor
1443 .execute_with_context(
1444 &tool_call.name,
1445 &tool_call.args,
1446 &stream_ctx,
1447 )
1448 .await;
1449
1450 match result {
1451 Ok(r) => (
1452 r.output,
1453 r.exit_code,
1454 r.exit_code != 0,
1455 r.metadata,
1456 ),
1457 Err(e) => (
1458 format!("Tool execution error: {}", e),
1459 1,
1460 true,
1461 None,
1462 ),
1463 }
1464 } else {
1465 let rejection_msg = format!(
1466 "Tool '{}' execution was rejected by user. Reason: {}",
1467 tool_call.name,
1468 response.reason.unwrap_or_else(|| "No reason provided".to_string())
1469 );
1470 (rejection_msg, 1, true, None)
1471 }
1472 }
1473 Ok(Err(_)) => {
1474 let msg = format!(
1475 "Tool '{}' confirmation failed: confirmation channel closed",
1476 tool_call.name
1477 );
1478 (msg, 1, true, None)
1479 }
1480 Err(_) => {
1481 cm.check_timeouts().await;
1482
1483 match timeout_action {
1484 crate::hitl::TimeoutAction::Reject => {
1485 let msg = format!(
1486 "Tool '{}' execution timed out waiting for confirmation ({}ms). Execution rejected.",
1487 tool_call.name, timeout_ms
1488 );
1489 (msg, 1, true, None)
1490 }
1491 crate::hitl::TimeoutAction::AutoApprove => {
1492 let stream_ctx = self.streaming_tool_context(
1493 &event_tx,
1494 &tool_call.id,
1495 &tool_call.name,
1496 );
1497 let result = self
1498 .tool_executor
1499 .execute_with_context(
1500 &tool_call.name,
1501 &tool_call.args,
1502 &stream_ctx,
1503 )
1504 .await;
1505
1506 match result {
1507 Ok(r) => (
1508 r.output,
1509 r.exit_code,
1510 r.exit_code != 0,
1511 r.metadata,
1512 ),
1513 Err(e) => (
1514 format!("Tool execution error: {}", e),
1515 1,
1516 true,
1517 None,
1518 ),
1519 }
1520 }
1521 }
1522 }
1523 }
1524 } else {
1525 let msg = format!(
1527 "Tool '{}' requires confirmation but no HITL confirmation manager is configured. \
1528 Configure a confirmation policy to enable tool execution.",
1529 tool_call.name
1530 );
1531 tracing::warn!(
1532 tool_name = tool_call.name.as_str(),
1533 "Tool requires confirmation but no HITL manager configured"
1534 );
1535 (msg, 1, true, None)
1536 }
1537 }
1538 };
1539
1540 Self::handle_post_execution_metadata(
1542 &metadata,
1543 &mut augmented_system,
1544 Some(&self.tool_executor),
1545 );
1546
1547 let tool_duration = tool_start.elapsed();
1549 tracing::info!(
1550 tool_name = tool_call.name.as_str(),
1551 tool_id = tool_call.id.as_str(),
1552 exit_code = exit_code,
1553 success = (exit_code == 0),
1554 duration_ms = tool_duration.as_millis() as u64,
1555 "Tool execution finished"
1556 );
1557
1558 crate::telemetry::record_tool_result(exit_code, tool_duration);
1560
1561 if let Some(ref metrics) = self.tool_metrics {
1563 metrics.write().await.record(
1564 &tool_call.name,
1565 exit_code == 0,
1566 tool_duration.as_millis() as u64,
1567 );
1568 }
1569
1570 self.fire_post_tool_use(
1572 session_id.unwrap_or(""),
1573 &tool_call.name,
1574 &tool_call.args,
1575 &output,
1576 exit_code == 0,
1577 tool_duration.as_millis() as u64,
1578 )
1579 .await;
1580
1581 if let Some(tx) = &event_tx {
1583 tx.send(AgentEvent::ToolEnd {
1584 id: tool_call.id.clone(),
1585 name: tool_call.name.clone(),
1586 output: output.clone(),
1587 exit_code,
1588 })
1589 .await
1590 .ok();
1591 }
1592
1593 messages.push(Message::tool_result(&tool_call.id, &output, is_error));
1595 }
1596 }
1597 }
1598
1599 async fn execute_query_tools_parallel(
1605 &self,
1606 query_tools: &[ToolCall],
1607 queue: &SessionLaneQueue,
1608 messages: &mut Vec<Message>,
1609 event_tx: &Option<mpsc::Sender<AgentEvent>>,
1610 augmented_system: &mut Option<String>,
1611 session_id: Option<&str>,
1612 ) -> usize {
1613 let mut receivers = Vec::with_capacity(query_tools.len());
1614 let mut tool_starts = Vec::with_capacity(query_tools.len());
1615
1616 for tool_call in query_tools {
1617 if let Some(parse_error) = tool_call.args.get("__parse_error").and_then(|v| v.as_str())
1619 {
1620 let error_msg = format!("Error: {}", parse_error);
1621 if let Some(tx) = event_tx {
1622 tx.send(AgentEvent::ToolEnd {
1623 id: tool_call.id.clone(),
1624 name: tool_call.name.clone(),
1625 output: error_msg.clone(),
1626 exit_code: 1,
1627 })
1628 .await
1629 .ok();
1630 }
1631 messages.push(Message::tool_result(&tool_call.id, &error_msg, true));
1632 continue;
1633 }
1634
1635 if let Some(HookResult::Block(reason)) = self
1637 .fire_pre_tool_use(session_id.unwrap_or(""), &tool_call.name, &tool_call.args)
1638 .await
1639 {
1640 let msg = format!("Tool '{}' blocked by hook: {}", tool_call.name, reason);
1641 if let Some(tx) = event_tx {
1642 tx.send(AgentEvent::PermissionDenied {
1643 tool_id: tool_call.id.clone(),
1644 tool_name: tool_call.name.clone(),
1645 args: tool_call.args.clone(),
1646 reason,
1647 })
1648 .await
1649 .ok();
1650 }
1651 messages.push(Message::tool_result(&tool_call.id, &msg, true));
1652 continue;
1653 }
1654
1655 if !self.config.skill_tool_filters.is_empty() {
1657 let has_restrictions = self
1658 .config
1659 .skill_tool_filters
1660 .iter()
1661 .any(|s| s.allowed_tools.is_some());
1662 if has_restrictions {
1663 let args_str = serde_json::to_string(&tool_call.args).unwrap_or_default();
1664 let tool_allowed = self
1665 .config
1666 .skill_tool_filters
1667 .iter()
1668 .filter(|s| s.allowed_tools.is_some())
1669 .any(|s| s.is_tool_allowed(&tool_call.name, &args_str));
1670 if !tool_allowed {
1671 let msg = format!(
1672 "Tool '{}' is not permitted by any loaded skill's allowed_tools policy.",
1673 tool_call.name
1674 );
1675 if let Some(tx) = event_tx {
1676 tx.send(AgentEvent::PermissionDenied {
1677 tool_id: tool_call.id.clone(),
1678 tool_name: tool_call.name.clone(),
1679 args: tool_call.args.clone(),
1680 reason: "Blocked by skill allowed_tools restriction".to_string(),
1681 })
1682 .await
1683 .ok();
1684 }
1685 messages.push(Message::tool_result(&tool_call.id, &msg, true));
1686 continue;
1687 }
1688 }
1689 }
1690
1691 let permission_decision = if let Some(policy_lock) = &self.config.permission_policy {
1693 let policy = policy_lock.read().await;
1694 policy.check(&tool_call.name, &tool_call.args)
1695 } else {
1696 PermissionDecision::Ask
1697 };
1698
1699 match permission_decision {
1700 PermissionDecision::Deny => {
1701 let denial_msg = format!(
1702 "Permission denied: Tool '{}' is blocked by permission policy.",
1703 tool_call.name
1704 );
1705 if let Some(tx) = event_tx {
1706 tx.send(AgentEvent::PermissionDenied {
1707 tool_id: tool_call.id.clone(),
1708 tool_name: tool_call.name.clone(),
1709 args: tool_call.args.clone(),
1710 reason: "Blocked by deny rule in permission policy".to_string(),
1711 })
1712 .await
1713 .ok();
1714 }
1715 messages.push(Message::tool_result(&tool_call.id, &denial_msg, true));
1716 continue;
1717 }
1718 PermissionDecision::Allow | PermissionDecision::Ask => {
1719 if permission_decision == PermissionDecision::Ask {
1724 if let Some(cm) = &self.config.confirmation_manager {
1725 if cm.requires_confirmation(&tool_call.name).await {
1726 continue;
1729 }
1730 }
1731 }
1732
1733 let cmd = ToolCommand {
1735 tool_executor: self.tool_executor.clone(),
1736 tool_name: tool_call.name.clone(),
1737 tool_args: tool_call.args.clone(),
1738 tool_context: self.tool_context.clone(),
1739 };
1740 let rx = queue.submit_by_tool(&tool_call.name, Box::new(cmd)).await;
1741 let start = std::time::Instant::now();
1742 tool_starts.push((tool_call.clone(), start));
1743 receivers.push(rx);
1744 }
1745 }
1746 }
1747
1748 let count = receivers.len();
1749
1750 let results = join_all(receivers).await;
1752
1753 for (i, result) in results.into_iter().enumerate() {
1754 let (tool_call, tool_start) = &tool_starts[i];
1755 let tool_duration = tool_start.elapsed();
1756
1757 let (output, exit_code, is_error, metadata) = match result {
1758 Ok(Ok(value)) => {
1759 let output = value["output"].as_str().unwrap_or("").to_string();
1760 let exit_code = value["exit_code"].as_i64().unwrap_or(0) as i32;
1761 let metadata = value.get("metadata").cloned();
1762 (output, exit_code, exit_code != 0, metadata)
1763 }
1764 Ok(Err(e)) => (format!("Tool execution error: {}", e), 1, true, None),
1765 Err(_) => ("Queue channel closed".to_string(), 1, true, None),
1766 };
1767
1768 Self::handle_post_execution_metadata(
1770 &metadata,
1771 augmented_system,
1772 Some(&self.tool_executor),
1773 );
1774
1775 tracing::info!(
1777 tool_name = tool_call.name.as_str(),
1778 tool_id = tool_call.id.as_str(),
1779 exit_code = exit_code,
1780 success = (exit_code == 0),
1781 duration_ms = tool_duration.as_millis() as u64,
1782 parallel = true,
1783 "Tool execution finished (parallel)"
1784 );
1785 crate::telemetry::record_tool_result(exit_code, tool_duration);
1786
1787 if let Some(ref metrics) = self.tool_metrics {
1788 metrics.write().await.record(
1789 &tool_call.name,
1790 exit_code == 0,
1791 tool_duration.as_millis() as u64,
1792 );
1793 }
1794
1795 self.fire_post_tool_use(
1797 session_id.unwrap_or(""),
1798 &tool_call.name,
1799 &tool_call.args,
1800 &output,
1801 exit_code == 0,
1802 tool_duration.as_millis() as u64,
1803 )
1804 .await;
1805
1806 if let Some(tx) = event_tx {
1808 tx.send(AgentEvent::ToolEnd {
1809 id: tool_call.id.clone(),
1810 name: tool_call.name.clone(),
1811 output: output.clone(),
1812 exit_code,
1813 })
1814 .await
1815 .ok();
1816 }
1817
1818 messages.push(Message::tool_result(&tool_call.id, &output, is_error));
1819 }
1820
1821 count
1822 }
1823
1824 pub async fn execute_streaming(
1826 &self,
1827 history: &[Message],
1828 prompt: &str,
1829 ) -> Result<(
1830 mpsc::Receiver<AgentEvent>,
1831 tokio::task::JoinHandle<Result<AgentResult>>,
1832 )> {
1833 let (tx, rx) = mpsc::channel(100);
1834
1835 let llm_client = self.llm_client.clone();
1836 let tool_executor = self.tool_executor.clone();
1837 let tool_context = self.tool_context.clone();
1838 let config = self.config.clone();
1839 let tool_metrics = self.tool_metrics.clone();
1840 let command_queue = self.command_queue.clone();
1841 let history = history.to_vec();
1842 let prompt = prompt.to_string();
1843
1844 let handle = tokio::spawn(async move {
1845 let mut agent = AgentLoop::new(llm_client, tool_executor, tool_context, config);
1846 if let Some(metrics) = tool_metrics {
1847 agent = agent.with_tool_metrics(metrics);
1848 }
1849 if let Some(queue) = command_queue {
1850 agent = agent.with_queue(queue);
1851 }
1852 agent.execute(&history, &prompt, Some(tx)).await
1853 });
1854
1855 Ok((rx, handle))
1856 }
1857
1858 pub async fn plan(&self, prompt: &str, _context: Option<&str>) -> Result<ExecutionPlan> {
1863 use crate::planning::LlmPlanner;
1864
1865 match LlmPlanner::create_plan(&self.llm_client, prompt).await {
1866 Ok(plan) => Ok(plan),
1867 Err(e) => {
1868 tracing::warn!("LLM plan creation failed, using fallback: {}", e);
1869 Ok(LlmPlanner::fallback_plan(prompt))
1870 }
1871 }
1872 }
1873
1874 pub async fn execute_with_planning(
1876 &self,
1877 history: &[Message],
1878 prompt: &str,
1879 event_tx: Option<mpsc::Sender<AgentEvent>>,
1880 ) -> Result<AgentResult> {
1881 if let Some(tx) = &event_tx {
1883 tx.send(AgentEvent::PlanningStart {
1884 prompt: prompt.to_string(),
1885 })
1886 .await
1887 .ok();
1888 }
1889
1890 let plan = self.plan(prompt, None).await?;
1892
1893 if let Some(tx) = &event_tx {
1895 tx.send(AgentEvent::PlanningEnd {
1896 estimated_steps: plan.steps.len(),
1897 plan: plan.clone(),
1898 })
1899 .await
1900 .ok();
1901 }
1902
1903 self.execute_plan(history, &plan, event_tx).await
1905 }
1906
1907 async fn execute_plan(
1914 &self,
1915 history: &[Message],
1916 plan: &ExecutionPlan,
1917 event_tx: Option<mpsc::Sender<AgentEvent>>,
1918 ) -> Result<AgentResult> {
1919 let mut plan = plan.clone();
1920 let mut current_history = history.to_vec();
1921 let mut total_usage = TokenUsage::default();
1922 let mut tool_calls_count = 0;
1923 let total_steps = plan.steps.len();
1924
1925 let steps_text = plan
1927 .steps
1928 .iter()
1929 .enumerate()
1930 .map(|(i, step)| format!("{}. {}", i + 1, step.content))
1931 .collect::<Vec<_>>()
1932 .join("\n");
1933 current_history.push(Message::user(&crate::prompts::render(
1934 crate::prompts::PLAN_EXECUTE_GOAL,
1935 &[("goal", &plan.goal), ("steps", &steps_text)],
1936 )));
1937
1938 loop {
1939 let ready: Vec<String> = plan
1940 .get_ready_steps()
1941 .iter()
1942 .map(|s| s.id.clone())
1943 .collect();
1944
1945 if ready.is_empty() {
1946 if plan.has_deadlock() {
1948 tracing::warn!(
1949 "Plan deadlock detected: {} pending steps with unresolvable dependencies",
1950 plan.pending_count()
1951 );
1952 }
1953 break;
1954 }
1955
1956 if ready.len() == 1 {
1957 let step_id = &ready[0];
1959 let step = plan
1960 .steps
1961 .iter()
1962 .find(|s| s.id == *step_id)
1963 .unwrap()
1964 .clone();
1965 let step_number = plan.steps.iter().position(|s| s.id == *step_id).unwrap() + 1;
1966
1967 if let Some(tx) = &event_tx {
1969 tx.send(AgentEvent::StepStart {
1970 step_id: step.id.clone(),
1971 description: step.content.clone(),
1972 step_number,
1973 total_steps,
1974 })
1975 .await
1976 .ok();
1977 }
1978
1979 plan.mark_status(&step.id, TaskStatus::InProgress);
1980
1981 let step_prompt = crate::prompts::render(
1982 crate::prompts::PLAN_EXECUTE_STEP,
1983 &[
1984 ("step_num", &step_number.to_string()),
1985 ("description", &step.content),
1986 ],
1987 );
1988
1989 match self
1990 .execute_loop(¤t_history, &step_prompt, None, event_tx.clone())
1991 .await
1992 {
1993 Ok(result) => {
1994 current_history = result.messages.clone();
1995 total_usage.prompt_tokens += result.usage.prompt_tokens;
1996 total_usage.completion_tokens += result.usage.completion_tokens;
1997 total_usage.total_tokens += result.usage.total_tokens;
1998 tool_calls_count += result.tool_calls_count;
1999 plan.mark_status(&step.id, TaskStatus::Completed);
2000
2001 if let Some(tx) = &event_tx {
2002 tx.send(AgentEvent::StepEnd {
2003 step_id: step.id.clone(),
2004 status: TaskStatus::Completed,
2005 step_number,
2006 total_steps,
2007 })
2008 .await
2009 .ok();
2010 }
2011 }
2012 Err(e) => {
2013 tracing::error!("Plan step '{}' failed: {}", step.id, e);
2014 plan.mark_status(&step.id, TaskStatus::Failed);
2015
2016 if let Some(tx) = &event_tx {
2017 tx.send(AgentEvent::StepEnd {
2018 step_id: step.id.clone(),
2019 status: TaskStatus::Failed,
2020 step_number,
2021 total_steps,
2022 })
2023 .await
2024 .ok();
2025 }
2026 }
2027 }
2028 } else {
2029 let ready_steps: Vec<_> = ready
2031 .iter()
2032 .map(|id| {
2033 let step = plan.steps.iter().find(|s| s.id == *id).unwrap().clone();
2034 let step_number = plan.steps.iter().position(|s| s.id == *id).unwrap() + 1;
2035 (step, step_number)
2036 })
2037 .collect();
2038
2039 for (step, step_number) in &ready_steps {
2041 plan.mark_status(&step.id, TaskStatus::InProgress);
2042 if let Some(tx) = &event_tx {
2043 tx.send(AgentEvent::StepStart {
2044 step_id: step.id.clone(),
2045 description: step.content.clone(),
2046 step_number: *step_number,
2047 total_steps,
2048 })
2049 .await
2050 .ok();
2051 }
2052 }
2053
2054 let mut join_set = tokio::task::JoinSet::new();
2056 for (step, step_number) in &ready_steps {
2057 let base_history = current_history.clone();
2058 let agent_clone = self.clone();
2059 let tx = event_tx.clone();
2060 let step_clone = step.clone();
2061 let sn = *step_number;
2062
2063 join_set.spawn(async move {
2064 let prompt = crate::prompts::render(
2065 crate::prompts::PLAN_EXECUTE_STEP,
2066 &[
2067 ("step_num", &sn.to_string()),
2068 ("description", &step_clone.content),
2069 ],
2070 );
2071 let result = agent_clone
2072 .execute_loop(&base_history, &prompt, None, tx)
2073 .await;
2074 (step_clone.id, sn, result)
2075 });
2076 }
2077
2078 let mut parallel_summaries = Vec::new();
2080 while let Some(join_result) = join_set.join_next().await {
2081 match join_result {
2082 Ok((step_id, step_number, step_result)) => match step_result {
2083 Ok(result) => {
2084 total_usage.prompt_tokens += result.usage.prompt_tokens;
2085 total_usage.completion_tokens += result.usage.completion_tokens;
2086 total_usage.total_tokens += result.usage.total_tokens;
2087 tool_calls_count += result.tool_calls_count;
2088 plan.mark_status(&step_id, TaskStatus::Completed);
2089
2090 parallel_summaries.push(format!(
2092 "- Step {} ({}): {}",
2093 step_number, step_id, result.text
2094 ));
2095
2096 if let Some(tx) = &event_tx {
2097 tx.send(AgentEvent::StepEnd {
2098 step_id,
2099 status: TaskStatus::Completed,
2100 step_number,
2101 total_steps,
2102 })
2103 .await
2104 .ok();
2105 }
2106 }
2107 Err(e) => {
2108 tracing::error!("Plan step '{}' failed: {}", step_id, e);
2109 plan.mark_status(&step_id, TaskStatus::Failed);
2110
2111 if let Some(tx) = &event_tx {
2112 tx.send(AgentEvent::StepEnd {
2113 step_id,
2114 status: TaskStatus::Failed,
2115 step_number,
2116 total_steps,
2117 })
2118 .await
2119 .ok();
2120 }
2121 }
2122 },
2123 Err(e) => {
2124 tracing::error!("JoinSet task panicked: {}", e);
2125 }
2126 }
2127 }
2128
2129 if !parallel_summaries.is_empty() {
2131 parallel_summaries.sort(); let results_text = parallel_summaries.join("\n");
2133 current_history.push(Message::user(&crate::prompts::render(
2134 crate::prompts::PLAN_PARALLEL_RESULTS,
2135 &[("results", &results_text)],
2136 )));
2137 }
2138 }
2139
2140 if self.config.goal_tracking {
2142 let completed = plan
2143 .steps
2144 .iter()
2145 .filter(|s| s.status == TaskStatus::Completed)
2146 .count();
2147 if let Some(tx) = &event_tx {
2148 tx.send(AgentEvent::GoalProgress {
2149 goal: plan.goal.clone(),
2150 progress: plan.progress(),
2151 completed_steps: completed,
2152 total_steps,
2153 })
2154 .await
2155 .ok();
2156 }
2157 }
2158 }
2159
2160 let final_text = current_history
2162 .last()
2163 .map(|m| {
2164 m.content
2165 .iter()
2166 .filter_map(|block| {
2167 if let crate::llm::ContentBlock::Text { text } = block {
2168 Some(text.as_str())
2169 } else {
2170 None
2171 }
2172 })
2173 .collect::<Vec<_>>()
2174 .join("\n")
2175 })
2176 .unwrap_or_default();
2177
2178 Ok(AgentResult {
2179 text: final_text,
2180 messages: current_history,
2181 usage: total_usage,
2182 tool_calls_count,
2183 })
2184 }
2185
2186 pub async fn extract_goal(&self, prompt: &str) -> Result<AgentGoal> {
2191 use crate::planning::LlmPlanner;
2192
2193 match LlmPlanner::extract_goal(&self.llm_client, prompt).await {
2194 Ok(goal) => Ok(goal),
2195 Err(e) => {
2196 tracing::warn!("LLM goal extraction failed, using fallback: {}", e);
2197 Ok(LlmPlanner::fallback_goal(prompt))
2198 }
2199 }
2200 }
2201
2202 pub async fn check_goal_achievement(
2207 &self,
2208 goal: &AgentGoal,
2209 current_state: &str,
2210 ) -> Result<bool> {
2211 use crate::planning::LlmPlanner;
2212
2213 match LlmPlanner::check_achievement(&self.llm_client, goal, current_state).await {
2214 Ok(result) => Ok(result.achieved),
2215 Err(e) => {
2216 tracing::warn!("LLM achievement check failed, using fallback: {}", e);
2217 let result = LlmPlanner::fallback_check_achievement(goal, current_state);
2218 Ok(result.achieved)
2219 }
2220 }
2221 }
2222}
2223
2224#[cfg(test)]
2225mod tests {
2226 use super::*;
2227 use crate::llm::{ContentBlock, StreamEvent};
2228 use crate::permissions::PermissionPolicy;
2229 use crate::tools::ToolExecutor;
2230 use std::path::PathBuf;
2231 use std::sync::atomic::{AtomicUsize, Ordering};
2232
2233 fn test_tool_context() -> ToolContext {
2235 ToolContext::new(PathBuf::from("/tmp"))
2236 }
2237
2238 #[test]
2239 fn test_agent_config_default() {
2240 let config = AgentConfig::default();
2241 assert!(config.system_prompt.is_none());
2242 assert!(config.tools.is_empty()); assert_eq!(config.max_tool_rounds, MAX_TOOL_ROUNDS);
2244 assert!(config.permission_policy.is_none());
2245 assert!(config.context_providers.is_empty());
2246 }
2247
2248 pub(crate) struct MockLlmClient {
2254 responses: std::sync::Mutex<Vec<LlmResponse>>,
2256 call_count: AtomicUsize,
2258 }
2259
2260 impl MockLlmClient {
2261 pub(crate) fn new(responses: Vec<LlmResponse>) -> Self {
2262 Self {
2263 responses: std::sync::Mutex::new(responses),
2264 call_count: AtomicUsize::new(0),
2265 }
2266 }
2267
2268 pub(crate) fn text_response(text: &str) -> LlmResponse {
2270 LlmResponse {
2271 message: Message {
2272 role: "assistant".to_string(),
2273 content: vec![ContentBlock::Text {
2274 text: text.to_string(),
2275 }],
2276 reasoning_content: None,
2277 },
2278 usage: TokenUsage {
2279 prompt_tokens: 10,
2280 completion_tokens: 5,
2281 total_tokens: 15,
2282 cache_read_tokens: None,
2283 cache_write_tokens: None,
2284 },
2285 stop_reason: Some("end_turn".to_string()),
2286 }
2287 }
2288
2289 pub(crate) fn tool_call_response(
2291 tool_id: &str,
2292 tool_name: &str,
2293 args: serde_json::Value,
2294 ) -> LlmResponse {
2295 LlmResponse {
2296 message: Message {
2297 role: "assistant".to_string(),
2298 content: vec![ContentBlock::ToolUse {
2299 id: tool_id.to_string(),
2300 name: tool_name.to_string(),
2301 input: args,
2302 }],
2303 reasoning_content: None,
2304 },
2305 usage: TokenUsage {
2306 prompt_tokens: 10,
2307 completion_tokens: 5,
2308 total_tokens: 15,
2309 cache_read_tokens: None,
2310 cache_write_tokens: None,
2311 },
2312 stop_reason: Some("tool_use".to_string()),
2313 }
2314 }
2315 }
2316
2317 #[async_trait::async_trait]
2318 impl LlmClient for MockLlmClient {
2319 async fn complete(
2320 &self,
2321 _messages: &[Message],
2322 _system: Option<&str>,
2323 _tools: &[ToolDefinition],
2324 ) -> Result<LlmResponse> {
2325 self.call_count.fetch_add(1, Ordering::SeqCst);
2326 let mut responses = self.responses.lock().unwrap();
2327 if responses.is_empty() {
2328 anyhow::bail!("No more mock responses available");
2329 }
2330 Ok(responses.remove(0))
2331 }
2332
2333 async fn complete_streaming(
2334 &self,
2335 _messages: &[Message],
2336 _system: Option<&str>,
2337 _tools: &[ToolDefinition],
2338 ) -> Result<mpsc::Receiver<StreamEvent>> {
2339 self.call_count.fetch_add(1, Ordering::SeqCst);
2340 let mut responses = self.responses.lock().unwrap();
2341 if responses.is_empty() {
2342 anyhow::bail!("No more mock responses available");
2343 }
2344 let response = responses.remove(0);
2345
2346 let (tx, rx) = mpsc::channel(10);
2347 tokio::spawn(async move {
2348 for block in &response.message.content {
2350 if let ContentBlock::Text { text } = block {
2351 tx.send(StreamEvent::TextDelta(text.clone())).await.ok();
2352 }
2353 }
2354 tx.send(StreamEvent::Done(response)).await.ok();
2355 });
2356
2357 Ok(rx)
2358 }
2359 }
2360
2361 #[tokio::test]
2366 async fn test_agent_simple_response() {
2367 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
2368 "Hello, I'm an AI assistant.",
2369 )]));
2370
2371 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2372 let config = AgentConfig::default();
2373
2374 let agent = AgentLoop::new(
2375 mock_client.clone(),
2376 tool_executor,
2377 test_tool_context(),
2378 config,
2379 );
2380 let result = agent.execute(&[], "Hello", None).await.unwrap();
2381
2382 assert_eq!(result.text, "Hello, I'm an AI assistant.");
2383 assert_eq!(result.tool_calls_count, 0);
2384 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 1);
2385 }
2386
2387 #[tokio::test]
2388 async fn test_agent_with_tool_call() {
2389 let mock_client = Arc::new(MockLlmClient::new(vec![
2390 MockLlmClient::tool_call_response(
2392 "tool-1",
2393 "bash",
2394 serde_json::json!({"command": "echo hello"}),
2395 ),
2396 MockLlmClient::text_response("The command output was: hello"),
2398 ]));
2399
2400 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2401 let config = AgentConfig::default();
2402
2403 let agent = AgentLoop::new(
2404 mock_client.clone(),
2405 tool_executor,
2406 test_tool_context(),
2407 config,
2408 );
2409 let result = agent.execute(&[], "Run echo hello", None).await.unwrap();
2410
2411 assert_eq!(result.text, "The command output was: hello");
2412 assert_eq!(result.tool_calls_count, 1);
2413 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 2);
2414 }
2415
2416 #[tokio::test]
2417 async fn test_agent_permission_deny() {
2418 let mock_client = Arc::new(MockLlmClient::new(vec![
2419 MockLlmClient::tool_call_response(
2421 "tool-1",
2422 "bash",
2423 serde_json::json!({"command": "rm -rf /tmp/test"}),
2424 ),
2425 MockLlmClient::text_response(
2427 "I cannot execute that command due to permission restrictions.",
2428 ),
2429 ]));
2430
2431 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2432
2433 let permission_policy = PermissionPolicy::new().deny("bash(rm:*)");
2435 let policy_lock = Arc::new(RwLock::new(permission_policy));
2436
2437 let config = AgentConfig {
2438 permission_policy: Some(policy_lock),
2439 ..Default::default()
2440 };
2441
2442 let (tx, mut rx) = mpsc::channel(100);
2443 let agent = AgentLoop::new(
2444 mock_client.clone(),
2445 tool_executor,
2446 test_tool_context(),
2447 config,
2448 );
2449 let result = agent.execute(&[], "Delete files", Some(tx)).await.unwrap();
2450
2451 let mut found_permission_denied = false;
2453 while let Ok(event) = rx.try_recv() {
2454 if let AgentEvent::PermissionDenied { tool_name, .. } = event {
2455 assert_eq!(tool_name, "bash");
2456 found_permission_denied = true;
2457 }
2458 }
2459 assert!(
2460 found_permission_denied,
2461 "Should have received PermissionDenied event"
2462 );
2463
2464 assert_eq!(result.tool_calls_count, 1);
2465 }
2466
2467 #[tokio::test]
2468 async fn test_agent_permission_allow() {
2469 let mock_client = Arc::new(MockLlmClient::new(vec![
2470 MockLlmClient::tool_call_response(
2472 "tool-1",
2473 "bash",
2474 serde_json::json!({"command": "echo hello"}),
2475 ),
2476 MockLlmClient::text_response("Done!"),
2478 ]));
2479
2480 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2481
2482 let permission_policy = PermissionPolicy::new()
2484 .allow("bash(echo:*)")
2485 .deny("bash(rm:*)");
2486 let policy_lock = Arc::new(RwLock::new(permission_policy));
2487
2488 let config = AgentConfig {
2489 permission_policy: Some(policy_lock),
2490 ..Default::default()
2491 };
2492
2493 let agent = AgentLoop::new(
2494 mock_client.clone(),
2495 tool_executor,
2496 test_tool_context(),
2497 config,
2498 );
2499 let result = agent.execute(&[], "Echo hello", None).await.unwrap();
2500
2501 assert_eq!(result.text, "Done!");
2502 assert_eq!(result.tool_calls_count, 1);
2503 }
2504
2505 #[tokio::test]
2506 async fn test_agent_streaming_events() {
2507 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
2508 "Hello!",
2509 )]));
2510
2511 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2512 let config = AgentConfig::default();
2513
2514 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2515 let (mut rx, handle) = agent.execute_streaming(&[], "Hi").await.unwrap();
2516
2517 let mut events = Vec::new();
2519 while let Some(event) = rx.recv().await {
2520 events.push(event);
2521 }
2522
2523 let result = handle.await.unwrap().unwrap();
2524 assert_eq!(result.text, "Hello!");
2525
2526 assert!(events.iter().any(|e| matches!(e, AgentEvent::Start { .. })));
2528 assert!(events.iter().any(|e| matches!(e, AgentEvent::End { .. })));
2529 }
2530
2531 #[tokio::test]
2532 async fn test_agent_max_tool_rounds() {
2533 let responses: Vec<LlmResponse> = (0..100)
2535 .map(|i| {
2536 MockLlmClient::tool_call_response(
2537 &format!("tool-{}", i),
2538 "bash",
2539 serde_json::json!({"command": "echo loop"}),
2540 )
2541 })
2542 .collect();
2543
2544 let mock_client = Arc::new(MockLlmClient::new(responses));
2545 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2546
2547 let config = AgentConfig {
2548 max_tool_rounds: 3,
2549 ..Default::default()
2550 };
2551
2552 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2553 let result = agent.execute(&[], "Loop forever", None).await;
2554
2555 assert!(result.is_err());
2557 assert!(result.unwrap_err().to_string().contains("Max tool rounds"));
2558 }
2559
2560 #[tokio::test]
2561 async fn test_agent_no_permission_policy_defaults_to_ask() {
2562 let mock_client = Arc::new(MockLlmClient::new(vec![
2565 MockLlmClient::tool_call_response(
2566 "tool-1",
2567 "bash",
2568 serde_json::json!({"command": "rm -rf /tmp/test"}),
2569 ),
2570 MockLlmClient::text_response("Denied!"),
2571 ]));
2572
2573 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2574 let config = AgentConfig {
2575 permission_policy: None, ..Default::default()
2578 };
2579
2580 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2581 let result = agent.execute(&[], "Delete", None).await.unwrap();
2582
2583 assert_eq!(result.text, "Denied!");
2585 assert_eq!(result.tool_calls_count, 1);
2586 }
2587
2588 #[tokio::test]
2589 async fn test_agent_permission_ask_without_cm_denies() {
2590 let mock_client = Arc::new(MockLlmClient::new(vec![
2593 MockLlmClient::tool_call_response(
2594 "tool-1",
2595 "bash",
2596 serde_json::json!({"command": "echo test"}),
2597 ),
2598 MockLlmClient::text_response("Denied!"),
2599 ]));
2600
2601 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2602
2603 let permission_policy = PermissionPolicy::new(); let policy_lock = Arc::new(RwLock::new(permission_policy));
2606
2607 let config = AgentConfig {
2608 permission_policy: Some(policy_lock),
2609 ..Default::default()
2611 };
2612
2613 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2614 let result = agent.execute(&[], "Echo", None).await.unwrap();
2615
2616 assert_eq!(result.text, "Denied!");
2618 assert!(result.tool_calls_count >= 1);
2620 }
2621
2622 #[tokio::test]
2627 async fn test_agent_hitl_approved() {
2628 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
2629 use tokio::sync::broadcast;
2630
2631 let mock_client = Arc::new(MockLlmClient::new(vec![
2632 MockLlmClient::tool_call_response(
2633 "tool-1",
2634 "bash",
2635 serde_json::json!({"command": "echo hello"}),
2636 ),
2637 MockLlmClient::text_response("Command executed!"),
2638 ]));
2639
2640 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2641
2642 let (event_tx, _event_rx) = broadcast::channel(100);
2644 let hitl_policy = ConfirmationPolicy {
2645 enabled: true,
2646 ..Default::default()
2647 };
2648 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
2649
2650 let permission_policy = PermissionPolicy::new(); let policy_lock = Arc::new(RwLock::new(permission_policy));
2653
2654 let config = AgentConfig {
2655 permission_policy: Some(policy_lock),
2656 confirmation_manager: Some(confirmation_manager.clone()),
2657 ..Default::default()
2658 };
2659
2660 let cm_clone = confirmation_manager.clone();
2662 tokio::spawn(async move {
2663 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
2665 cm_clone.confirm("tool-1", true, None).await.ok();
2667 });
2668
2669 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2670 let result = agent.execute(&[], "Run echo", None).await.unwrap();
2671
2672 assert_eq!(result.text, "Command executed!");
2673 assert_eq!(result.tool_calls_count, 1);
2674 }
2675
2676 #[tokio::test]
2677 async fn test_agent_hitl_rejected() {
2678 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
2679 use tokio::sync::broadcast;
2680
2681 let mock_client = Arc::new(MockLlmClient::new(vec![
2682 MockLlmClient::tool_call_response(
2683 "tool-1",
2684 "bash",
2685 serde_json::json!({"command": "rm -rf /"}),
2686 ),
2687 MockLlmClient::text_response("Understood, I won't do that."),
2688 ]));
2689
2690 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2691
2692 let (event_tx, _event_rx) = broadcast::channel(100);
2694 let hitl_policy = ConfirmationPolicy {
2695 enabled: true,
2696 ..Default::default()
2697 };
2698 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
2699
2700 let permission_policy = PermissionPolicy::new();
2702 let policy_lock = Arc::new(RwLock::new(permission_policy));
2703
2704 let config = AgentConfig {
2705 permission_policy: Some(policy_lock),
2706 confirmation_manager: Some(confirmation_manager.clone()),
2707 ..Default::default()
2708 };
2709
2710 let cm_clone = confirmation_manager.clone();
2712 tokio::spawn(async move {
2713 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
2714 cm_clone
2715 .confirm("tool-1", false, Some("Too dangerous".to_string()))
2716 .await
2717 .ok();
2718 });
2719
2720 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2721 let result = agent.execute(&[], "Delete everything", None).await.unwrap();
2722
2723 assert_eq!(result.text, "Understood, I won't do that.");
2725 }
2726
2727 #[tokio::test]
2728 async fn test_agent_hitl_timeout_reject() {
2729 use crate::hitl::{ConfirmationManager, ConfirmationPolicy, TimeoutAction};
2730 use tokio::sync::broadcast;
2731
2732 let mock_client = Arc::new(MockLlmClient::new(vec![
2733 MockLlmClient::tool_call_response(
2734 "tool-1",
2735 "bash",
2736 serde_json::json!({"command": "echo test"}),
2737 ),
2738 MockLlmClient::text_response("Timed out, I understand."),
2739 ]));
2740
2741 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2742
2743 let (event_tx, _event_rx) = broadcast::channel(100);
2745 let hitl_policy = ConfirmationPolicy {
2746 enabled: true,
2747 default_timeout_ms: 50, timeout_action: TimeoutAction::Reject,
2749 ..Default::default()
2750 };
2751 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
2752
2753 let permission_policy = PermissionPolicy::new();
2754 let policy_lock = Arc::new(RwLock::new(permission_policy));
2755
2756 let config = AgentConfig {
2757 permission_policy: Some(policy_lock),
2758 confirmation_manager: Some(confirmation_manager),
2759 ..Default::default()
2760 };
2761
2762 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2764 let result = agent.execute(&[], "Echo", None).await.unwrap();
2765
2766 assert_eq!(result.text, "Timed out, I understand.");
2768 }
2769
2770 #[tokio::test]
2771 async fn test_agent_hitl_timeout_auto_approve() {
2772 use crate::hitl::{ConfirmationManager, ConfirmationPolicy, TimeoutAction};
2773 use tokio::sync::broadcast;
2774
2775 let mock_client = Arc::new(MockLlmClient::new(vec![
2776 MockLlmClient::tool_call_response(
2777 "tool-1",
2778 "bash",
2779 serde_json::json!({"command": "echo hello"}),
2780 ),
2781 MockLlmClient::text_response("Auto-approved and executed!"),
2782 ]));
2783
2784 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2785
2786 let (event_tx, _event_rx) = broadcast::channel(100);
2788 let hitl_policy = ConfirmationPolicy {
2789 enabled: true,
2790 default_timeout_ms: 50, timeout_action: TimeoutAction::AutoApprove,
2792 ..Default::default()
2793 };
2794 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
2795
2796 let permission_policy = PermissionPolicy::new();
2797 let policy_lock = Arc::new(RwLock::new(permission_policy));
2798
2799 let config = AgentConfig {
2800 permission_policy: Some(policy_lock),
2801 confirmation_manager: Some(confirmation_manager),
2802 ..Default::default()
2803 };
2804
2805 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2807 let result = agent.execute(&[], "Echo", None).await.unwrap();
2808
2809 assert_eq!(result.text, "Auto-approved and executed!");
2811 assert_eq!(result.tool_calls_count, 1);
2812 }
2813
2814 #[tokio::test]
2815 async fn test_agent_hitl_confirmation_events() {
2816 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
2817 use tokio::sync::broadcast;
2818
2819 let mock_client = Arc::new(MockLlmClient::new(vec![
2820 MockLlmClient::tool_call_response(
2821 "tool-1",
2822 "bash",
2823 serde_json::json!({"command": "echo test"}),
2824 ),
2825 MockLlmClient::text_response("Done!"),
2826 ]));
2827
2828 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2829
2830 let (event_tx, mut event_rx) = broadcast::channel(100);
2832 let hitl_policy = ConfirmationPolicy {
2833 enabled: true,
2834 default_timeout_ms: 5000, ..Default::default()
2836 };
2837 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
2838
2839 let permission_policy = PermissionPolicy::new();
2840 let policy_lock = Arc::new(RwLock::new(permission_policy));
2841
2842 let config = AgentConfig {
2843 permission_policy: Some(policy_lock),
2844 confirmation_manager: Some(confirmation_manager.clone()),
2845 ..Default::default()
2846 };
2847
2848 let cm_clone = confirmation_manager.clone();
2850 let event_handle = tokio::spawn(async move {
2851 let mut events = Vec::new();
2852 while let Ok(event) = event_rx.recv().await {
2854 events.push(event.clone());
2855 if let AgentEvent::ConfirmationRequired { tool_id, .. } = event {
2856 cm_clone.confirm(&tool_id, true, None).await.ok();
2858 if let Ok(recv_event) = event_rx.recv().await {
2860 events.push(recv_event);
2861 }
2862 break;
2863 }
2864 }
2865 events
2866 });
2867
2868 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2869 let _result = agent.execute(&[], "Echo", None).await.unwrap();
2870
2871 let events = event_handle.await.unwrap();
2873 assert!(
2874 events
2875 .iter()
2876 .any(|e| matches!(e, AgentEvent::ConfirmationRequired { .. })),
2877 "Should have ConfirmationRequired event"
2878 );
2879 assert!(
2880 events
2881 .iter()
2882 .any(|e| matches!(e, AgentEvent::ConfirmationReceived { approved: true, .. })),
2883 "Should have ConfirmationReceived event with approved=true"
2884 );
2885 }
2886
2887 #[tokio::test]
2888 async fn test_agent_hitl_disabled_auto_executes() {
2889 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
2891 use tokio::sync::broadcast;
2892
2893 let mock_client = Arc::new(MockLlmClient::new(vec![
2894 MockLlmClient::tool_call_response(
2895 "tool-1",
2896 "bash",
2897 serde_json::json!({"command": "echo auto"}),
2898 ),
2899 MockLlmClient::text_response("Auto executed!"),
2900 ]));
2901
2902 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2903
2904 let (event_tx, _event_rx) = broadcast::channel(100);
2906 let hitl_policy = ConfirmationPolicy {
2907 enabled: false, ..Default::default()
2909 };
2910 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
2911
2912 let permission_policy = PermissionPolicy::new(); let policy_lock = Arc::new(RwLock::new(permission_policy));
2914
2915 let config = AgentConfig {
2916 permission_policy: Some(policy_lock),
2917 confirmation_manager: Some(confirmation_manager),
2918 ..Default::default()
2919 };
2920
2921 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2922 let result = agent.execute(&[], "Echo", None).await.unwrap();
2923
2924 assert_eq!(result.text, "Auto executed!");
2926 assert_eq!(result.tool_calls_count, 1);
2927 }
2928
2929 #[tokio::test]
2930 async fn test_agent_hitl_with_permission_deny_skips_hitl() {
2931 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
2933 use tokio::sync::broadcast;
2934
2935 let mock_client = Arc::new(MockLlmClient::new(vec![
2936 MockLlmClient::tool_call_response(
2937 "tool-1",
2938 "bash",
2939 serde_json::json!({"command": "rm -rf /"}),
2940 ),
2941 MockLlmClient::text_response("Blocked by permission."),
2942 ]));
2943
2944 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2945
2946 let (event_tx, mut event_rx) = broadcast::channel(100);
2948 let hitl_policy = ConfirmationPolicy {
2949 enabled: true,
2950 ..Default::default()
2951 };
2952 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
2953
2954 let permission_policy = PermissionPolicy::new().deny("bash(rm:*)");
2956 let policy_lock = Arc::new(RwLock::new(permission_policy));
2957
2958 let config = AgentConfig {
2959 permission_policy: Some(policy_lock),
2960 confirmation_manager: Some(confirmation_manager),
2961 ..Default::default()
2962 };
2963
2964 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2965 let result = agent.execute(&[], "Delete", None).await.unwrap();
2966
2967 assert_eq!(result.text, "Blocked by permission.");
2969
2970 let mut found_confirmation = false;
2972 while let Ok(event) = event_rx.try_recv() {
2973 if matches!(event, AgentEvent::ConfirmationRequired { .. }) {
2974 found_confirmation = true;
2975 }
2976 }
2977 assert!(
2978 !found_confirmation,
2979 "HITL should not be triggered when permission is Deny"
2980 );
2981 }
2982
2983 #[tokio::test]
2984 async fn test_agent_hitl_with_permission_allow_skips_hitl() {
2985 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
2988 use tokio::sync::broadcast;
2989
2990 let mock_client = Arc::new(MockLlmClient::new(vec![
2991 MockLlmClient::tool_call_response(
2992 "tool-1",
2993 "bash",
2994 serde_json::json!({"command": "echo hello"}),
2995 ),
2996 MockLlmClient::text_response("Allowed!"),
2997 ]));
2998
2999 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3000
3001 let (event_tx, mut event_rx) = broadcast::channel(100);
3003 let hitl_policy = ConfirmationPolicy {
3004 enabled: true,
3005 ..Default::default()
3006 };
3007 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3008
3009 let permission_policy = PermissionPolicy::new().allow("bash(echo:*)");
3011 let policy_lock = Arc::new(RwLock::new(permission_policy));
3012
3013 let config = AgentConfig {
3014 permission_policy: Some(policy_lock),
3015 confirmation_manager: Some(confirmation_manager.clone()),
3016 ..Default::default()
3017 };
3018
3019 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3020 let result = agent.execute(&[], "Echo", None).await.unwrap();
3021
3022 assert_eq!(result.text, "Allowed!");
3024
3025 let mut found_confirmation = false;
3027 while let Ok(event) = event_rx.try_recv() {
3028 if matches!(event, AgentEvent::ConfirmationRequired { .. }) {
3029 found_confirmation = true;
3030 }
3031 }
3032 assert!(
3033 !found_confirmation,
3034 "Permission Allow should skip HITL confirmation"
3035 );
3036 }
3037
3038 #[tokio::test]
3039 async fn test_agent_hitl_multiple_tool_calls() {
3040 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
3042 use tokio::sync::broadcast;
3043
3044 let mock_client = Arc::new(MockLlmClient::new(vec![
3045 LlmResponse {
3047 message: Message {
3048 role: "assistant".to_string(),
3049 content: vec![
3050 ContentBlock::ToolUse {
3051 id: "tool-1".to_string(),
3052 name: "bash".to_string(),
3053 input: serde_json::json!({"command": "echo first"}),
3054 },
3055 ContentBlock::ToolUse {
3056 id: "tool-2".to_string(),
3057 name: "bash".to_string(),
3058 input: serde_json::json!({"command": "echo second"}),
3059 },
3060 ],
3061 reasoning_content: None,
3062 },
3063 usage: TokenUsage {
3064 prompt_tokens: 10,
3065 completion_tokens: 5,
3066 total_tokens: 15,
3067 cache_read_tokens: None,
3068 cache_write_tokens: None,
3069 },
3070 stop_reason: Some("tool_use".to_string()),
3071 },
3072 MockLlmClient::text_response("Both executed!"),
3073 ]));
3074
3075 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3076
3077 let (event_tx, _event_rx) = broadcast::channel(100);
3079 let hitl_policy = ConfirmationPolicy {
3080 enabled: true,
3081 default_timeout_ms: 5000,
3082 ..Default::default()
3083 };
3084 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3085
3086 let permission_policy = PermissionPolicy::new(); let policy_lock = Arc::new(RwLock::new(permission_policy));
3088
3089 let config = AgentConfig {
3090 permission_policy: Some(policy_lock),
3091 confirmation_manager: Some(confirmation_manager.clone()),
3092 ..Default::default()
3093 };
3094
3095 let cm_clone = confirmation_manager.clone();
3097 tokio::spawn(async move {
3098 tokio::time::sleep(std::time::Duration::from_millis(30)).await;
3099 cm_clone.confirm("tool-1", true, None).await.ok();
3100 tokio::time::sleep(std::time::Duration::from_millis(30)).await;
3101 cm_clone.confirm("tool-2", true, None).await.ok();
3102 });
3103
3104 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3105 let result = agent.execute(&[], "Run both", None).await.unwrap();
3106
3107 assert_eq!(result.text, "Both executed!");
3108 assert_eq!(result.tool_calls_count, 2);
3109 }
3110
3111 #[tokio::test]
3112 async fn test_agent_hitl_partial_approval() {
3113 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
3115 use tokio::sync::broadcast;
3116
3117 let mock_client = Arc::new(MockLlmClient::new(vec![
3118 LlmResponse {
3120 message: Message {
3121 role: "assistant".to_string(),
3122 content: vec![
3123 ContentBlock::ToolUse {
3124 id: "tool-1".to_string(),
3125 name: "bash".to_string(),
3126 input: serde_json::json!({"command": "echo safe"}),
3127 },
3128 ContentBlock::ToolUse {
3129 id: "tool-2".to_string(),
3130 name: "bash".to_string(),
3131 input: serde_json::json!({"command": "rm -rf /"}),
3132 },
3133 ],
3134 reasoning_content: None,
3135 },
3136 usage: TokenUsage {
3137 prompt_tokens: 10,
3138 completion_tokens: 5,
3139 total_tokens: 15,
3140 cache_read_tokens: None,
3141 cache_write_tokens: None,
3142 },
3143 stop_reason: Some("tool_use".to_string()),
3144 },
3145 MockLlmClient::text_response("First worked, second rejected."),
3146 ]));
3147
3148 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3149
3150 let (event_tx, _event_rx) = broadcast::channel(100);
3151 let hitl_policy = ConfirmationPolicy {
3152 enabled: true,
3153 default_timeout_ms: 5000,
3154 ..Default::default()
3155 };
3156 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3157
3158 let permission_policy = PermissionPolicy::new();
3159 let policy_lock = Arc::new(RwLock::new(permission_policy));
3160
3161 let config = AgentConfig {
3162 permission_policy: Some(policy_lock),
3163 confirmation_manager: Some(confirmation_manager.clone()),
3164 ..Default::default()
3165 };
3166
3167 let cm_clone = confirmation_manager.clone();
3169 tokio::spawn(async move {
3170 tokio::time::sleep(std::time::Duration::from_millis(30)).await;
3171 cm_clone.confirm("tool-1", true, None).await.ok();
3172 tokio::time::sleep(std::time::Duration::from_millis(30)).await;
3173 cm_clone
3174 .confirm("tool-2", false, Some("Dangerous".to_string()))
3175 .await
3176 .ok();
3177 });
3178
3179 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3180 let result = agent.execute(&[], "Run both", None).await.unwrap();
3181
3182 assert_eq!(result.text, "First worked, second rejected.");
3183 assert_eq!(result.tool_calls_count, 2);
3184 }
3185
3186 #[tokio::test]
3187 async fn test_agent_hitl_yolo_mode_auto_approves() {
3188 use crate::hitl::{ConfirmationManager, ConfirmationPolicy, SessionLane};
3190 use tokio::sync::broadcast;
3191
3192 let mock_client = Arc::new(MockLlmClient::new(vec![
3193 MockLlmClient::tool_call_response(
3194 "tool-1",
3195 "read", serde_json::json!({"path": "/tmp/test.txt"}),
3197 ),
3198 MockLlmClient::text_response("File read!"),
3199 ]));
3200
3201 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3202
3203 let (event_tx, mut event_rx) = broadcast::channel(100);
3205 let mut yolo_lanes = std::collections::HashSet::new();
3206 yolo_lanes.insert(SessionLane::Query);
3207 let hitl_policy = ConfirmationPolicy {
3208 enabled: true,
3209 yolo_lanes, ..Default::default()
3211 };
3212 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3213
3214 let permission_policy = PermissionPolicy::new();
3215 let policy_lock = Arc::new(RwLock::new(permission_policy));
3216
3217 let config = AgentConfig {
3218 permission_policy: Some(policy_lock),
3219 confirmation_manager: Some(confirmation_manager),
3220 ..Default::default()
3221 };
3222
3223 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3224 let result = agent.execute(&[], "Read file", None).await.unwrap();
3225
3226 assert_eq!(result.text, "File read!");
3228
3229 let mut found_confirmation = false;
3231 while let Ok(event) = event_rx.try_recv() {
3232 if matches!(event, AgentEvent::ConfirmationRequired { .. }) {
3233 found_confirmation = true;
3234 }
3235 }
3236 assert!(
3237 !found_confirmation,
3238 "YOLO mode should not trigger confirmation"
3239 );
3240 }
3241
3242 #[tokio::test]
3243 async fn test_agent_config_with_all_options() {
3244 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
3245 use tokio::sync::broadcast;
3246
3247 let (event_tx, _) = broadcast::channel(100);
3248 let hitl_policy = ConfirmationPolicy::default();
3249 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3250
3251 let permission_policy = PermissionPolicy::new().allow("bash(*)");
3252 let policy_lock = Arc::new(RwLock::new(permission_policy));
3253
3254 let config = AgentConfig {
3255 system_prompt: Some("Test system prompt".to_string()),
3256 tools: vec![],
3257 max_tool_rounds: 10,
3258 permission_policy: Some(policy_lock),
3259 confirmation_manager: Some(confirmation_manager),
3260 context_providers: vec![],
3261 planning_enabled: false,
3262 goal_tracking: false,
3263 skill_tool_filters: vec![],
3264 hook_engine: None,
3265 };
3266
3267 assert_eq!(config.system_prompt, Some("Test system prompt".to_string()));
3268 assert_eq!(config.max_tool_rounds, 10);
3269 assert!(config.permission_policy.is_some());
3270 assert!(config.confirmation_manager.is_some());
3271 assert!(config.context_providers.is_empty());
3272
3273 let debug_str = format!("{:?}", config);
3275 assert!(debug_str.contains("AgentConfig"));
3276 assert!(debug_str.contains("permission_policy: true"));
3277 assert!(debug_str.contains("confirmation_manager: true"));
3278 assert!(debug_str.contains("context_providers: 0"));
3279 }
3280
3281 use crate::context::{ContextItem, ContextType};
3286
3287 struct MockContextProvider {
3289 name: String,
3290 items: Vec<ContextItem>,
3291 on_turn_calls: std::sync::Arc<tokio::sync::RwLock<Vec<(String, String, String)>>>,
3292 }
3293
3294 impl MockContextProvider {
3295 fn new(name: &str) -> Self {
3296 Self {
3297 name: name.to_string(),
3298 items: Vec::new(),
3299 on_turn_calls: std::sync::Arc::new(tokio::sync::RwLock::new(Vec::new())),
3300 }
3301 }
3302
3303 fn with_items(mut self, items: Vec<ContextItem>) -> Self {
3304 self.items = items;
3305 self
3306 }
3307 }
3308
3309 #[async_trait::async_trait]
3310 impl ContextProvider for MockContextProvider {
3311 fn name(&self) -> &str {
3312 &self.name
3313 }
3314
3315 async fn query(&self, _query: &ContextQuery) -> anyhow::Result<ContextResult> {
3316 let mut result = ContextResult::new(&self.name);
3317 for item in &self.items {
3318 result.add_item(item.clone());
3319 }
3320 Ok(result)
3321 }
3322
3323 async fn on_turn_complete(
3324 &self,
3325 session_id: &str,
3326 prompt: &str,
3327 response: &str,
3328 ) -> anyhow::Result<()> {
3329 let mut calls = self.on_turn_calls.write().await;
3330 calls.push((
3331 session_id.to_string(),
3332 prompt.to_string(),
3333 response.to_string(),
3334 ));
3335 Ok(())
3336 }
3337 }
3338
3339 #[tokio::test]
3340 async fn test_agent_with_context_provider() {
3341 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
3342 "Response using context",
3343 )]));
3344
3345 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3346
3347 let provider =
3348 MockContextProvider::new("test-provider").with_items(vec![ContextItem::new(
3349 "ctx-1",
3350 ContextType::Resource,
3351 "Relevant context here",
3352 )
3353 .with_source("test://docs/example")]);
3354
3355 let config = AgentConfig {
3356 system_prompt: Some("You are helpful.".to_string()),
3357 context_providers: vec![Arc::new(provider)],
3358 ..Default::default()
3359 };
3360
3361 let agent = AgentLoop::new(
3362 mock_client.clone(),
3363 tool_executor,
3364 test_tool_context(),
3365 config,
3366 );
3367 let result = agent.execute(&[], "What is X?", None).await.unwrap();
3368
3369 assert_eq!(result.text, "Response using context");
3370 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 1);
3371 }
3372
3373 #[tokio::test]
3374 async fn test_agent_context_provider_events() {
3375 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
3376 "Answer",
3377 )]));
3378
3379 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3380
3381 let provider =
3382 MockContextProvider::new("event-provider").with_items(vec![ContextItem::new(
3383 "item-1",
3384 ContextType::Memory,
3385 "Memory content",
3386 )
3387 .with_token_count(50)]);
3388
3389 let config = AgentConfig {
3390 context_providers: vec![Arc::new(provider)],
3391 ..Default::default()
3392 };
3393
3394 let (tx, mut rx) = mpsc::channel(100);
3395 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3396 let _result = agent.execute(&[], "Test prompt", Some(tx)).await.unwrap();
3397
3398 let mut events = Vec::new();
3400 while let Ok(event) = rx.try_recv() {
3401 events.push(event);
3402 }
3403
3404 assert!(
3406 events
3407 .iter()
3408 .any(|e| matches!(e, AgentEvent::ContextResolving { .. })),
3409 "Should have ContextResolving event"
3410 );
3411 assert!(
3412 events
3413 .iter()
3414 .any(|e| matches!(e, AgentEvent::ContextResolved { .. })),
3415 "Should have ContextResolved event"
3416 );
3417
3418 for event in &events {
3420 if let AgentEvent::ContextResolved {
3421 total_items,
3422 total_tokens,
3423 } = event
3424 {
3425 assert_eq!(*total_items, 1);
3426 assert_eq!(*total_tokens, 50);
3427 }
3428 }
3429 }
3430
3431 #[tokio::test]
3432 async fn test_agent_multiple_context_providers() {
3433 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
3434 "Combined response",
3435 )]));
3436
3437 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3438
3439 let provider1 = MockContextProvider::new("provider-1").with_items(vec![ContextItem::new(
3440 "p1-1",
3441 ContextType::Resource,
3442 "Resource from P1",
3443 )
3444 .with_token_count(100)]);
3445
3446 let provider2 = MockContextProvider::new("provider-2").with_items(vec![
3447 ContextItem::new("p2-1", ContextType::Memory, "Memory from P2").with_token_count(50),
3448 ContextItem::new("p2-2", ContextType::Skill, "Skill from P2").with_token_count(75),
3449 ]);
3450
3451 let config = AgentConfig {
3452 system_prompt: Some("Base system prompt.".to_string()),
3453 context_providers: vec![Arc::new(provider1), Arc::new(provider2)],
3454 ..Default::default()
3455 };
3456
3457 let (tx, mut rx) = mpsc::channel(100);
3458 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3459 let result = agent.execute(&[], "Query", Some(tx)).await.unwrap();
3460
3461 assert_eq!(result.text, "Combined response");
3462
3463 while let Ok(event) = rx.try_recv() {
3465 if let AgentEvent::ContextResolved {
3466 total_items,
3467 total_tokens,
3468 } = event
3469 {
3470 assert_eq!(total_items, 3); assert_eq!(total_tokens, 225); }
3473 }
3474 }
3475
3476 #[tokio::test]
3477 async fn test_agent_no_context_providers() {
3478 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
3479 "No context",
3480 )]));
3481
3482 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3483
3484 let config = AgentConfig::default();
3486
3487 let (tx, mut rx) = mpsc::channel(100);
3488 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3489 let result = agent.execute(&[], "Simple prompt", Some(tx)).await.unwrap();
3490
3491 assert_eq!(result.text, "No context");
3492
3493 let mut events = Vec::new();
3495 while let Ok(event) = rx.try_recv() {
3496 events.push(event);
3497 }
3498
3499 assert!(
3500 !events
3501 .iter()
3502 .any(|e| matches!(e, AgentEvent::ContextResolving { .. })),
3503 "Should NOT have ContextResolving event"
3504 );
3505 }
3506
3507 #[tokio::test]
3508 async fn test_agent_context_on_turn_complete() {
3509 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
3510 "Final response",
3511 )]));
3512
3513 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3514
3515 let provider = Arc::new(MockContextProvider::new("memory-provider"));
3516 let on_turn_calls = provider.on_turn_calls.clone();
3517
3518 let config = AgentConfig {
3519 context_providers: vec![provider],
3520 ..Default::default()
3521 };
3522
3523 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3524
3525 let result = agent
3527 .execute_with_session(&[], "User prompt", Some("sess-123"), None)
3528 .await
3529 .unwrap();
3530
3531 assert_eq!(result.text, "Final response");
3532
3533 let calls = on_turn_calls.read().await;
3535 assert_eq!(calls.len(), 1);
3536 assert_eq!(calls[0].0, "sess-123");
3537 assert_eq!(calls[0].1, "User prompt");
3538 assert_eq!(calls[0].2, "Final response");
3539 }
3540
3541 #[tokio::test]
3542 async fn test_agent_context_on_turn_complete_no_session() {
3543 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
3544 "Response",
3545 )]));
3546
3547 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3548
3549 let provider = Arc::new(MockContextProvider::new("memory-provider"));
3550 let on_turn_calls = provider.on_turn_calls.clone();
3551
3552 let config = AgentConfig {
3553 context_providers: vec![provider],
3554 ..Default::default()
3555 };
3556
3557 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3558
3559 let _result = agent.execute(&[], "Prompt", None).await.unwrap();
3561
3562 let calls = on_turn_calls.read().await;
3564 assert!(calls.is_empty());
3565 }
3566
3567 #[tokio::test]
3568 async fn test_agent_build_augmented_system_prompt() {
3569 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response("OK")]));
3570
3571 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3572
3573 let provider = MockContextProvider::new("test").with_items(vec![ContextItem::new(
3574 "doc-1",
3575 ContextType::Resource,
3576 "Auth uses JWT tokens.",
3577 )
3578 .with_source("viking://docs/auth")]);
3579
3580 let config = AgentConfig {
3581 system_prompt: Some("You are helpful.".to_string()),
3582 context_providers: vec![Arc::new(provider)],
3583 ..Default::default()
3584 };
3585
3586 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3587
3588 let context_results = agent.resolve_context("test", None).await;
3590 let augmented = agent.build_augmented_system_prompt(&context_results);
3591
3592 let augmented_str = augmented.unwrap();
3593 assert!(augmented_str.contains("You are helpful."));
3594 assert!(augmented_str.contains("<context source=\"viking://docs/auth\" type=\"Resource\">"));
3595 assert!(augmented_str.contains("Auth uses JWT tokens."));
3596 }
3597
3598 async fn collect_events(mut rx: mpsc::Receiver<AgentEvent>) -> Vec<AgentEvent> {
3604 let mut events = Vec::new();
3605 while let Ok(event) = rx.try_recv() {
3606 events.push(event);
3607 }
3608 while let Some(event) = rx.recv().await {
3610 events.push(event);
3611 }
3612 events
3613 }
3614
3615 #[tokio::test]
3616 async fn test_agent_multi_turn_tool_chain() {
3617 let mock_client = Arc::new(MockLlmClient::new(vec![
3619 MockLlmClient::tool_call_response(
3621 "t1",
3622 "bash",
3623 serde_json::json!({"command": "echo step1"}),
3624 ),
3625 MockLlmClient::tool_call_response(
3627 "t2",
3628 "bash",
3629 serde_json::json!({"command": "echo step2"}),
3630 ),
3631 MockLlmClient::text_response("Completed both steps: step1 then step2"),
3633 ]));
3634
3635 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3636 let config = AgentConfig::default();
3637
3638 let agent = AgentLoop::new(
3639 mock_client.clone(),
3640 tool_executor,
3641 test_tool_context(),
3642 config,
3643 );
3644 let result = agent.execute(&[], "Run two steps", None).await.unwrap();
3645
3646 assert_eq!(result.text, "Completed both steps: step1 then step2");
3647 assert_eq!(result.tool_calls_count, 2);
3648 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 3);
3649
3650 assert_eq!(result.messages[0].role, "user");
3652 assert_eq!(result.messages[1].role, "assistant"); assert_eq!(result.messages[2].role, "user"); assert_eq!(result.messages[3].role, "assistant"); assert_eq!(result.messages[4].role, "user"); assert_eq!(result.messages[5].role, "assistant"); assert_eq!(result.messages.len(), 6);
3658 }
3659
3660 #[tokio::test]
3661 async fn test_agent_conversation_history_preserved() {
3662 let existing_history = vec![
3664 Message::user("What is Rust?"),
3665 Message {
3666 role: "assistant".to_string(),
3667 content: vec![ContentBlock::Text {
3668 text: "Rust is a systems programming language.".to_string(),
3669 }],
3670 reasoning_content: None,
3671 },
3672 ];
3673
3674 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
3675 "Rust was created by Graydon Hoare at Mozilla.",
3676 )]));
3677
3678 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3679 let agent = AgentLoop::new(
3680 mock_client.clone(),
3681 tool_executor,
3682 test_tool_context(),
3683 AgentConfig::default(),
3684 );
3685
3686 let result = agent
3687 .execute(&existing_history, "Who created it?", None)
3688 .await
3689 .unwrap();
3690
3691 assert_eq!(result.messages.len(), 4);
3693 assert_eq!(result.messages[0].text(), "What is Rust?");
3694 assert_eq!(
3695 result.messages[1].text(),
3696 "Rust is a systems programming language."
3697 );
3698 assert_eq!(result.messages[2].text(), "Who created it?");
3699 assert_eq!(
3700 result.messages[3].text(),
3701 "Rust was created by Graydon Hoare at Mozilla."
3702 );
3703 }
3704
3705 #[tokio::test]
3706 async fn test_agent_event_stream_completeness() {
3707 let mock_client = Arc::new(MockLlmClient::new(vec![
3709 MockLlmClient::tool_call_response(
3710 "t1",
3711 "bash",
3712 serde_json::json!({"command": "echo hi"}),
3713 ),
3714 MockLlmClient::text_response("Done"),
3715 ]));
3716
3717 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3718 let agent = AgentLoop::new(
3719 mock_client,
3720 tool_executor,
3721 test_tool_context(),
3722 AgentConfig::default(),
3723 );
3724
3725 let (tx, rx) = mpsc::channel(100);
3726 let result = agent.execute(&[], "Say hi", Some(tx)).await.unwrap();
3727 assert_eq!(result.text, "Done");
3728
3729 let events = collect_events(rx).await;
3730
3731 let event_types: Vec<&str> = events
3733 .iter()
3734 .map(|e| match e {
3735 AgentEvent::Start { .. } => "Start",
3736 AgentEvent::TurnStart { .. } => "TurnStart",
3737 AgentEvent::TurnEnd { .. } => "TurnEnd",
3738 AgentEvent::ToolEnd { .. } => "ToolEnd",
3739 AgentEvent::End { .. } => "End",
3740 _ => "Other",
3741 })
3742 .collect();
3743
3744 assert_eq!(event_types.first(), Some(&"Start"));
3746 assert_eq!(event_types.last(), Some(&"End"));
3747
3748 let turn_starts = event_types.iter().filter(|&&t| t == "TurnStart").count();
3750 assert_eq!(turn_starts, 2);
3751
3752 let tool_ends = event_types.iter().filter(|&&t| t == "ToolEnd").count();
3754 assert_eq!(tool_ends, 1);
3755 }
3756
3757 #[tokio::test]
3758 async fn test_agent_multiple_tools_single_turn() {
3759 let mock_client = Arc::new(MockLlmClient::new(vec![
3761 LlmResponse {
3762 message: Message {
3763 role: "assistant".to_string(),
3764 content: vec![
3765 ContentBlock::ToolUse {
3766 id: "t1".to_string(),
3767 name: "bash".to_string(),
3768 input: serde_json::json!({"command": "echo first"}),
3769 },
3770 ContentBlock::ToolUse {
3771 id: "t2".to_string(),
3772 name: "bash".to_string(),
3773 input: serde_json::json!({"command": "echo second"}),
3774 },
3775 ],
3776 reasoning_content: None,
3777 },
3778 usage: TokenUsage {
3779 prompt_tokens: 10,
3780 completion_tokens: 5,
3781 total_tokens: 15,
3782 cache_read_tokens: None,
3783 cache_write_tokens: None,
3784 },
3785 stop_reason: Some("tool_use".to_string()),
3786 },
3787 MockLlmClient::text_response("Both commands ran"),
3788 ]));
3789
3790 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3791 let agent = AgentLoop::new(
3792 mock_client.clone(),
3793 tool_executor,
3794 test_tool_context(),
3795 AgentConfig::default(),
3796 );
3797
3798 let result = agent.execute(&[], "Run both", None).await.unwrap();
3799
3800 assert_eq!(result.text, "Both commands ran");
3801 assert_eq!(result.tool_calls_count, 2);
3802 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 2); assert_eq!(result.messages[0].role, "user");
3806 assert_eq!(result.messages[1].role, "assistant");
3807 assert_eq!(result.messages[2].role, "user"); assert_eq!(result.messages[3].role, "user"); assert_eq!(result.messages[4].role, "assistant");
3810 }
3811
3812 #[tokio::test]
3813 async fn test_agent_token_usage_accumulation() {
3814 let mock_client = Arc::new(MockLlmClient::new(vec![
3816 MockLlmClient::tool_call_response(
3817 "t1",
3818 "bash",
3819 serde_json::json!({"command": "echo x"}),
3820 ),
3821 MockLlmClient::text_response("Done"),
3822 ]));
3823
3824 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3825 let agent = AgentLoop::new(
3826 mock_client,
3827 tool_executor,
3828 test_tool_context(),
3829 AgentConfig::default(),
3830 );
3831
3832 let result = agent.execute(&[], "test", None).await.unwrap();
3833
3834 assert_eq!(result.usage.prompt_tokens, 20);
3837 assert_eq!(result.usage.completion_tokens, 10);
3838 assert_eq!(result.usage.total_tokens, 30);
3839 }
3840
3841 #[tokio::test]
3842 async fn test_agent_system_prompt_passed() {
3843 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
3845 "I am a coding assistant.",
3846 )]));
3847
3848 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3849 let config = AgentConfig {
3850 system_prompt: Some("You are a coding assistant.".to_string()),
3851 ..Default::default()
3852 };
3853
3854 let agent = AgentLoop::new(
3855 mock_client.clone(),
3856 tool_executor,
3857 test_tool_context(),
3858 config,
3859 );
3860 let result = agent.execute(&[], "What are you?", None).await.unwrap();
3861
3862 assert_eq!(result.text, "I am a coding assistant.");
3863 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 1);
3864 }
3865
3866 #[tokio::test]
3867 async fn test_agent_max_rounds_with_persistent_tool_calls() {
3868 let mut responses = Vec::new();
3870 for i in 0..15 {
3871 responses.push(MockLlmClient::tool_call_response(
3872 &format!("t{}", i),
3873 "bash",
3874 serde_json::json!({"command": format!("echo round{}", i)}),
3875 ));
3876 }
3877
3878 let mock_client = Arc::new(MockLlmClient::new(responses));
3879 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3880 let config = AgentConfig {
3881 max_tool_rounds: 5,
3882 ..Default::default()
3883 };
3884
3885 let agent = AgentLoop::new(
3886 mock_client.clone(),
3887 tool_executor,
3888 test_tool_context(),
3889 config,
3890 );
3891 let result = agent.execute(&[], "Loop forever", None).await;
3892
3893 assert!(result.is_err());
3894 let err = result.unwrap_err().to_string();
3895 assert!(err.contains("Max tool rounds (5) exceeded"));
3896 }
3897
3898 #[tokio::test]
3899 async fn test_agent_end_event_contains_final_text() {
3900 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
3901 "Final answer here",
3902 )]));
3903
3904 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3905 let agent = AgentLoop::new(
3906 mock_client,
3907 tool_executor,
3908 test_tool_context(),
3909 AgentConfig::default(),
3910 );
3911
3912 let (tx, rx) = mpsc::channel(100);
3913 agent.execute(&[], "test", Some(tx)).await.unwrap();
3914
3915 let events = collect_events(rx).await;
3916 let end_event = events.iter().find(|e| matches!(e, AgentEvent::End { .. }));
3917 assert!(end_event.is_some());
3918
3919 if let AgentEvent::End { text, usage } = end_event.unwrap() {
3920 assert_eq!(text, "Final answer here");
3921 assert_eq!(usage.total_tokens, 15);
3922 }
3923 }
3924}
3925
3926#[cfg(test)]
3927mod extra_agent_tests {
3928 use super::*;
3929 use crate::agent::tests::MockLlmClient;
3930 use crate::llm::{ContentBlock, StreamEvent};
3931 use crate::queue::SessionQueueConfig;
3932 use crate::tools::ToolExecutor;
3933 use std::path::PathBuf;
3934 use std::sync::atomic::{AtomicUsize, Ordering};
3935
3936 fn test_tool_context() -> ToolContext {
3937 ToolContext::new(PathBuf::from("/tmp"))
3938 }
3939
3940 #[test]
3945 fn test_agent_config_debug() {
3946 let config = AgentConfig {
3947 system_prompt: Some("You are helpful".to_string()),
3948 tools: vec![],
3949 max_tool_rounds: 10,
3950 permission_policy: None,
3951 confirmation_manager: None,
3952 context_providers: vec![],
3953 planning_enabled: true,
3954 goal_tracking: false,
3955 skill_tool_filters: vec![],
3956 hook_engine: None,
3957 };
3958 let debug = format!("{:?}", config);
3959 assert!(debug.contains("AgentConfig"));
3960 assert!(debug.contains("planning_enabled"));
3961 }
3962
3963 #[test]
3964 fn test_agent_config_default_values() {
3965 let config = AgentConfig::default();
3966 assert_eq!(config.max_tool_rounds, MAX_TOOL_ROUNDS);
3967 assert!(!config.planning_enabled);
3968 assert!(!config.goal_tracking);
3969 assert!(config.context_providers.is_empty());
3970 assert!(config.skill_tool_filters.is_empty());
3971 }
3972
3973 #[tokio::test]
3974 async fn test_agent_skill_tool_filters_blocks_unauthorized() {
3975 use crate::tools::skill::Skill;
3978
3979 let mock_client = Arc::new(MockLlmClient::new(vec![
3980 MockLlmClient::tool_call_response(
3981 "tool-1",
3982 "bash",
3983 serde_json::json!({"command": "rm -rf /"}),
3984 ),
3985 MockLlmClient::text_response("Blocked!"),
3986 ]));
3987
3988 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3989
3990 let skill = Skill {
3992 name: "read-only".to_string(),
3993 description: "Read-only skill".to_string(),
3994 content: "Read files only".to_string(),
3995 allowed_tools: Some("read(*)".to_string()),
3996 disable_model_invocation: false,
3997 kind: crate::tools::SkillKind::Instruction,
3998 };
3999
4000 let policy = PermissionPolicy::new().allow("bash(*)");
4002 let policy_lock = Arc::new(RwLock::new(policy));
4003
4004 let (event_tx, _) = tokio::sync::broadcast::channel(10);
4006 let cm = Arc::new(crate::hitl::ConfirmationManager::new(
4007 crate::hitl::ConfirmationPolicy::default(), event_tx,
4009 ));
4010
4011 let config = AgentConfig {
4012 permission_policy: Some(policy_lock),
4013 confirmation_manager: Some(cm),
4014 skill_tool_filters: vec![skill],
4015 ..Default::default()
4016 };
4017
4018 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
4019 let result = agent.execute(&[], "Delete", None).await.unwrap();
4020
4021 assert_eq!(result.text, "Blocked!");
4023 }
4024
4025 #[tokio::test]
4026 async fn test_agent_skill_tool_filters_allows_authorized() {
4027 use crate::tools::skill::Skill;
4029
4030 let mock_client = Arc::new(MockLlmClient::new(vec![
4031 MockLlmClient::tool_call_response(
4032 "tool-1",
4033 "bash",
4034 serde_json::json!({"command": "echo hello"}),
4035 ),
4036 MockLlmClient::text_response("Allowed!"),
4037 ]));
4038
4039 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4040
4041 let skill = Skill {
4043 name: "bash-skill".to_string(),
4044 description: "Bash skill".to_string(),
4045 content: "Run bash".to_string(),
4046 allowed_tools: Some("bash(*)".to_string()),
4047 disable_model_invocation: false,
4048 kind: crate::tools::SkillKind::Instruction,
4049 };
4050
4051 let policy = PermissionPolicy::new().allow("bash(*)");
4053 let policy_lock = Arc::new(RwLock::new(policy));
4054
4055 let (event_tx, _) = tokio::sync::broadcast::channel(10);
4057 let cm = Arc::new(crate::hitl::ConfirmationManager::new(
4058 crate::hitl::ConfirmationPolicy::default(), event_tx,
4060 ));
4061
4062 let config = AgentConfig {
4063 permission_policy: Some(policy_lock),
4064 confirmation_manager: Some(cm),
4065 skill_tool_filters: vec![skill],
4066 ..Default::default()
4067 };
4068
4069 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
4070 let result = agent.execute(&[], "Echo", None).await.unwrap();
4071
4072 assert_eq!(result.text, "Allowed!");
4074 }
4075
4076 #[test]
4081 fn test_agent_event_serialize_start() {
4082 let event = AgentEvent::Start {
4083 prompt: "Hello".to_string(),
4084 };
4085 let json = serde_json::to_string(&event).unwrap();
4086 assert!(json.contains("agent_start"));
4087 assert!(json.contains("Hello"));
4088 }
4089
4090 #[test]
4091 fn test_agent_event_serialize_text_delta() {
4092 let event = AgentEvent::TextDelta {
4093 text: "chunk".to_string(),
4094 };
4095 let json = serde_json::to_string(&event).unwrap();
4096 assert!(json.contains("text_delta"));
4097 }
4098
4099 #[test]
4100 fn test_agent_event_serialize_tool_start() {
4101 let event = AgentEvent::ToolStart {
4102 id: "t1".to_string(),
4103 name: "bash".to_string(),
4104 };
4105 let json = serde_json::to_string(&event).unwrap();
4106 assert!(json.contains("tool_start"));
4107 assert!(json.contains("bash"));
4108 }
4109
4110 #[test]
4111 fn test_agent_event_serialize_tool_end() {
4112 let event = AgentEvent::ToolEnd {
4113 id: "t1".to_string(),
4114 name: "bash".to_string(),
4115 output: "hello".to_string(),
4116 exit_code: 0,
4117 };
4118 let json = serde_json::to_string(&event).unwrap();
4119 assert!(json.contains("tool_end"));
4120 }
4121
4122 #[test]
4123 fn test_agent_event_serialize_error() {
4124 let event = AgentEvent::Error {
4125 message: "oops".to_string(),
4126 };
4127 let json = serde_json::to_string(&event).unwrap();
4128 assert!(json.contains("error"));
4129 assert!(json.contains("oops"));
4130 }
4131
4132 #[test]
4133 fn test_agent_event_serialize_confirmation_required() {
4134 let event = AgentEvent::ConfirmationRequired {
4135 tool_id: "t1".to_string(),
4136 tool_name: "bash".to_string(),
4137 args: serde_json::json!({"cmd": "rm"}),
4138 timeout_ms: 30000,
4139 };
4140 let json = serde_json::to_string(&event).unwrap();
4141 assert!(json.contains("confirmation_required"));
4142 }
4143
4144 #[test]
4145 fn test_agent_event_serialize_confirmation_received() {
4146 let event = AgentEvent::ConfirmationReceived {
4147 tool_id: "t1".to_string(),
4148 approved: true,
4149 reason: Some("safe".to_string()),
4150 };
4151 let json = serde_json::to_string(&event).unwrap();
4152 assert!(json.contains("confirmation_received"));
4153 }
4154
4155 #[test]
4156 fn test_agent_event_serialize_confirmation_timeout() {
4157 let event = AgentEvent::ConfirmationTimeout {
4158 tool_id: "t1".to_string(),
4159 action_taken: "rejected".to_string(),
4160 };
4161 let json = serde_json::to_string(&event).unwrap();
4162 assert!(json.contains("confirmation_timeout"));
4163 }
4164
4165 #[test]
4166 fn test_agent_event_serialize_external_task_pending() {
4167 let event = AgentEvent::ExternalTaskPending {
4168 task_id: "task-1".to_string(),
4169 session_id: "sess-1".to_string(),
4170 lane: crate::hitl::SessionLane::Execute,
4171 command_type: "bash".to_string(),
4172 payload: serde_json::json!({}),
4173 timeout_ms: 60000,
4174 };
4175 let json = serde_json::to_string(&event).unwrap();
4176 assert!(json.contains("external_task_pending"));
4177 }
4178
4179 #[test]
4180 fn test_agent_event_serialize_external_task_completed() {
4181 let event = AgentEvent::ExternalTaskCompleted {
4182 task_id: "task-1".to_string(),
4183 session_id: "sess-1".to_string(),
4184 success: false,
4185 };
4186 let json = serde_json::to_string(&event).unwrap();
4187 assert!(json.contains("external_task_completed"));
4188 }
4189
4190 #[test]
4191 fn test_agent_event_serialize_permission_denied() {
4192 let event = AgentEvent::PermissionDenied {
4193 tool_id: "t1".to_string(),
4194 tool_name: "bash".to_string(),
4195 args: serde_json::json!({}),
4196 reason: "denied".to_string(),
4197 };
4198 let json = serde_json::to_string(&event).unwrap();
4199 assert!(json.contains("permission_denied"));
4200 }
4201
4202 #[test]
4203 fn test_agent_event_serialize_context_compacted() {
4204 let event = AgentEvent::ContextCompacted {
4205 session_id: "sess-1".to_string(),
4206 before_messages: 100,
4207 after_messages: 20,
4208 percent_before: 0.85,
4209 };
4210 let json = serde_json::to_string(&event).unwrap();
4211 assert!(json.contains("context_compacted"));
4212 }
4213
4214 #[test]
4215 fn test_agent_event_serialize_turn_start() {
4216 let event = AgentEvent::TurnStart { turn: 3 };
4217 let json = serde_json::to_string(&event).unwrap();
4218 assert!(json.contains("turn_start"));
4219 }
4220
4221 #[test]
4222 fn test_agent_event_serialize_turn_end() {
4223 let event = AgentEvent::TurnEnd {
4224 turn: 3,
4225 usage: TokenUsage::default(),
4226 };
4227 let json = serde_json::to_string(&event).unwrap();
4228 assert!(json.contains("turn_end"));
4229 }
4230
4231 #[test]
4232 fn test_agent_event_serialize_end() {
4233 let event = AgentEvent::End {
4234 text: "Done".to_string(),
4235 usage: TokenUsage {
4236 prompt_tokens: 100,
4237 completion_tokens: 50,
4238 total_tokens: 150,
4239 cache_read_tokens: None,
4240 cache_write_tokens: None,
4241 },
4242 };
4243 let json = serde_json::to_string(&event).unwrap();
4244 assert!(json.contains("agent_end"));
4245 }
4246
4247 #[test]
4252 fn test_agent_result_fields() {
4253 let result = AgentResult {
4254 text: "output".to_string(),
4255 messages: vec![Message::user("hello")],
4256 usage: TokenUsage::default(),
4257 tool_calls_count: 3,
4258 };
4259 assert_eq!(result.text, "output");
4260 assert_eq!(result.messages.len(), 1);
4261 assert_eq!(result.tool_calls_count, 3);
4262 }
4263
4264 #[test]
4269 fn test_agent_event_serialize_context_resolving() {
4270 let event = AgentEvent::ContextResolving {
4271 providers: vec!["provider1".to_string(), "provider2".to_string()],
4272 };
4273 let json = serde_json::to_string(&event).unwrap();
4274 assert!(json.contains("context_resolving"));
4275 assert!(json.contains("provider1"));
4276 }
4277
4278 #[test]
4279 fn test_agent_event_serialize_context_resolved() {
4280 let event = AgentEvent::ContextResolved {
4281 total_items: 5,
4282 total_tokens: 1000,
4283 };
4284 let json = serde_json::to_string(&event).unwrap();
4285 assert!(json.contains("context_resolved"));
4286 assert!(json.contains("1000"));
4287 }
4288
4289 #[test]
4290 fn test_agent_event_serialize_command_dead_lettered() {
4291 let event = AgentEvent::CommandDeadLettered {
4292 command_id: "cmd-1".to_string(),
4293 command_type: "bash".to_string(),
4294 lane: "execute".to_string(),
4295 error: "timeout".to_string(),
4296 attempts: 3,
4297 };
4298 let json = serde_json::to_string(&event).unwrap();
4299 assert!(json.contains("command_dead_lettered"));
4300 assert!(json.contains("cmd-1"));
4301 }
4302
4303 #[test]
4304 fn test_agent_event_serialize_command_retry() {
4305 let event = AgentEvent::CommandRetry {
4306 command_id: "cmd-2".to_string(),
4307 command_type: "read".to_string(),
4308 lane: "query".to_string(),
4309 attempt: 2,
4310 delay_ms: 1000,
4311 };
4312 let json = serde_json::to_string(&event).unwrap();
4313 assert!(json.contains("command_retry"));
4314 assert!(json.contains("cmd-2"));
4315 }
4316
4317 #[test]
4318 fn test_agent_event_serialize_queue_alert() {
4319 let event = AgentEvent::QueueAlert {
4320 level: "warning".to_string(),
4321 alert_type: "depth".to_string(),
4322 message: "Queue depth exceeded".to_string(),
4323 };
4324 let json = serde_json::to_string(&event).unwrap();
4325 assert!(json.contains("queue_alert"));
4326 assert!(json.contains("warning"));
4327 }
4328
4329 #[test]
4330 fn test_agent_event_serialize_task_updated() {
4331 let event = AgentEvent::TaskUpdated {
4332 session_id: "sess-1".to_string(),
4333 tasks: vec![],
4334 };
4335 let json = serde_json::to_string(&event).unwrap();
4336 assert!(json.contains("task_updated"));
4337 assert!(json.contains("sess-1"));
4338 }
4339
4340 #[test]
4341 fn test_agent_event_serialize_memory_stored() {
4342 let event = AgentEvent::MemoryStored {
4343 memory_id: "mem-1".to_string(),
4344 memory_type: "conversation".to_string(),
4345 importance: 0.8,
4346 tags: vec!["important".to_string()],
4347 };
4348 let json = serde_json::to_string(&event).unwrap();
4349 assert!(json.contains("memory_stored"));
4350 assert!(json.contains("mem-1"));
4351 }
4352
4353 #[test]
4354 fn test_agent_event_serialize_memory_recalled() {
4355 let event = AgentEvent::MemoryRecalled {
4356 memory_id: "mem-2".to_string(),
4357 content: "Previous conversation".to_string(),
4358 relevance: 0.9,
4359 };
4360 let json = serde_json::to_string(&event).unwrap();
4361 assert!(json.contains("memory_recalled"));
4362 assert!(json.contains("mem-2"));
4363 }
4364
4365 #[test]
4366 fn test_agent_event_serialize_memories_searched() {
4367 let event = AgentEvent::MemoriesSearched {
4368 query: Some("search term".to_string()),
4369 tags: vec!["tag1".to_string()],
4370 result_count: 5,
4371 };
4372 let json = serde_json::to_string(&event).unwrap();
4373 assert!(json.contains("memories_searched"));
4374 assert!(json.contains("search term"));
4375 }
4376
4377 #[test]
4378 fn test_agent_event_serialize_memory_cleared() {
4379 let event = AgentEvent::MemoryCleared {
4380 tier: "short_term".to_string(),
4381 count: 10,
4382 };
4383 let json = serde_json::to_string(&event).unwrap();
4384 assert!(json.contains("memory_cleared"));
4385 assert!(json.contains("short_term"));
4386 }
4387
4388 #[test]
4389 fn test_agent_event_serialize_subagent_start() {
4390 let event = AgentEvent::SubagentStart {
4391 task_id: "task-1".to_string(),
4392 session_id: "child-sess".to_string(),
4393 parent_session_id: "parent-sess".to_string(),
4394 agent: "explore".to_string(),
4395 description: "Explore codebase".to_string(),
4396 };
4397 let json = serde_json::to_string(&event).unwrap();
4398 assert!(json.contains("subagent_start"));
4399 assert!(json.contains("explore"));
4400 }
4401
4402 #[test]
4403 fn test_agent_event_serialize_subagent_progress() {
4404 let event = AgentEvent::SubagentProgress {
4405 task_id: "task-1".to_string(),
4406 session_id: "child-sess".to_string(),
4407 status: "processing".to_string(),
4408 metadata: serde_json::json!({"progress": 50}),
4409 };
4410 let json = serde_json::to_string(&event).unwrap();
4411 assert!(json.contains("subagent_progress"));
4412 assert!(json.contains("processing"));
4413 }
4414
4415 #[test]
4416 fn test_agent_event_serialize_subagent_end() {
4417 let event = AgentEvent::SubagentEnd {
4418 task_id: "task-1".to_string(),
4419 session_id: "child-sess".to_string(),
4420 agent: "explore".to_string(),
4421 output: "Found 10 files".to_string(),
4422 success: true,
4423 };
4424 let json = serde_json::to_string(&event).unwrap();
4425 assert!(json.contains("subagent_end"));
4426 assert!(json.contains("Found 10 files"));
4427 }
4428
4429 #[test]
4430 fn test_agent_event_serialize_planning_start() {
4431 let event = AgentEvent::PlanningStart {
4432 prompt: "Build a web app".to_string(),
4433 };
4434 let json = serde_json::to_string(&event).unwrap();
4435 assert!(json.contains("planning_start"));
4436 assert!(json.contains("Build a web app"));
4437 }
4438
4439 #[test]
4440 fn test_agent_event_serialize_planning_end() {
4441 use crate::planning::{Complexity, ExecutionPlan};
4442 let plan = ExecutionPlan::new("Test goal".to_string(), Complexity::Simple);
4443 let event = AgentEvent::PlanningEnd {
4444 plan,
4445 estimated_steps: 3,
4446 };
4447 let json = serde_json::to_string(&event).unwrap();
4448 assert!(json.contains("planning_end"));
4449 assert!(json.contains("estimated_steps"));
4450 }
4451
4452 #[test]
4453 fn test_agent_event_serialize_step_start() {
4454 let event = AgentEvent::StepStart {
4455 step_id: "step-1".to_string(),
4456 description: "Initialize project".to_string(),
4457 step_number: 1,
4458 total_steps: 5,
4459 };
4460 let json = serde_json::to_string(&event).unwrap();
4461 assert!(json.contains("step_start"));
4462 assert!(json.contains("Initialize project"));
4463 }
4464
4465 #[test]
4466 fn test_agent_event_serialize_step_end() {
4467 let event = AgentEvent::StepEnd {
4468 step_id: "step-1".to_string(),
4469 status: TaskStatus::Completed,
4470 step_number: 1,
4471 total_steps: 5,
4472 };
4473 let json = serde_json::to_string(&event).unwrap();
4474 assert!(json.contains("step_end"));
4475 assert!(json.contains("step-1"));
4476 }
4477
4478 #[test]
4479 fn test_agent_event_serialize_goal_extracted() {
4480 use crate::planning::AgentGoal;
4481 let goal = AgentGoal::new("Complete the task".to_string());
4482 let event = AgentEvent::GoalExtracted { goal };
4483 let json = serde_json::to_string(&event).unwrap();
4484 assert!(json.contains("goal_extracted"));
4485 }
4486
4487 #[test]
4488 fn test_agent_event_serialize_goal_progress() {
4489 let event = AgentEvent::GoalProgress {
4490 goal: "Build app".to_string(),
4491 progress: 0.5,
4492 completed_steps: 2,
4493 total_steps: 4,
4494 };
4495 let json = serde_json::to_string(&event).unwrap();
4496 assert!(json.contains("goal_progress"));
4497 assert!(json.contains("0.5"));
4498 }
4499
4500 #[test]
4501 fn test_agent_event_serialize_goal_achieved() {
4502 let event = AgentEvent::GoalAchieved {
4503 goal: "Build app".to_string(),
4504 total_steps: 4,
4505 duration_ms: 5000,
4506 };
4507 let json = serde_json::to_string(&event).unwrap();
4508 assert!(json.contains("goal_achieved"));
4509 assert!(json.contains("5000"));
4510 }
4511
4512 #[tokio::test]
4513 async fn test_extract_goal_with_json_response() {
4514 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4516 r#"{"description": "Build web app", "success_criteria": ["App runs on port 3000", "Has login page"]}"#,
4517 )]));
4518 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4519 let agent = AgentLoop::new(
4520 mock_client,
4521 tool_executor,
4522 test_tool_context(),
4523 AgentConfig::default(),
4524 );
4525
4526 let goal = agent.extract_goal("Build a web app").await.unwrap();
4527 assert_eq!(goal.description, "Build web app");
4528 assert_eq!(goal.success_criteria.len(), 2);
4529 assert_eq!(goal.success_criteria[0], "App runs on port 3000");
4530 }
4531
4532 #[tokio::test]
4533 async fn test_extract_goal_fallback_on_non_json() {
4534 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4536 "Some non-JSON response",
4537 )]));
4538 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4539 let agent = AgentLoop::new(
4540 mock_client,
4541 tool_executor,
4542 test_tool_context(),
4543 AgentConfig::default(),
4544 );
4545
4546 let goal = agent.extract_goal("Do something").await.unwrap();
4547 assert_eq!(goal.description, "Do something");
4549 assert_eq!(goal.success_criteria.len(), 2);
4551 }
4552
4553 #[tokio::test]
4554 async fn test_check_goal_achievement_json_yes() {
4555 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4556 r#"{"achieved": true, "progress": 1.0, "remaining_criteria": []}"#,
4557 )]));
4558 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4559 let agent = AgentLoop::new(
4560 mock_client,
4561 tool_executor,
4562 test_tool_context(),
4563 AgentConfig::default(),
4564 );
4565
4566 let goal = crate::planning::AgentGoal::new("Test goal".to_string());
4567 let achieved = agent
4568 .check_goal_achievement(&goal, "All done")
4569 .await
4570 .unwrap();
4571 assert!(achieved);
4572 }
4573
4574 #[tokio::test]
4575 async fn test_check_goal_achievement_fallback_not_done() {
4576 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4578 "invalid json",
4579 )]));
4580 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4581 let agent = AgentLoop::new(
4582 mock_client,
4583 tool_executor,
4584 test_tool_context(),
4585 AgentConfig::default(),
4586 );
4587
4588 let goal = crate::planning::AgentGoal::new("Test goal".to_string());
4589 let achieved = agent
4591 .check_goal_achievement(&goal, "still working")
4592 .await
4593 .unwrap();
4594 assert!(!achieved);
4595 }
4596
4597 #[test]
4602 fn test_build_augmented_system_prompt_empty_context() {
4603 let mock_client = Arc::new(MockLlmClient::new(vec![]));
4604 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4605 let config = AgentConfig {
4606 system_prompt: Some("Base prompt".to_string()),
4607 ..Default::default()
4608 };
4609 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
4610
4611 let result = agent.build_augmented_system_prompt(&[]);
4612 assert_eq!(result, Some("Base prompt".to_string()));
4613 }
4614
4615 #[test]
4616 fn test_build_augmented_system_prompt_no_system_prompt() {
4617 let mock_client = Arc::new(MockLlmClient::new(vec![]));
4618 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4619 let agent = AgentLoop::new(
4620 mock_client,
4621 tool_executor,
4622 test_tool_context(),
4623 AgentConfig::default(),
4624 );
4625
4626 let result = agent.build_augmented_system_prompt(&[]);
4627 assert_eq!(result, None);
4628 }
4629
4630 #[test]
4631 fn test_build_augmented_system_prompt_with_context_no_base() {
4632 use crate::context::{ContextItem, ContextResult, ContextType};
4633
4634 let mock_client = Arc::new(MockLlmClient::new(vec![]));
4635 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4636 let agent = AgentLoop::new(
4637 mock_client,
4638 tool_executor,
4639 test_tool_context(),
4640 AgentConfig::default(),
4641 );
4642
4643 let context = vec![ContextResult {
4644 provider: "test".to_string(),
4645 items: vec![ContextItem::new("id1", ContextType::Resource, "Content")],
4646 total_tokens: 10,
4647 truncated: false,
4648 }];
4649
4650 let result = agent.build_augmented_system_prompt(&context);
4651 assert!(result.is_some());
4652 let text = result.unwrap();
4653 assert!(text.contains("<context"));
4654 assert!(text.contains("Content"));
4655 }
4656
4657 #[test]
4662 fn test_agent_result_clone() {
4663 let result = AgentResult {
4664 text: "output".to_string(),
4665 messages: vec![Message::user("hello")],
4666 usage: TokenUsage::default(),
4667 tool_calls_count: 3,
4668 };
4669 let cloned = result.clone();
4670 assert_eq!(cloned.text, result.text);
4671 assert_eq!(cloned.tool_calls_count, result.tool_calls_count);
4672 }
4673
4674 #[test]
4675 fn test_agent_result_debug() {
4676 let result = AgentResult {
4677 text: "output".to_string(),
4678 messages: vec![Message::user("hello")],
4679 usage: TokenUsage::default(),
4680 tool_calls_count: 3,
4681 };
4682 let debug = format!("{:?}", result);
4683 assert!(debug.contains("AgentResult"));
4684 assert!(debug.contains("output"));
4685 }
4686
4687 #[test]
4692 fn test_handle_post_execution_metadata_no_metadata() {
4693 let mut system = Some("base prompt".to_string());
4694 let result = AgentLoop::handle_post_execution_metadata(&None, &mut system, None);
4695 assert!(result.is_none());
4696 assert_eq!(system.as_deref(), Some("base prompt"));
4697 }
4698
4699 #[test]
4700 fn test_handle_post_execution_metadata_no_load_skill_key() {
4701 let mut system = Some("base prompt".to_string());
4702 let meta = Some(serde_json::json!({"other": "value"}));
4703 let result = AgentLoop::handle_post_execution_metadata(&meta, &mut system, None);
4704 assert!(result.is_none());
4705 assert_eq!(system.as_deref(), Some("base prompt"));
4706 }
4707
4708 #[test]
4709 fn test_handle_post_execution_metadata_load_skill_false() {
4710 let mut system = Some("base prompt".to_string());
4711 let meta = Some(serde_json::json!({"_load_skill": false}));
4712 let result = AgentLoop::handle_post_execution_metadata(&meta, &mut system, None);
4713 assert!(result.is_none());
4714 }
4715
4716 #[test]
4717 fn test_handle_post_execution_metadata_invalid_skill_content() {
4718 let mut system = Some("base prompt".to_string());
4719 let meta = Some(serde_json::json!({
4720 "_load_skill": true,
4721 "skill_name": "bad.md",
4722 "skill_content": "not a valid skill",
4723 }));
4724 let result = AgentLoop::handle_post_execution_metadata(&meta, &mut system, None);
4725 assert!(result.is_none());
4726 assert_eq!(system.as_deref(), Some("base prompt"));
4727 }
4728
4729 #[test]
4730 fn test_handle_post_execution_metadata_valid_skill() {
4731 let mut system = Some("base prompt".to_string());
4732 let skill_content =
4733 "---\nname: test-skill\ndescription: A test\n---\n# Instructions\nDo things.";
4734 let meta = Some(serde_json::json!({
4735 "_load_skill": true,
4736 "skill_name": "test-skill.md",
4737 "skill_content": skill_content,
4738 }));
4739 let result = AgentLoop::handle_post_execution_metadata(&meta, &mut system, None);
4740 assert!(result.is_some());
4741 let xml = result.unwrap();
4742 assert!(xml.contains("<skill name=\"test-skill\">"));
4743 assert!(xml.contains("# Instructions\nDo things."));
4744
4745 let sys = system.unwrap();
4747 assert!(sys.starts_with("base prompt"));
4748 assert!(sys.contains("<skills>"));
4749 assert!(sys.contains("</skills>"));
4750 }
4751
4752 #[test]
4753 fn test_handle_post_execution_metadata_none_system_prompt() {
4754 let mut system: Option<String> = None;
4755 let skill_content = "---\nname: my-skill\ndescription: desc\n---\nContent here";
4756 let meta = Some(serde_json::json!({
4757 "_load_skill": true,
4758 "skill_name": "my-skill.md",
4759 "skill_content": skill_content,
4760 }));
4761 let result = AgentLoop::handle_post_execution_metadata(&meta, &mut system, None);
4762 assert!(result.is_some());
4763
4764 let sys = system.unwrap();
4766 assert!(sys.contains("<skill name=\"my-skill\">"));
4767 assert!(sys.contains("Content here"));
4768 }
4769
4770 #[test]
4771 fn test_handle_post_execution_metadata_tool_kind_injects_xml() {
4772 let mut system = Some("base".to_string());
4773 let skill_content =
4774 "---\nname: tool-skill\nkind: tool\ndescription: A tool\n---\nTool instructions.";
4775 let meta = Some(serde_json::json!({
4776 "_load_skill": true,
4777 "skill_name": "tool-skill",
4778 "skill_content": skill_content,
4779 }));
4780 let result = AgentLoop::handle_post_execution_metadata(&meta, &mut system, None);
4781 assert!(result.is_some());
4782 let xml = result.unwrap();
4783 assert!(xml.contains("<skill name=\"tool-skill\">"));
4784 assert!(xml.contains("Tool instructions."));
4785 }
4786
4787 #[test]
4788 fn test_handle_post_execution_metadata_agent_kind_returns_none() {
4789 let mut system = Some("base".to_string());
4790 let skill_content =
4791 "---\nname: agent-skill\nkind: agent\ndescription: An agent\n---\nAgent def.";
4792 let meta = Some(serde_json::json!({
4793 "_load_skill": true,
4794 "skill_name": "agent-skill",
4795 "skill_content": skill_content,
4796 }));
4797 let result = AgentLoop::handle_post_execution_metadata(&meta, &mut system, None);
4798 assert!(result.is_none());
4800 assert_eq!(system.as_deref(), Some("base"));
4802 }
4803
4804 #[test]
4809 fn test_partition_by_lane_query_tools() {
4810 let tool_calls = vec![
4811 ToolCall {
4812 id: "t1".to_string(),
4813 name: "read".to_string(),
4814 args: serde_json::json!({"file": "a.rs"}),
4815 },
4816 ToolCall {
4817 id: "t2".to_string(),
4818 name: "glob".to_string(),
4819 args: serde_json::json!({"pattern": "**/*.rs"}),
4820 },
4821 ToolCall {
4822 id: "t3".to_string(),
4823 name: "grep".to_string(),
4824 args: serde_json::json!({"pattern": "fn main"}),
4825 },
4826 ToolCall {
4827 id: "t4".to_string(),
4828 name: "ls".to_string(),
4829 args: serde_json::json!({"path": "/tmp"}),
4830 },
4831 ToolCall {
4832 id: "t5".to_string(),
4833 name: "search".to_string(),
4834 args: serde_json::json!({"query": "error"}),
4835 },
4836 ToolCall {
4837 id: "t6".to_string(),
4838 name: "list_files".to_string(),
4839 args: serde_json::json!({}),
4840 },
4841 ];
4842
4843 let (query, sequential) = partition_by_lane(&tool_calls);
4844 assert_eq!(
4845 query.len(),
4846 6,
4847 "all read-only tools should be in query lane"
4848 );
4849 assert_eq!(sequential.len(), 0);
4850 }
4851
4852 #[test]
4853 fn test_partition_by_lane_execute_tools() {
4854 let tool_calls = vec![
4855 ToolCall {
4856 id: "t1".to_string(),
4857 name: "bash".to_string(),
4858 args: serde_json::json!({"command": "ls"}),
4859 },
4860 ToolCall {
4861 id: "t2".to_string(),
4862 name: "write".to_string(),
4863 args: serde_json::json!({"file": "a.rs", "content": ""}),
4864 },
4865 ToolCall {
4866 id: "t3".to_string(),
4867 name: "edit".to_string(),
4868 args: serde_json::json!({}),
4869 },
4870 ToolCall {
4871 id: "t4".to_string(),
4872 name: "delete".to_string(),
4873 args: serde_json::json!({}),
4874 },
4875 ];
4876
4877 let (query, sequential) = partition_by_lane(&tool_calls);
4878 assert_eq!(query.len(), 0);
4879 assert_eq!(sequential.len(), 4, "all write tools should be sequential");
4880 }
4881
4882 #[test]
4883 fn test_partition_by_lane_mixed() {
4884 let tool_calls = vec![
4885 ToolCall {
4886 id: "t1".to_string(),
4887 name: "read".to_string(),
4888 args: serde_json::json!({"file": "a.rs"}),
4889 },
4890 ToolCall {
4891 id: "t2".to_string(),
4892 name: "bash".to_string(),
4893 args: serde_json::json!({"command": "cargo build"}),
4894 },
4895 ToolCall {
4896 id: "t3".to_string(),
4897 name: "glob".to_string(),
4898 args: serde_json::json!({"pattern": "*.rs"}),
4899 },
4900 ToolCall {
4901 id: "t4".to_string(),
4902 name: "write".to_string(),
4903 args: serde_json::json!({"file": "b.rs", "content": ""}),
4904 },
4905 ToolCall {
4906 id: "t5".to_string(),
4907 name: "grep".to_string(),
4908 args: serde_json::json!({"pattern": "test"}),
4909 },
4910 ];
4911
4912 let (query, sequential) = partition_by_lane(&tool_calls);
4913 assert_eq!(query.len(), 3, "read/glob/grep → Query");
4914 assert_eq!(sequential.len(), 2, "bash/write → Sequential");
4915
4916 assert_eq!(query[0].name, "read");
4918 assert_eq!(query[1].name, "glob");
4919 assert_eq!(query[2].name, "grep");
4920 assert_eq!(sequential[0].name, "bash");
4921 assert_eq!(sequential[1].name, "write");
4922 }
4923
4924 #[test]
4925 fn test_partition_by_lane_empty() {
4926 let tool_calls: Vec<ToolCall> = vec![];
4927 let (query, sequential) = partition_by_lane(&tool_calls);
4928 assert!(query.is_empty());
4929 assert!(sequential.is_empty());
4930 }
4931
4932 #[test]
4933 fn test_partition_by_lane_unknown_tool_goes_sequential() {
4934 let tool_calls = vec![ToolCall {
4936 id: "t1".to_string(),
4937 name: "custom_tool".to_string(),
4938 args: serde_json::json!({}),
4939 }];
4940
4941 let (query, sequential) = partition_by_lane(&tool_calls);
4942 assert_eq!(query.len(), 0);
4943 assert_eq!(sequential.len(), 1);
4944 }
4945
4946 #[tokio::test]
4951 async fn test_tool_command_command_type() {
4952 let executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4953 let cmd = ToolCommand {
4954 tool_executor: executor,
4955 tool_name: "read".to_string(),
4956 tool_args: serde_json::json!({"file": "test.rs"}),
4957 tool_context: test_tool_context(),
4958 };
4959 assert_eq!(cmd.command_type(), "read");
4960 }
4961
4962 #[tokio::test]
4963 async fn test_tool_command_payload() {
4964 let executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4965 let args = serde_json::json!({"file": "test.rs", "offset": 10});
4966 let cmd = ToolCommand {
4967 tool_executor: executor,
4968 tool_name: "read".to_string(),
4969 tool_args: args.clone(),
4970 tool_context: test_tool_context(),
4971 };
4972 assert_eq!(cmd.payload(), args);
4973 }
4974
4975 #[tokio::test(flavor = "multi_thread")]
4980 async fn test_agent_loop_with_queue() {
4981 use tokio::sync::broadcast;
4982
4983 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4984 "Hello",
4985 )]));
4986 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4987 let config = AgentConfig::default();
4988
4989 let (event_tx, _) = broadcast::channel(100);
4990 let queue = SessionLaneQueue::new("test-session", SessionQueueConfig::default(), event_tx)
4991 .await
4992 .unwrap();
4993
4994 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config)
4995 .with_queue(Arc::new(queue));
4996
4997 assert!(agent.command_queue.is_some());
4998 }
4999
5000 #[tokio::test]
5001 async fn test_agent_loop_without_queue() {
5002 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
5003 "Hello",
5004 )]));
5005 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5006 let config = AgentConfig::default();
5007
5008 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
5009
5010 assert!(agent.command_queue.is_none());
5011 }
5012
5013 #[tokio::test]
5018 async fn test_execute_plan_parallel_independent() {
5019 use crate::planning::{Complexity, ExecutionPlan, Task};
5020
5021 let mock_client = Arc::new(MockLlmClient::new(vec![
5024 MockLlmClient::text_response("Step 1 done"),
5025 MockLlmClient::text_response("Step 2 done"),
5026 MockLlmClient::text_response("Step 3 done"),
5027 ]));
5028
5029 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5030 let config = AgentConfig::default();
5031 let agent = AgentLoop::new(
5032 mock_client.clone(),
5033 tool_executor,
5034 test_tool_context(),
5035 config,
5036 );
5037
5038 let mut plan = ExecutionPlan::new("Test parallel", Complexity::Simple);
5039 plan.add_step(Task::new("s1", "First step"));
5040 plan.add_step(Task::new("s2", "Second step"));
5041 plan.add_step(Task::new("s3", "Third step"));
5042
5043 let (tx, mut rx) = mpsc::channel(100);
5044 let result = agent.execute_plan(&[], &plan, Some(tx)).await.unwrap();
5045
5046 assert_eq!(result.usage.total_tokens, 45);
5048
5049 let mut step_starts = Vec::new();
5051 let mut step_ends = Vec::new();
5052 rx.close();
5053 while let Some(event) = rx.recv().await {
5054 match event {
5055 AgentEvent::StepStart { step_id, .. } => step_starts.push(step_id),
5056 AgentEvent::StepEnd {
5057 step_id, status, ..
5058 } => {
5059 assert_eq!(status, TaskStatus::Completed);
5060 step_ends.push(step_id);
5061 }
5062 _ => {}
5063 }
5064 }
5065 assert_eq!(step_starts.len(), 3);
5066 assert_eq!(step_ends.len(), 3);
5067 }
5068
5069 #[tokio::test]
5070 async fn test_execute_plan_respects_dependencies() {
5071 use crate::planning::{Complexity, ExecutionPlan, Task};
5072
5073 let mock_client = Arc::new(MockLlmClient::new(vec![
5076 MockLlmClient::text_response("Step 1 done"),
5077 MockLlmClient::text_response("Step 2 done"),
5078 MockLlmClient::text_response("Step 3 done"),
5079 ]));
5080
5081 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5082 let config = AgentConfig::default();
5083 let agent = AgentLoop::new(
5084 mock_client.clone(),
5085 tool_executor,
5086 test_tool_context(),
5087 config,
5088 );
5089
5090 let mut plan = ExecutionPlan::new("Test deps", Complexity::Medium);
5091 plan.add_step(Task::new("s1", "Independent A"));
5092 plan.add_step(Task::new("s2", "Independent B"));
5093 plan.add_step(
5094 Task::new("s3", "Depends on A+B")
5095 .with_dependencies(vec!["s1".to_string(), "s2".to_string()]),
5096 );
5097
5098 let (tx, mut rx) = mpsc::channel(100);
5099 let result = agent.execute_plan(&[], &plan, Some(tx)).await.unwrap();
5100
5101 assert_eq!(result.usage.total_tokens, 45);
5103
5104 let mut events = Vec::new();
5106 rx.close();
5107 while let Some(event) = rx.recv().await {
5108 match &event {
5109 AgentEvent::StepStart { step_id, .. } => {
5110 events.push(format!("start:{}", step_id));
5111 }
5112 AgentEvent::StepEnd { step_id, .. } => {
5113 events.push(format!("end:{}", step_id));
5114 }
5115 _ => {}
5116 }
5117 }
5118
5119 let s1_end = events.iter().position(|e| e == "end:s1").unwrap();
5121 let s2_end = events.iter().position(|e| e == "end:s2").unwrap();
5122 let s3_start = events.iter().position(|e| e == "start:s3").unwrap();
5123 assert!(
5124 s3_start > s1_end,
5125 "s3 started before s1 ended: {:?}",
5126 events
5127 );
5128 assert!(
5129 s3_start > s2_end,
5130 "s3 started before s2 ended: {:?}",
5131 events
5132 );
5133
5134 assert!(result.text.contains("Step 3 done") || !result.text.is_empty());
5136 }
5137
5138 #[tokio::test]
5139 async fn test_execute_plan_handles_step_failure() {
5140 use crate::planning::{Complexity, ExecutionPlan, Task};
5141
5142 let mock_client = Arc::new(MockLlmClient::new(vec![
5152 MockLlmClient::text_response("s1 done"),
5154 MockLlmClient::text_response("s3 done"),
5155 ]));
5158
5159 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5160 let config = AgentConfig::default();
5161 let agent = AgentLoop::new(
5162 mock_client.clone(),
5163 tool_executor,
5164 test_tool_context(),
5165 config,
5166 );
5167
5168 let mut plan = ExecutionPlan::new("Test failure", Complexity::Medium);
5169 plan.add_step(Task::new("s1", "Independent step"));
5170 plan.add_step(Task::new("s2", "Depends on s1").with_dependencies(vec!["s1".to_string()]));
5171 plan.add_step(Task::new("s3", "Another independent"));
5172 plan.add_step(Task::new("s4", "Depends on s2").with_dependencies(vec!["s2".to_string()]));
5173
5174 let (tx, mut rx) = mpsc::channel(100);
5175 let _result = agent.execute_plan(&[], &plan, Some(tx)).await.unwrap();
5176
5177 let mut completed_steps = Vec::new();
5180 let mut failed_steps = Vec::new();
5181 rx.close();
5182 while let Some(event) = rx.recv().await {
5183 if let AgentEvent::StepEnd {
5184 step_id, status, ..
5185 } = event
5186 {
5187 match status {
5188 TaskStatus::Completed => completed_steps.push(step_id),
5189 TaskStatus::Failed => failed_steps.push(step_id),
5190 _ => {}
5191 }
5192 }
5193 }
5194
5195 assert!(
5196 completed_steps.contains(&"s1".to_string()),
5197 "s1 should complete"
5198 );
5199 assert!(
5200 completed_steps.contains(&"s3".to_string()),
5201 "s3 should complete"
5202 );
5203 assert!(failed_steps.contains(&"s2".to_string()), "s2 should fail");
5204 assert!(
5206 !completed_steps.contains(&"s4".to_string()),
5207 "s4 should not complete"
5208 );
5209 assert!(
5210 !failed_steps.contains(&"s4".to_string()),
5211 "s4 should not fail (never started)"
5212 );
5213 }
5214}