1use crate::context::{ContextProvider, ContextQuery, ContextResult};
13use crate::hitl::ConfirmationProvider;
14use crate::hooks::{
15 ErrorType, GenerateEndEvent, GenerateStartEvent, HookEvent, HookExecutor, HookResult,
16 OnErrorEvent, PostResponseEvent, PostToolUseEvent, PrePromptEvent, PreToolUseEvent,
17 TokenUsageInfo, ToolCallInfo, ToolResultData,
18};
19use crate::llm::{LlmClient, LlmResponse, Message, TokenUsage, ToolDefinition};
20use crate::permissions::{PermissionChecker, PermissionDecision};
21use crate::planning::{AgentGoal, ExecutionPlan, TaskStatus};
22use crate::prompts::SystemPromptSlots;
23use crate::queue::SessionCommand;
24use crate::session_lane_queue::SessionLaneQueue;
25use crate::tool_search::ToolIndex;
26use crate::tools::{ToolContext, ToolExecutor, ToolStreamEvent};
27use anyhow::{Context, Result};
28use async_trait::async_trait;
29use futures::future::join_all;
30use serde::{Deserialize, Serialize};
31use serde_json::Value;
32use std::sync::Arc;
33use std::time::Duration;
34use tokio::sync::{mpsc, RwLock};
35
36const MAX_TOOL_ROUNDS: usize = 50;
38
39#[derive(Clone)]
41pub struct AgentConfig {
42 pub prompt_slots: SystemPromptSlots,
48 pub tools: Vec<ToolDefinition>,
49 pub max_tool_rounds: usize,
50 pub security_provider: Option<Arc<dyn crate::security::SecurityProvider>>,
52 pub permission_checker: Option<Arc<dyn PermissionChecker>>,
54 pub confirmation_manager: Option<Arc<dyn ConfirmationProvider>>,
56 pub context_providers: Vec<Arc<dyn ContextProvider>>,
58 pub planning_enabled: bool,
60 pub goal_tracking: bool,
62 pub hook_engine: Option<Arc<dyn HookExecutor>>,
64 pub skill_registry: Option<Arc<crate::skills::SkillRegistry>>,
66 pub max_parse_retries: u32,
72 pub tool_timeout_ms: Option<u64>,
78 pub circuit_breaker_threshold: u32,
85 pub auto_compact: bool,
87 pub auto_compact_threshold: f32,
90 pub max_context_tokens: usize,
93 pub llm_client: Option<Arc<dyn LlmClient>>,
95 pub memory: Option<Arc<crate::memory::AgentMemory>>,
97 pub continuation_enabled: bool,
105 pub max_continuation_turns: u32,
109 pub tool_index: Option<ToolIndex>,
114}
115
116impl std::fmt::Debug for AgentConfig {
117 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118 f.debug_struct("AgentConfig")
119 .field("prompt_slots", &self.prompt_slots)
120 .field("tools", &self.tools)
121 .field("max_tool_rounds", &self.max_tool_rounds)
122 .field("security_provider", &self.security_provider.is_some())
123 .field("permission_checker", &self.permission_checker.is_some())
124 .field("confirmation_manager", &self.confirmation_manager.is_some())
125 .field("context_providers", &self.context_providers.len())
126 .field("planning_enabled", &self.planning_enabled)
127 .field("goal_tracking", &self.goal_tracking)
128 .field("hook_engine", &self.hook_engine.is_some())
129 .field(
130 "skill_registry",
131 &self.skill_registry.as_ref().map(|r| r.len()),
132 )
133 .field("max_parse_retries", &self.max_parse_retries)
134 .field("tool_timeout_ms", &self.tool_timeout_ms)
135 .field("circuit_breaker_threshold", &self.circuit_breaker_threshold)
136 .field("auto_compact", &self.auto_compact)
137 .field("auto_compact_threshold", &self.auto_compact_threshold)
138 .field("max_context_tokens", &self.max_context_tokens)
139 .field("continuation_enabled", &self.continuation_enabled)
140 .field("max_continuation_turns", &self.max_continuation_turns)
141 .field("memory", &self.memory.is_some())
142 .field("tool_index", &self.tool_index.as_ref().map(|i| i.len()))
143 .finish()
144 }
145}
146
147impl Default for AgentConfig {
148 fn default() -> Self {
149 Self {
150 prompt_slots: SystemPromptSlots::default(),
151 tools: Vec::new(), max_tool_rounds: MAX_TOOL_ROUNDS,
153 security_provider: None,
154 permission_checker: None,
155 confirmation_manager: None,
156 context_providers: Vec::new(),
157 planning_enabled: false,
158 goal_tracking: false,
159 hook_engine: None,
160 skill_registry: Some(Arc::new(crate::skills::SkillRegistry::with_builtins())),
161 max_parse_retries: 2,
162 tool_timeout_ms: None,
163 circuit_breaker_threshold: 3,
164 auto_compact: false,
165 auto_compact_threshold: 0.80,
166 max_context_tokens: 200_000,
167 llm_client: None,
168 memory: None,
169 continuation_enabled: true,
170 max_continuation_turns: 3,
171 tool_index: None,
172 }
173 }
174}
175
176#[derive(Debug, Clone, Serialize, Deserialize)]
182#[serde(tag = "type")]
183#[non_exhaustive]
184pub enum AgentEvent {
185 #[serde(rename = "agent_start")]
187 Start { prompt: String },
188
189 #[serde(rename = "turn_start")]
191 TurnStart { turn: usize },
192
193 #[serde(rename = "text_delta")]
195 TextDelta { text: String },
196
197 #[serde(rename = "tool_start")]
199 ToolStart { id: String, name: String },
200
201 #[serde(rename = "tool_input_delta")]
203 ToolInputDelta { delta: String },
204
205 #[serde(rename = "tool_end")]
207 ToolEnd {
208 id: String,
209 name: String,
210 output: String,
211 exit_code: i32,
212 },
213
214 #[serde(rename = "tool_output_delta")]
216 ToolOutputDelta {
217 id: String,
218 name: String,
219 delta: String,
220 },
221
222 #[serde(rename = "turn_end")]
224 TurnEnd { turn: usize, usage: TokenUsage },
225
226 #[serde(rename = "agent_end")]
228 End { text: String, usage: TokenUsage },
229
230 #[serde(rename = "error")]
232 Error { message: String },
233
234 #[serde(rename = "confirmation_required")]
236 ConfirmationRequired {
237 tool_id: String,
238 tool_name: String,
239 args: serde_json::Value,
240 timeout_ms: u64,
241 },
242
243 #[serde(rename = "confirmation_received")]
245 ConfirmationReceived {
246 tool_id: String,
247 approved: bool,
248 reason: Option<String>,
249 },
250
251 #[serde(rename = "confirmation_timeout")]
253 ConfirmationTimeout {
254 tool_id: String,
255 action_taken: String, },
257
258 #[serde(rename = "external_task_pending")]
260 ExternalTaskPending {
261 task_id: String,
262 session_id: String,
263 lane: crate::hitl::SessionLane,
264 command_type: String,
265 payload: serde_json::Value,
266 timeout_ms: u64,
267 },
268
269 #[serde(rename = "external_task_completed")]
271 ExternalTaskCompleted {
272 task_id: String,
273 session_id: String,
274 success: bool,
275 },
276
277 #[serde(rename = "permission_denied")]
279 PermissionDenied {
280 tool_id: String,
281 tool_name: String,
282 args: serde_json::Value,
283 reason: String,
284 },
285
286 #[serde(rename = "context_resolving")]
288 ContextResolving { providers: Vec<String> },
289
290 #[serde(rename = "context_resolved")]
292 ContextResolved {
293 total_items: usize,
294 total_tokens: usize,
295 },
296
297 #[serde(rename = "command_dead_lettered")]
302 CommandDeadLettered {
303 command_id: String,
304 command_type: String,
305 lane: String,
306 error: String,
307 attempts: u32,
308 },
309
310 #[serde(rename = "command_retry")]
312 CommandRetry {
313 command_id: String,
314 command_type: String,
315 lane: String,
316 attempt: u32,
317 delay_ms: u64,
318 },
319
320 #[serde(rename = "queue_alert")]
322 QueueAlert {
323 level: String,
324 alert_type: String,
325 message: String,
326 },
327
328 #[serde(rename = "task_updated")]
333 TaskUpdated {
334 session_id: String,
335 tasks: Vec<crate::planning::Task>,
336 },
337
338 #[serde(rename = "memory_stored")]
343 MemoryStored {
344 memory_id: String,
345 memory_type: String,
346 importance: f32,
347 tags: Vec<String>,
348 },
349
350 #[serde(rename = "memory_recalled")]
352 MemoryRecalled {
353 memory_id: String,
354 content: String,
355 relevance: f32,
356 },
357
358 #[serde(rename = "memories_searched")]
360 MemoriesSearched {
361 query: Option<String>,
362 tags: Vec<String>,
363 result_count: usize,
364 },
365
366 #[serde(rename = "memory_cleared")]
368 MemoryCleared {
369 tier: String, count: u64,
371 },
372
373 #[serde(rename = "subagent_start")]
378 SubagentStart {
379 task_id: String,
381 session_id: String,
383 parent_session_id: String,
385 agent: String,
387 description: String,
389 },
390
391 #[serde(rename = "subagent_progress")]
393 SubagentProgress {
394 task_id: String,
396 session_id: String,
398 status: String,
400 metadata: serde_json::Value,
402 },
403
404 #[serde(rename = "subagent_end")]
406 SubagentEnd {
407 task_id: String,
409 session_id: String,
411 agent: String,
413 output: String,
415 success: bool,
417 },
418
419 #[serde(rename = "planning_start")]
424 PlanningStart { prompt: String },
425
426 #[serde(rename = "planning_end")]
428 PlanningEnd {
429 plan: ExecutionPlan,
430 estimated_steps: usize,
431 },
432
433 #[serde(rename = "step_start")]
435 StepStart {
436 step_id: String,
437 description: String,
438 step_number: usize,
439 total_steps: usize,
440 },
441
442 #[serde(rename = "step_end")]
444 StepEnd {
445 step_id: String,
446 status: TaskStatus,
447 step_number: usize,
448 total_steps: usize,
449 },
450
451 #[serde(rename = "goal_extracted")]
453 GoalExtracted { goal: AgentGoal },
454
455 #[serde(rename = "goal_progress")]
457 GoalProgress {
458 goal: String,
459 progress: f32,
460 completed_steps: usize,
461 total_steps: usize,
462 },
463
464 #[serde(rename = "goal_achieved")]
466 GoalAchieved {
467 goal: String,
468 total_steps: usize,
469 duration_ms: i64,
470 },
471
472 #[serde(rename = "context_compacted")]
477 ContextCompacted {
478 session_id: String,
479 before_messages: usize,
480 after_messages: usize,
481 percent_before: f32,
482 },
483
484 #[serde(rename = "persistence_failed")]
489 PersistenceFailed {
490 session_id: String,
491 operation: String,
492 error: String,
493 },
494}
495
496#[derive(Debug, Clone)]
498pub struct AgentResult {
499 pub text: String,
500 pub messages: Vec<Message>,
501 pub usage: TokenUsage,
502 pub tool_calls_count: usize,
503}
504
505pub struct ToolCommand {
513 tool_executor: Arc<ToolExecutor>,
514 tool_name: String,
515 tool_args: Value,
516 tool_context: ToolContext,
517 skill_registry: Option<Arc<crate::skills::SkillRegistry>>,
518}
519
520impl ToolCommand {
521 pub fn new(
523 tool_executor: Arc<ToolExecutor>,
524 tool_name: String,
525 tool_args: Value,
526 tool_context: ToolContext,
527 skill_registry: Option<Arc<crate::skills::SkillRegistry>>,
528 ) -> Self {
529 Self {
530 tool_executor,
531 tool_name,
532 tool_args,
533 tool_context,
534 skill_registry,
535 }
536 }
537}
538
539#[async_trait]
540impl SessionCommand for ToolCommand {
541 async fn execute(&self) -> Result<Value> {
542 if let Some(registry) = &self.skill_registry {
544 let instruction_skills = registry.by_kind(crate::skills::SkillKind::Instruction);
545
546 let has_restrictions = instruction_skills.iter().any(|s| s.allowed_tools.is_some());
548
549 if has_restrictions {
550 let mut allowed = false;
551
552 for skill in &instruction_skills {
553 if skill.is_tool_allowed(&self.tool_name) {
554 allowed = true;
555 break;
556 }
557 }
558
559 if !allowed {
560 return Err(anyhow::anyhow!(
561 "Tool '{}' is not allowed by any active skill. Active skills restrict tools to their allowed-tools lists.",
562 self.tool_name
563 ));
564 }
565 }
566 }
567
568 let result = self
570 .tool_executor
571 .execute_with_context(&self.tool_name, &self.tool_args, &self.tool_context)
572 .await?;
573 Ok(serde_json::json!({
574 "output": result.output,
575 "exit_code": result.exit_code,
576 "metadata": result.metadata,
577 }))
578 }
579
580 fn command_type(&self) -> &str {
581 &self.tool_name
582 }
583
584 fn payload(&self) -> Value {
585 self.tool_args.clone()
586 }
587}
588
589#[derive(Clone)]
595pub struct AgentLoop {
596 llm_client: Arc<dyn LlmClient>,
597 tool_executor: Arc<ToolExecutor>,
598 tool_context: ToolContext,
599 config: AgentConfig,
600 tool_metrics: Option<Arc<RwLock<crate::telemetry::ToolMetrics>>>,
602 command_queue: Option<Arc<SessionLaneQueue>>,
604}
605
606impl AgentLoop {
607 pub fn new(
608 llm_client: Arc<dyn LlmClient>,
609 tool_executor: Arc<ToolExecutor>,
610 tool_context: ToolContext,
611 config: AgentConfig,
612 ) -> Self {
613 Self {
614 llm_client,
615 tool_executor,
616 tool_context,
617 config,
618 tool_metrics: None,
619 command_queue: None,
620 }
621 }
622
623 pub fn with_tool_metrics(
625 mut self,
626 metrics: Arc<RwLock<crate::telemetry::ToolMetrics>>,
627 ) -> Self {
628 self.tool_metrics = Some(metrics);
629 self
630 }
631
632 pub fn with_queue(mut self, queue: Arc<SessionLaneQueue>) -> Self {
637 self.command_queue = Some(queue);
638 self
639 }
640
641 async fn execute_tool_timed(
647 &self,
648 name: &str,
649 args: &serde_json::Value,
650 ctx: &ToolContext,
651 ) -> anyhow::Result<crate::tools::ToolResult> {
652 let fut = self.tool_executor.execute_with_context(name, args, ctx);
653 if let Some(timeout_ms) = self.config.tool_timeout_ms {
654 match tokio::time::timeout(Duration::from_millis(timeout_ms), fut).await {
655 Ok(result) => result,
656 Err(_) => Err(anyhow::anyhow!(
657 "Tool '{}' timed out after {}ms",
658 name,
659 timeout_ms
660 )),
661 }
662 } else {
663 fut.await
664 }
665 }
666
667 fn tool_result_to_tuple(
669 result: anyhow::Result<crate::tools::ToolResult>,
670 ) -> (
671 String,
672 i32,
673 bool,
674 Option<serde_json::Value>,
675 Vec<crate::llm::Attachment>,
676 ) {
677 match result {
678 Ok(r) => (
679 r.output,
680 r.exit_code,
681 r.exit_code != 0,
682 r.metadata,
683 r.images,
684 ),
685 Err(e) => (
686 format!("Tool execution error: {}", e),
687 1,
688 true,
689 None,
690 Vec::new(),
691 ),
692 }
693 }
694
695 async fn execute_tool_queued_or_direct(
697 &self,
698 name: &str,
699 args: &serde_json::Value,
700 ctx: &ToolContext,
701 ) -> anyhow::Result<crate::tools::ToolResult> {
702 if let Some(ref queue) = self.command_queue {
703 let command = ToolCommand::new(
704 Arc::clone(&self.tool_executor),
705 name.to_string(),
706 args.clone(),
707 ctx.clone(),
708 self.config.skill_registry.clone(),
709 );
710 let rx = queue.submit_by_tool(name, Box::new(command)).await;
711 match rx.await {
712 Ok(Ok(value)) => {
713 let output = value["output"]
714 .as_str()
715 .ok_or_else(|| {
716 anyhow::anyhow!(
717 "Queue result missing 'output' field for tool '{}'",
718 name
719 )
720 })?
721 .to_string();
722 let exit_code = value["exit_code"].as_i64().unwrap_or(0) as i32;
723 return Ok(crate::tools::ToolResult {
724 name: name.to_string(),
725 output,
726 exit_code,
727 metadata: None,
728 images: Vec::new(),
729 });
730 }
731 Ok(Err(e)) => {
732 tracing::warn!(
733 "Queue execution failed for tool '{}', falling back to direct: {}",
734 name,
735 e
736 );
737 }
738 Err(_) => {
739 tracing::warn!(
740 "Queue channel closed for tool '{}', falling back to direct",
741 name
742 );
743 }
744 }
745 }
746 self.execute_tool_timed(name, args, ctx).await
747 }
748
749 async fn call_llm(
760 &self,
761 messages: &[Message],
762 system: Option<&str>,
763 event_tx: &Option<mpsc::Sender<AgentEvent>>,
764 ) -> anyhow::Result<LlmResponse> {
765 let tools = if let Some(ref index) = self.config.tool_index {
767 let query = messages
768 .iter()
769 .rev()
770 .find(|m| m.role == "user")
771 .and_then(|m| {
772 m.content.iter().find_map(|b| match b {
773 crate::llm::ContentBlock::Text { text } => Some(text.as_str()),
774 _ => None,
775 })
776 })
777 .unwrap_or("");
778 let matches = index.search(query, index.len());
779 let matched_names: std::collections::HashSet<&str> =
780 matches.iter().map(|m| m.name.as_str()).collect();
781 self.config
782 .tools
783 .iter()
784 .filter(|t| matched_names.contains(t.name.as_str()))
785 .cloned()
786 .collect::<Vec<_>>()
787 } else {
788 self.config.tools.clone()
789 };
790
791 if event_tx.is_some() {
792 let mut stream_rx = self
793 .llm_client
794 .complete_streaming(messages, system, &tools)
795 .await
796 .context("LLM streaming call failed")?;
797
798 let mut final_response: Option<LlmResponse> = None;
799 while let Some(event) = stream_rx.recv().await {
800 match event {
801 crate::llm::StreamEvent::TextDelta(text) => {
802 if let Some(tx) = event_tx {
803 tx.send(AgentEvent::TextDelta { text }).await.ok();
804 }
805 }
806 crate::llm::StreamEvent::ToolUseStart { id, name } => {
807 if let Some(tx) = event_tx {
808 tx.send(AgentEvent::ToolStart { id, name }).await.ok();
809 }
810 }
811 crate::llm::StreamEvent::ToolUseInputDelta(delta) => {
812 if let Some(tx) = event_tx {
813 tx.send(AgentEvent::ToolInputDelta { delta }).await.ok();
814 }
815 }
816 crate::llm::StreamEvent::Done(resp) => {
817 final_response = Some(resp);
818 }
819 }
820 }
821 final_response.context("Stream ended without final response")
822 } else {
823 self.llm_client
824 .complete(messages, system, &tools)
825 .await
826 .context("LLM call failed")
827 }
828 }
829
830 fn streaming_tool_context(
839 &self,
840 event_tx: &Option<mpsc::Sender<AgentEvent>>,
841 tool_id: &str,
842 tool_name: &str,
843 ) -> ToolContext {
844 let mut ctx = self.tool_context.clone();
845 if let Some(agent_tx) = event_tx {
846 let (tool_tx, mut tool_rx) = mpsc::channel::<ToolStreamEvent>(64);
847 ctx.event_tx = Some(tool_tx);
848
849 let agent_tx = agent_tx.clone();
850 let tool_id = tool_id.to_string();
851 let tool_name = tool_name.to_string();
852 tokio::spawn(async move {
853 while let Some(event) = tool_rx.recv().await {
854 match event {
855 ToolStreamEvent::OutputDelta(delta) => {
856 agent_tx
857 .send(AgentEvent::ToolOutputDelta {
858 id: tool_id.clone(),
859 name: tool_name.clone(),
860 delta,
861 })
862 .await
863 .ok();
864 }
865 }
866 }
867 });
868 }
869 ctx
870 }
871
872 async fn resolve_context(&self, prompt: &str, session_id: Option<&str>) -> Vec<ContextResult> {
876 if self.config.context_providers.is_empty() {
877 return Vec::new();
878 }
879
880 let query = ContextQuery::new(prompt).with_session_id(session_id.unwrap_or(""));
881
882 let futures = self
883 .config
884 .context_providers
885 .iter()
886 .map(|p| p.query(&query));
887 let outcomes = join_all(futures).await;
888
889 outcomes
890 .into_iter()
891 .enumerate()
892 .filter_map(|(i, r)| match r {
893 Ok(result) if !result.is_empty() => Some(result),
894 Ok(_) => None,
895 Err(e) => {
896 tracing::warn!(
897 "Context provider '{}' failed: {}",
898 self.config.context_providers[i].name(),
899 e
900 );
901 None
902 }
903 })
904 .collect()
905 }
906
907 fn looks_incomplete(text: &str) -> bool {
915 let t = text.trim();
916 if t.is_empty() {
917 return true;
918 }
919 if t.len() < 80 && !t.contains('\n') {
921 let ends_continuation =
924 t.ends_with(':') || t.ends_with("...") || t.ends_with('…') || t.ends_with(',');
925 if ends_continuation {
926 return true;
927 }
928 }
929 let incomplete_phrases = [
931 "i'll ",
932 "i will ",
933 "let me ",
934 "i need to ",
935 "i should ",
936 "next, i",
937 "first, i",
938 "now i",
939 "i'll start",
940 "i'll begin",
941 "i'll now",
942 "let's start",
943 "let's begin",
944 "to do this",
945 "i'm going to",
946 ];
947 let lower = t.to_lowercase();
948 for phrase in &incomplete_phrases {
949 if lower.contains(phrase) {
950 return true;
951 }
952 }
953 false
954 }
955
956 fn system_prompt(&self) -> String {
958 self.config.prompt_slots.build()
959 }
960
961 fn build_augmented_system_prompt(&self, context_results: &[ContextResult]) -> Option<String> {
963 let base = self.system_prompt();
964 if context_results.is_empty() {
965 return Some(base);
966 }
967
968 let context_xml: String = context_results
970 .iter()
971 .map(|r| r.to_xml())
972 .collect::<Vec<_>>()
973 .join("\n\n");
974
975 Some(format!("{}\n\n{}", base, context_xml))
976 }
977
978 async fn notify_turn_complete(&self, session_id: &str, prompt: &str, response: &str) {
980 let futures = self
981 .config
982 .context_providers
983 .iter()
984 .map(|p| p.on_turn_complete(session_id, prompt, response));
985 let outcomes = join_all(futures).await;
986
987 for (i, result) in outcomes.into_iter().enumerate() {
988 if let Err(e) = result {
989 tracing::warn!(
990 "Context provider '{}' on_turn_complete failed: {}",
991 self.config.context_providers[i].name(),
992 e
993 );
994 }
995 }
996 }
997
998 async fn fire_pre_tool_use(
1001 &self,
1002 session_id: &str,
1003 tool_name: &str,
1004 args: &serde_json::Value,
1005 ) -> Option<HookResult> {
1006 if let Some(he) = &self.config.hook_engine {
1007 let event = HookEvent::PreToolUse(PreToolUseEvent {
1008 session_id: session_id.to_string(),
1009 tool: tool_name.to_string(),
1010 args: args.clone(),
1011 working_directory: self.tool_context.workspace.to_string_lossy().to_string(),
1012 recent_tools: Vec::new(),
1013 });
1014 let result = he.fire(&event).await;
1015 if result.is_block() {
1016 return Some(result);
1017 }
1018 }
1019 None
1020 }
1021
1022 async fn fire_post_tool_use(
1024 &self,
1025 session_id: &str,
1026 tool_name: &str,
1027 args: &serde_json::Value,
1028 output: &str,
1029 success: bool,
1030 duration_ms: u64,
1031 ) {
1032 if let Some(he) = &self.config.hook_engine {
1033 let event = HookEvent::PostToolUse(PostToolUseEvent {
1034 session_id: session_id.to_string(),
1035 tool: tool_name.to_string(),
1036 args: args.clone(),
1037 result: ToolResultData {
1038 success,
1039 output: output.to_string(),
1040 exit_code: if success { Some(0) } else { Some(1) },
1041 duration_ms,
1042 },
1043 });
1044 let he = Arc::clone(he);
1045 tokio::spawn(async move {
1046 let _ = he.fire(&event).await;
1047 });
1048 }
1049 }
1050
1051 async fn fire_generate_start(
1053 &self,
1054 session_id: &str,
1055 prompt: &str,
1056 system_prompt: &Option<String>,
1057 ) {
1058 if let Some(he) = &self.config.hook_engine {
1059 let event = HookEvent::GenerateStart(GenerateStartEvent {
1060 session_id: session_id.to_string(),
1061 prompt: prompt.to_string(),
1062 system_prompt: system_prompt.clone(),
1063 model_provider: String::new(),
1064 model_name: String::new(),
1065 available_tools: self.config.tools.iter().map(|t| t.name.clone()).collect(),
1066 });
1067 let _ = he.fire(&event).await;
1068 }
1069 }
1070
1071 async fn fire_generate_end(
1073 &self,
1074 session_id: &str,
1075 prompt: &str,
1076 response: &LlmResponse,
1077 duration_ms: u64,
1078 ) {
1079 if let Some(he) = &self.config.hook_engine {
1080 let tool_calls: Vec<ToolCallInfo> = response
1081 .tool_calls()
1082 .iter()
1083 .map(|tc| ToolCallInfo {
1084 name: tc.name.clone(),
1085 args: tc.args.clone(),
1086 })
1087 .collect();
1088
1089 let event = HookEvent::GenerateEnd(GenerateEndEvent {
1090 session_id: session_id.to_string(),
1091 prompt: prompt.to_string(),
1092 response_text: response.text().to_string(),
1093 tool_calls,
1094 usage: TokenUsageInfo {
1095 prompt_tokens: response.usage.prompt_tokens as i32,
1096 completion_tokens: response.usage.completion_tokens as i32,
1097 total_tokens: response.usage.total_tokens as i32,
1098 },
1099 duration_ms,
1100 });
1101 let _ = he.fire(&event).await;
1102 }
1103 }
1104
1105 async fn fire_pre_prompt(
1108 &self,
1109 session_id: &str,
1110 prompt: &str,
1111 system_prompt: &Option<String>,
1112 message_count: usize,
1113 ) -> Option<String> {
1114 if let Some(he) = &self.config.hook_engine {
1115 let event = HookEvent::PrePrompt(PrePromptEvent {
1116 session_id: session_id.to_string(),
1117 prompt: prompt.to_string(),
1118 system_prompt: system_prompt.clone(),
1119 message_count,
1120 });
1121 let result = he.fire(&event).await;
1122 if let HookResult::Continue(Some(modified)) = result {
1123 if let Some(new_prompt) = modified.get("prompt").and_then(|v| v.as_str()) {
1125 return Some(new_prompt.to_string());
1126 }
1127 }
1128 }
1129 None
1130 }
1131
1132 async fn fire_post_response(
1134 &self,
1135 session_id: &str,
1136 response_text: &str,
1137 tool_calls_count: usize,
1138 usage: &TokenUsage,
1139 duration_ms: u64,
1140 ) {
1141 if let Some(he) = &self.config.hook_engine {
1142 let event = HookEvent::PostResponse(PostResponseEvent {
1143 session_id: session_id.to_string(),
1144 response_text: response_text.to_string(),
1145 tool_calls_count,
1146 usage: TokenUsageInfo {
1147 prompt_tokens: usage.prompt_tokens as i32,
1148 completion_tokens: usage.completion_tokens as i32,
1149 total_tokens: usage.total_tokens as i32,
1150 },
1151 duration_ms,
1152 });
1153 let he = Arc::clone(he);
1154 tokio::spawn(async move {
1155 let _ = he.fire(&event).await;
1156 });
1157 }
1158 }
1159
1160 async fn fire_on_error(
1162 &self,
1163 session_id: &str,
1164 error_type: ErrorType,
1165 error_message: &str,
1166 context: serde_json::Value,
1167 ) {
1168 if let Some(he) = &self.config.hook_engine {
1169 let event = HookEvent::OnError(OnErrorEvent {
1170 session_id: session_id.to_string(),
1171 error_type,
1172 error_message: error_message.to_string(),
1173 context,
1174 });
1175 let he = Arc::clone(he);
1176 tokio::spawn(async move {
1177 let _ = he.fire(&event).await;
1178 });
1179 }
1180 }
1181
1182 pub async fn execute(
1188 &self,
1189 history: &[Message],
1190 prompt: &str,
1191 event_tx: Option<mpsc::Sender<AgentEvent>>,
1192 ) -> Result<AgentResult> {
1193 self.execute_with_session(history, prompt, None, event_tx)
1194 .await
1195 }
1196
1197 pub async fn execute_from_messages(
1203 &self,
1204 messages: Vec<Message>,
1205 session_id: Option<&str>,
1206 event_tx: Option<mpsc::Sender<AgentEvent>>,
1207 ) -> Result<AgentResult> {
1208 tracing::info!(
1209 a3s.session.id = session_id.unwrap_or("none"),
1210 a3s.agent.max_turns = self.config.max_tool_rounds,
1211 "a3s.agent.execute_from_messages started"
1212 );
1213
1214 let effective_prompt = messages
1218 .iter()
1219 .rev()
1220 .find(|m| m.role == "user")
1221 .map(|m| m.text())
1222 .unwrap_or_default();
1223
1224 let result = self
1225 .execute_loop_inner(&messages, "", &effective_prompt, session_id, event_tx)
1226 .await;
1227
1228 match &result {
1229 Ok(r) => tracing::info!(
1230 a3s.agent.tool_calls_count = r.tool_calls_count,
1231 a3s.llm.total_tokens = r.usage.total_tokens,
1232 "a3s.agent.execute_from_messages completed"
1233 ),
1234 Err(e) => tracing::warn!(
1235 error = %e,
1236 "a3s.agent.execute_from_messages failed"
1237 ),
1238 }
1239
1240 result
1241 }
1242
1243 pub async fn execute_with_session(
1248 &self,
1249 history: &[Message],
1250 prompt: &str,
1251 session_id: Option<&str>,
1252 event_tx: Option<mpsc::Sender<AgentEvent>>,
1253 ) -> Result<AgentResult> {
1254 tracing::info!(
1255 a3s.session.id = session_id.unwrap_or("none"),
1256 a3s.agent.max_turns = self.config.max_tool_rounds,
1257 "a3s.agent.execute started"
1258 );
1259
1260 let result = if self.config.planning_enabled {
1262 self.execute_with_planning(history, prompt, event_tx).await
1263 } else {
1264 self.execute_loop(history, prompt, session_id, event_tx)
1265 .await
1266 };
1267
1268 match &result {
1269 Ok(r) => {
1270 tracing::info!(
1271 a3s.agent.tool_calls_count = r.tool_calls_count,
1272 a3s.llm.total_tokens = r.usage.total_tokens,
1273 "a3s.agent.execute completed"
1274 );
1275 self.fire_post_response(
1277 session_id.unwrap_or(""),
1278 &r.text,
1279 r.tool_calls_count,
1280 &r.usage,
1281 0, )
1283 .await;
1284 }
1285 Err(e) => {
1286 tracing::warn!(
1287 error = %e,
1288 "a3s.agent.execute failed"
1289 );
1290 self.fire_on_error(
1292 session_id.unwrap_or(""),
1293 ErrorType::Other,
1294 &e.to_string(),
1295 serde_json::json!({"phase": "execute"}),
1296 )
1297 .await;
1298 }
1299 }
1300
1301 result
1302 }
1303
1304 async fn execute_loop(
1310 &self,
1311 history: &[Message],
1312 prompt: &str,
1313 session_id: Option<&str>,
1314 event_tx: Option<mpsc::Sender<AgentEvent>>,
1315 ) -> Result<AgentResult> {
1316 self.execute_loop_inner(history, prompt, prompt, session_id, event_tx)
1319 .await
1320 }
1321
1322 async fn execute_loop_inner(
1327 &self,
1328 history: &[Message],
1329 msg_prompt: &str,
1330 effective_prompt: &str,
1331 session_id: Option<&str>,
1332 event_tx: Option<mpsc::Sender<AgentEvent>>,
1333 ) -> Result<AgentResult> {
1334 let mut messages = history.to_vec();
1335 let mut total_usage = TokenUsage::default();
1336 let mut tool_calls_count = 0;
1337 let mut turn = 0;
1338 let mut parse_error_count: u32 = 0;
1340 let mut continuation_count: u32 = 0;
1342
1343 if let Some(tx) = &event_tx {
1345 tx.send(AgentEvent::Start {
1346 prompt: effective_prompt.to_string(),
1347 })
1348 .await
1349 .ok();
1350 }
1351
1352 let _queue_forward_handle =
1354 if let (Some(ref queue), Some(ref tx)) = (&self.command_queue, &event_tx) {
1355 let mut rx = queue.subscribe();
1356 let tx = tx.clone();
1357 Some(tokio::spawn(async move {
1358 while let Ok(event) = rx.recv().await {
1359 if tx.send(event).await.is_err() {
1360 break;
1361 }
1362 }
1363 }))
1364 } else {
1365 None
1366 };
1367
1368 let built_system_prompt = Some(self.system_prompt());
1370 let hooked_prompt = if let Some(modified) = self
1371 .fire_pre_prompt(
1372 session_id.unwrap_or(""),
1373 effective_prompt,
1374 &built_system_prompt,
1375 messages.len(),
1376 )
1377 .await
1378 {
1379 modified
1380 } else {
1381 effective_prompt.to_string()
1382 };
1383 let effective_prompt = hooked_prompt.as_str();
1384
1385 if let Some(ref sp) = self.config.security_provider {
1387 sp.taint_input(effective_prompt);
1388 }
1389
1390 let system_with_memory = if let Some(ref memory) = self.config.memory {
1392 match memory.recall_similar(effective_prompt, 5).await {
1393 Ok(items) if !items.is_empty() => {
1394 if let Some(tx) = &event_tx {
1395 for item in &items {
1396 tx.send(AgentEvent::MemoryRecalled {
1397 memory_id: item.id.clone(),
1398 content: item.content.clone(),
1399 relevance: item.relevance_score(),
1400 })
1401 .await
1402 .ok();
1403 }
1404 tx.send(AgentEvent::MemoriesSearched {
1405 query: Some(effective_prompt.to_string()),
1406 tags: Vec::new(),
1407 result_count: items.len(),
1408 })
1409 .await
1410 .ok();
1411 }
1412 let memory_context = items
1413 .iter()
1414 .map(|i| format!("- {}", i.content))
1415 .collect::<Vec<_>>()
1416 .join(
1417 "
1418",
1419 );
1420 let base = self.system_prompt();
1421 Some(format!(
1422 "{}
1423
1424## Relevant past experience
1425{}",
1426 base, memory_context
1427 ))
1428 }
1429 _ => Some(self.system_prompt()),
1430 }
1431 } else {
1432 Some(self.system_prompt())
1433 };
1434
1435 let augmented_system = if !self.config.context_providers.is_empty() {
1437 if let Some(tx) = &event_tx {
1439 let provider_names: Vec<String> = self
1440 .config
1441 .context_providers
1442 .iter()
1443 .map(|p| p.name().to_string())
1444 .collect();
1445 tx.send(AgentEvent::ContextResolving {
1446 providers: provider_names,
1447 })
1448 .await
1449 .ok();
1450 }
1451
1452 tracing::info!(
1453 a3s.context.providers = self.config.context_providers.len() as i64,
1454 "Context resolution started"
1455 );
1456 let context_results = self.resolve_context(effective_prompt, session_id).await;
1457
1458 if let Some(tx) = &event_tx {
1460 let total_items: usize = context_results.iter().map(|r| r.items.len()).sum();
1461 let total_tokens: usize = context_results.iter().map(|r| r.total_tokens).sum();
1462
1463 tracing::info!(
1464 context_items = total_items,
1465 context_tokens = total_tokens,
1466 "Context resolution completed"
1467 );
1468
1469 tx.send(AgentEvent::ContextResolved {
1470 total_items,
1471 total_tokens,
1472 })
1473 .await
1474 .ok();
1475 }
1476
1477 self.build_augmented_system_prompt(&context_results)
1478 } else {
1479 Some(self.system_prompt())
1480 };
1481
1482 let base_prompt = self.system_prompt();
1484 let augmented_system = match (augmented_system, system_with_memory) {
1485 (Some(ctx), Some(mem)) if ctx != mem => Some(ctx.replacen(&base_prompt, &mem, 1)),
1486 (Some(ctx), _) => Some(ctx),
1487 (None, mem) => mem,
1488 };
1489
1490 if !msg_prompt.is_empty() {
1492 messages.push(Message::user(msg_prompt));
1493 }
1494
1495 loop {
1496 turn += 1;
1497
1498 if turn > self.config.max_tool_rounds {
1499 let error = format!("Max tool rounds ({}) exceeded", self.config.max_tool_rounds);
1500 if let Some(tx) = &event_tx {
1501 tx.send(AgentEvent::Error {
1502 message: error.clone(),
1503 })
1504 .await
1505 .ok();
1506 }
1507 anyhow::bail!(error);
1508 }
1509
1510 if let Some(tx) = &event_tx {
1512 tx.send(AgentEvent::TurnStart { turn }).await.ok();
1513 }
1514
1515 tracing::info!(
1516 turn = turn,
1517 max_turns = self.config.max_tool_rounds,
1518 "Agent turn started"
1519 );
1520
1521 tracing::info!(
1523 a3s.llm.streaming = event_tx.is_some(),
1524 "LLM completion started"
1525 );
1526
1527 self.fire_generate_start(
1529 session_id.unwrap_or(""),
1530 effective_prompt,
1531 &augmented_system,
1532 )
1533 .await;
1534
1535 let llm_start = std::time::Instant::now();
1536 let response = {
1540 let threshold = self.config.circuit_breaker_threshold.max(1);
1541 let mut attempt = 0u32;
1542 loop {
1543 attempt += 1;
1544 let result = self
1545 .call_llm(&messages, augmented_system.as_deref(), &event_tx)
1546 .await;
1547 match result {
1548 Ok(r) => {
1549 break r;
1550 }
1551 Err(e) if attempt < threshold && (event_tx.is_none() || attempt == 1) => {
1553 tracing::warn!(
1554 turn = turn,
1555 attempt = attempt,
1556 threshold = threshold,
1557 error = %e,
1558 "LLM call failed, will retry"
1559 );
1560 tokio::time::sleep(Duration::from_millis(100 * attempt as u64)).await;
1561 }
1562 Err(e) => {
1564 let msg = if attempt > 1 {
1565 format!(
1566 "LLM circuit breaker triggered: failed after {} attempt(s): {}",
1567 attempt, e
1568 )
1569 } else {
1570 format!("LLM call failed: {}", e)
1571 };
1572 tracing::error!(turn = turn, attempt = attempt, "{}", msg);
1573 self.fire_on_error(
1575 session_id.unwrap_or(""),
1576 ErrorType::LlmFailure,
1577 &msg,
1578 serde_json::json!({"turn": turn, "attempt": attempt}),
1579 )
1580 .await;
1581 if let Some(tx) = &event_tx {
1582 tx.send(AgentEvent::Error {
1583 message: msg.clone(),
1584 })
1585 .await
1586 .ok();
1587 }
1588 anyhow::bail!(msg);
1589 }
1590 }
1591 }
1592 };
1593
1594 total_usage.prompt_tokens += response.usage.prompt_tokens;
1596 total_usage.completion_tokens += response.usage.completion_tokens;
1597 total_usage.total_tokens += response.usage.total_tokens;
1598
1599 let llm_duration = llm_start.elapsed();
1601 tracing::info!(
1602 turn = turn,
1603 streaming = event_tx.is_some(),
1604 prompt_tokens = response.usage.prompt_tokens,
1605 completion_tokens = response.usage.completion_tokens,
1606 total_tokens = response.usage.total_tokens,
1607 stop_reason = response.stop_reason.as_deref().unwrap_or("unknown"),
1608 duration_ms = llm_duration.as_millis() as u64,
1609 "LLM completion finished"
1610 );
1611
1612 self.fire_generate_end(
1614 session_id.unwrap_or(""),
1615 effective_prompt,
1616 &response,
1617 llm_duration.as_millis() as u64,
1618 )
1619 .await;
1620
1621 crate::telemetry::record_llm_usage(
1623 response.usage.prompt_tokens,
1624 response.usage.completion_tokens,
1625 response.usage.total_tokens,
1626 response.stop_reason.as_deref(),
1627 );
1628 tracing::info!(
1630 turn = turn,
1631 a3s.llm.total_tokens = response.usage.total_tokens,
1632 "Turn token usage"
1633 );
1634
1635 messages.push(response.message.clone());
1637
1638 let tool_calls = response.tool_calls();
1640
1641 if let Some(tx) = &event_tx {
1643 tx.send(AgentEvent::TurnEnd {
1644 turn,
1645 usage: response.usage.clone(),
1646 })
1647 .await
1648 .ok();
1649 }
1650
1651 if self.config.auto_compact {
1653 let used = response.usage.prompt_tokens;
1654 let max = self.config.max_context_tokens;
1655 let threshold = self.config.auto_compact_threshold;
1656
1657 if crate::session::compaction::should_auto_compact(used, max, threshold) {
1658 let before_len = messages.len();
1659 let percent_before = used as f32 / max as f32;
1660
1661 tracing::info!(
1662 used_tokens = used,
1663 max_tokens = max,
1664 percent = percent_before,
1665 threshold = threshold,
1666 "Auto-compact triggered"
1667 );
1668
1669 if let Some(pruned) = crate::session::compaction::prune_tool_outputs(&messages)
1671 {
1672 messages = pruned;
1673 tracing::info!("Tool output pruning applied");
1674 }
1675
1676 if let Ok(Some(compacted)) = crate::session::compaction::compact_messages(
1678 session_id.unwrap_or(""),
1679 &messages,
1680 &self.llm_client,
1681 )
1682 .await
1683 {
1684 messages = compacted;
1685 }
1686
1687 if let Some(tx) = &event_tx {
1689 tx.send(AgentEvent::ContextCompacted {
1690 session_id: session_id.unwrap_or("").to_string(),
1691 before_messages: before_len,
1692 after_messages: messages.len(),
1693 percent_before,
1694 })
1695 .await
1696 .ok();
1697 }
1698 }
1699 }
1700
1701 if tool_calls.is_empty() {
1702 let final_text = response.text();
1705
1706 if self.config.continuation_enabled
1707 && continuation_count < self.config.max_continuation_turns
1708 && Self::looks_incomplete(&final_text)
1709 {
1710 continuation_count += 1;
1711 tracing::info!(
1712 turn = turn,
1713 continuation = continuation_count,
1714 max_continuation = self.config.max_continuation_turns,
1715 "Injecting continuation message — response looks incomplete"
1716 );
1717 messages.push(Message::user(crate::prompts::CONTINUATION));
1719 continue;
1720 }
1721
1722 let final_text = if let Some(ref sp) = self.config.security_provider {
1724 sp.sanitize_output(&final_text)
1725 } else {
1726 final_text
1727 };
1728
1729 tracing::info!(
1731 tool_calls_count = tool_calls_count,
1732 total_prompt_tokens = total_usage.prompt_tokens,
1733 total_completion_tokens = total_usage.completion_tokens,
1734 total_tokens = total_usage.total_tokens,
1735 turns = turn,
1736 "Agent execution completed"
1737 );
1738
1739 if let Some(tx) = &event_tx {
1740 tx.send(AgentEvent::End {
1741 text: final_text.clone(),
1742 usage: total_usage.clone(),
1743 })
1744 .await
1745 .ok();
1746 }
1747
1748 if let Some(sid) = session_id {
1750 self.notify_turn_complete(sid, effective_prompt, &final_text)
1751 .await;
1752 }
1753
1754 return Ok(AgentResult {
1755 text: final_text,
1756 messages,
1757 usage: total_usage,
1758 tool_calls_count,
1759 });
1760 }
1761
1762 for tool_call in tool_calls {
1764 tool_calls_count += 1;
1765
1766 let tool_start = std::time::Instant::now();
1767
1768 tracing::info!(
1769 tool_name = tool_call.name.as_str(),
1770 tool_id = tool_call.id.as_str(),
1771 "Tool execution started"
1772 );
1773
1774 if let Some(parse_error) =
1780 tool_call.args.get("__parse_error").and_then(|v| v.as_str())
1781 {
1782 parse_error_count += 1;
1783 let error_msg = format!("Error: {}", parse_error);
1784 tracing::warn!(
1785 tool = tool_call.name.as_str(),
1786 parse_error_count = parse_error_count,
1787 max_parse_retries = self.config.max_parse_retries,
1788 "Malformed tool arguments from LLM"
1789 );
1790
1791 if let Some(tx) = &event_tx {
1792 tx.send(AgentEvent::ToolEnd {
1793 id: tool_call.id.clone(),
1794 name: tool_call.name.clone(),
1795 output: error_msg.clone(),
1796 exit_code: 1,
1797 })
1798 .await
1799 .ok();
1800 }
1801
1802 messages.push(Message::tool_result(&tool_call.id, &error_msg, true));
1803
1804 if parse_error_count > self.config.max_parse_retries {
1805 let msg = format!(
1806 "LLM produced malformed tool arguments {} time(s) in a row \
1807 (max_parse_retries={}); giving up",
1808 parse_error_count, self.config.max_parse_retries
1809 );
1810 tracing::error!("{}", msg);
1811 if let Some(tx) = &event_tx {
1812 tx.send(AgentEvent::Error {
1813 message: msg.clone(),
1814 })
1815 .await
1816 .ok();
1817 }
1818 anyhow::bail!(msg);
1819 }
1820 continue;
1821 }
1822
1823 parse_error_count = 0;
1825
1826 if let Some(ref registry) = self.config.skill_registry {
1828 let instruction_skills =
1829 registry.by_kind(crate::skills::SkillKind::Instruction);
1830 let has_restrictions =
1831 instruction_skills.iter().any(|s| s.allowed_tools.is_some());
1832 if has_restrictions {
1833 let allowed = instruction_skills
1834 .iter()
1835 .any(|s| s.is_tool_allowed(&tool_call.name));
1836 if !allowed {
1837 let msg = format!(
1838 "Tool '{}' is not allowed by any active skill.",
1839 tool_call.name
1840 );
1841 tracing::info!(
1842 tool_name = tool_call.name.as_str(),
1843 "Tool blocked by skill registry"
1844 );
1845 if let Some(tx) = &event_tx {
1846 tx.send(AgentEvent::PermissionDenied {
1847 tool_id: tool_call.id.clone(),
1848 tool_name: tool_call.name.clone(),
1849 args: tool_call.args.clone(),
1850 reason: msg.clone(),
1851 })
1852 .await
1853 .ok();
1854 }
1855 messages.push(Message::tool_result(&tool_call.id, &msg, true));
1856 continue;
1857 }
1858 }
1859 }
1860
1861 if let Some(HookResult::Block(reason)) = self
1863 .fire_pre_tool_use(session_id.unwrap_or(""), &tool_call.name, &tool_call.args)
1864 .await
1865 {
1866 let msg = format!("Tool '{}' blocked by hook: {}", tool_call.name, reason);
1867 tracing::info!(
1868 tool_name = tool_call.name.as_str(),
1869 "Tool blocked by PreToolUse hook"
1870 );
1871
1872 if let Some(tx) = &event_tx {
1873 tx.send(AgentEvent::PermissionDenied {
1874 tool_id: tool_call.id.clone(),
1875 tool_name: tool_call.name.clone(),
1876 args: tool_call.args.clone(),
1877 reason: reason.clone(),
1878 })
1879 .await
1880 .ok();
1881 }
1882
1883 messages.push(Message::tool_result(&tool_call.id, &msg, true));
1884 continue;
1885 }
1886
1887 let permission_decision = if let Some(checker) = &self.config.permission_checker {
1889 checker.check(&tool_call.name, &tool_call.args)
1890 } else {
1891 PermissionDecision::Ask
1893 };
1894
1895 let (output, exit_code, is_error, _metadata, images) = match permission_decision {
1896 PermissionDecision::Deny => {
1897 tracing::info!(
1898 tool_name = tool_call.name.as_str(),
1899 permission = "deny",
1900 "Tool permission denied"
1901 );
1902 let denial_msg = format!(
1904 "Permission denied: Tool '{}' is blocked by permission policy.",
1905 tool_call.name
1906 );
1907
1908 if let Some(tx) = &event_tx {
1910 tx.send(AgentEvent::PermissionDenied {
1911 tool_id: tool_call.id.clone(),
1912 tool_name: tool_call.name.clone(),
1913 args: tool_call.args.clone(),
1914 reason: "Blocked by deny rule in permission policy".to_string(),
1915 })
1916 .await
1917 .ok();
1918 }
1919
1920 (denial_msg, 1, true, None, Vec::new())
1921 }
1922 PermissionDecision::Allow => {
1923 tracing::info!(
1924 tool_name = tool_call.name.as_str(),
1925 permission = "allow",
1926 "Tool permission: allow"
1927 );
1928 let stream_ctx =
1930 self.streaming_tool_context(&event_tx, &tool_call.id, &tool_call.name);
1931 let result = self
1932 .execute_tool_queued_or_direct(
1933 &tool_call.name,
1934 &tool_call.args,
1935 &stream_ctx,
1936 )
1937 .await;
1938
1939 Self::tool_result_to_tuple(result)
1940 }
1941 PermissionDecision::Ask => {
1942 tracing::info!(
1943 tool_name = tool_call.name.as_str(),
1944 permission = "ask",
1945 "Tool permission: ask"
1946 );
1947 if let Some(cm) = &self.config.confirmation_manager {
1949 if !cm.requires_confirmation(&tool_call.name).await {
1951 let stream_ctx = self.streaming_tool_context(
1952 &event_tx,
1953 &tool_call.id,
1954 &tool_call.name,
1955 );
1956 let result = self
1957 .execute_tool_queued_or_direct(
1958 &tool_call.name,
1959 &tool_call.args,
1960 &stream_ctx,
1961 )
1962 .await;
1963
1964 let (output, exit_code, is_error, _metadata, images) =
1965 Self::tool_result_to_tuple(result);
1966
1967 if images.is_empty() {
1969 messages.push(Message::tool_result(
1970 &tool_call.id,
1971 &output,
1972 is_error,
1973 ));
1974 } else {
1975 messages.push(Message::tool_result_with_images(
1976 &tool_call.id,
1977 &output,
1978 &images,
1979 is_error,
1980 ));
1981 }
1982
1983 let tool_duration = tool_start.elapsed();
1985 crate::telemetry::record_tool_result(exit_code, tool_duration);
1986
1987 if let Some(tx) = &event_tx {
1989 tx.send(AgentEvent::ToolEnd {
1990 id: tool_call.id.clone(),
1991 name: tool_call.name.clone(),
1992 output: output.clone(),
1993 exit_code,
1994 })
1995 .await
1996 .ok();
1997 }
1998
1999 self.fire_post_tool_use(
2001 session_id.unwrap_or(""),
2002 &tool_call.name,
2003 &tool_call.args,
2004 &output,
2005 exit_code == 0,
2006 tool_duration.as_millis() as u64,
2007 )
2008 .await;
2009
2010 continue; }
2012
2013 let policy = cm.policy().await;
2015 let timeout_ms = policy.default_timeout_ms;
2016 let timeout_action = policy.timeout_action;
2017
2018 let rx = cm
2020 .request_confirmation(
2021 &tool_call.id,
2022 &tool_call.name,
2023 &tool_call.args,
2024 )
2025 .await;
2026
2027 if let Some(tx) = &event_tx {
2031 tx.send(AgentEvent::ConfirmationRequired {
2032 tool_id: tool_call.id.clone(),
2033 tool_name: tool_call.name.clone(),
2034 args: tool_call.args.clone(),
2035 timeout_ms,
2036 })
2037 .await
2038 .ok();
2039 }
2040
2041 let confirmation_result =
2043 tokio::time::timeout(Duration::from_millis(timeout_ms), rx).await;
2044
2045 match confirmation_result {
2046 Ok(Ok(response)) => {
2047 if let Some(tx) = &event_tx {
2049 tx.send(AgentEvent::ConfirmationReceived {
2050 tool_id: tool_call.id.clone(),
2051 approved: response.approved,
2052 reason: response.reason.clone(),
2053 })
2054 .await
2055 .ok();
2056 }
2057 if response.approved {
2058 let stream_ctx = self.streaming_tool_context(
2059 &event_tx,
2060 &tool_call.id,
2061 &tool_call.name,
2062 );
2063 let result = self
2064 .execute_tool_queued_or_direct(
2065 &tool_call.name,
2066 &tool_call.args,
2067 &stream_ctx,
2068 )
2069 .await;
2070
2071 Self::tool_result_to_tuple(result)
2072 } else {
2073 let rejection_msg = format!(
2074 "Tool '{}' execution was REJECTED by the user. Reason: {}. \
2075 DO NOT retry this tool call unless the user explicitly asks you to.",
2076 tool_call.name,
2077 response.reason.unwrap_or_else(|| "No reason provided".to_string())
2078 );
2079 (rejection_msg, 1, true, None, Vec::new())
2080 }
2081 }
2082 Ok(Err(_)) => {
2083 if let Some(tx) = &event_tx {
2085 tx.send(AgentEvent::ConfirmationTimeout {
2086 tool_id: tool_call.id.clone(),
2087 action_taken: "rejected".to_string(),
2088 })
2089 .await
2090 .ok();
2091 }
2092 let msg = format!(
2093 "Tool '{}' confirmation failed: confirmation channel closed",
2094 tool_call.name
2095 );
2096 (msg, 1, true, None, Vec::new())
2097 }
2098 Err(_) => {
2099 cm.check_timeouts().await;
2100
2101 if let Some(tx) = &event_tx {
2103 tx.send(AgentEvent::ConfirmationTimeout {
2104 tool_id: tool_call.id.clone(),
2105 action_taken: match timeout_action {
2106 crate::hitl::TimeoutAction::Reject => {
2107 "rejected".to_string()
2108 }
2109 crate::hitl::TimeoutAction::AutoApprove => {
2110 "auto_approved".to_string()
2111 }
2112 },
2113 })
2114 .await
2115 .ok();
2116 }
2117
2118 match timeout_action {
2119 crate::hitl::TimeoutAction::Reject => {
2120 let msg = format!(
2121 "Tool '{}' execution was REJECTED: user confirmation timed out after {}ms. \
2122 DO NOT retry this tool call — the user did not approve it. \
2123 Inform the user that the operation requires their approval and ask them to try again.",
2124 tool_call.name, timeout_ms
2125 );
2126 (msg, 1, true, None, Vec::new())
2127 }
2128 crate::hitl::TimeoutAction::AutoApprove => {
2129 let stream_ctx = self.streaming_tool_context(
2130 &event_tx,
2131 &tool_call.id,
2132 &tool_call.name,
2133 );
2134 let result = self
2135 .execute_tool_queued_or_direct(
2136 &tool_call.name,
2137 &tool_call.args,
2138 &stream_ctx,
2139 )
2140 .await;
2141
2142 Self::tool_result_to_tuple(result)
2143 }
2144 }
2145 }
2146 }
2147 } else {
2148 let msg = format!(
2150 "Tool '{}' requires confirmation but no HITL confirmation manager is configured. \
2151 Configure a confirmation policy to enable tool execution.",
2152 tool_call.name
2153 );
2154 tracing::warn!(
2155 tool_name = tool_call.name.as_str(),
2156 "Tool requires confirmation but no HITL manager configured"
2157 );
2158 (msg, 1, true, None, Vec::new())
2159 }
2160 }
2161 };
2162
2163 let tool_duration = tool_start.elapsed();
2164 crate::telemetry::record_tool_result(exit_code, tool_duration);
2165
2166 let output = if let Some(ref sp) = self.config.security_provider {
2168 sp.sanitize_output(&output)
2169 } else {
2170 output
2171 };
2172
2173 self.fire_post_tool_use(
2175 session_id.unwrap_or(""),
2176 &tool_call.name,
2177 &tool_call.args,
2178 &output,
2179 exit_code == 0,
2180 tool_duration.as_millis() as u64,
2181 )
2182 .await;
2183
2184 if let Some(ref memory) = self.config.memory {
2186 let tools_used = [tool_call.name.clone()];
2187 let remember_result = if exit_code == 0 {
2188 memory
2189 .remember_success(effective_prompt, &tools_used, &output)
2190 .await
2191 } else {
2192 memory
2193 .remember_failure(effective_prompt, &output, &tools_used)
2194 .await
2195 };
2196 match remember_result {
2197 Ok(()) => {
2198 if let Some(tx) = &event_tx {
2199 let item_type = if exit_code == 0 { "success" } else { "failure" };
2200 tx.send(AgentEvent::MemoryStored {
2201 memory_id: uuid::Uuid::new_v4().to_string(),
2202 memory_type: item_type.to_string(),
2203 importance: if exit_code == 0 { 0.8 } else { 0.9 },
2204 tags: vec![item_type.to_string(), tool_call.name.clone()],
2205 })
2206 .await
2207 .ok();
2208 }
2209 }
2210 Err(e) => {
2211 tracing::warn!("Failed to store memory after tool execution: {}", e);
2212 }
2213 }
2214 }
2215
2216 if let Some(tx) = &event_tx {
2218 tx.send(AgentEvent::ToolEnd {
2219 id: tool_call.id.clone(),
2220 name: tool_call.name.clone(),
2221 output: output.clone(),
2222 exit_code,
2223 })
2224 .await
2225 .ok();
2226 }
2227
2228 if images.is_empty() {
2230 messages.push(Message::tool_result(&tool_call.id, &output, is_error));
2231 } else {
2232 messages.push(Message::tool_result_with_images(
2233 &tool_call.id,
2234 &output,
2235 &images,
2236 is_error,
2237 ));
2238 }
2239 }
2240 }
2241 }
2242
2243 pub async fn execute_streaming(
2245 &self,
2246 history: &[Message],
2247 prompt: &str,
2248 ) -> Result<(
2249 mpsc::Receiver<AgentEvent>,
2250 tokio::task::JoinHandle<Result<AgentResult>>,
2251 )> {
2252 let (tx, rx) = mpsc::channel(100);
2253
2254 let llm_client = self.llm_client.clone();
2255 let tool_executor = self.tool_executor.clone();
2256 let tool_context = self.tool_context.clone();
2257 let config = self.config.clone();
2258 let tool_metrics = self.tool_metrics.clone();
2259 let command_queue = self.command_queue.clone();
2260 let history = history.to_vec();
2261 let prompt = prompt.to_string();
2262
2263 let handle = tokio::spawn(async move {
2264 let mut agent = AgentLoop::new(llm_client, tool_executor, tool_context, config);
2265 if let Some(metrics) = tool_metrics {
2266 agent = agent.with_tool_metrics(metrics);
2267 }
2268 if let Some(queue) = command_queue {
2269 agent = agent.with_queue(queue);
2270 }
2271 agent.execute(&history, &prompt, Some(tx)).await
2272 });
2273
2274 Ok((rx, handle))
2275 }
2276
2277 pub async fn plan(&self, prompt: &str, _context: Option<&str>) -> Result<ExecutionPlan> {
2282 use crate::planning::LlmPlanner;
2283
2284 match LlmPlanner::create_plan(&self.llm_client, prompt).await {
2285 Ok(plan) => Ok(plan),
2286 Err(e) => {
2287 tracing::warn!("LLM plan creation failed, using fallback: {}", e);
2288 Ok(LlmPlanner::fallback_plan(prompt))
2289 }
2290 }
2291 }
2292
2293 pub async fn execute_with_planning(
2295 &self,
2296 history: &[Message],
2297 prompt: &str,
2298 event_tx: Option<mpsc::Sender<AgentEvent>>,
2299 ) -> Result<AgentResult> {
2300 if let Some(tx) = &event_tx {
2302 tx.send(AgentEvent::PlanningStart {
2303 prompt: prompt.to_string(),
2304 })
2305 .await
2306 .ok();
2307 }
2308
2309 let goal = if self.config.goal_tracking {
2311 let g = self.extract_goal(prompt).await?;
2312 if let Some(tx) = &event_tx {
2313 tx.send(AgentEvent::GoalExtracted { goal: g.clone() })
2314 .await
2315 .ok();
2316 }
2317 Some(g)
2318 } else {
2319 None
2320 };
2321
2322 let plan = self.plan(prompt, None).await?;
2324
2325 if let Some(tx) = &event_tx {
2327 tx.send(AgentEvent::PlanningEnd {
2328 estimated_steps: plan.steps.len(),
2329 plan: plan.clone(),
2330 })
2331 .await
2332 .ok();
2333 }
2334
2335 let plan_start = std::time::Instant::now();
2336
2337 let result = self.execute_plan(history, &plan, event_tx.clone()).await?;
2339
2340 if self.config.goal_tracking {
2342 if let Some(ref g) = goal {
2343 let achieved = self.check_goal_achievement(g, &result.text).await?;
2344 if achieved {
2345 if let Some(tx) = &event_tx {
2346 tx.send(AgentEvent::GoalAchieved {
2347 goal: g.description.clone(),
2348 total_steps: result.messages.len(),
2349 duration_ms: plan_start.elapsed().as_millis() as i64,
2350 })
2351 .await
2352 .ok();
2353 }
2354 }
2355 }
2356 }
2357
2358 Ok(result)
2359 }
2360
2361 async fn execute_plan(
2368 &self,
2369 history: &[Message],
2370 plan: &ExecutionPlan,
2371 event_tx: Option<mpsc::Sender<AgentEvent>>,
2372 ) -> Result<AgentResult> {
2373 let mut plan = plan.clone();
2374 let mut current_history = history.to_vec();
2375 let mut total_usage = TokenUsage::default();
2376 let mut tool_calls_count = 0;
2377 let total_steps = plan.steps.len();
2378
2379 let steps_text = plan
2381 .steps
2382 .iter()
2383 .enumerate()
2384 .map(|(i, step)| format!("{}. {}", i + 1, step.content))
2385 .collect::<Vec<_>>()
2386 .join("\n");
2387 current_history.push(Message::user(&crate::prompts::render(
2388 crate::prompts::PLAN_EXECUTE_GOAL,
2389 &[("goal", &plan.goal), ("steps", &steps_text)],
2390 )));
2391
2392 loop {
2393 let ready: Vec<String> = plan
2394 .get_ready_steps()
2395 .iter()
2396 .map(|s| s.id.clone())
2397 .collect();
2398
2399 if ready.is_empty() {
2400 if plan.has_deadlock() {
2402 tracing::warn!(
2403 "Plan deadlock detected: {} pending steps with unresolvable dependencies",
2404 plan.pending_count()
2405 );
2406 }
2407 break;
2408 }
2409
2410 if ready.len() == 1 {
2411 let step_id = &ready[0];
2413 let step = plan
2414 .steps
2415 .iter()
2416 .find(|s| s.id == *step_id)
2417 .ok_or_else(|| anyhow::anyhow!("step '{}' not found in plan", step_id))?
2418 .clone();
2419 let step_number = plan
2420 .steps
2421 .iter()
2422 .position(|s| s.id == *step_id)
2423 .unwrap_or(0)
2424 + 1;
2425
2426 if let Some(tx) = &event_tx {
2428 tx.send(AgentEvent::StepStart {
2429 step_id: step.id.clone(),
2430 description: step.content.clone(),
2431 step_number,
2432 total_steps,
2433 })
2434 .await
2435 .ok();
2436 }
2437
2438 plan.mark_status(&step.id, TaskStatus::InProgress);
2439
2440 let step_prompt = crate::prompts::render(
2441 crate::prompts::PLAN_EXECUTE_STEP,
2442 &[
2443 ("step_num", &step_number.to_string()),
2444 ("description", &step.content),
2445 ],
2446 );
2447
2448 match self
2449 .execute_loop(¤t_history, &step_prompt, None, event_tx.clone())
2450 .await
2451 {
2452 Ok(result) => {
2453 current_history = result.messages.clone();
2454 total_usage.prompt_tokens += result.usage.prompt_tokens;
2455 total_usage.completion_tokens += result.usage.completion_tokens;
2456 total_usage.total_tokens += result.usage.total_tokens;
2457 tool_calls_count += result.tool_calls_count;
2458 plan.mark_status(&step.id, TaskStatus::Completed);
2459
2460 if let Some(tx) = &event_tx {
2461 tx.send(AgentEvent::StepEnd {
2462 step_id: step.id.clone(),
2463 status: TaskStatus::Completed,
2464 step_number,
2465 total_steps,
2466 })
2467 .await
2468 .ok();
2469 }
2470 }
2471 Err(e) => {
2472 tracing::error!("Plan step '{}' failed: {}", step.id, e);
2473 plan.mark_status(&step.id, TaskStatus::Failed);
2474
2475 if let Some(tx) = &event_tx {
2476 tx.send(AgentEvent::StepEnd {
2477 step_id: step.id.clone(),
2478 status: TaskStatus::Failed,
2479 step_number,
2480 total_steps,
2481 })
2482 .await
2483 .ok();
2484 }
2485 }
2486 }
2487 } else {
2488 let ready_steps: Vec<_> = ready
2495 .iter()
2496 .filter_map(|id| {
2497 let step = plan.steps.iter().find(|s| s.id == *id)?.clone();
2498 let step_number =
2499 plan.steps.iter().position(|s| s.id == *id).unwrap_or(0) + 1;
2500 Some((step, step_number))
2501 })
2502 .collect();
2503
2504 for (step, step_number) in &ready_steps {
2506 plan.mark_status(&step.id, TaskStatus::InProgress);
2507 if let Some(tx) = &event_tx {
2508 tx.send(AgentEvent::StepStart {
2509 step_id: step.id.clone(),
2510 description: step.content.clone(),
2511 step_number: *step_number,
2512 total_steps,
2513 })
2514 .await
2515 .ok();
2516 }
2517 }
2518
2519 let mut join_set = tokio::task::JoinSet::new();
2521 for (step, step_number) in &ready_steps {
2522 let base_history = current_history.clone();
2523 let agent_clone = self.clone();
2524 let tx = event_tx.clone();
2525 let step_clone = step.clone();
2526 let sn = *step_number;
2527
2528 join_set.spawn(async move {
2529 let prompt = crate::prompts::render(
2530 crate::prompts::PLAN_EXECUTE_STEP,
2531 &[
2532 ("step_num", &sn.to_string()),
2533 ("description", &step_clone.content),
2534 ],
2535 );
2536 let result = agent_clone
2537 .execute_loop(&base_history, &prompt, None, tx)
2538 .await;
2539 (step_clone.id, sn, result)
2540 });
2541 }
2542
2543 let mut parallel_summaries = Vec::new();
2545 while let Some(join_result) = join_set.join_next().await {
2546 match join_result {
2547 Ok((step_id, step_number, step_result)) => match step_result {
2548 Ok(result) => {
2549 total_usage.prompt_tokens += result.usage.prompt_tokens;
2550 total_usage.completion_tokens += result.usage.completion_tokens;
2551 total_usage.total_tokens += result.usage.total_tokens;
2552 tool_calls_count += result.tool_calls_count;
2553 plan.mark_status(&step_id, TaskStatus::Completed);
2554
2555 parallel_summaries.push(format!(
2557 "- Step {} ({}): {}",
2558 step_number, step_id, result.text
2559 ));
2560
2561 if let Some(tx) = &event_tx {
2562 tx.send(AgentEvent::StepEnd {
2563 step_id,
2564 status: TaskStatus::Completed,
2565 step_number,
2566 total_steps,
2567 })
2568 .await
2569 .ok();
2570 }
2571 }
2572 Err(e) => {
2573 tracing::error!("Plan step '{}' failed: {}", step_id, e);
2574 plan.mark_status(&step_id, TaskStatus::Failed);
2575
2576 if let Some(tx) = &event_tx {
2577 tx.send(AgentEvent::StepEnd {
2578 step_id,
2579 status: TaskStatus::Failed,
2580 step_number,
2581 total_steps,
2582 })
2583 .await
2584 .ok();
2585 }
2586 }
2587 },
2588 Err(e) => {
2589 tracing::error!("JoinSet task panicked: {}", e);
2590 }
2591 }
2592 }
2593
2594 if !parallel_summaries.is_empty() {
2596 parallel_summaries.sort(); let results_text = parallel_summaries.join("\n");
2598 current_history.push(Message::user(&crate::prompts::render(
2599 crate::prompts::PLAN_PARALLEL_RESULTS,
2600 &[("results", &results_text)],
2601 )));
2602 }
2603 }
2604
2605 if self.config.goal_tracking {
2607 let completed = plan
2608 .steps
2609 .iter()
2610 .filter(|s| s.status == TaskStatus::Completed)
2611 .count();
2612 if let Some(tx) = &event_tx {
2613 tx.send(AgentEvent::GoalProgress {
2614 goal: plan.goal.clone(),
2615 progress: plan.progress(),
2616 completed_steps: completed,
2617 total_steps,
2618 })
2619 .await
2620 .ok();
2621 }
2622 }
2623 }
2624
2625 let final_text = current_history
2627 .last()
2628 .map(|m| {
2629 m.content
2630 .iter()
2631 .filter_map(|block| {
2632 if let crate::llm::ContentBlock::Text { text } = block {
2633 Some(text.as_str())
2634 } else {
2635 None
2636 }
2637 })
2638 .collect::<Vec<_>>()
2639 .join("\n")
2640 })
2641 .unwrap_or_default();
2642
2643 Ok(AgentResult {
2644 text: final_text,
2645 messages: current_history,
2646 usage: total_usage,
2647 tool_calls_count,
2648 })
2649 }
2650
2651 pub async fn extract_goal(&self, prompt: &str) -> Result<AgentGoal> {
2656 use crate::planning::LlmPlanner;
2657
2658 match LlmPlanner::extract_goal(&self.llm_client, prompt).await {
2659 Ok(goal) => Ok(goal),
2660 Err(e) => {
2661 tracing::warn!("LLM goal extraction failed, using fallback: {}", e);
2662 Ok(LlmPlanner::fallback_goal(prompt))
2663 }
2664 }
2665 }
2666
2667 pub async fn check_goal_achievement(
2672 &self,
2673 goal: &AgentGoal,
2674 current_state: &str,
2675 ) -> Result<bool> {
2676 use crate::planning::LlmPlanner;
2677
2678 match LlmPlanner::check_achievement(&self.llm_client, goal, current_state).await {
2679 Ok(result) => Ok(result.achieved),
2680 Err(e) => {
2681 tracing::warn!("LLM achievement check failed, using fallback: {}", e);
2682 let result = LlmPlanner::fallback_check_achievement(goal, current_state);
2683 Ok(result.achieved)
2684 }
2685 }
2686 }
2687}
2688
2689#[cfg(test)]
2690mod tests {
2691 use super::*;
2692 use crate::llm::{ContentBlock, StreamEvent};
2693 use crate::permissions::PermissionPolicy;
2694 use crate::tools::ToolExecutor;
2695 use std::path::PathBuf;
2696 use std::sync::atomic::{AtomicUsize, Ordering};
2697
2698 fn test_tool_context() -> ToolContext {
2700 ToolContext::new(PathBuf::from("/tmp"))
2701 }
2702
2703 #[test]
2704 fn test_agent_config_default() {
2705 let config = AgentConfig::default();
2706 assert!(config.prompt_slots.is_empty());
2707 assert!(config.tools.is_empty()); assert_eq!(config.max_tool_rounds, MAX_TOOL_ROUNDS);
2709 assert!(config.permission_checker.is_none());
2710 assert!(config.context_providers.is_empty());
2711 let registry = config
2713 .skill_registry
2714 .expect("skill_registry must be Some by default");
2715 assert!(registry.len() >= 7, "expected at least 7 built-in skills");
2716 assert!(registry.get("code-search").is_some());
2717 assert!(registry.get("find-bugs").is_some());
2718 }
2719
2720 pub(crate) struct MockLlmClient {
2726 responses: std::sync::Mutex<Vec<LlmResponse>>,
2728 pub(crate) call_count: AtomicUsize,
2730 }
2731
2732 impl MockLlmClient {
2733 pub(crate) fn new(responses: Vec<LlmResponse>) -> Self {
2734 Self {
2735 responses: std::sync::Mutex::new(responses),
2736 call_count: AtomicUsize::new(0),
2737 }
2738 }
2739
2740 pub(crate) fn text_response(text: &str) -> LlmResponse {
2742 LlmResponse {
2743 message: Message {
2744 role: "assistant".to_string(),
2745 content: vec![ContentBlock::Text {
2746 text: text.to_string(),
2747 }],
2748 reasoning_content: None,
2749 },
2750 usage: TokenUsage {
2751 prompt_tokens: 10,
2752 completion_tokens: 5,
2753 total_tokens: 15,
2754 cache_read_tokens: None,
2755 cache_write_tokens: None,
2756 },
2757 stop_reason: Some("end_turn".to_string()),
2758 }
2759 }
2760
2761 pub(crate) fn tool_call_response(
2763 tool_id: &str,
2764 tool_name: &str,
2765 args: serde_json::Value,
2766 ) -> LlmResponse {
2767 LlmResponse {
2768 message: Message {
2769 role: "assistant".to_string(),
2770 content: vec![ContentBlock::ToolUse {
2771 id: tool_id.to_string(),
2772 name: tool_name.to_string(),
2773 input: args,
2774 }],
2775 reasoning_content: None,
2776 },
2777 usage: TokenUsage {
2778 prompt_tokens: 10,
2779 completion_tokens: 5,
2780 total_tokens: 15,
2781 cache_read_tokens: None,
2782 cache_write_tokens: None,
2783 },
2784 stop_reason: Some("tool_use".to_string()),
2785 }
2786 }
2787 }
2788
2789 #[async_trait::async_trait]
2790 impl LlmClient for MockLlmClient {
2791 async fn complete(
2792 &self,
2793 _messages: &[Message],
2794 _system: Option<&str>,
2795 _tools: &[ToolDefinition],
2796 ) -> Result<LlmResponse> {
2797 self.call_count.fetch_add(1, Ordering::SeqCst);
2798 let mut responses = self.responses.lock().unwrap();
2799 if responses.is_empty() {
2800 anyhow::bail!("No more mock responses available");
2801 }
2802 Ok(responses.remove(0))
2803 }
2804
2805 async fn complete_streaming(
2806 &self,
2807 _messages: &[Message],
2808 _system: Option<&str>,
2809 _tools: &[ToolDefinition],
2810 ) -> Result<mpsc::Receiver<StreamEvent>> {
2811 self.call_count.fetch_add(1, Ordering::SeqCst);
2812 let mut responses = self.responses.lock().unwrap();
2813 if responses.is_empty() {
2814 anyhow::bail!("No more mock responses available");
2815 }
2816 let response = responses.remove(0);
2817
2818 let (tx, rx) = mpsc::channel(10);
2819 tokio::spawn(async move {
2820 for block in &response.message.content {
2822 if let ContentBlock::Text { text } = block {
2823 tx.send(StreamEvent::TextDelta(text.clone())).await.ok();
2824 }
2825 }
2826 tx.send(StreamEvent::Done(response)).await.ok();
2827 });
2828
2829 Ok(rx)
2830 }
2831 }
2832
2833 #[tokio::test]
2838 async fn test_agent_simple_response() {
2839 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
2840 "Hello, I'm an AI assistant.",
2841 )]));
2842
2843 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2844 let config = AgentConfig::default();
2845
2846 let agent = AgentLoop::new(
2847 mock_client.clone(),
2848 tool_executor,
2849 test_tool_context(),
2850 config,
2851 );
2852 let result = agent.execute(&[], "Hello", None).await.unwrap();
2853
2854 assert_eq!(result.text, "Hello, I'm an AI assistant.");
2855 assert_eq!(result.tool_calls_count, 0);
2856 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 1);
2857 }
2858
2859 #[tokio::test]
2860 async fn test_agent_with_tool_call() {
2861 let mock_client = Arc::new(MockLlmClient::new(vec![
2862 MockLlmClient::tool_call_response(
2864 "tool-1",
2865 "bash",
2866 serde_json::json!({"command": "echo hello"}),
2867 ),
2868 MockLlmClient::text_response("The command output was: hello"),
2870 ]));
2871
2872 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2873 let config = AgentConfig::default();
2874
2875 let agent = AgentLoop::new(
2876 mock_client.clone(),
2877 tool_executor,
2878 test_tool_context(),
2879 config,
2880 );
2881 let result = agent.execute(&[], "Run echo hello", None).await.unwrap();
2882
2883 assert_eq!(result.text, "The command output was: hello");
2884 assert_eq!(result.tool_calls_count, 1);
2885 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 2);
2886 }
2887
2888 #[tokio::test]
2889 async fn test_agent_permission_deny() {
2890 let mock_client = Arc::new(MockLlmClient::new(vec![
2891 MockLlmClient::tool_call_response(
2893 "tool-1",
2894 "bash",
2895 serde_json::json!({"command": "rm -rf /tmp/test"}),
2896 ),
2897 MockLlmClient::text_response(
2899 "I cannot execute that command due to permission restrictions.",
2900 ),
2901 ]));
2902
2903 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2904
2905 let permission_policy = PermissionPolicy::new().deny("bash(rm:*)");
2907
2908 let config = AgentConfig {
2909 permission_checker: Some(Arc::new(permission_policy)),
2910 ..Default::default()
2911 };
2912
2913 let (tx, mut rx) = mpsc::channel(100);
2914 let agent = AgentLoop::new(
2915 mock_client.clone(),
2916 tool_executor,
2917 test_tool_context(),
2918 config,
2919 );
2920 let result = agent.execute(&[], "Delete files", Some(tx)).await.unwrap();
2921
2922 let mut found_permission_denied = false;
2924 while let Ok(event) = rx.try_recv() {
2925 if let AgentEvent::PermissionDenied { tool_name, .. } = event {
2926 assert_eq!(tool_name, "bash");
2927 found_permission_denied = true;
2928 }
2929 }
2930 assert!(
2931 found_permission_denied,
2932 "Should have received PermissionDenied event"
2933 );
2934
2935 assert_eq!(result.tool_calls_count, 1);
2936 }
2937
2938 #[tokio::test]
2939 async fn test_agent_permission_allow() {
2940 let mock_client = Arc::new(MockLlmClient::new(vec![
2941 MockLlmClient::tool_call_response(
2943 "tool-1",
2944 "bash",
2945 serde_json::json!({"command": "echo hello"}),
2946 ),
2947 MockLlmClient::text_response("Done!"),
2949 ]));
2950
2951 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2952
2953 let permission_policy = PermissionPolicy::new()
2955 .allow("bash(echo:*)")
2956 .deny("bash(rm:*)");
2957
2958 let config = AgentConfig {
2959 permission_checker: Some(Arc::new(permission_policy)),
2960 ..Default::default()
2961 };
2962
2963 let agent = AgentLoop::new(
2964 mock_client.clone(),
2965 tool_executor,
2966 test_tool_context(),
2967 config,
2968 );
2969 let result = agent.execute(&[], "Echo hello", None).await.unwrap();
2970
2971 assert_eq!(result.text, "Done!");
2972 assert_eq!(result.tool_calls_count, 1);
2973 }
2974
2975 #[tokio::test]
2976 async fn test_agent_streaming_events() {
2977 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
2978 "Hello!",
2979 )]));
2980
2981 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2982 let config = AgentConfig::default();
2983
2984 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2985 let (mut rx, handle) = agent.execute_streaming(&[], "Hi").await.unwrap();
2986
2987 let mut events = Vec::new();
2989 while let Some(event) = rx.recv().await {
2990 events.push(event);
2991 }
2992
2993 let result = handle.await.unwrap().unwrap();
2994 assert_eq!(result.text, "Hello!");
2995
2996 assert!(events.iter().any(|e| matches!(e, AgentEvent::Start { .. })));
2998 assert!(events.iter().any(|e| matches!(e, AgentEvent::End { .. })));
2999 }
3000
3001 #[tokio::test]
3002 async fn test_agent_max_tool_rounds() {
3003 let responses: Vec<LlmResponse> = (0..100)
3005 .map(|i| {
3006 MockLlmClient::tool_call_response(
3007 &format!("tool-{}", i),
3008 "bash",
3009 serde_json::json!({"command": "echo loop"}),
3010 )
3011 })
3012 .collect();
3013
3014 let mock_client = Arc::new(MockLlmClient::new(responses));
3015 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3016
3017 let config = AgentConfig {
3018 max_tool_rounds: 3,
3019 ..Default::default()
3020 };
3021
3022 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3023 let result = agent.execute(&[], "Loop forever", None).await;
3024
3025 assert!(result.is_err());
3027 assert!(result.unwrap_err().to_string().contains("Max tool rounds"));
3028 }
3029
3030 #[tokio::test]
3031 async fn test_agent_no_permission_policy_defaults_to_ask() {
3032 let mock_client = Arc::new(MockLlmClient::new(vec![
3035 MockLlmClient::tool_call_response(
3036 "tool-1",
3037 "bash",
3038 serde_json::json!({"command": "rm -rf /tmp/test"}),
3039 ),
3040 MockLlmClient::text_response("Denied!"),
3041 ]));
3042
3043 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3044 let config = AgentConfig {
3045 permission_checker: None, ..Default::default()
3048 };
3049
3050 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3051 let result = agent.execute(&[], "Delete", None).await.unwrap();
3052
3053 assert_eq!(result.text, "Denied!");
3055 assert_eq!(result.tool_calls_count, 1);
3056 }
3057
3058 #[tokio::test]
3059 async fn test_agent_permission_ask_without_cm_denies() {
3060 let mock_client = Arc::new(MockLlmClient::new(vec![
3063 MockLlmClient::tool_call_response(
3064 "tool-1",
3065 "bash",
3066 serde_json::json!({"command": "echo test"}),
3067 ),
3068 MockLlmClient::text_response("Denied!"),
3069 ]));
3070
3071 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3072
3073 let permission_policy = PermissionPolicy::new(); let config = AgentConfig {
3077 permission_checker: Some(Arc::new(permission_policy)),
3078 ..Default::default()
3080 };
3081
3082 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3083 let result = agent.execute(&[], "Echo", None).await.unwrap();
3084
3085 assert_eq!(result.text, "Denied!");
3087 assert!(result.tool_calls_count >= 1);
3089 }
3090
3091 #[tokio::test]
3096 async fn test_agent_hitl_approved() {
3097 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
3098 use tokio::sync::broadcast;
3099
3100 let mock_client = Arc::new(MockLlmClient::new(vec![
3101 MockLlmClient::tool_call_response(
3102 "tool-1",
3103 "bash",
3104 serde_json::json!({"command": "echo hello"}),
3105 ),
3106 MockLlmClient::text_response("Command executed!"),
3107 ]));
3108
3109 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3110
3111 let (event_tx, _event_rx) = broadcast::channel(100);
3113 let hitl_policy = ConfirmationPolicy {
3114 enabled: true,
3115 ..Default::default()
3116 };
3117 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3118
3119 let permission_policy = PermissionPolicy::new(); let config = AgentConfig {
3123 permission_checker: Some(Arc::new(permission_policy)),
3124 confirmation_manager: Some(confirmation_manager.clone()),
3125 ..Default::default()
3126 };
3127
3128 let cm_clone = confirmation_manager.clone();
3130 tokio::spawn(async move {
3131 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
3133 cm_clone.confirm("tool-1", true, None).await.ok();
3135 });
3136
3137 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3138 let result = agent.execute(&[], "Run echo", None).await.unwrap();
3139
3140 assert_eq!(result.text, "Command executed!");
3141 assert_eq!(result.tool_calls_count, 1);
3142 }
3143
3144 #[tokio::test]
3145 async fn test_agent_hitl_rejected() {
3146 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
3147 use tokio::sync::broadcast;
3148
3149 let mock_client = Arc::new(MockLlmClient::new(vec![
3150 MockLlmClient::tool_call_response(
3151 "tool-1",
3152 "bash",
3153 serde_json::json!({"command": "rm -rf /"}),
3154 ),
3155 MockLlmClient::text_response("Understood, I won't do that."),
3156 ]));
3157
3158 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3159
3160 let (event_tx, _event_rx) = broadcast::channel(100);
3162 let hitl_policy = ConfirmationPolicy {
3163 enabled: true,
3164 ..Default::default()
3165 };
3166 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3167
3168 let permission_policy = PermissionPolicy::new();
3170
3171 let config = AgentConfig {
3172 permission_checker: Some(Arc::new(permission_policy)),
3173 confirmation_manager: Some(confirmation_manager.clone()),
3174 ..Default::default()
3175 };
3176
3177 let cm_clone = confirmation_manager.clone();
3179 tokio::spawn(async move {
3180 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
3181 cm_clone
3182 .confirm("tool-1", false, Some("Too dangerous".to_string()))
3183 .await
3184 .ok();
3185 });
3186
3187 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3188 let result = agent.execute(&[], "Delete everything", None).await.unwrap();
3189
3190 assert_eq!(result.text, "Understood, I won't do that.");
3192 }
3193
3194 #[tokio::test]
3195 async fn test_agent_hitl_timeout_reject() {
3196 use crate::hitl::{ConfirmationManager, ConfirmationPolicy, TimeoutAction};
3197 use tokio::sync::broadcast;
3198
3199 let mock_client = Arc::new(MockLlmClient::new(vec![
3200 MockLlmClient::tool_call_response(
3201 "tool-1",
3202 "bash",
3203 serde_json::json!({"command": "echo test"}),
3204 ),
3205 MockLlmClient::text_response("Timed out, I understand."),
3206 ]));
3207
3208 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3209
3210 let (event_tx, _event_rx) = broadcast::channel(100);
3212 let hitl_policy = ConfirmationPolicy {
3213 enabled: true,
3214 default_timeout_ms: 50, timeout_action: TimeoutAction::Reject,
3216 ..Default::default()
3217 };
3218 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3219
3220 let permission_policy = PermissionPolicy::new();
3221
3222 let config = AgentConfig {
3223 permission_checker: Some(Arc::new(permission_policy)),
3224 confirmation_manager: Some(confirmation_manager),
3225 ..Default::default()
3226 };
3227
3228 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3230 let result = agent.execute(&[], "Echo", None).await.unwrap();
3231
3232 assert_eq!(result.text, "Timed out, I understand.");
3234 }
3235
3236 #[tokio::test]
3237 async fn test_agent_hitl_timeout_auto_approve() {
3238 use crate::hitl::{ConfirmationManager, ConfirmationPolicy, TimeoutAction};
3239 use tokio::sync::broadcast;
3240
3241 let mock_client = Arc::new(MockLlmClient::new(vec![
3242 MockLlmClient::tool_call_response(
3243 "tool-1",
3244 "bash",
3245 serde_json::json!({"command": "echo hello"}),
3246 ),
3247 MockLlmClient::text_response("Auto-approved and executed!"),
3248 ]));
3249
3250 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3251
3252 let (event_tx, _event_rx) = broadcast::channel(100);
3254 let hitl_policy = ConfirmationPolicy {
3255 enabled: true,
3256 default_timeout_ms: 50, timeout_action: TimeoutAction::AutoApprove,
3258 ..Default::default()
3259 };
3260 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3261
3262 let permission_policy = PermissionPolicy::new();
3263
3264 let config = AgentConfig {
3265 permission_checker: Some(Arc::new(permission_policy)),
3266 confirmation_manager: Some(confirmation_manager),
3267 ..Default::default()
3268 };
3269
3270 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3272 let result = agent.execute(&[], "Echo", None).await.unwrap();
3273
3274 assert_eq!(result.text, "Auto-approved and executed!");
3276 assert_eq!(result.tool_calls_count, 1);
3277 }
3278
3279 #[tokio::test]
3280 async fn test_agent_hitl_confirmation_events() {
3281 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
3282 use tokio::sync::broadcast;
3283
3284 let mock_client = Arc::new(MockLlmClient::new(vec![
3285 MockLlmClient::tool_call_response(
3286 "tool-1",
3287 "bash",
3288 serde_json::json!({"command": "echo test"}),
3289 ),
3290 MockLlmClient::text_response("Done!"),
3291 ]));
3292
3293 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3294
3295 let (event_tx, mut event_rx) = broadcast::channel(100);
3297 let hitl_policy = ConfirmationPolicy {
3298 enabled: true,
3299 default_timeout_ms: 5000, ..Default::default()
3301 };
3302 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3303
3304 let permission_policy = PermissionPolicy::new();
3305
3306 let config = AgentConfig {
3307 permission_checker: Some(Arc::new(permission_policy)),
3308 confirmation_manager: Some(confirmation_manager.clone()),
3309 ..Default::default()
3310 };
3311
3312 let cm_clone = confirmation_manager.clone();
3314 let event_handle = tokio::spawn(async move {
3315 let mut events = Vec::new();
3316 while let Ok(event) = event_rx.recv().await {
3318 events.push(event.clone());
3319 if let AgentEvent::ConfirmationRequired { tool_id, .. } = event {
3320 cm_clone.confirm(&tool_id, true, None).await.ok();
3322 if let Ok(recv_event) = event_rx.recv().await {
3324 events.push(recv_event);
3325 }
3326 break;
3327 }
3328 }
3329 events
3330 });
3331
3332 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3333 let _result = agent.execute(&[], "Echo", None).await.unwrap();
3334
3335 let events = event_handle.await.unwrap();
3337 assert!(
3338 events
3339 .iter()
3340 .any(|e| matches!(e, AgentEvent::ConfirmationRequired { .. })),
3341 "Should have ConfirmationRequired event"
3342 );
3343 assert!(
3344 events
3345 .iter()
3346 .any(|e| matches!(e, AgentEvent::ConfirmationReceived { approved: true, .. })),
3347 "Should have ConfirmationReceived event with approved=true"
3348 );
3349 }
3350
3351 #[tokio::test]
3352 async fn test_agent_hitl_disabled_auto_executes() {
3353 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
3355 use tokio::sync::broadcast;
3356
3357 let mock_client = Arc::new(MockLlmClient::new(vec![
3358 MockLlmClient::tool_call_response(
3359 "tool-1",
3360 "bash",
3361 serde_json::json!({"command": "echo auto"}),
3362 ),
3363 MockLlmClient::text_response("Auto executed!"),
3364 ]));
3365
3366 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3367
3368 let (event_tx, _event_rx) = broadcast::channel(100);
3370 let hitl_policy = ConfirmationPolicy {
3371 enabled: false, ..Default::default()
3373 };
3374 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3375
3376 let permission_policy = PermissionPolicy::new(); let config = AgentConfig {
3379 permission_checker: Some(Arc::new(permission_policy)),
3380 confirmation_manager: Some(confirmation_manager),
3381 ..Default::default()
3382 };
3383
3384 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3385 let result = agent.execute(&[], "Echo", None).await.unwrap();
3386
3387 assert_eq!(result.text, "Auto executed!");
3389 assert_eq!(result.tool_calls_count, 1);
3390 }
3391
3392 #[tokio::test]
3393 async fn test_agent_hitl_with_permission_deny_skips_hitl() {
3394 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
3396 use tokio::sync::broadcast;
3397
3398 let mock_client = Arc::new(MockLlmClient::new(vec![
3399 MockLlmClient::tool_call_response(
3400 "tool-1",
3401 "bash",
3402 serde_json::json!({"command": "rm -rf /"}),
3403 ),
3404 MockLlmClient::text_response("Blocked by permission."),
3405 ]));
3406
3407 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3408
3409 let (event_tx, mut event_rx) = broadcast::channel(100);
3411 let hitl_policy = ConfirmationPolicy {
3412 enabled: true,
3413 ..Default::default()
3414 };
3415 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3416
3417 let permission_policy = PermissionPolicy::new().deny("bash(rm:*)");
3419
3420 let config = AgentConfig {
3421 permission_checker: Some(Arc::new(permission_policy)),
3422 confirmation_manager: Some(confirmation_manager),
3423 ..Default::default()
3424 };
3425
3426 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3427 let result = agent.execute(&[], "Delete", None).await.unwrap();
3428
3429 assert_eq!(result.text, "Blocked by permission.");
3431
3432 let mut found_confirmation = false;
3434 while let Ok(event) = event_rx.try_recv() {
3435 if matches!(event, AgentEvent::ConfirmationRequired { .. }) {
3436 found_confirmation = true;
3437 }
3438 }
3439 assert!(
3440 !found_confirmation,
3441 "HITL should not be triggered when permission is Deny"
3442 );
3443 }
3444
3445 #[tokio::test]
3446 async fn test_agent_hitl_with_permission_allow_skips_hitl() {
3447 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
3450 use tokio::sync::broadcast;
3451
3452 let mock_client = Arc::new(MockLlmClient::new(vec![
3453 MockLlmClient::tool_call_response(
3454 "tool-1",
3455 "bash",
3456 serde_json::json!({"command": "echo hello"}),
3457 ),
3458 MockLlmClient::text_response("Allowed!"),
3459 ]));
3460
3461 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3462
3463 let (event_tx, mut event_rx) = broadcast::channel(100);
3465 let hitl_policy = ConfirmationPolicy {
3466 enabled: true,
3467 ..Default::default()
3468 };
3469 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3470
3471 let permission_policy = PermissionPolicy::new().allow("bash(echo:*)");
3473
3474 let config = AgentConfig {
3475 permission_checker: Some(Arc::new(permission_policy)),
3476 confirmation_manager: Some(confirmation_manager.clone()),
3477 ..Default::default()
3478 };
3479
3480 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3481 let result = agent.execute(&[], "Echo", None).await.unwrap();
3482
3483 assert_eq!(result.text, "Allowed!");
3485
3486 let mut found_confirmation = false;
3488 while let Ok(event) = event_rx.try_recv() {
3489 if matches!(event, AgentEvent::ConfirmationRequired { .. }) {
3490 found_confirmation = true;
3491 }
3492 }
3493 assert!(
3494 !found_confirmation,
3495 "Permission Allow should skip HITL confirmation"
3496 );
3497 }
3498
3499 #[tokio::test]
3500 async fn test_agent_hitl_multiple_tool_calls() {
3501 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
3503 use tokio::sync::broadcast;
3504
3505 let mock_client = Arc::new(MockLlmClient::new(vec![
3506 LlmResponse {
3508 message: Message {
3509 role: "assistant".to_string(),
3510 content: vec![
3511 ContentBlock::ToolUse {
3512 id: "tool-1".to_string(),
3513 name: "bash".to_string(),
3514 input: serde_json::json!({"command": "echo first"}),
3515 },
3516 ContentBlock::ToolUse {
3517 id: "tool-2".to_string(),
3518 name: "bash".to_string(),
3519 input: serde_json::json!({"command": "echo second"}),
3520 },
3521 ],
3522 reasoning_content: None,
3523 },
3524 usage: TokenUsage {
3525 prompt_tokens: 10,
3526 completion_tokens: 5,
3527 total_tokens: 15,
3528 cache_read_tokens: None,
3529 cache_write_tokens: None,
3530 },
3531 stop_reason: Some("tool_use".to_string()),
3532 },
3533 MockLlmClient::text_response("Both executed!"),
3534 ]));
3535
3536 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3537
3538 let (event_tx, _event_rx) = broadcast::channel(100);
3540 let hitl_policy = ConfirmationPolicy {
3541 enabled: true,
3542 default_timeout_ms: 5000,
3543 ..Default::default()
3544 };
3545 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3546
3547 let permission_policy = PermissionPolicy::new(); let config = AgentConfig {
3550 permission_checker: Some(Arc::new(permission_policy)),
3551 confirmation_manager: Some(confirmation_manager.clone()),
3552 ..Default::default()
3553 };
3554
3555 let cm_clone = confirmation_manager.clone();
3557 tokio::spawn(async move {
3558 tokio::time::sleep(std::time::Duration::from_millis(30)).await;
3559 cm_clone.confirm("tool-1", true, None).await.ok();
3560 tokio::time::sleep(std::time::Duration::from_millis(30)).await;
3561 cm_clone.confirm("tool-2", true, None).await.ok();
3562 });
3563
3564 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3565 let result = agent.execute(&[], "Run both", None).await.unwrap();
3566
3567 assert_eq!(result.text, "Both executed!");
3568 assert_eq!(result.tool_calls_count, 2);
3569 }
3570
3571 #[tokio::test]
3572 async fn test_agent_hitl_partial_approval() {
3573 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
3575 use tokio::sync::broadcast;
3576
3577 let mock_client = Arc::new(MockLlmClient::new(vec![
3578 LlmResponse {
3580 message: Message {
3581 role: "assistant".to_string(),
3582 content: vec![
3583 ContentBlock::ToolUse {
3584 id: "tool-1".to_string(),
3585 name: "bash".to_string(),
3586 input: serde_json::json!({"command": "echo safe"}),
3587 },
3588 ContentBlock::ToolUse {
3589 id: "tool-2".to_string(),
3590 name: "bash".to_string(),
3591 input: serde_json::json!({"command": "rm -rf /"}),
3592 },
3593 ],
3594 reasoning_content: None,
3595 },
3596 usage: TokenUsage {
3597 prompt_tokens: 10,
3598 completion_tokens: 5,
3599 total_tokens: 15,
3600 cache_read_tokens: None,
3601 cache_write_tokens: None,
3602 },
3603 stop_reason: Some("tool_use".to_string()),
3604 },
3605 MockLlmClient::text_response("First worked, second rejected."),
3606 ]));
3607
3608 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3609
3610 let (event_tx, _event_rx) = broadcast::channel(100);
3611 let hitl_policy = ConfirmationPolicy {
3612 enabled: true,
3613 default_timeout_ms: 5000,
3614 ..Default::default()
3615 };
3616 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3617
3618 let permission_policy = PermissionPolicy::new();
3619
3620 let config = AgentConfig {
3621 permission_checker: Some(Arc::new(permission_policy)),
3622 confirmation_manager: Some(confirmation_manager.clone()),
3623 ..Default::default()
3624 };
3625
3626 let cm_clone = confirmation_manager.clone();
3628 tokio::spawn(async move {
3629 tokio::time::sleep(std::time::Duration::from_millis(30)).await;
3630 cm_clone.confirm("tool-1", true, None).await.ok();
3631 tokio::time::sleep(std::time::Duration::from_millis(30)).await;
3632 cm_clone
3633 .confirm("tool-2", false, Some("Dangerous".to_string()))
3634 .await
3635 .ok();
3636 });
3637
3638 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3639 let result = agent.execute(&[], "Run both", None).await.unwrap();
3640
3641 assert_eq!(result.text, "First worked, second rejected.");
3642 assert_eq!(result.tool_calls_count, 2);
3643 }
3644
3645 #[tokio::test]
3646 async fn test_agent_hitl_yolo_mode_auto_approves() {
3647 use crate::hitl::{ConfirmationManager, ConfirmationPolicy, SessionLane};
3649 use tokio::sync::broadcast;
3650
3651 let mock_client = Arc::new(MockLlmClient::new(vec![
3652 MockLlmClient::tool_call_response(
3653 "tool-1",
3654 "read", serde_json::json!({"path": "/tmp/test.txt"}),
3656 ),
3657 MockLlmClient::text_response("File read!"),
3658 ]));
3659
3660 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3661
3662 let (event_tx, mut event_rx) = broadcast::channel(100);
3664 let mut yolo_lanes = std::collections::HashSet::new();
3665 yolo_lanes.insert(SessionLane::Query);
3666 let hitl_policy = ConfirmationPolicy {
3667 enabled: true,
3668 yolo_lanes, ..Default::default()
3670 };
3671 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3672
3673 let permission_policy = PermissionPolicy::new();
3674
3675 let config = AgentConfig {
3676 permission_checker: Some(Arc::new(permission_policy)),
3677 confirmation_manager: Some(confirmation_manager),
3678 ..Default::default()
3679 };
3680
3681 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3682 let result = agent.execute(&[], "Read file", None).await.unwrap();
3683
3684 assert_eq!(result.text, "File read!");
3686
3687 let mut found_confirmation = false;
3689 while let Ok(event) = event_rx.try_recv() {
3690 if matches!(event, AgentEvent::ConfirmationRequired { .. }) {
3691 found_confirmation = true;
3692 }
3693 }
3694 assert!(
3695 !found_confirmation,
3696 "YOLO mode should not trigger confirmation"
3697 );
3698 }
3699
3700 #[tokio::test]
3701 async fn test_agent_config_with_all_options() {
3702 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
3703 use tokio::sync::broadcast;
3704
3705 let (event_tx, _) = broadcast::channel(100);
3706 let hitl_policy = ConfirmationPolicy::default();
3707 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3708
3709 let permission_policy = PermissionPolicy::new().allow("bash(*)");
3710
3711 let config = AgentConfig {
3712 prompt_slots: SystemPromptSlots {
3713 extra: Some("Test system prompt".to_string()),
3714 ..Default::default()
3715 },
3716 tools: vec![],
3717 max_tool_rounds: 10,
3718 permission_checker: Some(Arc::new(permission_policy)),
3719 confirmation_manager: Some(confirmation_manager),
3720 context_providers: vec![],
3721 planning_enabled: false,
3722 goal_tracking: false,
3723 hook_engine: None,
3724 skill_registry: None,
3725 ..AgentConfig::default()
3726 };
3727
3728 assert!(config.prompt_slots.build().contains("Test system prompt"));
3729 assert_eq!(config.max_tool_rounds, 10);
3730 assert!(config.permission_checker.is_some());
3731 assert!(config.confirmation_manager.is_some());
3732 assert!(config.context_providers.is_empty());
3733
3734 let debug_str = format!("{:?}", config);
3736 assert!(debug_str.contains("AgentConfig"));
3737 assert!(debug_str.contains("permission_checker: true"));
3738 assert!(debug_str.contains("confirmation_manager: true"));
3739 assert!(debug_str.contains("context_providers: 0"));
3740 }
3741
3742 use crate::context::{ContextItem, ContextType};
3747
3748 struct MockContextProvider {
3750 name: String,
3751 items: Vec<ContextItem>,
3752 on_turn_calls: std::sync::Arc<tokio::sync::RwLock<Vec<(String, String, String)>>>,
3753 }
3754
3755 impl MockContextProvider {
3756 fn new(name: &str) -> Self {
3757 Self {
3758 name: name.to_string(),
3759 items: Vec::new(),
3760 on_turn_calls: std::sync::Arc::new(tokio::sync::RwLock::new(Vec::new())),
3761 }
3762 }
3763
3764 fn with_items(mut self, items: Vec<ContextItem>) -> Self {
3765 self.items = items;
3766 self
3767 }
3768 }
3769
3770 #[async_trait::async_trait]
3771 impl ContextProvider for MockContextProvider {
3772 fn name(&self) -> &str {
3773 &self.name
3774 }
3775
3776 async fn query(&self, _query: &ContextQuery) -> anyhow::Result<ContextResult> {
3777 let mut result = ContextResult::new(&self.name);
3778 for item in &self.items {
3779 result.add_item(item.clone());
3780 }
3781 Ok(result)
3782 }
3783
3784 async fn on_turn_complete(
3785 &self,
3786 session_id: &str,
3787 prompt: &str,
3788 response: &str,
3789 ) -> anyhow::Result<()> {
3790 let mut calls = self.on_turn_calls.write().await;
3791 calls.push((
3792 session_id.to_string(),
3793 prompt.to_string(),
3794 response.to_string(),
3795 ));
3796 Ok(())
3797 }
3798 }
3799
3800 #[tokio::test]
3801 async fn test_agent_with_context_provider() {
3802 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
3803 "Response using context",
3804 )]));
3805
3806 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3807
3808 let provider =
3809 MockContextProvider::new("test-provider").with_items(vec![ContextItem::new(
3810 "ctx-1",
3811 ContextType::Resource,
3812 "Relevant context here",
3813 )
3814 .with_source("test://docs/example")]);
3815
3816 let config = AgentConfig {
3817 prompt_slots: SystemPromptSlots {
3818 extra: Some("You are helpful.".to_string()),
3819 ..Default::default()
3820 },
3821 context_providers: vec![Arc::new(provider)],
3822 ..Default::default()
3823 };
3824
3825 let agent = AgentLoop::new(
3826 mock_client.clone(),
3827 tool_executor,
3828 test_tool_context(),
3829 config,
3830 );
3831 let result = agent.execute(&[], "What is X?", None).await.unwrap();
3832
3833 assert_eq!(result.text, "Response using context");
3834 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 1);
3835 }
3836
3837 #[tokio::test]
3838 async fn test_agent_context_provider_events() {
3839 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
3840 "Answer",
3841 )]));
3842
3843 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3844
3845 let provider =
3846 MockContextProvider::new("event-provider").with_items(vec![ContextItem::new(
3847 "item-1",
3848 ContextType::Memory,
3849 "Memory content",
3850 )
3851 .with_token_count(50)]);
3852
3853 let config = AgentConfig {
3854 context_providers: vec![Arc::new(provider)],
3855 ..Default::default()
3856 };
3857
3858 let (tx, mut rx) = mpsc::channel(100);
3859 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3860 let _result = agent.execute(&[], "Test prompt", Some(tx)).await.unwrap();
3861
3862 let mut events = Vec::new();
3864 while let Ok(event) = rx.try_recv() {
3865 events.push(event);
3866 }
3867
3868 assert!(
3870 events
3871 .iter()
3872 .any(|e| matches!(e, AgentEvent::ContextResolving { .. })),
3873 "Should have ContextResolving event"
3874 );
3875 assert!(
3876 events
3877 .iter()
3878 .any(|e| matches!(e, AgentEvent::ContextResolved { .. })),
3879 "Should have ContextResolved event"
3880 );
3881
3882 for event in &events {
3884 if let AgentEvent::ContextResolved {
3885 total_items,
3886 total_tokens,
3887 } = event
3888 {
3889 assert_eq!(*total_items, 1);
3890 assert_eq!(*total_tokens, 50);
3891 }
3892 }
3893 }
3894
3895 #[tokio::test]
3896 async fn test_agent_multiple_context_providers() {
3897 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
3898 "Combined response",
3899 )]));
3900
3901 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3902
3903 let provider1 = MockContextProvider::new("provider-1").with_items(vec![ContextItem::new(
3904 "p1-1",
3905 ContextType::Resource,
3906 "Resource from P1",
3907 )
3908 .with_token_count(100)]);
3909
3910 let provider2 = MockContextProvider::new("provider-2").with_items(vec![
3911 ContextItem::new("p2-1", ContextType::Memory, "Memory from P2").with_token_count(50),
3912 ContextItem::new("p2-2", ContextType::Skill, "Skill from P2").with_token_count(75),
3913 ]);
3914
3915 let config = AgentConfig {
3916 prompt_slots: SystemPromptSlots {
3917 extra: Some("Base system prompt.".to_string()),
3918 ..Default::default()
3919 },
3920 context_providers: vec![Arc::new(provider1), Arc::new(provider2)],
3921 ..Default::default()
3922 };
3923
3924 let (tx, mut rx) = mpsc::channel(100);
3925 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3926 let result = agent.execute(&[], "Query", Some(tx)).await.unwrap();
3927
3928 assert_eq!(result.text, "Combined response");
3929
3930 while let Ok(event) = rx.try_recv() {
3932 if let AgentEvent::ContextResolved {
3933 total_items,
3934 total_tokens,
3935 } = event
3936 {
3937 assert_eq!(total_items, 3); assert_eq!(total_tokens, 225); }
3940 }
3941 }
3942
3943 #[tokio::test]
3944 async fn test_agent_no_context_providers() {
3945 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
3946 "No context",
3947 )]));
3948
3949 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3950
3951 let config = AgentConfig::default();
3953
3954 let (tx, mut rx) = mpsc::channel(100);
3955 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3956 let result = agent.execute(&[], "Simple prompt", Some(tx)).await.unwrap();
3957
3958 assert_eq!(result.text, "No context");
3959
3960 let mut events = Vec::new();
3962 while let Ok(event) = rx.try_recv() {
3963 events.push(event);
3964 }
3965
3966 assert!(
3967 !events
3968 .iter()
3969 .any(|e| matches!(e, AgentEvent::ContextResolving { .. })),
3970 "Should NOT have ContextResolving event"
3971 );
3972 }
3973
3974 #[tokio::test]
3975 async fn test_agent_context_on_turn_complete() {
3976 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
3977 "Final response",
3978 )]));
3979
3980 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3981
3982 let provider = Arc::new(MockContextProvider::new("memory-provider"));
3983 let on_turn_calls = provider.on_turn_calls.clone();
3984
3985 let config = AgentConfig {
3986 context_providers: vec![provider],
3987 ..Default::default()
3988 };
3989
3990 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3991
3992 let result = agent
3994 .execute_with_session(&[], "User prompt", Some("sess-123"), None)
3995 .await
3996 .unwrap();
3997
3998 assert_eq!(result.text, "Final response");
3999
4000 let calls = on_turn_calls.read().await;
4002 assert_eq!(calls.len(), 1);
4003 assert_eq!(calls[0].0, "sess-123");
4004 assert_eq!(calls[0].1, "User prompt");
4005 assert_eq!(calls[0].2, "Final response");
4006 }
4007
4008 #[tokio::test]
4009 async fn test_agent_context_on_turn_complete_no_session() {
4010 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4011 "Response",
4012 )]));
4013
4014 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4015
4016 let provider = Arc::new(MockContextProvider::new("memory-provider"));
4017 let on_turn_calls = provider.on_turn_calls.clone();
4018
4019 let config = AgentConfig {
4020 context_providers: vec![provider],
4021 ..Default::default()
4022 };
4023
4024 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
4025
4026 let _result = agent.execute(&[], "Prompt", None).await.unwrap();
4028
4029 let calls = on_turn_calls.read().await;
4031 assert!(calls.is_empty());
4032 }
4033
4034 #[tokio::test]
4035 async fn test_agent_build_augmented_system_prompt() {
4036 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response("OK")]));
4037
4038 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4039
4040 let provider = MockContextProvider::new("test").with_items(vec![ContextItem::new(
4041 "doc-1",
4042 ContextType::Resource,
4043 "Auth uses JWT tokens.",
4044 )
4045 .with_source("viking://docs/auth")]);
4046
4047 let config = AgentConfig {
4048 prompt_slots: SystemPromptSlots {
4049 extra: Some("You are helpful.".to_string()),
4050 ..Default::default()
4051 },
4052 context_providers: vec![Arc::new(provider)],
4053 ..Default::default()
4054 };
4055
4056 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
4057
4058 let context_results = agent.resolve_context("test", None).await;
4060 let augmented = agent.build_augmented_system_prompt(&context_results);
4061
4062 let augmented_str = augmented.unwrap();
4063 assert!(augmented_str.contains("You are helpful."));
4064 assert!(augmented_str.contains("<context source=\"viking://docs/auth\" type=\"Resource\">"));
4065 assert!(augmented_str.contains("Auth uses JWT tokens."));
4066 }
4067
4068 async fn collect_events(mut rx: mpsc::Receiver<AgentEvent>) -> Vec<AgentEvent> {
4074 let mut events = Vec::new();
4075 while let Ok(event) = rx.try_recv() {
4076 events.push(event);
4077 }
4078 while let Some(event) = rx.recv().await {
4080 events.push(event);
4081 }
4082 events
4083 }
4084
4085 #[tokio::test]
4086 async fn test_agent_multi_turn_tool_chain() {
4087 let mock_client = Arc::new(MockLlmClient::new(vec![
4089 MockLlmClient::tool_call_response(
4091 "t1",
4092 "bash",
4093 serde_json::json!({"command": "echo step1"}),
4094 ),
4095 MockLlmClient::tool_call_response(
4097 "t2",
4098 "bash",
4099 serde_json::json!({"command": "echo step2"}),
4100 ),
4101 MockLlmClient::text_response("Completed both steps: step1 then step2"),
4103 ]));
4104
4105 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4106 let config = AgentConfig::default();
4107
4108 let agent = AgentLoop::new(
4109 mock_client.clone(),
4110 tool_executor,
4111 test_tool_context(),
4112 config,
4113 );
4114 let result = agent.execute(&[], "Run two steps", None).await.unwrap();
4115
4116 assert_eq!(result.text, "Completed both steps: step1 then step2");
4117 assert_eq!(result.tool_calls_count, 2);
4118 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 3);
4119
4120 assert_eq!(result.messages[0].role, "user");
4122 assert_eq!(result.messages[1].role, "assistant"); assert_eq!(result.messages[2].role, "user"); assert_eq!(result.messages[3].role, "assistant"); assert_eq!(result.messages[4].role, "user"); assert_eq!(result.messages[5].role, "assistant"); assert_eq!(result.messages.len(), 6);
4128 }
4129
4130 #[tokio::test]
4131 async fn test_agent_conversation_history_preserved() {
4132 let existing_history = vec![
4134 Message::user("What is Rust?"),
4135 Message {
4136 role: "assistant".to_string(),
4137 content: vec![ContentBlock::Text {
4138 text: "Rust is a systems programming language.".to_string(),
4139 }],
4140 reasoning_content: None,
4141 },
4142 ];
4143
4144 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4145 "Rust was created by Graydon Hoare at Mozilla.",
4146 )]));
4147
4148 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4149 let agent = AgentLoop::new(
4150 mock_client.clone(),
4151 tool_executor,
4152 test_tool_context(),
4153 AgentConfig::default(),
4154 );
4155
4156 let result = agent
4157 .execute(&existing_history, "Who created it?", None)
4158 .await
4159 .unwrap();
4160
4161 assert_eq!(result.messages.len(), 4);
4163 assert_eq!(result.messages[0].text(), "What is Rust?");
4164 assert_eq!(
4165 result.messages[1].text(),
4166 "Rust is a systems programming language."
4167 );
4168 assert_eq!(result.messages[2].text(), "Who created it?");
4169 assert_eq!(
4170 result.messages[3].text(),
4171 "Rust was created by Graydon Hoare at Mozilla."
4172 );
4173 }
4174
4175 #[tokio::test]
4176 async fn test_agent_event_stream_completeness() {
4177 let mock_client = Arc::new(MockLlmClient::new(vec![
4179 MockLlmClient::tool_call_response(
4180 "t1",
4181 "bash",
4182 serde_json::json!({"command": "echo hi"}),
4183 ),
4184 MockLlmClient::text_response("Done"),
4185 ]));
4186
4187 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4188 let agent = AgentLoop::new(
4189 mock_client,
4190 tool_executor,
4191 test_tool_context(),
4192 AgentConfig::default(),
4193 );
4194
4195 let (tx, rx) = mpsc::channel(100);
4196 let result = agent.execute(&[], "Say hi", Some(tx)).await.unwrap();
4197 assert_eq!(result.text, "Done");
4198
4199 let events = collect_events(rx).await;
4200
4201 let event_types: Vec<&str> = events
4203 .iter()
4204 .map(|e| match e {
4205 AgentEvent::Start { .. } => "Start",
4206 AgentEvent::TurnStart { .. } => "TurnStart",
4207 AgentEvent::TurnEnd { .. } => "TurnEnd",
4208 AgentEvent::ToolEnd { .. } => "ToolEnd",
4209 AgentEvent::End { .. } => "End",
4210 _ => "Other",
4211 })
4212 .collect();
4213
4214 assert_eq!(event_types.first(), Some(&"Start"));
4216 assert_eq!(event_types.last(), Some(&"End"));
4217
4218 let turn_starts = event_types.iter().filter(|&&t| t == "TurnStart").count();
4220 assert_eq!(turn_starts, 2);
4221
4222 let tool_ends = event_types.iter().filter(|&&t| t == "ToolEnd").count();
4224 assert_eq!(tool_ends, 1);
4225 }
4226
4227 #[tokio::test]
4228 async fn test_agent_multiple_tools_single_turn() {
4229 let mock_client = Arc::new(MockLlmClient::new(vec![
4231 LlmResponse {
4232 message: Message {
4233 role: "assistant".to_string(),
4234 content: vec![
4235 ContentBlock::ToolUse {
4236 id: "t1".to_string(),
4237 name: "bash".to_string(),
4238 input: serde_json::json!({"command": "echo first"}),
4239 },
4240 ContentBlock::ToolUse {
4241 id: "t2".to_string(),
4242 name: "bash".to_string(),
4243 input: serde_json::json!({"command": "echo second"}),
4244 },
4245 ],
4246 reasoning_content: None,
4247 },
4248 usage: TokenUsage {
4249 prompt_tokens: 10,
4250 completion_tokens: 5,
4251 total_tokens: 15,
4252 cache_read_tokens: None,
4253 cache_write_tokens: None,
4254 },
4255 stop_reason: Some("tool_use".to_string()),
4256 },
4257 MockLlmClient::text_response("Both commands ran"),
4258 ]));
4259
4260 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4261 let agent = AgentLoop::new(
4262 mock_client.clone(),
4263 tool_executor,
4264 test_tool_context(),
4265 AgentConfig::default(),
4266 );
4267
4268 let result = agent.execute(&[], "Run both", None).await.unwrap();
4269
4270 assert_eq!(result.text, "Both commands ran");
4271 assert_eq!(result.tool_calls_count, 2);
4272 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 2); assert_eq!(result.messages[0].role, "user");
4276 assert_eq!(result.messages[1].role, "assistant");
4277 assert_eq!(result.messages[2].role, "user"); assert_eq!(result.messages[3].role, "user"); assert_eq!(result.messages[4].role, "assistant");
4280 }
4281
4282 #[tokio::test]
4283 async fn test_agent_token_usage_accumulation() {
4284 let mock_client = Arc::new(MockLlmClient::new(vec![
4286 MockLlmClient::tool_call_response(
4287 "t1",
4288 "bash",
4289 serde_json::json!({"command": "echo x"}),
4290 ),
4291 MockLlmClient::text_response("Done"),
4292 ]));
4293
4294 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4295 let agent = AgentLoop::new(
4296 mock_client,
4297 tool_executor,
4298 test_tool_context(),
4299 AgentConfig::default(),
4300 );
4301
4302 let result = agent.execute(&[], "test", None).await.unwrap();
4303
4304 assert_eq!(result.usage.prompt_tokens, 20);
4307 assert_eq!(result.usage.completion_tokens, 10);
4308 assert_eq!(result.usage.total_tokens, 30);
4309 }
4310
4311 #[tokio::test]
4312 async fn test_agent_system_prompt_passed() {
4313 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4315 "I am a coding assistant.",
4316 )]));
4317
4318 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4319 let config = AgentConfig {
4320 prompt_slots: SystemPromptSlots {
4321 extra: Some("You are a coding assistant.".to_string()),
4322 ..Default::default()
4323 },
4324 ..Default::default()
4325 };
4326
4327 let agent = AgentLoop::new(
4328 mock_client.clone(),
4329 tool_executor,
4330 test_tool_context(),
4331 config,
4332 );
4333 let result = agent.execute(&[], "What are you?", None).await.unwrap();
4334
4335 assert_eq!(result.text, "I am a coding assistant.");
4336 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 1);
4337 }
4338
4339 #[tokio::test]
4340 async fn test_agent_max_rounds_with_persistent_tool_calls() {
4341 let mut responses = Vec::new();
4343 for i in 0..15 {
4344 responses.push(MockLlmClient::tool_call_response(
4345 &format!("t{}", i),
4346 "bash",
4347 serde_json::json!({"command": format!("echo round{}", i)}),
4348 ));
4349 }
4350
4351 let mock_client = Arc::new(MockLlmClient::new(responses));
4352 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4353 let config = AgentConfig {
4354 max_tool_rounds: 5,
4355 ..Default::default()
4356 };
4357
4358 let agent = AgentLoop::new(
4359 mock_client.clone(),
4360 tool_executor,
4361 test_tool_context(),
4362 config,
4363 );
4364 let result = agent.execute(&[], "Loop forever", None).await;
4365
4366 assert!(result.is_err());
4367 let err = result.unwrap_err().to_string();
4368 assert!(err.contains("Max tool rounds (5) exceeded"));
4369 }
4370
4371 #[tokio::test]
4372 async fn test_agent_end_event_contains_final_text() {
4373 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4374 "Final answer here",
4375 )]));
4376
4377 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4378 let agent = AgentLoop::new(
4379 mock_client,
4380 tool_executor,
4381 test_tool_context(),
4382 AgentConfig::default(),
4383 );
4384
4385 let (tx, rx) = mpsc::channel(100);
4386 agent.execute(&[], "test", Some(tx)).await.unwrap();
4387
4388 let events = collect_events(rx).await;
4389 let end_event = events.iter().find(|e| matches!(e, AgentEvent::End { .. }));
4390 assert!(end_event.is_some());
4391
4392 if let AgentEvent::End { text, usage } = end_event.unwrap() {
4393 assert_eq!(text, "Final answer here");
4394 assert_eq!(usage.total_tokens, 15);
4395 }
4396 }
4397}
4398
4399#[cfg(test)]
4400mod extra_agent_tests {
4401 use super::*;
4402 use crate::agent::tests::MockLlmClient;
4403 use crate::queue::SessionQueueConfig;
4404 use crate::tools::ToolExecutor;
4405 use std::path::PathBuf;
4406 use std::sync::atomic::{AtomicUsize, Ordering};
4407
4408 fn test_tool_context() -> ToolContext {
4409 ToolContext::new(PathBuf::from("/tmp"))
4410 }
4411
4412 #[test]
4417 fn test_agent_config_debug() {
4418 let config = AgentConfig {
4419 prompt_slots: SystemPromptSlots {
4420 extra: Some("You are helpful".to_string()),
4421 ..Default::default()
4422 },
4423 tools: vec![],
4424 max_tool_rounds: 10,
4425 permission_checker: None,
4426 confirmation_manager: None,
4427 context_providers: vec![],
4428 planning_enabled: true,
4429 goal_tracking: false,
4430 hook_engine: None,
4431 skill_registry: None,
4432 ..AgentConfig::default()
4433 };
4434 let debug = format!("{:?}", config);
4435 assert!(debug.contains("AgentConfig"));
4436 assert!(debug.contains("planning_enabled"));
4437 }
4438
4439 #[test]
4440 fn test_agent_config_default_values() {
4441 let config = AgentConfig::default();
4442 assert_eq!(config.max_tool_rounds, MAX_TOOL_ROUNDS);
4443 assert!(!config.planning_enabled);
4444 assert!(!config.goal_tracking);
4445 assert!(config.context_providers.is_empty());
4446 }
4447
4448 #[test]
4453 fn test_agent_event_serialize_start() {
4454 let event = AgentEvent::Start {
4455 prompt: "Hello".to_string(),
4456 };
4457 let json = serde_json::to_string(&event).unwrap();
4458 assert!(json.contains("agent_start"));
4459 assert!(json.contains("Hello"));
4460 }
4461
4462 #[test]
4463 fn test_agent_event_serialize_text_delta() {
4464 let event = AgentEvent::TextDelta {
4465 text: "chunk".to_string(),
4466 };
4467 let json = serde_json::to_string(&event).unwrap();
4468 assert!(json.contains("text_delta"));
4469 }
4470
4471 #[test]
4472 fn test_agent_event_serialize_tool_start() {
4473 let event = AgentEvent::ToolStart {
4474 id: "t1".to_string(),
4475 name: "bash".to_string(),
4476 };
4477 let json = serde_json::to_string(&event).unwrap();
4478 assert!(json.contains("tool_start"));
4479 assert!(json.contains("bash"));
4480 }
4481
4482 #[test]
4483 fn test_agent_event_serialize_tool_end() {
4484 let event = AgentEvent::ToolEnd {
4485 id: "t1".to_string(),
4486 name: "bash".to_string(),
4487 output: "hello".to_string(),
4488 exit_code: 0,
4489 };
4490 let json = serde_json::to_string(&event).unwrap();
4491 assert!(json.contains("tool_end"));
4492 }
4493
4494 #[test]
4495 fn test_agent_event_serialize_error() {
4496 let event = AgentEvent::Error {
4497 message: "oops".to_string(),
4498 };
4499 let json = serde_json::to_string(&event).unwrap();
4500 assert!(json.contains("error"));
4501 assert!(json.contains("oops"));
4502 }
4503
4504 #[test]
4505 fn test_agent_event_serialize_confirmation_required() {
4506 let event = AgentEvent::ConfirmationRequired {
4507 tool_id: "t1".to_string(),
4508 tool_name: "bash".to_string(),
4509 args: serde_json::json!({"cmd": "rm"}),
4510 timeout_ms: 30000,
4511 };
4512 let json = serde_json::to_string(&event).unwrap();
4513 assert!(json.contains("confirmation_required"));
4514 }
4515
4516 #[test]
4517 fn test_agent_event_serialize_confirmation_received() {
4518 let event = AgentEvent::ConfirmationReceived {
4519 tool_id: "t1".to_string(),
4520 approved: true,
4521 reason: Some("safe".to_string()),
4522 };
4523 let json = serde_json::to_string(&event).unwrap();
4524 assert!(json.contains("confirmation_received"));
4525 }
4526
4527 #[test]
4528 fn test_agent_event_serialize_confirmation_timeout() {
4529 let event = AgentEvent::ConfirmationTimeout {
4530 tool_id: "t1".to_string(),
4531 action_taken: "rejected".to_string(),
4532 };
4533 let json = serde_json::to_string(&event).unwrap();
4534 assert!(json.contains("confirmation_timeout"));
4535 }
4536
4537 #[test]
4538 fn test_agent_event_serialize_external_task_pending() {
4539 let event = AgentEvent::ExternalTaskPending {
4540 task_id: "task-1".to_string(),
4541 session_id: "sess-1".to_string(),
4542 lane: crate::hitl::SessionLane::Execute,
4543 command_type: "bash".to_string(),
4544 payload: serde_json::json!({}),
4545 timeout_ms: 60000,
4546 };
4547 let json = serde_json::to_string(&event).unwrap();
4548 assert!(json.contains("external_task_pending"));
4549 }
4550
4551 #[test]
4552 fn test_agent_event_serialize_external_task_completed() {
4553 let event = AgentEvent::ExternalTaskCompleted {
4554 task_id: "task-1".to_string(),
4555 session_id: "sess-1".to_string(),
4556 success: false,
4557 };
4558 let json = serde_json::to_string(&event).unwrap();
4559 assert!(json.contains("external_task_completed"));
4560 }
4561
4562 #[test]
4563 fn test_agent_event_serialize_permission_denied() {
4564 let event = AgentEvent::PermissionDenied {
4565 tool_id: "t1".to_string(),
4566 tool_name: "bash".to_string(),
4567 args: serde_json::json!({}),
4568 reason: "denied".to_string(),
4569 };
4570 let json = serde_json::to_string(&event).unwrap();
4571 assert!(json.contains("permission_denied"));
4572 }
4573
4574 #[test]
4575 fn test_agent_event_serialize_context_compacted() {
4576 let event = AgentEvent::ContextCompacted {
4577 session_id: "sess-1".to_string(),
4578 before_messages: 100,
4579 after_messages: 20,
4580 percent_before: 0.85,
4581 };
4582 let json = serde_json::to_string(&event).unwrap();
4583 assert!(json.contains("context_compacted"));
4584 }
4585
4586 #[test]
4587 fn test_agent_event_serialize_turn_start() {
4588 let event = AgentEvent::TurnStart { turn: 3 };
4589 let json = serde_json::to_string(&event).unwrap();
4590 assert!(json.contains("turn_start"));
4591 }
4592
4593 #[test]
4594 fn test_agent_event_serialize_turn_end() {
4595 let event = AgentEvent::TurnEnd {
4596 turn: 3,
4597 usage: TokenUsage::default(),
4598 };
4599 let json = serde_json::to_string(&event).unwrap();
4600 assert!(json.contains("turn_end"));
4601 }
4602
4603 #[test]
4604 fn test_agent_event_serialize_end() {
4605 let event = AgentEvent::End {
4606 text: "Done".to_string(),
4607 usage: TokenUsage {
4608 prompt_tokens: 100,
4609 completion_tokens: 50,
4610 total_tokens: 150,
4611 cache_read_tokens: None,
4612 cache_write_tokens: None,
4613 },
4614 };
4615 let json = serde_json::to_string(&event).unwrap();
4616 assert!(json.contains("agent_end"));
4617 }
4618
4619 #[test]
4624 fn test_agent_result_fields() {
4625 let result = AgentResult {
4626 text: "output".to_string(),
4627 messages: vec![Message::user("hello")],
4628 usage: TokenUsage::default(),
4629 tool_calls_count: 3,
4630 };
4631 assert_eq!(result.text, "output");
4632 assert_eq!(result.messages.len(), 1);
4633 assert_eq!(result.tool_calls_count, 3);
4634 }
4635
4636 #[test]
4641 fn test_agent_event_serialize_context_resolving() {
4642 let event = AgentEvent::ContextResolving {
4643 providers: vec!["provider1".to_string(), "provider2".to_string()],
4644 };
4645 let json = serde_json::to_string(&event).unwrap();
4646 assert!(json.contains("context_resolving"));
4647 assert!(json.contains("provider1"));
4648 }
4649
4650 #[test]
4651 fn test_agent_event_serialize_context_resolved() {
4652 let event = AgentEvent::ContextResolved {
4653 total_items: 5,
4654 total_tokens: 1000,
4655 };
4656 let json = serde_json::to_string(&event).unwrap();
4657 assert!(json.contains("context_resolved"));
4658 assert!(json.contains("1000"));
4659 }
4660
4661 #[test]
4662 fn test_agent_event_serialize_command_dead_lettered() {
4663 let event = AgentEvent::CommandDeadLettered {
4664 command_id: "cmd-1".to_string(),
4665 command_type: "bash".to_string(),
4666 lane: "execute".to_string(),
4667 error: "timeout".to_string(),
4668 attempts: 3,
4669 };
4670 let json = serde_json::to_string(&event).unwrap();
4671 assert!(json.contains("command_dead_lettered"));
4672 assert!(json.contains("cmd-1"));
4673 }
4674
4675 #[test]
4676 fn test_agent_event_serialize_command_retry() {
4677 let event = AgentEvent::CommandRetry {
4678 command_id: "cmd-2".to_string(),
4679 command_type: "read".to_string(),
4680 lane: "query".to_string(),
4681 attempt: 2,
4682 delay_ms: 1000,
4683 };
4684 let json = serde_json::to_string(&event).unwrap();
4685 assert!(json.contains("command_retry"));
4686 assert!(json.contains("cmd-2"));
4687 }
4688
4689 #[test]
4690 fn test_agent_event_serialize_queue_alert() {
4691 let event = AgentEvent::QueueAlert {
4692 level: "warning".to_string(),
4693 alert_type: "depth".to_string(),
4694 message: "Queue depth exceeded".to_string(),
4695 };
4696 let json = serde_json::to_string(&event).unwrap();
4697 assert!(json.contains("queue_alert"));
4698 assert!(json.contains("warning"));
4699 }
4700
4701 #[test]
4702 fn test_agent_event_serialize_task_updated() {
4703 let event = AgentEvent::TaskUpdated {
4704 session_id: "sess-1".to_string(),
4705 tasks: vec![],
4706 };
4707 let json = serde_json::to_string(&event).unwrap();
4708 assert!(json.contains("task_updated"));
4709 assert!(json.contains("sess-1"));
4710 }
4711
4712 #[test]
4713 fn test_agent_event_serialize_memory_stored() {
4714 let event = AgentEvent::MemoryStored {
4715 memory_id: "mem-1".to_string(),
4716 memory_type: "conversation".to_string(),
4717 importance: 0.8,
4718 tags: vec!["important".to_string()],
4719 };
4720 let json = serde_json::to_string(&event).unwrap();
4721 assert!(json.contains("memory_stored"));
4722 assert!(json.contains("mem-1"));
4723 }
4724
4725 #[test]
4726 fn test_agent_event_serialize_memory_recalled() {
4727 let event = AgentEvent::MemoryRecalled {
4728 memory_id: "mem-2".to_string(),
4729 content: "Previous conversation".to_string(),
4730 relevance: 0.9,
4731 };
4732 let json = serde_json::to_string(&event).unwrap();
4733 assert!(json.contains("memory_recalled"));
4734 assert!(json.contains("mem-2"));
4735 }
4736
4737 #[test]
4738 fn test_agent_event_serialize_memories_searched() {
4739 let event = AgentEvent::MemoriesSearched {
4740 query: Some("search term".to_string()),
4741 tags: vec!["tag1".to_string()],
4742 result_count: 5,
4743 };
4744 let json = serde_json::to_string(&event).unwrap();
4745 assert!(json.contains("memories_searched"));
4746 assert!(json.contains("search term"));
4747 }
4748
4749 #[test]
4750 fn test_agent_event_serialize_memory_cleared() {
4751 let event = AgentEvent::MemoryCleared {
4752 tier: "short_term".to_string(),
4753 count: 10,
4754 };
4755 let json = serde_json::to_string(&event).unwrap();
4756 assert!(json.contains("memory_cleared"));
4757 assert!(json.contains("short_term"));
4758 }
4759
4760 #[test]
4761 fn test_agent_event_serialize_subagent_start() {
4762 let event = AgentEvent::SubagentStart {
4763 task_id: "task-1".to_string(),
4764 session_id: "child-sess".to_string(),
4765 parent_session_id: "parent-sess".to_string(),
4766 agent: "explore".to_string(),
4767 description: "Explore codebase".to_string(),
4768 };
4769 let json = serde_json::to_string(&event).unwrap();
4770 assert!(json.contains("subagent_start"));
4771 assert!(json.contains("explore"));
4772 }
4773
4774 #[test]
4775 fn test_agent_event_serialize_subagent_progress() {
4776 let event = AgentEvent::SubagentProgress {
4777 task_id: "task-1".to_string(),
4778 session_id: "child-sess".to_string(),
4779 status: "processing".to_string(),
4780 metadata: serde_json::json!({"progress": 50}),
4781 };
4782 let json = serde_json::to_string(&event).unwrap();
4783 assert!(json.contains("subagent_progress"));
4784 assert!(json.contains("processing"));
4785 }
4786
4787 #[test]
4788 fn test_agent_event_serialize_subagent_end() {
4789 let event = AgentEvent::SubagentEnd {
4790 task_id: "task-1".to_string(),
4791 session_id: "child-sess".to_string(),
4792 agent: "explore".to_string(),
4793 output: "Found 10 files".to_string(),
4794 success: true,
4795 };
4796 let json = serde_json::to_string(&event).unwrap();
4797 assert!(json.contains("subagent_end"));
4798 assert!(json.contains("Found 10 files"));
4799 }
4800
4801 #[test]
4802 fn test_agent_event_serialize_planning_start() {
4803 let event = AgentEvent::PlanningStart {
4804 prompt: "Build a web app".to_string(),
4805 };
4806 let json = serde_json::to_string(&event).unwrap();
4807 assert!(json.contains("planning_start"));
4808 assert!(json.contains("Build a web app"));
4809 }
4810
4811 #[test]
4812 fn test_agent_event_serialize_planning_end() {
4813 use crate::planning::{Complexity, ExecutionPlan};
4814 let plan = ExecutionPlan::new("Test goal".to_string(), Complexity::Simple);
4815 let event = AgentEvent::PlanningEnd {
4816 plan,
4817 estimated_steps: 3,
4818 };
4819 let json = serde_json::to_string(&event).unwrap();
4820 assert!(json.contains("planning_end"));
4821 assert!(json.contains("estimated_steps"));
4822 }
4823
4824 #[test]
4825 fn test_agent_event_serialize_step_start() {
4826 let event = AgentEvent::StepStart {
4827 step_id: "step-1".to_string(),
4828 description: "Initialize project".to_string(),
4829 step_number: 1,
4830 total_steps: 5,
4831 };
4832 let json = serde_json::to_string(&event).unwrap();
4833 assert!(json.contains("step_start"));
4834 assert!(json.contains("Initialize project"));
4835 }
4836
4837 #[test]
4838 fn test_agent_event_serialize_step_end() {
4839 let event = AgentEvent::StepEnd {
4840 step_id: "step-1".to_string(),
4841 status: TaskStatus::Completed,
4842 step_number: 1,
4843 total_steps: 5,
4844 };
4845 let json = serde_json::to_string(&event).unwrap();
4846 assert!(json.contains("step_end"));
4847 assert!(json.contains("step-1"));
4848 }
4849
4850 #[test]
4851 fn test_agent_event_serialize_goal_extracted() {
4852 use crate::planning::AgentGoal;
4853 let goal = AgentGoal::new("Complete the task".to_string());
4854 let event = AgentEvent::GoalExtracted { goal };
4855 let json = serde_json::to_string(&event).unwrap();
4856 assert!(json.contains("goal_extracted"));
4857 }
4858
4859 #[test]
4860 fn test_agent_event_serialize_goal_progress() {
4861 let event = AgentEvent::GoalProgress {
4862 goal: "Build app".to_string(),
4863 progress: 0.5,
4864 completed_steps: 2,
4865 total_steps: 4,
4866 };
4867 let json = serde_json::to_string(&event).unwrap();
4868 assert!(json.contains("goal_progress"));
4869 assert!(json.contains("0.5"));
4870 }
4871
4872 #[test]
4873 fn test_agent_event_serialize_goal_achieved() {
4874 let event = AgentEvent::GoalAchieved {
4875 goal: "Build app".to_string(),
4876 total_steps: 4,
4877 duration_ms: 5000,
4878 };
4879 let json = serde_json::to_string(&event).unwrap();
4880 assert!(json.contains("goal_achieved"));
4881 assert!(json.contains("5000"));
4882 }
4883
4884 #[tokio::test]
4885 async fn test_extract_goal_with_json_response() {
4886 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4888 r#"{"description": "Build web app", "success_criteria": ["App runs on port 3000", "Has login page"]}"#,
4889 )]));
4890 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4891 let agent = AgentLoop::new(
4892 mock_client,
4893 tool_executor,
4894 test_tool_context(),
4895 AgentConfig::default(),
4896 );
4897
4898 let goal = agent.extract_goal("Build a web app").await.unwrap();
4899 assert_eq!(goal.description, "Build web app");
4900 assert_eq!(goal.success_criteria.len(), 2);
4901 assert_eq!(goal.success_criteria[0], "App runs on port 3000");
4902 }
4903
4904 #[tokio::test]
4905 async fn test_extract_goal_fallback_on_non_json() {
4906 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4908 "Some non-JSON response",
4909 )]));
4910 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4911 let agent = AgentLoop::new(
4912 mock_client,
4913 tool_executor,
4914 test_tool_context(),
4915 AgentConfig::default(),
4916 );
4917
4918 let goal = agent.extract_goal("Do something").await.unwrap();
4919 assert_eq!(goal.description, "Do something");
4921 assert_eq!(goal.success_criteria.len(), 2);
4923 }
4924
4925 #[tokio::test]
4926 async fn test_check_goal_achievement_json_yes() {
4927 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4928 r#"{"achieved": true, "progress": 1.0, "remaining_criteria": []}"#,
4929 )]));
4930 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4931 let agent = AgentLoop::new(
4932 mock_client,
4933 tool_executor,
4934 test_tool_context(),
4935 AgentConfig::default(),
4936 );
4937
4938 let goal = crate::planning::AgentGoal::new("Test goal".to_string());
4939 let achieved = agent
4940 .check_goal_achievement(&goal, "All done")
4941 .await
4942 .unwrap();
4943 assert!(achieved);
4944 }
4945
4946 #[tokio::test]
4947 async fn test_check_goal_achievement_fallback_not_done() {
4948 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4950 "invalid json",
4951 )]));
4952 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4953 let agent = AgentLoop::new(
4954 mock_client,
4955 tool_executor,
4956 test_tool_context(),
4957 AgentConfig::default(),
4958 );
4959
4960 let goal = crate::planning::AgentGoal::new("Test goal".to_string());
4961 let achieved = agent
4963 .check_goal_achievement(&goal, "still working")
4964 .await
4965 .unwrap();
4966 assert!(!achieved);
4967 }
4968
4969 #[test]
4974 fn test_build_augmented_system_prompt_empty_context() {
4975 let mock_client = Arc::new(MockLlmClient::new(vec![]));
4976 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4977 let config = AgentConfig {
4978 prompt_slots: SystemPromptSlots {
4979 extra: Some("Base prompt".to_string()),
4980 ..Default::default()
4981 },
4982 ..Default::default()
4983 };
4984 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
4985
4986 let result = agent.build_augmented_system_prompt(&[]);
4987 assert!(result.unwrap().contains("Base prompt"));
4988 }
4989
4990 #[test]
4991 fn test_build_augmented_system_prompt_no_custom_slots() {
4992 let mock_client = Arc::new(MockLlmClient::new(vec![]));
4993 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4994 let agent = AgentLoop::new(
4995 mock_client,
4996 tool_executor,
4997 test_tool_context(),
4998 AgentConfig::default(),
4999 );
5000
5001 let result = agent.build_augmented_system_prompt(&[]);
5002 assert!(result.is_some());
5004 assert!(result.unwrap().contains("Core Behaviour"));
5005 }
5006
5007 #[test]
5008 fn test_build_augmented_system_prompt_with_context_no_base() {
5009 use crate::context::{ContextItem, ContextResult, ContextType};
5010
5011 let mock_client = Arc::new(MockLlmClient::new(vec![]));
5012 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5013 let agent = AgentLoop::new(
5014 mock_client,
5015 tool_executor,
5016 test_tool_context(),
5017 AgentConfig::default(),
5018 );
5019
5020 let context = vec![ContextResult {
5021 provider: "test".to_string(),
5022 items: vec![ContextItem::new("id1", ContextType::Resource, "Content")],
5023 total_tokens: 10,
5024 truncated: false,
5025 }];
5026
5027 let result = agent.build_augmented_system_prompt(&context);
5028 assert!(result.is_some());
5029 let text = result.unwrap();
5030 assert!(text.contains("<context"));
5031 assert!(text.contains("Content"));
5032 }
5033
5034 #[test]
5039 fn test_agent_result_clone() {
5040 let result = AgentResult {
5041 text: "output".to_string(),
5042 messages: vec![Message::user("hello")],
5043 usage: TokenUsage::default(),
5044 tool_calls_count: 3,
5045 };
5046 let cloned = result.clone();
5047 assert_eq!(cloned.text, result.text);
5048 assert_eq!(cloned.tool_calls_count, result.tool_calls_count);
5049 }
5050
5051 #[test]
5052 fn test_agent_result_debug() {
5053 let result = AgentResult {
5054 text: "output".to_string(),
5055 messages: vec![Message::user("hello")],
5056 usage: TokenUsage::default(),
5057 tool_calls_count: 3,
5058 };
5059 let debug = format!("{:?}", result);
5060 assert!(debug.contains("AgentResult"));
5061 assert!(debug.contains("output"));
5062 }
5063
5064 #[tokio::test]
5073 async fn test_tool_command_command_type() {
5074 let executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5075 let cmd = ToolCommand {
5076 tool_executor: executor,
5077 tool_name: "read".to_string(),
5078 tool_args: serde_json::json!({"file": "test.rs"}),
5079 skill_registry: None,
5080 tool_context: test_tool_context(),
5081 };
5082 assert_eq!(cmd.command_type(), "read");
5083 }
5084
5085 #[tokio::test]
5086 async fn test_tool_command_payload() {
5087 let executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5088 let args = serde_json::json!({"file": "test.rs", "offset": 10});
5089 let cmd = ToolCommand {
5090 tool_executor: executor,
5091 tool_name: "read".to_string(),
5092 tool_args: args.clone(),
5093 skill_registry: None,
5094 tool_context: test_tool_context(),
5095 };
5096 assert_eq!(cmd.payload(), args);
5097 }
5098
5099 #[tokio::test(flavor = "multi_thread")]
5104 async fn test_agent_loop_with_queue() {
5105 use tokio::sync::broadcast;
5106
5107 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
5108 "Hello",
5109 )]));
5110 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5111 let config = AgentConfig::default();
5112
5113 let (event_tx, _) = broadcast::channel(100);
5114 let queue = SessionLaneQueue::new("test-session", SessionQueueConfig::default(), event_tx)
5115 .await
5116 .unwrap();
5117
5118 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config)
5119 .with_queue(Arc::new(queue));
5120
5121 assert!(agent.command_queue.is_some());
5122 }
5123
5124 #[tokio::test]
5125 async fn test_agent_loop_without_queue() {
5126 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
5127 "Hello",
5128 )]));
5129 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5130 let config = AgentConfig::default();
5131
5132 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
5133
5134 assert!(agent.command_queue.is_none());
5135 }
5136
5137 #[tokio::test]
5142 async fn test_execute_plan_parallel_independent() {
5143 use crate::planning::{Complexity, ExecutionPlan, Task};
5144
5145 let mock_client = Arc::new(MockLlmClient::new(vec![
5148 MockLlmClient::text_response("Step 1 done"),
5149 MockLlmClient::text_response("Step 2 done"),
5150 MockLlmClient::text_response("Step 3 done"),
5151 ]));
5152
5153 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5154 let config = AgentConfig::default();
5155 let agent = AgentLoop::new(
5156 mock_client.clone(),
5157 tool_executor,
5158 test_tool_context(),
5159 config,
5160 );
5161
5162 let mut plan = ExecutionPlan::new("Test parallel", Complexity::Simple);
5163 plan.add_step(Task::new("s1", "First step"));
5164 plan.add_step(Task::new("s2", "Second step"));
5165 plan.add_step(Task::new("s3", "Third step"));
5166
5167 let (tx, mut rx) = mpsc::channel(100);
5168 let result = agent.execute_plan(&[], &plan, Some(tx)).await.unwrap();
5169
5170 assert_eq!(result.usage.total_tokens, 45);
5172
5173 let mut step_starts = Vec::new();
5175 let mut step_ends = Vec::new();
5176 rx.close();
5177 while let Some(event) = rx.recv().await {
5178 match event {
5179 AgentEvent::StepStart { step_id, .. } => step_starts.push(step_id),
5180 AgentEvent::StepEnd {
5181 step_id, status, ..
5182 } => {
5183 assert_eq!(status, TaskStatus::Completed);
5184 step_ends.push(step_id);
5185 }
5186 _ => {}
5187 }
5188 }
5189 assert_eq!(step_starts.len(), 3);
5190 assert_eq!(step_ends.len(), 3);
5191 }
5192
5193 #[tokio::test]
5194 async fn test_execute_plan_respects_dependencies() {
5195 use crate::planning::{Complexity, ExecutionPlan, Task};
5196
5197 let mock_client = Arc::new(MockLlmClient::new(vec![
5200 MockLlmClient::text_response("Step 1 done"),
5201 MockLlmClient::text_response("Step 2 done"),
5202 MockLlmClient::text_response("Step 3 done"),
5203 ]));
5204
5205 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5206 let config = AgentConfig::default();
5207 let agent = AgentLoop::new(
5208 mock_client.clone(),
5209 tool_executor,
5210 test_tool_context(),
5211 config,
5212 );
5213
5214 let mut plan = ExecutionPlan::new("Test deps", Complexity::Medium);
5215 plan.add_step(Task::new("s1", "Independent A"));
5216 plan.add_step(Task::new("s2", "Independent B"));
5217 plan.add_step(
5218 Task::new("s3", "Depends on A+B")
5219 .with_dependencies(vec!["s1".to_string(), "s2".to_string()]),
5220 );
5221
5222 let (tx, mut rx) = mpsc::channel(100);
5223 let result = agent.execute_plan(&[], &plan, Some(tx)).await.unwrap();
5224
5225 assert_eq!(result.usage.total_tokens, 45);
5227
5228 let mut events = Vec::new();
5230 rx.close();
5231 while let Some(event) = rx.recv().await {
5232 match &event {
5233 AgentEvent::StepStart { step_id, .. } => {
5234 events.push(format!("start:{}", step_id));
5235 }
5236 AgentEvent::StepEnd { step_id, .. } => {
5237 events.push(format!("end:{}", step_id));
5238 }
5239 _ => {}
5240 }
5241 }
5242
5243 let s1_end = events.iter().position(|e| e == "end:s1").unwrap();
5245 let s2_end = events.iter().position(|e| e == "end:s2").unwrap();
5246 let s3_start = events.iter().position(|e| e == "start:s3").unwrap();
5247 assert!(
5248 s3_start > s1_end,
5249 "s3 started before s1 ended: {:?}",
5250 events
5251 );
5252 assert!(
5253 s3_start > s2_end,
5254 "s3 started before s2 ended: {:?}",
5255 events
5256 );
5257
5258 assert!(result.text.contains("Step 3 done") || !result.text.is_empty());
5260 }
5261
5262 #[tokio::test]
5263 async fn test_execute_plan_handles_step_failure() {
5264 use crate::planning::{Complexity, ExecutionPlan, Task};
5265
5266 let mock_client = Arc::new(MockLlmClient::new(vec![
5276 MockLlmClient::text_response("s1 done"),
5278 MockLlmClient::text_response("s3 done"),
5279 ]));
5282
5283 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5284 let config = AgentConfig::default();
5285 let agent = AgentLoop::new(
5286 mock_client.clone(),
5287 tool_executor,
5288 test_tool_context(),
5289 config,
5290 );
5291
5292 let mut plan = ExecutionPlan::new("Test failure", Complexity::Medium);
5293 plan.add_step(Task::new("s1", "Independent step"));
5294 plan.add_step(Task::new("s2", "Depends on s1").with_dependencies(vec!["s1".to_string()]));
5295 plan.add_step(Task::new("s3", "Another independent"));
5296 plan.add_step(Task::new("s4", "Depends on s2").with_dependencies(vec!["s2".to_string()]));
5297
5298 let (tx, mut rx) = mpsc::channel(100);
5299 let _result = agent.execute_plan(&[], &plan, Some(tx)).await.unwrap();
5300
5301 let mut completed_steps = Vec::new();
5304 let mut failed_steps = Vec::new();
5305 rx.close();
5306 while let Some(event) = rx.recv().await {
5307 if let AgentEvent::StepEnd {
5308 step_id, status, ..
5309 } = event
5310 {
5311 match status {
5312 TaskStatus::Completed => completed_steps.push(step_id),
5313 TaskStatus::Failed => failed_steps.push(step_id),
5314 _ => {}
5315 }
5316 }
5317 }
5318
5319 assert!(
5320 completed_steps.contains(&"s1".to_string()),
5321 "s1 should complete"
5322 );
5323 assert!(
5324 completed_steps.contains(&"s3".to_string()),
5325 "s3 should complete"
5326 );
5327 assert!(failed_steps.contains(&"s2".to_string()), "s2 should fail");
5328 assert!(
5330 !completed_steps.contains(&"s4".to_string()),
5331 "s4 should not complete"
5332 );
5333 assert!(
5334 !failed_steps.contains(&"s4".to_string()),
5335 "s4 should not fail (never started)"
5336 );
5337 }
5338
5339 #[test]
5344 fn test_agent_config_resilience_defaults() {
5345 let config = AgentConfig::default();
5346 assert_eq!(config.max_parse_retries, 2);
5347 assert_eq!(config.tool_timeout_ms, None);
5348 assert_eq!(config.circuit_breaker_threshold, 3);
5349 }
5350
5351 #[tokio::test]
5353 async fn test_parse_error_recovery_bails_after_threshold() {
5354 let mock_client = Arc::new(MockLlmClient::new(vec![
5356 MockLlmClient::tool_call_response(
5357 "c1",
5358 "bash",
5359 serde_json::json!({"__parse_error": "unexpected token at position 5"}),
5360 ),
5361 MockLlmClient::tool_call_response(
5362 "c2",
5363 "bash",
5364 serde_json::json!({"__parse_error": "missing closing brace"}),
5365 ),
5366 MockLlmClient::tool_call_response(
5367 "c3",
5368 "bash",
5369 serde_json::json!({"__parse_error": "still broken"}),
5370 ),
5371 MockLlmClient::text_response("Done"), ]));
5373
5374 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5375 let config = AgentConfig {
5376 max_parse_retries: 2,
5377 ..AgentConfig::default()
5378 };
5379 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
5380 let result = agent.execute(&[], "Do something", None).await;
5381 assert!(result.is_err(), "should bail after parse error threshold");
5382 let err = result.unwrap_err().to_string();
5383 assert!(
5384 err.contains("malformed tool arguments"),
5385 "error should mention malformed tool arguments, got: {}",
5386 err
5387 );
5388 }
5389
5390 #[tokio::test]
5392 async fn test_parse_error_counter_resets_on_success() {
5393 let mock_client = Arc::new(MockLlmClient::new(vec![
5397 MockLlmClient::tool_call_response(
5398 "c1",
5399 "bash",
5400 serde_json::json!({"__parse_error": "bad args"}),
5401 ),
5402 MockLlmClient::tool_call_response(
5403 "c2",
5404 "bash",
5405 serde_json::json!({"__parse_error": "bad args again"}),
5406 ),
5407 MockLlmClient::tool_call_response(
5409 "c3",
5410 "bash",
5411 serde_json::json!({"command": "echo ok"}),
5412 ),
5413 MockLlmClient::text_response("All done"),
5414 ]));
5415
5416 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5417 let config = AgentConfig {
5418 max_parse_retries: 2,
5419 ..AgentConfig::default()
5420 };
5421 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
5422 let result = agent.execute(&[], "Do something", None).await;
5423 assert!(
5424 result.is_ok(),
5425 "should not bail — counter reset after successful tool, got: {:?}",
5426 result.err()
5427 );
5428 assert_eq!(result.unwrap().text, "All done");
5429 }
5430
5431 #[tokio::test]
5433 async fn test_tool_timeout_produces_error_result() {
5434 let mock_client = Arc::new(MockLlmClient::new(vec![
5435 MockLlmClient::tool_call_response(
5436 "t1",
5437 "bash",
5438 serde_json::json!({"command": "sleep 10"}),
5439 ),
5440 MockLlmClient::text_response("The command timed out."),
5441 ]));
5442
5443 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5444 let config = AgentConfig {
5445 tool_timeout_ms: Some(50),
5447 ..AgentConfig::default()
5448 };
5449 let agent = AgentLoop::new(
5450 mock_client.clone(),
5451 tool_executor,
5452 test_tool_context(),
5453 config,
5454 );
5455 let result = agent.execute(&[], "Run sleep", None).await;
5456 assert!(
5457 result.is_ok(),
5458 "session should continue after tool timeout: {:?}",
5459 result.err()
5460 );
5461 assert_eq!(result.unwrap().text, "The command timed out.");
5462 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 2);
5464 }
5465
5466 #[tokio::test]
5468 async fn test_tool_within_timeout_succeeds() {
5469 let mock_client = Arc::new(MockLlmClient::new(vec![
5470 MockLlmClient::tool_call_response(
5471 "t1",
5472 "bash",
5473 serde_json::json!({"command": "echo fast"}),
5474 ),
5475 MockLlmClient::text_response("Command succeeded."),
5476 ]));
5477
5478 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5479 let config = AgentConfig {
5480 tool_timeout_ms: Some(5_000), ..AgentConfig::default()
5482 };
5483 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
5484 let result = agent.execute(&[], "Run something fast", None).await;
5485 assert!(
5486 result.is_ok(),
5487 "fast tool should succeed: {:?}",
5488 result.err()
5489 );
5490 assert_eq!(result.unwrap().text, "Command succeeded.");
5491 }
5492
5493 #[tokio::test]
5495 async fn test_circuit_breaker_retries_non_streaming() {
5496 let mock_client = Arc::new(MockLlmClient::new(vec![]));
5499
5500 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5501 let config = AgentConfig {
5502 circuit_breaker_threshold: 2,
5503 ..AgentConfig::default()
5504 };
5505 let agent = AgentLoop::new(
5506 mock_client.clone(),
5507 tool_executor,
5508 test_tool_context(),
5509 config,
5510 );
5511 let result = agent.execute(&[], "Hello", None).await;
5512 assert!(result.is_err(), "should fail when LLM always errors");
5513 let err = result.unwrap_err().to_string();
5514 assert!(
5515 err.contains("circuit breaker"),
5516 "error should mention circuit breaker, got: {}",
5517 err
5518 );
5519 assert_eq!(
5520 mock_client.call_count.load(Ordering::SeqCst),
5521 2,
5522 "should make exactly threshold=2 LLM calls"
5523 );
5524 }
5525
5526 #[tokio::test]
5528 async fn test_circuit_breaker_threshold_one_no_retry() {
5529 let mock_client = Arc::new(MockLlmClient::new(vec![]));
5530
5531 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5532 let config = AgentConfig {
5533 circuit_breaker_threshold: 1,
5534 ..AgentConfig::default()
5535 };
5536 let agent = AgentLoop::new(
5537 mock_client.clone(),
5538 tool_executor,
5539 test_tool_context(),
5540 config,
5541 );
5542 let result = agent.execute(&[], "Hello", None).await;
5543 assert!(result.is_err());
5544 assert_eq!(
5545 mock_client.call_count.load(Ordering::SeqCst),
5546 1,
5547 "with threshold=1 exactly one attempt should be made"
5548 );
5549 }
5550
5551 #[tokio::test]
5553 async fn test_circuit_breaker_succeeds_if_llm_recovers() {
5554 struct FailOnceThenSucceed {
5556 inner: MockLlmClient,
5557 failed_once: std::sync::atomic::AtomicBool,
5558 call_count: AtomicUsize,
5559 }
5560
5561 #[async_trait::async_trait]
5562 impl LlmClient for FailOnceThenSucceed {
5563 async fn complete(
5564 &self,
5565 messages: &[Message],
5566 system: Option<&str>,
5567 tools: &[ToolDefinition],
5568 ) -> Result<LlmResponse> {
5569 self.call_count.fetch_add(1, Ordering::SeqCst);
5570 let already_failed = self
5571 .failed_once
5572 .swap(true, std::sync::atomic::Ordering::SeqCst);
5573 if !already_failed {
5574 anyhow::bail!("transient network error");
5575 }
5576 self.inner.complete(messages, system, tools).await
5577 }
5578
5579 async fn complete_streaming(
5580 &self,
5581 messages: &[Message],
5582 system: Option<&str>,
5583 tools: &[ToolDefinition],
5584 ) -> Result<tokio::sync::mpsc::Receiver<crate::llm::StreamEvent>> {
5585 self.inner.complete_streaming(messages, system, tools).await
5586 }
5587 }
5588
5589 let mock = Arc::new(FailOnceThenSucceed {
5590 inner: MockLlmClient::new(vec![MockLlmClient::text_response("Recovered!")]),
5591 failed_once: std::sync::atomic::AtomicBool::new(false),
5592 call_count: AtomicUsize::new(0),
5593 });
5594
5595 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5596 let config = AgentConfig {
5597 circuit_breaker_threshold: 3,
5598 ..AgentConfig::default()
5599 };
5600 let agent = AgentLoop::new(mock.clone(), tool_executor, test_tool_context(), config);
5601 let result = agent.execute(&[], "Hello", None).await;
5602 assert!(
5603 result.is_ok(),
5604 "should succeed when LLM recovers within threshold: {:?}",
5605 result.err()
5606 );
5607 assert_eq!(result.unwrap().text, "Recovered!");
5608 assert_eq!(
5609 mock.call_count.load(Ordering::SeqCst),
5610 2,
5611 "should have made exactly 2 calls (1 fail + 1 success)"
5612 );
5613 }
5614
5615 #[test]
5618 fn test_looks_incomplete_empty() {
5619 assert!(AgentLoop::looks_incomplete(""));
5620 assert!(AgentLoop::looks_incomplete(" "));
5621 }
5622
5623 #[test]
5624 fn test_looks_incomplete_trailing_colon() {
5625 assert!(AgentLoop::looks_incomplete("Let me check the file:"));
5626 assert!(AgentLoop::looks_incomplete("Next steps:"));
5627 }
5628
5629 #[test]
5630 fn test_looks_incomplete_ellipsis() {
5631 assert!(AgentLoop::looks_incomplete("Working on it..."));
5632 assert!(AgentLoop::looks_incomplete("Processing…"));
5633 }
5634
5635 #[test]
5636 fn test_looks_incomplete_intent_phrases() {
5637 assert!(AgentLoop::looks_incomplete(
5638 "I'll start by reading the file."
5639 ));
5640 assert!(AgentLoop::looks_incomplete(
5641 "Let me check the configuration."
5642 ));
5643 assert!(AgentLoop::looks_incomplete("I will now run the tests."));
5644 assert!(AgentLoop::looks_incomplete(
5645 "I need to update the Cargo.toml."
5646 ));
5647 }
5648
5649 #[test]
5650 fn test_looks_complete_final_answer() {
5651 assert!(!AgentLoop::looks_incomplete(
5653 "The tests pass. All changes have been applied successfully."
5654 ));
5655 assert!(!AgentLoop::looks_incomplete(
5656 "Done. I've updated the three files and verified the build succeeds."
5657 ));
5658 assert!(!AgentLoop::looks_incomplete("42"));
5659 assert!(!AgentLoop::looks_incomplete("Yes."));
5660 }
5661
5662 #[test]
5663 fn test_looks_incomplete_multiline_complete() {
5664 let text = "Here is the summary:\n\n- Fixed the bug in agent.rs\n- All tests pass\n- Build succeeds";
5665 assert!(!AgentLoop::looks_incomplete(text));
5666 }
5667}