1use crate::context::{ContextProvider, ContextQuery, ContextResult};
13use crate::hitl::ConfirmationProvider;
14use crate::hooks::{
15 GenerateEndEvent, GenerateStartEvent, HookEvent, HookExecutor, HookResult, PostToolUseEvent,
16 PreToolUseEvent, TokenUsageInfo, ToolCallInfo, ToolResultData,
17};
18use crate::llm::{LlmClient, LlmResponse, Message, TokenUsage, ToolCall, ToolDefinition};
19use crate::permissions::{PermissionChecker, PermissionDecision};
20use crate::planning::{AgentGoal, ExecutionPlan, TaskStatus};
21use crate::queue::{SessionCommand, SessionLane};
22use crate::session_lane_queue::SessionLaneQueue;
23use crate::tools::{ToolContext, ToolExecutor, ToolStreamEvent};
24use anyhow::{Context, Result};
25use async_trait::async_trait;
26use futures::future::join_all;
27use serde::{Deserialize, Serialize};
28use serde_json::Value;
29use std::sync::Arc;
30use std::time::Duration;
31use tokio::sync::{mpsc, RwLock};
32
33const MAX_TOOL_ROUNDS: usize = 50;
35
36#[derive(Clone)]
38pub struct AgentConfig {
39 pub system_prompt: Option<String>,
40 pub tools: Vec<ToolDefinition>,
41 pub max_tool_rounds: usize,
42 pub permission_checker: Option<Arc<dyn PermissionChecker>>,
44 pub confirmation_manager: Option<Arc<dyn ConfirmationProvider>>,
46 pub context_providers: Vec<Arc<dyn ContextProvider>>,
48 pub planning_enabled: bool,
50 pub goal_tracking: bool,
52 pub hook_engine: Option<Arc<dyn HookExecutor>>,
54 pub skill_registry: Option<Arc<crate::skills::SkillRegistry>>,
56}
57
58impl std::fmt::Debug for AgentConfig {
59 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60 f.debug_struct("AgentConfig")
61 .field("system_prompt", &self.system_prompt)
62 .field("tools", &self.tools)
63 .field("max_tool_rounds", &self.max_tool_rounds)
64 .field("permission_checker", &self.permission_checker.is_some())
65 .field("confirmation_manager", &self.confirmation_manager.is_some())
66 .field("context_providers", &self.context_providers.len())
67 .field("planning_enabled", &self.planning_enabled)
68 .field("goal_tracking", &self.goal_tracking)
69 .field("hook_engine", &self.hook_engine.is_some())
70 .field("skill_registry", &self.skill_registry.as_ref().map(|r| r.len()))
71 .finish()
72 }
73}
74
75impl Default for AgentConfig {
76 fn default() -> Self {
77 Self {
78 system_prompt: None,
79 tools: Vec::new(), max_tool_rounds: MAX_TOOL_ROUNDS,
81 permission_checker: None,
82 confirmation_manager: None,
83 context_providers: Vec::new(),
84 planning_enabled: false,
85 goal_tracking: false,
86 hook_engine: None,
87 skill_registry: None,
88 }
89 }
90}
91
92#[derive(Debug, Clone, Serialize, Deserialize)]
98#[serde(tag = "type")]
99#[non_exhaustive]
100pub enum AgentEvent {
101 #[serde(rename = "agent_start")]
103 Start { prompt: String },
104
105 #[serde(rename = "turn_start")]
107 TurnStart { turn: usize },
108
109 #[serde(rename = "text_delta")]
111 TextDelta { text: String },
112
113 #[serde(rename = "tool_start")]
115 ToolStart { id: String, name: String },
116
117 #[serde(rename = "tool_end")]
119 ToolEnd {
120 id: String,
121 name: String,
122 output: String,
123 exit_code: i32,
124 },
125
126 #[serde(rename = "tool_output_delta")]
128 ToolOutputDelta {
129 id: String,
130 name: String,
131 delta: String,
132 },
133
134 #[serde(rename = "turn_end")]
136 TurnEnd { turn: usize, usage: TokenUsage },
137
138 #[serde(rename = "agent_end")]
140 End { text: String, usage: TokenUsage },
141
142 #[serde(rename = "error")]
144 Error { message: String },
145
146 #[serde(rename = "confirmation_required")]
148 ConfirmationRequired {
149 tool_id: String,
150 tool_name: String,
151 args: serde_json::Value,
152 timeout_ms: u64,
153 },
154
155 #[serde(rename = "confirmation_received")]
157 ConfirmationReceived {
158 tool_id: String,
159 approved: bool,
160 reason: Option<String>,
161 },
162
163 #[serde(rename = "confirmation_timeout")]
165 ConfirmationTimeout {
166 tool_id: String,
167 action_taken: String, },
169
170 #[serde(rename = "external_task_pending")]
172 ExternalTaskPending {
173 task_id: String,
174 session_id: String,
175 lane: crate::hitl::SessionLane,
176 command_type: String,
177 payload: serde_json::Value,
178 timeout_ms: u64,
179 },
180
181 #[serde(rename = "external_task_completed")]
183 ExternalTaskCompleted {
184 task_id: String,
185 session_id: String,
186 success: bool,
187 },
188
189 #[serde(rename = "permission_denied")]
191 PermissionDenied {
192 tool_id: String,
193 tool_name: String,
194 args: serde_json::Value,
195 reason: String,
196 },
197
198 #[serde(rename = "context_resolving")]
200 ContextResolving { providers: Vec<String> },
201
202 #[serde(rename = "context_resolved")]
204 ContextResolved {
205 total_items: usize,
206 total_tokens: usize,
207 },
208
209 #[serde(rename = "command_dead_lettered")]
214 CommandDeadLettered {
215 command_id: String,
216 command_type: String,
217 lane: String,
218 error: String,
219 attempts: u32,
220 },
221
222 #[serde(rename = "command_retry")]
224 CommandRetry {
225 command_id: String,
226 command_type: String,
227 lane: String,
228 attempt: u32,
229 delay_ms: u64,
230 },
231
232 #[serde(rename = "queue_alert")]
234 QueueAlert {
235 level: String,
236 alert_type: String,
237 message: String,
238 },
239
240 #[serde(rename = "task_updated")]
245 TaskUpdated {
246 session_id: String,
247 tasks: Vec<crate::planning::Task>,
248 },
249
250 #[serde(rename = "memory_stored")]
255 MemoryStored {
256 memory_id: String,
257 memory_type: String,
258 importance: f32,
259 tags: Vec<String>,
260 },
261
262 #[serde(rename = "memory_recalled")]
264 MemoryRecalled {
265 memory_id: String,
266 content: String,
267 relevance: f32,
268 },
269
270 #[serde(rename = "memories_searched")]
272 MemoriesSearched {
273 query: Option<String>,
274 tags: Vec<String>,
275 result_count: usize,
276 },
277
278 #[serde(rename = "memory_cleared")]
280 MemoryCleared {
281 tier: String, count: u64,
283 },
284
285 #[serde(rename = "subagent_start")]
290 SubagentStart {
291 task_id: String,
293 session_id: String,
295 parent_session_id: String,
297 agent: String,
299 description: String,
301 },
302
303 #[serde(rename = "subagent_progress")]
305 SubagentProgress {
306 task_id: String,
308 session_id: String,
310 status: String,
312 metadata: serde_json::Value,
314 },
315
316 #[serde(rename = "subagent_end")]
318 SubagentEnd {
319 task_id: String,
321 session_id: String,
323 agent: String,
325 output: String,
327 success: bool,
329 },
330
331 #[serde(rename = "planning_start")]
336 PlanningStart { prompt: String },
337
338 #[serde(rename = "planning_end")]
340 PlanningEnd {
341 plan: ExecutionPlan,
342 estimated_steps: usize,
343 },
344
345 #[serde(rename = "step_start")]
347 StepStart {
348 step_id: String,
349 description: String,
350 step_number: usize,
351 total_steps: usize,
352 },
353
354 #[serde(rename = "step_end")]
356 StepEnd {
357 step_id: String,
358 status: TaskStatus,
359 step_number: usize,
360 total_steps: usize,
361 },
362
363 #[serde(rename = "goal_extracted")]
365 GoalExtracted { goal: AgentGoal },
366
367 #[serde(rename = "goal_progress")]
369 GoalProgress {
370 goal: String,
371 progress: f32,
372 completed_steps: usize,
373 total_steps: usize,
374 },
375
376 #[serde(rename = "goal_achieved")]
378 GoalAchieved {
379 goal: String,
380 total_steps: usize,
381 duration_ms: i64,
382 },
383
384 #[serde(rename = "context_compacted")]
389 ContextCompacted {
390 session_id: String,
391 before_messages: usize,
392 after_messages: usize,
393 percent_before: f32,
394 },
395
396 #[serde(rename = "persistence_failed")]
401 PersistenceFailed {
402 session_id: String,
403 operation: String,
404 error: String,
405 },
406}
407
408#[derive(Debug, Clone)]
410pub struct AgentResult {
411 pub text: String,
412 pub messages: Vec<Message>,
413 pub usage: TokenUsage,
414 pub tool_calls_count: usize,
415}
416
417pub struct ToolCommand {
425 tool_executor: Arc<ToolExecutor>,
426 tool_name: String,
427 tool_args: Value,
428 tool_context: ToolContext,
429 skill_registry: Option<Arc<crate::skills::SkillRegistry>>,
430}
431
432impl ToolCommand {
433 pub fn new(
435 tool_executor: Arc<ToolExecutor>,
436 tool_name: String,
437 tool_args: Value,
438 tool_context: ToolContext,
439 skill_registry: Option<Arc<crate::skills::SkillRegistry>>,
440 ) -> Self {
441 Self {
442 tool_executor,
443 tool_name,
444 tool_args,
445 tool_context,
446 skill_registry,
447 }
448 }
449}
450
451#[async_trait]
452impl SessionCommand for ToolCommand {
453 async fn execute(&self) -> Result<Value> {
454 if let Some(registry) = &self.skill_registry {
456 let instruction_skills = registry.by_kind(crate::skills::SkillKind::Instruction);
457
458 let has_restrictions = instruction_skills.iter().any(|s| s.allowed_tools.is_some());
460
461 if has_restrictions {
462 let mut allowed = false;
463
464 for skill in &instruction_skills {
465 if skill.is_tool_allowed(&self.tool_name) {
466 allowed = true;
467 break;
468 }
469 }
470
471 if !allowed {
472 return Err(anyhow::anyhow!(
473 "Tool '{}' is not allowed by any active skill. Active skills restrict tools to their allowed-tools lists.",
474 self.tool_name
475 ));
476 }
477 }
478 }
479
480 let result = self
482 .tool_executor
483 .execute_with_context(&self.tool_name, &self.tool_args, &self.tool_context)
484 .await?;
485 Ok(serde_json::json!({
486 "output": result.output,
487 "exit_code": result.exit_code,
488 "metadata": result.metadata,
489 }))
490 }
491
492 fn command_type(&self) -> &str {
493 &self.tool_name
494 }
495
496 fn payload(&self) -> Value {
497 self.tool_args.clone()
498 }
499}
500
501pub fn partition_by_lane(tool_calls: &[ToolCall]) -> (Vec<ToolCall>, Vec<ToolCall>) {
511 let mut query_tools = Vec::new();
512 let mut sequential_tools = Vec::new();
513
514 for tc in tool_calls {
515 match SessionLane::from_tool_name(&tc.name) {
516 SessionLane::Query => query_tools.push(tc.clone()),
517 _ => sequential_tools.push(tc.clone()),
518 }
519 }
520
521 (query_tools, sequential_tools)
522}
523
524#[derive(Clone)]
526pub struct AgentLoop {
527 llm_client: Arc<dyn LlmClient>,
528 tool_executor: Arc<ToolExecutor>,
529 tool_context: ToolContext,
530 config: AgentConfig,
531 tool_metrics: Option<Arc<RwLock<crate::telemetry::ToolMetrics>>>,
533 command_queue: Option<Arc<SessionLaneQueue>>,
535}
536
537impl AgentLoop {
538 pub fn new(
539 llm_client: Arc<dyn LlmClient>,
540 tool_executor: Arc<ToolExecutor>,
541 tool_context: ToolContext,
542 config: AgentConfig,
543 ) -> Self {
544 Self {
545 llm_client,
546 tool_executor,
547 tool_context,
548 config,
549 tool_metrics: None,
550 command_queue: None,
551 }
552 }
553
554 pub fn with_tool_metrics(
556 mut self,
557 metrics: Arc<RwLock<crate::telemetry::ToolMetrics>>,
558 ) -> Self {
559 self.tool_metrics = Some(metrics);
560 self
561 }
562
563 pub fn with_queue(mut self, queue: Arc<SessionLaneQueue>) -> Self {
569 self.command_queue = Some(queue);
570 self
571 }
572
573 fn streaming_tool_context(
582 &self,
583 event_tx: &Option<mpsc::Sender<AgentEvent>>,
584 tool_id: &str,
585 tool_name: &str,
586 ) -> ToolContext {
587 let mut ctx = self.tool_context.clone();
588 if let Some(agent_tx) = event_tx {
589 let (tool_tx, mut tool_rx) = mpsc::channel::<ToolStreamEvent>(64);
590 ctx.event_tx = Some(tool_tx);
591
592 let agent_tx = agent_tx.clone();
593 let tool_id = tool_id.to_string();
594 let tool_name = tool_name.to_string();
595 tokio::spawn(async move {
596 while let Some(event) = tool_rx.recv().await {
597 match event {
598 ToolStreamEvent::OutputDelta(delta) => {
599 agent_tx
600 .send(AgentEvent::ToolOutputDelta {
601 id: tool_id.clone(),
602 name: tool_name.clone(),
603 delta,
604 })
605 .await
606 .ok();
607 }
608 }
609 }
610 });
611 }
612 ctx
613 }
614
615 async fn resolve_context(&self, prompt: &str, session_id: Option<&str>) -> Vec<ContextResult> {
619 if self.config.context_providers.is_empty() {
620 return Vec::new();
621 }
622
623 let query = ContextQuery::new(prompt).with_session_id(session_id.unwrap_or(""));
624
625 let futures = self
626 .config
627 .context_providers
628 .iter()
629 .map(|p| p.query(&query));
630 let outcomes = join_all(futures).await;
631
632 outcomes
633 .into_iter()
634 .enumerate()
635 .filter_map(|(i, r)| match r {
636 Ok(result) if !result.is_empty() => Some(result),
637 Ok(_) => None,
638 Err(e) => {
639 tracing::warn!(
640 "Context provider '{}' failed: {}",
641 self.config.context_providers[i].name(),
642 e
643 );
644 None
645 }
646 })
647 .collect()
648 }
649
650 fn build_augmented_system_prompt(&self, context_results: &[ContextResult]) -> Option<String> {
652 if context_results.is_empty() {
653 return self.config.system_prompt.clone();
654 }
655
656 let context_xml: String = context_results
658 .iter()
659 .map(|r| r.to_xml())
660 .collect::<Vec<_>>()
661 .join("\n\n");
662
663 match &self.config.system_prompt {
665 Some(system) => Some(format!("{}\n\n{}", system, context_xml)),
666 None => Some(context_xml),
667 }
668 }
669
670 async fn notify_turn_complete(&self, session_id: &str, prompt: &str, response: &str) {
672 let futures = self
673 .config
674 .context_providers
675 .iter()
676 .map(|p| p.on_turn_complete(session_id, prompt, response));
677 let outcomes = join_all(futures).await;
678
679 for (i, result) in outcomes.into_iter().enumerate() {
680 if let Err(e) = result {
681 tracing::warn!(
682 "Context provider '{}' on_turn_complete failed: {}",
683 self.config.context_providers[i].name(),
684 e
685 );
686 }
687 }
688 }
689
690 async fn fire_pre_tool_use(
693 &self,
694 session_id: &str,
695 tool_name: &str,
696 args: &serde_json::Value,
697 ) -> Option<HookResult> {
698 if let Some(he) = &self.config.hook_engine {
699 let event = HookEvent::PreToolUse(PreToolUseEvent {
700 session_id: session_id.to_string(),
701 tool: tool_name.to_string(),
702 args: args.clone(),
703 working_directory: self.tool_context.workspace.to_string_lossy().to_string(),
704 recent_tools: Vec::new(),
705 });
706 let result = he.fire(&event).await;
707 if result.is_block() {
708 return Some(result);
709 }
710 }
711 None
712 }
713
714 async fn fire_post_tool_use(
716 &self,
717 session_id: &str,
718 tool_name: &str,
719 args: &serde_json::Value,
720 output: &str,
721 success: bool,
722 duration_ms: u64,
723 ) {
724 if let Some(he) = &self.config.hook_engine {
725 let event = HookEvent::PostToolUse(PostToolUseEvent {
726 session_id: session_id.to_string(),
727 tool: tool_name.to_string(),
728 args: args.clone(),
729 result: ToolResultData {
730 success,
731 output: output.to_string(),
732 exit_code: if success { Some(0) } else { Some(1) },
733 duration_ms,
734 },
735 });
736 let _ = he.fire(&event).await;
737 }
738 }
739
740 async fn fire_generate_start(
742 &self,
743 session_id: &str,
744 prompt: &str,
745 system_prompt: &Option<String>,
746 ) {
747 if let Some(he) = &self.config.hook_engine {
748 let event = HookEvent::GenerateStart(GenerateStartEvent {
749 session_id: session_id.to_string(),
750 prompt: prompt.to_string(),
751 system_prompt: system_prompt.clone(),
752 model_provider: String::new(),
753 model_name: String::new(),
754 available_tools: self.config.tools.iter().map(|t| t.name.clone()).collect(),
755 });
756 let _ = he.fire(&event).await;
757 }
758 }
759
760 async fn fire_generate_end(
762 &self,
763 session_id: &str,
764 prompt: &str,
765 response: &LlmResponse,
766 duration_ms: u64,
767 ) {
768 if let Some(he) = &self.config.hook_engine {
769 let tool_calls: Vec<ToolCallInfo> = response
770 .tool_calls()
771 .iter()
772 .map(|tc| ToolCallInfo {
773 name: tc.name.clone(),
774 args: tc.args.clone(),
775 })
776 .collect();
777
778 let event = HookEvent::GenerateEnd(GenerateEndEvent {
779 session_id: session_id.to_string(),
780 prompt: prompt.to_string(),
781 response_text: response.text().to_string(),
782 tool_calls,
783 usage: TokenUsageInfo {
784 prompt_tokens: response.usage.prompt_tokens as i32,
785 completion_tokens: response.usage.completion_tokens as i32,
786 total_tokens: response.usage.total_tokens as i32,
787 },
788 duration_ms,
789 });
790 let _ = he.fire(&event).await;
791 }
792 }
793
794 pub async fn execute(
800 &self,
801 history: &[Message],
802 prompt: &str,
803 event_tx: Option<mpsc::Sender<AgentEvent>>,
804 ) -> Result<AgentResult> {
805 self.execute_with_session(history, prompt, None, event_tx)
806 .await
807 }
808
809 pub async fn execute_with_session(
814 &self,
815 history: &[Message],
816 prompt: &str,
817 session_id: Option<&str>,
818 event_tx: Option<mpsc::Sender<AgentEvent>>,
819 ) -> Result<AgentResult> {
820 tracing::info!(
821 a3s.session.id = session_id.unwrap_or("none"),
822 a3s.agent.max_turns = self.config.max_tool_rounds,
823 "a3s.agent.execute started"
824 );
825
826 let result = if self.config.planning_enabled {
828 self.execute_with_planning(history, prompt, event_tx).await
829 } else {
830 self.execute_loop(history, prompt, session_id, event_tx)
831 .await
832 };
833
834 match &result {
835 Ok(r) => tracing::info!(
836 a3s.agent.tool_calls_count = r.tool_calls_count,
837 a3s.llm.total_tokens = r.usage.total_tokens,
838 "a3s.agent.execute completed"
839 ),
840 Err(e) => tracing::warn!(
841 error = %e,
842 "a3s.agent.execute failed"
843 ),
844 }
845
846 result
847 }
848
849 async fn execute_loop(
855 &self,
856 history: &[Message],
857 prompt: &str,
858 session_id: Option<&str>,
859 event_tx: Option<mpsc::Sender<AgentEvent>>,
860 ) -> Result<AgentResult> {
861 let mut messages = history.to_vec();
862 let mut total_usage = TokenUsage::default();
863 let mut tool_calls_count = 0;
864 let mut turn = 0;
865
866 if let Some(tx) = &event_tx {
868 tx.send(AgentEvent::Start {
869 prompt: prompt.to_string(),
870 })
871 .await
872 .ok();
873 }
874
875 let mut augmented_system = if !self.config.context_providers.is_empty() {
877 if let Some(tx) = &event_tx {
879 let provider_names: Vec<String> = self
880 .config
881 .context_providers
882 .iter()
883 .map(|p| p.name().to_string())
884 .collect();
885 tx.send(AgentEvent::ContextResolving {
886 providers: provider_names,
887 })
888 .await
889 .ok();
890 }
891
892 tracing::info!(
893 a3s.context.providers = self.config.context_providers.len() as i64,
894 "Context resolution started"
895 );
896 let context_results = self.resolve_context(prompt, session_id).await;
897
898 if let Some(tx) = &event_tx {
900 let total_items: usize = context_results.iter().map(|r| r.items.len()).sum();
901 let total_tokens: usize = context_results.iter().map(|r| r.total_tokens).sum();
902
903 tracing::info!(
904 context_items = total_items,
905 context_tokens = total_tokens,
906 "Context resolution completed"
907 );
908
909 tx.send(AgentEvent::ContextResolved {
910 total_items,
911 total_tokens,
912 })
913 .await
914 .ok();
915 }
916
917 self.build_augmented_system_prompt(&context_results)
918 } else {
919 self.config.system_prompt.clone()
920 };
921
922 messages.push(Message::user(prompt));
924
925 loop {
926 turn += 1;
927
928 if turn > self.config.max_tool_rounds {
929 let error = format!("Max tool rounds ({}) exceeded", self.config.max_tool_rounds);
930 if let Some(tx) = &event_tx {
931 tx.send(AgentEvent::Error {
932 message: error.clone(),
933 })
934 .await
935 .ok();
936 }
937 anyhow::bail!(error);
938 }
939
940 if let Some(tx) = &event_tx {
942 tx.send(AgentEvent::TurnStart { turn }).await.ok();
943 }
944
945 tracing::info!(
946 turn = turn,
947 max_turns = self.config.max_tool_rounds,
948 "Agent turn started"
949 );
950
951 tracing::info!(
953 a3s.llm.streaming = event_tx.is_some(),
954 "LLM completion started"
955 );
956
957 self.fire_generate_start(session_id.unwrap_or(""), prompt, &augmented_system)
959 .await;
960
961 let llm_start = std::time::Instant::now();
962 let response = if event_tx.is_some() {
963 let mut stream_rx = self
965 .llm_client
966 .complete_streaming(&messages, augmented_system.as_deref(), &self.config.tools)
967 .await
968 .context("LLM streaming call failed")?;
969
970 let mut final_response: Option<LlmResponse> = None;
971
972 while let Some(event) = stream_rx.recv().await {
973 match event {
974 crate::llm::StreamEvent::TextDelta(text) => {
975 if let Some(tx) = &event_tx {
976 tx.send(AgentEvent::TextDelta { text }).await.ok();
977 }
978 }
979 crate::llm::StreamEvent::ToolUseStart { id, name } => {
980 if let Some(tx) = &event_tx {
981 tx.send(AgentEvent::ToolStart { id, name }).await.ok();
982 }
983 }
984 crate::llm::StreamEvent::ToolUseInputDelta(_) => {
985 }
987 crate::llm::StreamEvent::Done(resp) => {
988 final_response = Some(resp);
989 }
990 }
991 }
992
993 final_response.context("Stream ended without final response")?
994 } else {
995 self.llm_client
997 .complete(&messages, augmented_system.as_deref(), &self.config.tools)
998 .await
999 .context("LLM call failed")?
1000 };
1001
1002 total_usage.prompt_tokens += response.usage.prompt_tokens;
1004 total_usage.completion_tokens += response.usage.completion_tokens;
1005 total_usage.total_tokens += response.usage.total_tokens;
1006
1007 let llm_duration = llm_start.elapsed();
1009 tracing::info!(
1010 turn = turn,
1011 streaming = event_tx.is_some(),
1012 prompt_tokens = response.usage.prompt_tokens,
1013 completion_tokens = response.usage.completion_tokens,
1014 total_tokens = response.usage.total_tokens,
1015 stop_reason = response.stop_reason.as_deref().unwrap_or("unknown"),
1016 duration_ms = llm_duration.as_millis() as u64,
1017 "LLM completion finished"
1018 );
1019
1020 self.fire_generate_end(
1022 session_id.unwrap_or(""),
1023 prompt,
1024 &response,
1025 llm_duration.as_millis() as u64,
1026 )
1027 .await;
1028
1029 crate::telemetry::record_llm_usage(
1031 response.usage.prompt_tokens,
1032 response.usage.completion_tokens,
1033 response.usage.total_tokens,
1034 response.stop_reason.as_deref(),
1035 );
1036 tracing::info!(
1038 turn = turn,
1039 a3s.llm.total_tokens = response.usage.total_tokens,
1040 "Turn token usage"
1041 );
1042
1043 messages.push(response.message.clone());
1045
1046 let tool_calls = response.tool_calls();
1048
1049 if let Some(tx) = &event_tx {
1051 tx.send(AgentEvent::TurnEnd {
1052 turn,
1053 usage: response.usage.clone(),
1054 })
1055 .await
1056 .ok();
1057 }
1058
1059 if tool_calls.is_empty() {
1060 let final_text = response.text();
1062
1063 tracing::info!(
1065 tool_calls_count = tool_calls_count,
1066 total_prompt_tokens = total_usage.prompt_tokens,
1067 total_completion_tokens = total_usage.completion_tokens,
1068 total_tokens = total_usage.total_tokens,
1069 turns = turn,
1070 "Agent execution completed"
1071 );
1072
1073 if let Some(tx) = &event_tx {
1074 tx.send(AgentEvent::End {
1075 text: final_text.clone(),
1076 usage: total_usage.clone(),
1077 })
1078 .await
1079 .ok();
1080 }
1081
1082 if let Some(sid) = session_id {
1084 self.notify_turn_complete(sid, prompt, &final_text).await;
1085 }
1086
1087 return Ok(AgentResult {
1088 text: final_text,
1089 messages,
1090 usage: total_usage,
1091 tool_calls_count,
1092 });
1093 }
1094
1095 let (mut query_tools, sequential_tools) = if self.command_queue.is_some() {
1097 partition_by_lane(&tool_calls)
1098 } else {
1099 (Vec::new(), tool_calls.clone())
1100 };
1101
1102 let use_parallel = self.should_use_parallel_execution(&query_tools);
1104
1105 if !query_tools.is_empty() && use_parallel {
1107 if let Some(queue) = &self.command_queue {
1108 let parallel_count = self
1109 .execute_query_tools_parallel(
1110 &query_tools,
1111 queue,
1112 &mut messages,
1113 &event_tx,
1114 &mut augmented_system,
1115 session_id,
1116 )
1117 .await;
1118 tool_calls_count += parallel_count;
1119 query_tools.clear(); }
1121 }
1122
1123 let mut all_sequential_tools = query_tools;
1125 all_sequential_tools.extend(sequential_tools);
1126
1127 for tool_call in all_sequential_tools {
1129 tool_calls_count += 1;
1130
1131 let tool_start = std::time::Instant::now();
1132
1133 tracing::info!(
1134 tool_name = tool_call.name.as_str(),
1135 tool_id = tool_call.id.as_str(),
1136 "Tool execution started"
1137 );
1138
1139 if let Some(parse_error) =
1145 tool_call.args.get("__parse_error").and_then(|v| v.as_str())
1146 {
1147 let error_msg = format!("Error: {}", parse_error);
1148 tracing::warn!(
1149 tool = tool_call.name.as_str(),
1150 "Malformed tool arguments from LLM"
1151 );
1152
1153 if let Some(tx) = &event_tx {
1154 tx.send(AgentEvent::ToolEnd {
1155 id: tool_call.id.clone(),
1156 name: tool_call.name.clone(),
1157 output: error_msg.clone(),
1158 exit_code: 1,
1159 })
1160 .await
1161 .ok();
1162 }
1163
1164 messages.push(Message::tool_result(&tool_call.id, &error_msg, true));
1165 continue;
1166 }
1167
1168 if let Some(HookResult::Block(reason)) = self
1170 .fire_pre_tool_use(session_id.unwrap_or(""), &tool_call.name, &tool_call.args)
1171 .await
1172 {
1173 let msg = format!("Tool '{}' blocked by hook: {}", tool_call.name, reason);
1174 tracing::info!(
1175 tool_name = tool_call.name.as_str(),
1176 "Tool blocked by PreToolUse hook"
1177 );
1178
1179 if let Some(tx) = &event_tx {
1180 tx.send(AgentEvent::PermissionDenied {
1181 tool_id: tool_call.id.clone(),
1182 tool_name: tool_call.name.clone(),
1183 args: tool_call.args.clone(),
1184 reason: reason.clone(),
1185 })
1186 .await
1187 .ok();
1188 }
1189
1190 messages.push(Message::tool_result(&tool_call.id, &msg, true));
1191 continue;
1192 }
1193
1194 let permission_decision = if let Some(checker) = &self.config.permission_checker
1196 {
1197 checker.check(&tool_call.name, &tool_call.args)
1198 } else {
1199 PermissionDecision::Ask
1201 };
1202
1203 let (output, exit_code, is_error, _metadata) = match permission_decision {
1204 PermissionDecision::Deny => {
1205 tracing::info!(
1206 tool_name = tool_call.name.as_str(),
1207 permission = "deny",
1208 "Tool permission denied"
1209 );
1210 let denial_msg = format!(
1212 "Permission denied: Tool '{}' is blocked by permission policy.",
1213 tool_call.name
1214 );
1215
1216 if let Some(tx) = &event_tx {
1218 tx.send(AgentEvent::PermissionDenied {
1219 tool_id: tool_call.id.clone(),
1220 tool_name: tool_call.name.clone(),
1221 args: tool_call.args.clone(),
1222 reason: "Blocked by deny rule in permission policy".to_string(),
1223 })
1224 .await
1225 .ok();
1226 }
1227
1228 (denial_msg, 1, true, None)
1229 }
1230 PermissionDecision::Allow => {
1231 tracing::info!(
1232 tool_name = tool_call.name.as_str(),
1233 permission = "allow",
1234 "Tool permission: allow"
1235 );
1236 let stream_ctx =
1238 self.streaming_tool_context(&event_tx, &tool_call.id, &tool_call.name);
1239 let result = self
1240 .tool_executor
1241 .execute_with_context(&tool_call.name, &tool_call.args, &stream_ctx)
1242 .await;
1243
1244 match result {
1245 Ok(r) => (r.output, r.exit_code, r.exit_code != 0, r.metadata),
1246 Err(e) => (format!("Tool execution error: {}", e), 1, true, None),
1247 }
1248 }
1249 PermissionDecision::Ask => {
1250 tracing::info!(
1251 tool_name = tool_call.name.as_str(),
1252 permission = "ask",
1253 "Tool permission: ask"
1254 );
1255 if let Some(cm) = &self.config.confirmation_manager {
1257 if !cm.requires_confirmation(&tool_call.name).await {
1259 let stream_ctx = self.streaming_tool_context(
1260 &event_tx,
1261 &tool_call.id,
1262 &tool_call.name,
1263 );
1264 let result = self
1265 .tool_executor
1266 .execute_with_context(
1267 &tool_call.name,
1268 &tool_call.args,
1269 &stream_ctx,
1270 )
1271 .await;
1272
1273 let (output, exit_code, is_error, _metadata) = match result {
1274 Ok(r) => (r.output, r.exit_code, r.exit_code != 0, r.metadata),
1275 Err(e) => {
1276 (format!("Tool execution error: {}", e), 1, true, None)
1277 }
1278 };
1279
1280 messages.push(Message::tool_result(
1282 &tool_call.id,
1283 &output,
1284 is_error,
1285 ));
1286
1287 let tool_duration = tool_start.elapsed();
1289 crate::telemetry::record_tool_result(exit_code, tool_duration);
1290
1291 self.fire_post_tool_use(
1293 session_id.unwrap_or(""),
1294 &tool_call.name,
1295 &tool_call.args,
1296 &output,
1297 exit_code == 0,
1298 tool_duration.as_millis() as u64,
1299 )
1300 .await;
1301
1302 continue; }
1304
1305 let policy = cm.policy().await;
1307 let timeout_ms = policy.default_timeout_ms;
1308 let timeout_action = policy.timeout_action;
1309
1310 let rx = cm
1312 .request_confirmation(
1313 &tool_call.id,
1314 &tool_call.name,
1315 &tool_call.args,
1316 )
1317 .await;
1318
1319 let confirmation_result =
1321 tokio::time::timeout(Duration::from_millis(timeout_ms), rx).await;
1322
1323 match confirmation_result {
1324 Ok(Ok(response)) => {
1325 if response.approved {
1326 let stream_ctx = self.streaming_tool_context(
1327 &event_tx,
1328 &tool_call.id,
1329 &tool_call.name,
1330 );
1331 let result = self
1332 .tool_executor
1333 .execute_with_context(
1334 &tool_call.name,
1335 &tool_call.args,
1336 &stream_ctx,
1337 )
1338 .await;
1339
1340 match result {
1341 Ok(r) => (
1342 r.output,
1343 r.exit_code,
1344 r.exit_code != 0,
1345 r.metadata,
1346 ),
1347 Err(e) => (
1348 format!("Tool execution error: {}", e),
1349 1,
1350 true,
1351 None,
1352 ),
1353 }
1354 } else {
1355 let rejection_msg = format!(
1356 "Tool '{}' execution was rejected by user. Reason: {}",
1357 tool_call.name,
1358 response.reason.unwrap_or_else(|| "No reason provided".to_string())
1359 );
1360 (rejection_msg, 1, true, None)
1361 }
1362 }
1363 Ok(Err(_)) => {
1364 let msg = format!(
1365 "Tool '{}' confirmation failed: confirmation channel closed",
1366 tool_call.name
1367 );
1368 (msg, 1, true, None)
1369 }
1370 Err(_) => {
1371 cm.check_timeouts().await;
1372
1373 match timeout_action {
1374 crate::hitl::TimeoutAction::Reject => {
1375 let msg = format!(
1376 "Tool '{}' execution timed out waiting for confirmation ({}ms). Execution rejected.",
1377 tool_call.name, timeout_ms
1378 );
1379 (msg, 1, true, None)
1380 }
1381 crate::hitl::TimeoutAction::AutoApprove => {
1382 let stream_ctx = self.streaming_tool_context(
1383 &event_tx,
1384 &tool_call.id,
1385 &tool_call.name,
1386 );
1387 let result = self
1388 .tool_executor
1389 .execute_with_context(
1390 &tool_call.name,
1391 &tool_call.args,
1392 &stream_ctx,
1393 )
1394 .await;
1395
1396 match result {
1397 Ok(r) => (
1398 r.output,
1399 r.exit_code,
1400 r.exit_code != 0,
1401 r.metadata,
1402 ),
1403 Err(e) => (
1404 format!("Tool execution error: {}", e),
1405 1,
1406 true,
1407 None,
1408 ),
1409 }
1410 }
1411 }
1412 }
1413 }
1414 } else {
1415 let msg = format!(
1417 "Tool '{}' requires confirmation but no HITL confirmation manager is configured. \
1418 Configure a confirmation policy to enable tool execution.",
1419 tool_call.name
1420 );
1421 tracing::warn!(
1422 tool_name = tool_call.name.as_str(),
1423 "Tool requires confirmation but no HITL manager configured"
1424 );
1425 (msg, 1, true, None)
1426 }
1427 }
1428 };
1429
1430 let tool_duration = tool_start.elapsed();
1431 crate::telemetry::record_tool_result(exit_code, tool_duration);
1432
1433 self.fire_post_tool_use(
1435 session_id.unwrap_or(""),
1436 &tool_call.name,
1437 &tool_call.args,
1438 &output,
1439 exit_code == 0,
1440 tool_duration.as_millis() as u64,
1441 )
1442 .await;
1443
1444 if let Some(tx) = &event_tx {
1446 tx.send(AgentEvent::ToolEnd {
1447 id: tool_call.id.clone(),
1448 name: tool_call.name.clone(),
1449 output: output.clone(),
1450 exit_code,
1451 })
1452 .await
1453 .ok();
1454 }
1455
1456 messages.push(Message::tool_result(&tool_call.id, &output, is_error));
1458 }
1459 }
1460 }
1461
1462 fn should_use_parallel_execution(&self, query_tools: &[ToolCall]) -> bool {
1469 if self.command_queue.is_none() {
1471 tracing::debug!("Parallel execution bypassed: no queue configured");
1472 return false;
1473 }
1474
1475 if query_tools.len() < 8 {
1478 tracing::info!(
1479 tool_count = query_tools.len(),
1480 "Parallel execution bypassed: too few tools (< 8)"
1481 );
1482 return false;
1483 }
1484
1485 let all_fast_ops = query_tools.iter().all(|t| {
1487 matches!(t.name.as_str(), "glob" | "ls" | "list_files")
1488 });
1489
1490 if all_fast_ops {
1492 tracing::info!(
1493 tool_count = query_tools.len(),
1494 "Parallel execution bypassed: all fast operations (glob/ls/list_files)"
1495 );
1496 return false;
1497 }
1498
1499 tracing::info!(
1501 tool_count = query_tools.len(),
1502 "Using parallel execution for Query-lane tools"
1503 );
1504 true
1505 }
1506
1507 async fn execute_query_tools_parallel(
1513 &self,
1514 query_tools: &[ToolCall],
1515 queue: &SessionLaneQueue,
1516 messages: &mut Vec<Message>,
1517 event_tx: &Option<mpsc::Sender<AgentEvent>>,
1518 _augmented_system: &mut Option<String>,
1519 session_id: Option<&str>,
1520 ) -> usize {
1521 let mut commands_to_submit = Vec::with_capacity(query_tools.len());
1523 let mut tool_calls_to_execute = Vec::with_capacity(query_tools.len());
1524
1525 for tool_call in query_tools {
1526 if let Some(parse_error) = tool_call.args.get("__parse_error").and_then(|v| v.as_str())
1528 {
1529 let error_msg = format!("Error: {}", parse_error);
1530 if let Some(tx) = event_tx {
1531 tx.send(AgentEvent::ToolEnd {
1532 id: tool_call.id.clone(),
1533 name: tool_call.name.clone(),
1534 output: error_msg.clone(),
1535 exit_code: 1,
1536 })
1537 .await
1538 .ok();
1539 }
1540 messages.push(Message::tool_result(&tool_call.id, &error_msg, true));
1541 continue;
1542 }
1543
1544 if let Some(HookResult::Block(reason)) = self
1546 .fire_pre_tool_use(session_id.unwrap_or(""), &tool_call.name, &tool_call.args)
1547 .await
1548 {
1549 let msg = format!("Tool '{}' blocked by hook: {}", tool_call.name, reason);
1550 if let Some(tx) = event_tx {
1551 tx.send(AgentEvent::PermissionDenied {
1552 tool_id: tool_call.id.clone(),
1553 tool_name: tool_call.name.clone(),
1554 args: tool_call.args.clone(),
1555 reason,
1556 })
1557 .await
1558 .ok();
1559 }
1560 messages.push(Message::tool_result(&tool_call.id, &msg, true));
1561 continue;
1562 }
1563
1564
1565 let permission_decision = if let Some(checker) = &self.config.permission_checker {
1567 checker.check(&tool_call.name, &tool_call.args)
1568 } else {
1569 PermissionDecision::Ask
1570 };
1571
1572 match permission_decision {
1573 PermissionDecision::Deny => {
1574 let denial_msg = format!(
1575 "Permission denied: Tool '{}' is blocked by permission policy.",
1576 tool_call.name
1577 );
1578 if let Some(tx) = event_tx {
1579 tx.send(AgentEvent::PermissionDenied {
1580 tool_id: tool_call.id.clone(),
1581 tool_name: tool_call.name.clone(),
1582 args: tool_call.args.clone(),
1583 reason: "Blocked by deny rule in permission policy".to_string(),
1584 })
1585 .await
1586 .ok();
1587 }
1588 messages.push(Message::tool_result(&tool_call.id, &denial_msg, true));
1589 continue;
1590 }
1591 PermissionDecision::Allow | PermissionDecision::Ask => {
1592 if permission_decision == PermissionDecision::Ask {
1597 if let Some(cm) = &self.config.confirmation_manager {
1598 if cm.requires_confirmation(&tool_call.name).await {
1599 continue;
1602 }
1603 }
1604 }
1605
1606 let cmd = ToolCommand {
1608 tool_executor: self.tool_executor.clone(),
1609 tool_name: tool_call.name.clone(),
1610 tool_args: tool_call.args.clone(),
1611 tool_context: self.tool_context.clone(),
1612 skill_registry: self.config.skill_registry.clone(),
1613 };
1614 commands_to_submit.push(Box::new(cmd) as Box<dyn crate::queue::SessionCommand>);
1615 tool_calls_to_execute.push(tool_call.clone());
1616 }
1617 }
1618 }
1619
1620 let receivers = queue.submit_batch(crate::queue::SessionLane::Query, commands_to_submit).await;
1622 let tool_starts: Vec<_> = tool_calls_to_execute.iter().map(|_| std::time::Instant::now()).collect();
1623
1624 let count = receivers.len();
1625
1626 let results = join_all(receivers).await;
1628
1629 for (i, result) in results.into_iter().enumerate() {
1630 let tool_call = &tool_calls_to_execute[i];
1631 let tool_start = &tool_starts[i];
1632 let tool_duration = tool_start.elapsed();
1633
1634 let (output, exit_code, is_error, _metadata) = match result {
1635 Ok(Ok(value)) => {
1636 let output = value["output"].as_str().unwrap_or("").to_string();
1637 let exit_code = value["exit_code"].as_i64().unwrap_or(0) as i32;
1638 let metadata = value.get("metadata").cloned();
1639 (output, exit_code, exit_code != 0, metadata)
1640 }
1641 Ok(Err(e)) => (format!("Tool execution error: {}", e), 1, true, None),
1642 Err(_) => ("Queue channel closed".to_string(), 1, true, None),
1643 };
1644
1645 self.fire_post_tool_use(
1647 session_id.unwrap_or(""),
1648 &tool_call.name,
1649 &tool_call.args,
1650 &output,
1651 exit_code == 0,
1652 tool_duration.as_millis() as u64,
1653 )
1654 .await;
1655
1656 if let Some(tx) = event_tx {
1658 tx.send(AgentEvent::ToolEnd {
1659 id: tool_call.id.clone(),
1660 name: tool_call.name.clone(),
1661 output: output.clone(),
1662 exit_code,
1663 })
1664 .await
1665 .ok();
1666 }
1667
1668 messages.push(Message::tool_result(&tool_call.id, &output, is_error));
1669 }
1670
1671 count
1672 }
1673
1674 pub async fn execute_streaming(
1676 &self,
1677 history: &[Message],
1678 prompt: &str,
1679 ) -> Result<(
1680 mpsc::Receiver<AgentEvent>,
1681 tokio::task::JoinHandle<Result<AgentResult>>,
1682 )> {
1683 let (tx, rx) = mpsc::channel(100);
1684
1685 let llm_client = self.llm_client.clone();
1686 let tool_executor = self.tool_executor.clone();
1687 let tool_context = self.tool_context.clone();
1688 let config = self.config.clone();
1689 let tool_metrics = self.tool_metrics.clone();
1690 let command_queue = self.command_queue.clone();
1691 let history = history.to_vec();
1692 let prompt = prompt.to_string();
1693
1694 let handle = tokio::spawn(async move {
1695 let mut agent = AgentLoop::new(llm_client, tool_executor, tool_context, config);
1696 if let Some(metrics) = tool_metrics {
1697 agent = agent.with_tool_metrics(metrics);
1698 }
1699 if let Some(queue) = command_queue {
1700 agent = agent.with_queue(queue);
1701 }
1702 agent.execute(&history, &prompt, Some(tx)).await
1703 });
1704
1705 Ok((rx, handle))
1706 }
1707
1708 pub async fn plan(&self, prompt: &str, _context: Option<&str>) -> Result<ExecutionPlan> {
1713 use crate::planning::LlmPlanner;
1714
1715 match LlmPlanner::create_plan(&self.llm_client, prompt).await {
1716 Ok(plan) => Ok(plan),
1717 Err(e) => {
1718 tracing::warn!("LLM plan creation failed, using fallback: {}", e);
1719 Ok(LlmPlanner::fallback_plan(prompt))
1720 }
1721 }
1722 }
1723
1724 pub async fn execute_with_planning(
1726 &self,
1727 history: &[Message],
1728 prompt: &str,
1729 event_tx: Option<mpsc::Sender<AgentEvent>>,
1730 ) -> Result<AgentResult> {
1731 if let Some(tx) = &event_tx {
1733 tx.send(AgentEvent::PlanningStart {
1734 prompt: prompt.to_string(),
1735 })
1736 .await
1737 .ok();
1738 }
1739
1740 let plan = self.plan(prompt, None).await?;
1742
1743 if let Some(tx) = &event_tx {
1745 tx.send(AgentEvent::PlanningEnd {
1746 estimated_steps: plan.steps.len(),
1747 plan: plan.clone(),
1748 })
1749 .await
1750 .ok();
1751 }
1752
1753 self.execute_plan(history, &plan, event_tx).await
1755 }
1756
1757 async fn execute_plan(
1764 &self,
1765 history: &[Message],
1766 plan: &ExecutionPlan,
1767 event_tx: Option<mpsc::Sender<AgentEvent>>,
1768 ) -> Result<AgentResult> {
1769 let mut plan = plan.clone();
1770 let mut current_history = history.to_vec();
1771 let mut total_usage = TokenUsage::default();
1772 let mut tool_calls_count = 0;
1773 let total_steps = plan.steps.len();
1774
1775 let steps_text = plan
1777 .steps
1778 .iter()
1779 .enumerate()
1780 .map(|(i, step)| format!("{}. {}", i + 1, step.content))
1781 .collect::<Vec<_>>()
1782 .join("\n");
1783 current_history.push(Message::user(&crate::prompts::render(
1784 crate::prompts::PLAN_EXECUTE_GOAL,
1785 &[("goal", &plan.goal), ("steps", &steps_text)],
1786 )));
1787
1788 loop {
1789 let ready: Vec<String> = plan
1790 .get_ready_steps()
1791 .iter()
1792 .map(|s| s.id.clone())
1793 .collect();
1794
1795 if ready.is_empty() {
1796 if plan.has_deadlock() {
1798 tracing::warn!(
1799 "Plan deadlock detected: {} pending steps with unresolvable dependencies",
1800 plan.pending_count()
1801 );
1802 }
1803 break;
1804 }
1805
1806 if ready.len() == 1 {
1807 let step_id = &ready[0];
1809 let step = plan
1810 .steps
1811 .iter()
1812 .find(|s| s.id == *step_id)
1813 .unwrap()
1814 .clone();
1815 let step_number = plan.steps.iter().position(|s| s.id == *step_id).unwrap() + 1;
1816
1817 if let Some(tx) = &event_tx {
1819 tx.send(AgentEvent::StepStart {
1820 step_id: step.id.clone(),
1821 description: step.content.clone(),
1822 step_number,
1823 total_steps,
1824 })
1825 .await
1826 .ok();
1827 }
1828
1829 plan.mark_status(&step.id, TaskStatus::InProgress);
1830
1831 let step_prompt = crate::prompts::render(
1832 crate::prompts::PLAN_EXECUTE_STEP,
1833 &[
1834 ("step_num", &step_number.to_string()),
1835 ("description", &step.content),
1836 ],
1837 );
1838
1839 match self
1840 .execute_loop(¤t_history, &step_prompt, None, event_tx.clone())
1841 .await
1842 {
1843 Ok(result) => {
1844 current_history = result.messages.clone();
1845 total_usage.prompt_tokens += result.usage.prompt_tokens;
1846 total_usage.completion_tokens += result.usage.completion_tokens;
1847 total_usage.total_tokens += result.usage.total_tokens;
1848 tool_calls_count += result.tool_calls_count;
1849 plan.mark_status(&step.id, TaskStatus::Completed);
1850
1851 if let Some(tx) = &event_tx {
1852 tx.send(AgentEvent::StepEnd {
1853 step_id: step.id.clone(),
1854 status: TaskStatus::Completed,
1855 step_number,
1856 total_steps,
1857 })
1858 .await
1859 .ok();
1860 }
1861 }
1862 Err(e) => {
1863 tracing::error!("Plan step '{}' failed: {}", step.id, e);
1864 plan.mark_status(&step.id, TaskStatus::Failed);
1865
1866 if let Some(tx) = &event_tx {
1867 tx.send(AgentEvent::StepEnd {
1868 step_id: step.id.clone(),
1869 status: TaskStatus::Failed,
1870 step_number,
1871 total_steps,
1872 })
1873 .await
1874 .ok();
1875 }
1876 }
1877 }
1878 } else {
1879 let ready_steps: Vec<_> = ready
1881 .iter()
1882 .map(|id| {
1883 let step = plan.steps.iter().find(|s| s.id == *id).unwrap().clone();
1884 let step_number = plan.steps.iter().position(|s| s.id == *id).unwrap() + 1;
1885 (step, step_number)
1886 })
1887 .collect();
1888
1889 for (step, step_number) in &ready_steps {
1891 plan.mark_status(&step.id, TaskStatus::InProgress);
1892 if let Some(tx) = &event_tx {
1893 tx.send(AgentEvent::StepStart {
1894 step_id: step.id.clone(),
1895 description: step.content.clone(),
1896 step_number: *step_number,
1897 total_steps,
1898 })
1899 .await
1900 .ok();
1901 }
1902 }
1903
1904 let mut join_set = tokio::task::JoinSet::new();
1906 for (step, step_number) in &ready_steps {
1907 let base_history = current_history.clone();
1908 let agent_clone = self.clone();
1909 let tx = event_tx.clone();
1910 let step_clone = step.clone();
1911 let sn = *step_number;
1912
1913 join_set.spawn(async move {
1914 let prompt = crate::prompts::render(
1915 crate::prompts::PLAN_EXECUTE_STEP,
1916 &[
1917 ("step_num", &sn.to_string()),
1918 ("description", &step_clone.content),
1919 ],
1920 );
1921 let result = agent_clone
1922 .execute_loop(&base_history, &prompt, None, tx)
1923 .await;
1924 (step_clone.id, sn, result)
1925 });
1926 }
1927
1928 let mut parallel_summaries = Vec::new();
1930 while let Some(join_result) = join_set.join_next().await {
1931 match join_result {
1932 Ok((step_id, step_number, step_result)) => match step_result {
1933 Ok(result) => {
1934 total_usage.prompt_tokens += result.usage.prompt_tokens;
1935 total_usage.completion_tokens += result.usage.completion_tokens;
1936 total_usage.total_tokens += result.usage.total_tokens;
1937 tool_calls_count += result.tool_calls_count;
1938 plan.mark_status(&step_id, TaskStatus::Completed);
1939
1940 parallel_summaries.push(format!(
1942 "- Step {} ({}): {}",
1943 step_number, step_id, result.text
1944 ));
1945
1946 if let Some(tx) = &event_tx {
1947 tx.send(AgentEvent::StepEnd {
1948 step_id,
1949 status: TaskStatus::Completed,
1950 step_number,
1951 total_steps,
1952 })
1953 .await
1954 .ok();
1955 }
1956 }
1957 Err(e) => {
1958 tracing::error!("Plan step '{}' failed: {}", step_id, e);
1959 plan.mark_status(&step_id, TaskStatus::Failed);
1960
1961 if let Some(tx) = &event_tx {
1962 tx.send(AgentEvent::StepEnd {
1963 step_id,
1964 status: TaskStatus::Failed,
1965 step_number,
1966 total_steps,
1967 })
1968 .await
1969 .ok();
1970 }
1971 }
1972 },
1973 Err(e) => {
1974 tracing::error!("JoinSet task panicked: {}", e);
1975 }
1976 }
1977 }
1978
1979 if !parallel_summaries.is_empty() {
1981 parallel_summaries.sort(); let results_text = parallel_summaries.join("\n");
1983 current_history.push(Message::user(&crate::prompts::render(
1984 crate::prompts::PLAN_PARALLEL_RESULTS,
1985 &[("results", &results_text)],
1986 )));
1987 }
1988 }
1989
1990 if self.config.goal_tracking {
1992 let completed = plan
1993 .steps
1994 .iter()
1995 .filter(|s| s.status == TaskStatus::Completed)
1996 .count();
1997 if let Some(tx) = &event_tx {
1998 tx.send(AgentEvent::GoalProgress {
1999 goal: plan.goal.clone(),
2000 progress: plan.progress(),
2001 completed_steps: completed,
2002 total_steps,
2003 })
2004 .await
2005 .ok();
2006 }
2007 }
2008 }
2009
2010 let final_text = current_history
2012 .last()
2013 .map(|m| {
2014 m.content
2015 .iter()
2016 .filter_map(|block| {
2017 if let crate::llm::ContentBlock::Text { text } = block {
2018 Some(text.as_str())
2019 } else {
2020 None
2021 }
2022 })
2023 .collect::<Vec<_>>()
2024 .join("\n")
2025 })
2026 .unwrap_or_default();
2027
2028 Ok(AgentResult {
2029 text: final_text,
2030 messages: current_history,
2031 usage: total_usage,
2032 tool_calls_count,
2033 })
2034 }
2035
2036 pub async fn extract_goal(&self, prompt: &str) -> Result<AgentGoal> {
2041 use crate::planning::LlmPlanner;
2042
2043 match LlmPlanner::extract_goal(&self.llm_client, prompt).await {
2044 Ok(goal) => Ok(goal),
2045 Err(e) => {
2046 tracing::warn!("LLM goal extraction failed, using fallback: {}", e);
2047 Ok(LlmPlanner::fallback_goal(prompt))
2048 }
2049 }
2050 }
2051
2052 pub async fn check_goal_achievement(
2057 &self,
2058 goal: &AgentGoal,
2059 current_state: &str,
2060 ) -> Result<bool> {
2061 use crate::planning::LlmPlanner;
2062
2063 match LlmPlanner::check_achievement(&self.llm_client, goal, current_state).await {
2064 Ok(result) => Ok(result.achieved),
2065 Err(e) => {
2066 tracing::warn!("LLM achievement check failed, using fallback: {}", e);
2067 let result = LlmPlanner::fallback_check_achievement(goal, current_state);
2068 Ok(result.achieved)
2069 }
2070 }
2071 }
2072}
2073
2074#[cfg(test)]
2075mod tests {
2076 use super::*;
2077 use crate::llm::{ContentBlock, StreamEvent};
2078 use crate::permissions::PermissionPolicy;
2079 use crate::tools::ToolExecutor;
2080 use std::path::PathBuf;
2081 use std::sync::atomic::{AtomicUsize, Ordering};
2082
2083 fn test_tool_context() -> ToolContext {
2085 ToolContext::new(PathBuf::from("/tmp"))
2086 }
2087
2088 #[test]
2089 fn test_agent_config_default() {
2090 let config = AgentConfig::default();
2091 assert!(config.system_prompt.is_none());
2092 assert!(config.tools.is_empty()); assert_eq!(config.max_tool_rounds, MAX_TOOL_ROUNDS);
2094 assert!(config.permission_checker.is_none());
2095 assert!(config.context_providers.is_empty());
2096 }
2097
2098 pub(crate) struct MockLlmClient {
2104 responses: std::sync::Mutex<Vec<LlmResponse>>,
2106 call_count: AtomicUsize,
2108 }
2109
2110 impl MockLlmClient {
2111 pub(crate) fn new(responses: Vec<LlmResponse>) -> Self {
2112 Self {
2113 responses: std::sync::Mutex::new(responses),
2114 call_count: AtomicUsize::new(0),
2115 }
2116 }
2117
2118 pub(crate) fn text_response(text: &str) -> LlmResponse {
2120 LlmResponse {
2121 message: Message {
2122 role: "assistant".to_string(),
2123 content: vec![ContentBlock::Text {
2124 text: text.to_string(),
2125 }],
2126 reasoning_content: None,
2127 },
2128 usage: TokenUsage {
2129 prompt_tokens: 10,
2130 completion_tokens: 5,
2131 total_tokens: 15,
2132 cache_read_tokens: None,
2133 cache_write_tokens: None,
2134 },
2135 stop_reason: Some("end_turn".to_string()),
2136 }
2137 }
2138
2139 pub(crate) fn tool_call_response(
2141 tool_id: &str,
2142 tool_name: &str,
2143 args: serde_json::Value,
2144 ) -> LlmResponse {
2145 LlmResponse {
2146 message: Message {
2147 role: "assistant".to_string(),
2148 content: vec![ContentBlock::ToolUse {
2149 id: tool_id.to_string(),
2150 name: tool_name.to_string(),
2151 input: args,
2152 }],
2153 reasoning_content: None,
2154 },
2155 usage: TokenUsage {
2156 prompt_tokens: 10,
2157 completion_tokens: 5,
2158 total_tokens: 15,
2159 cache_read_tokens: None,
2160 cache_write_tokens: None,
2161 },
2162 stop_reason: Some("tool_use".to_string()),
2163 }
2164 }
2165 }
2166
2167 #[async_trait::async_trait]
2168 impl LlmClient for MockLlmClient {
2169 async fn complete(
2170 &self,
2171 _messages: &[Message],
2172 _system: Option<&str>,
2173 _tools: &[ToolDefinition],
2174 ) -> Result<LlmResponse> {
2175 self.call_count.fetch_add(1, Ordering::SeqCst);
2176 let mut responses = self.responses.lock().unwrap();
2177 if responses.is_empty() {
2178 anyhow::bail!("No more mock responses available");
2179 }
2180 Ok(responses.remove(0))
2181 }
2182
2183 async fn complete_streaming(
2184 &self,
2185 _messages: &[Message],
2186 _system: Option<&str>,
2187 _tools: &[ToolDefinition],
2188 ) -> Result<mpsc::Receiver<StreamEvent>> {
2189 self.call_count.fetch_add(1, Ordering::SeqCst);
2190 let mut responses = self.responses.lock().unwrap();
2191 if responses.is_empty() {
2192 anyhow::bail!("No more mock responses available");
2193 }
2194 let response = responses.remove(0);
2195
2196 let (tx, rx) = mpsc::channel(10);
2197 tokio::spawn(async move {
2198 for block in &response.message.content {
2200 if let ContentBlock::Text { text } = block {
2201 tx.send(StreamEvent::TextDelta(text.clone())).await.ok();
2202 }
2203 }
2204 tx.send(StreamEvent::Done(response)).await.ok();
2205 });
2206
2207 Ok(rx)
2208 }
2209 }
2210
2211 #[tokio::test]
2216 async fn test_agent_simple_response() {
2217 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
2218 "Hello, I'm an AI assistant.",
2219 )]));
2220
2221 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2222 let config = AgentConfig::default();
2223
2224 let agent = AgentLoop::new(
2225 mock_client.clone(),
2226 tool_executor,
2227 test_tool_context(),
2228 config,
2229 );
2230 let result = agent.execute(&[], "Hello", None).await.unwrap();
2231
2232 assert_eq!(result.text, "Hello, I'm an AI assistant.");
2233 assert_eq!(result.tool_calls_count, 0);
2234 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 1);
2235 }
2236
2237 #[tokio::test]
2238 async fn test_agent_with_tool_call() {
2239 let mock_client = Arc::new(MockLlmClient::new(vec![
2240 MockLlmClient::tool_call_response(
2242 "tool-1",
2243 "bash",
2244 serde_json::json!({"command": "echo hello"}),
2245 ),
2246 MockLlmClient::text_response("The command output was: hello"),
2248 ]));
2249
2250 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2251 let config = AgentConfig::default();
2252
2253 let agent = AgentLoop::new(
2254 mock_client.clone(),
2255 tool_executor,
2256 test_tool_context(),
2257 config,
2258 );
2259 let result = agent.execute(&[], "Run echo hello", None).await.unwrap();
2260
2261 assert_eq!(result.text, "The command output was: hello");
2262 assert_eq!(result.tool_calls_count, 1);
2263 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 2);
2264 }
2265
2266 #[tokio::test]
2267 async fn test_agent_permission_deny() {
2268 let mock_client = Arc::new(MockLlmClient::new(vec![
2269 MockLlmClient::tool_call_response(
2271 "tool-1",
2272 "bash",
2273 serde_json::json!({"command": "rm -rf /tmp/test"}),
2274 ),
2275 MockLlmClient::text_response(
2277 "I cannot execute that command due to permission restrictions.",
2278 ),
2279 ]));
2280
2281 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2282
2283 let permission_policy = PermissionPolicy::new().deny("bash(rm:*)");
2285
2286 let config = AgentConfig {
2287 permission_checker: Some(Arc::new(permission_policy)),
2288 ..Default::default()
2289 };
2290
2291 let (tx, mut rx) = mpsc::channel(100);
2292 let agent = AgentLoop::new(
2293 mock_client.clone(),
2294 tool_executor,
2295 test_tool_context(),
2296 config,
2297 );
2298 let result = agent.execute(&[], "Delete files", Some(tx)).await.unwrap();
2299
2300 let mut found_permission_denied = false;
2302 while let Ok(event) = rx.try_recv() {
2303 if let AgentEvent::PermissionDenied { tool_name, .. } = event {
2304 assert_eq!(tool_name, "bash");
2305 found_permission_denied = true;
2306 }
2307 }
2308 assert!(
2309 found_permission_denied,
2310 "Should have received PermissionDenied event"
2311 );
2312
2313 assert_eq!(result.tool_calls_count, 1);
2314 }
2315
2316 #[tokio::test]
2317 async fn test_agent_permission_allow() {
2318 let mock_client = Arc::new(MockLlmClient::new(vec![
2319 MockLlmClient::tool_call_response(
2321 "tool-1",
2322 "bash",
2323 serde_json::json!({"command": "echo hello"}),
2324 ),
2325 MockLlmClient::text_response("Done!"),
2327 ]));
2328
2329 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2330
2331 let permission_policy = PermissionPolicy::new()
2333 .allow("bash(echo:*)")
2334 .deny("bash(rm:*)");
2335
2336 let config = AgentConfig {
2337 permission_checker: Some(Arc::new(permission_policy)),
2338 ..Default::default()
2339 };
2340
2341 let agent = AgentLoop::new(
2342 mock_client.clone(),
2343 tool_executor,
2344 test_tool_context(),
2345 config,
2346 );
2347 let result = agent.execute(&[], "Echo hello", None).await.unwrap();
2348
2349 assert_eq!(result.text, "Done!");
2350 assert_eq!(result.tool_calls_count, 1);
2351 }
2352
2353 #[tokio::test]
2354 async fn test_agent_streaming_events() {
2355 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
2356 "Hello!",
2357 )]));
2358
2359 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2360 let config = AgentConfig::default();
2361
2362 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2363 let (mut rx, handle) = agent.execute_streaming(&[], "Hi").await.unwrap();
2364
2365 let mut events = Vec::new();
2367 while let Some(event) = rx.recv().await {
2368 events.push(event);
2369 }
2370
2371 let result = handle.await.unwrap().unwrap();
2372 assert_eq!(result.text, "Hello!");
2373
2374 assert!(events.iter().any(|e| matches!(e, AgentEvent::Start { .. })));
2376 assert!(events.iter().any(|e| matches!(e, AgentEvent::End { .. })));
2377 }
2378
2379 #[tokio::test]
2380 async fn test_agent_max_tool_rounds() {
2381 let responses: Vec<LlmResponse> = (0..100)
2383 .map(|i| {
2384 MockLlmClient::tool_call_response(
2385 &format!("tool-{}", i),
2386 "bash",
2387 serde_json::json!({"command": "echo loop"}),
2388 )
2389 })
2390 .collect();
2391
2392 let mock_client = Arc::new(MockLlmClient::new(responses));
2393 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2394
2395 let config = AgentConfig {
2396 max_tool_rounds: 3,
2397 ..Default::default()
2398 };
2399
2400 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2401 let result = agent.execute(&[], "Loop forever", None).await;
2402
2403 assert!(result.is_err());
2405 assert!(result.unwrap_err().to_string().contains("Max tool rounds"));
2406 }
2407
2408 #[tokio::test]
2409 async fn test_agent_no_permission_policy_defaults_to_ask() {
2410 let mock_client = Arc::new(MockLlmClient::new(vec![
2413 MockLlmClient::tool_call_response(
2414 "tool-1",
2415 "bash",
2416 serde_json::json!({"command": "rm -rf /tmp/test"}),
2417 ),
2418 MockLlmClient::text_response("Denied!"),
2419 ]));
2420
2421 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2422 let config = AgentConfig {
2423 permission_checker: None, ..Default::default()
2426 };
2427
2428 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2429 let result = agent.execute(&[], "Delete", None).await.unwrap();
2430
2431 assert_eq!(result.text, "Denied!");
2433 assert_eq!(result.tool_calls_count, 1);
2434 }
2435
2436 #[tokio::test]
2437 async fn test_agent_permission_ask_without_cm_denies() {
2438 let mock_client = Arc::new(MockLlmClient::new(vec![
2441 MockLlmClient::tool_call_response(
2442 "tool-1",
2443 "bash",
2444 serde_json::json!({"command": "echo test"}),
2445 ),
2446 MockLlmClient::text_response("Denied!"),
2447 ]));
2448
2449 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2450
2451 let permission_policy = PermissionPolicy::new(); let config = AgentConfig {
2455 permission_checker: Some(Arc::new(permission_policy)),
2456 ..Default::default()
2458 };
2459
2460 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2461 let result = agent.execute(&[], "Echo", None).await.unwrap();
2462
2463 assert_eq!(result.text, "Denied!");
2465 assert!(result.tool_calls_count >= 1);
2467 }
2468
2469 #[tokio::test]
2474 async fn test_agent_hitl_approved() {
2475 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
2476 use tokio::sync::broadcast;
2477
2478 let mock_client = Arc::new(MockLlmClient::new(vec![
2479 MockLlmClient::tool_call_response(
2480 "tool-1",
2481 "bash",
2482 serde_json::json!({"command": "echo hello"}),
2483 ),
2484 MockLlmClient::text_response("Command executed!"),
2485 ]));
2486
2487 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2488
2489 let (event_tx, _event_rx) = broadcast::channel(100);
2491 let hitl_policy = ConfirmationPolicy {
2492 enabled: true,
2493 ..Default::default()
2494 };
2495 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
2496
2497 let permission_policy = PermissionPolicy::new(); let config = AgentConfig {
2501 permission_checker: Some(Arc::new(permission_policy)),
2502 confirmation_manager: Some(confirmation_manager.clone()),
2503 ..Default::default()
2504 };
2505
2506 let cm_clone = confirmation_manager.clone();
2508 tokio::spawn(async move {
2509 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
2511 cm_clone.confirm("tool-1", true, None).await.ok();
2513 });
2514
2515 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2516 let result = agent.execute(&[], "Run echo", None).await.unwrap();
2517
2518 assert_eq!(result.text, "Command executed!");
2519 assert_eq!(result.tool_calls_count, 1);
2520 }
2521
2522 #[tokio::test]
2523 async fn test_agent_hitl_rejected() {
2524 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
2525 use tokio::sync::broadcast;
2526
2527 let mock_client = Arc::new(MockLlmClient::new(vec![
2528 MockLlmClient::tool_call_response(
2529 "tool-1",
2530 "bash",
2531 serde_json::json!({"command": "rm -rf /"}),
2532 ),
2533 MockLlmClient::text_response("Understood, I won't do that."),
2534 ]));
2535
2536 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2537
2538 let (event_tx, _event_rx) = broadcast::channel(100);
2540 let hitl_policy = ConfirmationPolicy {
2541 enabled: true,
2542 ..Default::default()
2543 };
2544 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
2545
2546 let permission_policy = PermissionPolicy::new();
2548
2549 let config = AgentConfig {
2550 permission_checker: Some(Arc::new(permission_policy)),
2551 confirmation_manager: Some(confirmation_manager.clone()),
2552 ..Default::default()
2553 };
2554
2555 let cm_clone = confirmation_manager.clone();
2557 tokio::spawn(async move {
2558 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
2559 cm_clone
2560 .confirm("tool-1", false, Some("Too dangerous".to_string()))
2561 .await
2562 .ok();
2563 });
2564
2565 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2566 let result = agent.execute(&[], "Delete everything", None).await.unwrap();
2567
2568 assert_eq!(result.text, "Understood, I won't do that.");
2570 }
2571
2572 #[tokio::test]
2573 async fn test_agent_hitl_timeout_reject() {
2574 use crate::hitl::{ConfirmationManager, ConfirmationPolicy, TimeoutAction};
2575 use tokio::sync::broadcast;
2576
2577 let mock_client = Arc::new(MockLlmClient::new(vec![
2578 MockLlmClient::tool_call_response(
2579 "tool-1",
2580 "bash",
2581 serde_json::json!({"command": "echo test"}),
2582 ),
2583 MockLlmClient::text_response("Timed out, I understand."),
2584 ]));
2585
2586 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2587
2588 let (event_tx, _event_rx) = broadcast::channel(100);
2590 let hitl_policy = ConfirmationPolicy {
2591 enabled: true,
2592 default_timeout_ms: 50, timeout_action: TimeoutAction::Reject,
2594 ..Default::default()
2595 };
2596 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
2597
2598 let permission_policy = PermissionPolicy::new();
2599
2600 let config = AgentConfig {
2601 permission_checker: Some(Arc::new(permission_policy)),
2602 confirmation_manager: Some(confirmation_manager),
2603 ..Default::default()
2604 };
2605
2606 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2608 let result = agent.execute(&[], "Echo", None).await.unwrap();
2609
2610 assert_eq!(result.text, "Timed out, I understand.");
2612 }
2613
2614 #[tokio::test]
2615 async fn test_agent_hitl_timeout_auto_approve() {
2616 use crate::hitl::{ConfirmationManager, ConfirmationPolicy, TimeoutAction};
2617 use tokio::sync::broadcast;
2618
2619 let mock_client = Arc::new(MockLlmClient::new(vec![
2620 MockLlmClient::tool_call_response(
2621 "tool-1",
2622 "bash",
2623 serde_json::json!({"command": "echo hello"}),
2624 ),
2625 MockLlmClient::text_response("Auto-approved and executed!"),
2626 ]));
2627
2628 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2629
2630 let (event_tx, _event_rx) = broadcast::channel(100);
2632 let hitl_policy = ConfirmationPolicy {
2633 enabled: true,
2634 default_timeout_ms: 50, timeout_action: TimeoutAction::AutoApprove,
2636 ..Default::default()
2637 };
2638 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
2639
2640 let permission_policy = PermissionPolicy::new();
2641
2642 let config = AgentConfig {
2643 permission_checker: Some(Arc::new(permission_policy)),
2644 confirmation_manager: Some(confirmation_manager),
2645 ..Default::default()
2646 };
2647
2648 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2650 let result = agent.execute(&[], "Echo", None).await.unwrap();
2651
2652 assert_eq!(result.text, "Auto-approved and executed!");
2654 assert_eq!(result.tool_calls_count, 1);
2655 }
2656
2657 #[tokio::test]
2658 async fn test_agent_hitl_confirmation_events() {
2659 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
2660 use tokio::sync::broadcast;
2661
2662 let mock_client = Arc::new(MockLlmClient::new(vec![
2663 MockLlmClient::tool_call_response(
2664 "tool-1",
2665 "bash",
2666 serde_json::json!({"command": "echo test"}),
2667 ),
2668 MockLlmClient::text_response("Done!"),
2669 ]));
2670
2671 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2672
2673 let (event_tx, mut event_rx) = broadcast::channel(100);
2675 let hitl_policy = ConfirmationPolicy {
2676 enabled: true,
2677 default_timeout_ms: 5000, ..Default::default()
2679 };
2680 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
2681
2682 let permission_policy = PermissionPolicy::new();
2683
2684 let config = AgentConfig {
2685 permission_checker: Some(Arc::new(permission_policy)),
2686 confirmation_manager: Some(confirmation_manager.clone()),
2687 ..Default::default()
2688 };
2689
2690 let cm_clone = confirmation_manager.clone();
2692 let event_handle = tokio::spawn(async move {
2693 let mut events = Vec::new();
2694 while let Ok(event) = event_rx.recv().await {
2696 events.push(event.clone());
2697 if let AgentEvent::ConfirmationRequired { tool_id, .. } = event {
2698 cm_clone.confirm(&tool_id, true, None).await.ok();
2700 if let Ok(recv_event) = event_rx.recv().await {
2702 events.push(recv_event);
2703 }
2704 break;
2705 }
2706 }
2707 events
2708 });
2709
2710 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2711 let _result = agent.execute(&[], "Echo", None).await.unwrap();
2712
2713 let events = event_handle.await.unwrap();
2715 assert!(
2716 events
2717 .iter()
2718 .any(|e| matches!(e, AgentEvent::ConfirmationRequired { .. })),
2719 "Should have ConfirmationRequired event"
2720 );
2721 assert!(
2722 events
2723 .iter()
2724 .any(|e| matches!(e, AgentEvent::ConfirmationReceived { approved: true, .. })),
2725 "Should have ConfirmationReceived event with approved=true"
2726 );
2727 }
2728
2729 #[tokio::test]
2730 async fn test_agent_hitl_disabled_auto_executes() {
2731 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
2733 use tokio::sync::broadcast;
2734
2735 let mock_client = Arc::new(MockLlmClient::new(vec![
2736 MockLlmClient::tool_call_response(
2737 "tool-1",
2738 "bash",
2739 serde_json::json!({"command": "echo auto"}),
2740 ),
2741 MockLlmClient::text_response("Auto executed!"),
2742 ]));
2743
2744 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2745
2746 let (event_tx, _event_rx) = broadcast::channel(100);
2748 let hitl_policy = ConfirmationPolicy {
2749 enabled: false, ..Default::default()
2751 };
2752 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
2753
2754 let permission_policy = PermissionPolicy::new(); let config = AgentConfig {
2757 permission_checker: Some(Arc::new(permission_policy)),
2758 confirmation_manager: Some(confirmation_manager),
2759 ..Default::default()
2760 };
2761
2762 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2763 let result = agent.execute(&[], "Echo", None).await.unwrap();
2764
2765 assert_eq!(result.text, "Auto executed!");
2767 assert_eq!(result.tool_calls_count, 1);
2768 }
2769
2770 #[tokio::test]
2771 async fn test_agent_hitl_with_permission_deny_skips_hitl() {
2772 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
2774 use tokio::sync::broadcast;
2775
2776 let mock_client = Arc::new(MockLlmClient::new(vec![
2777 MockLlmClient::tool_call_response(
2778 "tool-1",
2779 "bash",
2780 serde_json::json!({"command": "rm -rf /"}),
2781 ),
2782 MockLlmClient::text_response("Blocked by permission."),
2783 ]));
2784
2785 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2786
2787 let (event_tx, mut event_rx) = broadcast::channel(100);
2789 let hitl_policy = ConfirmationPolicy {
2790 enabled: true,
2791 ..Default::default()
2792 };
2793 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
2794
2795 let permission_policy = PermissionPolicy::new().deny("bash(rm:*)");
2797
2798 let config = AgentConfig {
2799 permission_checker: Some(Arc::new(permission_policy)),
2800 confirmation_manager: Some(confirmation_manager),
2801 ..Default::default()
2802 };
2803
2804 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2805 let result = agent.execute(&[], "Delete", None).await.unwrap();
2806
2807 assert_eq!(result.text, "Blocked by permission.");
2809
2810 let mut found_confirmation = false;
2812 while let Ok(event) = event_rx.try_recv() {
2813 if matches!(event, AgentEvent::ConfirmationRequired { .. }) {
2814 found_confirmation = true;
2815 }
2816 }
2817 assert!(
2818 !found_confirmation,
2819 "HITL should not be triggered when permission is Deny"
2820 );
2821 }
2822
2823 #[tokio::test]
2824 async fn test_agent_hitl_with_permission_allow_skips_hitl() {
2825 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
2828 use tokio::sync::broadcast;
2829
2830 let mock_client = Arc::new(MockLlmClient::new(vec![
2831 MockLlmClient::tool_call_response(
2832 "tool-1",
2833 "bash",
2834 serde_json::json!({"command": "echo hello"}),
2835 ),
2836 MockLlmClient::text_response("Allowed!"),
2837 ]));
2838
2839 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2840
2841 let (event_tx, mut event_rx) = broadcast::channel(100);
2843 let hitl_policy = ConfirmationPolicy {
2844 enabled: true,
2845 ..Default::default()
2846 };
2847 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
2848
2849 let permission_policy = PermissionPolicy::new().allow("bash(echo:*)");
2851
2852 let config = AgentConfig {
2853 permission_checker: Some(Arc::new(permission_policy)),
2854 confirmation_manager: Some(confirmation_manager.clone()),
2855 ..Default::default()
2856 };
2857
2858 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2859 let result = agent.execute(&[], "Echo", None).await.unwrap();
2860
2861 assert_eq!(result.text, "Allowed!");
2863
2864 let mut found_confirmation = false;
2866 while let Ok(event) = event_rx.try_recv() {
2867 if matches!(event, AgentEvent::ConfirmationRequired { .. }) {
2868 found_confirmation = true;
2869 }
2870 }
2871 assert!(
2872 !found_confirmation,
2873 "Permission Allow should skip HITL confirmation"
2874 );
2875 }
2876
2877 #[tokio::test]
2878 async fn test_agent_hitl_multiple_tool_calls() {
2879 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
2881 use tokio::sync::broadcast;
2882
2883 let mock_client = Arc::new(MockLlmClient::new(vec![
2884 LlmResponse {
2886 message: Message {
2887 role: "assistant".to_string(),
2888 content: vec![
2889 ContentBlock::ToolUse {
2890 id: "tool-1".to_string(),
2891 name: "bash".to_string(),
2892 input: serde_json::json!({"command": "echo first"}),
2893 },
2894 ContentBlock::ToolUse {
2895 id: "tool-2".to_string(),
2896 name: "bash".to_string(),
2897 input: serde_json::json!({"command": "echo second"}),
2898 },
2899 ],
2900 reasoning_content: None,
2901 },
2902 usage: TokenUsage {
2903 prompt_tokens: 10,
2904 completion_tokens: 5,
2905 total_tokens: 15,
2906 cache_read_tokens: None,
2907 cache_write_tokens: None,
2908 },
2909 stop_reason: Some("tool_use".to_string()),
2910 },
2911 MockLlmClient::text_response("Both executed!"),
2912 ]));
2913
2914 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2915
2916 let (event_tx, _event_rx) = broadcast::channel(100);
2918 let hitl_policy = ConfirmationPolicy {
2919 enabled: true,
2920 default_timeout_ms: 5000,
2921 ..Default::default()
2922 };
2923 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
2924
2925 let permission_policy = PermissionPolicy::new(); let config = AgentConfig {
2928 permission_checker: Some(Arc::new(permission_policy)),
2929 confirmation_manager: Some(confirmation_manager.clone()),
2930 ..Default::default()
2931 };
2932
2933 let cm_clone = confirmation_manager.clone();
2935 tokio::spawn(async move {
2936 tokio::time::sleep(std::time::Duration::from_millis(30)).await;
2937 cm_clone.confirm("tool-1", true, None).await.ok();
2938 tokio::time::sleep(std::time::Duration::from_millis(30)).await;
2939 cm_clone.confirm("tool-2", true, None).await.ok();
2940 });
2941
2942 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2943 let result = agent.execute(&[], "Run both", None).await.unwrap();
2944
2945 assert_eq!(result.text, "Both executed!");
2946 assert_eq!(result.tool_calls_count, 2);
2947 }
2948
2949 #[tokio::test]
2950 async fn test_agent_hitl_partial_approval() {
2951 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
2953 use tokio::sync::broadcast;
2954
2955 let mock_client = Arc::new(MockLlmClient::new(vec![
2956 LlmResponse {
2958 message: Message {
2959 role: "assistant".to_string(),
2960 content: vec![
2961 ContentBlock::ToolUse {
2962 id: "tool-1".to_string(),
2963 name: "bash".to_string(),
2964 input: serde_json::json!({"command": "echo safe"}),
2965 },
2966 ContentBlock::ToolUse {
2967 id: "tool-2".to_string(),
2968 name: "bash".to_string(),
2969 input: serde_json::json!({"command": "rm -rf /"}),
2970 },
2971 ],
2972 reasoning_content: None,
2973 },
2974 usage: TokenUsage {
2975 prompt_tokens: 10,
2976 completion_tokens: 5,
2977 total_tokens: 15,
2978 cache_read_tokens: None,
2979 cache_write_tokens: None,
2980 },
2981 stop_reason: Some("tool_use".to_string()),
2982 },
2983 MockLlmClient::text_response("First worked, second rejected."),
2984 ]));
2985
2986 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2987
2988 let (event_tx, _event_rx) = broadcast::channel(100);
2989 let hitl_policy = ConfirmationPolicy {
2990 enabled: true,
2991 default_timeout_ms: 5000,
2992 ..Default::default()
2993 };
2994 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
2995
2996 let permission_policy = PermissionPolicy::new();
2997
2998 let config = AgentConfig {
2999 permission_checker: Some(Arc::new(permission_policy)),
3000 confirmation_manager: Some(confirmation_manager.clone()),
3001 ..Default::default()
3002 };
3003
3004 let cm_clone = confirmation_manager.clone();
3006 tokio::spawn(async move {
3007 tokio::time::sleep(std::time::Duration::from_millis(30)).await;
3008 cm_clone.confirm("tool-1", true, None).await.ok();
3009 tokio::time::sleep(std::time::Duration::from_millis(30)).await;
3010 cm_clone
3011 .confirm("tool-2", false, Some("Dangerous".to_string()))
3012 .await
3013 .ok();
3014 });
3015
3016 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3017 let result = agent.execute(&[], "Run both", None).await.unwrap();
3018
3019 assert_eq!(result.text, "First worked, second rejected.");
3020 assert_eq!(result.tool_calls_count, 2);
3021 }
3022
3023 #[tokio::test]
3024 async fn test_agent_hitl_yolo_mode_auto_approves() {
3025 use crate::hitl::{ConfirmationManager, ConfirmationPolicy, SessionLane};
3027 use tokio::sync::broadcast;
3028
3029 let mock_client = Arc::new(MockLlmClient::new(vec![
3030 MockLlmClient::tool_call_response(
3031 "tool-1",
3032 "read", serde_json::json!({"path": "/tmp/test.txt"}),
3034 ),
3035 MockLlmClient::text_response("File read!"),
3036 ]));
3037
3038 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3039
3040 let (event_tx, mut event_rx) = broadcast::channel(100);
3042 let mut yolo_lanes = std::collections::HashSet::new();
3043 yolo_lanes.insert(SessionLane::Query);
3044 let hitl_policy = ConfirmationPolicy {
3045 enabled: true,
3046 yolo_lanes, ..Default::default()
3048 };
3049 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3050
3051 let permission_policy = PermissionPolicy::new();
3052
3053 let config = AgentConfig {
3054 permission_checker: Some(Arc::new(permission_policy)),
3055 confirmation_manager: Some(confirmation_manager),
3056 ..Default::default()
3057 };
3058
3059 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3060 let result = agent.execute(&[], "Read file", None).await.unwrap();
3061
3062 assert_eq!(result.text, "File read!");
3064
3065 let mut found_confirmation = false;
3067 while let Ok(event) = event_rx.try_recv() {
3068 if matches!(event, AgentEvent::ConfirmationRequired { .. }) {
3069 found_confirmation = true;
3070 }
3071 }
3072 assert!(
3073 !found_confirmation,
3074 "YOLO mode should not trigger confirmation"
3075 );
3076 }
3077
3078 #[tokio::test]
3079 async fn test_agent_config_with_all_options() {
3080 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
3081 use tokio::sync::broadcast;
3082
3083 let (event_tx, _) = broadcast::channel(100);
3084 let hitl_policy = ConfirmationPolicy::default();
3085 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3086
3087 let permission_policy = PermissionPolicy::new().allow("bash(*)");
3088
3089 let config = AgentConfig {
3090 system_prompt: Some("Test system prompt".to_string()),
3091 tools: vec![],
3092 max_tool_rounds: 10,
3093 permission_checker: Some(Arc::new(permission_policy)),
3094 confirmation_manager: Some(confirmation_manager),
3095 context_providers: vec![],
3096 planning_enabled: false,
3097 goal_tracking: false,
3098 hook_engine: None,
3099 skill_registry: None,
3100 };
3101
3102 assert_eq!(config.system_prompt, Some("Test system prompt".to_string()));
3103 assert_eq!(config.max_tool_rounds, 10);
3104 assert!(config.permission_checker.is_some());
3105 assert!(config.confirmation_manager.is_some());
3106 assert!(config.context_providers.is_empty());
3107
3108 let debug_str = format!("{:?}", config);
3110 assert!(debug_str.contains("AgentConfig"));
3111 assert!(debug_str.contains("permission_checker: true"));
3112 assert!(debug_str.contains("confirmation_manager: true"));
3113 assert!(debug_str.contains("context_providers: 0"));
3114 }
3115
3116 use crate::context::{ContextItem, ContextType};
3121
3122 struct MockContextProvider {
3124 name: String,
3125 items: Vec<ContextItem>,
3126 on_turn_calls: std::sync::Arc<tokio::sync::RwLock<Vec<(String, String, String)>>>,
3127 }
3128
3129 impl MockContextProvider {
3130 fn new(name: &str) -> Self {
3131 Self {
3132 name: name.to_string(),
3133 items: Vec::new(),
3134 on_turn_calls: std::sync::Arc::new(tokio::sync::RwLock::new(Vec::new())),
3135 }
3136 }
3137
3138 fn with_items(mut self, items: Vec<ContextItem>) -> Self {
3139 self.items = items;
3140 self
3141 }
3142 }
3143
3144 #[async_trait::async_trait]
3145 impl ContextProvider for MockContextProvider {
3146 fn name(&self) -> &str {
3147 &self.name
3148 }
3149
3150 async fn query(&self, _query: &ContextQuery) -> anyhow::Result<ContextResult> {
3151 let mut result = ContextResult::new(&self.name);
3152 for item in &self.items {
3153 result.add_item(item.clone());
3154 }
3155 Ok(result)
3156 }
3157
3158 async fn on_turn_complete(
3159 &self,
3160 session_id: &str,
3161 prompt: &str,
3162 response: &str,
3163 ) -> anyhow::Result<()> {
3164 let mut calls = self.on_turn_calls.write().await;
3165 calls.push((
3166 session_id.to_string(),
3167 prompt.to_string(),
3168 response.to_string(),
3169 ));
3170 Ok(())
3171 }
3172 }
3173
3174 #[tokio::test]
3175 async fn test_agent_with_context_provider() {
3176 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
3177 "Response using context",
3178 )]));
3179
3180 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3181
3182 let provider =
3183 MockContextProvider::new("test-provider").with_items(vec![ContextItem::new(
3184 "ctx-1",
3185 ContextType::Resource,
3186 "Relevant context here",
3187 )
3188 .with_source("test://docs/example")]);
3189
3190 let config = AgentConfig {
3191 system_prompt: Some("You are helpful.".to_string()),
3192 context_providers: vec![Arc::new(provider)],
3193 ..Default::default()
3194 };
3195
3196 let agent = AgentLoop::new(
3197 mock_client.clone(),
3198 tool_executor,
3199 test_tool_context(),
3200 config,
3201 );
3202 let result = agent.execute(&[], "What is X?", None).await.unwrap();
3203
3204 assert_eq!(result.text, "Response using context");
3205 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 1);
3206 }
3207
3208 #[tokio::test]
3209 async fn test_agent_context_provider_events() {
3210 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
3211 "Answer",
3212 )]));
3213
3214 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3215
3216 let provider =
3217 MockContextProvider::new("event-provider").with_items(vec![ContextItem::new(
3218 "item-1",
3219 ContextType::Memory,
3220 "Memory content",
3221 )
3222 .with_token_count(50)]);
3223
3224 let config = AgentConfig {
3225 context_providers: vec![Arc::new(provider)],
3226 ..Default::default()
3227 };
3228
3229 let (tx, mut rx) = mpsc::channel(100);
3230 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3231 let _result = agent.execute(&[], "Test prompt", Some(tx)).await.unwrap();
3232
3233 let mut events = Vec::new();
3235 while let Ok(event) = rx.try_recv() {
3236 events.push(event);
3237 }
3238
3239 assert!(
3241 events
3242 .iter()
3243 .any(|e| matches!(e, AgentEvent::ContextResolving { .. })),
3244 "Should have ContextResolving event"
3245 );
3246 assert!(
3247 events
3248 .iter()
3249 .any(|e| matches!(e, AgentEvent::ContextResolved { .. })),
3250 "Should have ContextResolved event"
3251 );
3252
3253 for event in &events {
3255 if let AgentEvent::ContextResolved {
3256 total_items,
3257 total_tokens,
3258 } = event
3259 {
3260 assert_eq!(*total_items, 1);
3261 assert_eq!(*total_tokens, 50);
3262 }
3263 }
3264 }
3265
3266 #[tokio::test]
3267 async fn test_agent_multiple_context_providers() {
3268 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
3269 "Combined response",
3270 )]));
3271
3272 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3273
3274 let provider1 = MockContextProvider::new("provider-1").with_items(vec![ContextItem::new(
3275 "p1-1",
3276 ContextType::Resource,
3277 "Resource from P1",
3278 )
3279 .with_token_count(100)]);
3280
3281 let provider2 = MockContextProvider::new("provider-2").with_items(vec![
3282 ContextItem::new("p2-1", ContextType::Memory, "Memory from P2").with_token_count(50),
3283 ContextItem::new("p2-2", ContextType::Skill, "Skill from P2").with_token_count(75),
3284 ]);
3285
3286 let config = AgentConfig {
3287 system_prompt: Some("Base system prompt.".to_string()),
3288 context_providers: vec![Arc::new(provider1), Arc::new(provider2)],
3289 ..Default::default()
3290 };
3291
3292 let (tx, mut rx) = mpsc::channel(100);
3293 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3294 let result = agent.execute(&[], "Query", Some(tx)).await.unwrap();
3295
3296 assert_eq!(result.text, "Combined response");
3297
3298 while let Ok(event) = rx.try_recv() {
3300 if let AgentEvent::ContextResolved {
3301 total_items,
3302 total_tokens,
3303 } = event
3304 {
3305 assert_eq!(total_items, 3); assert_eq!(total_tokens, 225); }
3308 }
3309 }
3310
3311 #[tokio::test]
3312 async fn test_agent_no_context_providers() {
3313 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
3314 "No context",
3315 )]));
3316
3317 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3318
3319 let config = AgentConfig::default();
3321
3322 let (tx, mut rx) = mpsc::channel(100);
3323 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3324 let result = agent.execute(&[], "Simple prompt", Some(tx)).await.unwrap();
3325
3326 assert_eq!(result.text, "No context");
3327
3328 let mut events = Vec::new();
3330 while let Ok(event) = rx.try_recv() {
3331 events.push(event);
3332 }
3333
3334 assert!(
3335 !events
3336 .iter()
3337 .any(|e| matches!(e, AgentEvent::ContextResolving { .. })),
3338 "Should NOT have ContextResolving event"
3339 );
3340 }
3341
3342 #[tokio::test]
3343 async fn test_agent_context_on_turn_complete() {
3344 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
3345 "Final response",
3346 )]));
3347
3348 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3349
3350 let provider = Arc::new(MockContextProvider::new("memory-provider"));
3351 let on_turn_calls = provider.on_turn_calls.clone();
3352
3353 let config = AgentConfig {
3354 context_providers: vec![provider],
3355 ..Default::default()
3356 };
3357
3358 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3359
3360 let result = agent
3362 .execute_with_session(&[], "User prompt", Some("sess-123"), None)
3363 .await
3364 .unwrap();
3365
3366 assert_eq!(result.text, "Final response");
3367
3368 let calls = on_turn_calls.read().await;
3370 assert_eq!(calls.len(), 1);
3371 assert_eq!(calls[0].0, "sess-123");
3372 assert_eq!(calls[0].1, "User prompt");
3373 assert_eq!(calls[0].2, "Final response");
3374 }
3375
3376 #[tokio::test]
3377 async fn test_agent_context_on_turn_complete_no_session() {
3378 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
3379 "Response",
3380 )]));
3381
3382 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3383
3384 let provider = Arc::new(MockContextProvider::new("memory-provider"));
3385 let on_turn_calls = provider.on_turn_calls.clone();
3386
3387 let config = AgentConfig {
3388 context_providers: vec![provider],
3389 ..Default::default()
3390 };
3391
3392 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3393
3394 let _result = agent.execute(&[], "Prompt", None).await.unwrap();
3396
3397 let calls = on_turn_calls.read().await;
3399 assert!(calls.is_empty());
3400 }
3401
3402 #[tokio::test]
3403 async fn test_agent_build_augmented_system_prompt() {
3404 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response("OK")]));
3405
3406 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3407
3408 let provider = MockContextProvider::new("test").with_items(vec![ContextItem::new(
3409 "doc-1",
3410 ContextType::Resource,
3411 "Auth uses JWT tokens.",
3412 )
3413 .with_source("viking://docs/auth")]);
3414
3415 let config = AgentConfig {
3416 system_prompt: Some("You are helpful.".to_string()),
3417 context_providers: vec![Arc::new(provider)],
3418 ..Default::default()
3419 };
3420
3421 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3422
3423 let context_results = agent.resolve_context("test", None).await;
3425 let augmented = agent.build_augmented_system_prompt(&context_results);
3426
3427 let augmented_str = augmented.unwrap();
3428 assert!(augmented_str.contains("You are helpful."));
3429 assert!(augmented_str.contains("<context source=\"viking://docs/auth\" type=\"Resource\">"));
3430 assert!(augmented_str.contains("Auth uses JWT tokens."));
3431 }
3432
3433 async fn collect_events(mut rx: mpsc::Receiver<AgentEvent>) -> Vec<AgentEvent> {
3439 let mut events = Vec::new();
3440 while let Ok(event) = rx.try_recv() {
3441 events.push(event);
3442 }
3443 while let Some(event) = rx.recv().await {
3445 events.push(event);
3446 }
3447 events
3448 }
3449
3450 #[tokio::test]
3451 async fn test_agent_multi_turn_tool_chain() {
3452 let mock_client = Arc::new(MockLlmClient::new(vec![
3454 MockLlmClient::tool_call_response(
3456 "t1",
3457 "bash",
3458 serde_json::json!({"command": "echo step1"}),
3459 ),
3460 MockLlmClient::tool_call_response(
3462 "t2",
3463 "bash",
3464 serde_json::json!({"command": "echo step2"}),
3465 ),
3466 MockLlmClient::text_response("Completed both steps: step1 then step2"),
3468 ]));
3469
3470 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3471 let config = AgentConfig::default();
3472
3473 let agent = AgentLoop::new(
3474 mock_client.clone(),
3475 tool_executor,
3476 test_tool_context(),
3477 config,
3478 );
3479 let result = agent.execute(&[], "Run two steps", None).await.unwrap();
3480
3481 assert_eq!(result.text, "Completed both steps: step1 then step2");
3482 assert_eq!(result.tool_calls_count, 2);
3483 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 3);
3484
3485 assert_eq!(result.messages[0].role, "user");
3487 assert_eq!(result.messages[1].role, "assistant"); assert_eq!(result.messages[2].role, "user"); assert_eq!(result.messages[3].role, "assistant"); assert_eq!(result.messages[4].role, "user"); assert_eq!(result.messages[5].role, "assistant"); assert_eq!(result.messages.len(), 6);
3493 }
3494
3495 #[tokio::test]
3496 async fn test_agent_conversation_history_preserved() {
3497 let existing_history = vec![
3499 Message::user("What is Rust?"),
3500 Message {
3501 role: "assistant".to_string(),
3502 content: vec![ContentBlock::Text {
3503 text: "Rust is a systems programming language.".to_string(),
3504 }],
3505 reasoning_content: None,
3506 },
3507 ];
3508
3509 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
3510 "Rust was created by Graydon Hoare at Mozilla.",
3511 )]));
3512
3513 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3514 let agent = AgentLoop::new(
3515 mock_client.clone(),
3516 tool_executor,
3517 test_tool_context(),
3518 AgentConfig::default(),
3519 );
3520
3521 let result = agent
3522 .execute(&existing_history, "Who created it?", None)
3523 .await
3524 .unwrap();
3525
3526 assert_eq!(result.messages.len(), 4);
3528 assert_eq!(result.messages[0].text(), "What is Rust?");
3529 assert_eq!(
3530 result.messages[1].text(),
3531 "Rust is a systems programming language."
3532 );
3533 assert_eq!(result.messages[2].text(), "Who created it?");
3534 assert_eq!(
3535 result.messages[3].text(),
3536 "Rust was created by Graydon Hoare at Mozilla."
3537 );
3538 }
3539
3540 #[tokio::test]
3541 async fn test_agent_event_stream_completeness() {
3542 let mock_client = Arc::new(MockLlmClient::new(vec![
3544 MockLlmClient::tool_call_response(
3545 "t1",
3546 "bash",
3547 serde_json::json!({"command": "echo hi"}),
3548 ),
3549 MockLlmClient::text_response("Done"),
3550 ]));
3551
3552 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3553 let agent = AgentLoop::new(
3554 mock_client,
3555 tool_executor,
3556 test_tool_context(),
3557 AgentConfig::default(),
3558 );
3559
3560 let (tx, rx) = mpsc::channel(100);
3561 let result = agent.execute(&[], "Say hi", Some(tx)).await.unwrap();
3562 assert_eq!(result.text, "Done");
3563
3564 let events = collect_events(rx).await;
3565
3566 let event_types: Vec<&str> = events
3568 .iter()
3569 .map(|e| match e {
3570 AgentEvent::Start { .. } => "Start",
3571 AgentEvent::TurnStart { .. } => "TurnStart",
3572 AgentEvent::TurnEnd { .. } => "TurnEnd",
3573 AgentEvent::ToolEnd { .. } => "ToolEnd",
3574 AgentEvent::End { .. } => "End",
3575 _ => "Other",
3576 })
3577 .collect();
3578
3579 assert_eq!(event_types.first(), Some(&"Start"));
3581 assert_eq!(event_types.last(), Some(&"End"));
3582
3583 let turn_starts = event_types.iter().filter(|&&t| t == "TurnStart").count();
3585 assert_eq!(turn_starts, 2);
3586
3587 let tool_ends = event_types.iter().filter(|&&t| t == "ToolEnd").count();
3589 assert_eq!(tool_ends, 1);
3590 }
3591
3592 #[tokio::test]
3593 async fn test_agent_multiple_tools_single_turn() {
3594 let mock_client = Arc::new(MockLlmClient::new(vec![
3596 LlmResponse {
3597 message: Message {
3598 role: "assistant".to_string(),
3599 content: vec![
3600 ContentBlock::ToolUse {
3601 id: "t1".to_string(),
3602 name: "bash".to_string(),
3603 input: serde_json::json!({"command": "echo first"}),
3604 },
3605 ContentBlock::ToolUse {
3606 id: "t2".to_string(),
3607 name: "bash".to_string(),
3608 input: serde_json::json!({"command": "echo second"}),
3609 },
3610 ],
3611 reasoning_content: None,
3612 },
3613 usage: TokenUsage {
3614 prompt_tokens: 10,
3615 completion_tokens: 5,
3616 total_tokens: 15,
3617 cache_read_tokens: None,
3618 cache_write_tokens: None,
3619 },
3620 stop_reason: Some("tool_use".to_string()),
3621 },
3622 MockLlmClient::text_response("Both commands ran"),
3623 ]));
3624
3625 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3626 let agent = AgentLoop::new(
3627 mock_client.clone(),
3628 tool_executor,
3629 test_tool_context(),
3630 AgentConfig::default(),
3631 );
3632
3633 let result = agent.execute(&[], "Run both", None).await.unwrap();
3634
3635 assert_eq!(result.text, "Both commands ran");
3636 assert_eq!(result.tool_calls_count, 2);
3637 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 2); assert_eq!(result.messages[0].role, "user");
3641 assert_eq!(result.messages[1].role, "assistant");
3642 assert_eq!(result.messages[2].role, "user"); assert_eq!(result.messages[3].role, "user"); assert_eq!(result.messages[4].role, "assistant");
3645 }
3646
3647 #[tokio::test]
3648 async fn test_agent_token_usage_accumulation() {
3649 let mock_client = Arc::new(MockLlmClient::new(vec![
3651 MockLlmClient::tool_call_response(
3652 "t1",
3653 "bash",
3654 serde_json::json!({"command": "echo x"}),
3655 ),
3656 MockLlmClient::text_response("Done"),
3657 ]));
3658
3659 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3660 let agent = AgentLoop::new(
3661 mock_client,
3662 tool_executor,
3663 test_tool_context(),
3664 AgentConfig::default(),
3665 );
3666
3667 let result = agent.execute(&[], "test", None).await.unwrap();
3668
3669 assert_eq!(result.usage.prompt_tokens, 20);
3672 assert_eq!(result.usage.completion_tokens, 10);
3673 assert_eq!(result.usage.total_tokens, 30);
3674 }
3675
3676 #[tokio::test]
3677 async fn test_agent_system_prompt_passed() {
3678 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
3680 "I am a coding assistant.",
3681 )]));
3682
3683 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3684 let config = AgentConfig {
3685 system_prompt: Some("You are a coding assistant.".to_string()),
3686 ..Default::default()
3687 };
3688
3689 let agent = AgentLoop::new(
3690 mock_client.clone(),
3691 tool_executor,
3692 test_tool_context(),
3693 config,
3694 );
3695 let result = agent.execute(&[], "What are you?", None).await.unwrap();
3696
3697 assert_eq!(result.text, "I am a coding assistant.");
3698 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 1);
3699 }
3700
3701 #[tokio::test]
3702 async fn test_agent_max_rounds_with_persistent_tool_calls() {
3703 let mut responses = Vec::new();
3705 for i in 0..15 {
3706 responses.push(MockLlmClient::tool_call_response(
3707 &format!("t{}", i),
3708 "bash",
3709 serde_json::json!({"command": format!("echo round{}", i)}),
3710 ));
3711 }
3712
3713 let mock_client = Arc::new(MockLlmClient::new(responses));
3714 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3715 let config = AgentConfig {
3716 max_tool_rounds: 5,
3717 ..Default::default()
3718 };
3719
3720 let agent = AgentLoop::new(
3721 mock_client.clone(),
3722 tool_executor,
3723 test_tool_context(),
3724 config,
3725 );
3726 let result = agent.execute(&[], "Loop forever", None).await;
3727
3728 assert!(result.is_err());
3729 let err = result.unwrap_err().to_string();
3730 assert!(err.contains("Max tool rounds (5) exceeded"));
3731 }
3732
3733 #[tokio::test]
3734 async fn test_agent_end_event_contains_final_text() {
3735 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
3736 "Final answer here",
3737 )]));
3738
3739 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3740 let agent = AgentLoop::new(
3741 mock_client,
3742 tool_executor,
3743 test_tool_context(),
3744 AgentConfig::default(),
3745 );
3746
3747 let (tx, rx) = mpsc::channel(100);
3748 agent.execute(&[], "test", Some(tx)).await.unwrap();
3749
3750 let events = collect_events(rx).await;
3751 let end_event = events.iter().find(|e| matches!(e, AgentEvent::End { .. }));
3752 assert!(end_event.is_some());
3753
3754 if let AgentEvent::End { text, usage } = end_event.unwrap() {
3755 assert_eq!(text, "Final answer here");
3756 assert_eq!(usage.total_tokens, 15);
3757 }
3758 }
3759}
3760
3761#[cfg(test)]
3762mod extra_agent_tests {
3763 use super::*;
3764 use crate::agent::tests::MockLlmClient;
3765 use crate::llm::{ContentBlock, StreamEvent};
3766 use crate::queue::SessionQueueConfig;
3767 use crate::tools::ToolExecutor;
3768 use std::path::PathBuf;
3769 use std::sync::atomic::{AtomicUsize, Ordering};
3770
3771 fn test_tool_context() -> ToolContext {
3772 ToolContext::new(PathBuf::from("/tmp"))
3773 }
3774
3775 #[test]
3780 fn test_agent_config_debug() {
3781 let config = AgentConfig {
3782 system_prompt: Some("You are helpful".to_string()),
3783 tools: vec![],
3784 max_tool_rounds: 10,
3785 permission_checker: None,
3786 confirmation_manager: None,
3787 context_providers: vec![],
3788 planning_enabled: true,
3789 goal_tracking: false,
3790 hook_engine: None,
3791 skill_registry: None,
3792 };
3793 let debug = format!("{:?}", config);
3794 assert!(debug.contains("AgentConfig"));
3795 assert!(debug.contains("planning_enabled"));
3796 }
3797
3798 #[test]
3799 fn test_agent_config_default_values() {
3800 let config = AgentConfig::default();
3801 assert_eq!(config.max_tool_rounds, MAX_TOOL_ROUNDS);
3802 assert!(!config.planning_enabled);
3803 assert!(!config.goal_tracking);
3804 assert!(config.context_providers.is_empty());
3805 }
3806
3807 #[test]
3812 fn test_agent_event_serialize_start() {
3813 let event = AgentEvent::Start {
3814 prompt: "Hello".to_string(),
3815 };
3816 let json = serde_json::to_string(&event).unwrap();
3817 assert!(json.contains("agent_start"));
3818 assert!(json.contains("Hello"));
3819 }
3820
3821 #[test]
3822 fn test_agent_event_serialize_text_delta() {
3823 let event = AgentEvent::TextDelta {
3824 text: "chunk".to_string(),
3825 };
3826 let json = serde_json::to_string(&event).unwrap();
3827 assert!(json.contains("text_delta"));
3828 }
3829
3830 #[test]
3831 fn test_agent_event_serialize_tool_start() {
3832 let event = AgentEvent::ToolStart {
3833 id: "t1".to_string(),
3834 name: "bash".to_string(),
3835 };
3836 let json = serde_json::to_string(&event).unwrap();
3837 assert!(json.contains("tool_start"));
3838 assert!(json.contains("bash"));
3839 }
3840
3841 #[test]
3842 fn test_agent_event_serialize_tool_end() {
3843 let event = AgentEvent::ToolEnd {
3844 id: "t1".to_string(),
3845 name: "bash".to_string(),
3846 output: "hello".to_string(),
3847 exit_code: 0,
3848 };
3849 let json = serde_json::to_string(&event).unwrap();
3850 assert!(json.contains("tool_end"));
3851 }
3852
3853 #[test]
3854 fn test_agent_event_serialize_error() {
3855 let event = AgentEvent::Error {
3856 message: "oops".to_string(),
3857 };
3858 let json = serde_json::to_string(&event).unwrap();
3859 assert!(json.contains("error"));
3860 assert!(json.contains("oops"));
3861 }
3862
3863 #[test]
3864 fn test_agent_event_serialize_confirmation_required() {
3865 let event = AgentEvent::ConfirmationRequired {
3866 tool_id: "t1".to_string(),
3867 tool_name: "bash".to_string(),
3868 args: serde_json::json!({"cmd": "rm"}),
3869 timeout_ms: 30000,
3870 };
3871 let json = serde_json::to_string(&event).unwrap();
3872 assert!(json.contains("confirmation_required"));
3873 }
3874
3875 #[test]
3876 fn test_agent_event_serialize_confirmation_received() {
3877 let event = AgentEvent::ConfirmationReceived {
3878 tool_id: "t1".to_string(),
3879 approved: true,
3880 reason: Some("safe".to_string()),
3881 };
3882 let json = serde_json::to_string(&event).unwrap();
3883 assert!(json.contains("confirmation_received"));
3884 }
3885
3886 #[test]
3887 fn test_agent_event_serialize_confirmation_timeout() {
3888 let event = AgentEvent::ConfirmationTimeout {
3889 tool_id: "t1".to_string(),
3890 action_taken: "rejected".to_string(),
3891 };
3892 let json = serde_json::to_string(&event).unwrap();
3893 assert!(json.contains("confirmation_timeout"));
3894 }
3895
3896 #[test]
3897 fn test_agent_event_serialize_external_task_pending() {
3898 let event = AgentEvent::ExternalTaskPending {
3899 task_id: "task-1".to_string(),
3900 session_id: "sess-1".to_string(),
3901 lane: crate::hitl::SessionLane::Execute,
3902 command_type: "bash".to_string(),
3903 payload: serde_json::json!({}),
3904 timeout_ms: 60000,
3905 };
3906 let json = serde_json::to_string(&event).unwrap();
3907 assert!(json.contains("external_task_pending"));
3908 }
3909
3910 #[test]
3911 fn test_agent_event_serialize_external_task_completed() {
3912 let event = AgentEvent::ExternalTaskCompleted {
3913 task_id: "task-1".to_string(),
3914 session_id: "sess-1".to_string(),
3915 success: false,
3916 };
3917 let json = serde_json::to_string(&event).unwrap();
3918 assert!(json.contains("external_task_completed"));
3919 }
3920
3921 #[test]
3922 fn test_agent_event_serialize_permission_denied() {
3923 let event = AgentEvent::PermissionDenied {
3924 tool_id: "t1".to_string(),
3925 tool_name: "bash".to_string(),
3926 args: serde_json::json!({}),
3927 reason: "denied".to_string(),
3928 };
3929 let json = serde_json::to_string(&event).unwrap();
3930 assert!(json.contains("permission_denied"));
3931 }
3932
3933 #[test]
3934 fn test_agent_event_serialize_context_compacted() {
3935 let event = AgentEvent::ContextCompacted {
3936 session_id: "sess-1".to_string(),
3937 before_messages: 100,
3938 after_messages: 20,
3939 percent_before: 0.85,
3940 };
3941 let json = serde_json::to_string(&event).unwrap();
3942 assert!(json.contains("context_compacted"));
3943 }
3944
3945 #[test]
3946 fn test_agent_event_serialize_turn_start() {
3947 let event = AgentEvent::TurnStart { turn: 3 };
3948 let json = serde_json::to_string(&event).unwrap();
3949 assert!(json.contains("turn_start"));
3950 }
3951
3952 #[test]
3953 fn test_agent_event_serialize_turn_end() {
3954 let event = AgentEvent::TurnEnd {
3955 turn: 3,
3956 usage: TokenUsage::default(),
3957 };
3958 let json = serde_json::to_string(&event).unwrap();
3959 assert!(json.contains("turn_end"));
3960 }
3961
3962 #[test]
3963 fn test_agent_event_serialize_end() {
3964 let event = AgentEvent::End {
3965 text: "Done".to_string(),
3966 usage: TokenUsage {
3967 prompt_tokens: 100,
3968 completion_tokens: 50,
3969 total_tokens: 150,
3970 cache_read_tokens: None,
3971 cache_write_tokens: None,
3972 },
3973 };
3974 let json = serde_json::to_string(&event).unwrap();
3975 assert!(json.contains("agent_end"));
3976 }
3977
3978 #[test]
3983 fn test_agent_result_fields() {
3984 let result = AgentResult {
3985 text: "output".to_string(),
3986 messages: vec![Message::user("hello")],
3987 usage: TokenUsage::default(),
3988 tool_calls_count: 3,
3989 };
3990 assert_eq!(result.text, "output");
3991 assert_eq!(result.messages.len(), 1);
3992 assert_eq!(result.tool_calls_count, 3);
3993 }
3994
3995 #[test]
4000 fn test_agent_event_serialize_context_resolving() {
4001 let event = AgentEvent::ContextResolving {
4002 providers: vec!["provider1".to_string(), "provider2".to_string()],
4003 };
4004 let json = serde_json::to_string(&event).unwrap();
4005 assert!(json.contains("context_resolving"));
4006 assert!(json.contains("provider1"));
4007 }
4008
4009 #[test]
4010 fn test_agent_event_serialize_context_resolved() {
4011 let event = AgentEvent::ContextResolved {
4012 total_items: 5,
4013 total_tokens: 1000,
4014 };
4015 let json = serde_json::to_string(&event).unwrap();
4016 assert!(json.contains("context_resolved"));
4017 assert!(json.contains("1000"));
4018 }
4019
4020 #[test]
4021 fn test_agent_event_serialize_command_dead_lettered() {
4022 let event = AgentEvent::CommandDeadLettered {
4023 command_id: "cmd-1".to_string(),
4024 command_type: "bash".to_string(),
4025 lane: "execute".to_string(),
4026 error: "timeout".to_string(),
4027 attempts: 3,
4028 };
4029 let json = serde_json::to_string(&event).unwrap();
4030 assert!(json.contains("command_dead_lettered"));
4031 assert!(json.contains("cmd-1"));
4032 }
4033
4034 #[test]
4035 fn test_agent_event_serialize_command_retry() {
4036 let event = AgentEvent::CommandRetry {
4037 command_id: "cmd-2".to_string(),
4038 command_type: "read".to_string(),
4039 lane: "query".to_string(),
4040 attempt: 2,
4041 delay_ms: 1000,
4042 };
4043 let json = serde_json::to_string(&event).unwrap();
4044 assert!(json.contains("command_retry"));
4045 assert!(json.contains("cmd-2"));
4046 }
4047
4048 #[test]
4049 fn test_agent_event_serialize_queue_alert() {
4050 let event = AgentEvent::QueueAlert {
4051 level: "warning".to_string(),
4052 alert_type: "depth".to_string(),
4053 message: "Queue depth exceeded".to_string(),
4054 };
4055 let json = serde_json::to_string(&event).unwrap();
4056 assert!(json.contains("queue_alert"));
4057 assert!(json.contains("warning"));
4058 }
4059
4060 #[test]
4061 fn test_agent_event_serialize_task_updated() {
4062 let event = AgentEvent::TaskUpdated {
4063 session_id: "sess-1".to_string(),
4064 tasks: vec![],
4065 };
4066 let json = serde_json::to_string(&event).unwrap();
4067 assert!(json.contains("task_updated"));
4068 assert!(json.contains("sess-1"));
4069 }
4070
4071 #[test]
4072 fn test_agent_event_serialize_memory_stored() {
4073 let event = AgentEvent::MemoryStored {
4074 memory_id: "mem-1".to_string(),
4075 memory_type: "conversation".to_string(),
4076 importance: 0.8,
4077 tags: vec!["important".to_string()],
4078 };
4079 let json = serde_json::to_string(&event).unwrap();
4080 assert!(json.contains("memory_stored"));
4081 assert!(json.contains("mem-1"));
4082 }
4083
4084 #[test]
4085 fn test_agent_event_serialize_memory_recalled() {
4086 let event = AgentEvent::MemoryRecalled {
4087 memory_id: "mem-2".to_string(),
4088 content: "Previous conversation".to_string(),
4089 relevance: 0.9,
4090 };
4091 let json = serde_json::to_string(&event).unwrap();
4092 assert!(json.contains("memory_recalled"));
4093 assert!(json.contains("mem-2"));
4094 }
4095
4096 #[test]
4097 fn test_agent_event_serialize_memories_searched() {
4098 let event = AgentEvent::MemoriesSearched {
4099 query: Some("search term".to_string()),
4100 tags: vec!["tag1".to_string()],
4101 result_count: 5,
4102 };
4103 let json = serde_json::to_string(&event).unwrap();
4104 assert!(json.contains("memories_searched"));
4105 assert!(json.contains("search term"));
4106 }
4107
4108 #[test]
4109 fn test_agent_event_serialize_memory_cleared() {
4110 let event = AgentEvent::MemoryCleared {
4111 tier: "short_term".to_string(),
4112 count: 10,
4113 };
4114 let json = serde_json::to_string(&event).unwrap();
4115 assert!(json.contains("memory_cleared"));
4116 assert!(json.contains("short_term"));
4117 }
4118
4119 #[test]
4120 fn test_agent_event_serialize_subagent_start() {
4121 let event = AgentEvent::SubagentStart {
4122 task_id: "task-1".to_string(),
4123 session_id: "child-sess".to_string(),
4124 parent_session_id: "parent-sess".to_string(),
4125 agent: "explore".to_string(),
4126 description: "Explore codebase".to_string(),
4127 };
4128 let json = serde_json::to_string(&event).unwrap();
4129 assert!(json.contains("subagent_start"));
4130 assert!(json.contains("explore"));
4131 }
4132
4133 #[test]
4134 fn test_agent_event_serialize_subagent_progress() {
4135 let event = AgentEvent::SubagentProgress {
4136 task_id: "task-1".to_string(),
4137 session_id: "child-sess".to_string(),
4138 status: "processing".to_string(),
4139 metadata: serde_json::json!({"progress": 50}),
4140 };
4141 let json = serde_json::to_string(&event).unwrap();
4142 assert!(json.contains("subagent_progress"));
4143 assert!(json.contains("processing"));
4144 }
4145
4146 #[test]
4147 fn test_agent_event_serialize_subagent_end() {
4148 let event = AgentEvent::SubagentEnd {
4149 task_id: "task-1".to_string(),
4150 session_id: "child-sess".to_string(),
4151 agent: "explore".to_string(),
4152 output: "Found 10 files".to_string(),
4153 success: true,
4154 };
4155 let json = serde_json::to_string(&event).unwrap();
4156 assert!(json.contains("subagent_end"));
4157 assert!(json.contains("Found 10 files"));
4158 }
4159
4160 #[test]
4161 fn test_agent_event_serialize_planning_start() {
4162 let event = AgentEvent::PlanningStart {
4163 prompt: "Build a web app".to_string(),
4164 };
4165 let json = serde_json::to_string(&event).unwrap();
4166 assert!(json.contains("planning_start"));
4167 assert!(json.contains("Build a web app"));
4168 }
4169
4170 #[test]
4171 fn test_agent_event_serialize_planning_end() {
4172 use crate::planning::{Complexity, ExecutionPlan};
4173 let plan = ExecutionPlan::new("Test goal".to_string(), Complexity::Simple);
4174 let event = AgentEvent::PlanningEnd {
4175 plan,
4176 estimated_steps: 3,
4177 };
4178 let json = serde_json::to_string(&event).unwrap();
4179 assert!(json.contains("planning_end"));
4180 assert!(json.contains("estimated_steps"));
4181 }
4182
4183 #[test]
4184 fn test_agent_event_serialize_step_start() {
4185 let event = AgentEvent::StepStart {
4186 step_id: "step-1".to_string(),
4187 description: "Initialize project".to_string(),
4188 step_number: 1,
4189 total_steps: 5,
4190 };
4191 let json = serde_json::to_string(&event).unwrap();
4192 assert!(json.contains("step_start"));
4193 assert!(json.contains("Initialize project"));
4194 }
4195
4196 #[test]
4197 fn test_agent_event_serialize_step_end() {
4198 let event = AgentEvent::StepEnd {
4199 step_id: "step-1".to_string(),
4200 status: TaskStatus::Completed,
4201 step_number: 1,
4202 total_steps: 5,
4203 };
4204 let json = serde_json::to_string(&event).unwrap();
4205 assert!(json.contains("step_end"));
4206 assert!(json.contains("step-1"));
4207 }
4208
4209 #[test]
4210 fn test_agent_event_serialize_goal_extracted() {
4211 use crate::planning::AgentGoal;
4212 let goal = AgentGoal::new("Complete the task".to_string());
4213 let event = AgentEvent::GoalExtracted { goal };
4214 let json = serde_json::to_string(&event).unwrap();
4215 assert!(json.contains("goal_extracted"));
4216 }
4217
4218 #[test]
4219 fn test_agent_event_serialize_goal_progress() {
4220 let event = AgentEvent::GoalProgress {
4221 goal: "Build app".to_string(),
4222 progress: 0.5,
4223 completed_steps: 2,
4224 total_steps: 4,
4225 };
4226 let json = serde_json::to_string(&event).unwrap();
4227 assert!(json.contains("goal_progress"));
4228 assert!(json.contains("0.5"));
4229 }
4230
4231 #[test]
4232 fn test_agent_event_serialize_goal_achieved() {
4233 let event = AgentEvent::GoalAchieved {
4234 goal: "Build app".to_string(),
4235 total_steps: 4,
4236 duration_ms: 5000,
4237 };
4238 let json = serde_json::to_string(&event).unwrap();
4239 assert!(json.contains("goal_achieved"));
4240 assert!(json.contains("5000"));
4241 }
4242
4243 #[tokio::test]
4244 async fn test_extract_goal_with_json_response() {
4245 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4247 r#"{"description": "Build web app", "success_criteria": ["App runs on port 3000", "Has login page"]}"#,
4248 )]));
4249 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4250 let agent = AgentLoop::new(
4251 mock_client,
4252 tool_executor,
4253 test_tool_context(),
4254 AgentConfig::default(),
4255 );
4256
4257 let goal = agent.extract_goal("Build a web app").await.unwrap();
4258 assert_eq!(goal.description, "Build web app");
4259 assert_eq!(goal.success_criteria.len(), 2);
4260 assert_eq!(goal.success_criteria[0], "App runs on port 3000");
4261 }
4262
4263 #[tokio::test]
4264 async fn test_extract_goal_fallback_on_non_json() {
4265 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4267 "Some non-JSON response",
4268 )]));
4269 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4270 let agent = AgentLoop::new(
4271 mock_client,
4272 tool_executor,
4273 test_tool_context(),
4274 AgentConfig::default(),
4275 );
4276
4277 let goal = agent.extract_goal("Do something").await.unwrap();
4278 assert_eq!(goal.description, "Do something");
4280 assert_eq!(goal.success_criteria.len(), 2);
4282 }
4283
4284 #[tokio::test]
4285 async fn test_check_goal_achievement_json_yes() {
4286 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4287 r#"{"achieved": true, "progress": 1.0, "remaining_criteria": []}"#,
4288 )]));
4289 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4290 let agent = AgentLoop::new(
4291 mock_client,
4292 tool_executor,
4293 test_tool_context(),
4294 AgentConfig::default(),
4295 );
4296
4297 let goal = crate::planning::AgentGoal::new("Test goal".to_string());
4298 let achieved = agent
4299 .check_goal_achievement(&goal, "All done")
4300 .await
4301 .unwrap();
4302 assert!(achieved);
4303 }
4304
4305 #[tokio::test]
4306 async fn test_check_goal_achievement_fallback_not_done() {
4307 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4309 "invalid json",
4310 )]));
4311 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4312 let agent = AgentLoop::new(
4313 mock_client,
4314 tool_executor,
4315 test_tool_context(),
4316 AgentConfig::default(),
4317 );
4318
4319 let goal = crate::planning::AgentGoal::new("Test goal".to_string());
4320 let achieved = agent
4322 .check_goal_achievement(&goal, "still working")
4323 .await
4324 .unwrap();
4325 assert!(!achieved);
4326 }
4327
4328 #[test]
4333 fn test_build_augmented_system_prompt_empty_context() {
4334 let mock_client = Arc::new(MockLlmClient::new(vec![]));
4335 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4336 let config = AgentConfig {
4337 system_prompt: Some("Base prompt".to_string()),
4338 ..Default::default()
4339 };
4340 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
4341
4342 let result = agent.build_augmented_system_prompt(&[]);
4343 assert_eq!(result, Some("Base prompt".to_string()));
4344 }
4345
4346 #[test]
4347 fn test_build_augmented_system_prompt_no_system_prompt() {
4348 let mock_client = Arc::new(MockLlmClient::new(vec![]));
4349 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4350 let agent = AgentLoop::new(
4351 mock_client,
4352 tool_executor,
4353 test_tool_context(),
4354 AgentConfig::default(),
4355 );
4356
4357 let result = agent.build_augmented_system_prompt(&[]);
4358 assert_eq!(result, None);
4359 }
4360
4361 #[test]
4362 fn test_build_augmented_system_prompt_with_context_no_base() {
4363 use crate::context::{ContextItem, ContextResult, ContextType};
4364
4365 let mock_client = Arc::new(MockLlmClient::new(vec![]));
4366 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4367 let agent = AgentLoop::new(
4368 mock_client,
4369 tool_executor,
4370 test_tool_context(),
4371 AgentConfig::default(),
4372 );
4373
4374 let context = vec![ContextResult {
4375 provider: "test".to_string(),
4376 items: vec![ContextItem::new("id1", ContextType::Resource, "Content")],
4377 total_tokens: 10,
4378 truncated: false,
4379 }];
4380
4381 let result = agent.build_augmented_system_prompt(&context);
4382 assert!(result.is_some());
4383 let text = result.unwrap();
4384 assert!(text.contains("<context"));
4385 assert!(text.contains("Content"));
4386 }
4387
4388 #[test]
4393 fn test_agent_result_clone() {
4394 let result = AgentResult {
4395 text: "output".to_string(),
4396 messages: vec![Message::user("hello")],
4397 usage: TokenUsage::default(),
4398 tool_calls_count: 3,
4399 };
4400 let cloned = result.clone();
4401 assert_eq!(cloned.text, result.text);
4402 assert_eq!(cloned.tool_calls_count, result.tool_calls_count);
4403 }
4404
4405 #[test]
4406 fn test_agent_result_debug() {
4407 let result = AgentResult {
4408 text: "output".to_string(),
4409 messages: vec![Message::user("hello")],
4410 usage: TokenUsage::default(),
4411 tool_calls_count: 3,
4412 };
4413 let debug = format!("{:?}", result);
4414 assert!(debug.contains("AgentResult"));
4415 assert!(debug.contains("output"));
4416 }
4417
4418 #[test]
4427 fn test_partition_by_lane_query_tools() {
4428 let tool_calls = vec![
4429 ToolCall {
4430 id: "t1".to_string(),
4431 name: "read".to_string(),
4432 args: serde_json::json!({"file": "a.rs"}),
4433 },
4434 ToolCall {
4435 id: "t2".to_string(),
4436 name: "glob".to_string(),
4437 args: serde_json::json!({"pattern": "**/*.rs"}),
4438 },
4439 ToolCall {
4440 id: "t3".to_string(),
4441 name: "grep".to_string(),
4442 args: serde_json::json!({"pattern": "fn main"}),
4443 },
4444 ToolCall {
4445 id: "t4".to_string(),
4446 name: "ls".to_string(),
4447 args: serde_json::json!({"path": "/tmp"}),
4448 },
4449 ToolCall {
4450 id: "t5".to_string(),
4451 name: "search".to_string(),
4452 args: serde_json::json!({"query": "error"}),
4453 },
4454 ToolCall {
4455 id: "t6".to_string(),
4456 name: "list_files".to_string(),
4457 args: serde_json::json!({}),
4458 },
4459 ];
4460
4461 let (query, sequential) = partition_by_lane(&tool_calls);
4462 assert_eq!(
4463 query.len(),
4464 6,
4465 "all read-only tools should be in query lane"
4466 );
4467 assert_eq!(sequential.len(), 0);
4468 }
4469
4470 #[test]
4471 fn test_partition_by_lane_execute_tools() {
4472 let tool_calls = vec![
4473 ToolCall {
4474 id: "t1".to_string(),
4475 name: "bash".to_string(),
4476 args: serde_json::json!({"command": "ls"}),
4477 },
4478 ToolCall {
4479 id: "t2".to_string(),
4480 name: "write".to_string(),
4481 args: serde_json::json!({"file": "a.rs", "content": ""}),
4482 },
4483 ToolCall {
4484 id: "t3".to_string(),
4485 name: "edit".to_string(),
4486 args: serde_json::json!({}),
4487 },
4488 ToolCall {
4489 id: "t4".to_string(),
4490 name: "delete".to_string(),
4491 args: serde_json::json!({}),
4492 },
4493 ];
4494
4495 let (query, sequential) = partition_by_lane(&tool_calls);
4496 assert_eq!(query.len(), 0);
4497 assert_eq!(sequential.len(), 4, "all write tools should be sequential");
4498 }
4499
4500 #[test]
4501 fn test_partition_by_lane_mixed() {
4502 let tool_calls = vec![
4503 ToolCall {
4504 id: "t1".to_string(),
4505 name: "read".to_string(),
4506 args: serde_json::json!({"file": "a.rs"}),
4507 },
4508 ToolCall {
4509 id: "t2".to_string(),
4510 name: "bash".to_string(),
4511 args: serde_json::json!({"command": "cargo build"}),
4512 },
4513 ToolCall {
4514 id: "t3".to_string(),
4515 name: "glob".to_string(),
4516 args: serde_json::json!({"pattern": "*.rs"}),
4517 },
4518 ToolCall {
4519 id: "t4".to_string(),
4520 name: "write".to_string(),
4521 args: serde_json::json!({"file": "b.rs", "content": ""}),
4522 },
4523 ToolCall {
4524 id: "t5".to_string(),
4525 name: "grep".to_string(),
4526 args: serde_json::json!({"pattern": "test"}),
4527 },
4528 ];
4529
4530 let (query, sequential) = partition_by_lane(&tool_calls);
4531 assert_eq!(query.len(), 3, "read/glob/grep → Query");
4532 assert_eq!(sequential.len(), 2, "bash/write → Sequential");
4533
4534 assert_eq!(query[0].name, "read");
4536 assert_eq!(query[1].name, "glob");
4537 assert_eq!(query[2].name, "grep");
4538 assert_eq!(sequential[0].name, "bash");
4539 assert_eq!(sequential[1].name, "write");
4540 }
4541
4542 #[test]
4543 fn test_partition_by_lane_empty() {
4544 let tool_calls: Vec<ToolCall> = vec![];
4545 let (query, sequential) = partition_by_lane(&tool_calls);
4546 assert!(query.is_empty());
4547 assert!(sequential.is_empty());
4548 }
4549
4550 #[test]
4551 fn test_partition_by_lane_unknown_tool_goes_sequential() {
4552 let tool_calls = vec![ToolCall {
4554 id: "t1".to_string(),
4555 name: "custom_tool".to_string(),
4556 args: serde_json::json!({}),
4557 }];
4558
4559 let (query, sequential) = partition_by_lane(&tool_calls);
4560 assert_eq!(query.len(), 0);
4561 assert_eq!(sequential.len(), 1);
4562 }
4563
4564 #[tokio::test]
4569 async fn test_tool_command_command_type() {
4570 let executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4571 let cmd = ToolCommand {
4572 tool_executor: executor,
4573 tool_name: "read".to_string(),
4574 tool_args: serde_json::json!({"file": "test.rs"}),
4575 skill_registry: None,
4576 tool_context: test_tool_context(),
4577 };
4578 assert_eq!(cmd.command_type(), "read");
4579 }
4580
4581 #[tokio::test]
4582 async fn test_tool_command_payload() {
4583 let executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4584 let args = serde_json::json!({"file": "test.rs", "offset": 10});
4585 let cmd = ToolCommand {
4586 tool_executor: executor,
4587 tool_name: "read".to_string(),
4588 tool_args: args.clone(),
4589 skill_registry: None,
4590 tool_context: test_tool_context(),
4591 };
4592 assert_eq!(cmd.payload(), args);
4593 }
4594
4595 #[tokio::test(flavor = "multi_thread")]
4600 async fn test_agent_loop_with_queue() {
4601 use tokio::sync::broadcast;
4602
4603 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4604 "Hello",
4605 )]));
4606 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4607 let config = AgentConfig::default();
4608
4609 let (event_tx, _) = broadcast::channel(100);
4610 let queue = SessionLaneQueue::new("test-session", SessionQueueConfig::default(), event_tx)
4611 .await
4612 .unwrap();
4613
4614 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config)
4615 .with_queue(Arc::new(queue));
4616
4617 assert!(agent.command_queue.is_some());
4618 }
4619
4620 #[tokio::test]
4621 async fn test_agent_loop_without_queue() {
4622 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4623 "Hello",
4624 )]));
4625 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4626 let config = AgentConfig::default();
4627
4628 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
4629
4630 assert!(agent.command_queue.is_none());
4631 }
4632
4633 #[tokio::test]
4638 async fn test_execute_plan_parallel_independent() {
4639 use crate::planning::{Complexity, ExecutionPlan, Task};
4640
4641 let mock_client = Arc::new(MockLlmClient::new(vec![
4644 MockLlmClient::text_response("Step 1 done"),
4645 MockLlmClient::text_response("Step 2 done"),
4646 MockLlmClient::text_response("Step 3 done"),
4647 ]));
4648
4649 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4650 let config = AgentConfig::default();
4651 let agent = AgentLoop::new(
4652 mock_client.clone(),
4653 tool_executor,
4654 test_tool_context(),
4655 config,
4656 );
4657
4658 let mut plan = ExecutionPlan::new("Test parallel", Complexity::Simple);
4659 plan.add_step(Task::new("s1", "First step"));
4660 plan.add_step(Task::new("s2", "Second step"));
4661 plan.add_step(Task::new("s3", "Third step"));
4662
4663 let (tx, mut rx) = mpsc::channel(100);
4664 let result = agent.execute_plan(&[], &plan, Some(tx)).await.unwrap();
4665
4666 assert_eq!(result.usage.total_tokens, 45);
4668
4669 let mut step_starts = Vec::new();
4671 let mut step_ends = Vec::new();
4672 rx.close();
4673 while let Some(event) = rx.recv().await {
4674 match event {
4675 AgentEvent::StepStart { step_id, .. } => step_starts.push(step_id),
4676 AgentEvent::StepEnd {
4677 step_id, status, ..
4678 } => {
4679 assert_eq!(status, TaskStatus::Completed);
4680 step_ends.push(step_id);
4681 }
4682 _ => {}
4683 }
4684 }
4685 assert_eq!(step_starts.len(), 3);
4686 assert_eq!(step_ends.len(), 3);
4687 }
4688
4689 #[tokio::test]
4690 async fn test_execute_plan_respects_dependencies() {
4691 use crate::planning::{Complexity, ExecutionPlan, Task};
4692
4693 let mock_client = Arc::new(MockLlmClient::new(vec![
4696 MockLlmClient::text_response("Step 1 done"),
4697 MockLlmClient::text_response("Step 2 done"),
4698 MockLlmClient::text_response("Step 3 done"),
4699 ]));
4700
4701 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4702 let config = AgentConfig::default();
4703 let agent = AgentLoop::new(
4704 mock_client.clone(),
4705 tool_executor,
4706 test_tool_context(),
4707 config,
4708 );
4709
4710 let mut plan = ExecutionPlan::new("Test deps", Complexity::Medium);
4711 plan.add_step(Task::new("s1", "Independent A"));
4712 plan.add_step(Task::new("s2", "Independent B"));
4713 plan.add_step(
4714 Task::new("s3", "Depends on A+B")
4715 .with_dependencies(vec!["s1".to_string(), "s2".to_string()]),
4716 );
4717
4718 let (tx, mut rx) = mpsc::channel(100);
4719 let result = agent.execute_plan(&[], &plan, Some(tx)).await.unwrap();
4720
4721 assert_eq!(result.usage.total_tokens, 45);
4723
4724 let mut events = Vec::new();
4726 rx.close();
4727 while let Some(event) = rx.recv().await {
4728 match &event {
4729 AgentEvent::StepStart { step_id, .. } => {
4730 events.push(format!("start:{}", step_id));
4731 }
4732 AgentEvent::StepEnd { step_id, .. } => {
4733 events.push(format!("end:{}", step_id));
4734 }
4735 _ => {}
4736 }
4737 }
4738
4739 let s1_end = events.iter().position(|e| e == "end:s1").unwrap();
4741 let s2_end = events.iter().position(|e| e == "end:s2").unwrap();
4742 let s3_start = events.iter().position(|e| e == "start:s3").unwrap();
4743 assert!(
4744 s3_start > s1_end,
4745 "s3 started before s1 ended: {:?}",
4746 events
4747 );
4748 assert!(
4749 s3_start > s2_end,
4750 "s3 started before s2 ended: {:?}",
4751 events
4752 );
4753
4754 assert!(result.text.contains("Step 3 done") || !result.text.is_empty());
4756 }
4757
4758 #[tokio::test]
4759 async fn test_execute_plan_handles_step_failure() {
4760 use crate::planning::{Complexity, ExecutionPlan, Task};
4761
4762 let mock_client = Arc::new(MockLlmClient::new(vec![
4772 MockLlmClient::text_response("s1 done"),
4774 MockLlmClient::text_response("s3 done"),
4775 ]));
4778
4779 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4780 let config = AgentConfig::default();
4781 let agent = AgentLoop::new(
4782 mock_client.clone(),
4783 tool_executor,
4784 test_tool_context(),
4785 config,
4786 );
4787
4788 let mut plan = ExecutionPlan::new("Test failure", Complexity::Medium);
4789 plan.add_step(Task::new("s1", "Independent step"));
4790 plan.add_step(Task::new("s2", "Depends on s1").with_dependencies(vec!["s1".to_string()]));
4791 plan.add_step(Task::new("s3", "Another independent"));
4792 plan.add_step(Task::new("s4", "Depends on s2").with_dependencies(vec!["s2".to_string()]));
4793
4794 let (tx, mut rx) = mpsc::channel(100);
4795 let _result = agent.execute_plan(&[], &plan, Some(tx)).await.unwrap();
4796
4797 let mut completed_steps = Vec::new();
4800 let mut failed_steps = Vec::new();
4801 rx.close();
4802 while let Some(event) = rx.recv().await {
4803 if let AgentEvent::StepEnd {
4804 step_id, status, ..
4805 } = event
4806 {
4807 match status {
4808 TaskStatus::Completed => completed_steps.push(step_id),
4809 TaskStatus::Failed => failed_steps.push(step_id),
4810 _ => {}
4811 }
4812 }
4813 }
4814
4815 assert!(
4816 completed_steps.contains(&"s1".to_string()),
4817 "s1 should complete"
4818 );
4819 assert!(
4820 completed_steps.contains(&"s3".to_string()),
4821 "s3 should complete"
4822 );
4823 assert!(failed_steps.contains(&"s2".to_string()), "s2 should fail");
4824 assert!(
4826 !completed_steps.contains(&"s4".to_string()),
4827 "s4 should not complete"
4828 );
4829 assert!(
4830 !failed_steps.contains(&"s4".to_string()),
4831 "s4 should not fail (never started)"
4832 );
4833 }
4834}