1use crate::context::{ContextProvider, ContextQuery, ContextResult};
13use crate::hitl::ConfirmationProvider;
14use crate::hooks::{
15 ErrorType, GenerateEndEvent, GenerateStartEvent, HookEvent, HookExecutor, HookResult,
16 OnErrorEvent, PostResponseEvent, PostToolUseEvent, PrePromptEvent, PreToolUseEvent,
17 TokenUsageInfo, ToolCallInfo, ToolResultData,
18};
19use crate::llm::{LlmClient, LlmResponse, Message, TokenUsage, ToolDefinition};
20use crate::permissions::{PermissionChecker, PermissionDecision};
21use crate::planning::{AgentGoal, ExecutionPlan, TaskStatus};
22use crate::prompts::SystemPromptSlots;
23use crate::queue::SessionCommand;
24use crate::session_lane_queue::SessionLaneQueue;
25use crate::tool_search::ToolIndex;
26use crate::tools::{ToolContext, ToolExecutor, ToolStreamEvent};
27use anyhow::{Context, Result};
28use async_trait::async_trait;
29use futures::future::join_all;
30use serde::{Deserialize, Serialize};
31use serde_json::Value;
32use std::sync::Arc;
33use std::time::Duration;
34use tokio::sync::{mpsc, RwLock};
35
36const MAX_TOOL_ROUNDS: usize = 50;
38
39#[derive(Clone)]
41pub struct AgentConfig {
42 pub prompt_slots: SystemPromptSlots,
48 pub tools: Vec<ToolDefinition>,
49 pub max_tool_rounds: usize,
50 pub security_provider: Option<Arc<dyn crate::security::SecurityProvider>>,
52 pub permission_checker: Option<Arc<dyn PermissionChecker>>,
54 pub confirmation_manager: Option<Arc<dyn ConfirmationProvider>>,
56 pub context_providers: Vec<Arc<dyn ContextProvider>>,
58 pub planning_enabled: bool,
60 pub goal_tracking: bool,
62 pub hook_engine: Option<Arc<dyn HookExecutor>>,
64 pub skill_registry: Option<Arc<crate::skills::SkillRegistry>>,
66 pub max_parse_retries: u32,
72 pub tool_timeout_ms: Option<u64>,
78 pub circuit_breaker_threshold: u32,
85 pub auto_compact: bool,
87 pub auto_compact_threshold: f32,
90 pub max_context_tokens: usize,
93 pub llm_client: Option<Arc<dyn LlmClient>>,
95 pub memory: Option<Arc<crate::memory::AgentMemory>>,
97 pub continuation_enabled: bool,
105 pub max_continuation_turns: u32,
109 pub tool_index: Option<ToolIndex>,
114}
115
116impl std::fmt::Debug for AgentConfig {
117 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118 f.debug_struct("AgentConfig")
119 .field("prompt_slots", &self.prompt_slots)
120 .field("tools", &self.tools)
121 .field("max_tool_rounds", &self.max_tool_rounds)
122 .field("security_provider", &self.security_provider.is_some())
123 .field("permission_checker", &self.permission_checker.is_some())
124 .field("confirmation_manager", &self.confirmation_manager.is_some())
125 .field("context_providers", &self.context_providers.len())
126 .field("planning_enabled", &self.planning_enabled)
127 .field("goal_tracking", &self.goal_tracking)
128 .field("hook_engine", &self.hook_engine.is_some())
129 .field(
130 "skill_registry",
131 &self.skill_registry.as_ref().map(|r| r.len()),
132 )
133 .field("max_parse_retries", &self.max_parse_retries)
134 .field("tool_timeout_ms", &self.tool_timeout_ms)
135 .field("circuit_breaker_threshold", &self.circuit_breaker_threshold)
136 .field("auto_compact", &self.auto_compact)
137 .field("auto_compact_threshold", &self.auto_compact_threshold)
138 .field("max_context_tokens", &self.max_context_tokens)
139 .field("continuation_enabled", &self.continuation_enabled)
140 .field("max_continuation_turns", &self.max_continuation_turns)
141 .field("memory", &self.memory.is_some())
142 .field("tool_index", &self.tool_index.as_ref().map(|i| i.len()))
143 .finish()
144 }
145}
146
147impl Default for AgentConfig {
148 fn default() -> Self {
149 Self {
150 prompt_slots: SystemPromptSlots::default(),
151 tools: Vec::new(), max_tool_rounds: MAX_TOOL_ROUNDS,
153 security_provider: None,
154 permission_checker: None,
155 confirmation_manager: None,
156 context_providers: Vec::new(),
157 planning_enabled: false,
158 goal_tracking: false,
159 hook_engine: None,
160 skill_registry: None,
161 max_parse_retries: 2,
162 tool_timeout_ms: None,
163 circuit_breaker_threshold: 3,
164 auto_compact: false,
165 auto_compact_threshold: 0.80,
166 max_context_tokens: 200_000,
167 llm_client: None,
168 memory: None,
169 continuation_enabled: true,
170 max_continuation_turns: 3,
171 tool_index: None,
172 }
173 }
174}
175
176#[derive(Debug, Clone, Serialize, Deserialize)]
182#[serde(tag = "type")]
183#[non_exhaustive]
184pub enum AgentEvent {
185 #[serde(rename = "agent_start")]
187 Start { prompt: String },
188
189 #[serde(rename = "turn_start")]
191 TurnStart { turn: usize },
192
193 #[serde(rename = "text_delta")]
195 TextDelta { text: String },
196
197 #[serde(rename = "tool_start")]
199 ToolStart { id: String, name: String },
200
201 #[serde(rename = "tool_input_delta")]
203 ToolInputDelta { delta: String },
204
205 #[serde(rename = "tool_end")]
207 ToolEnd {
208 id: String,
209 name: String,
210 output: String,
211 exit_code: i32,
212 },
213
214 #[serde(rename = "tool_output_delta")]
216 ToolOutputDelta {
217 id: String,
218 name: String,
219 delta: String,
220 },
221
222 #[serde(rename = "turn_end")]
224 TurnEnd { turn: usize, usage: TokenUsage },
225
226 #[serde(rename = "agent_end")]
228 End { text: String, usage: TokenUsage },
229
230 #[serde(rename = "error")]
232 Error { message: String },
233
234 #[serde(rename = "confirmation_required")]
236 ConfirmationRequired {
237 tool_id: String,
238 tool_name: String,
239 args: serde_json::Value,
240 timeout_ms: u64,
241 },
242
243 #[serde(rename = "confirmation_received")]
245 ConfirmationReceived {
246 tool_id: String,
247 approved: bool,
248 reason: Option<String>,
249 },
250
251 #[serde(rename = "confirmation_timeout")]
253 ConfirmationTimeout {
254 tool_id: String,
255 action_taken: String, },
257
258 #[serde(rename = "external_task_pending")]
260 ExternalTaskPending {
261 task_id: String,
262 session_id: String,
263 lane: crate::hitl::SessionLane,
264 command_type: String,
265 payload: serde_json::Value,
266 timeout_ms: u64,
267 },
268
269 #[serde(rename = "external_task_completed")]
271 ExternalTaskCompleted {
272 task_id: String,
273 session_id: String,
274 success: bool,
275 },
276
277 #[serde(rename = "permission_denied")]
279 PermissionDenied {
280 tool_id: String,
281 tool_name: String,
282 args: serde_json::Value,
283 reason: String,
284 },
285
286 #[serde(rename = "context_resolving")]
288 ContextResolving { providers: Vec<String> },
289
290 #[serde(rename = "context_resolved")]
292 ContextResolved {
293 total_items: usize,
294 total_tokens: usize,
295 },
296
297 #[serde(rename = "command_dead_lettered")]
302 CommandDeadLettered {
303 command_id: String,
304 command_type: String,
305 lane: String,
306 error: String,
307 attempts: u32,
308 },
309
310 #[serde(rename = "command_retry")]
312 CommandRetry {
313 command_id: String,
314 command_type: String,
315 lane: String,
316 attempt: u32,
317 delay_ms: u64,
318 },
319
320 #[serde(rename = "queue_alert")]
322 QueueAlert {
323 level: String,
324 alert_type: String,
325 message: String,
326 },
327
328 #[serde(rename = "task_updated")]
333 TaskUpdated {
334 session_id: String,
335 tasks: Vec<crate::planning::Task>,
336 },
337
338 #[serde(rename = "memory_stored")]
343 MemoryStored {
344 memory_id: String,
345 memory_type: String,
346 importance: f32,
347 tags: Vec<String>,
348 },
349
350 #[serde(rename = "memory_recalled")]
352 MemoryRecalled {
353 memory_id: String,
354 content: String,
355 relevance: f32,
356 },
357
358 #[serde(rename = "memories_searched")]
360 MemoriesSearched {
361 query: Option<String>,
362 tags: Vec<String>,
363 result_count: usize,
364 },
365
366 #[serde(rename = "memory_cleared")]
368 MemoryCleared {
369 tier: String, count: u64,
371 },
372
373 #[serde(rename = "subagent_start")]
378 SubagentStart {
379 task_id: String,
381 session_id: String,
383 parent_session_id: String,
385 agent: String,
387 description: String,
389 },
390
391 #[serde(rename = "subagent_progress")]
393 SubagentProgress {
394 task_id: String,
396 session_id: String,
398 status: String,
400 metadata: serde_json::Value,
402 },
403
404 #[serde(rename = "subagent_end")]
406 SubagentEnd {
407 task_id: String,
409 session_id: String,
411 agent: String,
413 output: String,
415 success: bool,
417 },
418
419 #[serde(rename = "planning_start")]
424 PlanningStart { prompt: String },
425
426 #[serde(rename = "planning_end")]
428 PlanningEnd {
429 plan: ExecutionPlan,
430 estimated_steps: usize,
431 },
432
433 #[serde(rename = "step_start")]
435 StepStart {
436 step_id: String,
437 description: String,
438 step_number: usize,
439 total_steps: usize,
440 },
441
442 #[serde(rename = "step_end")]
444 StepEnd {
445 step_id: String,
446 status: TaskStatus,
447 step_number: usize,
448 total_steps: usize,
449 },
450
451 #[serde(rename = "goal_extracted")]
453 GoalExtracted { goal: AgentGoal },
454
455 #[serde(rename = "goal_progress")]
457 GoalProgress {
458 goal: String,
459 progress: f32,
460 completed_steps: usize,
461 total_steps: usize,
462 },
463
464 #[serde(rename = "goal_achieved")]
466 GoalAchieved {
467 goal: String,
468 total_steps: usize,
469 duration_ms: i64,
470 },
471
472 #[serde(rename = "context_compacted")]
477 ContextCompacted {
478 session_id: String,
479 before_messages: usize,
480 after_messages: usize,
481 percent_before: f32,
482 },
483
484 #[serde(rename = "persistence_failed")]
489 PersistenceFailed {
490 session_id: String,
491 operation: String,
492 error: String,
493 },
494}
495
496#[derive(Debug, Clone)]
498pub struct AgentResult {
499 pub text: String,
500 pub messages: Vec<Message>,
501 pub usage: TokenUsage,
502 pub tool_calls_count: usize,
503}
504
505pub struct ToolCommand {
513 tool_executor: Arc<ToolExecutor>,
514 tool_name: String,
515 tool_args: Value,
516 tool_context: ToolContext,
517 skill_registry: Option<Arc<crate::skills::SkillRegistry>>,
518}
519
520impl ToolCommand {
521 pub fn new(
523 tool_executor: Arc<ToolExecutor>,
524 tool_name: String,
525 tool_args: Value,
526 tool_context: ToolContext,
527 skill_registry: Option<Arc<crate::skills::SkillRegistry>>,
528 ) -> Self {
529 Self {
530 tool_executor,
531 tool_name,
532 tool_args,
533 tool_context,
534 skill_registry,
535 }
536 }
537}
538
539#[async_trait]
540impl SessionCommand for ToolCommand {
541 async fn execute(&self) -> Result<Value> {
542 if let Some(registry) = &self.skill_registry {
544 let instruction_skills = registry.by_kind(crate::skills::SkillKind::Instruction);
545
546 let has_restrictions = instruction_skills.iter().any(|s| s.allowed_tools.is_some());
548
549 if has_restrictions {
550 let mut allowed = false;
551
552 for skill in &instruction_skills {
553 if skill.is_tool_allowed(&self.tool_name) {
554 allowed = true;
555 break;
556 }
557 }
558
559 if !allowed {
560 return Err(anyhow::anyhow!(
561 "Tool '{}' is not allowed by any active skill. Active skills restrict tools to their allowed-tools lists.",
562 self.tool_name
563 ));
564 }
565 }
566 }
567
568 let result = self
570 .tool_executor
571 .execute_with_context(&self.tool_name, &self.tool_args, &self.tool_context)
572 .await?;
573 Ok(serde_json::json!({
574 "output": result.output,
575 "exit_code": result.exit_code,
576 "metadata": result.metadata,
577 }))
578 }
579
580 fn command_type(&self) -> &str {
581 &self.tool_name
582 }
583
584 fn payload(&self) -> Value {
585 self.tool_args.clone()
586 }
587}
588
589#[derive(Clone)]
595pub struct AgentLoop {
596 llm_client: Arc<dyn LlmClient>,
597 tool_executor: Arc<ToolExecutor>,
598 tool_context: ToolContext,
599 config: AgentConfig,
600 tool_metrics: Option<Arc<RwLock<crate::telemetry::ToolMetrics>>>,
602 command_queue: Option<Arc<SessionLaneQueue>>,
604}
605
606impl AgentLoop {
607 pub fn new(
608 llm_client: Arc<dyn LlmClient>,
609 tool_executor: Arc<ToolExecutor>,
610 tool_context: ToolContext,
611 config: AgentConfig,
612 ) -> Self {
613 Self {
614 llm_client,
615 tool_executor,
616 tool_context,
617 config,
618 tool_metrics: None,
619 command_queue: None,
620 }
621 }
622
623 pub fn with_tool_metrics(
625 mut self,
626 metrics: Arc<RwLock<crate::telemetry::ToolMetrics>>,
627 ) -> Self {
628 self.tool_metrics = Some(metrics);
629 self
630 }
631
632 pub fn with_queue(mut self, queue: Arc<SessionLaneQueue>) -> Self {
637 self.command_queue = Some(queue);
638 self
639 }
640
641 async fn execute_tool_timed(
647 &self,
648 name: &str,
649 args: &serde_json::Value,
650 ctx: &ToolContext,
651 ) -> anyhow::Result<crate::tools::ToolResult> {
652 let fut = self.tool_executor.execute_with_context(name, args, ctx);
653 if let Some(timeout_ms) = self.config.tool_timeout_ms {
654 match tokio::time::timeout(Duration::from_millis(timeout_ms), fut).await {
655 Ok(result) => result,
656 Err(_) => Err(anyhow::anyhow!(
657 "Tool '{}' timed out after {}ms",
658 name,
659 timeout_ms
660 )),
661 }
662 } else {
663 fut.await
664 }
665 }
666
667 fn tool_result_to_tuple(
669 result: anyhow::Result<crate::tools::ToolResult>,
670 ) -> (
671 String,
672 i32,
673 bool,
674 Option<serde_json::Value>,
675 Vec<crate::llm::Attachment>,
676 ) {
677 match result {
678 Ok(r) => (
679 r.output,
680 r.exit_code,
681 r.exit_code != 0,
682 r.metadata,
683 r.images,
684 ),
685 Err(e) => (
686 format!("Tool execution error: {}", e),
687 1,
688 true,
689 None,
690 Vec::new(),
691 ),
692 }
693 }
694
695 async fn execute_tool_queued_or_direct(
697 &self,
698 name: &str,
699 args: &serde_json::Value,
700 ctx: &ToolContext,
701 ) -> anyhow::Result<crate::tools::ToolResult> {
702 if let Some(ref queue) = self.command_queue {
703 let command = ToolCommand::new(
704 Arc::clone(&self.tool_executor),
705 name.to_string(),
706 args.clone(),
707 ctx.clone(),
708 self.config.skill_registry.clone(),
709 );
710 let rx = queue.submit_by_tool(name, Box::new(command)).await;
711 match rx.await {
712 Ok(Ok(value)) => {
713 let output = value["output"]
714 .as_str()
715 .ok_or_else(|| {
716 anyhow::anyhow!(
717 "Queue result missing 'output' field for tool '{}'",
718 name
719 )
720 })?
721 .to_string();
722 let exit_code = value["exit_code"].as_i64().unwrap_or(0) as i32;
723 return Ok(crate::tools::ToolResult {
724 name: name.to_string(),
725 output,
726 exit_code,
727 metadata: None,
728 images: Vec::new(),
729 });
730 }
731 Ok(Err(e)) => {
732 tracing::warn!(
733 "Queue execution failed for tool '{}', falling back to direct: {}",
734 name,
735 e
736 );
737 }
738 Err(_) => {
739 tracing::warn!(
740 "Queue channel closed for tool '{}', falling back to direct",
741 name
742 );
743 }
744 }
745 }
746 self.execute_tool_timed(name, args, ctx).await
747 }
748
749 async fn call_llm(
760 &self,
761 messages: &[Message],
762 system: Option<&str>,
763 event_tx: &Option<mpsc::Sender<AgentEvent>>,
764 ) -> anyhow::Result<LlmResponse> {
765 let tools = if let Some(ref index) = self.config.tool_index {
767 let query = messages
768 .iter()
769 .rev()
770 .find(|m| m.role == "user")
771 .and_then(|m| {
772 m.content.iter().find_map(|b| match b {
773 crate::llm::ContentBlock::Text { text } => Some(text.as_str()),
774 _ => None,
775 })
776 })
777 .unwrap_or("");
778 let matches = index.search(query, index.len());
779 let matched_names: std::collections::HashSet<&str> =
780 matches.iter().map(|m| m.name.as_str()).collect();
781 self.config
782 .tools
783 .iter()
784 .filter(|t| matched_names.contains(t.name.as_str()))
785 .cloned()
786 .collect::<Vec<_>>()
787 } else {
788 self.config.tools.clone()
789 };
790
791 if event_tx.is_some() {
792 let mut stream_rx = self
793 .llm_client
794 .complete_streaming(messages, system, &tools)
795 .await
796 .context("LLM streaming call failed")?;
797
798 let mut final_response: Option<LlmResponse> = None;
799 while let Some(event) = stream_rx.recv().await {
800 match event {
801 crate::llm::StreamEvent::TextDelta(text) => {
802 if let Some(tx) = event_tx {
803 tx.send(AgentEvent::TextDelta { text }).await.ok();
804 }
805 }
806 crate::llm::StreamEvent::ToolUseStart { id, name } => {
807 if let Some(tx) = event_tx {
808 tx.send(AgentEvent::ToolStart { id, name }).await.ok();
809 }
810 }
811 crate::llm::StreamEvent::ToolUseInputDelta(delta) => {
812 if let Some(tx) = event_tx {
813 tx.send(AgentEvent::ToolInputDelta { delta }).await.ok();
814 }
815 }
816 crate::llm::StreamEvent::Done(resp) => {
817 final_response = Some(resp);
818 }
819 }
820 }
821 final_response.context("Stream ended without final response")
822 } else {
823 self.llm_client
824 .complete(messages, system, &tools)
825 .await
826 .context("LLM call failed")
827 }
828 }
829
830 fn streaming_tool_context(
839 &self,
840 event_tx: &Option<mpsc::Sender<AgentEvent>>,
841 tool_id: &str,
842 tool_name: &str,
843 ) -> ToolContext {
844 let mut ctx = self.tool_context.clone();
845 if let Some(agent_tx) = event_tx {
846 let (tool_tx, mut tool_rx) = mpsc::channel::<ToolStreamEvent>(64);
847 ctx.event_tx = Some(tool_tx);
848
849 let agent_tx = agent_tx.clone();
850 let tool_id = tool_id.to_string();
851 let tool_name = tool_name.to_string();
852 tokio::spawn(async move {
853 while let Some(event) = tool_rx.recv().await {
854 match event {
855 ToolStreamEvent::OutputDelta(delta) => {
856 agent_tx
857 .send(AgentEvent::ToolOutputDelta {
858 id: tool_id.clone(),
859 name: tool_name.clone(),
860 delta,
861 })
862 .await
863 .ok();
864 }
865 }
866 }
867 });
868 }
869 ctx
870 }
871
872 async fn resolve_context(&self, prompt: &str, session_id: Option<&str>) -> Vec<ContextResult> {
876 if self.config.context_providers.is_empty() {
877 return Vec::new();
878 }
879
880 let query = ContextQuery::new(prompt).with_session_id(session_id.unwrap_or(""));
881
882 let futures = self
883 .config
884 .context_providers
885 .iter()
886 .map(|p| p.query(&query));
887 let outcomes = join_all(futures).await;
888
889 outcomes
890 .into_iter()
891 .enumerate()
892 .filter_map(|(i, r)| match r {
893 Ok(result) if !result.is_empty() => Some(result),
894 Ok(_) => None,
895 Err(e) => {
896 tracing::warn!(
897 "Context provider '{}' failed: {}",
898 self.config.context_providers[i].name(),
899 e
900 );
901 None
902 }
903 })
904 .collect()
905 }
906
907 fn looks_incomplete(text: &str) -> bool {
915 let t = text.trim();
916 if t.is_empty() {
917 return true;
918 }
919 if t.len() < 80 && !t.contains('\n') {
921 let ends_continuation =
924 t.ends_with(':') || t.ends_with("...") || t.ends_with('…') || t.ends_with(',');
925 if ends_continuation {
926 return true;
927 }
928 }
929 let incomplete_phrases = [
931 "i'll ",
932 "i will ",
933 "let me ",
934 "i need to ",
935 "i should ",
936 "next, i",
937 "first, i",
938 "now i",
939 "i'll start",
940 "i'll begin",
941 "i'll now",
942 "let's start",
943 "let's begin",
944 "to do this",
945 "i'm going to",
946 ];
947 let lower = t.to_lowercase();
948 for phrase in &incomplete_phrases {
949 if lower.contains(phrase) {
950 return true;
951 }
952 }
953 false
954 }
955
956 fn system_prompt(&self) -> String {
958 self.config.prompt_slots.build()
959 }
960
961 fn build_augmented_system_prompt(&self, context_results: &[ContextResult]) -> Option<String> {
963 let base = self.system_prompt();
964 if context_results.is_empty() {
965 return Some(base);
966 }
967
968 let context_xml: String = context_results
970 .iter()
971 .map(|r| r.to_xml())
972 .collect::<Vec<_>>()
973 .join("\n\n");
974
975 Some(format!("{}\n\n{}", base, context_xml))
976 }
977
978 async fn notify_turn_complete(&self, session_id: &str, prompt: &str, response: &str) {
980 let futures = self
981 .config
982 .context_providers
983 .iter()
984 .map(|p| p.on_turn_complete(session_id, prompt, response));
985 let outcomes = join_all(futures).await;
986
987 for (i, result) in outcomes.into_iter().enumerate() {
988 if let Err(e) = result {
989 tracing::warn!(
990 "Context provider '{}' on_turn_complete failed: {}",
991 self.config.context_providers[i].name(),
992 e
993 );
994 }
995 }
996 }
997
998 async fn fire_pre_tool_use(
1001 &self,
1002 session_id: &str,
1003 tool_name: &str,
1004 args: &serde_json::Value,
1005 ) -> Option<HookResult> {
1006 if let Some(he) = &self.config.hook_engine {
1007 let event = HookEvent::PreToolUse(PreToolUseEvent {
1008 session_id: session_id.to_string(),
1009 tool: tool_name.to_string(),
1010 args: args.clone(),
1011 working_directory: self.tool_context.workspace.to_string_lossy().to_string(),
1012 recent_tools: Vec::new(),
1013 });
1014 let result = he.fire(&event).await;
1015 if result.is_block() {
1016 return Some(result);
1017 }
1018 }
1019 None
1020 }
1021
1022 async fn fire_post_tool_use(
1024 &self,
1025 session_id: &str,
1026 tool_name: &str,
1027 args: &serde_json::Value,
1028 output: &str,
1029 success: bool,
1030 duration_ms: u64,
1031 ) {
1032 if let Some(he) = &self.config.hook_engine {
1033 let event = HookEvent::PostToolUse(PostToolUseEvent {
1034 session_id: session_id.to_string(),
1035 tool: tool_name.to_string(),
1036 args: args.clone(),
1037 result: ToolResultData {
1038 success,
1039 output: output.to_string(),
1040 exit_code: if success { Some(0) } else { Some(1) },
1041 duration_ms,
1042 },
1043 });
1044 let he = Arc::clone(he);
1045 tokio::spawn(async move {
1046 let _ = he.fire(&event).await;
1047 });
1048 }
1049 }
1050
1051 async fn fire_generate_start(
1053 &self,
1054 session_id: &str,
1055 prompt: &str,
1056 system_prompt: &Option<String>,
1057 ) {
1058 if let Some(he) = &self.config.hook_engine {
1059 let event = HookEvent::GenerateStart(GenerateStartEvent {
1060 session_id: session_id.to_string(),
1061 prompt: prompt.to_string(),
1062 system_prompt: system_prompt.clone(),
1063 model_provider: String::new(),
1064 model_name: String::new(),
1065 available_tools: self.config.tools.iter().map(|t| t.name.clone()).collect(),
1066 });
1067 let _ = he.fire(&event).await;
1068 }
1069 }
1070
1071 async fn fire_generate_end(
1073 &self,
1074 session_id: &str,
1075 prompt: &str,
1076 response: &LlmResponse,
1077 duration_ms: u64,
1078 ) {
1079 if let Some(he) = &self.config.hook_engine {
1080 let tool_calls: Vec<ToolCallInfo> = response
1081 .tool_calls()
1082 .iter()
1083 .map(|tc| ToolCallInfo {
1084 name: tc.name.clone(),
1085 args: tc.args.clone(),
1086 })
1087 .collect();
1088
1089 let event = HookEvent::GenerateEnd(GenerateEndEvent {
1090 session_id: session_id.to_string(),
1091 prompt: prompt.to_string(),
1092 response_text: response.text().to_string(),
1093 tool_calls,
1094 usage: TokenUsageInfo {
1095 prompt_tokens: response.usage.prompt_tokens as i32,
1096 completion_tokens: response.usage.completion_tokens as i32,
1097 total_tokens: response.usage.total_tokens as i32,
1098 },
1099 duration_ms,
1100 });
1101 let _ = he.fire(&event).await;
1102 }
1103 }
1104
1105 async fn fire_pre_prompt(
1108 &self,
1109 session_id: &str,
1110 prompt: &str,
1111 system_prompt: &Option<String>,
1112 message_count: usize,
1113 ) -> Option<String> {
1114 if let Some(he) = &self.config.hook_engine {
1115 let event = HookEvent::PrePrompt(PrePromptEvent {
1116 session_id: session_id.to_string(),
1117 prompt: prompt.to_string(),
1118 system_prompt: system_prompt.clone(),
1119 message_count,
1120 });
1121 let result = he.fire(&event).await;
1122 if let HookResult::Continue(Some(modified)) = result {
1123 if let Some(new_prompt) = modified.get("prompt").and_then(|v| v.as_str()) {
1125 return Some(new_prompt.to_string());
1126 }
1127 }
1128 }
1129 None
1130 }
1131
1132 async fn fire_post_response(
1134 &self,
1135 session_id: &str,
1136 response_text: &str,
1137 tool_calls_count: usize,
1138 usage: &TokenUsage,
1139 duration_ms: u64,
1140 ) {
1141 if let Some(he) = &self.config.hook_engine {
1142 let event = HookEvent::PostResponse(PostResponseEvent {
1143 session_id: session_id.to_string(),
1144 response_text: response_text.to_string(),
1145 tool_calls_count,
1146 usage: TokenUsageInfo {
1147 prompt_tokens: usage.prompt_tokens as i32,
1148 completion_tokens: usage.completion_tokens as i32,
1149 total_tokens: usage.total_tokens as i32,
1150 },
1151 duration_ms,
1152 });
1153 let he = Arc::clone(he);
1154 tokio::spawn(async move {
1155 let _ = he.fire(&event).await;
1156 });
1157 }
1158 }
1159
1160 async fn fire_on_error(
1162 &self,
1163 session_id: &str,
1164 error_type: ErrorType,
1165 error_message: &str,
1166 context: serde_json::Value,
1167 ) {
1168 if let Some(he) = &self.config.hook_engine {
1169 let event = HookEvent::OnError(OnErrorEvent {
1170 session_id: session_id.to_string(),
1171 error_type,
1172 error_message: error_message.to_string(),
1173 context,
1174 });
1175 let he = Arc::clone(he);
1176 tokio::spawn(async move {
1177 let _ = he.fire(&event).await;
1178 });
1179 }
1180 }
1181
1182 pub async fn execute(
1188 &self,
1189 history: &[Message],
1190 prompt: &str,
1191 event_tx: Option<mpsc::Sender<AgentEvent>>,
1192 ) -> Result<AgentResult> {
1193 self.execute_with_session(history, prompt, None, event_tx)
1194 .await
1195 }
1196
1197 pub async fn execute_from_messages(
1203 &self,
1204 messages: Vec<Message>,
1205 session_id: Option<&str>,
1206 event_tx: Option<mpsc::Sender<AgentEvent>>,
1207 ) -> Result<AgentResult> {
1208 tracing::info!(
1209 a3s.session.id = session_id.unwrap_or("none"),
1210 a3s.agent.max_turns = self.config.max_tool_rounds,
1211 "a3s.agent.execute_from_messages started"
1212 );
1213
1214 let effective_prompt = messages
1218 .iter()
1219 .rev()
1220 .find(|m| m.role == "user")
1221 .map(|m| m.text())
1222 .unwrap_or_default();
1223
1224 let result = self
1225 .execute_loop_inner(&messages, "", &effective_prompt, session_id, event_tx)
1226 .await;
1227
1228 match &result {
1229 Ok(r) => tracing::info!(
1230 a3s.agent.tool_calls_count = r.tool_calls_count,
1231 a3s.llm.total_tokens = r.usage.total_tokens,
1232 "a3s.agent.execute_from_messages completed"
1233 ),
1234 Err(e) => tracing::warn!(
1235 error = %e,
1236 "a3s.agent.execute_from_messages failed"
1237 ),
1238 }
1239
1240 result
1241 }
1242
1243 pub async fn execute_with_session(
1248 &self,
1249 history: &[Message],
1250 prompt: &str,
1251 session_id: Option<&str>,
1252 event_tx: Option<mpsc::Sender<AgentEvent>>,
1253 ) -> Result<AgentResult> {
1254 tracing::info!(
1255 a3s.session.id = session_id.unwrap_or("none"),
1256 a3s.agent.max_turns = self.config.max_tool_rounds,
1257 "a3s.agent.execute started"
1258 );
1259
1260 let result = if self.config.planning_enabled {
1262 self.execute_with_planning(history, prompt, event_tx).await
1263 } else {
1264 self.execute_loop(history, prompt, session_id, event_tx)
1265 .await
1266 };
1267
1268 match &result {
1269 Ok(r) => {
1270 tracing::info!(
1271 a3s.agent.tool_calls_count = r.tool_calls_count,
1272 a3s.llm.total_tokens = r.usage.total_tokens,
1273 "a3s.agent.execute completed"
1274 );
1275 self.fire_post_response(
1277 session_id.unwrap_or(""),
1278 &r.text,
1279 r.tool_calls_count,
1280 &r.usage,
1281 0, )
1283 .await;
1284 }
1285 Err(e) => {
1286 tracing::warn!(
1287 error = %e,
1288 "a3s.agent.execute failed"
1289 );
1290 self.fire_on_error(
1292 session_id.unwrap_or(""),
1293 ErrorType::Other,
1294 &e.to_string(),
1295 serde_json::json!({"phase": "execute"}),
1296 )
1297 .await;
1298 }
1299 }
1300
1301 result
1302 }
1303
1304 async fn execute_loop(
1310 &self,
1311 history: &[Message],
1312 prompt: &str,
1313 session_id: Option<&str>,
1314 event_tx: Option<mpsc::Sender<AgentEvent>>,
1315 ) -> Result<AgentResult> {
1316 self.execute_loop_inner(history, prompt, prompt, session_id, event_tx)
1319 .await
1320 }
1321
1322 async fn execute_loop_inner(
1327 &self,
1328 history: &[Message],
1329 msg_prompt: &str,
1330 effective_prompt: &str,
1331 session_id: Option<&str>,
1332 event_tx: Option<mpsc::Sender<AgentEvent>>,
1333 ) -> Result<AgentResult> {
1334 let mut messages = history.to_vec();
1335 let mut total_usage = TokenUsage::default();
1336 let mut tool_calls_count = 0;
1337 let mut turn = 0;
1338 let mut parse_error_count: u32 = 0;
1340 let mut continuation_count: u32 = 0;
1342
1343 if let Some(tx) = &event_tx {
1345 tx.send(AgentEvent::Start {
1346 prompt: effective_prompt.to_string(),
1347 })
1348 .await
1349 .ok();
1350 }
1351
1352 let _queue_forward_handle =
1354 if let (Some(ref queue), Some(ref tx)) = (&self.command_queue, &event_tx) {
1355 let mut rx = queue.subscribe();
1356 let tx = tx.clone();
1357 Some(tokio::spawn(async move {
1358 while let Ok(event) = rx.recv().await {
1359 if tx.send(event).await.is_err() {
1360 break;
1361 }
1362 }
1363 }))
1364 } else {
1365 None
1366 };
1367
1368 let built_system_prompt = Some(self.system_prompt());
1370 let hooked_prompt = if let Some(modified) = self
1371 .fire_pre_prompt(
1372 session_id.unwrap_or(""),
1373 effective_prompt,
1374 &built_system_prompt,
1375 messages.len(),
1376 )
1377 .await
1378 {
1379 modified
1380 } else {
1381 effective_prompt.to_string()
1382 };
1383 let effective_prompt = hooked_prompt.as_str();
1384
1385 if let Some(ref sp) = self.config.security_provider {
1387 sp.taint_input(effective_prompt);
1388 }
1389
1390 let system_with_memory = if let Some(ref memory) = self.config.memory {
1392 match memory.recall_similar(effective_prompt, 5).await {
1393 Ok(items) if !items.is_empty() => {
1394 if let Some(tx) = &event_tx {
1395 for item in &items {
1396 tx.send(AgentEvent::MemoryRecalled {
1397 memory_id: item.id.clone(),
1398 content: item.content.clone(),
1399 relevance: item.relevance_score(),
1400 })
1401 .await
1402 .ok();
1403 }
1404 tx.send(AgentEvent::MemoriesSearched {
1405 query: Some(effective_prompt.to_string()),
1406 tags: Vec::new(),
1407 result_count: items.len(),
1408 })
1409 .await
1410 .ok();
1411 }
1412 let memory_context = items
1413 .iter()
1414 .map(|i| format!("- {}", i.content))
1415 .collect::<Vec<_>>()
1416 .join(
1417 "
1418",
1419 );
1420 let base = self.system_prompt();
1421 Some(format!(
1422 "{}
1423
1424## Relevant past experience
1425{}",
1426 base, memory_context
1427 ))
1428 }
1429 _ => Some(self.system_prompt()),
1430 }
1431 } else {
1432 Some(self.system_prompt())
1433 };
1434
1435 let augmented_system = if !self.config.context_providers.is_empty() {
1437 if let Some(tx) = &event_tx {
1439 let provider_names: Vec<String> = self
1440 .config
1441 .context_providers
1442 .iter()
1443 .map(|p| p.name().to_string())
1444 .collect();
1445 tx.send(AgentEvent::ContextResolving {
1446 providers: provider_names,
1447 })
1448 .await
1449 .ok();
1450 }
1451
1452 tracing::info!(
1453 a3s.context.providers = self.config.context_providers.len() as i64,
1454 "Context resolution started"
1455 );
1456 let context_results = self.resolve_context(effective_prompt, session_id).await;
1457
1458 if let Some(tx) = &event_tx {
1460 let total_items: usize = context_results.iter().map(|r| r.items.len()).sum();
1461 let total_tokens: usize = context_results.iter().map(|r| r.total_tokens).sum();
1462
1463 tracing::info!(
1464 context_items = total_items,
1465 context_tokens = total_tokens,
1466 "Context resolution completed"
1467 );
1468
1469 tx.send(AgentEvent::ContextResolved {
1470 total_items,
1471 total_tokens,
1472 })
1473 .await
1474 .ok();
1475 }
1476
1477 self.build_augmented_system_prompt(&context_results)
1478 } else {
1479 Some(self.system_prompt())
1480 };
1481
1482 let base_prompt = self.system_prompt();
1484 let augmented_system = match (augmented_system, system_with_memory) {
1485 (Some(ctx), Some(mem)) if ctx != mem => Some(ctx.replacen(&base_prompt, &mem, 1)),
1486 (Some(ctx), _) => Some(ctx),
1487 (None, mem) => mem,
1488 };
1489
1490 if !msg_prompt.is_empty() {
1492 messages.push(Message::user(msg_prompt));
1493 }
1494
1495 loop {
1496 turn += 1;
1497
1498 if turn > self.config.max_tool_rounds {
1499 let error = format!("Max tool rounds ({}) exceeded", self.config.max_tool_rounds);
1500 if let Some(tx) = &event_tx {
1501 tx.send(AgentEvent::Error {
1502 message: error.clone(),
1503 })
1504 .await
1505 .ok();
1506 }
1507 anyhow::bail!(error);
1508 }
1509
1510 if let Some(tx) = &event_tx {
1512 tx.send(AgentEvent::TurnStart { turn }).await.ok();
1513 }
1514
1515 tracing::info!(
1516 turn = turn,
1517 max_turns = self.config.max_tool_rounds,
1518 "Agent turn started"
1519 );
1520
1521 tracing::info!(
1523 a3s.llm.streaming = event_tx.is_some(),
1524 "LLM completion started"
1525 );
1526
1527 self.fire_generate_start(
1529 session_id.unwrap_or(""),
1530 effective_prompt,
1531 &augmented_system,
1532 )
1533 .await;
1534
1535 let llm_start = std::time::Instant::now();
1536 let response = {
1540 let threshold = self.config.circuit_breaker_threshold.max(1);
1541 let mut attempt = 0u32;
1542 loop {
1543 attempt += 1;
1544 let result = self
1545 .call_llm(&messages, augmented_system.as_deref(), &event_tx)
1546 .await;
1547 match result {
1548 Ok(r) => {
1549 break r;
1550 }
1551 Err(e) if attempt < threshold && (event_tx.is_none() || attempt == 1) => {
1553 tracing::warn!(
1554 turn = turn,
1555 attempt = attempt,
1556 threshold = threshold,
1557 error = %e,
1558 "LLM call failed, will retry"
1559 );
1560 tokio::time::sleep(Duration::from_millis(100 * attempt as u64)).await;
1561 }
1562 Err(e) => {
1564 let msg = if attempt > 1 {
1565 format!(
1566 "LLM circuit breaker triggered: failed after {} attempt(s): {}",
1567 attempt, e
1568 )
1569 } else {
1570 format!("LLM call failed: {}", e)
1571 };
1572 tracing::error!(turn = turn, attempt = attempt, "{}", msg);
1573 self.fire_on_error(
1575 session_id.unwrap_or(""),
1576 ErrorType::LlmFailure,
1577 &msg,
1578 serde_json::json!({"turn": turn, "attempt": attempt}),
1579 )
1580 .await;
1581 if let Some(tx) = &event_tx {
1582 tx.send(AgentEvent::Error {
1583 message: msg.clone(),
1584 })
1585 .await
1586 .ok();
1587 }
1588 anyhow::bail!(msg);
1589 }
1590 }
1591 }
1592 };
1593
1594 total_usage.prompt_tokens += response.usage.prompt_tokens;
1596 total_usage.completion_tokens += response.usage.completion_tokens;
1597 total_usage.total_tokens += response.usage.total_tokens;
1598
1599 let llm_duration = llm_start.elapsed();
1601 tracing::info!(
1602 turn = turn,
1603 streaming = event_tx.is_some(),
1604 prompt_tokens = response.usage.prompt_tokens,
1605 completion_tokens = response.usage.completion_tokens,
1606 total_tokens = response.usage.total_tokens,
1607 stop_reason = response.stop_reason.as_deref().unwrap_or("unknown"),
1608 duration_ms = llm_duration.as_millis() as u64,
1609 "LLM completion finished"
1610 );
1611
1612 self.fire_generate_end(
1614 session_id.unwrap_or(""),
1615 effective_prompt,
1616 &response,
1617 llm_duration.as_millis() as u64,
1618 )
1619 .await;
1620
1621 crate::telemetry::record_llm_usage(
1623 response.usage.prompt_tokens,
1624 response.usage.completion_tokens,
1625 response.usage.total_tokens,
1626 response.stop_reason.as_deref(),
1627 );
1628 tracing::info!(
1630 turn = turn,
1631 a3s.llm.total_tokens = response.usage.total_tokens,
1632 "Turn token usage"
1633 );
1634
1635 messages.push(response.message.clone());
1637
1638 let tool_calls = response.tool_calls();
1640
1641 if let Some(tx) = &event_tx {
1643 tx.send(AgentEvent::TurnEnd {
1644 turn,
1645 usage: response.usage.clone(),
1646 })
1647 .await
1648 .ok();
1649 }
1650
1651 if self.config.auto_compact {
1653 let used = response.usage.prompt_tokens;
1654 let max = self.config.max_context_tokens;
1655 let threshold = self.config.auto_compact_threshold;
1656
1657 if crate::session::compaction::should_auto_compact(used, max, threshold) {
1658 let before_len = messages.len();
1659 let percent_before = used as f32 / max as f32;
1660
1661 tracing::info!(
1662 used_tokens = used,
1663 max_tokens = max,
1664 percent = percent_before,
1665 threshold = threshold,
1666 "Auto-compact triggered"
1667 );
1668
1669 if let Some(pruned) = crate::session::compaction::prune_tool_outputs(&messages)
1671 {
1672 messages = pruned;
1673 tracing::info!("Tool output pruning applied");
1674 }
1675
1676 if let Ok(Some(compacted)) = crate::session::compaction::compact_messages(
1678 session_id.unwrap_or(""),
1679 &messages,
1680 &self.llm_client,
1681 )
1682 .await
1683 {
1684 messages = compacted;
1685 }
1686
1687 if let Some(tx) = &event_tx {
1689 tx.send(AgentEvent::ContextCompacted {
1690 session_id: session_id.unwrap_or("").to_string(),
1691 before_messages: before_len,
1692 after_messages: messages.len(),
1693 percent_before,
1694 })
1695 .await
1696 .ok();
1697 }
1698 }
1699 }
1700
1701 if tool_calls.is_empty() {
1702 let final_text = response.text();
1705
1706 if self.config.continuation_enabled
1707 && continuation_count < self.config.max_continuation_turns
1708 && Self::looks_incomplete(&final_text)
1709 {
1710 continuation_count += 1;
1711 tracing::info!(
1712 turn = turn,
1713 continuation = continuation_count,
1714 max_continuation = self.config.max_continuation_turns,
1715 "Injecting continuation message — response looks incomplete"
1716 );
1717 messages.push(Message::user(crate::prompts::CONTINUATION));
1719 continue;
1720 }
1721
1722 let final_text = if let Some(ref sp) = self.config.security_provider {
1724 sp.sanitize_output(&final_text)
1725 } else {
1726 final_text
1727 };
1728
1729 tracing::info!(
1731 tool_calls_count = tool_calls_count,
1732 total_prompt_tokens = total_usage.prompt_tokens,
1733 total_completion_tokens = total_usage.completion_tokens,
1734 total_tokens = total_usage.total_tokens,
1735 turns = turn,
1736 "Agent execution completed"
1737 );
1738
1739 if let Some(tx) = &event_tx {
1740 tx.send(AgentEvent::End {
1741 text: final_text.clone(),
1742 usage: total_usage.clone(),
1743 })
1744 .await
1745 .ok();
1746 }
1747
1748 if let Some(sid) = session_id {
1750 self.notify_turn_complete(sid, effective_prompt, &final_text)
1751 .await;
1752 }
1753
1754 return Ok(AgentResult {
1755 text: final_text,
1756 messages,
1757 usage: total_usage,
1758 tool_calls_count,
1759 });
1760 }
1761
1762 for tool_call in tool_calls {
1764 tool_calls_count += 1;
1765
1766 let tool_start = std::time::Instant::now();
1767
1768 tracing::info!(
1769 tool_name = tool_call.name.as_str(),
1770 tool_id = tool_call.id.as_str(),
1771 "Tool execution started"
1772 );
1773
1774 if let Some(parse_error) =
1780 tool_call.args.get("__parse_error").and_then(|v| v.as_str())
1781 {
1782 parse_error_count += 1;
1783 let error_msg = format!("Error: {}", parse_error);
1784 tracing::warn!(
1785 tool = tool_call.name.as_str(),
1786 parse_error_count = parse_error_count,
1787 max_parse_retries = self.config.max_parse_retries,
1788 "Malformed tool arguments from LLM"
1789 );
1790
1791 if let Some(tx) = &event_tx {
1792 tx.send(AgentEvent::ToolEnd {
1793 id: tool_call.id.clone(),
1794 name: tool_call.name.clone(),
1795 output: error_msg.clone(),
1796 exit_code: 1,
1797 })
1798 .await
1799 .ok();
1800 }
1801
1802 messages.push(Message::tool_result(&tool_call.id, &error_msg, true));
1803
1804 if parse_error_count > self.config.max_parse_retries {
1805 let msg = format!(
1806 "LLM produced malformed tool arguments {} time(s) in a row \
1807 (max_parse_retries={}); giving up",
1808 parse_error_count, self.config.max_parse_retries
1809 );
1810 tracing::error!("{}", msg);
1811 if let Some(tx) = &event_tx {
1812 tx.send(AgentEvent::Error {
1813 message: msg.clone(),
1814 })
1815 .await
1816 .ok();
1817 }
1818 anyhow::bail!(msg);
1819 }
1820 continue;
1821 }
1822
1823 parse_error_count = 0;
1825
1826 if let Some(ref registry) = self.config.skill_registry {
1828 let instruction_skills =
1829 registry.by_kind(crate::skills::SkillKind::Instruction);
1830 let has_restrictions =
1831 instruction_skills.iter().any(|s| s.allowed_tools.is_some());
1832 if has_restrictions {
1833 let allowed = instruction_skills
1834 .iter()
1835 .any(|s| s.is_tool_allowed(&tool_call.name));
1836 if !allowed {
1837 let msg = format!(
1838 "Tool '{}' is not allowed by any active skill.",
1839 tool_call.name
1840 );
1841 tracing::info!(
1842 tool_name = tool_call.name.as_str(),
1843 "Tool blocked by skill registry"
1844 );
1845 if let Some(tx) = &event_tx {
1846 tx.send(AgentEvent::PermissionDenied {
1847 tool_id: tool_call.id.clone(),
1848 tool_name: tool_call.name.clone(),
1849 args: tool_call.args.clone(),
1850 reason: msg.clone(),
1851 })
1852 .await
1853 .ok();
1854 }
1855 messages.push(Message::tool_result(&tool_call.id, &msg, true));
1856 continue;
1857 }
1858 }
1859 }
1860
1861 if let Some(HookResult::Block(reason)) = self
1863 .fire_pre_tool_use(session_id.unwrap_or(""), &tool_call.name, &tool_call.args)
1864 .await
1865 {
1866 let msg = format!("Tool '{}' blocked by hook: {}", tool_call.name, reason);
1867 tracing::info!(
1868 tool_name = tool_call.name.as_str(),
1869 "Tool blocked by PreToolUse hook"
1870 );
1871
1872 if let Some(tx) = &event_tx {
1873 tx.send(AgentEvent::PermissionDenied {
1874 tool_id: tool_call.id.clone(),
1875 tool_name: tool_call.name.clone(),
1876 args: tool_call.args.clone(),
1877 reason: reason.clone(),
1878 })
1879 .await
1880 .ok();
1881 }
1882
1883 messages.push(Message::tool_result(&tool_call.id, &msg, true));
1884 continue;
1885 }
1886
1887 let permission_decision = if let Some(checker) = &self.config.permission_checker {
1889 checker.check(&tool_call.name, &tool_call.args)
1890 } else {
1891 PermissionDecision::Ask
1893 };
1894
1895 let (output, exit_code, is_error, _metadata, images) = match permission_decision {
1896 PermissionDecision::Deny => {
1897 tracing::info!(
1898 tool_name = tool_call.name.as_str(),
1899 permission = "deny",
1900 "Tool permission denied"
1901 );
1902 let denial_msg = format!(
1904 "Permission denied: Tool '{}' is blocked by permission policy.",
1905 tool_call.name
1906 );
1907
1908 if let Some(tx) = &event_tx {
1910 tx.send(AgentEvent::PermissionDenied {
1911 tool_id: tool_call.id.clone(),
1912 tool_name: tool_call.name.clone(),
1913 args: tool_call.args.clone(),
1914 reason: "Blocked by deny rule in permission policy".to_string(),
1915 })
1916 .await
1917 .ok();
1918 }
1919
1920 (denial_msg, 1, true, None, Vec::new())
1921 }
1922 PermissionDecision::Allow => {
1923 tracing::info!(
1924 tool_name = tool_call.name.as_str(),
1925 permission = "allow",
1926 "Tool permission: allow"
1927 );
1928 let stream_ctx =
1930 self.streaming_tool_context(&event_tx, &tool_call.id, &tool_call.name);
1931 let result = self
1932 .execute_tool_queued_or_direct(
1933 &tool_call.name,
1934 &tool_call.args,
1935 &stream_ctx,
1936 )
1937 .await;
1938
1939 Self::tool_result_to_tuple(result)
1940 }
1941 PermissionDecision::Ask => {
1942 tracing::info!(
1943 tool_name = tool_call.name.as_str(),
1944 permission = "ask",
1945 "Tool permission: ask"
1946 );
1947 if let Some(cm) = &self.config.confirmation_manager {
1949 if !cm.requires_confirmation(&tool_call.name).await {
1951 let stream_ctx = self.streaming_tool_context(
1952 &event_tx,
1953 &tool_call.id,
1954 &tool_call.name,
1955 );
1956 let result = self
1957 .execute_tool_queued_or_direct(
1958 &tool_call.name,
1959 &tool_call.args,
1960 &stream_ctx,
1961 )
1962 .await;
1963
1964 let (output, exit_code, is_error, _metadata, images) =
1965 Self::tool_result_to_tuple(result);
1966
1967 if images.is_empty() {
1969 messages.push(Message::tool_result(
1970 &tool_call.id,
1971 &output,
1972 is_error,
1973 ));
1974 } else {
1975 messages.push(Message::tool_result_with_images(
1976 &tool_call.id,
1977 &output,
1978 &images,
1979 is_error,
1980 ));
1981 }
1982
1983 let tool_duration = tool_start.elapsed();
1985 crate::telemetry::record_tool_result(exit_code, tool_duration);
1986
1987 if let Some(tx) = &event_tx {
1989 tx.send(AgentEvent::ToolEnd {
1990 id: tool_call.id.clone(),
1991 name: tool_call.name.clone(),
1992 output: output.clone(),
1993 exit_code,
1994 })
1995 .await
1996 .ok();
1997 }
1998
1999 self.fire_post_tool_use(
2001 session_id.unwrap_or(""),
2002 &tool_call.name,
2003 &tool_call.args,
2004 &output,
2005 exit_code == 0,
2006 tool_duration.as_millis() as u64,
2007 )
2008 .await;
2009
2010 continue; }
2012
2013 let policy = cm.policy().await;
2015 let timeout_ms = policy.default_timeout_ms;
2016 let timeout_action = policy.timeout_action;
2017
2018 let rx = cm
2020 .request_confirmation(
2021 &tool_call.id,
2022 &tool_call.name,
2023 &tool_call.args,
2024 )
2025 .await;
2026
2027 if let Some(tx) = &event_tx {
2031 tx.send(AgentEvent::ConfirmationRequired {
2032 tool_id: tool_call.id.clone(),
2033 tool_name: tool_call.name.clone(),
2034 args: tool_call.args.clone(),
2035 timeout_ms,
2036 })
2037 .await
2038 .ok();
2039 }
2040
2041 let confirmation_result =
2043 tokio::time::timeout(Duration::from_millis(timeout_ms), rx).await;
2044
2045 match confirmation_result {
2046 Ok(Ok(response)) => {
2047 if let Some(tx) = &event_tx {
2049 tx.send(AgentEvent::ConfirmationReceived {
2050 tool_id: tool_call.id.clone(),
2051 approved: response.approved,
2052 reason: response.reason.clone(),
2053 })
2054 .await
2055 .ok();
2056 }
2057 if response.approved {
2058 let stream_ctx = self.streaming_tool_context(
2059 &event_tx,
2060 &tool_call.id,
2061 &tool_call.name,
2062 );
2063 let result = self
2064 .execute_tool_queued_or_direct(
2065 &tool_call.name,
2066 &tool_call.args,
2067 &stream_ctx,
2068 )
2069 .await;
2070
2071 Self::tool_result_to_tuple(result)
2072 } else {
2073 let rejection_msg = format!(
2074 "Tool '{}' execution was REJECTED by the user. Reason: {}. \
2075 DO NOT retry this tool call unless the user explicitly asks you to.",
2076 tool_call.name,
2077 response.reason.unwrap_or_else(|| "No reason provided".to_string())
2078 );
2079 (rejection_msg, 1, true, None, Vec::new())
2080 }
2081 }
2082 Ok(Err(_)) => {
2083 if let Some(tx) = &event_tx {
2085 tx.send(AgentEvent::ConfirmationTimeout {
2086 tool_id: tool_call.id.clone(),
2087 action_taken: "rejected".to_string(),
2088 })
2089 .await
2090 .ok();
2091 }
2092 let msg = format!(
2093 "Tool '{}' confirmation failed: confirmation channel closed",
2094 tool_call.name
2095 );
2096 (msg, 1, true, None, Vec::new())
2097 }
2098 Err(_) => {
2099 cm.check_timeouts().await;
2100
2101 if let Some(tx) = &event_tx {
2103 tx.send(AgentEvent::ConfirmationTimeout {
2104 tool_id: tool_call.id.clone(),
2105 action_taken: match timeout_action {
2106 crate::hitl::TimeoutAction::Reject => {
2107 "rejected".to_string()
2108 }
2109 crate::hitl::TimeoutAction::AutoApprove => {
2110 "auto_approved".to_string()
2111 }
2112 },
2113 })
2114 .await
2115 .ok();
2116 }
2117
2118 match timeout_action {
2119 crate::hitl::TimeoutAction::Reject => {
2120 let msg = format!(
2121 "Tool '{}' execution was REJECTED: user confirmation timed out after {}ms. \
2122 DO NOT retry this tool call — the user did not approve it. \
2123 Inform the user that the operation requires their approval and ask them to try again.",
2124 tool_call.name, timeout_ms
2125 );
2126 (msg, 1, true, None, Vec::new())
2127 }
2128 crate::hitl::TimeoutAction::AutoApprove => {
2129 let stream_ctx = self.streaming_tool_context(
2130 &event_tx,
2131 &tool_call.id,
2132 &tool_call.name,
2133 );
2134 let result = self
2135 .execute_tool_queued_or_direct(
2136 &tool_call.name,
2137 &tool_call.args,
2138 &stream_ctx,
2139 )
2140 .await;
2141
2142 Self::tool_result_to_tuple(result)
2143 }
2144 }
2145 }
2146 }
2147 } else {
2148 let msg = format!(
2150 "Tool '{}' requires confirmation but no HITL confirmation manager is configured. \
2151 Configure a confirmation policy to enable tool execution.",
2152 tool_call.name
2153 );
2154 tracing::warn!(
2155 tool_name = tool_call.name.as_str(),
2156 "Tool requires confirmation but no HITL manager configured"
2157 );
2158 (msg, 1, true, None, Vec::new())
2159 }
2160 }
2161 };
2162
2163 let tool_duration = tool_start.elapsed();
2164 crate::telemetry::record_tool_result(exit_code, tool_duration);
2165
2166 let output = if let Some(ref sp) = self.config.security_provider {
2168 sp.sanitize_output(&output)
2169 } else {
2170 output
2171 };
2172
2173 self.fire_post_tool_use(
2175 session_id.unwrap_or(""),
2176 &tool_call.name,
2177 &tool_call.args,
2178 &output,
2179 exit_code == 0,
2180 tool_duration.as_millis() as u64,
2181 )
2182 .await;
2183
2184 if let Some(ref memory) = self.config.memory {
2186 let tools_used = [tool_call.name.clone()];
2187 let remember_result = if exit_code == 0 {
2188 memory
2189 .remember_success(effective_prompt, &tools_used, &output)
2190 .await
2191 } else {
2192 memory
2193 .remember_failure(effective_prompt, &output, &tools_used)
2194 .await
2195 };
2196 match remember_result {
2197 Ok(()) => {
2198 if let Some(tx) = &event_tx {
2199 let item_type = if exit_code == 0 { "success" } else { "failure" };
2200 tx.send(AgentEvent::MemoryStored {
2201 memory_id: uuid::Uuid::new_v4().to_string(),
2202 memory_type: item_type.to_string(),
2203 importance: if exit_code == 0 { 0.8 } else { 0.9 },
2204 tags: vec![item_type.to_string(), tool_call.name.clone()],
2205 })
2206 .await
2207 .ok();
2208 }
2209 }
2210 Err(e) => {
2211 tracing::warn!("Failed to store memory after tool execution: {}", e);
2212 }
2213 }
2214 }
2215
2216 if let Some(tx) = &event_tx {
2218 tx.send(AgentEvent::ToolEnd {
2219 id: tool_call.id.clone(),
2220 name: tool_call.name.clone(),
2221 output: output.clone(),
2222 exit_code,
2223 })
2224 .await
2225 .ok();
2226 }
2227
2228 if images.is_empty() {
2230 messages.push(Message::tool_result(&tool_call.id, &output, is_error));
2231 } else {
2232 messages.push(Message::tool_result_with_images(
2233 &tool_call.id,
2234 &output,
2235 &images,
2236 is_error,
2237 ));
2238 }
2239 }
2240 }
2241 }
2242
2243 pub async fn execute_streaming(
2245 &self,
2246 history: &[Message],
2247 prompt: &str,
2248 ) -> Result<(
2249 mpsc::Receiver<AgentEvent>,
2250 tokio::task::JoinHandle<Result<AgentResult>>,
2251 )> {
2252 let (tx, rx) = mpsc::channel(100);
2253
2254 let llm_client = self.llm_client.clone();
2255 let tool_executor = self.tool_executor.clone();
2256 let tool_context = self.tool_context.clone();
2257 let config = self.config.clone();
2258 let tool_metrics = self.tool_metrics.clone();
2259 let command_queue = self.command_queue.clone();
2260 let history = history.to_vec();
2261 let prompt = prompt.to_string();
2262
2263 let handle = tokio::spawn(async move {
2264 let mut agent = AgentLoop::new(llm_client, tool_executor, tool_context, config);
2265 if let Some(metrics) = tool_metrics {
2266 agent = agent.with_tool_metrics(metrics);
2267 }
2268 if let Some(queue) = command_queue {
2269 agent = agent.with_queue(queue);
2270 }
2271 agent.execute(&history, &prompt, Some(tx)).await
2272 });
2273
2274 Ok((rx, handle))
2275 }
2276
2277 pub async fn plan(&self, prompt: &str, _context: Option<&str>) -> Result<ExecutionPlan> {
2282 use crate::planning::LlmPlanner;
2283
2284 match LlmPlanner::create_plan(&self.llm_client, prompt).await {
2285 Ok(plan) => Ok(plan),
2286 Err(e) => {
2287 tracing::warn!("LLM plan creation failed, using fallback: {}", e);
2288 Ok(LlmPlanner::fallback_plan(prompt))
2289 }
2290 }
2291 }
2292
2293 pub async fn execute_with_planning(
2295 &self,
2296 history: &[Message],
2297 prompt: &str,
2298 event_tx: Option<mpsc::Sender<AgentEvent>>,
2299 ) -> Result<AgentResult> {
2300 if let Some(tx) = &event_tx {
2302 tx.send(AgentEvent::PlanningStart {
2303 prompt: prompt.to_string(),
2304 })
2305 .await
2306 .ok();
2307 }
2308
2309 let goal = if self.config.goal_tracking {
2311 let g = self.extract_goal(prompt).await?;
2312 if let Some(tx) = &event_tx {
2313 tx.send(AgentEvent::GoalExtracted { goal: g.clone() })
2314 .await
2315 .ok();
2316 }
2317 Some(g)
2318 } else {
2319 None
2320 };
2321
2322 let plan = self.plan(prompt, None).await?;
2324
2325 if let Some(tx) = &event_tx {
2327 tx.send(AgentEvent::PlanningEnd {
2328 estimated_steps: plan.steps.len(),
2329 plan: plan.clone(),
2330 })
2331 .await
2332 .ok();
2333 }
2334
2335 let plan_start = std::time::Instant::now();
2336
2337 let result = self.execute_plan(history, &plan, event_tx.clone()).await?;
2339
2340 if self.config.goal_tracking {
2342 if let Some(ref g) = goal {
2343 let achieved = self.check_goal_achievement(g, &result.text).await?;
2344 if achieved {
2345 if let Some(tx) = &event_tx {
2346 tx.send(AgentEvent::GoalAchieved {
2347 goal: g.description.clone(),
2348 total_steps: result.messages.len(),
2349 duration_ms: plan_start.elapsed().as_millis() as i64,
2350 })
2351 .await
2352 .ok();
2353 }
2354 }
2355 }
2356 }
2357
2358 Ok(result)
2359 }
2360
2361 async fn execute_plan(
2368 &self,
2369 history: &[Message],
2370 plan: &ExecutionPlan,
2371 event_tx: Option<mpsc::Sender<AgentEvent>>,
2372 ) -> Result<AgentResult> {
2373 let mut plan = plan.clone();
2374 let mut current_history = history.to_vec();
2375 let mut total_usage = TokenUsage::default();
2376 let mut tool_calls_count = 0;
2377 let total_steps = plan.steps.len();
2378
2379 let steps_text = plan
2381 .steps
2382 .iter()
2383 .enumerate()
2384 .map(|(i, step)| format!("{}. {}", i + 1, step.content))
2385 .collect::<Vec<_>>()
2386 .join("\n");
2387 current_history.push(Message::user(&crate::prompts::render(
2388 crate::prompts::PLAN_EXECUTE_GOAL,
2389 &[("goal", &plan.goal), ("steps", &steps_text)],
2390 )));
2391
2392 loop {
2393 let ready: Vec<String> = plan
2394 .get_ready_steps()
2395 .iter()
2396 .map(|s| s.id.clone())
2397 .collect();
2398
2399 if ready.is_empty() {
2400 if plan.has_deadlock() {
2402 tracing::warn!(
2403 "Plan deadlock detected: {} pending steps with unresolvable dependencies",
2404 plan.pending_count()
2405 );
2406 }
2407 break;
2408 }
2409
2410 if ready.len() == 1 {
2411 let step_id = &ready[0];
2413 let step = plan
2414 .steps
2415 .iter()
2416 .find(|s| s.id == *step_id)
2417 .ok_or_else(|| anyhow::anyhow!("step '{}' not found in plan", step_id))?
2418 .clone();
2419 let step_number = plan
2420 .steps
2421 .iter()
2422 .position(|s| s.id == *step_id)
2423 .unwrap_or(0)
2424 + 1;
2425
2426 if let Some(tx) = &event_tx {
2428 tx.send(AgentEvent::StepStart {
2429 step_id: step.id.clone(),
2430 description: step.content.clone(),
2431 step_number,
2432 total_steps,
2433 })
2434 .await
2435 .ok();
2436 }
2437
2438 plan.mark_status(&step.id, TaskStatus::InProgress);
2439
2440 let step_prompt = crate::prompts::render(
2441 crate::prompts::PLAN_EXECUTE_STEP,
2442 &[
2443 ("step_num", &step_number.to_string()),
2444 ("description", &step.content),
2445 ],
2446 );
2447
2448 match self
2449 .execute_loop(¤t_history, &step_prompt, None, event_tx.clone())
2450 .await
2451 {
2452 Ok(result) => {
2453 current_history = result.messages.clone();
2454 total_usage.prompt_tokens += result.usage.prompt_tokens;
2455 total_usage.completion_tokens += result.usage.completion_tokens;
2456 total_usage.total_tokens += result.usage.total_tokens;
2457 tool_calls_count += result.tool_calls_count;
2458 plan.mark_status(&step.id, TaskStatus::Completed);
2459
2460 if let Some(tx) = &event_tx {
2461 tx.send(AgentEvent::StepEnd {
2462 step_id: step.id.clone(),
2463 status: TaskStatus::Completed,
2464 step_number,
2465 total_steps,
2466 })
2467 .await
2468 .ok();
2469 }
2470 }
2471 Err(e) => {
2472 tracing::error!("Plan step '{}' failed: {}", step.id, e);
2473 plan.mark_status(&step.id, TaskStatus::Failed);
2474
2475 if let Some(tx) = &event_tx {
2476 tx.send(AgentEvent::StepEnd {
2477 step_id: step.id.clone(),
2478 status: TaskStatus::Failed,
2479 step_number,
2480 total_steps,
2481 })
2482 .await
2483 .ok();
2484 }
2485 }
2486 }
2487 } else {
2488 let ready_steps: Vec<_> = ready
2495 .iter()
2496 .filter_map(|id| {
2497 let step = plan.steps.iter().find(|s| s.id == *id)?.clone();
2498 let step_number =
2499 plan.steps.iter().position(|s| s.id == *id).unwrap_or(0) + 1;
2500 Some((step, step_number))
2501 })
2502 .collect();
2503
2504 for (step, step_number) in &ready_steps {
2506 plan.mark_status(&step.id, TaskStatus::InProgress);
2507 if let Some(tx) = &event_tx {
2508 tx.send(AgentEvent::StepStart {
2509 step_id: step.id.clone(),
2510 description: step.content.clone(),
2511 step_number: *step_number,
2512 total_steps,
2513 })
2514 .await
2515 .ok();
2516 }
2517 }
2518
2519 let mut join_set = tokio::task::JoinSet::new();
2521 for (step, step_number) in &ready_steps {
2522 let base_history = current_history.clone();
2523 let agent_clone = self.clone();
2524 let tx = event_tx.clone();
2525 let step_clone = step.clone();
2526 let sn = *step_number;
2527
2528 join_set.spawn(async move {
2529 let prompt = crate::prompts::render(
2530 crate::prompts::PLAN_EXECUTE_STEP,
2531 &[
2532 ("step_num", &sn.to_string()),
2533 ("description", &step_clone.content),
2534 ],
2535 );
2536 let result = agent_clone
2537 .execute_loop(&base_history, &prompt, None, tx)
2538 .await;
2539 (step_clone.id, sn, result)
2540 });
2541 }
2542
2543 let mut parallel_summaries = Vec::new();
2545 while let Some(join_result) = join_set.join_next().await {
2546 match join_result {
2547 Ok((step_id, step_number, step_result)) => match step_result {
2548 Ok(result) => {
2549 total_usage.prompt_tokens += result.usage.prompt_tokens;
2550 total_usage.completion_tokens += result.usage.completion_tokens;
2551 total_usage.total_tokens += result.usage.total_tokens;
2552 tool_calls_count += result.tool_calls_count;
2553 plan.mark_status(&step_id, TaskStatus::Completed);
2554
2555 parallel_summaries.push(format!(
2557 "- Step {} ({}): {}",
2558 step_number, step_id, result.text
2559 ));
2560
2561 if let Some(tx) = &event_tx {
2562 tx.send(AgentEvent::StepEnd {
2563 step_id,
2564 status: TaskStatus::Completed,
2565 step_number,
2566 total_steps,
2567 })
2568 .await
2569 .ok();
2570 }
2571 }
2572 Err(e) => {
2573 tracing::error!("Plan step '{}' failed: {}", step_id, e);
2574 plan.mark_status(&step_id, TaskStatus::Failed);
2575
2576 if let Some(tx) = &event_tx {
2577 tx.send(AgentEvent::StepEnd {
2578 step_id,
2579 status: TaskStatus::Failed,
2580 step_number,
2581 total_steps,
2582 })
2583 .await
2584 .ok();
2585 }
2586 }
2587 },
2588 Err(e) => {
2589 tracing::error!("JoinSet task panicked: {}", e);
2590 }
2591 }
2592 }
2593
2594 if !parallel_summaries.is_empty() {
2596 parallel_summaries.sort(); let results_text = parallel_summaries.join("\n");
2598 current_history.push(Message::user(&crate::prompts::render(
2599 crate::prompts::PLAN_PARALLEL_RESULTS,
2600 &[("results", &results_text)],
2601 )));
2602 }
2603 }
2604
2605 if self.config.goal_tracking {
2607 let completed = plan
2608 .steps
2609 .iter()
2610 .filter(|s| s.status == TaskStatus::Completed)
2611 .count();
2612 if let Some(tx) = &event_tx {
2613 tx.send(AgentEvent::GoalProgress {
2614 goal: plan.goal.clone(),
2615 progress: plan.progress(),
2616 completed_steps: completed,
2617 total_steps,
2618 })
2619 .await
2620 .ok();
2621 }
2622 }
2623 }
2624
2625 let final_text = current_history
2627 .last()
2628 .map(|m| {
2629 m.content
2630 .iter()
2631 .filter_map(|block| {
2632 if let crate::llm::ContentBlock::Text { text } = block {
2633 Some(text.as_str())
2634 } else {
2635 None
2636 }
2637 })
2638 .collect::<Vec<_>>()
2639 .join("\n")
2640 })
2641 .unwrap_or_default();
2642
2643 Ok(AgentResult {
2644 text: final_text,
2645 messages: current_history,
2646 usage: total_usage,
2647 tool_calls_count,
2648 })
2649 }
2650
2651 pub async fn extract_goal(&self, prompt: &str) -> Result<AgentGoal> {
2656 use crate::planning::LlmPlanner;
2657
2658 match LlmPlanner::extract_goal(&self.llm_client, prompt).await {
2659 Ok(goal) => Ok(goal),
2660 Err(e) => {
2661 tracing::warn!("LLM goal extraction failed, using fallback: {}", e);
2662 Ok(LlmPlanner::fallback_goal(prompt))
2663 }
2664 }
2665 }
2666
2667 pub async fn check_goal_achievement(
2672 &self,
2673 goal: &AgentGoal,
2674 current_state: &str,
2675 ) -> Result<bool> {
2676 use crate::planning::LlmPlanner;
2677
2678 match LlmPlanner::check_achievement(&self.llm_client, goal, current_state).await {
2679 Ok(result) => Ok(result.achieved),
2680 Err(e) => {
2681 tracing::warn!("LLM achievement check failed, using fallback: {}", e);
2682 let result = LlmPlanner::fallback_check_achievement(goal, current_state);
2683 Ok(result.achieved)
2684 }
2685 }
2686 }
2687}
2688
2689#[cfg(test)]
2690mod tests {
2691 use super::*;
2692 use crate::llm::{ContentBlock, StreamEvent};
2693 use crate::permissions::PermissionPolicy;
2694 use crate::tools::ToolExecutor;
2695 use std::path::PathBuf;
2696 use std::sync::atomic::{AtomicUsize, Ordering};
2697
2698 fn test_tool_context() -> ToolContext {
2700 ToolContext::new(PathBuf::from("/tmp"))
2701 }
2702
2703 #[test]
2704 fn test_agent_config_default() {
2705 let config = AgentConfig::default();
2706 assert!(config.prompt_slots.is_empty());
2707 assert!(config.tools.is_empty()); assert_eq!(config.max_tool_rounds, MAX_TOOL_ROUNDS);
2709 assert!(config.permission_checker.is_none());
2710 assert!(config.context_providers.is_empty());
2711 }
2712
2713 pub(crate) struct MockLlmClient {
2719 responses: std::sync::Mutex<Vec<LlmResponse>>,
2721 pub(crate) call_count: AtomicUsize,
2723 }
2724
2725 impl MockLlmClient {
2726 pub(crate) fn new(responses: Vec<LlmResponse>) -> Self {
2727 Self {
2728 responses: std::sync::Mutex::new(responses),
2729 call_count: AtomicUsize::new(0),
2730 }
2731 }
2732
2733 pub(crate) fn text_response(text: &str) -> LlmResponse {
2735 LlmResponse {
2736 message: Message {
2737 role: "assistant".to_string(),
2738 content: vec![ContentBlock::Text {
2739 text: text.to_string(),
2740 }],
2741 reasoning_content: None,
2742 },
2743 usage: TokenUsage {
2744 prompt_tokens: 10,
2745 completion_tokens: 5,
2746 total_tokens: 15,
2747 cache_read_tokens: None,
2748 cache_write_tokens: None,
2749 },
2750 stop_reason: Some("end_turn".to_string()),
2751 }
2752 }
2753
2754 pub(crate) fn tool_call_response(
2756 tool_id: &str,
2757 tool_name: &str,
2758 args: serde_json::Value,
2759 ) -> LlmResponse {
2760 LlmResponse {
2761 message: Message {
2762 role: "assistant".to_string(),
2763 content: vec![ContentBlock::ToolUse {
2764 id: tool_id.to_string(),
2765 name: tool_name.to_string(),
2766 input: args,
2767 }],
2768 reasoning_content: None,
2769 },
2770 usage: TokenUsage {
2771 prompt_tokens: 10,
2772 completion_tokens: 5,
2773 total_tokens: 15,
2774 cache_read_tokens: None,
2775 cache_write_tokens: None,
2776 },
2777 stop_reason: Some("tool_use".to_string()),
2778 }
2779 }
2780 }
2781
2782 #[async_trait::async_trait]
2783 impl LlmClient for MockLlmClient {
2784 async fn complete(
2785 &self,
2786 _messages: &[Message],
2787 _system: Option<&str>,
2788 _tools: &[ToolDefinition],
2789 ) -> Result<LlmResponse> {
2790 self.call_count.fetch_add(1, Ordering::SeqCst);
2791 let mut responses = self.responses.lock().unwrap();
2792 if responses.is_empty() {
2793 anyhow::bail!("No more mock responses available");
2794 }
2795 Ok(responses.remove(0))
2796 }
2797
2798 async fn complete_streaming(
2799 &self,
2800 _messages: &[Message],
2801 _system: Option<&str>,
2802 _tools: &[ToolDefinition],
2803 ) -> Result<mpsc::Receiver<StreamEvent>> {
2804 self.call_count.fetch_add(1, Ordering::SeqCst);
2805 let mut responses = self.responses.lock().unwrap();
2806 if responses.is_empty() {
2807 anyhow::bail!("No more mock responses available");
2808 }
2809 let response = responses.remove(0);
2810
2811 let (tx, rx) = mpsc::channel(10);
2812 tokio::spawn(async move {
2813 for block in &response.message.content {
2815 if let ContentBlock::Text { text } = block {
2816 tx.send(StreamEvent::TextDelta(text.clone())).await.ok();
2817 }
2818 }
2819 tx.send(StreamEvent::Done(response)).await.ok();
2820 });
2821
2822 Ok(rx)
2823 }
2824 }
2825
2826 #[tokio::test]
2831 async fn test_agent_simple_response() {
2832 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
2833 "Hello, I'm an AI assistant.",
2834 )]));
2835
2836 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2837 let config = AgentConfig::default();
2838
2839 let agent = AgentLoop::new(
2840 mock_client.clone(),
2841 tool_executor,
2842 test_tool_context(),
2843 config,
2844 );
2845 let result = agent.execute(&[], "Hello", None).await.unwrap();
2846
2847 assert_eq!(result.text, "Hello, I'm an AI assistant.");
2848 assert_eq!(result.tool_calls_count, 0);
2849 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 1);
2850 }
2851
2852 #[tokio::test]
2853 async fn test_agent_with_tool_call() {
2854 let mock_client = Arc::new(MockLlmClient::new(vec![
2855 MockLlmClient::tool_call_response(
2857 "tool-1",
2858 "bash",
2859 serde_json::json!({"command": "echo hello"}),
2860 ),
2861 MockLlmClient::text_response("The command output was: hello"),
2863 ]));
2864
2865 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2866 let config = AgentConfig::default();
2867
2868 let agent = AgentLoop::new(
2869 mock_client.clone(),
2870 tool_executor,
2871 test_tool_context(),
2872 config,
2873 );
2874 let result = agent.execute(&[], "Run echo hello", None).await.unwrap();
2875
2876 assert_eq!(result.text, "The command output was: hello");
2877 assert_eq!(result.tool_calls_count, 1);
2878 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 2);
2879 }
2880
2881 #[tokio::test]
2882 async fn test_agent_permission_deny() {
2883 let mock_client = Arc::new(MockLlmClient::new(vec![
2884 MockLlmClient::tool_call_response(
2886 "tool-1",
2887 "bash",
2888 serde_json::json!({"command": "rm -rf /tmp/test"}),
2889 ),
2890 MockLlmClient::text_response(
2892 "I cannot execute that command due to permission restrictions.",
2893 ),
2894 ]));
2895
2896 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2897
2898 let permission_policy = PermissionPolicy::new().deny("bash(rm:*)");
2900
2901 let config = AgentConfig {
2902 permission_checker: Some(Arc::new(permission_policy)),
2903 ..Default::default()
2904 };
2905
2906 let (tx, mut rx) = mpsc::channel(100);
2907 let agent = AgentLoop::new(
2908 mock_client.clone(),
2909 tool_executor,
2910 test_tool_context(),
2911 config,
2912 );
2913 let result = agent.execute(&[], "Delete files", Some(tx)).await.unwrap();
2914
2915 let mut found_permission_denied = false;
2917 while let Ok(event) = rx.try_recv() {
2918 if let AgentEvent::PermissionDenied { tool_name, .. } = event {
2919 assert_eq!(tool_name, "bash");
2920 found_permission_denied = true;
2921 }
2922 }
2923 assert!(
2924 found_permission_denied,
2925 "Should have received PermissionDenied event"
2926 );
2927
2928 assert_eq!(result.tool_calls_count, 1);
2929 }
2930
2931 #[tokio::test]
2932 async fn test_agent_permission_allow() {
2933 let mock_client = Arc::new(MockLlmClient::new(vec![
2934 MockLlmClient::tool_call_response(
2936 "tool-1",
2937 "bash",
2938 serde_json::json!({"command": "echo hello"}),
2939 ),
2940 MockLlmClient::text_response("Done!"),
2942 ]));
2943
2944 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2945
2946 let permission_policy = PermissionPolicy::new()
2948 .allow("bash(echo:*)")
2949 .deny("bash(rm:*)");
2950
2951 let config = AgentConfig {
2952 permission_checker: Some(Arc::new(permission_policy)),
2953 ..Default::default()
2954 };
2955
2956 let agent = AgentLoop::new(
2957 mock_client.clone(),
2958 tool_executor,
2959 test_tool_context(),
2960 config,
2961 );
2962 let result = agent.execute(&[], "Echo hello", None).await.unwrap();
2963
2964 assert_eq!(result.text, "Done!");
2965 assert_eq!(result.tool_calls_count, 1);
2966 }
2967
2968 #[tokio::test]
2969 async fn test_agent_streaming_events() {
2970 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
2971 "Hello!",
2972 )]));
2973
2974 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2975 let config = AgentConfig::default();
2976
2977 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2978 let (mut rx, handle) = agent.execute_streaming(&[], "Hi").await.unwrap();
2979
2980 let mut events = Vec::new();
2982 while let Some(event) = rx.recv().await {
2983 events.push(event);
2984 }
2985
2986 let result = handle.await.unwrap().unwrap();
2987 assert_eq!(result.text, "Hello!");
2988
2989 assert!(events.iter().any(|e| matches!(e, AgentEvent::Start { .. })));
2991 assert!(events.iter().any(|e| matches!(e, AgentEvent::End { .. })));
2992 }
2993
2994 #[tokio::test]
2995 async fn test_agent_max_tool_rounds() {
2996 let responses: Vec<LlmResponse> = (0..100)
2998 .map(|i| {
2999 MockLlmClient::tool_call_response(
3000 &format!("tool-{}", i),
3001 "bash",
3002 serde_json::json!({"command": "echo loop"}),
3003 )
3004 })
3005 .collect();
3006
3007 let mock_client = Arc::new(MockLlmClient::new(responses));
3008 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3009
3010 let config = AgentConfig {
3011 max_tool_rounds: 3,
3012 ..Default::default()
3013 };
3014
3015 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3016 let result = agent.execute(&[], "Loop forever", None).await;
3017
3018 assert!(result.is_err());
3020 assert!(result.unwrap_err().to_string().contains("Max tool rounds"));
3021 }
3022
3023 #[tokio::test]
3024 async fn test_agent_no_permission_policy_defaults_to_ask() {
3025 let mock_client = Arc::new(MockLlmClient::new(vec![
3028 MockLlmClient::tool_call_response(
3029 "tool-1",
3030 "bash",
3031 serde_json::json!({"command": "rm -rf /tmp/test"}),
3032 ),
3033 MockLlmClient::text_response("Denied!"),
3034 ]));
3035
3036 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3037 let config = AgentConfig {
3038 permission_checker: None, ..Default::default()
3041 };
3042
3043 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3044 let result = agent.execute(&[], "Delete", None).await.unwrap();
3045
3046 assert_eq!(result.text, "Denied!");
3048 assert_eq!(result.tool_calls_count, 1);
3049 }
3050
3051 #[tokio::test]
3052 async fn test_agent_permission_ask_without_cm_denies() {
3053 let mock_client = Arc::new(MockLlmClient::new(vec![
3056 MockLlmClient::tool_call_response(
3057 "tool-1",
3058 "bash",
3059 serde_json::json!({"command": "echo test"}),
3060 ),
3061 MockLlmClient::text_response("Denied!"),
3062 ]));
3063
3064 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3065
3066 let permission_policy = PermissionPolicy::new(); let config = AgentConfig {
3070 permission_checker: Some(Arc::new(permission_policy)),
3071 ..Default::default()
3073 };
3074
3075 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3076 let result = agent.execute(&[], "Echo", None).await.unwrap();
3077
3078 assert_eq!(result.text, "Denied!");
3080 assert!(result.tool_calls_count >= 1);
3082 }
3083
3084 #[tokio::test]
3089 async fn test_agent_hitl_approved() {
3090 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
3091 use tokio::sync::broadcast;
3092
3093 let mock_client = Arc::new(MockLlmClient::new(vec![
3094 MockLlmClient::tool_call_response(
3095 "tool-1",
3096 "bash",
3097 serde_json::json!({"command": "echo hello"}),
3098 ),
3099 MockLlmClient::text_response("Command executed!"),
3100 ]));
3101
3102 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3103
3104 let (event_tx, _event_rx) = broadcast::channel(100);
3106 let hitl_policy = ConfirmationPolicy {
3107 enabled: true,
3108 ..Default::default()
3109 };
3110 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3111
3112 let permission_policy = PermissionPolicy::new(); let config = AgentConfig {
3116 permission_checker: Some(Arc::new(permission_policy)),
3117 confirmation_manager: Some(confirmation_manager.clone()),
3118 ..Default::default()
3119 };
3120
3121 let cm_clone = confirmation_manager.clone();
3123 tokio::spawn(async move {
3124 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
3126 cm_clone.confirm("tool-1", true, None).await.ok();
3128 });
3129
3130 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3131 let result = agent.execute(&[], "Run echo", None).await.unwrap();
3132
3133 assert_eq!(result.text, "Command executed!");
3134 assert_eq!(result.tool_calls_count, 1);
3135 }
3136
3137 #[tokio::test]
3138 async fn test_agent_hitl_rejected() {
3139 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
3140 use tokio::sync::broadcast;
3141
3142 let mock_client = Arc::new(MockLlmClient::new(vec![
3143 MockLlmClient::tool_call_response(
3144 "tool-1",
3145 "bash",
3146 serde_json::json!({"command": "rm -rf /"}),
3147 ),
3148 MockLlmClient::text_response("Understood, I won't do that."),
3149 ]));
3150
3151 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3152
3153 let (event_tx, _event_rx) = broadcast::channel(100);
3155 let hitl_policy = ConfirmationPolicy {
3156 enabled: true,
3157 ..Default::default()
3158 };
3159 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3160
3161 let permission_policy = PermissionPolicy::new();
3163
3164 let config = AgentConfig {
3165 permission_checker: Some(Arc::new(permission_policy)),
3166 confirmation_manager: Some(confirmation_manager.clone()),
3167 ..Default::default()
3168 };
3169
3170 let cm_clone = confirmation_manager.clone();
3172 tokio::spawn(async move {
3173 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
3174 cm_clone
3175 .confirm("tool-1", false, Some("Too dangerous".to_string()))
3176 .await
3177 .ok();
3178 });
3179
3180 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3181 let result = agent.execute(&[], "Delete everything", None).await.unwrap();
3182
3183 assert_eq!(result.text, "Understood, I won't do that.");
3185 }
3186
3187 #[tokio::test]
3188 async fn test_agent_hitl_timeout_reject() {
3189 use crate::hitl::{ConfirmationManager, ConfirmationPolicy, TimeoutAction};
3190 use tokio::sync::broadcast;
3191
3192 let mock_client = Arc::new(MockLlmClient::new(vec![
3193 MockLlmClient::tool_call_response(
3194 "tool-1",
3195 "bash",
3196 serde_json::json!({"command": "echo test"}),
3197 ),
3198 MockLlmClient::text_response("Timed out, I understand."),
3199 ]));
3200
3201 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3202
3203 let (event_tx, _event_rx) = broadcast::channel(100);
3205 let hitl_policy = ConfirmationPolicy {
3206 enabled: true,
3207 default_timeout_ms: 50, timeout_action: TimeoutAction::Reject,
3209 ..Default::default()
3210 };
3211 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3212
3213 let permission_policy = PermissionPolicy::new();
3214
3215 let config = AgentConfig {
3216 permission_checker: Some(Arc::new(permission_policy)),
3217 confirmation_manager: Some(confirmation_manager),
3218 ..Default::default()
3219 };
3220
3221 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3223 let result = agent.execute(&[], "Echo", None).await.unwrap();
3224
3225 assert_eq!(result.text, "Timed out, I understand.");
3227 }
3228
3229 #[tokio::test]
3230 async fn test_agent_hitl_timeout_auto_approve() {
3231 use crate::hitl::{ConfirmationManager, ConfirmationPolicy, TimeoutAction};
3232 use tokio::sync::broadcast;
3233
3234 let mock_client = Arc::new(MockLlmClient::new(vec![
3235 MockLlmClient::tool_call_response(
3236 "tool-1",
3237 "bash",
3238 serde_json::json!({"command": "echo hello"}),
3239 ),
3240 MockLlmClient::text_response("Auto-approved and executed!"),
3241 ]));
3242
3243 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3244
3245 let (event_tx, _event_rx) = broadcast::channel(100);
3247 let hitl_policy = ConfirmationPolicy {
3248 enabled: true,
3249 default_timeout_ms: 50, timeout_action: TimeoutAction::AutoApprove,
3251 ..Default::default()
3252 };
3253 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3254
3255 let permission_policy = PermissionPolicy::new();
3256
3257 let config = AgentConfig {
3258 permission_checker: Some(Arc::new(permission_policy)),
3259 confirmation_manager: Some(confirmation_manager),
3260 ..Default::default()
3261 };
3262
3263 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3265 let result = agent.execute(&[], "Echo", None).await.unwrap();
3266
3267 assert_eq!(result.text, "Auto-approved and executed!");
3269 assert_eq!(result.tool_calls_count, 1);
3270 }
3271
3272 #[tokio::test]
3273 async fn test_agent_hitl_confirmation_events() {
3274 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
3275 use tokio::sync::broadcast;
3276
3277 let mock_client = Arc::new(MockLlmClient::new(vec![
3278 MockLlmClient::tool_call_response(
3279 "tool-1",
3280 "bash",
3281 serde_json::json!({"command": "echo test"}),
3282 ),
3283 MockLlmClient::text_response("Done!"),
3284 ]));
3285
3286 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3287
3288 let (event_tx, mut event_rx) = broadcast::channel(100);
3290 let hitl_policy = ConfirmationPolicy {
3291 enabled: true,
3292 default_timeout_ms: 5000, ..Default::default()
3294 };
3295 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3296
3297 let permission_policy = PermissionPolicy::new();
3298
3299 let config = AgentConfig {
3300 permission_checker: Some(Arc::new(permission_policy)),
3301 confirmation_manager: Some(confirmation_manager.clone()),
3302 ..Default::default()
3303 };
3304
3305 let cm_clone = confirmation_manager.clone();
3307 let event_handle = tokio::spawn(async move {
3308 let mut events = Vec::new();
3309 while let Ok(event) = event_rx.recv().await {
3311 events.push(event.clone());
3312 if let AgentEvent::ConfirmationRequired { tool_id, .. } = event {
3313 cm_clone.confirm(&tool_id, true, None).await.ok();
3315 if let Ok(recv_event) = event_rx.recv().await {
3317 events.push(recv_event);
3318 }
3319 break;
3320 }
3321 }
3322 events
3323 });
3324
3325 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3326 let _result = agent.execute(&[], "Echo", None).await.unwrap();
3327
3328 let events = event_handle.await.unwrap();
3330 assert!(
3331 events
3332 .iter()
3333 .any(|e| matches!(e, AgentEvent::ConfirmationRequired { .. })),
3334 "Should have ConfirmationRequired event"
3335 );
3336 assert!(
3337 events
3338 .iter()
3339 .any(|e| matches!(e, AgentEvent::ConfirmationReceived { approved: true, .. })),
3340 "Should have ConfirmationReceived event with approved=true"
3341 );
3342 }
3343
3344 #[tokio::test]
3345 async fn test_agent_hitl_disabled_auto_executes() {
3346 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
3348 use tokio::sync::broadcast;
3349
3350 let mock_client = Arc::new(MockLlmClient::new(vec![
3351 MockLlmClient::tool_call_response(
3352 "tool-1",
3353 "bash",
3354 serde_json::json!({"command": "echo auto"}),
3355 ),
3356 MockLlmClient::text_response("Auto executed!"),
3357 ]));
3358
3359 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3360
3361 let (event_tx, _event_rx) = broadcast::channel(100);
3363 let hitl_policy = ConfirmationPolicy {
3364 enabled: false, ..Default::default()
3366 };
3367 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3368
3369 let permission_policy = PermissionPolicy::new(); let config = AgentConfig {
3372 permission_checker: Some(Arc::new(permission_policy)),
3373 confirmation_manager: Some(confirmation_manager),
3374 ..Default::default()
3375 };
3376
3377 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3378 let result = agent.execute(&[], "Echo", None).await.unwrap();
3379
3380 assert_eq!(result.text, "Auto executed!");
3382 assert_eq!(result.tool_calls_count, 1);
3383 }
3384
3385 #[tokio::test]
3386 async fn test_agent_hitl_with_permission_deny_skips_hitl() {
3387 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
3389 use tokio::sync::broadcast;
3390
3391 let mock_client = Arc::new(MockLlmClient::new(vec![
3392 MockLlmClient::tool_call_response(
3393 "tool-1",
3394 "bash",
3395 serde_json::json!({"command": "rm -rf /"}),
3396 ),
3397 MockLlmClient::text_response("Blocked by permission."),
3398 ]));
3399
3400 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3401
3402 let (event_tx, mut event_rx) = broadcast::channel(100);
3404 let hitl_policy = ConfirmationPolicy {
3405 enabled: true,
3406 ..Default::default()
3407 };
3408 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3409
3410 let permission_policy = PermissionPolicy::new().deny("bash(rm:*)");
3412
3413 let config = AgentConfig {
3414 permission_checker: Some(Arc::new(permission_policy)),
3415 confirmation_manager: Some(confirmation_manager),
3416 ..Default::default()
3417 };
3418
3419 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3420 let result = agent.execute(&[], "Delete", None).await.unwrap();
3421
3422 assert_eq!(result.text, "Blocked by permission.");
3424
3425 let mut found_confirmation = false;
3427 while let Ok(event) = event_rx.try_recv() {
3428 if matches!(event, AgentEvent::ConfirmationRequired { .. }) {
3429 found_confirmation = true;
3430 }
3431 }
3432 assert!(
3433 !found_confirmation,
3434 "HITL should not be triggered when permission is Deny"
3435 );
3436 }
3437
3438 #[tokio::test]
3439 async fn test_agent_hitl_with_permission_allow_skips_hitl() {
3440 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
3443 use tokio::sync::broadcast;
3444
3445 let mock_client = Arc::new(MockLlmClient::new(vec![
3446 MockLlmClient::tool_call_response(
3447 "tool-1",
3448 "bash",
3449 serde_json::json!({"command": "echo hello"}),
3450 ),
3451 MockLlmClient::text_response("Allowed!"),
3452 ]));
3453
3454 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3455
3456 let (event_tx, mut event_rx) = broadcast::channel(100);
3458 let hitl_policy = ConfirmationPolicy {
3459 enabled: true,
3460 ..Default::default()
3461 };
3462 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3463
3464 let permission_policy = PermissionPolicy::new().allow("bash(echo:*)");
3466
3467 let config = AgentConfig {
3468 permission_checker: Some(Arc::new(permission_policy)),
3469 confirmation_manager: Some(confirmation_manager.clone()),
3470 ..Default::default()
3471 };
3472
3473 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3474 let result = agent.execute(&[], "Echo", None).await.unwrap();
3475
3476 assert_eq!(result.text, "Allowed!");
3478
3479 let mut found_confirmation = false;
3481 while let Ok(event) = event_rx.try_recv() {
3482 if matches!(event, AgentEvent::ConfirmationRequired { .. }) {
3483 found_confirmation = true;
3484 }
3485 }
3486 assert!(
3487 !found_confirmation,
3488 "Permission Allow should skip HITL confirmation"
3489 );
3490 }
3491
3492 #[tokio::test]
3493 async fn test_agent_hitl_multiple_tool_calls() {
3494 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
3496 use tokio::sync::broadcast;
3497
3498 let mock_client = Arc::new(MockLlmClient::new(vec![
3499 LlmResponse {
3501 message: Message {
3502 role: "assistant".to_string(),
3503 content: vec![
3504 ContentBlock::ToolUse {
3505 id: "tool-1".to_string(),
3506 name: "bash".to_string(),
3507 input: serde_json::json!({"command": "echo first"}),
3508 },
3509 ContentBlock::ToolUse {
3510 id: "tool-2".to_string(),
3511 name: "bash".to_string(),
3512 input: serde_json::json!({"command": "echo second"}),
3513 },
3514 ],
3515 reasoning_content: None,
3516 },
3517 usage: TokenUsage {
3518 prompt_tokens: 10,
3519 completion_tokens: 5,
3520 total_tokens: 15,
3521 cache_read_tokens: None,
3522 cache_write_tokens: None,
3523 },
3524 stop_reason: Some("tool_use".to_string()),
3525 },
3526 MockLlmClient::text_response("Both executed!"),
3527 ]));
3528
3529 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3530
3531 let (event_tx, _event_rx) = broadcast::channel(100);
3533 let hitl_policy = ConfirmationPolicy {
3534 enabled: true,
3535 default_timeout_ms: 5000,
3536 ..Default::default()
3537 };
3538 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3539
3540 let permission_policy = PermissionPolicy::new(); let config = AgentConfig {
3543 permission_checker: Some(Arc::new(permission_policy)),
3544 confirmation_manager: Some(confirmation_manager.clone()),
3545 ..Default::default()
3546 };
3547
3548 let cm_clone = confirmation_manager.clone();
3550 tokio::spawn(async move {
3551 tokio::time::sleep(std::time::Duration::from_millis(30)).await;
3552 cm_clone.confirm("tool-1", true, None).await.ok();
3553 tokio::time::sleep(std::time::Duration::from_millis(30)).await;
3554 cm_clone.confirm("tool-2", true, None).await.ok();
3555 });
3556
3557 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3558 let result = agent.execute(&[], "Run both", None).await.unwrap();
3559
3560 assert_eq!(result.text, "Both executed!");
3561 assert_eq!(result.tool_calls_count, 2);
3562 }
3563
3564 #[tokio::test]
3565 async fn test_agent_hitl_partial_approval() {
3566 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
3568 use tokio::sync::broadcast;
3569
3570 let mock_client = Arc::new(MockLlmClient::new(vec![
3571 LlmResponse {
3573 message: Message {
3574 role: "assistant".to_string(),
3575 content: vec![
3576 ContentBlock::ToolUse {
3577 id: "tool-1".to_string(),
3578 name: "bash".to_string(),
3579 input: serde_json::json!({"command": "echo safe"}),
3580 },
3581 ContentBlock::ToolUse {
3582 id: "tool-2".to_string(),
3583 name: "bash".to_string(),
3584 input: serde_json::json!({"command": "rm -rf /"}),
3585 },
3586 ],
3587 reasoning_content: None,
3588 },
3589 usage: TokenUsage {
3590 prompt_tokens: 10,
3591 completion_tokens: 5,
3592 total_tokens: 15,
3593 cache_read_tokens: None,
3594 cache_write_tokens: None,
3595 },
3596 stop_reason: Some("tool_use".to_string()),
3597 },
3598 MockLlmClient::text_response("First worked, second rejected."),
3599 ]));
3600
3601 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3602
3603 let (event_tx, _event_rx) = broadcast::channel(100);
3604 let hitl_policy = ConfirmationPolicy {
3605 enabled: true,
3606 default_timeout_ms: 5000,
3607 ..Default::default()
3608 };
3609 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3610
3611 let permission_policy = PermissionPolicy::new();
3612
3613 let config = AgentConfig {
3614 permission_checker: Some(Arc::new(permission_policy)),
3615 confirmation_manager: Some(confirmation_manager.clone()),
3616 ..Default::default()
3617 };
3618
3619 let cm_clone = confirmation_manager.clone();
3621 tokio::spawn(async move {
3622 tokio::time::sleep(std::time::Duration::from_millis(30)).await;
3623 cm_clone.confirm("tool-1", true, None).await.ok();
3624 tokio::time::sleep(std::time::Duration::from_millis(30)).await;
3625 cm_clone
3626 .confirm("tool-2", false, Some("Dangerous".to_string()))
3627 .await
3628 .ok();
3629 });
3630
3631 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3632 let result = agent.execute(&[], "Run both", None).await.unwrap();
3633
3634 assert_eq!(result.text, "First worked, second rejected.");
3635 assert_eq!(result.tool_calls_count, 2);
3636 }
3637
3638 #[tokio::test]
3639 async fn test_agent_hitl_yolo_mode_auto_approves() {
3640 use crate::hitl::{ConfirmationManager, ConfirmationPolicy, SessionLane};
3642 use tokio::sync::broadcast;
3643
3644 let mock_client = Arc::new(MockLlmClient::new(vec![
3645 MockLlmClient::tool_call_response(
3646 "tool-1",
3647 "read", serde_json::json!({"path": "/tmp/test.txt"}),
3649 ),
3650 MockLlmClient::text_response("File read!"),
3651 ]));
3652
3653 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3654
3655 let (event_tx, mut event_rx) = broadcast::channel(100);
3657 let mut yolo_lanes = std::collections::HashSet::new();
3658 yolo_lanes.insert(SessionLane::Query);
3659 let hitl_policy = ConfirmationPolicy {
3660 enabled: true,
3661 yolo_lanes, ..Default::default()
3663 };
3664 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3665
3666 let permission_policy = PermissionPolicy::new();
3667
3668 let config = AgentConfig {
3669 permission_checker: Some(Arc::new(permission_policy)),
3670 confirmation_manager: Some(confirmation_manager),
3671 ..Default::default()
3672 };
3673
3674 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3675 let result = agent.execute(&[], "Read file", None).await.unwrap();
3676
3677 assert_eq!(result.text, "File read!");
3679
3680 let mut found_confirmation = false;
3682 while let Ok(event) = event_rx.try_recv() {
3683 if matches!(event, AgentEvent::ConfirmationRequired { .. }) {
3684 found_confirmation = true;
3685 }
3686 }
3687 assert!(
3688 !found_confirmation,
3689 "YOLO mode should not trigger confirmation"
3690 );
3691 }
3692
3693 #[tokio::test]
3694 async fn test_agent_config_with_all_options() {
3695 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
3696 use tokio::sync::broadcast;
3697
3698 let (event_tx, _) = broadcast::channel(100);
3699 let hitl_policy = ConfirmationPolicy::default();
3700 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3701
3702 let permission_policy = PermissionPolicy::new().allow("bash(*)");
3703
3704 let config = AgentConfig {
3705 prompt_slots: SystemPromptSlots {
3706 extra: Some("Test system prompt".to_string()),
3707 ..Default::default()
3708 },
3709 tools: vec![],
3710 max_tool_rounds: 10,
3711 permission_checker: Some(Arc::new(permission_policy)),
3712 confirmation_manager: Some(confirmation_manager),
3713 context_providers: vec![],
3714 planning_enabled: false,
3715 goal_tracking: false,
3716 hook_engine: None,
3717 skill_registry: None,
3718 ..AgentConfig::default()
3719 };
3720
3721 assert!(config.prompt_slots.build().contains("Test system prompt"));
3722 assert_eq!(config.max_tool_rounds, 10);
3723 assert!(config.permission_checker.is_some());
3724 assert!(config.confirmation_manager.is_some());
3725 assert!(config.context_providers.is_empty());
3726
3727 let debug_str = format!("{:?}", config);
3729 assert!(debug_str.contains("AgentConfig"));
3730 assert!(debug_str.contains("permission_checker: true"));
3731 assert!(debug_str.contains("confirmation_manager: true"));
3732 assert!(debug_str.contains("context_providers: 0"));
3733 }
3734
3735 use crate::context::{ContextItem, ContextType};
3740
3741 struct MockContextProvider {
3743 name: String,
3744 items: Vec<ContextItem>,
3745 on_turn_calls: std::sync::Arc<tokio::sync::RwLock<Vec<(String, String, String)>>>,
3746 }
3747
3748 impl MockContextProvider {
3749 fn new(name: &str) -> Self {
3750 Self {
3751 name: name.to_string(),
3752 items: Vec::new(),
3753 on_turn_calls: std::sync::Arc::new(tokio::sync::RwLock::new(Vec::new())),
3754 }
3755 }
3756
3757 fn with_items(mut self, items: Vec<ContextItem>) -> Self {
3758 self.items = items;
3759 self
3760 }
3761 }
3762
3763 #[async_trait::async_trait]
3764 impl ContextProvider for MockContextProvider {
3765 fn name(&self) -> &str {
3766 &self.name
3767 }
3768
3769 async fn query(&self, _query: &ContextQuery) -> anyhow::Result<ContextResult> {
3770 let mut result = ContextResult::new(&self.name);
3771 for item in &self.items {
3772 result.add_item(item.clone());
3773 }
3774 Ok(result)
3775 }
3776
3777 async fn on_turn_complete(
3778 &self,
3779 session_id: &str,
3780 prompt: &str,
3781 response: &str,
3782 ) -> anyhow::Result<()> {
3783 let mut calls = self.on_turn_calls.write().await;
3784 calls.push((
3785 session_id.to_string(),
3786 prompt.to_string(),
3787 response.to_string(),
3788 ));
3789 Ok(())
3790 }
3791 }
3792
3793 #[tokio::test]
3794 async fn test_agent_with_context_provider() {
3795 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
3796 "Response using context",
3797 )]));
3798
3799 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3800
3801 let provider =
3802 MockContextProvider::new("test-provider").with_items(vec![ContextItem::new(
3803 "ctx-1",
3804 ContextType::Resource,
3805 "Relevant context here",
3806 )
3807 .with_source("test://docs/example")]);
3808
3809 let config = AgentConfig {
3810 prompt_slots: SystemPromptSlots {
3811 extra: Some("You are helpful.".to_string()),
3812 ..Default::default()
3813 },
3814 context_providers: vec![Arc::new(provider)],
3815 ..Default::default()
3816 };
3817
3818 let agent = AgentLoop::new(
3819 mock_client.clone(),
3820 tool_executor,
3821 test_tool_context(),
3822 config,
3823 );
3824 let result = agent.execute(&[], "What is X?", None).await.unwrap();
3825
3826 assert_eq!(result.text, "Response using context");
3827 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 1);
3828 }
3829
3830 #[tokio::test]
3831 async fn test_agent_context_provider_events() {
3832 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
3833 "Answer",
3834 )]));
3835
3836 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3837
3838 let provider =
3839 MockContextProvider::new("event-provider").with_items(vec![ContextItem::new(
3840 "item-1",
3841 ContextType::Memory,
3842 "Memory content",
3843 )
3844 .with_token_count(50)]);
3845
3846 let config = AgentConfig {
3847 context_providers: vec![Arc::new(provider)],
3848 ..Default::default()
3849 };
3850
3851 let (tx, mut rx) = mpsc::channel(100);
3852 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3853 let _result = agent.execute(&[], "Test prompt", Some(tx)).await.unwrap();
3854
3855 let mut events = Vec::new();
3857 while let Ok(event) = rx.try_recv() {
3858 events.push(event);
3859 }
3860
3861 assert!(
3863 events
3864 .iter()
3865 .any(|e| matches!(e, AgentEvent::ContextResolving { .. })),
3866 "Should have ContextResolving event"
3867 );
3868 assert!(
3869 events
3870 .iter()
3871 .any(|e| matches!(e, AgentEvent::ContextResolved { .. })),
3872 "Should have ContextResolved event"
3873 );
3874
3875 for event in &events {
3877 if let AgentEvent::ContextResolved {
3878 total_items,
3879 total_tokens,
3880 } = event
3881 {
3882 assert_eq!(*total_items, 1);
3883 assert_eq!(*total_tokens, 50);
3884 }
3885 }
3886 }
3887
3888 #[tokio::test]
3889 async fn test_agent_multiple_context_providers() {
3890 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
3891 "Combined response",
3892 )]));
3893
3894 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3895
3896 let provider1 = MockContextProvider::new("provider-1").with_items(vec![ContextItem::new(
3897 "p1-1",
3898 ContextType::Resource,
3899 "Resource from P1",
3900 )
3901 .with_token_count(100)]);
3902
3903 let provider2 = MockContextProvider::new("provider-2").with_items(vec![
3904 ContextItem::new("p2-1", ContextType::Memory, "Memory from P2").with_token_count(50),
3905 ContextItem::new("p2-2", ContextType::Skill, "Skill from P2").with_token_count(75),
3906 ]);
3907
3908 let config = AgentConfig {
3909 prompt_slots: SystemPromptSlots {
3910 extra: Some("Base system prompt.".to_string()),
3911 ..Default::default()
3912 },
3913 context_providers: vec![Arc::new(provider1), Arc::new(provider2)],
3914 ..Default::default()
3915 };
3916
3917 let (tx, mut rx) = mpsc::channel(100);
3918 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3919 let result = agent.execute(&[], "Query", Some(tx)).await.unwrap();
3920
3921 assert_eq!(result.text, "Combined response");
3922
3923 while let Ok(event) = rx.try_recv() {
3925 if let AgentEvent::ContextResolved {
3926 total_items,
3927 total_tokens,
3928 } = event
3929 {
3930 assert_eq!(total_items, 3); assert_eq!(total_tokens, 225); }
3933 }
3934 }
3935
3936 #[tokio::test]
3937 async fn test_agent_no_context_providers() {
3938 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
3939 "No context",
3940 )]));
3941
3942 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3943
3944 let config = AgentConfig::default();
3946
3947 let (tx, mut rx) = mpsc::channel(100);
3948 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3949 let result = agent.execute(&[], "Simple prompt", Some(tx)).await.unwrap();
3950
3951 assert_eq!(result.text, "No context");
3952
3953 let mut events = Vec::new();
3955 while let Ok(event) = rx.try_recv() {
3956 events.push(event);
3957 }
3958
3959 assert!(
3960 !events
3961 .iter()
3962 .any(|e| matches!(e, AgentEvent::ContextResolving { .. })),
3963 "Should NOT have ContextResolving event"
3964 );
3965 }
3966
3967 #[tokio::test]
3968 async fn test_agent_context_on_turn_complete() {
3969 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
3970 "Final response",
3971 )]));
3972
3973 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3974
3975 let provider = Arc::new(MockContextProvider::new("memory-provider"));
3976 let on_turn_calls = provider.on_turn_calls.clone();
3977
3978 let config = AgentConfig {
3979 context_providers: vec![provider],
3980 ..Default::default()
3981 };
3982
3983 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3984
3985 let result = agent
3987 .execute_with_session(&[], "User prompt", Some("sess-123"), None)
3988 .await
3989 .unwrap();
3990
3991 assert_eq!(result.text, "Final response");
3992
3993 let calls = on_turn_calls.read().await;
3995 assert_eq!(calls.len(), 1);
3996 assert_eq!(calls[0].0, "sess-123");
3997 assert_eq!(calls[0].1, "User prompt");
3998 assert_eq!(calls[0].2, "Final response");
3999 }
4000
4001 #[tokio::test]
4002 async fn test_agent_context_on_turn_complete_no_session() {
4003 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4004 "Response",
4005 )]));
4006
4007 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4008
4009 let provider = Arc::new(MockContextProvider::new("memory-provider"));
4010 let on_turn_calls = provider.on_turn_calls.clone();
4011
4012 let config = AgentConfig {
4013 context_providers: vec![provider],
4014 ..Default::default()
4015 };
4016
4017 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
4018
4019 let _result = agent.execute(&[], "Prompt", None).await.unwrap();
4021
4022 let calls = on_turn_calls.read().await;
4024 assert!(calls.is_empty());
4025 }
4026
4027 #[tokio::test]
4028 async fn test_agent_build_augmented_system_prompt() {
4029 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response("OK")]));
4030
4031 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4032
4033 let provider = MockContextProvider::new("test").with_items(vec![ContextItem::new(
4034 "doc-1",
4035 ContextType::Resource,
4036 "Auth uses JWT tokens.",
4037 )
4038 .with_source("viking://docs/auth")]);
4039
4040 let config = AgentConfig {
4041 prompt_slots: SystemPromptSlots {
4042 extra: Some("You are helpful.".to_string()),
4043 ..Default::default()
4044 },
4045 context_providers: vec![Arc::new(provider)],
4046 ..Default::default()
4047 };
4048
4049 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
4050
4051 let context_results = agent.resolve_context("test", None).await;
4053 let augmented = agent.build_augmented_system_prompt(&context_results);
4054
4055 let augmented_str = augmented.unwrap();
4056 assert!(augmented_str.contains("You are helpful."));
4057 assert!(augmented_str.contains("<context source=\"viking://docs/auth\" type=\"Resource\">"));
4058 assert!(augmented_str.contains("Auth uses JWT tokens."));
4059 }
4060
4061 async fn collect_events(mut rx: mpsc::Receiver<AgentEvent>) -> Vec<AgentEvent> {
4067 let mut events = Vec::new();
4068 while let Ok(event) = rx.try_recv() {
4069 events.push(event);
4070 }
4071 while let Some(event) = rx.recv().await {
4073 events.push(event);
4074 }
4075 events
4076 }
4077
4078 #[tokio::test]
4079 async fn test_agent_multi_turn_tool_chain() {
4080 let mock_client = Arc::new(MockLlmClient::new(vec![
4082 MockLlmClient::tool_call_response(
4084 "t1",
4085 "bash",
4086 serde_json::json!({"command": "echo step1"}),
4087 ),
4088 MockLlmClient::tool_call_response(
4090 "t2",
4091 "bash",
4092 serde_json::json!({"command": "echo step2"}),
4093 ),
4094 MockLlmClient::text_response("Completed both steps: step1 then step2"),
4096 ]));
4097
4098 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4099 let config = AgentConfig::default();
4100
4101 let agent = AgentLoop::new(
4102 mock_client.clone(),
4103 tool_executor,
4104 test_tool_context(),
4105 config,
4106 );
4107 let result = agent.execute(&[], "Run two steps", None).await.unwrap();
4108
4109 assert_eq!(result.text, "Completed both steps: step1 then step2");
4110 assert_eq!(result.tool_calls_count, 2);
4111 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 3);
4112
4113 assert_eq!(result.messages[0].role, "user");
4115 assert_eq!(result.messages[1].role, "assistant"); assert_eq!(result.messages[2].role, "user"); assert_eq!(result.messages[3].role, "assistant"); assert_eq!(result.messages[4].role, "user"); assert_eq!(result.messages[5].role, "assistant"); assert_eq!(result.messages.len(), 6);
4121 }
4122
4123 #[tokio::test]
4124 async fn test_agent_conversation_history_preserved() {
4125 let existing_history = vec![
4127 Message::user("What is Rust?"),
4128 Message {
4129 role: "assistant".to_string(),
4130 content: vec![ContentBlock::Text {
4131 text: "Rust is a systems programming language.".to_string(),
4132 }],
4133 reasoning_content: None,
4134 },
4135 ];
4136
4137 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4138 "Rust was created by Graydon Hoare at Mozilla.",
4139 )]));
4140
4141 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4142 let agent = AgentLoop::new(
4143 mock_client.clone(),
4144 tool_executor,
4145 test_tool_context(),
4146 AgentConfig::default(),
4147 );
4148
4149 let result = agent
4150 .execute(&existing_history, "Who created it?", None)
4151 .await
4152 .unwrap();
4153
4154 assert_eq!(result.messages.len(), 4);
4156 assert_eq!(result.messages[0].text(), "What is Rust?");
4157 assert_eq!(
4158 result.messages[1].text(),
4159 "Rust is a systems programming language."
4160 );
4161 assert_eq!(result.messages[2].text(), "Who created it?");
4162 assert_eq!(
4163 result.messages[3].text(),
4164 "Rust was created by Graydon Hoare at Mozilla."
4165 );
4166 }
4167
4168 #[tokio::test]
4169 async fn test_agent_event_stream_completeness() {
4170 let mock_client = Arc::new(MockLlmClient::new(vec![
4172 MockLlmClient::tool_call_response(
4173 "t1",
4174 "bash",
4175 serde_json::json!({"command": "echo hi"}),
4176 ),
4177 MockLlmClient::text_response("Done"),
4178 ]));
4179
4180 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4181 let agent = AgentLoop::new(
4182 mock_client,
4183 tool_executor,
4184 test_tool_context(),
4185 AgentConfig::default(),
4186 );
4187
4188 let (tx, rx) = mpsc::channel(100);
4189 let result = agent.execute(&[], "Say hi", Some(tx)).await.unwrap();
4190 assert_eq!(result.text, "Done");
4191
4192 let events = collect_events(rx).await;
4193
4194 let event_types: Vec<&str> = events
4196 .iter()
4197 .map(|e| match e {
4198 AgentEvent::Start { .. } => "Start",
4199 AgentEvent::TurnStart { .. } => "TurnStart",
4200 AgentEvent::TurnEnd { .. } => "TurnEnd",
4201 AgentEvent::ToolEnd { .. } => "ToolEnd",
4202 AgentEvent::End { .. } => "End",
4203 _ => "Other",
4204 })
4205 .collect();
4206
4207 assert_eq!(event_types.first(), Some(&"Start"));
4209 assert_eq!(event_types.last(), Some(&"End"));
4210
4211 let turn_starts = event_types.iter().filter(|&&t| t == "TurnStart").count();
4213 assert_eq!(turn_starts, 2);
4214
4215 let tool_ends = event_types.iter().filter(|&&t| t == "ToolEnd").count();
4217 assert_eq!(tool_ends, 1);
4218 }
4219
4220 #[tokio::test]
4221 async fn test_agent_multiple_tools_single_turn() {
4222 let mock_client = Arc::new(MockLlmClient::new(vec![
4224 LlmResponse {
4225 message: Message {
4226 role: "assistant".to_string(),
4227 content: vec![
4228 ContentBlock::ToolUse {
4229 id: "t1".to_string(),
4230 name: "bash".to_string(),
4231 input: serde_json::json!({"command": "echo first"}),
4232 },
4233 ContentBlock::ToolUse {
4234 id: "t2".to_string(),
4235 name: "bash".to_string(),
4236 input: serde_json::json!({"command": "echo second"}),
4237 },
4238 ],
4239 reasoning_content: None,
4240 },
4241 usage: TokenUsage {
4242 prompt_tokens: 10,
4243 completion_tokens: 5,
4244 total_tokens: 15,
4245 cache_read_tokens: None,
4246 cache_write_tokens: None,
4247 },
4248 stop_reason: Some("tool_use".to_string()),
4249 },
4250 MockLlmClient::text_response("Both commands ran"),
4251 ]));
4252
4253 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4254 let agent = AgentLoop::new(
4255 mock_client.clone(),
4256 tool_executor,
4257 test_tool_context(),
4258 AgentConfig::default(),
4259 );
4260
4261 let result = agent.execute(&[], "Run both", None).await.unwrap();
4262
4263 assert_eq!(result.text, "Both commands ran");
4264 assert_eq!(result.tool_calls_count, 2);
4265 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 2); assert_eq!(result.messages[0].role, "user");
4269 assert_eq!(result.messages[1].role, "assistant");
4270 assert_eq!(result.messages[2].role, "user"); assert_eq!(result.messages[3].role, "user"); assert_eq!(result.messages[4].role, "assistant");
4273 }
4274
4275 #[tokio::test]
4276 async fn test_agent_token_usage_accumulation() {
4277 let mock_client = Arc::new(MockLlmClient::new(vec![
4279 MockLlmClient::tool_call_response(
4280 "t1",
4281 "bash",
4282 serde_json::json!({"command": "echo x"}),
4283 ),
4284 MockLlmClient::text_response("Done"),
4285 ]));
4286
4287 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4288 let agent = AgentLoop::new(
4289 mock_client,
4290 tool_executor,
4291 test_tool_context(),
4292 AgentConfig::default(),
4293 );
4294
4295 let result = agent.execute(&[], "test", None).await.unwrap();
4296
4297 assert_eq!(result.usage.prompt_tokens, 20);
4300 assert_eq!(result.usage.completion_tokens, 10);
4301 assert_eq!(result.usage.total_tokens, 30);
4302 }
4303
4304 #[tokio::test]
4305 async fn test_agent_system_prompt_passed() {
4306 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4308 "I am a coding assistant.",
4309 )]));
4310
4311 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4312 let config = AgentConfig {
4313 prompt_slots: SystemPromptSlots {
4314 extra: Some("You are a coding assistant.".to_string()),
4315 ..Default::default()
4316 },
4317 ..Default::default()
4318 };
4319
4320 let agent = AgentLoop::new(
4321 mock_client.clone(),
4322 tool_executor,
4323 test_tool_context(),
4324 config,
4325 );
4326 let result = agent.execute(&[], "What are you?", None).await.unwrap();
4327
4328 assert_eq!(result.text, "I am a coding assistant.");
4329 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 1);
4330 }
4331
4332 #[tokio::test]
4333 async fn test_agent_max_rounds_with_persistent_tool_calls() {
4334 let mut responses = Vec::new();
4336 for i in 0..15 {
4337 responses.push(MockLlmClient::tool_call_response(
4338 &format!("t{}", i),
4339 "bash",
4340 serde_json::json!({"command": format!("echo round{}", i)}),
4341 ));
4342 }
4343
4344 let mock_client = Arc::new(MockLlmClient::new(responses));
4345 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4346 let config = AgentConfig {
4347 max_tool_rounds: 5,
4348 ..Default::default()
4349 };
4350
4351 let agent = AgentLoop::new(
4352 mock_client.clone(),
4353 tool_executor,
4354 test_tool_context(),
4355 config,
4356 );
4357 let result = agent.execute(&[], "Loop forever", None).await;
4358
4359 assert!(result.is_err());
4360 let err = result.unwrap_err().to_string();
4361 assert!(err.contains("Max tool rounds (5) exceeded"));
4362 }
4363
4364 #[tokio::test]
4365 async fn test_agent_end_event_contains_final_text() {
4366 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4367 "Final answer here",
4368 )]));
4369
4370 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4371 let agent = AgentLoop::new(
4372 mock_client,
4373 tool_executor,
4374 test_tool_context(),
4375 AgentConfig::default(),
4376 );
4377
4378 let (tx, rx) = mpsc::channel(100);
4379 agent.execute(&[], "test", Some(tx)).await.unwrap();
4380
4381 let events = collect_events(rx).await;
4382 let end_event = events.iter().find(|e| matches!(e, AgentEvent::End { .. }));
4383 assert!(end_event.is_some());
4384
4385 if let AgentEvent::End { text, usage } = end_event.unwrap() {
4386 assert_eq!(text, "Final answer here");
4387 assert_eq!(usage.total_tokens, 15);
4388 }
4389 }
4390}
4391
4392#[cfg(test)]
4393mod extra_agent_tests {
4394 use super::*;
4395 use crate::agent::tests::MockLlmClient;
4396 use crate::queue::SessionQueueConfig;
4397 use crate::tools::ToolExecutor;
4398 use std::path::PathBuf;
4399 use std::sync::atomic::{AtomicUsize, Ordering};
4400
4401 fn test_tool_context() -> ToolContext {
4402 ToolContext::new(PathBuf::from("/tmp"))
4403 }
4404
4405 #[test]
4410 fn test_agent_config_debug() {
4411 let config = AgentConfig {
4412 prompt_slots: SystemPromptSlots {
4413 extra: Some("You are helpful".to_string()),
4414 ..Default::default()
4415 },
4416 tools: vec![],
4417 max_tool_rounds: 10,
4418 permission_checker: None,
4419 confirmation_manager: None,
4420 context_providers: vec![],
4421 planning_enabled: true,
4422 goal_tracking: false,
4423 hook_engine: None,
4424 skill_registry: None,
4425 ..AgentConfig::default()
4426 };
4427 let debug = format!("{:?}", config);
4428 assert!(debug.contains("AgentConfig"));
4429 assert!(debug.contains("planning_enabled"));
4430 }
4431
4432 #[test]
4433 fn test_agent_config_default_values() {
4434 let config = AgentConfig::default();
4435 assert_eq!(config.max_tool_rounds, MAX_TOOL_ROUNDS);
4436 assert!(!config.planning_enabled);
4437 assert!(!config.goal_tracking);
4438 assert!(config.context_providers.is_empty());
4439 }
4440
4441 #[test]
4446 fn test_agent_event_serialize_start() {
4447 let event = AgentEvent::Start {
4448 prompt: "Hello".to_string(),
4449 };
4450 let json = serde_json::to_string(&event).unwrap();
4451 assert!(json.contains("agent_start"));
4452 assert!(json.contains("Hello"));
4453 }
4454
4455 #[test]
4456 fn test_agent_event_serialize_text_delta() {
4457 let event = AgentEvent::TextDelta {
4458 text: "chunk".to_string(),
4459 };
4460 let json = serde_json::to_string(&event).unwrap();
4461 assert!(json.contains("text_delta"));
4462 }
4463
4464 #[test]
4465 fn test_agent_event_serialize_tool_start() {
4466 let event = AgentEvent::ToolStart {
4467 id: "t1".to_string(),
4468 name: "bash".to_string(),
4469 };
4470 let json = serde_json::to_string(&event).unwrap();
4471 assert!(json.contains("tool_start"));
4472 assert!(json.contains("bash"));
4473 }
4474
4475 #[test]
4476 fn test_agent_event_serialize_tool_end() {
4477 let event = AgentEvent::ToolEnd {
4478 id: "t1".to_string(),
4479 name: "bash".to_string(),
4480 output: "hello".to_string(),
4481 exit_code: 0,
4482 };
4483 let json = serde_json::to_string(&event).unwrap();
4484 assert!(json.contains("tool_end"));
4485 }
4486
4487 #[test]
4488 fn test_agent_event_serialize_error() {
4489 let event = AgentEvent::Error {
4490 message: "oops".to_string(),
4491 };
4492 let json = serde_json::to_string(&event).unwrap();
4493 assert!(json.contains("error"));
4494 assert!(json.contains("oops"));
4495 }
4496
4497 #[test]
4498 fn test_agent_event_serialize_confirmation_required() {
4499 let event = AgentEvent::ConfirmationRequired {
4500 tool_id: "t1".to_string(),
4501 tool_name: "bash".to_string(),
4502 args: serde_json::json!({"cmd": "rm"}),
4503 timeout_ms: 30000,
4504 };
4505 let json = serde_json::to_string(&event).unwrap();
4506 assert!(json.contains("confirmation_required"));
4507 }
4508
4509 #[test]
4510 fn test_agent_event_serialize_confirmation_received() {
4511 let event = AgentEvent::ConfirmationReceived {
4512 tool_id: "t1".to_string(),
4513 approved: true,
4514 reason: Some("safe".to_string()),
4515 };
4516 let json = serde_json::to_string(&event).unwrap();
4517 assert!(json.contains("confirmation_received"));
4518 }
4519
4520 #[test]
4521 fn test_agent_event_serialize_confirmation_timeout() {
4522 let event = AgentEvent::ConfirmationTimeout {
4523 tool_id: "t1".to_string(),
4524 action_taken: "rejected".to_string(),
4525 };
4526 let json = serde_json::to_string(&event).unwrap();
4527 assert!(json.contains("confirmation_timeout"));
4528 }
4529
4530 #[test]
4531 fn test_agent_event_serialize_external_task_pending() {
4532 let event = AgentEvent::ExternalTaskPending {
4533 task_id: "task-1".to_string(),
4534 session_id: "sess-1".to_string(),
4535 lane: crate::hitl::SessionLane::Execute,
4536 command_type: "bash".to_string(),
4537 payload: serde_json::json!({}),
4538 timeout_ms: 60000,
4539 };
4540 let json = serde_json::to_string(&event).unwrap();
4541 assert!(json.contains("external_task_pending"));
4542 }
4543
4544 #[test]
4545 fn test_agent_event_serialize_external_task_completed() {
4546 let event = AgentEvent::ExternalTaskCompleted {
4547 task_id: "task-1".to_string(),
4548 session_id: "sess-1".to_string(),
4549 success: false,
4550 };
4551 let json = serde_json::to_string(&event).unwrap();
4552 assert!(json.contains("external_task_completed"));
4553 }
4554
4555 #[test]
4556 fn test_agent_event_serialize_permission_denied() {
4557 let event = AgentEvent::PermissionDenied {
4558 tool_id: "t1".to_string(),
4559 tool_name: "bash".to_string(),
4560 args: serde_json::json!({}),
4561 reason: "denied".to_string(),
4562 };
4563 let json = serde_json::to_string(&event).unwrap();
4564 assert!(json.contains("permission_denied"));
4565 }
4566
4567 #[test]
4568 fn test_agent_event_serialize_context_compacted() {
4569 let event = AgentEvent::ContextCompacted {
4570 session_id: "sess-1".to_string(),
4571 before_messages: 100,
4572 after_messages: 20,
4573 percent_before: 0.85,
4574 };
4575 let json = serde_json::to_string(&event).unwrap();
4576 assert!(json.contains("context_compacted"));
4577 }
4578
4579 #[test]
4580 fn test_agent_event_serialize_turn_start() {
4581 let event = AgentEvent::TurnStart { turn: 3 };
4582 let json = serde_json::to_string(&event).unwrap();
4583 assert!(json.contains("turn_start"));
4584 }
4585
4586 #[test]
4587 fn test_agent_event_serialize_turn_end() {
4588 let event = AgentEvent::TurnEnd {
4589 turn: 3,
4590 usage: TokenUsage::default(),
4591 };
4592 let json = serde_json::to_string(&event).unwrap();
4593 assert!(json.contains("turn_end"));
4594 }
4595
4596 #[test]
4597 fn test_agent_event_serialize_end() {
4598 let event = AgentEvent::End {
4599 text: "Done".to_string(),
4600 usage: TokenUsage {
4601 prompt_tokens: 100,
4602 completion_tokens: 50,
4603 total_tokens: 150,
4604 cache_read_tokens: None,
4605 cache_write_tokens: None,
4606 },
4607 };
4608 let json = serde_json::to_string(&event).unwrap();
4609 assert!(json.contains("agent_end"));
4610 }
4611
4612 #[test]
4617 fn test_agent_result_fields() {
4618 let result = AgentResult {
4619 text: "output".to_string(),
4620 messages: vec![Message::user("hello")],
4621 usage: TokenUsage::default(),
4622 tool_calls_count: 3,
4623 };
4624 assert_eq!(result.text, "output");
4625 assert_eq!(result.messages.len(), 1);
4626 assert_eq!(result.tool_calls_count, 3);
4627 }
4628
4629 #[test]
4634 fn test_agent_event_serialize_context_resolving() {
4635 let event = AgentEvent::ContextResolving {
4636 providers: vec!["provider1".to_string(), "provider2".to_string()],
4637 };
4638 let json = serde_json::to_string(&event).unwrap();
4639 assert!(json.contains("context_resolving"));
4640 assert!(json.contains("provider1"));
4641 }
4642
4643 #[test]
4644 fn test_agent_event_serialize_context_resolved() {
4645 let event = AgentEvent::ContextResolved {
4646 total_items: 5,
4647 total_tokens: 1000,
4648 };
4649 let json = serde_json::to_string(&event).unwrap();
4650 assert!(json.contains("context_resolved"));
4651 assert!(json.contains("1000"));
4652 }
4653
4654 #[test]
4655 fn test_agent_event_serialize_command_dead_lettered() {
4656 let event = AgentEvent::CommandDeadLettered {
4657 command_id: "cmd-1".to_string(),
4658 command_type: "bash".to_string(),
4659 lane: "execute".to_string(),
4660 error: "timeout".to_string(),
4661 attempts: 3,
4662 };
4663 let json = serde_json::to_string(&event).unwrap();
4664 assert!(json.contains("command_dead_lettered"));
4665 assert!(json.contains("cmd-1"));
4666 }
4667
4668 #[test]
4669 fn test_agent_event_serialize_command_retry() {
4670 let event = AgentEvent::CommandRetry {
4671 command_id: "cmd-2".to_string(),
4672 command_type: "read".to_string(),
4673 lane: "query".to_string(),
4674 attempt: 2,
4675 delay_ms: 1000,
4676 };
4677 let json = serde_json::to_string(&event).unwrap();
4678 assert!(json.contains("command_retry"));
4679 assert!(json.contains("cmd-2"));
4680 }
4681
4682 #[test]
4683 fn test_agent_event_serialize_queue_alert() {
4684 let event = AgentEvent::QueueAlert {
4685 level: "warning".to_string(),
4686 alert_type: "depth".to_string(),
4687 message: "Queue depth exceeded".to_string(),
4688 };
4689 let json = serde_json::to_string(&event).unwrap();
4690 assert!(json.contains("queue_alert"));
4691 assert!(json.contains("warning"));
4692 }
4693
4694 #[test]
4695 fn test_agent_event_serialize_task_updated() {
4696 let event = AgentEvent::TaskUpdated {
4697 session_id: "sess-1".to_string(),
4698 tasks: vec![],
4699 };
4700 let json = serde_json::to_string(&event).unwrap();
4701 assert!(json.contains("task_updated"));
4702 assert!(json.contains("sess-1"));
4703 }
4704
4705 #[test]
4706 fn test_agent_event_serialize_memory_stored() {
4707 let event = AgentEvent::MemoryStored {
4708 memory_id: "mem-1".to_string(),
4709 memory_type: "conversation".to_string(),
4710 importance: 0.8,
4711 tags: vec!["important".to_string()],
4712 };
4713 let json = serde_json::to_string(&event).unwrap();
4714 assert!(json.contains("memory_stored"));
4715 assert!(json.contains("mem-1"));
4716 }
4717
4718 #[test]
4719 fn test_agent_event_serialize_memory_recalled() {
4720 let event = AgentEvent::MemoryRecalled {
4721 memory_id: "mem-2".to_string(),
4722 content: "Previous conversation".to_string(),
4723 relevance: 0.9,
4724 };
4725 let json = serde_json::to_string(&event).unwrap();
4726 assert!(json.contains("memory_recalled"));
4727 assert!(json.contains("mem-2"));
4728 }
4729
4730 #[test]
4731 fn test_agent_event_serialize_memories_searched() {
4732 let event = AgentEvent::MemoriesSearched {
4733 query: Some("search term".to_string()),
4734 tags: vec!["tag1".to_string()],
4735 result_count: 5,
4736 };
4737 let json = serde_json::to_string(&event).unwrap();
4738 assert!(json.contains("memories_searched"));
4739 assert!(json.contains("search term"));
4740 }
4741
4742 #[test]
4743 fn test_agent_event_serialize_memory_cleared() {
4744 let event = AgentEvent::MemoryCleared {
4745 tier: "short_term".to_string(),
4746 count: 10,
4747 };
4748 let json = serde_json::to_string(&event).unwrap();
4749 assert!(json.contains("memory_cleared"));
4750 assert!(json.contains("short_term"));
4751 }
4752
4753 #[test]
4754 fn test_agent_event_serialize_subagent_start() {
4755 let event = AgentEvent::SubagentStart {
4756 task_id: "task-1".to_string(),
4757 session_id: "child-sess".to_string(),
4758 parent_session_id: "parent-sess".to_string(),
4759 agent: "explore".to_string(),
4760 description: "Explore codebase".to_string(),
4761 };
4762 let json = serde_json::to_string(&event).unwrap();
4763 assert!(json.contains("subagent_start"));
4764 assert!(json.contains("explore"));
4765 }
4766
4767 #[test]
4768 fn test_agent_event_serialize_subagent_progress() {
4769 let event = AgentEvent::SubagentProgress {
4770 task_id: "task-1".to_string(),
4771 session_id: "child-sess".to_string(),
4772 status: "processing".to_string(),
4773 metadata: serde_json::json!({"progress": 50}),
4774 };
4775 let json = serde_json::to_string(&event).unwrap();
4776 assert!(json.contains("subagent_progress"));
4777 assert!(json.contains("processing"));
4778 }
4779
4780 #[test]
4781 fn test_agent_event_serialize_subagent_end() {
4782 let event = AgentEvent::SubagentEnd {
4783 task_id: "task-1".to_string(),
4784 session_id: "child-sess".to_string(),
4785 agent: "explore".to_string(),
4786 output: "Found 10 files".to_string(),
4787 success: true,
4788 };
4789 let json = serde_json::to_string(&event).unwrap();
4790 assert!(json.contains("subagent_end"));
4791 assert!(json.contains("Found 10 files"));
4792 }
4793
4794 #[test]
4795 fn test_agent_event_serialize_planning_start() {
4796 let event = AgentEvent::PlanningStart {
4797 prompt: "Build a web app".to_string(),
4798 };
4799 let json = serde_json::to_string(&event).unwrap();
4800 assert!(json.contains("planning_start"));
4801 assert!(json.contains("Build a web app"));
4802 }
4803
4804 #[test]
4805 fn test_agent_event_serialize_planning_end() {
4806 use crate::planning::{Complexity, ExecutionPlan};
4807 let plan = ExecutionPlan::new("Test goal".to_string(), Complexity::Simple);
4808 let event = AgentEvent::PlanningEnd {
4809 plan,
4810 estimated_steps: 3,
4811 };
4812 let json = serde_json::to_string(&event).unwrap();
4813 assert!(json.contains("planning_end"));
4814 assert!(json.contains("estimated_steps"));
4815 }
4816
4817 #[test]
4818 fn test_agent_event_serialize_step_start() {
4819 let event = AgentEvent::StepStart {
4820 step_id: "step-1".to_string(),
4821 description: "Initialize project".to_string(),
4822 step_number: 1,
4823 total_steps: 5,
4824 };
4825 let json = serde_json::to_string(&event).unwrap();
4826 assert!(json.contains("step_start"));
4827 assert!(json.contains("Initialize project"));
4828 }
4829
4830 #[test]
4831 fn test_agent_event_serialize_step_end() {
4832 let event = AgentEvent::StepEnd {
4833 step_id: "step-1".to_string(),
4834 status: TaskStatus::Completed,
4835 step_number: 1,
4836 total_steps: 5,
4837 };
4838 let json = serde_json::to_string(&event).unwrap();
4839 assert!(json.contains("step_end"));
4840 assert!(json.contains("step-1"));
4841 }
4842
4843 #[test]
4844 fn test_agent_event_serialize_goal_extracted() {
4845 use crate::planning::AgentGoal;
4846 let goal = AgentGoal::new("Complete the task".to_string());
4847 let event = AgentEvent::GoalExtracted { goal };
4848 let json = serde_json::to_string(&event).unwrap();
4849 assert!(json.contains("goal_extracted"));
4850 }
4851
4852 #[test]
4853 fn test_agent_event_serialize_goal_progress() {
4854 let event = AgentEvent::GoalProgress {
4855 goal: "Build app".to_string(),
4856 progress: 0.5,
4857 completed_steps: 2,
4858 total_steps: 4,
4859 };
4860 let json = serde_json::to_string(&event).unwrap();
4861 assert!(json.contains("goal_progress"));
4862 assert!(json.contains("0.5"));
4863 }
4864
4865 #[test]
4866 fn test_agent_event_serialize_goal_achieved() {
4867 let event = AgentEvent::GoalAchieved {
4868 goal: "Build app".to_string(),
4869 total_steps: 4,
4870 duration_ms: 5000,
4871 };
4872 let json = serde_json::to_string(&event).unwrap();
4873 assert!(json.contains("goal_achieved"));
4874 assert!(json.contains("5000"));
4875 }
4876
4877 #[tokio::test]
4878 async fn test_extract_goal_with_json_response() {
4879 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4881 r#"{"description": "Build web app", "success_criteria": ["App runs on port 3000", "Has login page"]}"#,
4882 )]));
4883 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4884 let agent = AgentLoop::new(
4885 mock_client,
4886 tool_executor,
4887 test_tool_context(),
4888 AgentConfig::default(),
4889 );
4890
4891 let goal = agent.extract_goal("Build a web app").await.unwrap();
4892 assert_eq!(goal.description, "Build web app");
4893 assert_eq!(goal.success_criteria.len(), 2);
4894 assert_eq!(goal.success_criteria[0], "App runs on port 3000");
4895 }
4896
4897 #[tokio::test]
4898 async fn test_extract_goal_fallback_on_non_json() {
4899 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4901 "Some non-JSON response",
4902 )]));
4903 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4904 let agent = AgentLoop::new(
4905 mock_client,
4906 tool_executor,
4907 test_tool_context(),
4908 AgentConfig::default(),
4909 );
4910
4911 let goal = agent.extract_goal("Do something").await.unwrap();
4912 assert_eq!(goal.description, "Do something");
4914 assert_eq!(goal.success_criteria.len(), 2);
4916 }
4917
4918 #[tokio::test]
4919 async fn test_check_goal_achievement_json_yes() {
4920 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4921 r#"{"achieved": true, "progress": 1.0, "remaining_criteria": []}"#,
4922 )]));
4923 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4924 let agent = AgentLoop::new(
4925 mock_client,
4926 tool_executor,
4927 test_tool_context(),
4928 AgentConfig::default(),
4929 );
4930
4931 let goal = crate::planning::AgentGoal::new("Test goal".to_string());
4932 let achieved = agent
4933 .check_goal_achievement(&goal, "All done")
4934 .await
4935 .unwrap();
4936 assert!(achieved);
4937 }
4938
4939 #[tokio::test]
4940 async fn test_check_goal_achievement_fallback_not_done() {
4941 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4943 "invalid json",
4944 )]));
4945 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4946 let agent = AgentLoop::new(
4947 mock_client,
4948 tool_executor,
4949 test_tool_context(),
4950 AgentConfig::default(),
4951 );
4952
4953 let goal = crate::planning::AgentGoal::new("Test goal".to_string());
4954 let achieved = agent
4956 .check_goal_achievement(&goal, "still working")
4957 .await
4958 .unwrap();
4959 assert!(!achieved);
4960 }
4961
4962 #[test]
4967 fn test_build_augmented_system_prompt_empty_context() {
4968 let mock_client = Arc::new(MockLlmClient::new(vec![]));
4969 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4970 let config = AgentConfig {
4971 prompt_slots: SystemPromptSlots {
4972 extra: Some("Base prompt".to_string()),
4973 ..Default::default()
4974 },
4975 ..Default::default()
4976 };
4977 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
4978
4979 let result = agent.build_augmented_system_prompt(&[]);
4980 assert!(result.unwrap().contains("Base prompt"));
4981 }
4982
4983 #[test]
4984 fn test_build_augmented_system_prompt_no_custom_slots() {
4985 let mock_client = Arc::new(MockLlmClient::new(vec![]));
4986 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4987 let agent = AgentLoop::new(
4988 mock_client,
4989 tool_executor,
4990 test_tool_context(),
4991 AgentConfig::default(),
4992 );
4993
4994 let result = agent.build_augmented_system_prompt(&[]);
4995 assert!(result.is_some());
4997 assert!(result.unwrap().contains("Core Behaviour"));
4998 }
4999
5000 #[test]
5001 fn test_build_augmented_system_prompt_with_context_no_base() {
5002 use crate::context::{ContextItem, ContextResult, ContextType};
5003
5004 let mock_client = Arc::new(MockLlmClient::new(vec![]));
5005 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5006 let agent = AgentLoop::new(
5007 mock_client,
5008 tool_executor,
5009 test_tool_context(),
5010 AgentConfig::default(),
5011 );
5012
5013 let context = vec![ContextResult {
5014 provider: "test".to_string(),
5015 items: vec![ContextItem::new("id1", ContextType::Resource, "Content")],
5016 total_tokens: 10,
5017 truncated: false,
5018 }];
5019
5020 let result = agent.build_augmented_system_prompt(&context);
5021 assert!(result.is_some());
5022 let text = result.unwrap();
5023 assert!(text.contains("<context"));
5024 assert!(text.contains("Content"));
5025 }
5026
5027 #[test]
5032 fn test_agent_result_clone() {
5033 let result = AgentResult {
5034 text: "output".to_string(),
5035 messages: vec![Message::user("hello")],
5036 usage: TokenUsage::default(),
5037 tool_calls_count: 3,
5038 };
5039 let cloned = result.clone();
5040 assert_eq!(cloned.text, result.text);
5041 assert_eq!(cloned.tool_calls_count, result.tool_calls_count);
5042 }
5043
5044 #[test]
5045 fn test_agent_result_debug() {
5046 let result = AgentResult {
5047 text: "output".to_string(),
5048 messages: vec![Message::user("hello")],
5049 usage: TokenUsage::default(),
5050 tool_calls_count: 3,
5051 };
5052 let debug = format!("{:?}", result);
5053 assert!(debug.contains("AgentResult"));
5054 assert!(debug.contains("output"));
5055 }
5056
5057 #[tokio::test]
5066 async fn test_tool_command_command_type() {
5067 let executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5068 let cmd = ToolCommand {
5069 tool_executor: executor,
5070 tool_name: "read".to_string(),
5071 tool_args: serde_json::json!({"file": "test.rs"}),
5072 skill_registry: None,
5073 tool_context: test_tool_context(),
5074 };
5075 assert_eq!(cmd.command_type(), "read");
5076 }
5077
5078 #[tokio::test]
5079 async fn test_tool_command_payload() {
5080 let executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5081 let args = serde_json::json!({"file": "test.rs", "offset": 10});
5082 let cmd = ToolCommand {
5083 tool_executor: executor,
5084 tool_name: "read".to_string(),
5085 tool_args: args.clone(),
5086 skill_registry: None,
5087 tool_context: test_tool_context(),
5088 };
5089 assert_eq!(cmd.payload(), args);
5090 }
5091
5092 #[tokio::test(flavor = "multi_thread")]
5097 async fn test_agent_loop_with_queue() {
5098 use tokio::sync::broadcast;
5099
5100 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
5101 "Hello",
5102 )]));
5103 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5104 let config = AgentConfig::default();
5105
5106 let (event_tx, _) = broadcast::channel(100);
5107 let queue = SessionLaneQueue::new("test-session", SessionQueueConfig::default(), event_tx)
5108 .await
5109 .unwrap();
5110
5111 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config)
5112 .with_queue(Arc::new(queue));
5113
5114 assert!(agent.command_queue.is_some());
5115 }
5116
5117 #[tokio::test]
5118 async fn test_agent_loop_without_queue() {
5119 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
5120 "Hello",
5121 )]));
5122 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5123 let config = AgentConfig::default();
5124
5125 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
5126
5127 assert!(agent.command_queue.is_none());
5128 }
5129
5130 #[tokio::test]
5135 async fn test_execute_plan_parallel_independent() {
5136 use crate::planning::{Complexity, ExecutionPlan, Task};
5137
5138 let mock_client = Arc::new(MockLlmClient::new(vec![
5141 MockLlmClient::text_response("Step 1 done"),
5142 MockLlmClient::text_response("Step 2 done"),
5143 MockLlmClient::text_response("Step 3 done"),
5144 ]));
5145
5146 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5147 let config = AgentConfig::default();
5148 let agent = AgentLoop::new(
5149 mock_client.clone(),
5150 tool_executor,
5151 test_tool_context(),
5152 config,
5153 );
5154
5155 let mut plan = ExecutionPlan::new("Test parallel", Complexity::Simple);
5156 plan.add_step(Task::new("s1", "First step"));
5157 plan.add_step(Task::new("s2", "Second step"));
5158 plan.add_step(Task::new("s3", "Third step"));
5159
5160 let (tx, mut rx) = mpsc::channel(100);
5161 let result = agent.execute_plan(&[], &plan, Some(tx)).await.unwrap();
5162
5163 assert_eq!(result.usage.total_tokens, 45);
5165
5166 let mut step_starts = Vec::new();
5168 let mut step_ends = Vec::new();
5169 rx.close();
5170 while let Some(event) = rx.recv().await {
5171 match event {
5172 AgentEvent::StepStart { step_id, .. } => step_starts.push(step_id),
5173 AgentEvent::StepEnd {
5174 step_id, status, ..
5175 } => {
5176 assert_eq!(status, TaskStatus::Completed);
5177 step_ends.push(step_id);
5178 }
5179 _ => {}
5180 }
5181 }
5182 assert_eq!(step_starts.len(), 3);
5183 assert_eq!(step_ends.len(), 3);
5184 }
5185
5186 #[tokio::test]
5187 async fn test_execute_plan_respects_dependencies() {
5188 use crate::planning::{Complexity, ExecutionPlan, Task};
5189
5190 let mock_client = Arc::new(MockLlmClient::new(vec![
5193 MockLlmClient::text_response("Step 1 done"),
5194 MockLlmClient::text_response("Step 2 done"),
5195 MockLlmClient::text_response("Step 3 done"),
5196 ]));
5197
5198 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5199 let config = AgentConfig::default();
5200 let agent = AgentLoop::new(
5201 mock_client.clone(),
5202 tool_executor,
5203 test_tool_context(),
5204 config,
5205 );
5206
5207 let mut plan = ExecutionPlan::new("Test deps", Complexity::Medium);
5208 plan.add_step(Task::new("s1", "Independent A"));
5209 plan.add_step(Task::new("s2", "Independent B"));
5210 plan.add_step(
5211 Task::new("s3", "Depends on A+B")
5212 .with_dependencies(vec!["s1".to_string(), "s2".to_string()]),
5213 );
5214
5215 let (tx, mut rx) = mpsc::channel(100);
5216 let result = agent.execute_plan(&[], &plan, Some(tx)).await.unwrap();
5217
5218 assert_eq!(result.usage.total_tokens, 45);
5220
5221 let mut events = Vec::new();
5223 rx.close();
5224 while let Some(event) = rx.recv().await {
5225 match &event {
5226 AgentEvent::StepStart { step_id, .. } => {
5227 events.push(format!("start:{}", step_id));
5228 }
5229 AgentEvent::StepEnd { step_id, .. } => {
5230 events.push(format!("end:{}", step_id));
5231 }
5232 _ => {}
5233 }
5234 }
5235
5236 let s1_end = events.iter().position(|e| e == "end:s1").unwrap();
5238 let s2_end = events.iter().position(|e| e == "end:s2").unwrap();
5239 let s3_start = events.iter().position(|e| e == "start:s3").unwrap();
5240 assert!(
5241 s3_start > s1_end,
5242 "s3 started before s1 ended: {:?}",
5243 events
5244 );
5245 assert!(
5246 s3_start > s2_end,
5247 "s3 started before s2 ended: {:?}",
5248 events
5249 );
5250
5251 assert!(result.text.contains("Step 3 done") || !result.text.is_empty());
5253 }
5254
5255 #[tokio::test]
5256 async fn test_execute_plan_handles_step_failure() {
5257 use crate::planning::{Complexity, ExecutionPlan, Task};
5258
5259 let mock_client = Arc::new(MockLlmClient::new(vec![
5269 MockLlmClient::text_response("s1 done"),
5271 MockLlmClient::text_response("s3 done"),
5272 ]));
5275
5276 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5277 let config = AgentConfig::default();
5278 let agent = AgentLoop::new(
5279 mock_client.clone(),
5280 tool_executor,
5281 test_tool_context(),
5282 config,
5283 );
5284
5285 let mut plan = ExecutionPlan::new("Test failure", Complexity::Medium);
5286 plan.add_step(Task::new("s1", "Independent step"));
5287 plan.add_step(Task::new("s2", "Depends on s1").with_dependencies(vec!["s1".to_string()]));
5288 plan.add_step(Task::new("s3", "Another independent"));
5289 plan.add_step(Task::new("s4", "Depends on s2").with_dependencies(vec!["s2".to_string()]));
5290
5291 let (tx, mut rx) = mpsc::channel(100);
5292 let _result = agent.execute_plan(&[], &plan, Some(tx)).await.unwrap();
5293
5294 let mut completed_steps = Vec::new();
5297 let mut failed_steps = Vec::new();
5298 rx.close();
5299 while let Some(event) = rx.recv().await {
5300 if let AgentEvent::StepEnd {
5301 step_id, status, ..
5302 } = event
5303 {
5304 match status {
5305 TaskStatus::Completed => completed_steps.push(step_id),
5306 TaskStatus::Failed => failed_steps.push(step_id),
5307 _ => {}
5308 }
5309 }
5310 }
5311
5312 assert!(
5313 completed_steps.contains(&"s1".to_string()),
5314 "s1 should complete"
5315 );
5316 assert!(
5317 completed_steps.contains(&"s3".to_string()),
5318 "s3 should complete"
5319 );
5320 assert!(failed_steps.contains(&"s2".to_string()), "s2 should fail");
5321 assert!(
5323 !completed_steps.contains(&"s4".to_string()),
5324 "s4 should not complete"
5325 );
5326 assert!(
5327 !failed_steps.contains(&"s4".to_string()),
5328 "s4 should not fail (never started)"
5329 );
5330 }
5331
5332 #[test]
5337 fn test_agent_config_resilience_defaults() {
5338 let config = AgentConfig::default();
5339 assert_eq!(config.max_parse_retries, 2);
5340 assert_eq!(config.tool_timeout_ms, None);
5341 assert_eq!(config.circuit_breaker_threshold, 3);
5342 }
5343
5344 #[tokio::test]
5346 async fn test_parse_error_recovery_bails_after_threshold() {
5347 let mock_client = Arc::new(MockLlmClient::new(vec![
5349 MockLlmClient::tool_call_response(
5350 "c1",
5351 "bash",
5352 serde_json::json!({"__parse_error": "unexpected token at position 5"}),
5353 ),
5354 MockLlmClient::tool_call_response(
5355 "c2",
5356 "bash",
5357 serde_json::json!({"__parse_error": "missing closing brace"}),
5358 ),
5359 MockLlmClient::tool_call_response(
5360 "c3",
5361 "bash",
5362 serde_json::json!({"__parse_error": "still broken"}),
5363 ),
5364 MockLlmClient::text_response("Done"), ]));
5366
5367 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5368 let config = AgentConfig {
5369 max_parse_retries: 2,
5370 ..AgentConfig::default()
5371 };
5372 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
5373 let result = agent.execute(&[], "Do something", None).await;
5374 assert!(result.is_err(), "should bail after parse error threshold");
5375 let err = result.unwrap_err().to_string();
5376 assert!(
5377 err.contains("malformed tool arguments"),
5378 "error should mention malformed tool arguments, got: {}",
5379 err
5380 );
5381 }
5382
5383 #[tokio::test]
5385 async fn test_parse_error_counter_resets_on_success() {
5386 let mock_client = Arc::new(MockLlmClient::new(vec![
5390 MockLlmClient::tool_call_response(
5391 "c1",
5392 "bash",
5393 serde_json::json!({"__parse_error": "bad args"}),
5394 ),
5395 MockLlmClient::tool_call_response(
5396 "c2",
5397 "bash",
5398 serde_json::json!({"__parse_error": "bad args again"}),
5399 ),
5400 MockLlmClient::tool_call_response(
5402 "c3",
5403 "bash",
5404 serde_json::json!({"command": "echo ok"}),
5405 ),
5406 MockLlmClient::text_response("All done"),
5407 ]));
5408
5409 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5410 let config = AgentConfig {
5411 max_parse_retries: 2,
5412 ..AgentConfig::default()
5413 };
5414 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
5415 let result = agent.execute(&[], "Do something", None).await;
5416 assert!(
5417 result.is_ok(),
5418 "should not bail — counter reset after successful tool, got: {:?}",
5419 result.err()
5420 );
5421 assert_eq!(result.unwrap().text, "All done");
5422 }
5423
5424 #[tokio::test]
5426 async fn test_tool_timeout_produces_error_result() {
5427 let mock_client = Arc::new(MockLlmClient::new(vec![
5428 MockLlmClient::tool_call_response(
5429 "t1",
5430 "bash",
5431 serde_json::json!({"command": "sleep 10"}),
5432 ),
5433 MockLlmClient::text_response("The command timed out."),
5434 ]));
5435
5436 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5437 let config = AgentConfig {
5438 tool_timeout_ms: Some(50),
5440 ..AgentConfig::default()
5441 };
5442 let agent = AgentLoop::new(
5443 mock_client.clone(),
5444 tool_executor,
5445 test_tool_context(),
5446 config,
5447 );
5448 let result = agent.execute(&[], "Run sleep", None).await;
5449 assert!(
5450 result.is_ok(),
5451 "session should continue after tool timeout: {:?}",
5452 result.err()
5453 );
5454 assert_eq!(result.unwrap().text, "The command timed out.");
5455 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 2);
5457 }
5458
5459 #[tokio::test]
5461 async fn test_tool_within_timeout_succeeds() {
5462 let mock_client = Arc::new(MockLlmClient::new(vec![
5463 MockLlmClient::tool_call_response(
5464 "t1",
5465 "bash",
5466 serde_json::json!({"command": "echo fast"}),
5467 ),
5468 MockLlmClient::text_response("Command succeeded."),
5469 ]));
5470
5471 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5472 let config = AgentConfig {
5473 tool_timeout_ms: Some(5_000), ..AgentConfig::default()
5475 };
5476 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
5477 let result = agent.execute(&[], "Run something fast", None).await;
5478 assert!(
5479 result.is_ok(),
5480 "fast tool should succeed: {:?}",
5481 result.err()
5482 );
5483 assert_eq!(result.unwrap().text, "Command succeeded.");
5484 }
5485
5486 #[tokio::test]
5488 async fn test_circuit_breaker_retries_non_streaming() {
5489 let mock_client = Arc::new(MockLlmClient::new(vec![]));
5492
5493 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5494 let config = AgentConfig {
5495 circuit_breaker_threshold: 2,
5496 ..AgentConfig::default()
5497 };
5498 let agent = AgentLoop::new(
5499 mock_client.clone(),
5500 tool_executor,
5501 test_tool_context(),
5502 config,
5503 );
5504 let result = agent.execute(&[], "Hello", None).await;
5505 assert!(result.is_err(), "should fail when LLM always errors");
5506 let err = result.unwrap_err().to_string();
5507 assert!(
5508 err.contains("circuit breaker"),
5509 "error should mention circuit breaker, got: {}",
5510 err
5511 );
5512 assert_eq!(
5513 mock_client.call_count.load(Ordering::SeqCst),
5514 2,
5515 "should make exactly threshold=2 LLM calls"
5516 );
5517 }
5518
5519 #[tokio::test]
5521 async fn test_circuit_breaker_threshold_one_no_retry() {
5522 let mock_client = Arc::new(MockLlmClient::new(vec![]));
5523
5524 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5525 let config = AgentConfig {
5526 circuit_breaker_threshold: 1,
5527 ..AgentConfig::default()
5528 };
5529 let agent = AgentLoop::new(
5530 mock_client.clone(),
5531 tool_executor,
5532 test_tool_context(),
5533 config,
5534 );
5535 let result = agent.execute(&[], "Hello", None).await;
5536 assert!(result.is_err());
5537 assert_eq!(
5538 mock_client.call_count.load(Ordering::SeqCst),
5539 1,
5540 "with threshold=1 exactly one attempt should be made"
5541 );
5542 }
5543
5544 #[tokio::test]
5546 async fn test_circuit_breaker_succeeds_if_llm_recovers() {
5547 struct FailOnceThenSucceed {
5549 inner: MockLlmClient,
5550 failed_once: std::sync::atomic::AtomicBool,
5551 call_count: AtomicUsize,
5552 }
5553
5554 #[async_trait::async_trait]
5555 impl LlmClient for FailOnceThenSucceed {
5556 async fn complete(
5557 &self,
5558 messages: &[Message],
5559 system: Option<&str>,
5560 tools: &[ToolDefinition],
5561 ) -> Result<LlmResponse> {
5562 self.call_count.fetch_add(1, Ordering::SeqCst);
5563 let already_failed = self
5564 .failed_once
5565 .swap(true, std::sync::atomic::Ordering::SeqCst);
5566 if !already_failed {
5567 anyhow::bail!("transient network error");
5568 }
5569 self.inner.complete(messages, system, tools).await
5570 }
5571
5572 async fn complete_streaming(
5573 &self,
5574 messages: &[Message],
5575 system: Option<&str>,
5576 tools: &[ToolDefinition],
5577 ) -> Result<tokio::sync::mpsc::Receiver<crate::llm::StreamEvent>> {
5578 self.inner.complete_streaming(messages, system, tools).await
5579 }
5580 }
5581
5582 let mock = Arc::new(FailOnceThenSucceed {
5583 inner: MockLlmClient::new(vec![MockLlmClient::text_response("Recovered!")]),
5584 failed_once: std::sync::atomic::AtomicBool::new(false),
5585 call_count: AtomicUsize::new(0),
5586 });
5587
5588 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5589 let config = AgentConfig {
5590 circuit_breaker_threshold: 3,
5591 ..AgentConfig::default()
5592 };
5593 let agent = AgentLoop::new(mock.clone(), tool_executor, test_tool_context(), config);
5594 let result = agent.execute(&[], "Hello", None).await;
5595 assert!(
5596 result.is_ok(),
5597 "should succeed when LLM recovers within threshold: {:?}",
5598 result.err()
5599 );
5600 assert_eq!(result.unwrap().text, "Recovered!");
5601 assert_eq!(
5602 mock.call_count.load(Ordering::SeqCst),
5603 2,
5604 "should have made exactly 2 calls (1 fail + 1 success)"
5605 );
5606 }
5607
5608 #[test]
5611 fn test_looks_incomplete_empty() {
5612 assert!(AgentLoop::looks_incomplete(""));
5613 assert!(AgentLoop::looks_incomplete(" "));
5614 }
5615
5616 #[test]
5617 fn test_looks_incomplete_trailing_colon() {
5618 assert!(AgentLoop::looks_incomplete("Let me check the file:"));
5619 assert!(AgentLoop::looks_incomplete("Next steps:"));
5620 }
5621
5622 #[test]
5623 fn test_looks_incomplete_ellipsis() {
5624 assert!(AgentLoop::looks_incomplete("Working on it..."));
5625 assert!(AgentLoop::looks_incomplete("Processing…"));
5626 }
5627
5628 #[test]
5629 fn test_looks_incomplete_intent_phrases() {
5630 assert!(AgentLoop::looks_incomplete(
5631 "I'll start by reading the file."
5632 ));
5633 assert!(AgentLoop::looks_incomplete(
5634 "Let me check the configuration."
5635 ));
5636 assert!(AgentLoop::looks_incomplete("I will now run the tests."));
5637 assert!(AgentLoop::looks_incomplete(
5638 "I need to update the Cargo.toml."
5639 ));
5640 }
5641
5642 #[test]
5643 fn test_looks_complete_final_answer() {
5644 assert!(!AgentLoop::looks_incomplete(
5646 "The tests pass. All changes have been applied successfully."
5647 ));
5648 assert!(!AgentLoop::looks_incomplete(
5649 "Done. I've updated the three files and verified the build succeeds."
5650 ));
5651 assert!(!AgentLoop::looks_incomplete("42"));
5652 assert!(!AgentLoop::looks_incomplete("Yes."));
5653 }
5654
5655 #[test]
5656 fn test_looks_incomplete_multiline_complete() {
5657 let text = "Here is the summary:\n\n- Fixed the bug in agent.rs\n- All tests pass\n- Build succeeds";
5658 assert!(!AgentLoop::looks_incomplete(text));
5659 }
5660}