1use crate::context::{ContextProvider, ContextQuery, ContextResult};
13use crate::hitl::ConfirmationProvider;
14use crate::hooks::{
15 ErrorType, GenerateEndEvent, GenerateStartEvent, HookEvent, HookExecutor, HookResult,
16 OnErrorEvent, PostResponseEvent, PostToolUseEvent, PrePromptEvent, PreToolUseEvent,
17 TokenUsageInfo, ToolCallInfo, ToolResultData,
18};
19use crate::llm::{LlmClient, LlmResponse, Message, TokenUsage, ToolDefinition};
20use crate::permissions::{PermissionChecker, PermissionDecision};
21use crate::planning::{AgentGoal, ExecutionPlan, TaskStatus};
22use crate::prompts::SystemPromptSlots;
23use crate::queue::SessionCommand;
24use crate::session_lane_queue::SessionLaneQueue;
25use crate::tool_search::ToolIndex;
26use crate::tools::{ToolContext, ToolExecutor, ToolStreamEvent};
27use anyhow::{Context, Result};
28use async_trait::async_trait;
29use futures::future::join_all;
30use serde::{Deserialize, Serialize};
31use serde_json::Value;
32use std::sync::Arc;
33use std::time::Duration;
34use tokio::sync::{mpsc, RwLock};
35
36const MAX_TOOL_ROUNDS: usize = 50;
38
39#[derive(Clone)]
41pub struct AgentConfig {
42 pub prompt_slots: SystemPromptSlots,
48 pub tools: Vec<ToolDefinition>,
49 pub max_tool_rounds: usize,
50 pub security_provider: Option<Arc<dyn crate::security::SecurityProvider>>,
52 pub permission_checker: Option<Arc<dyn PermissionChecker>>,
54 pub confirmation_manager: Option<Arc<dyn ConfirmationProvider>>,
56 pub context_providers: Vec<Arc<dyn ContextProvider>>,
58 pub planning_enabled: bool,
60 pub goal_tracking: bool,
62 pub hook_engine: Option<Arc<dyn HookExecutor>>,
64 pub skill_registry: Option<Arc<crate::skills::SkillRegistry>>,
66 pub max_parse_retries: u32,
72 pub tool_timeout_ms: Option<u64>,
78 pub circuit_breaker_threshold: u32,
85 pub auto_compact: bool,
87 pub auto_compact_threshold: f32,
90 pub max_context_tokens: usize,
93 pub llm_client: Option<Arc<dyn LlmClient>>,
95 pub memory: Option<Arc<crate::memory::AgentMemory>>,
97 pub continuation_enabled: bool,
105 pub max_continuation_turns: u32,
109 pub tool_index: Option<ToolIndex>,
114}
115
116impl std::fmt::Debug for AgentConfig {
117 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118 f.debug_struct("AgentConfig")
119 .field("prompt_slots", &self.prompt_slots)
120 .field("tools", &self.tools)
121 .field("max_tool_rounds", &self.max_tool_rounds)
122 .field("security_provider", &self.security_provider.is_some())
123 .field("permission_checker", &self.permission_checker.is_some())
124 .field("confirmation_manager", &self.confirmation_manager.is_some())
125 .field("context_providers", &self.context_providers.len())
126 .field("planning_enabled", &self.planning_enabled)
127 .field("goal_tracking", &self.goal_tracking)
128 .field("hook_engine", &self.hook_engine.is_some())
129 .field(
130 "skill_registry",
131 &self.skill_registry.as_ref().map(|r| r.len()),
132 )
133 .field("max_parse_retries", &self.max_parse_retries)
134 .field("tool_timeout_ms", &self.tool_timeout_ms)
135 .field("circuit_breaker_threshold", &self.circuit_breaker_threshold)
136 .field("auto_compact", &self.auto_compact)
137 .field("auto_compact_threshold", &self.auto_compact_threshold)
138 .field("max_context_tokens", &self.max_context_tokens)
139 .field("continuation_enabled", &self.continuation_enabled)
140 .field("max_continuation_turns", &self.max_continuation_turns)
141 .field("memory", &self.memory.is_some())
142 .field("tool_index", &self.tool_index.as_ref().map(|i| i.len()))
143 .finish()
144 }
145}
146
147impl Default for AgentConfig {
148 fn default() -> Self {
149 Self {
150 prompt_slots: SystemPromptSlots::default(),
151 tools: Vec::new(), max_tool_rounds: MAX_TOOL_ROUNDS,
153 security_provider: None,
154 permission_checker: None,
155 confirmation_manager: None,
156 context_providers: Vec::new(),
157 planning_enabled: false,
158 goal_tracking: false,
159 hook_engine: None,
160 skill_registry: None,
161 max_parse_retries: 2,
162 tool_timeout_ms: None,
163 circuit_breaker_threshold: 3,
164 auto_compact: false,
165 auto_compact_threshold: 0.80,
166 max_context_tokens: 200_000,
167 llm_client: None,
168 memory: None,
169 continuation_enabled: true,
170 max_continuation_turns: 3,
171 tool_index: None,
172 }
173 }
174}
175
176#[derive(Debug, Clone, Serialize, Deserialize)]
182#[serde(tag = "type")]
183#[non_exhaustive]
184pub enum AgentEvent {
185 #[serde(rename = "agent_start")]
187 Start { prompt: String },
188
189 #[serde(rename = "turn_start")]
191 TurnStart { turn: usize },
192
193 #[serde(rename = "text_delta")]
195 TextDelta { text: String },
196
197 #[serde(rename = "tool_start")]
199 ToolStart { id: String, name: String },
200
201 #[serde(rename = "tool_input_delta")]
203 ToolInputDelta { delta: String },
204
205 #[serde(rename = "tool_end")]
207 ToolEnd {
208 id: String,
209 name: String,
210 output: String,
211 exit_code: i32,
212 },
213
214 #[serde(rename = "tool_output_delta")]
216 ToolOutputDelta {
217 id: String,
218 name: String,
219 delta: String,
220 },
221
222 #[serde(rename = "turn_end")]
224 TurnEnd { turn: usize, usage: TokenUsage },
225
226 #[serde(rename = "agent_end")]
228 End { text: String, usage: TokenUsage },
229
230 #[serde(rename = "error")]
232 Error { message: String },
233
234 #[serde(rename = "confirmation_required")]
236 ConfirmationRequired {
237 tool_id: String,
238 tool_name: String,
239 args: serde_json::Value,
240 timeout_ms: u64,
241 },
242
243 #[serde(rename = "confirmation_received")]
245 ConfirmationReceived {
246 tool_id: String,
247 approved: bool,
248 reason: Option<String>,
249 },
250
251 #[serde(rename = "confirmation_timeout")]
253 ConfirmationTimeout {
254 tool_id: String,
255 action_taken: String, },
257
258 #[serde(rename = "external_task_pending")]
260 ExternalTaskPending {
261 task_id: String,
262 session_id: String,
263 lane: crate::hitl::SessionLane,
264 command_type: String,
265 payload: serde_json::Value,
266 timeout_ms: u64,
267 },
268
269 #[serde(rename = "external_task_completed")]
271 ExternalTaskCompleted {
272 task_id: String,
273 session_id: String,
274 success: bool,
275 },
276
277 #[serde(rename = "permission_denied")]
279 PermissionDenied {
280 tool_id: String,
281 tool_name: String,
282 args: serde_json::Value,
283 reason: String,
284 },
285
286 #[serde(rename = "context_resolving")]
288 ContextResolving { providers: Vec<String> },
289
290 #[serde(rename = "context_resolved")]
292 ContextResolved {
293 total_items: usize,
294 total_tokens: usize,
295 },
296
297 #[serde(rename = "command_dead_lettered")]
302 CommandDeadLettered {
303 command_id: String,
304 command_type: String,
305 lane: String,
306 error: String,
307 attempts: u32,
308 },
309
310 #[serde(rename = "command_retry")]
312 CommandRetry {
313 command_id: String,
314 command_type: String,
315 lane: String,
316 attempt: u32,
317 delay_ms: u64,
318 },
319
320 #[serde(rename = "queue_alert")]
322 QueueAlert {
323 level: String,
324 alert_type: String,
325 message: String,
326 },
327
328 #[serde(rename = "task_updated")]
333 TaskUpdated {
334 session_id: String,
335 tasks: Vec<crate::planning::Task>,
336 },
337
338 #[serde(rename = "memory_stored")]
343 MemoryStored {
344 memory_id: String,
345 memory_type: String,
346 importance: f32,
347 tags: Vec<String>,
348 },
349
350 #[serde(rename = "memory_recalled")]
352 MemoryRecalled {
353 memory_id: String,
354 content: String,
355 relevance: f32,
356 },
357
358 #[serde(rename = "memories_searched")]
360 MemoriesSearched {
361 query: Option<String>,
362 tags: Vec<String>,
363 result_count: usize,
364 },
365
366 #[serde(rename = "memory_cleared")]
368 MemoryCleared {
369 tier: String, count: u64,
371 },
372
373 #[serde(rename = "subagent_start")]
378 SubagentStart {
379 task_id: String,
381 session_id: String,
383 parent_session_id: String,
385 agent: String,
387 description: String,
389 },
390
391 #[serde(rename = "subagent_progress")]
393 SubagentProgress {
394 task_id: String,
396 session_id: String,
398 status: String,
400 metadata: serde_json::Value,
402 },
403
404 #[serde(rename = "subagent_end")]
406 SubagentEnd {
407 task_id: String,
409 session_id: String,
411 agent: String,
413 output: String,
415 success: bool,
417 },
418
419 #[serde(rename = "planning_start")]
424 PlanningStart { prompt: String },
425
426 #[serde(rename = "planning_end")]
428 PlanningEnd {
429 plan: ExecutionPlan,
430 estimated_steps: usize,
431 },
432
433 #[serde(rename = "step_start")]
435 StepStart {
436 step_id: String,
437 description: String,
438 step_number: usize,
439 total_steps: usize,
440 },
441
442 #[serde(rename = "step_end")]
444 StepEnd {
445 step_id: String,
446 status: TaskStatus,
447 step_number: usize,
448 total_steps: usize,
449 },
450
451 #[serde(rename = "goal_extracted")]
453 GoalExtracted { goal: AgentGoal },
454
455 #[serde(rename = "goal_progress")]
457 GoalProgress {
458 goal: String,
459 progress: f32,
460 completed_steps: usize,
461 total_steps: usize,
462 },
463
464 #[serde(rename = "goal_achieved")]
466 GoalAchieved {
467 goal: String,
468 total_steps: usize,
469 duration_ms: i64,
470 },
471
472 #[serde(rename = "context_compacted")]
477 ContextCompacted {
478 session_id: String,
479 before_messages: usize,
480 after_messages: usize,
481 percent_before: f32,
482 },
483
484 #[serde(rename = "persistence_failed")]
489 PersistenceFailed {
490 session_id: String,
491 operation: String,
492 error: String,
493 },
494}
495
496#[derive(Debug, Clone)]
498pub struct AgentResult {
499 pub text: String,
500 pub messages: Vec<Message>,
501 pub usage: TokenUsage,
502 pub tool_calls_count: usize,
503}
504
505pub struct ToolCommand {
513 tool_executor: Arc<ToolExecutor>,
514 tool_name: String,
515 tool_args: Value,
516 tool_context: ToolContext,
517 skill_registry: Option<Arc<crate::skills::SkillRegistry>>,
518}
519
520impl ToolCommand {
521 pub fn new(
523 tool_executor: Arc<ToolExecutor>,
524 tool_name: String,
525 tool_args: Value,
526 tool_context: ToolContext,
527 skill_registry: Option<Arc<crate::skills::SkillRegistry>>,
528 ) -> Self {
529 Self {
530 tool_executor,
531 tool_name,
532 tool_args,
533 tool_context,
534 skill_registry,
535 }
536 }
537}
538
539#[async_trait]
540impl SessionCommand for ToolCommand {
541 async fn execute(&self) -> Result<Value> {
542 if let Some(registry) = &self.skill_registry {
544 let instruction_skills = registry.by_kind(crate::skills::SkillKind::Instruction);
545
546 let has_restrictions = instruction_skills.iter().any(|s| s.allowed_tools.is_some());
548
549 if has_restrictions {
550 let mut allowed = false;
551
552 for skill in &instruction_skills {
553 if skill.is_tool_allowed(&self.tool_name) {
554 allowed = true;
555 break;
556 }
557 }
558
559 if !allowed {
560 return Err(anyhow::anyhow!(
561 "Tool '{}' is not allowed by any active skill. Active skills restrict tools to their allowed-tools lists.",
562 self.tool_name
563 ));
564 }
565 }
566 }
567
568 let result = self
570 .tool_executor
571 .execute_with_context(&self.tool_name, &self.tool_args, &self.tool_context)
572 .await?;
573 Ok(serde_json::json!({
574 "output": result.output,
575 "exit_code": result.exit_code,
576 "metadata": result.metadata,
577 }))
578 }
579
580 fn command_type(&self) -> &str {
581 &self.tool_name
582 }
583
584 fn payload(&self) -> Value {
585 self.tool_args.clone()
586 }
587}
588
589#[derive(Clone)]
595pub struct AgentLoop {
596 llm_client: Arc<dyn LlmClient>,
597 tool_executor: Arc<ToolExecutor>,
598 tool_context: ToolContext,
599 config: AgentConfig,
600 tool_metrics: Option<Arc<RwLock<crate::telemetry::ToolMetrics>>>,
602 command_queue: Option<Arc<SessionLaneQueue>>,
604}
605
606impl AgentLoop {
607 pub fn new(
608 llm_client: Arc<dyn LlmClient>,
609 tool_executor: Arc<ToolExecutor>,
610 tool_context: ToolContext,
611 config: AgentConfig,
612 ) -> Self {
613 Self {
614 llm_client,
615 tool_executor,
616 tool_context,
617 config,
618 tool_metrics: None,
619 command_queue: None,
620 }
621 }
622
623 pub fn with_tool_metrics(
625 mut self,
626 metrics: Arc<RwLock<crate::telemetry::ToolMetrics>>,
627 ) -> Self {
628 self.tool_metrics = Some(metrics);
629 self
630 }
631
632 pub fn with_queue(mut self, queue: Arc<SessionLaneQueue>) -> Self {
637 self.command_queue = Some(queue);
638 self
639 }
640
641 async fn execute_tool_timed(
647 &self,
648 name: &str,
649 args: &serde_json::Value,
650 ctx: &ToolContext,
651 ) -> anyhow::Result<crate::tools::ToolResult> {
652 let fut = self.tool_executor.execute_with_context(name, args, ctx);
653 if let Some(timeout_ms) = self.config.tool_timeout_ms {
654 match tokio::time::timeout(Duration::from_millis(timeout_ms), fut).await {
655 Ok(result) => result,
656 Err(_) => Err(anyhow::anyhow!(
657 "Tool '{}' timed out after {}ms",
658 name,
659 timeout_ms
660 )),
661 }
662 } else {
663 fut.await
664 }
665 }
666
667 fn tool_result_to_tuple(
669 result: anyhow::Result<crate::tools::ToolResult>,
670 ) -> (
671 String,
672 i32,
673 bool,
674 Option<serde_json::Value>,
675 Vec<crate::llm::Attachment>,
676 ) {
677 match result {
678 Ok(r) => (
679 r.output,
680 r.exit_code,
681 r.exit_code != 0,
682 r.metadata,
683 r.images,
684 ),
685 Err(e) => (
686 format!("Tool execution error: {}", e),
687 1,
688 true,
689 None,
690 Vec::new(),
691 ),
692 }
693 }
694
695 async fn execute_tool_queued_or_direct(
697 &self,
698 name: &str,
699 args: &serde_json::Value,
700 ctx: &ToolContext,
701 ) -> anyhow::Result<crate::tools::ToolResult> {
702 if let Some(ref queue) = self.command_queue {
703 let command = ToolCommand::new(
704 Arc::clone(&self.tool_executor),
705 name.to_string(),
706 args.clone(),
707 ctx.clone(),
708 self.config.skill_registry.clone(),
709 );
710 let rx = queue.submit_by_tool(name, Box::new(command)).await;
711 match rx.await {
712 Ok(Ok(value)) => {
713 let output = value["output"]
714 .as_str()
715 .ok_or_else(|| {
716 anyhow::anyhow!(
717 "Queue result missing 'output' field for tool '{}'",
718 name
719 )
720 })?
721 .to_string();
722 let exit_code = value["exit_code"].as_i64().unwrap_or(0) as i32;
723 return Ok(crate::tools::ToolResult {
724 name: name.to_string(),
725 output,
726 exit_code,
727 metadata: None,
728 images: Vec::new(),
729 });
730 }
731 Ok(Err(e)) => {
732 tracing::warn!(
733 "Queue execution failed for tool '{}', falling back to direct: {}",
734 name,
735 e
736 );
737 }
738 Err(_) => {
739 tracing::warn!(
740 "Queue channel closed for tool '{}', falling back to direct",
741 name
742 );
743 }
744 }
745 }
746 self.execute_tool_timed(name, args, ctx).await
747 }
748
749 async fn call_llm(
760 &self,
761 messages: &[Message],
762 system: Option<&str>,
763 event_tx: &Option<mpsc::Sender<AgentEvent>>,
764 ) -> anyhow::Result<LlmResponse> {
765 let tools = if let Some(ref index) = self.config.tool_index {
767 let query = messages
768 .iter()
769 .rev()
770 .find(|m| m.role == "user")
771 .and_then(|m| {
772 m.content.iter().find_map(|b| match b {
773 crate::llm::ContentBlock::Text { text } => Some(text.as_str()),
774 _ => None,
775 })
776 })
777 .unwrap_or("");
778 let matches = index.search(query, index.len());
779 let matched_names: std::collections::HashSet<&str> =
780 matches.iter().map(|m| m.name.as_str()).collect();
781 self.config
782 .tools
783 .iter()
784 .filter(|t| matched_names.contains(t.name.as_str()))
785 .cloned()
786 .collect::<Vec<_>>()
787 } else {
788 self.config.tools.clone()
789 };
790
791 if event_tx.is_some() {
792 let mut stream_rx = self
793 .llm_client
794 .complete_streaming(messages, system, &tools)
795 .await
796 .context("LLM streaming call failed")?;
797
798 let mut final_response: Option<LlmResponse> = None;
799 while let Some(event) = stream_rx.recv().await {
800 match event {
801 crate::llm::StreamEvent::TextDelta(text) => {
802 if let Some(tx) = event_tx {
803 tx.send(AgentEvent::TextDelta { text }).await.ok();
804 }
805 }
806 crate::llm::StreamEvent::ToolUseStart { id, name } => {
807 if let Some(tx) = event_tx {
808 tx.send(AgentEvent::ToolStart { id, name }).await.ok();
809 }
810 }
811 crate::llm::StreamEvent::ToolUseInputDelta(delta) => {
812 if let Some(tx) = event_tx {
813 tx.send(AgentEvent::ToolInputDelta { delta }).await.ok();
814 }
815 }
816 crate::llm::StreamEvent::Done(resp) => {
817 final_response = Some(resp);
818 }
819 }
820 }
821 final_response.context("Stream ended without final response")
822 } else {
823 self.llm_client
824 .complete(messages, system, &tools)
825 .await
826 .context("LLM call failed")
827 }
828 }
829
830 fn streaming_tool_context(
839 &self,
840 event_tx: &Option<mpsc::Sender<AgentEvent>>,
841 tool_id: &str,
842 tool_name: &str,
843 ) -> ToolContext {
844 let mut ctx = self.tool_context.clone();
845 if let Some(agent_tx) = event_tx {
846 let (tool_tx, mut tool_rx) = mpsc::channel::<ToolStreamEvent>(64);
847 ctx.event_tx = Some(tool_tx);
848
849 let agent_tx = agent_tx.clone();
850 let tool_id = tool_id.to_string();
851 let tool_name = tool_name.to_string();
852 tokio::spawn(async move {
853 while let Some(event) = tool_rx.recv().await {
854 match event {
855 ToolStreamEvent::OutputDelta(delta) => {
856 agent_tx
857 .send(AgentEvent::ToolOutputDelta {
858 id: tool_id.clone(),
859 name: tool_name.clone(),
860 delta,
861 })
862 .await
863 .ok();
864 }
865 }
866 }
867 });
868 }
869 ctx
870 }
871
872 async fn resolve_context(&self, prompt: &str, session_id: Option<&str>) -> Vec<ContextResult> {
876 if self.config.context_providers.is_empty() {
877 return Vec::new();
878 }
879
880 let query = ContextQuery::new(prompt).with_session_id(session_id.unwrap_or(""));
881
882 let futures = self
883 .config
884 .context_providers
885 .iter()
886 .map(|p| p.query(&query));
887 let outcomes = join_all(futures).await;
888
889 outcomes
890 .into_iter()
891 .enumerate()
892 .filter_map(|(i, r)| match r {
893 Ok(result) if !result.is_empty() => Some(result),
894 Ok(_) => None,
895 Err(e) => {
896 tracing::warn!(
897 "Context provider '{}' failed: {}",
898 self.config.context_providers[i].name(),
899 e
900 );
901 None
902 }
903 })
904 .collect()
905 }
906
907 fn looks_incomplete(text: &str) -> bool {
915 let t = text.trim();
916 if t.is_empty() {
917 return true;
918 }
919 if t.len() < 80 && !t.contains('\n') {
921 let ends_continuation =
924 t.ends_with(':') || t.ends_with("...") || t.ends_with('…') || t.ends_with(',');
925 if ends_continuation {
926 return true;
927 }
928 }
929 let incomplete_phrases = [
931 "i'll ",
932 "i will ",
933 "let me ",
934 "i need to ",
935 "i should ",
936 "next, i",
937 "first, i",
938 "now i",
939 "i'll start",
940 "i'll begin",
941 "i'll now",
942 "let's start",
943 "let's begin",
944 "to do this",
945 "i'm going to",
946 ];
947 let lower = t.to_lowercase();
948 for phrase in &incomplete_phrases {
949 if lower.contains(phrase) {
950 return true;
951 }
952 }
953 false
954 }
955
956 fn system_prompt(&self) -> String {
958 self.config.prompt_slots.build()
959 }
960
961 fn build_augmented_system_prompt(&self, context_results: &[ContextResult]) -> Option<String> {
963 let base = self.system_prompt();
964 if context_results.is_empty() {
965 return Some(base);
966 }
967
968 let context_xml: String = context_results
970 .iter()
971 .map(|r| r.to_xml())
972 .collect::<Vec<_>>()
973 .join("\n\n");
974
975 Some(format!("{}\n\n{}", base, context_xml))
976 }
977
978 async fn notify_turn_complete(&self, session_id: &str, prompt: &str, response: &str) {
980 let futures = self
981 .config
982 .context_providers
983 .iter()
984 .map(|p| p.on_turn_complete(session_id, prompt, response));
985 let outcomes = join_all(futures).await;
986
987 for (i, result) in outcomes.into_iter().enumerate() {
988 if let Err(e) = result {
989 tracing::warn!(
990 "Context provider '{}' on_turn_complete failed: {}",
991 self.config.context_providers[i].name(),
992 e
993 );
994 }
995 }
996 }
997
998 async fn fire_pre_tool_use(
1001 &self,
1002 session_id: &str,
1003 tool_name: &str,
1004 args: &serde_json::Value,
1005 ) -> Option<HookResult> {
1006 if let Some(he) = &self.config.hook_engine {
1007 let event = HookEvent::PreToolUse(PreToolUseEvent {
1008 session_id: session_id.to_string(),
1009 tool: tool_name.to_string(),
1010 args: args.clone(),
1011 working_directory: self.tool_context.workspace.to_string_lossy().to_string(),
1012 recent_tools: Vec::new(),
1013 });
1014 let result = he.fire(&event).await;
1015 if result.is_block() {
1016 return Some(result);
1017 }
1018 }
1019 None
1020 }
1021
1022 async fn fire_post_tool_use(
1024 &self,
1025 session_id: &str,
1026 tool_name: &str,
1027 args: &serde_json::Value,
1028 output: &str,
1029 success: bool,
1030 duration_ms: u64,
1031 ) {
1032 if let Some(he) = &self.config.hook_engine {
1033 let event = HookEvent::PostToolUse(PostToolUseEvent {
1034 session_id: session_id.to_string(),
1035 tool: tool_name.to_string(),
1036 args: args.clone(),
1037 result: ToolResultData {
1038 success,
1039 output: output.to_string(),
1040 exit_code: if success { Some(0) } else { Some(1) },
1041 duration_ms,
1042 },
1043 });
1044 let he = Arc::clone(he);
1045 tokio::spawn(async move {
1046 let _ = he.fire(&event).await;
1047 });
1048 }
1049 }
1050
1051 async fn fire_generate_start(
1053 &self,
1054 session_id: &str,
1055 prompt: &str,
1056 system_prompt: &Option<String>,
1057 ) {
1058 if let Some(he) = &self.config.hook_engine {
1059 let event = HookEvent::GenerateStart(GenerateStartEvent {
1060 session_id: session_id.to_string(),
1061 prompt: prompt.to_string(),
1062 system_prompt: system_prompt.clone(),
1063 model_provider: String::new(),
1064 model_name: String::new(),
1065 available_tools: self.config.tools.iter().map(|t| t.name.clone()).collect(),
1066 });
1067 let _ = he.fire(&event).await;
1068 }
1069 }
1070
1071 async fn fire_generate_end(
1073 &self,
1074 session_id: &str,
1075 prompt: &str,
1076 response: &LlmResponse,
1077 duration_ms: u64,
1078 ) {
1079 if let Some(he) = &self.config.hook_engine {
1080 let tool_calls: Vec<ToolCallInfo> = response
1081 .tool_calls()
1082 .iter()
1083 .map(|tc| ToolCallInfo {
1084 name: tc.name.clone(),
1085 args: tc.args.clone(),
1086 })
1087 .collect();
1088
1089 let event = HookEvent::GenerateEnd(GenerateEndEvent {
1090 session_id: session_id.to_string(),
1091 prompt: prompt.to_string(),
1092 response_text: response.text().to_string(),
1093 tool_calls,
1094 usage: TokenUsageInfo {
1095 prompt_tokens: response.usage.prompt_tokens as i32,
1096 completion_tokens: response.usage.completion_tokens as i32,
1097 total_tokens: response.usage.total_tokens as i32,
1098 },
1099 duration_ms,
1100 });
1101 let _ = he.fire(&event).await;
1102 }
1103 }
1104
1105 async fn fire_pre_prompt(
1108 &self,
1109 session_id: &str,
1110 prompt: &str,
1111 system_prompt: &Option<String>,
1112 message_count: usize,
1113 ) -> Option<String> {
1114 if let Some(he) = &self.config.hook_engine {
1115 let event = HookEvent::PrePrompt(PrePromptEvent {
1116 session_id: session_id.to_string(),
1117 prompt: prompt.to_string(),
1118 system_prompt: system_prompt.clone(),
1119 message_count,
1120 });
1121 let result = he.fire(&event).await;
1122 if let HookResult::Continue(Some(modified)) = result {
1123 if let Some(new_prompt) = modified.get("prompt").and_then(|v| v.as_str()) {
1125 return Some(new_prompt.to_string());
1126 }
1127 }
1128 }
1129 None
1130 }
1131
1132 async fn fire_post_response(
1134 &self,
1135 session_id: &str,
1136 response_text: &str,
1137 tool_calls_count: usize,
1138 usage: &TokenUsage,
1139 duration_ms: u64,
1140 ) {
1141 if let Some(he) = &self.config.hook_engine {
1142 let event = HookEvent::PostResponse(PostResponseEvent {
1143 session_id: session_id.to_string(),
1144 response_text: response_text.to_string(),
1145 tool_calls_count,
1146 usage: TokenUsageInfo {
1147 prompt_tokens: usage.prompt_tokens as i32,
1148 completion_tokens: usage.completion_tokens as i32,
1149 total_tokens: usage.total_tokens as i32,
1150 },
1151 duration_ms,
1152 });
1153 let he = Arc::clone(he);
1154 tokio::spawn(async move {
1155 let _ = he.fire(&event).await;
1156 });
1157 }
1158 }
1159
1160 async fn fire_on_error(
1162 &self,
1163 session_id: &str,
1164 error_type: ErrorType,
1165 error_message: &str,
1166 context: serde_json::Value,
1167 ) {
1168 if let Some(he) = &self.config.hook_engine {
1169 let event = HookEvent::OnError(OnErrorEvent {
1170 session_id: session_id.to_string(),
1171 error_type,
1172 error_message: error_message.to_string(),
1173 context,
1174 });
1175 let he = Arc::clone(he);
1176 tokio::spawn(async move {
1177 let _ = he.fire(&event).await;
1178 });
1179 }
1180 }
1181
1182 pub async fn execute(
1188 &self,
1189 history: &[Message],
1190 prompt: &str,
1191 event_tx: Option<mpsc::Sender<AgentEvent>>,
1192 ) -> Result<AgentResult> {
1193 self.execute_with_session(history, prompt, None, event_tx)
1194 .await
1195 }
1196
1197 pub async fn execute_from_messages(
1203 &self,
1204 messages: Vec<Message>,
1205 session_id: Option<&str>,
1206 event_tx: Option<mpsc::Sender<AgentEvent>>,
1207 ) -> Result<AgentResult> {
1208 tracing::info!(
1209 a3s.session.id = session_id.unwrap_or("none"),
1210 a3s.agent.max_turns = self.config.max_tool_rounds,
1211 "a3s.agent.execute_from_messages started"
1212 );
1213
1214 let effective_prompt = messages
1218 .iter()
1219 .rev()
1220 .find(|m| m.role == "user")
1221 .map(|m| m.text())
1222 .unwrap_or_default();
1223
1224 let result = self
1225 .execute_loop_inner(&messages, "", &effective_prompt, session_id, event_tx)
1226 .await;
1227
1228 match &result {
1229 Ok(r) => tracing::info!(
1230 a3s.agent.tool_calls_count = r.tool_calls_count,
1231 a3s.llm.total_tokens = r.usage.total_tokens,
1232 "a3s.agent.execute_from_messages completed"
1233 ),
1234 Err(e) => tracing::warn!(
1235 error = %e,
1236 "a3s.agent.execute_from_messages failed"
1237 ),
1238 }
1239
1240 result
1241 }
1242
1243 pub async fn execute_with_session(
1248 &self,
1249 history: &[Message],
1250 prompt: &str,
1251 session_id: Option<&str>,
1252 event_tx: Option<mpsc::Sender<AgentEvent>>,
1253 ) -> Result<AgentResult> {
1254 tracing::info!(
1255 a3s.session.id = session_id.unwrap_or("none"),
1256 a3s.agent.max_turns = self.config.max_tool_rounds,
1257 "a3s.agent.execute started"
1258 );
1259
1260 let result = if self.config.planning_enabled {
1262 self.execute_with_planning(history, prompt, event_tx).await
1263 } else {
1264 self.execute_loop(history, prompt, session_id, event_tx)
1265 .await
1266 };
1267
1268 match &result {
1269 Ok(r) => {
1270 tracing::info!(
1271 a3s.agent.tool_calls_count = r.tool_calls_count,
1272 a3s.llm.total_tokens = r.usage.total_tokens,
1273 "a3s.agent.execute completed"
1274 );
1275 self.fire_post_response(
1277 session_id.unwrap_or(""),
1278 &r.text,
1279 r.tool_calls_count,
1280 &r.usage,
1281 0, )
1283 .await;
1284 }
1285 Err(e) => {
1286 tracing::warn!(
1287 error = %e,
1288 "a3s.agent.execute failed"
1289 );
1290 self.fire_on_error(
1292 session_id.unwrap_or(""),
1293 ErrorType::Other,
1294 &e.to_string(),
1295 serde_json::json!({"phase": "execute"}),
1296 )
1297 .await;
1298 }
1299 }
1300
1301 result
1302 }
1303
1304 async fn execute_loop(
1310 &self,
1311 history: &[Message],
1312 prompt: &str,
1313 session_id: Option<&str>,
1314 event_tx: Option<mpsc::Sender<AgentEvent>>,
1315 ) -> Result<AgentResult> {
1316 self.execute_loop_inner(history, prompt, prompt, session_id, event_tx)
1319 .await
1320 }
1321
1322 async fn execute_loop_inner(
1327 &self,
1328 history: &[Message],
1329 msg_prompt: &str,
1330 effective_prompt: &str,
1331 session_id: Option<&str>,
1332 event_tx: Option<mpsc::Sender<AgentEvent>>,
1333 ) -> Result<AgentResult> {
1334 let mut messages = history.to_vec();
1335 let mut total_usage = TokenUsage::default();
1336 let mut tool_calls_count = 0;
1337 let mut turn = 0;
1338 let mut parse_error_count: u32 = 0;
1340 let mut continuation_count: u32 = 0;
1342
1343 if let Some(tx) = &event_tx {
1345 tx.send(AgentEvent::Start {
1346 prompt: effective_prompt.to_string(),
1347 })
1348 .await
1349 .ok();
1350 }
1351
1352 let _queue_forward_handle =
1354 if let (Some(ref queue), Some(ref tx)) = (&self.command_queue, &event_tx) {
1355 let mut rx = queue.subscribe();
1356 let tx = tx.clone();
1357 Some(tokio::spawn(async move {
1358 while let Ok(event) = rx.recv().await {
1359 if tx.send(event).await.is_err() {
1360 break;
1361 }
1362 }
1363 }))
1364 } else {
1365 None
1366 };
1367
1368 let built_system_prompt = Some(self.system_prompt());
1370 let hooked_prompt = if let Some(modified) = self
1371 .fire_pre_prompt(
1372 session_id.unwrap_or(""),
1373 effective_prompt,
1374 &built_system_prompt,
1375 messages.len(),
1376 )
1377 .await
1378 {
1379 modified
1380 } else {
1381 effective_prompt.to_string()
1382 };
1383 let effective_prompt = hooked_prompt.as_str();
1384
1385 if let Some(ref sp) = self.config.security_provider {
1387 sp.taint_input(effective_prompt);
1388 }
1389
1390 let system_with_memory = if let Some(ref memory) = self.config.memory {
1392 match memory.recall_similar(effective_prompt, 5).await {
1393 Ok(items) if !items.is_empty() => {
1394 if let Some(tx) = &event_tx {
1395 for item in &items {
1396 tx.send(AgentEvent::MemoryRecalled {
1397 memory_id: item.id.clone(),
1398 content: item.content.clone(),
1399 relevance: item.relevance_score(),
1400 })
1401 .await
1402 .ok();
1403 }
1404 tx.send(AgentEvent::MemoriesSearched {
1405 query: Some(effective_prompt.to_string()),
1406 tags: Vec::new(),
1407 result_count: items.len(),
1408 })
1409 .await
1410 .ok();
1411 }
1412 let memory_context = items
1413 .iter()
1414 .map(|i| format!("- {}", i.content))
1415 .collect::<Vec<_>>()
1416 .join(
1417 "
1418",
1419 );
1420 let base = self.system_prompt();
1421 Some(format!(
1422 "{}
1423
1424## Relevant past experience
1425{}",
1426 base, memory_context
1427 ))
1428 }
1429 _ => Some(self.system_prompt()),
1430 }
1431 } else {
1432 Some(self.system_prompt())
1433 };
1434
1435 let augmented_system = if !self.config.context_providers.is_empty() {
1437 if let Some(tx) = &event_tx {
1439 let provider_names: Vec<String> = self
1440 .config
1441 .context_providers
1442 .iter()
1443 .map(|p| p.name().to_string())
1444 .collect();
1445 tx.send(AgentEvent::ContextResolving {
1446 providers: provider_names,
1447 })
1448 .await
1449 .ok();
1450 }
1451
1452 tracing::info!(
1453 a3s.context.providers = self.config.context_providers.len() as i64,
1454 "Context resolution started"
1455 );
1456 let context_results = self.resolve_context(effective_prompt, session_id).await;
1457
1458 if let Some(tx) = &event_tx {
1460 let total_items: usize = context_results.iter().map(|r| r.items.len()).sum();
1461 let total_tokens: usize = context_results.iter().map(|r| r.total_tokens).sum();
1462
1463 tracing::info!(
1464 context_items = total_items,
1465 context_tokens = total_tokens,
1466 "Context resolution completed"
1467 );
1468
1469 tx.send(AgentEvent::ContextResolved {
1470 total_items,
1471 total_tokens,
1472 })
1473 .await
1474 .ok();
1475 }
1476
1477 self.build_augmented_system_prompt(&context_results)
1478 } else {
1479 Some(self.system_prompt())
1480 };
1481
1482 let base_prompt = self.system_prompt();
1484 let augmented_system = match (augmented_system, system_with_memory) {
1485 (Some(ctx), Some(mem)) if ctx != mem => Some(ctx.replacen(&base_prompt, &mem, 1)),
1486 (Some(ctx), _) => Some(ctx),
1487 (None, mem) => mem,
1488 };
1489
1490 if !msg_prompt.is_empty() {
1492 messages.push(Message::user(msg_prompt));
1493 }
1494
1495 loop {
1496 turn += 1;
1497
1498 if turn > self.config.max_tool_rounds {
1499 let error = format!("Max tool rounds ({}) exceeded", self.config.max_tool_rounds);
1500 if let Some(tx) = &event_tx {
1501 tx.send(AgentEvent::Error {
1502 message: error.clone(),
1503 })
1504 .await
1505 .ok();
1506 }
1507 anyhow::bail!(error);
1508 }
1509
1510 if let Some(tx) = &event_tx {
1512 tx.send(AgentEvent::TurnStart { turn }).await.ok();
1513 }
1514
1515 tracing::info!(
1516 turn = turn,
1517 max_turns = self.config.max_tool_rounds,
1518 "Agent turn started"
1519 );
1520
1521 tracing::info!(
1523 a3s.llm.streaming = event_tx.is_some(),
1524 "LLM completion started"
1525 );
1526
1527 self.fire_generate_start(
1529 session_id.unwrap_or(""),
1530 effective_prompt,
1531 &augmented_system,
1532 )
1533 .await;
1534
1535 let llm_start = std::time::Instant::now();
1536 let response = {
1540 let threshold = self.config.circuit_breaker_threshold.max(1);
1541 let mut attempt = 0u32;
1542 loop {
1543 attempt += 1;
1544 let result = self
1545 .call_llm(&messages, augmented_system.as_deref(), &event_tx)
1546 .await;
1547 match result {
1548 Ok(r) => {
1549 break r;
1550 }
1551 Err(e) if attempt < threshold && (event_tx.is_none() || attempt == 1) => {
1553 tracing::warn!(
1554 turn = turn,
1555 attempt = attempt,
1556 threshold = threshold,
1557 error = %e,
1558 "LLM call failed, will retry"
1559 );
1560 tokio::time::sleep(Duration::from_millis(100 * attempt as u64)).await;
1561 }
1562 Err(e) => {
1564 let msg = if attempt > 1 {
1565 format!(
1566 "LLM circuit breaker triggered: failed after {} attempt(s): {}",
1567 attempt, e
1568 )
1569 } else {
1570 format!("LLM call failed: {}", e)
1571 };
1572 tracing::error!(turn = turn, attempt = attempt, "{}", msg);
1573 self.fire_on_error(
1575 session_id.unwrap_or(""),
1576 ErrorType::LlmFailure,
1577 &msg,
1578 serde_json::json!({"turn": turn, "attempt": attempt}),
1579 )
1580 .await;
1581 if let Some(tx) = &event_tx {
1582 tx.send(AgentEvent::Error {
1583 message: msg.clone(),
1584 })
1585 .await
1586 .ok();
1587 }
1588 anyhow::bail!(msg);
1589 }
1590 }
1591 }
1592 };
1593
1594 total_usage.prompt_tokens += response.usage.prompt_tokens;
1596 total_usage.completion_tokens += response.usage.completion_tokens;
1597 total_usage.total_tokens += response.usage.total_tokens;
1598
1599 let llm_duration = llm_start.elapsed();
1601 tracing::info!(
1602 turn = turn,
1603 streaming = event_tx.is_some(),
1604 prompt_tokens = response.usage.prompt_tokens,
1605 completion_tokens = response.usage.completion_tokens,
1606 total_tokens = response.usage.total_tokens,
1607 stop_reason = response.stop_reason.as_deref().unwrap_or("unknown"),
1608 duration_ms = llm_duration.as_millis() as u64,
1609 "LLM completion finished"
1610 );
1611
1612 self.fire_generate_end(
1614 session_id.unwrap_or(""),
1615 effective_prompt,
1616 &response,
1617 llm_duration.as_millis() as u64,
1618 )
1619 .await;
1620
1621 crate::telemetry::record_llm_usage(
1623 response.usage.prompt_tokens,
1624 response.usage.completion_tokens,
1625 response.usage.total_tokens,
1626 response.stop_reason.as_deref(),
1627 );
1628 tracing::info!(
1630 turn = turn,
1631 a3s.llm.total_tokens = response.usage.total_tokens,
1632 "Turn token usage"
1633 );
1634
1635 messages.push(response.message.clone());
1637
1638 let tool_calls = response.tool_calls();
1640
1641 if let Some(tx) = &event_tx {
1643 tx.send(AgentEvent::TurnEnd {
1644 turn,
1645 usage: response.usage.clone(),
1646 })
1647 .await
1648 .ok();
1649 }
1650
1651 if self.config.auto_compact {
1653 let used = response.usage.prompt_tokens;
1654 let max = self.config.max_context_tokens;
1655 let threshold = self.config.auto_compact_threshold;
1656
1657 if crate::session::compaction::should_auto_compact(used, max, threshold) {
1658 let before_len = messages.len();
1659 let percent_before = used as f32 / max as f32;
1660
1661 tracing::info!(
1662 used_tokens = used,
1663 max_tokens = max,
1664 percent = percent_before,
1665 threshold = threshold,
1666 "Auto-compact triggered"
1667 );
1668
1669 if let Some(pruned) = crate::session::compaction::prune_tool_outputs(&messages)
1671 {
1672 messages = pruned;
1673 tracing::info!("Tool output pruning applied");
1674 }
1675
1676 if let Ok(Some(compacted)) = crate::session::compaction::compact_messages(
1678 session_id.unwrap_or(""),
1679 &messages,
1680 &self.llm_client,
1681 )
1682 .await
1683 {
1684 messages = compacted;
1685 }
1686
1687 if let Some(tx) = &event_tx {
1689 tx.send(AgentEvent::ContextCompacted {
1690 session_id: session_id.unwrap_or("").to_string(),
1691 before_messages: before_len,
1692 after_messages: messages.len(),
1693 percent_before,
1694 })
1695 .await
1696 .ok();
1697 }
1698 }
1699 }
1700
1701 if tool_calls.is_empty() {
1702 let final_text = response.text();
1705
1706 if self.config.continuation_enabled
1707 && continuation_count < self.config.max_continuation_turns
1708 && Self::looks_incomplete(&final_text)
1709 {
1710 continuation_count += 1;
1711 tracing::info!(
1712 turn = turn,
1713 continuation = continuation_count,
1714 max_continuation = self.config.max_continuation_turns,
1715 "Injecting continuation message — response looks incomplete"
1716 );
1717 messages.push(Message::user(crate::prompts::CONTINUATION));
1719 continue;
1720 }
1721
1722 let final_text = if let Some(ref sp) = self.config.security_provider {
1724 sp.sanitize_output(&final_text)
1725 } else {
1726 final_text
1727 };
1728
1729 tracing::info!(
1731 tool_calls_count = tool_calls_count,
1732 total_prompt_tokens = total_usage.prompt_tokens,
1733 total_completion_tokens = total_usage.completion_tokens,
1734 total_tokens = total_usage.total_tokens,
1735 turns = turn,
1736 "Agent execution completed"
1737 );
1738
1739 if let Some(tx) = &event_tx {
1740 tx.send(AgentEvent::End {
1741 text: final_text.clone(),
1742 usage: total_usage.clone(),
1743 })
1744 .await
1745 .ok();
1746 }
1747
1748 if let Some(sid) = session_id {
1750 self.notify_turn_complete(sid, effective_prompt, &final_text)
1751 .await;
1752 }
1753
1754 return Ok(AgentResult {
1755 text: final_text,
1756 messages,
1757 usage: total_usage,
1758 tool_calls_count,
1759 });
1760 }
1761
1762 for tool_call in tool_calls {
1764 tool_calls_count += 1;
1765
1766 let tool_start = std::time::Instant::now();
1767
1768 tracing::info!(
1769 tool_name = tool_call.name.as_str(),
1770 tool_id = tool_call.id.as_str(),
1771 "Tool execution started"
1772 );
1773
1774 if let Some(parse_error) =
1780 tool_call.args.get("__parse_error").and_then(|v| v.as_str())
1781 {
1782 parse_error_count += 1;
1783 let error_msg = format!("Error: {}", parse_error);
1784 tracing::warn!(
1785 tool = tool_call.name.as_str(),
1786 parse_error_count = parse_error_count,
1787 max_parse_retries = self.config.max_parse_retries,
1788 "Malformed tool arguments from LLM"
1789 );
1790
1791 if let Some(tx) = &event_tx {
1792 tx.send(AgentEvent::ToolEnd {
1793 id: tool_call.id.clone(),
1794 name: tool_call.name.clone(),
1795 output: error_msg.clone(),
1796 exit_code: 1,
1797 })
1798 .await
1799 .ok();
1800 }
1801
1802 messages.push(Message::tool_result(&tool_call.id, &error_msg, true));
1803
1804 if parse_error_count > self.config.max_parse_retries {
1805 let msg = format!(
1806 "LLM produced malformed tool arguments {} time(s) in a row \
1807 (max_parse_retries={}); giving up",
1808 parse_error_count, self.config.max_parse_retries
1809 );
1810 tracing::error!("{}", msg);
1811 if let Some(tx) = &event_tx {
1812 tx.send(AgentEvent::Error {
1813 message: msg.clone(),
1814 })
1815 .await
1816 .ok();
1817 }
1818 anyhow::bail!(msg);
1819 }
1820 continue;
1821 }
1822
1823 parse_error_count = 0;
1825
1826 if let Some(ref registry) = self.config.skill_registry {
1828 let instruction_skills =
1829 registry.by_kind(crate::skills::SkillKind::Instruction);
1830 let has_restrictions =
1831 instruction_skills.iter().any(|s| s.allowed_tools.is_some());
1832 if has_restrictions {
1833 let allowed = instruction_skills
1834 .iter()
1835 .any(|s| s.is_tool_allowed(&tool_call.name));
1836 if !allowed {
1837 let msg = format!(
1838 "Tool '{}' is not allowed by any active skill.",
1839 tool_call.name
1840 );
1841 tracing::info!(
1842 tool_name = tool_call.name.as_str(),
1843 "Tool blocked by skill registry"
1844 );
1845 if let Some(tx) = &event_tx {
1846 tx.send(AgentEvent::PermissionDenied {
1847 tool_id: tool_call.id.clone(),
1848 tool_name: tool_call.name.clone(),
1849 args: tool_call.args.clone(),
1850 reason: msg.clone(),
1851 })
1852 .await
1853 .ok();
1854 }
1855 messages.push(Message::tool_result(&tool_call.id, &msg, true));
1856 continue;
1857 }
1858 }
1859 }
1860
1861 if let Some(HookResult::Block(reason)) = self
1863 .fire_pre_tool_use(session_id.unwrap_or(""), &tool_call.name, &tool_call.args)
1864 .await
1865 {
1866 let msg = format!("Tool '{}' blocked by hook: {}", tool_call.name, reason);
1867 tracing::info!(
1868 tool_name = tool_call.name.as_str(),
1869 "Tool blocked by PreToolUse hook"
1870 );
1871
1872 if let Some(tx) = &event_tx {
1873 tx.send(AgentEvent::PermissionDenied {
1874 tool_id: tool_call.id.clone(),
1875 tool_name: tool_call.name.clone(),
1876 args: tool_call.args.clone(),
1877 reason: reason.clone(),
1878 })
1879 .await
1880 .ok();
1881 }
1882
1883 messages.push(Message::tool_result(&tool_call.id, &msg, true));
1884 continue;
1885 }
1886
1887 let permission_decision = if let Some(checker) = &self.config.permission_checker {
1889 checker.check(&tool_call.name, &tool_call.args)
1890 } else {
1891 PermissionDecision::Ask
1893 };
1894
1895 let (output, exit_code, is_error, _metadata, images) = match permission_decision {
1896 PermissionDecision::Deny => {
1897 tracing::info!(
1898 tool_name = tool_call.name.as_str(),
1899 permission = "deny",
1900 "Tool permission denied"
1901 );
1902 let denial_msg = format!(
1904 "Permission denied: Tool '{}' is blocked by permission policy.",
1905 tool_call.name
1906 );
1907
1908 if let Some(tx) = &event_tx {
1910 tx.send(AgentEvent::PermissionDenied {
1911 tool_id: tool_call.id.clone(),
1912 tool_name: tool_call.name.clone(),
1913 args: tool_call.args.clone(),
1914 reason: "Blocked by deny rule in permission policy".to_string(),
1915 })
1916 .await
1917 .ok();
1918 }
1919
1920 (denial_msg, 1, true, None, Vec::new())
1921 }
1922 PermissionDecision::Allow => {
1923 tracing::info!(
1924 tool_name = tool_call.name.as_str(),
1925 permission = "allow",
1926 "Tool permission: allow"
1927 );
1928 let stream_ctx =
1930 self.streaming_tool_context(&event_tx, &tool_call.id, &tool_call.name);
1931 let result = self
1932 .execute_tool_queued_or_direct(
1933 &tool_call.name,
1934 &tool_call.args,
1935 &stream_ctx,
1936 )
1937 .await;
1938
1939 Self::tool_result_to_tuple(result)
1940 }
1941 PermissionDecision::Ask => {
1942 tracing::info!(
1943 tool_name = tool_call.name.as_str(),
1944 permission = "ask",
1945 "Tool permission: ask"
1946 );
1947 if let Some(cm) = &self.config.confirmation_manager {
1949 if !cm.requires_confirmation(&tool_call.name).await {
1951 let stream_ctx = self.streaming_tool_context(
1952 &event_tx,
1953 &tool_call.id,
1954 &tool_call.name,
1955 );
1956 let result = self
1957 .execute_tool_queued_or_direct(
1958 &tool_call.name,
1959 &tool_call.args,
1960 &stream_ctx,
1961 )
1962 .await;
1963
1964 let (output, exit_code, is_error, _metadata, images) =
1965 Self::tool_result_to_tuple(result);
1966
1967 if images.is_empty() {
1969 messages.push(Message::tool_result(
1970 &tool_call.id,
1971 &output,
1972 is_error,
1973 ));
1974 } else {
1975 messages.push(Message::tool_result_with_images(
1976 &tool_call.id,
1977 &output,
1978 &images,
1979 is_error,
1980 ));
1981 }
1982
1983 let tool_duration = tool_start.elapsed();
1985 crate::telemetry::record_tool_result(exit_code, tool_duration);
1986
1987 if let Some(tx) = &event_tx {
1989 tx.send(AgentEvent::ToolEnd {
1990 id: tool_call.id.clone(),
1991 name: tool_call.name.clone(),
1992 output: output.clone(),
1993 exit_code,
1994 })
1995 .await
1996 .ok();
1997 }
1998
1999 self.fire_post_tool_use(
2001 session_id.unwrap_or(""),
2002 &tool_call.name,
2003 &tool_call.args,
2004 &output,
2005 exit_code == 0,
2006 tool_duration.as_millis() as u64,
2007 )
2008 .await;
2009
2010 continue; }
2012
2013 let policy = cm.policy().await;
2015 let timeout_ms = policy.default_timeout_ms;
2016 let timeout_action = policy.timeout_action;
2017
2018 let rx = cm
2020 .request_confirmation(
2021 &tool_call.id,
2022 &tool_call.name,
2023 &tool_call.args,
2024 )
2025 .await;
2026
2027 let confirmation_result =
2029 tokio::time::timeout(Duration::from_millis(timeout_ms), rx).await;
2030
2031 match confirmation_result {
2032 Ok(Ok(response)) => {
2033 if response.approved {
2034 let stream_ctx = self.streaming_tool_context(
2035 &event_tx,
2036 &tool_call.id,
2037 &tool_call.name,
2038 );
2039 let result = self
2040 .execute_tool_queued_or_direct(
2041 &tool_call.name,
2042 &tool_call.args,
2043 &stream_ctx,
2044 )
2045 .await;
2046
2047 Self::tool_result_to_tuple(result)
2048 } else {
2049 let rejection_msg = format!(
2050 "Tool '{}' execution was rejected by user. Reason: {}",
2051 tool_call.name,
2052 response.reason.unwrap_or_else(|| "No reason provided".to_string())
2053 );
2054 (rejection_msg, 1, true, None, Vec::new())
2055 }
2056 }
2057 Ok(Err(_)) => {
2058 let msg = format!(
2059 "Tool '{}' confirmation failed: confirmation channel closed",
2060 tool_call.name
2061 );
2062 (msg, 1, true, None, Vec::new())
2063 }
2064 Err(_) => {
2065 cm.check_timeouts().await;
2066
2067 match timeout_action {
2068 crate::hitl::TimeoutAction::Reject => {
2069 let msg = format!(
2070 "Tool '{}' execution timed out waiting for confirmation ({}ms). Execution rejected.",
2071 tool_call.name, timeout_ms
2072 );
2073 (msg, 1, true, None, Vec::new())
2074 }
2075 crate::hitl::TimeoutAction::AutoApprove => {
2076 let stream_ctx = self.streaming_tool_context(
2077 &event_tx,
2078 &tool_call.id,
2079 &tool_call.name,
2080 );
2081 let result = self
2082 .execute_tool_queued_or_direct(
2083 &tool_call.name,
2084 &tool_call.args,
2085 &stream_ctx,
2086 )
2087 .await;
2088
2089 Self::tool_result_to_tuple(result)
2090 }
2091 }
2092 }
2093 }
2094 } else {
2095 let msg = format!(
2097 "Tool '{}' requires confirmation but no HITL confirmation manager is configured. \
2098 Configure a confirmation policy to enable tool execution.",
2099 tool_call.name
2100 );
2101 tracing::warn!(
2102 tool_name = tool_call.name.as_str(),
2103 "Tool requires confirmation but no HITL manager configured"
2104 );
2105 (msg, 1, true, None, Vec::new())
2106 }
2107 }
2108 };
2109
2110 let tool_duration = tool_start.elapsed();
2111 crate::telemetry::record_tool_result(exit_code, tool_duration);
2112
2113 let output = if let Some(ref sp) = self.config.security_provider {
2115 sp.sanitize_output(&output)
2116 } else {
2117 output
2118 };
2119
2120 self.fire_post_tool_use(
2122 session_id.unwrap_or(""),
2123 &tool_call.name,
2124 &tool_call.args,
2125 &output,
2126 exit_code == 0,
2127 tool_duration.as_millis() as u64,
2128 )
2129 .await;
2130
2131 if let Some(ref memory) = self.config.memory {
2133 let tools_used = [tool_call.name.clone()];
2134 let remember_result = if exit_code == 0 {
2135 memory
2136 .remember_success(effective_prompt, &tools_used, &output)
2137 .await
2138 } else {
2139 memory
2140 .remember_failure(effective_prompt, &output, &tools_used)
2141 .await
2142 };
2143 match remember_result {
2144 Ok(()) => {
2145 if let Some(tx) = &event_tx {
2146 let item_type = if exit_code == 0 { "success" } else { "failure" };
2147 tx.send(AgentEvent::MemoryStored {
2148 memory_id: uuid::Uuid::new_v4().to_string(),
2149 memory_type: item_type.to_string(),
2150 importance: if exit_code == 0 { 0.8 } else { 0.9 },
2151 tags: vec![item_type.to_string(), tool_call.name.clone()],
2152 })
2153 .await
2154 .ok();
2155 }
2156 }
2157 Err(e) => {
2158 tracing::warn!("Failed to store memory after tool execution: {}", e);
2159 }
2160 }
2161 }
2162
2163 if let Some(tx) = &event_tx {
2165 tx.send(AgentEvent::ToolEnd {
2166 id: tool_call.id.clone(),
2167 name: tool_call.name.clone(),
2168 output: output.clone(),
2169 exit_code,
2170 })
2171 .await
2172 .ok();
2173 }
2174
2175 if images.is_empty() {
2177 messages.push(Message::tool_result(&tool_call.id, &output, is_error));
2178 } else {
2179 messages.push(Message::tool_result_with_images(
2180 &tool_call.id,
2181 &output,
2182 &images,
2183 is_error,
2184 ));
2185 }
2186 }
2187 }
2188 }
2189
2190 pub async fn execute_streaming(
2192 &self,
2193 history: &[Message],
2194 prompt: &str,
2195 ) -> Result<(
2196 mpsc::Receiver<AgentEvent>,
2197 tokio::task::JoinHandle<Result<AgentResult>>,
2198 )> {
2199 let (tx, rx) = mpsc::channel(100);
2200
2201 let llm_client = self.llm_client.clone();
2202 let tool_executor = self.tool_executor.clone();
2203 let tool_context = self.tool_context.clone();
2204 let config = self.config.clone();
2205 let tool_metrics = self.tool_metrics.clone();
2206 let command_queue = self.command_queue.clone();
2207 let history = history.to_vec();
2208 let prompt = prompt.to_string();
2209
2210 let handle = tokio::spawn(async move {
2211 let mut agent = AgentLoop::new(llm_client, tool_executor, tool_context, config);
2212 if let Some(metrics) = tool_metrics {
2213 agent = agent.with_tool_metrics(metrics);
2214 }
2215 if let Some(queue) = command_queue {
2216 agent = agent.with_queue(queue);
2217 }
2218 agent.execute(&history, &prompt, Some(tx)).await
2219 });
2220
2221 Ok((rx, handle))
2222 }
2223
2224 pub async fn plan(&self, prompt: &str, _context: Option<&str>) -> Result<ExecutionPlan> {
2229 use crate::planning::LlmPlanner;
2230
2231 match LlmPlanner::create_plan(&self.llm_client, prompt).await {
2232 Ok(plan) => Ok(plan),
2233 Err(e) => {
2234 tracing::warn!("LLM plan creation failed, using fallback: {}", e);
2235 Ok(LlmPlanner::fallback_plan(prompt))
2236 }
2237 }
2238 }
2239
2240 pub async fn execute_with_planning(
2242 &self,
2243 history: &[Message],
2244 prompt: &str,
2245 event_tx: Option<mpsc::Sender<AgentEvent>>,
2246 ) -> Result<AgentResult> {
2247 if let Some(tx) = &event_tx {
2249 tx.send(AgentEvent::PlanningStart {
2250 prompt: prompt.to_string(),
2251 })
2252 .await
2253 .ok();
2254 }
2255
2256 let goal = if self.config.goal_tracking {
2258 let g = self.extract_goal(prompt).await?;
2259 if let Some(tx) = &event_tx {
2260 tx.send(AgentEvent::GoalExtracted { goal: g.clone() })
2261 .await
2262 .ok();
2263 }
2264 Some(g)
2265 } else {
2266 None
2267 };
2268
2269 let plan = self.plan(prompt, None).await?;
2271
2272 if let Some(tx) = &event_tx {
2274 tx.send(AgentEvent::PlanningEnd {
2275 estimated_steps: plan.steps.len(),
2276 plan: plan.clone(),
2277 })
2278 .await
2279 .ok();
2280 }
2281
2282 let plan_start = std::time::Instant::now();
2283
2284 let result = self.execute_plan(history, &plan, event_tx.clone()).await?;
2286
2287 if self.config.goal_tracking {
2289 if let Some(ref g) = goal {
2290 let achieved = self.check_goal_achievement(g, &result.text).await?;
2291 if achieved {
2292 if let Some(tx) = &event_tx {
2293 tx.send(AgentEvent::GoalAchieved {
2294 goal: g.description.clone(),
2295 total_steps: result.messages.len(),
2296 duration_ms: plan_start.elapsed().as_millis() as i64,
2297 })
2298 .await
2299 .ok();
2300 }
2301 }
2302 }
2303 }
2304
2305 Ok(result)
2306 }
2307
2308 async fn execute_plan(
2315 &self,
2316 history: &[Message],
2317 plan: &ExecutionPlan,
2318 event_tx: Option<mpsc::Sender<AgentEvent>>,
2319 ) -> Result<AgentResult> {
2320 let mut plan = plan.clone();
2321 let mut current_history = history.to_vec();
2322 let mut total_usage = TokenUsage::default();
2323 let mut tool_calls_count = 0;
2324 let total_steps = plan.steps.len();
2325
2326 let steps_text = plan
2328 .steps
2329 .iter()
2330 .enumerate()
2331 .map(|(i, step)| format!("{}. {}", i + 1, step.content))
2332 .collect::<Vec<_>>()
2333 .join("\n");
2334 current_history.push(Message::user(&crate::prompts::render(
2335 crate::prompts::PLAN_EXECUTE_GOAL,
2336 &[("goal", &plan.goal), ("steps", &steps_text)],
2337 )));
2338
2339 loop {
2340 let ready: Vec<String> = plan
2341 .get_ready_steps()
2342 .iter()
2343 .map(|s| s.id.clone())
2344 .collect();
2345
2346 if ready.is_empty() {
2347 if plan.has_deadlock() {
2349 tracing::warn!(
2350 "Plan deadlock detected: {} pending steps with unresolvable dependencies",
2351 plan.pending_count()
2352 );
2353 }
2354 break;
2355 }
2356
2357 if ready.len() == 1 {
2358 let step_id = &ready[0];
2360 let step = plan
2361 .steps
2362 .iter()
2363 .find(|s| s.id == *step_id)
2364 .ok_or_else(|| anyhow::anyhow!("step '{}' not found in plan", step_id))?
2365 .clone();
2366 let step_number = plan
2367 .steps
2368 .iter()
2369 .position(|s| s.id == *step_id)
2370 .unwrap_or(0)
2371 + 1;
2372
2373 if let Some(tx) = &event_tx {
2375 tx.send(AgentEvent::StepStart {
2376 step_id: step.id.clone(),
2377 description: step.content.clone(),
2378 step_number,
2379 total_steps,
2380 })
2381 .await
2382 .ok();
2383 }
2384
2385 plan.mark_status(&step.id, TaskStatus::InProgress);
2386
2387 let step_prompt = crate::prompts::render(
2388 crate::prompts::PLAN_EXECUTE_STEP,
2389 &[
2390 ("step_num", &step_number.to_string()),
2391 ("description", &step.content),
2392 ],
2393 );
2394
2395 match self
2396 .execute_loop(¤t_history, &step_prompt, None, event_tx.clone())
2397 .await
2398 {
2399 Ok(result) => {
2400 current_history = result.messages.clone();
2401 total_usage.prompt_tokens += result.usage.prompt_tokens;
2402 total_usage.completion_tokens += result.usage.completion_tokens;
2403 total_usage.total_tokens += result.usage.total_tokens;
2404 tool_calls_count += result.tool_calls_count;
2405 plan.mark_status(&step.id, TaskStatus::Completed);
2406
2407 if let Some(tx) = &event_tx {
2408 tx.send(AgentEvent::StepEnd {
2409 step_id: step.id.clone(),
2410 status: TaskStatus::Completed,
2411 step_number,
2412 total_steps,
2413 })
2414 .await
2415 .ok();
2416 }
2417 }
2418 Err(e) => {
2419 tracing::error!("Plan step '{}' failed: {}", step.id, e);
2420 plan.mark_status(&step.id, TaskStatus::Failed);
2421
2422 if let Some(tx) = &event_tx {
2423 tx.send(AgentEvent::StepEnd {
2424 step_id: step.id.clone(),
2425 status: TaskStatus::Failed,
2426 step_number,
2427 total_steps,
2428 })
2429 .await
2430 .ok();
2431 }
2432 }
2433 }
2434 } else {
2435 let ready_steps: Vec<_> = ready
2442 .iter()
2443 .filter_map(|id| {
2444 let step = plan.steps.iter().find(|s| s.id == *id)?.clone();
2445 let step_number =
2446 plan.steps.iter().position(|s| s.id == *id).unwrap_or(0) + 1;
2447 Some((step, step_number))
2448 })
2449 .collect();
2450
2451 for (step, step_number) in &ready_steps {
2453 plan.mark_status(&step.id, TaskStatus::InProgress);
2454 if let Some(tx) = &event_tx {
2455 tx.send(AgentEvent::StepStart {
2456 step_id: step.id.clone(),
2457 description: step.content.clone(),
2458 step_number: *step_number,
2459 total_steps,
2460 })
2461 .await
2462 .ok();
2463 }
2464 }
2465
2466 let mut join_set = tokio::task::JoinSet::new();
2468 for (step, step_number) in &ready_steps {
2469 let base_history = current_history.clone();
2470 let agent_clone = self.clone();
2471 let tx = event_tx.clone();
2472 let step_clone = step.clone();
2473 let sn = *step_number;
2474
2475 join_set.spawn(async move {
2476 let prompt = crate::prompts::render(
2477 crate::prompts::PLAN_EXECUTE_STEP,
2478 &[
2479 ("step_num", &sn.to_string()),
2480 ("description", &step_clone.content),
2481 ],
2482 );
2483 let result = agent_clone
2484 .execute_loop(&base_history, &prompt, None, tx)
2485 .await;
2486 (step_clone.id, sn, result)
2487 });
2488 }
2489
2490 let mut parallel_summaries = Vec::new();
2492 while let Some(join_result) = join_set.join_next().await {
2493 match join_result {
2494 Ok((step_id, step_number, step_result)) => match step_result {
2495 Ok(result) => {
2496 total_usage.prompt_tokens += result.usage.prompt_tokens;
2497 total_usage.completion_tokens += result.usage.completion_tokens;
2498 total_usage.total_tokens += result.usage.total_tokens;
2499 tool_calls_count += result.tool_calls_count;
2500 plan.mark_status(&step_id, TaskStatus::Completed);
2501
2502 parallel_summaries.push(format!(
2504 "- Step {} ({}): {}",
2505 step_number, step_id, result.text
2506 ));
2507
2508 if let Some(tx) = &event_tx {
2509 tx.send(AgentEvent::StepEnd {
2510 step_id,
2511 status: TaskStatus::Completed,
2512 step_number,
2513 total_steps,
2514 })
2515 .await
2516 .ok();
2517 }
2518 }
2519 Err(e) => {
2520 tracing::error!("Plan step '{}' failed: {}", step_id, e);
2521 plan.mark_status(&step_id, TaskStatus::Failed);
2522
2523 if let Some(tx) = &event_tx {
2524 tx.send(AgentEvent::StepEnd {
2525 step_id,
2526 status: TaskStatus::Failed,
2527 step_number,
2528 total_steps,
2529 })
2530 .await
2531 .ok();
2532 }
2533 }
2534 },
2535 Err(e) => {
2536 tracing::error!("JoinSet task panicked: {}", e);
2537 }
2538 }
2539 }
2540
2541 if !parallel_summaries.is_empty() {
2543 parallel_summaries.sort(); let results_text = parallel_summaries.join("\n");
2545 current_history.push(Message::user(&crate::prompts::render(
2546 crate::prompts::PLAN_PARALLEL_RESULTS,
2547 &[("results", &results_text)],
2548 )));
2549 }
2550 }
2551
2552 if self.config.goal_tracking {
2554 let completed = plan
2555 .steps
2556 .iter()
2557 .filter(|s| s.status == TaskStatus::Completed)
2558 .count();
2559 if let Some(tx) = &event_tx {
2560 tx.send(AgentEvent::GoalProgress {
2561 goal: plan.goal.clone(),
2562 progress: plan.progress(),
2563 completed_steps: completed,
2564 total_steps,
2565 })
2566 .await
2567 .ok();
2568 }
2569 }
2570 }
2571
2572 let final_text = current_history
2574 .last()
2575 .map(|m| {
2576 m.content
2577 .iter()
2578 .filter_map(|block| {
2579 if let crate::llm::ContentBlock::Text { text } = block {
2580 Some(text.as_str())
2581 } else {
2582 None
2583 }
2584 })
2585 .collect::<Vec<_>>()
2586 .join("\n")
2587 })
2588 .unwrap_or_default();
2589
2590 Ok(AgentResult {
2591 text: final_text,
2592 messages: current_history,
2593 usage: total_usage,
2594 tool_calls_count,
2595 })
2596 }
2597
2598 pub async fn extract_goal(&self, prompt: &str) -> Result<AgentGoal> {
2603 use crate::planning::LlmPlanner;
2604
2605 match LlmPlanner::extract_goal(&self.llm_client, prompt).await {
2606 Ok(goal) => Ok(goal),
2607 Err(e) => {
2608 tracing::warn!("LLM goal extraction failed, using fallback: {}", e);
2609 Ok(LlmPlanner::fallback_goal(prompt))
2610 }
2611 }
2612 }
2613
2614 pub async fn check_goal_achievement(
2619 &self,
2620 goal: &AgentGoal,
2621 current_state: &str,
2622 ) -> Result<bool> {
2623 use crate::planning::LlmPlanner;
2624
2625 match LlmPlanner::check_achievement(&self.llm_client, goal, current_state).await {
2626 Ok(result) => Ok(result.achieved),
2627 Err(e) => {
2628 tracing::warn!("LLM achievement check failed, using fallback: {}", e);
2629 let result = LlmPlanner::fallback_check_achievement(goal, current_state);
2630 Ok(result.achieved)
2631 }
2632 }
2633 }
2634}
2635
2636#[cfg(test)]
2637mod tests {
2638 use super::*;
2639 use crate::llm::{ContentBlock, StreamEvent};
2640 use crate::permissions::PermissionPolicy;
2641 use crate::tools::ToolExecutor;
2642 use std::path::PathBuf;
2643 use std::sync::atomic::{AtomicUsize, Ordering};
2644
2645 fn test_tool_context() -> ToolContext {
2647 ToolContext::new(PathBuf::from("/tmp"))
2648 }
2649
2650 #[test]
2651 fn test_agent_config_default() {
2652 let config = AgentConfig::default();
2653 assert!(config.prompt_slots.is_empty());
2654 assert!(config.tools.is_empty()); assert_eq!(config.max_tool_rounds, MAX_TOOL_ROUNDS);
2656 assert!(config.permission_checker.is_none());
2657 assert!(config.context_providers.is_empty());
2658 }
2659
2660 pub(crate) struct MockLlmClient {
2666 responses: std::sync::Mutex<Vec<LlmResponse>>,
2668 pub(crate) call_count: AtomicUsize,
2670 }
2671
2672 impl MockLlmClient {
2673 pub(crate) fn new(responses: Vec<LlmResponse>) -> Self {
2674 Self {
2675 responses: std::sync::Mutex::new(responses),
2676 call_count: AtomicUsize::new(0),
2677 }
2678 }
2679
2680 pub(crate) fn text_response(text: &str) -> LlmResponse {
2682 LlmResponse {
2683 message: Message {
2684 role: "assistant".to_string(),
2685 content: vec![ContentBlock::Text {
2686 text: text.to_string(),
2687 }],
2688 reasoning_content: None,
2689 },
2690 usage: TokenUsage {
2691 prompt_tokens: 10,
2692 completion_tokens: 5,
2693 total_tokens: 15,
2694 cache_read_tokens: None,
2695 cache_write_tokens: None,
2696 },
2697 stop_reason: Some("end_turn".to_string()),
2698 }
2699 }
2700
2701 pub(crate) fn tool_call_response(
2703 tool_id: &str,
2704 tool_name: &str,
2705 args: serde_json::Value,
2706 ) -> LlmResponse {
2707 LlmResponse {
2708 message: Message {
2709 role: "assistant".to_string(),
2710 content: vec![ContentBlock::ToolUse {
2711 id: tool_id.to_string(),
2712 name: tool_name.to_string(),
2713 input: args,
2714 }],
2715 reasoning_content: None,
2716 },
2717 usage: TokenUsage {
2718 prompt_tokens: 10,
2719 completion_tokens: 5,
2720 total_tokens: 15,
2721 cache_read_tokens: None,
2722 cache_write_tokens: None,
2723 },
2724 stop_reason: Some("tool_use".to_string()),
2725 }
2726 }
2727 }
2728
2729 #[async_trait::async_trait]
2730 impl LlmClient for MockLlmClient {
2731 async fn complete(
2732 &self,
2733 _messages: &[Message],
2734 _system: Option<&str>,
2735 _tools: &[ToolDefinition],
2736 ) -> Result<LlmResponse> {
2737 self.call_count.fetch_add(1, Ordering::SeqCst);
2738 let mut responses = self.responses.lock().unwrap();
2739 if responses.is_empty() {
2740 anyhow::bail!("No more mock responses available");
2741 }
2742 Ok(responses.remove(0))
2743 }
2744
2745 async fn complete_streaming(
2746 &self,
2747 _messages: &[Message],
2748 _system: Option<&str>,
2749 _tools: &[ToolDefinition],
2750 ) -> Result<mpsc::Receiver<StreamEvent>> {
2751 self.call_count.fetch_add(1, Ordering::SeqCst);
2752 let mut responses = self.responses.lock().unwrap();
2753 if responses.is_empty() {
2754 anyhow::bail!("No more mock responses available");
2755 }
2756 let response = responses.remove(0);
2757
2758 let (tx, rx) = mpsc::channel(10);
2759 tokio::spawn(async move {
2760 for block in &response.message.content {
2762 if let ContentBlock::Text { text } = block {
2763 tx.send(StreamEvent::TextDelta(text.clone())).await.ok();
2764 }
2765 }
2766 tx.send(StreamEvent::Done(response)).await.ok();
2767 });
2768
2769 Ok(rx)
2770 }
2771 }
2772
2773 #[tokio::test]
2778 async fn test_agent_simple_response() {
2779 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
2780 "Hello, I'm an AI assistant.",
2781 )]));
2782
2783 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2784 let config = AgentConfig::default();
2785
2786 let agent = AgentLoop::new(
2787 mock_client.clone(),
2788 tool_executor,
2789 test_tool_context(),
2790 config,
2791 );
2792 let result = agent.execute(&[], "Hello", None).await.unwrap();
2793
2794 assert_eq!(result.text, "Hello, I'm an AI assistant.");
2795 assert_eq!(result.tool_calls_count, 0);
2796 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 1);
2797 }
2798
2799 #[tokio::test]
2800 async fn test_agent_with_tool_call() {
2801 let mock_client = Arc::new(MockLlmClient::new(vec![
2802 MockLlmClient::tool_call_response(
2804 "tool-1",
2805 "bash",
2806 serde_json::json!({"command": "echo hello"}),
2807 ),
2808 MockLlmClient::text_response("The command output was: hello"),
2810 ]));
2811
2812 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2813 let config = AgentConfig::default();
2814
2815 let agent = AgentLoop::new(
2816 mock_client.clone(),
2817 tool_executor,
2818 test_tool_context(),
2819 config,
2820 );
2821 let result = agent.execute(&[], "Run echo hello", None).await.unwrap();
2822
2823 assert_eq!(result.text, "The command output was: hello");
2824 assert_eq!(result.tool_calls_count, 1);
2825 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 2);
2826 }
2827
2828 #[tokio::test]
2829 async fn test_agent_permission_deny() {
2830 let mock_client = Arc::new(MockLlmClient::new(vec![
2831 MockLlmClient::tool_call_response(
2833 "tool-1",
2834 "bash",
2835 serde_json::json!({"command": "rm -rf /tmp/test"}),
2836 ),
2837 MockLlmClient::text_response(
2839 "I cannot execute that command due to permission restrictions.",
2840 ),
2841 ]));
2842
2843 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2844
2845 let permission_policy = PermissionPolicy::new().deny("bash(rm:*)");
2847
2848 let config = AgentConfig {
2849 permission_checker: Some(Arc::new(permission_policy)),
2850 ..Default::default()
2851 };
2852
2853 let (tx, mut rx) = mpsc::channel(100);
2854 let agent = AgentLoop::new(
2855 mock_client.clone(),
2856 tool_executor,
2857 test_tool_context(),
2858 config,
2859 );
2860 let result = agent.execute(&[], "Delete files", Some(tx)).await.unwrap();
2861
2862 let mut found_permission_denied = false;
2864 while let Ok(event) = rx.try_recv() {
2865 if let AgentEvent::PermissionDenied { tool_name, .. } = event {
2866 assert_eq!(tool_name, "bash");
2867 found_permission_denied = true;
2868 }
2869 }
2870 assert!(
2871 found_permission_denied,
2872 "Should have received PermissionDenied event"
2873 );
2874
2875 assert_eq!(result.tool_calls_count, 1);
2876 }
2877
2878 #[tokio::test]
2879 async fn test_agent_permission_allow() {
2880 let mock_client = Arc::new(MockLlmClient::new(vec![
2881 MockLlmClient::tool_call_response(
2883 "tool-1",
2884 "bash",
2885 serde_json::json!({"command": "echo hello"}),
2886 ),
2887 MockLlmClient::text_response("Done!"),
2889 ]));
2890
2891 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2892
2893 let permission_policy = PermissionPolicy::new()
2895 .allow("bash(echo:*)")
2896 .deny("bash(rm:*)");
2897
2898 let config = AgentConfig {
2899 permission_checker: Some(Arc::new(permission_policy)),
2900 ..Default::default()
2901 };
2902
2903 let agent = AgentLoop::new(
2904 mock_client.clone(),
2905 tool_executor,
2906 test_tool_context(),
2907 config,
2908 );
2909 let result = agent.execute(&[], "Echo hello", None).await.unwrap();
2910
2911 assert_eq!(result.text, "Done!");
2912 assert_eq!(result.tool_calls_count, 1);
2913 }
2914
2915 #[tokio::test]
2916 async fn test_agent_streaming_events() {
2917 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
2918 "Hello!",
2919 )]));
2920
2921 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2922 let config = AgentConfig::default();
2923
2924 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2925 let (mut rx, handle) = agent.execute_streaming(&[], "Hi").await.unwrap();
2926
2927 let mut events = Vec::new();
2929 while let Some(event) = rx.recv().await {
2930 events.push(event);
2931 }
2932
2933 let result = handle.await.unwrap().unwrap();
2934 assert_eq!(result.text, "Hello!");
2935
2936 assert!(events.iter().any(|e| matches!(e, AgentEvent::Start { .. })));
2938 assert!(events.iter().any(|e| matches!(e, AgentEvent::End { .. })));
2939 }
2940
2941 #[tokio::test]
2942 async fn test_agent_max_tool_rounds() {
2943 let responses: Vec<LlmResponse> = (0..100)
2945 .map(|i| {
2946 MockLlmClient::tool_call_response(
2947 &format!("tool-{}", i),
2948 "bash",
2949 serde_json::json!({"command": "echo loop"}),
2950 )
2951 })
2952 .collect();
2953
2954 let mock_client = Arc::new(MockLlmClient::new(responses));
2955 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2956
2957 let config = AgentConfig {
2958 max_tool_rounds: 3,
2959 ..Default::default()
2960 };
2961
2962 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2963 let result = agent.execute(&[], "Loop forever", None).await;
2964
2965 assert!(result.is_err());
2967 assert!(result.unwrap_err().to_string().contains("Max tool rounds"));
2968 }
2969
2970 #[tokio::test]
2971 async fn test_agent_no_permission_policy_defaults_to_ask() {
2972 let mock_client = Arc::new(MockLlmClient::new(vec![
2975 MockLlmClient::tool_call_response(
2976 "tool-1",
2977 "bash",
2978 serde_json::json!({"command": "rm -rf /tmp/test"}),
2979 ),
2980 MockLlmClient::text_response("Denied!"),
2981 ]));
2982
2983 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
2984 let config = AgentConfig {
2985 permission_checker: None, ..Default::default()
2988 };
2989
2990 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
2991 let result = agent.execute(&[], "Delete", None).await.unwrap();
2992
2993 assert_eq!(result.text, "Denied!");
2995 assert_eq!(result.tool_calls_count, 1);
2996 }
2997
2998 #[tokio::test]
2999 async fn test_agent_permission_ask_without_cm_denies() {
3000 let mock_client = Arc::new(MockLlmClient::new(vec![
3003 MockLlmClient::tool_call_response(
3004 "tool-1",
3005 "bash",
3006 serde_json::json!({"command": "echo test"}),
3007 ),
3008 MockLlmClient::text_response("Denied!"),
3009 ]));
3010
3011 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3012
3013 let permission_policy = PermissionPolicy::new(); let config = AgentConfig {
3017 permission_checker: Some(Arc::new(permission_policy)),
3018 ..Default::default()
3020 };
3021
3022 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3023 let result = agent.execute(&[], "Echo", None).await.unwrap();
3024
3025 assert_eq!(result.text, "Denied!");
3027 assert!(result.tool_calls_count >= 1);
3029 }
3030
3031 #[tokio::test]
3036 async fn test_agent_hitl_approved() {
3037 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
3038 use tokio::sync::broadcast;
3039
3040 let mock_client = Arc::new(MockLlmClient::new(vec![
3041 MockLlmClient::tool_call_response(
3042 "tool-1",
3043 "bash",
3044 serde_json::json!({"command": "echo hello"}),
3045 ),
3046 MockLlmClient::text_response("Command executed!"),
3047 ]));
3048
3049 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3050
3051 let (event_tx, _event_rx) = broadcast::channel(100);
3053 let hitl_policy = ConfirmationPolicy {
3054 enabled: true,
3055 ..Default::default()
3056 };
3057 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3058
3059 let permission_policy = PermissionPolicy::new(); let config = AgentConfig {
3063 permission_checker: Some(Arc::new(permission_policy)),
3064 confirmation_manager: Some(confirmation_manager.clone()),
3065 ..Default::default()
3066 };
3067
3068 let cm_clone = confirmation_manager.clone();
3070 tokio::spawn(async move {
3071 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
3073 cm_clone.confirm("tool-1", true, None).await.ok();
3075 });
3076
3077 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3078 let result = agent.execute(&[], "Run echo", None).await.unwrap();
3079
3080 assert_eq!(result.text, "Command executed!");
3081 assert_eq!(result.tool_calls_count, 1);
3082 }
3083
3084 #[tokio::test]
3085 async fn test_agent_hitl_rejected() {
3086 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
3087 use tokio::sync::broadcast;
3088
3089 let mock_client = Arc::new(MockLlmClient::new(vec![
3090 MockLlmClient::tool_call_response(
3091 "tool-1",
3092 "bash",
3093 serde_json::json!({"command": "rm -rf /"}),
3094 ),
3095 MockLlmClient::text_response("Understood, I won't do that."),
3096 ]));
3097
3098 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3099
3100 let (event_tx, _event_rx) = broadcast::channel(100);
3102 let hitl_policy = ConfirmationPolicy {
3103 enabled: true,
3104 ..Default::default()
3105 };
3106 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3107
3108 let permission_policy = PermissionPolicy::new();
3110
3111 let config = AgentConfig {
3112 permission_checker: Some(Arc::new(permission_policy)),
3113 confirmation_manager: Some(confirmation_manager.clone()),
3114 ..Default::default()
3115 };
3116
3117 let cm_clone = confirmation_manager.clone();
3119 tokio::spawn(async move {
3120 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
3121 cm_clone
3122 .confirm("tool-1", false, Some("Too dangerous".to_string()))
3123 .await
3124 .ok();
3125 });
3126
3127 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3128 let result = agent.execute(&[], "Delete everything", None).await.unwrap();
3129
3130 assert_eq!(result.text, "Understood, I won't do that.");
3132 }
3133
3134 #[tokio::test]
3135 async fn test_agent_hitl_timeout_reject() {
3136 use crate::hitl::{ConfirmationManager, ConfirmationPolicy, TimeoutAction};
3137 use tokio::sync::broadcast;
3138
3139 let mock_client = Arc::new(MockLlmClient::new(vec![
3140 MockLlmClient::tool_call_response(
3141 "tool-1",
3142 "bash",
3143 serde_json::json!({"command": "echo test"}),
3144 ),
3145 MockLlmClient::text_response("Timed out, I understand."),
3146 ]));
3147
3148 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3149
3150 let (event_tx, _event_rx) = broadcast::channel(100);
3152 let hitl_policy = ConfirmationPolicy {
3153 enabled: true,
3154 default_timeout_ms: 50, timeout_action: TimeoutAction::Reject,
3156 ..Default::default()
3157 };
3158 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3159
3160 let permission_policy = PermissionPolicy::new();
3161
3162 let config = AgentConfig {
3163 permission_checker: Some(Arc::new(permission_policy)),
3164 confirmation_manager: Some(confirmation_manager),
3165 ..Default::default()
3166 };
3167
3168 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3170 let result = agent.execute(&[], "Echo", None).await.unwrap();
3171
3172 assert_eq!(result.text, "Timed out, I understand.");
3174 }
3175
3176 #[tokio::test]
3177 async fn test_agent_hitl_timeout_auto_approve() {
3178 use crate::hitl::{ConfirmationManager, ConfirmationPolicy, TimeoutAction};
3179 use tokio::sync::broadcast;
3180
3181 let mock_client = Arc::new(MockLlmClient::new(vec![
3182 MockLlmClient::tool_call_response(
3183 "tool-1",
3184 "bash",
3185 serde_json::json!({"command": "echo hello"}),
3186 ),
3187 MockLlmClient::text_response("Auto-approved and executed!"),
3188 ]));
3189
3190 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3191
3192 let (event_tx, _event_rx) = broadcast::channel(100);
3194 let hitl_policy = ConfirmationPolicy {
3195 enabled: true,
3196 default_timeout_ms: 50, timeout_action: TimeoutAction::AutoApprove,
3198 ..Default::default()
3199 };
3200 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3201
3202 let permission_policy = PermissionPolicy::new();
3203
3204 let config = AgentConfig {
3205 permission_checker: Some(Arc::new(permission_policy)),
3206 confirmation_manager: Some(confirmation_manager),
3207 ..Default::default()
3208 };
3209
3210 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3212 let result = agent.execute(&[], "Echo", None).await.unwrap();
3213
3214 assert_eq!(result.text, "Auto-approved and executed!");
3216 assert_eq!(result.tool_calls_count, 1);
3217 }
3218
3219 #[tokio::test]
3220 async fn test_agent_hitl_confirmation_events() {
3221 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
3222 use tokio::sync::broadcast;
3223
3224 let mock_client = Arc::new(MockLlmClient::new(vec![
3225 MockLlmClient::tool_call_response(
3226 "tool-1",
3227 "bash",
3228 serde_json::json!({"command": "echo test"}),
3229 ),
3230 MockLlmClient::text_response("Done!"),
3231 ]));
3232
3233 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3234
3235 let (event_tx, mut event_rx) = broadcast::channel(100);
3237 let hitl_policy = ConfirmationPolicy {
3238 enabled: true,
3239 default_timeout_ms: 5000, ..Default::default()
3241 };
3242 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3243
3244 let permission_policy = PermissionPolicy::new();
3245
3246 let config = AgentConfig {
3247 permission_checker: Some(Arc::new(permission_policy)),
3248 confirmation_manager: Some(confirmation_manager.clone()),
3249 ..Default::default()
3250 };
3251
3252 let cm_clone = confirmation_manager.clone();
3254 let event_handle = tokio::spawn(async move {
3255 let mut events = Vec::new();
3256 while let Ok(event) = event_rx.recv().await {
3258 events.push(event.clone());
3259 if let AgentEvent::ConfirmationRequired { tool_id, .. } = event {
3260 cm_clone.confirm(&tool_id, true, None).await.ok();
3262 if let Ok(recv_event) = event_rx.recv().await {
3264 events.push(recv_event);
3265 }
3266 break;
3267 }
3268 }
3269 events
3270 });
3271
3272 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3273 let _result = agent.execute(&[], "Echo", None).await.unwrap();
3274
3275 let events = event_handle.await.unwrap();
3277 assert!(
3278 events
3279 .iter()
3280 .any(|e| matches!(e, AgentEvent::ConfirmationRequired { .. })),
3281 "Should have ConfirmationRequired event"
3282 );
3283 assert!(
3284 events
3285 .iter()
3286 .any(|e| matches!(e, AgentEvent::ConfirmationReceived { approved: true, .. })),
3287 "Should have ConfirmationReceived event with approved=true"
3288 );
3289 }
3290
3291 #[tokio::test]
3292 async fn test_agent_hitl_disabled_auto_executes() {
3293 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
3295 use tokio::sync::broadcast;
3296
3297 let mock_client = Arc::new(MockLlmClient::new(vec![
3298 MockLlmClient::tool_call_response(
3299 "tool-1",
3300 "bash",
3301 serde_json::json!({"command": "echo auto"}),
3302 ),
3303 MockLlmClient::text_response("Auto executed!"),
3304 ]));
3305
3306 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3307
3308 let (event_tx, _event_rx) = broadcast::channel(100);
3310 let hitl_policy = ConfirmationPolicy {
3311 enabled: false, ..Default::default()
3313 };
3314 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3315
3316 let permission_policy = PermissionPolicy::new(); let config = AgentConfig {
3319 permission_checker: Some(Arc::new(permission_policy)),
3320 confirmation_manager: Some(confirmation_manager),
3321 ..Default::default()
3322 };
3323
3324 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3325 let result = agent.execute(&[], "Echo", None).await.unwrap();
3326
3327 assert_eq!(result.text, "Auto executed!");
3329 assert_eq!(result.tool_calls_count, 1);
3330 }
3331
3332 #[tokio::test]
3333 async fn test_agent_hitl_with_permission_deny_skips_hitl() {
3334 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
3336 use tokio::sync::broadcast;
3337
3338 let mock_client = Arc::new(MockLlmClient::new(vec![
3339 MockLlmClient::tool_call_response(
3340 "tool-1",
3341 "bash",
3342 serde_json::json!({"command": "rm -rf /"}),
3343 ),
3344 MockLlmClient::text_response("Blocked by permission."),
3345 ]));
3346
3347 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3348
3349 let (event_tx, mut event_rx) = broadcast::channel(100);
3351 let hitl_policy = ConfirmationPolicy {
3352 enabled: true,
3353 ..Default::default()
3354 };
3355 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3356
3357 let permission_policy = PermissionPolicy::new().deny("bash(rm:*)");
3359
3360 let config = AgentConfig {
3361 permission_checker: Some(Arc::new(permission_policy)),
3362 confirmation_manager: Some(confirmation_manager),
3363 ..Default::default()
3364 };
3365
3366 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3367 let result = agent.execute(&[], "Delete", None).await.unwrap();
3368
3369 assert_eq!(result.text, "Blocked by permission.");
3371
3372 let mut found_confirmation = false;
3374 while let Ok(event) = event_rx.try_recv() {
3375 if matches!(event, AgentEvent::ConfirmationRequired { .. }) {
3376 found_confirmation = true;
3377 }
3378 }
3379 assert!(
3380 !found_confirmation,
3381 "HITL should not be triggered when permission is Deny"
3382 );
3383 }
3384
3385 #[tokio::test]
3386 async fn test_agent_hitl_with_permission_allow_skips_hitl() {
3387 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
3390 use tokio::sync::broadcast;
3391
3392 let mock_client = Arc::new(MockLlmClient::new(vec![
3393 MockLlmClient::tool_call_response(
3394 "tool-1",
3395 "bash",
3396 serde_json::json!({"command": "echo hello"}),
3397 ),
3398 MockLlmClient::text_response("Allowed!"),
3399 ]));
3400
3401 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3402
3403 let (event_tx, mut event_rx) = broadcast::channel(100);
3405 let hitl_policy = ConfirmationPolicy {
3406 enabled: true,
3407 ..Default::default()
3408 };
3409 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3410
3411 let permission_policy = PermissionPolicy::new().allow("bash(echo:*)");
3413
3414 let config = AgentConfig {
3415 permission_checker: Some(Arc::new(permission_policy)),
3416 confirmation_manager: Some(confirmation_manager.clone()),
3417 ..Default::default()
3418 };
3419
3420 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3421 let result = agent.execute(&[], "Echo", None).await.unwrap();
3422
3423 assert_eq!(result.text, "Allowed!");
3425
3426 let mut found_confirmation = false;
3428 while let Ok(event) = event_rx.try_recv() {
3429 if matches!(event, AgentEvent::ConfirmationRequired { .. }) {
3430 found_confirmation = true;
3431 }
3432 }
3433 assert!(
3434 !found_confirmation,
3435 "Permission Allow should skip HITL confirmation"
3436 );
3437 }
3438
3439 #[tokio::test]
3440 async fn test_agent_hitl_multiple_tool_calls() {
3441 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
3443 use tokio::sync::broadcast;
3444
3445 let mock_client = Arc::new(MockLlmClient::new(vec![
3446 LlmResponse {
3448 message: Message {
3449 role: "assistant".to_string(),
3450 content: vec![
3451 ContentBlock::ToolUse {
3452 id: "tool-1".to_string(),
3453 name: "bash".to_string(),
3454 input: serde_json::json!({"command": "echo first"}),
3455 },
3456 ContentBlock::ToolUse {
3457 id: "tool-2".to_string(),
3458 name: "bash".to_string(),
3459 input: serde_json::json!({"command": "echo second"}),
3460 },
3461 ],
3462 reasoning_content: None,
3463 },
3464 usage: TokenUsage {
3465 prompt_tokens: 10,
3466 completion_tokens: 5,
3467 total_tokens: 15,
3468 cache_read_tokens: None,
3469 cache_write_tokens: None,
3470 },
3471 stop_reason: Some("tool_use".to_string()),
3472 },
3473 MockLlmClient::text_response("Both executed!"),
3474 ]));
3475
3476 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3477
3478 let (event_tx, _event_rx) = broadcast::channel(100);
3480 let hitl_policy = ConfirmationPolicy {
3481 enabled: true,
3482 default_timeout_ms: 5000,
3483 ..Default::default()
3484 };
3485 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3486
3487 let permission_policy = PermissionPolicy::new(); let config = AgentConfig {
3490 permission_checker: Some(Arc::new(permission_policy)),
3491 confirmation_manager: Some(confirmation_manager.clone()),
3492 ..Default::default()
3493 };
3494
3495 let cm_clone = confirmation_manager.clone();
3497 tokio::spawn(async move {
3498 tokio::time::sleep(std::time::Duration::from_millis(30)).await;
3499 cm_clone.confirm("tool-1", true, None).await.ok();
3500 tokio::time::sleep(std::time::Duration::from_millis(30)).await;
3501 cm_clone.confirm("tool-2", true, None).await.ok();
3502 });
3503
3504 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3505 let result = agent.execute(&[], "Run both", None).await.unwrap();
3506
3507 assert_eq!(result.text, "Both executed!");
3508 assert_eq!(result.tool_calls_count, 2);
3509 }
3510
3511 #[tokio::test]
3512 async fn test_agent_hitl_partial_approval() {
3513 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
3515 use tokio::sync::broadcast;
3516
3517 let mock_client = Arc::new(MockLlmClient::new(vec![
3518 LlmResponse {
3520 message: Message {
3521 role: "assistant".to_string(),
3522 content: vec![
3523 ContentBlock::ToolUse {
3524 id: "tool-1".to_string(),
3525 name: "bash".to_string(),
3526 input: serde_json::json!({"command": "echo safe"}),
3527 },
3528 ContentBlock::ToolUse {
3529 id: "tool-2".to_string(),
3530 name: "bash".to_string(),
3531 input: serde_json::json!({"command": "rm -rf /"}),
3532 },
3533 ],
3534 reasoning_content: None,
3535 },
3536 usage: TokenUsage {
3537 prompt_tokens: 10,
3538 completion_tokens: 5,
3539 total_tokens: 15,
3540 cache_read_tokens: None,
3541 cache_write_tokens: None,
3542 },
3543 stop_reason: Some("tool_use".to_string()),
3544 },
3545 MockLlmClient::text_response("First worked, second rejected."),
3546 ]));
3547
3548 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3549
3550 let (event_tx, _event_rx) = broadcast::channel(100);
3551 let hitl_policy = ConfirmationPolicy {
3552 enabled: true,
3553 default_timeout_ms: 5000,
3554 ..Default::default()
3555 };
3556 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3557
3558 let permission_policy = PermissionPolicy::new();
3559
3560 let config = AgentConfig {
3561 permission_checker: Some(Arc::new(permission_policy)),
3562 confirmation_manager: Some(confirmation_manager.clone()),
3563 ..Default::default()
3564 };
3565
3566 let cm_clone = confirmation_manager.clone();
3568 tokio::spawn(async move {
3569 tokio::time::sleep(std::time::Duration::from_millis(30)).await;
3570 cm_clone.confirm("tool-1", true, None).await.ok();
3571 tokio::time::sleep(std::time::Duration::from_millis(30)).await;
3572 cm_clone
3573 .confirm("tool-2", false, Some("Dangerous".to_string()))
3574 .await
3575 .ok();
3576 });
3577
3578 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3579 let result = agent.execute(&[], "Run both", None).await.unwrap();
3580
3581 assert_eq!(result.text, "First worked, second rejected.");
3582 assert_eq!(result.tool_calls_count, 2);
3583 }
3584
3585 #[tokio::test]
3586 async fn test_agent_hitl_yolo_mode_auto_approves() {
3587 use crate::hitl::{ConfirmationManager, ConfirmationPolicy, SessionLane};
3589 use tokio::sync::broadcast;
3590
3591 let mock_client = Arc::new(MockLlmClient::new(vec![
3592 MockLlmClient::tool_call_response(
3593 "tool-1",
3594 "read", serde_json::json!({"path": "/tmp/test.txt"}),
3596 ),
3597 MockLlmClient::text_response("File read!"),
3598 ]));
3599
3600 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3601
3602 let (event_tx, mut event_rx) = broadcast::channel(100);
3604 let mut yolo_lanes = std::collections::HashSet::new();
3605 yolo_lanes.insert(SessionLane::Query);
3606 let hitl_policy = ConfirmationPolicy {
3607 enabled: true,
3608 yolo_lanes, ..Default::default()
3610 };
3611 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3612
3613 let permission_policy = PermissionPolicy::new();
3614
3615 let config = AgentConfig {
3616 permission_checker: Some(Arc::new(permission_policy)),
3617 confirmation_manager: Some(confirmation_manager),
3618 ..Default::default()
3619 };
3620
3621 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3622 let result = agent.execute(&[], "Read file", None).await.unwrap();
3623
3624 assert_eq!(result.text, "File read!");
3626
3627 let mut found_confirmation = false;
3629 while let Ok(event) = event_rx.try_recv() {
3630 if matches!(event, AgentEvent::ConfirmationRequired { .. }) {
3631 found_confirmation = true;
3632 }
3633 }
3634 assert!(
3635 !found_confirmation,
3636 "YOLO mode should not trigger confirmation"
3637 );
3638 }
3639
3640 #[tokio::test]
3641 async fn test_agent_config_with_all_options() {
3642 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
3643 use tokio::sync::broadcast;
3644
3645 let (event_tx, _) = broadcast::channel(100);
3646 let hitl_policy = ConfirmationPolicy::default();
3647 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3648
3649 let permission_policy = PermissionPolicy::new().allow("bash(*)");
3650
3651 let config = AgentConfig {
3652 prompt_slots: SystemPromptSlots {
3653 extra: Some("Test system prompt".to_string()),
3654 ..Default::default()
3655 },
3656 tools: vec![],
3657 max_tool_rounds: 10,
3658 permission_checker: Some(Arc::new(permission_policy)),
3659 confirmation_manager: Some(confirmation_manager),
3660 context_providers: vec![],
3661 planning_enabled: false,
3662 goal_tracking: false,
3663 hook_engine: None,
3664 skill_registry: None,
3665 ..AgentConfig::default()
3666 };
3667
3668 assert!(config.prompt_slots.build().contains("Test system prompt"));
3669 assert_eq!(config.max_tool_rounds, 10);
3670 assert!(config.permission_checker.is_some());
3671 assert!(config.confirmation_manager.is_some());
3672 assert!(config.context_providers.is_empty());
3673
3674 let debug_str = format!("{:?}", config);
3676 assert!(debug_str.contains("AgentConfig"));
3677 assert!(debug_str.contains("permission_checker: true"));
3678 assert!(debug_str.contains("confirmation_manager: true"));
3679 assert!(debug_str.contains("context_providers: 0"));
3680 }
3681
3682 use crate::context::{ContextItem, ContextType};
3687
3688 struct MockContextProvider {
3690 name: String,
3691 items: Vec<ContextItem>,
3692 on_turn_calls: std::sync::Arc<tokio::sync::RwLock<Vec<(String, String, String)>>>,
3693 }
3694
3695 impl MockContextProvider {
3696 fn new(name: &str) -> Self {
3697 Self {
3698 name: name.to_string(),
3699 items: Vec::new(),
3700 on_turn_calls: std::sync::Arc::new(tokio::sync::RwLock::new(Vec::new())),
3701 }
3702 }
3703
3704 fn with_items(mut self, items: Vec<ContextItem>) -> Self {
3705 self.items = items;
3706 self
3707 }
3708 }
3709
3710 #[async_trait::async_trait]
3711 impl ContextProvider for MockContextProvider {
3712 fn name(&self) -> &str {
3713 &self.name
3714 }
3715
3716 async fn query(&self, _query: &ContextQuery) -> anyhow::Result<ContextResult> {
3717 let mut result = ContextResult::new(&self.name);
3718 for item in &self.items {
3719 result.add_item(item.clone());
3720 }
3721 Ok(result)
3722 }
3723
3724 async fn on_turn_complete(
3725 &self,
3726 session_id: &str,
3727 prompt: &str,
3728 response: &str,
3729 ) -> anyhow::Result<()> {
3730 let mut calls = self.on_turn_calls.write().await;
3731 calls.push((
3732 session_id.to_string(),
3733 prompt.to_string(),
3734 response.to_string(),
3735 ));
3736 Ok(())
3737 }
3738 }
3739
3740 #[tokio::test]
3741 async fn test_agent_with_context_provider() {
3742 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
3743 "Response using context",
3744 )]));
3745
3746 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3747
3748 let provider =
3749 MockContextProvider::new("test-provider").with_items(vec![ContextItem::new(
3750 "ctx-1",
3751 ContextType::Resource,
3752 "Relevant context here",
3753 )
3754 .with_source("test://docs/example")]);
3755
3756 let config = AgentConfig {
3757 prompt_slots: SystemPromptSlots {
3758 extra: Some("You are helpful.".to_string()),
3759 ..Default::default()
3760 },
3761 context_providers: vec![Arc::new(provider)],
3762 ..Default::default()
3763 };
3764
3765 let agent = AgentLoop::new(
3766 mock_client.clone(),
3767 tool_executor,
3768 test_tool_context(),
3769 config,
3770 );
3771 let result = agent.execute(&[], "What is X?", None).await.unwrap();
3772
3773 assert_eq!(result.text, "Response using context");
3774 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 1);
3775 }
3776
3777 #[tokio::test]
3778 async fn test_agent_context_provider_events() {
3779 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
3780 "Answer",
3781 )]));
3782
3783 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3784
3785 let provider =
3786 MockContextProvider::new("event-provider").with_items(vec![ContextItem::new(
3787 "item-1",
3788 ContextType::Memory,
3789 "Memory content",
3790 )
3791 .with_token_count(50)]);
3792
3793 let config = AgentConfig {
3794 context_providers: vec![Arc::new(provider)],
3795 ..Default::default()
3796 };
3797
3798 let (tx, mut rx) = mpsc::channel(100);
3799 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3800 let _result = agent.execute(&[], "Test prompt", Some(tx)).await.unwrap();
3801
3802 let mut events = Vec::new();
3804 while let Ok(event) = rx.try_recv() {
3805 events.push(event);
3806 }
3807
3808 assert!(
3810 events
3811 .iter()
3812 .any(|e| matches!(e, AgentEvent::ContextResolving { .. })),
3813 "Should have ContextResolving event"
3814 );
3815 assert!(
3816 events
3817 .iter()
3818 .any(|e| matches!(e, AgentEvent::ContextResolved { .. })),
3819 "Should have ContextResolved event"
3820 );
3821
3822 for event in &events {
3824 if let AgentEvent::ContextResolved {
3825 total_items,
3826 total_tokens,
3827 } = event
3828 {
3829 assert_eq!(*total_items, 1);
3830 assert_eq!(*total_tokens, 50);
3831 }
3832 }
3833 }
3834
3835 #[tokio::test]
3836 async fn test_agent_multiple_context_providers() {
3837 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
3838 "Combined response",
3839 )]));
3840
3841 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3842
3843 let provider1 = MockContextProvider::new("provider-1").with_items(vec![ContextItem::new(
3844 "p1-1",
3845 ContextType::Resource,
3846 "Resource from P1",
3847 )
3848 .with_token_count(100)]);
3849
3850 let provider2 = MockContextProvider::new("provider-2").with_items(vec![
3851 ContextItem::new("p2-1", ContextType::Memory, "Memory from P2").with_token_count(50),
3852 ContextItem::new("p2-2", ContextType::Skill, "Skill from P2").with_token_count(75),
3853 ]);
3854
3855 let config = AgentConfig {
3856 prompt_slots: SystemPromptSlots {
3857 extra: Some("Base system prompt.".to_string()),
3858 ..Default::default()
3859 },
3860 context_providers: vec![Arc::new(provider1), Arc::new(provider2)],
3861 ..Default::default()
3862 };
3863
3864 let (tx, mut rx) = mpsc::channel(100);
3865 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3866 let result = agent.execute(&[], "Query", Some(tx)).await.unwrap();
3867
3868 assert_eq!(result.text, "Combined response");
3869
3870 while let Ok(event) = rx.try_recv() {
3872 if let AgentEvent::ContextResolved {
3873 total_items,
3874 total_tokens,
3875 } = event
3876 {
3877 assert_eq!(total_items, 3); assert_eq!(total_tokens, 225); }
3880 }
3881 }
3882
3883 #[tokio::test]
3884 async fn test_agent_no_context_providers() {
3885 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
3886 "No context",
3887 )]));
3888
3889 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3890
3891 let config = AgentConfig::default();
3893
3894 let (tx, mut rx) = mpsc::channel(100);
3895 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3896 let result = agent.execute(&[], "Simple prompt", Some(tx)).await.unwrap();
3897
3898 assert_eq!(result.text, "No context");
3899
3900 let mut events = Vec::new();
3902 while let Ok(event) = rx.try_recv() {
3903 events.push(event);
3904 }
3905
3906 assert!(
3907 !events
3908 .iter()
3909 .any(|e| matches!(e, AgentEvent::ContextResolving { .. })),
3910 "Should NOT have ContextResolving event"
3911 );
3912 }
3913
3914 #[tokio::test]
3915 async fn test_agent_context_on_turn_complete() {
3916 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
3917 "Final response",
3918 )]));
3919
3920 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3921
3922 let provider = Arc::new(MockContextProvider::new("memory-provider"));
3923 let on_turn_calls = provider.on_turn_calls.clone();
3924
3925 let config = AgentConfig {
3926 context_providers: vec![provider],
3927 ..Default::default()
3928 };
3929
3930 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3931
3932 let result = agent
3934 .execute_with_session(&[], "User prompt", Some("sess-123"), None)
3935 .await
3936 .unwrap();
3937
3938 assert_eq!(result.text, "Final response");
3939
3940 let calls = on_turn_calls.read().await;
3942 assert_eq!(calls.len(), 1);
3943 assert_eq!(calls[0].0, "sess-123");
3944 assert_eq!(calls[0].1, "User prompt");
3945 assert_eq!(calls[0].2, "Final response");
3946 }
3947
3948 #[tokio::test]
3949 async fn test_agent_context_on_turn_complete_no_session() {
3950 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
3951 "Response",
3952 )]));
3953
3954 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3955
3956 let provider = Arc::new(MockContextProvider::new("memory-provider"));
3957 let on_turn_calls = provider.on_turn_calls.clone();
3958
3959 let config = AgentConfig {
3960 context_providers: vec![provider],
3961 ..Default::default()
3962 };
3963
3964 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3965
3966 let _result = agent.execute(&[], "Prompt", None).await.unwrap();
3968
3969 let calls = on_turn_calls.read().await;
3971 assert!(calls.is_empty());
3972 }
3973
3974 #[tokio::test]
3975 async fn test_agent_build_augmented_system_prompt() {
3976 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response("OK")]));
3977
3978 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3979
3980 let provider = MockContextProvider::new("test").with_items(vec![ContextItem::new(
3981 "doc-1",
3982 ContextType::Resource,
3983 "Auth uses JWT tokens.",
3984 )
3985 .with_source("viking://docs/auth")]);
3986
3987 let config = AgentConfig {
3988 prompt_slots: SystemPromptSlots {
3989 extra: Some("You are helpful.".to_string()),
3990 ..Default::default()
3991 },
3992 context_providers: vec![Arc::new(provider)],
3993 ..Default::default()
3994 };
3995
3996 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3997
3998 let context_results = agent.resolve_context("test", None).await;
4000 let augmented = agent.build_augmented_system_prompt(&context_results);
4001
4002 let augmented_str = augmented.unwrap();
4003 assert!(augmented_str.contains("You are helpful."));
4004 assert!(augmented_str.contains("<context source=\"viking://docs/auth\" type=\"Resource\">"));
4005 assert!(augmented_str.contains("Auth uses JWT tokens."));
4006 }
4007
4008 async fn collect_events(mut rx: mpsc::Receiver<AgentEvent>) -> Vec<AgentEvent> {
4014 let mut events = Vec::new();
4015 while let Ok(event) = rx.try_recv() {
4016 events.push(event);
4017 }
4018 while let Some(event) = rx.recv().await {
4020 events.push(event);
4021 }
4022 events
4023 }
4024
4025 #[tokio::test]
4026 async fn test_agent_multi_turn_tool_chain() {
4027 let mock_client = Arc::new(MockLlmClient::new(vec![
4029 MockLlmClient::tool_call_response(
4031 "t1",
4032 "bash",
4033 serde_json::json!({"command": "echo step1"}),
4034 ),
4035 MockLlmClient::tool_call_response(
4037 "t2",
4038 "bash",
4039 serde_json::json!({"command": "echo step2"}),
4040 ),
4041 MockLlmClient::text_response("Completed both steps: step1 then step2"),
4043 ]));
4044
4045 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4046 let config = AgentConfig::default();
4047
4048 let agent = AgentLoop::new(
4049 mock_client.clone(),
4050 tool_executor,
4051 test_tool_context(),
4052 config,
4053 );
4054 let result = agent.execute(&[], "Run two steps", None).await.unwrap();
4055
4056 assert_eq!(result.text, "Completed both steps: step1 then step2");
4057 assert_eq!(result.tool_calls_count, 2);
4058 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 3);
4059
4060 assert_eq!(result.messages[0].role, "user");
4062 assert_eq!(result.messages[1].role, "assistant"); assert_eq!(result.messages[2].role, "user"); assert_eq!(result.messages[3].role, "assistant"); assert_eq!(result.messages[4].role, "user"); assert_eq!(result.messages[5].role, "assistant"); assert_eq!(result.messages.len(), 6);
4068 }
4069
4070 #[tokio::test]
4071 async fn test_agent_conversation_history_preserved() {
4072 let existing_history = vec![
4074 Message::user("What is Rust?"),
4075 Message {
4076 role: "assistant".to_string(),
4077 content: vec![ContentBlock::Text {
4078 text: "Rust is a systems programming language.".to_string(),
4079 }],
4080 reasoning_content: None,
4081 },
4082 ];
4083
4084 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4085 "Rust was created by Graydon Hoare at Mozilla.",
4086 )]));
4087
4088 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4089 let agent = AgentLoop::new(
4090 mock_client.clone(),
4091 tool_executor,
4092 test_tool_context(),
4093 AgentConfig::default(),
4094 );
4095
4096 let result = agent
4097 .execute(&existing_history, "Who created it?", None)
4098 .await
4099 .unwrap();
4100
4101 assert_eq!(result.messages.len(), 4);
4103 assert_eq!(result.messages[0].text(), "What is Rust?");
4104 assert_eq!(
4105 result.messages[1].text(),
4106 "Rust is a systems programming language."
4107 );
4108 assert_eq!(result.messages[2].text(), "Who created it?");
4109 assert_eq!(
4110 result.messages[3].text(),
4111 "Rust was created by Graydon Hoare at Mozilla."
4112 );
4113 }
4114
4115 #[tokio::test]
4116 async fn test_agent_event_stream_completeness() {
4117 let mock_client = Arc::new(MockLlmClient::new(vec![
4119 MockLlmClient::tool_call_response(
4120 "t1",
4121 "bash",
4122 serde_json::json!({"command": "echo hi"}),
4123 ),
4124 MockLlmClient::text_response("Done"),
4125 ]));
4126
4127 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4128 let agent = AgentLoop::new(
4129 mock_client,
4130 tool_executor,
4131 test_tool_context(),
4132 AgentConfig::default(),
4133 );
4134
4135 let (tx, rx) = mpsc::channel(100);
4136 let result = agent.execute(&[], "Say hi", Some(tx)).await.unwrap();
4137 assert_eq!(result.text, "Done");
4138
4139 let events = collect_events(rx).await;
4140
4141 let event_types: Vec<&str> = events
4143 .iter()
4144 .map(|e| match e {
4145 AgentEvent::Start { .. } => "Start",
4146 AgentEvent::TurnStart { .. } => "TurnStart",
4147 AgentEvent::TurnEnd { .. } => "TurnEnd",
4148 AgentEvent::ToolEnd { .. } => "ToolEnd",
4149 AgentEvent::End { .. } => "End",
4150 _ => "Other",
4151 })
4152 .collect();
4153
4154 assert_eq!(event_types.first(), Some(&"Start"));
4156 assert_eq!(event_types.last(), Some(&"End"));
4157
4158 let turn_starts = event_types.iter().filter(|&&t| t == "TurnStart").count();
4160 assert_eq!(turn_starts, 2);
4161
4162 let tool_ends = event_types.iter().filter(|&&t| t == "ToolEnd").count();
4164 assert_eq!(tool_ends, 1);
4165 }
4166
4167 #[tokio::test]
4168 async fn test_agent_multiple_tools_single_turn() {
4169 let mock_client = Arc::new(MockLlmClient::new(vec![
4171 LlmResponse {
4172 message: Message {
4173 role: "assistant".to_string(),
4174 content: vec![
4175 ContentBlock::ToolUse {
4176 id: "t1".to_string(),
4177 name: "bash".to_string(),
4178 input: serde_json::json!({"command": "echo first"}),
4179 },
4180 ContentBlock::ToolUse {
4181 id: "t2".to_string(),
4182 name: "bash".to_string(),
4183 input: serde_json::json!({"command": "echo second"}),
4184 },
4185 ],
4186 reasoning_content: None,
4187 },
4188 usage: TokenUsage {
4189 prompt_tokens: 10,
4190 completion_tokens: 5,
4191 total_tokens: 15,
4192 cache_read_tokens: None,
4193 cache_write_tokens: None,
4194 },
4195 stop_reason: Some("tool_use".to_string()),
4196 },
4197 MockLlmClient::text_response("Both commands ran"),
4198 ]));
4199
4200 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4201 let agent = AgentLoop::new(
4202 mock_client.clone(),
4203 tool_executor,
4204 test_tool_context(),
4205 AgentConfig::default(),
4206 );
4207
4208 let result = agent.execute(&[], "Run both", None).await.unwrap();
4209
4210 assert_eq!(result.text, "Both commands ran");
4211 assert_eq!(result.tool_calls_count, 2);
4212 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 2); assert_eq!(result.messages[0].role, "user");
4216 assert_eq!(result.messages[1].role, "assistant");
4217 assert_eq!(result.messages[2].role, "user"); assert_eq!(result.messages[3].role, "user"); assert_eq!(result.messages[4].role, "assistant");
4220 }
4221
4222 #[tokio::test]
4223 async fn test_agent_token_usage_accumulation() {
4224 let mock_client = Arc::new(MockLlmClient::new(vec![
4226 MockLlmClient::tool_call_response(
4227 "t1",
4228 "bash",
4229 serde_json::json!({"command": "echo x"}),
4230 ),
4231 MockLlmClient::text_response("Done"),
4232 ]));
4233
4234 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4235 let agent = AgentLoop::new(
4236 mock_client,
4237 tool_executor,
4238 test_tool_context(),
4239 AgentConfig::default(),
4240 );
4241
4242 let result = agent.execute(&[], "test", None).await.unwrap();
4243
4244 assert_eq!(result.usage.prompt_tokens, 20);
4247 assert_eq!(result.usage.completion_tokens, 10);
4248 assert_eq!(result.usage.total_tokens, 30);
4249 }
4250
4251 #[tokio::test]
4252 async fn test_agent_system_prompt_passed() {
4253 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4255 "I am a coding assistant.",
4256 )]));
4257
4258 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4259 let config = AgentConfig {
4260 prompt_slots: SystemPromptSlots {
4261 extra: Some("You are a coding assistant.".to_string()),
4262 ..Default::default()
4263 },
4264 ..Default::default()
4265 };
4266
4267 let agent = AgentLoop::new(
4268 mock_client.clone(),
4269 tool_executor,
4270 test_tool_context(),
4271 config,
4272 );
4273 let result = agent.execute(&[], "What are you?", None).await.unwrap();
4274
4275 assert_eq!(result.text, "I am a coding assistant.");
4276 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 1);
4277 }
4278
4279 #[tokio::test]
4280 async fn test_agent_max_rounds_with_persistent_tool_calls() {
4281 let mut responses = Vec::new();
4283 for i in 0..15 {
4284 responses.push(MockLlmClient::tool_call_response(
4285 &format!("t{}", i),
4286 "bash",
4287 serde_json::json!({"command": format!("echo round{}", i)}),
4288 ));
4289 }
4290
4291 let mock_client = Arc::new(MockLlmClient::new(responses));
4292 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4293 let config = AgentConfig {
4294 max_tool_rounds: 5,
4295 ..Default::default()
4296 };
4297
4298 let agent = AgentLoop::new(
4299 mock_client.clone(),
4300 tool_executor,
4301 test_tool_context(),
4302 config,
4303 );
4304 let result = agent.execute(&[], "Loop forever", None).await;
4305
4306 assert!(result.is_err());
4307 let err = result.unwrap_err().to_string();
4308 assert!(err.contains("Max tool rounds (5) exceeded"));
4309 }
4310
4311 #[tokio::test]
4312 async fn test_agent_end_event_contains_final_text() {
4313 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4314 "Final answer here",
4315 )]));
4316
4317 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4318 let agent = AgentLoop::new(
4319 mock_client,
4320 tool_executor,
4321 test_tool_context(),
4322 AgentConfig::default(),
4323 );
4324
4325 let (tx, rx) = mpsc::channel(100);
4326 agent.execute(&[], "test", Some(tx)).await.unwrap();
4327
4328 let events = collect_events(rx).await;
4329 let end_event = events.iter().find(|e| matches!(e, AgentEvent::End { .. }));
4330 assert!(end_event.is_some());
4331
4332 if let AgentEvent::End { text, usage } = end_event.unwrap() {
4333 assert_eq!(text, "Final answer here");
4334 assert_eq!(usage.total_tokens, 15);
4335 }
4336 }
4337}
4338
4339#[cfg(test)]
4340mod extra_agent_tests {
4341 use super::*;
4342 use crate::agent::tests::MockLlmClient;
4343 use crate::queue::SessionQueueConfig;
4344 use crate::tools::ToolExecutor;
4345 use std::path::PathBuf;
4346 use std::sync::atomic::{AtomicUsize, Ordering};
4347
4348 fn test_tool_context() -> ToolContext {
4349 ToolContext::new(PathBuf::from("/tmp"))
4350 }
4351
4352 #[test]
4357 fn test_agent_config_debug() {
4358 let config = AgentConfig {
4359 prompt_slots: SystemPromptSlots {
4360 extra: Some("You are helpful".to_string()),
4361 ..Default::default()
4362 },
4363 tools: vec![],
4364 max_tool_rounds: 10,
4365 permission_checker: None,
4366 confirmation_manager: None,
4367 context_providers: vec![],
4368 planning_enabled: true,
4369 goal_tracking: false,
4370 hook_engine: None,
4371 skill_registry: None,
4372 ..AgentConfig::default()
4373 };
4374 let debug = format!("{:?}", config);
4375 assert!(debug.contains("AgentConfig"));
4376 assert!(debug.contains("planning_enabled"));
4377 }
4378
4379 #[test]
4380 fn test_agent_config_default_values() {
4381 let config = AgentConfig::default();
4382 assert_eq!(config.max_tool_rounds, MAX_TOOL_ROUNDS);
4383 assert!(!config.planning_enabled);
4384 assert!(!config.goal_tracking);
4385 assert!(config.context_providers.is_empty());
4386 }
4387
4388 #[test]
4393 fn test_agent_event_serialize_start() {
4394 let event = AgentEvent::Start {
4395 prompt: "Hello".to_string(),
4396 };
4397 let json = serde_json::to_string(&event).unwrap();
4398 assert!(json.contains("agent_start"));
4399 assert!(json.contains("Hello"));
4400 }
4401
4402 #[test]
4403 fn test_agent_event_serialize_text_delta() {
4404 let event = AgentEvent::TextDelta {
4405 text: "chunk".to_string(),
4406 };
4407 let json = serde_json::to_string(&event).unwrap();
4408 assert!(json.contains("text_delta"));
4409 }
4410
4411 #[test]
4412 fn test_agent_event_serialize_tool_start() {
4413 let event = AgentEvent::ToolStart {
4414 id: "t1".to_string(),
4415 name: "bash".to_string(),
4416 };
4417 let json = serde_json::to_string(&event).unwrap();
4418 assert!(json.contains("tool_start"));
4419 assert!(json.contains("bash"));
4420 }
4421
4422 #[test]
4423 fn test_agent_event_serialize_tool_end() {
4424 let event = AgentEvent::ToolEnd {
4425 id: "t1".to_string(),
4426 name: "bash".to_string(),
4427 output: "hello".to_string(),
4428 exit_code: 0,
4429 };
4430 let json = serde_json::to_string(&event).unwrap();
4431 assert!(json.contains("tool_end"));
4432 }
4433
4434 #[test]
4435 fn test_agent_event_serialize_error() {
4436 let event = AgentEvent::Error {
4437 message: "oops".to_string(),
4438 };
4439 let json = serde_json::to_string(&event).unwrap();
4440 assert!(json.contains("error"));
4441 assert!(json.contains("oops"));
4442 }
4443
4444 #[test]
4445 fn test_agent_event_serialize_confirmation_required() {
4446 let event = AgentEvent::ConfirmationRequired {
4447 tool_id: "t1".to_string(),
4448 tool_name: "bash".to_string(),
4449 args: serde_json::json!({"cmd": "rm"}),
4450 timeout_ms: 30000,
4451 };
4452 let json = serde_json::to_string(&event).unwrap();
4453 assert!(json.contains("confirmation_required"));
4454 }
4455
4456 #[test]
4457 fn test_agent_event_serialize_confirmation_received() {
4458 let event = AgentEvent::ConfirmationReceived {
4459 tool_id: "t1".to_string(),
4460 approved: true,
4461 reason: Some("safe".to_string()),
4462 };
4463 let json = serde_json::to_string(&event).unwrap();
4464 assert!(json.contains("confirmation_received"));
4465 }
4466
4467 #[test]
4468 fn test_agent_event_serialize_confirmation_timeout() {
4469 let event = AgentEvent::ConfirmationTimeout {
4470 tool_id: "t1".to_string(),
4471 action_taken: "rejected".to_string(),
4472 };
4473 let json = serde_json::to_string(&event).unwrap();
4474 assert!(json.contains("confirmation_timeout"));
4475 }
4476
4477 #[test]
4478 fn test_agent_event_serialize_external_task_pending() {
4479 let event = AgentEvent::ExternalTaskPending {
4480 task_id: "task-1".to_string(),
4481 session_id: "sess-1".to_string(),
4482 lane: crate::hitl::SessionLane::Execute,
4483 command_type: "bash".to_string(),
4484 payload: serde_json::json!({}),
4485 timeout_ms: 60000,
4486 };
4487 let json = serde_json::to_string(&event).unwrap();
4488 assert!(json.contains("external_task_pending"));
4489 }
4490
4491 #[test]
4492 fn test_agent_event_serialize_external_task_completed() {
4493 let event = AgentEvent::ExternalTaskCompleted {
4494 task_id: "task-1".to_string(),
4495 session_id: "sess-1".to_string(),
4496 success: false,
4497 };
4498 let json = serde_json::to_string(&event).unwrap();
4499 assert!(json.contains("external_task_completed"));
4500 }
4501
4502 #[test]
4503 fn test_agent_event_serialize_permission_denied() {
4504 let event = AgentEvent::PermissionDenied {
4505 tool_id: "t1".to_string(),
4506 tool_name: "bash".to_string(),
4507 args: serde_json::json!({}),
4508 reason: "denied".to_string(),
4509 };
4510 let json = serde_json::to_string(&event).unwrap();
4511 assert!(json.contains("permission_denied"));
4512 }
4513
4514 #[test]
4515 fn test_agent_event_serialize_context_compacted() {
4516 let event = AgentEvent::ContextCompacted {
4517 session_id: "sess-1".to_string(),
4518 before_messages: 100,
4519 after_messages: 20,
4520 percent_before: 0.85,
4521 };
4522 let json = serde_json::to_string(&event).unwrap();
4523 assert!(json.contains("context_compacted"));
4524 }
4525
4526 #[test]
4527 fn test_agent_event_serialize_turn_start() {
4528 let event = AgentEvent::TurnStart { turn: 3 };
4529 let json = serde_json::to_string(&event).unwrap();
4530 assert!(json.contains("turn_start"));
4531 }
4532
4533 #[test]
4534 fn test_agent_event_serialize_turn_end() {
4535 let event = AgentEvent::TurnEnd {
4536 turn: 3,
4537 usage: TokenUsage::default(),
4538 };
4539 let json = serde_json::to_string(&event).unwrap();
4540 assert!(json.contains("turn_end"));
4541 }
4542
4543 #[test]
4544 fn test_agent_event_serialize_end() {
4545 let event = AgentEvent::End {
4546 text: "Done".to_string(),
4547 usage: TokenUsage {
4548 prompt_tokens: 100,
4549 completion_tokens: 50,
4550 total_tokens: 150,
4551 cache_read_tokens: None,
4552 cache_write_tokens: None,
4553 },
4554 };
4555 let json = serde_json::to_string(&event).unwrap();
4556 assert!(json.contains("agent_end"));
4557 }
4558
4559 #[test]
4564 fn test_agent_result_fields() {
4565 let result = AgentResult {
4566 text: "output".to_string(),
4567 messages: vec![Message::user("hello")],
4568 usage: TokenUsage::default(),
4569 tool_calls_count: 3,
4570 };
4571 assert_eq!(result.text, "output");
4572 assert_eq!(result.messages.len(), 1);
4573 assert_eq!(result.tool_calls_count, 3);
4574 }
4575
4576 #[test]
4581 fn test_agent_event_serialize_context_resolving() {
4582 let event = AgentEvent::ContextResolving {
4583 providers: vec!["provider1".to_string(), "provider2".to_string()],
4584 };
4585 let json = serde_json::to_string(&event).unwrap();
4586 assert!(json.contains("context_resolving"));
4587 assert!(json.contains("provider1"));
4588 }
4589
4590 #[test]
4591 fn test_agent_event_serialize_context_resolved() {
4592 let event = AgentEvent::ContextResolved {
4593 total_items: 5,
4594 total_tokens: 1000,
4595 };
4596 let json = serde_json::to_string(&event).unwrap();
4597 assert!(json.contains("context_resolved"));
4598 assert!(json.contains("1000"));
4599 }
4600
4601 #[test]
4602 fn test_agent_event_serialize_command_dead_lettered() {
4603 let event = AgentEvent::CommandDeadLettered {
4604 command_id: "cmd-1".to_string(),
4605 command_type: "bash".to_string(),
4606 lane: "execute".to_string(),
4607 error: "timeout".to_string(),
4608 attempts: 3,
4609 };
4610 let json = serde_json::to_string(&event).unwrap();
4611 assert!(json.contains("command_dead_lettered"));
4612 assert!(json.contains("cmd-1"));
4613 }
4614
4615 #[test]
4616 fn test_agent_event_serialize_command_retry() {
4617 let event = AgentEvent::CommandRetry {
4618 command_id: "cmd-2".to_string(),
4619 command_type: "read".to_string(),
4620 lane: "query".to_string(),
4621 attempt: 2,
4622 delay_ms: 1000,
4623 };
4624 let json = serde_json::to_string(&event).unwrap();
4625 assert!(json.contains("command_retry"));
4626 assert!(json.contains("cmd-2"));
4627 }
4628
4629 #[test]
4630 fn test_agent_event_serialize_queue_alert() {
4631 let event = AgentEvent::QueueAlert {
4632 level: "warning".to_string(),
4633 alert_type: "depth".to_string(),
4634 message: "Queue depth exceeded".to_string(),
4635 };
4636 let json = serde_json::to_string(&event).unwrap();
4637 assert!(json.contains("queue_alert"));
4638 assert!(json.contains("warning"));
4639 }
4640
4641 #[test]
4642 fn test_agent_event_serialize_task_updated() {
4643 let event = AgentEvent::TaskUpdated {
4644 session_id: "sess-1".to_string(),
4645 tasks: vec![],
4646 };
4647 let json = serde_json::to_string(&event).unwrap();
4648 assert!(json.contains("task_updated"));
4649 assert!(json.contains("sess-1"));
4650 }
4651
4652 #[test]
4653 fn test_agent_event_serialize_memory_stored() {
4654 let event = AgentEvent::MemoryStored {
4655 memory_id: "mem-1".to_string(),
4656 memory_type: "conversation".to_string(),
4657 importance: 0.8,
4658 tags: vec!["important".to_string()],
4659 };
4660 let json = serde_json::to_string(&event).unwrap();
4661 assert!(json.contains("memory_stored"));
4662 assert!(json.contains("mem-1"));
4663 }
4664
4665 #[test]
4666 fn test_agent_event_serialize_memory_recalled() {
4667 let event = AgentEvent::MemoryRecalled {
4668 memory_id: "mem-2".to_string(),
4669 content: "Previous conversation".to_string(),
4670 relevance: 0.9,
4671 };
4672 let json = serde_json::to_string(&event).unwrap();
4673 assert!(json.contains("memory_recalled"));
4674 assert!(json.contains("mem-2"));
4675 }
4676
4677 #[test]
4678 fn test_agent_event_serialize_memories_searched() {
4679 let event = AgentEvent::MemoriesSearched {
4680 query: Some("search term".to_string()),
4681 tags: vec!["tag1".to_string()],
4682 result_count: 5,
4683 };
4684 let json = serde_json::to_string(&event).unwrap();
4685 assert!(json.contains("memories_searched"));
4686 assert!(json.contains("search term"));
4687 }
4688
4689 #[test]
4690 fn test_agent_event_serialize_memory_cleared() {
4691 let event = AgentEvent::MemoryCleared {
4692 tier: "short_term".to_string(),
4693 count: 10,
4694 };
4695 let json = serde_json::to_string(&event).unwrap();
4696 assert!(json.contains("memory_cleared"));
4697 assert!(json.contains("short_term"));
4698 }
4699
4700 #[test]
4701 fn test_agent_event_serialize_subagent_start() {
4702 let event = AgentEvent::SubagentStart {
4703 task_id: "task-1".to_string(),
4704 session_id: "child-sess".to_string(),
4705 parent_session_id: "parent-sess".to_string(),
4706 agent: "explore".to_string(),
4707 description: "Explore codebase".to_string(),
4708 };
4709 let json = serde_json::to_string(&event).unwrap();
4710 assert!(json.contains("subagent_start"));
4711 assert!(json.contains("explore"));
4712 }
4713
4714 #[test]
4715 fn test_agent_event_serialize_subagent_progress() {
4716 let event = AgentEvent::SubagentProgress {
4717 task_id: "task-1".to_string(),
4718 session_id: "child-sess".to_string(),
4719 status: "processing".to_string(),
4720 metadata: serde_json::json!({"progress": 50}),
4721 };
4722 let json = serde_json::to_string(&event).unwrap();
4723 assert!(json.contains("subagent_progress"));
4724 assert!(json.contains("processing"));
4725 }
4726
4727 #[test]
4728 fn test_agent_event_serialize_subagent_end() {
4729 let event = AgentEvent::SubagentEnd {
4730 task_id: "task-1".to_string(),
4731 session_id: "child-sess".to_string(),
4732 agent: "explore".to_string(),
4733 output: "Found 10 files".to_string(),
4734 success: true,
4735 };
4736 let json = serde_json::to_string(&event).unwrap();
4737 assert!(json.contains("subagent_end"));
4738 assert!(json.contains("Found 10 files"));
4739 }
4740
4741 #[test]
4742 fn test_agent_event_serialize_planning_start() {
4743 let event = AgentEvent::PlanningStart {
4744 prompt: "Build a web app".to_string(),
4745 };
4746 let json = serde_json::to_string(&event).unwrap();
4747 assert!(json.contains("planning_start"));
4748 assert!(json.contains("Build a web app"));
4749 }
4750
4751 #[test]
4752 fn test_agent_event_serialize_planning_end() {
4753 use crate::planning::{Complexity, ExecutionPlan};
4754 let plan = ExecutionPlan::new("Test goal".to_string(), Complexity::Simple);
4755 let event = AgentEvent::PlanningEnd {
4756 plan,
4757 estimated_steps: 3,
4758 };
4759 let json = serde_json::to_string(&event).unwrap();
4760 assert!(json.contains("planning_end"));
4761 assert!(json.contains("estimated_steps"));
4762 }
4763
4764 #[test]
4765 fn test_agent_event_serialize_step_start() {
4766 let event = AgentEvent::StepStart {
4767 step_id: "step-1".to_string(),
4768 description: "Initialize project".to_string(),
4769 step_number: 1,
4770 total_steps: 5,
4771 };
4772 let json = serde_json::to_string(&event).unwrap();
4773 assert!(json.contains("step_start"));
4774 assert!(json.contains("Initialize project"));
4775 }
4776
4777 #[test]
4778 fn test_agent_event_serialize_step_end() {
4779 let event = AgentEvent::StepEnd {
4780 step_id: "step-1".to_string(),
4781 status: TaskStatus::Completed,
4782 step_number: 1,
4783 total_steps: 5,
4784 };
4785 let json = serde_json::to_string(&event).unwrap();
4786 assert!(json.contains("step_end"));
4787 assert!(json.contains("step-1"));
4788 }
4789
4790 #[test]
4791 fn test_agent_event_serialize_goal_extracted() {
4792 use crate::planning::AgentGoal;
4793 let goal = AgentGoal::new("Complete the task".to_string());
4794 let event = AgentEvent::GoalExtracted { goal };
4795 let json = serde_json::to_string(&event).unwrap();
4796 assert!(json.contains("goal_extracted"));
4797 }
4798
4799 #[test]
4800 fn test_agent_event_serialize_goal_progress() {
4801 let event = AgentEvent::GoalProgress {
4802 goal: "Build app".to_string(),
4803 progress: 0.5,
4804 completed_steps: 2,
4805 total_steps: 4,
4806 };
4807 let json = serde_json::to_string(&event).unwrap();
4808 assert!(json.contains("goal_progress"));
4809 assert!(json.contains("0.5"));
4810 }
4811
4812 #[test]
4813 fn test_agent_event_serialize_goal_achieved() {
4814 let event = AgentEvent::GoalAchieved {
4815 goal: "Build app".to_string(),
4816 total_steps: 4,
4817 duration_ms: 5000,
4818 };
4819 let json = serde_json::to_string(&event).unwrap();
4820 assert!(json.contains("goal_achieved"));
4821 assert!(json.contains("5000"));
4822 }
4823
4824 #[tokio::test]
4825 async fn test_extract_goal_with_json_response() {
4826 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4828 r#"{"description": "Build web app", "success_criteria": ["App runs on port 3000", "Has login page"]}"#,
4829 )]));
4830 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4831 let agent = AgentLoop::new(
4832 mock_client,
4833 tool_executor,
4834 test_tool_context(),
4835 AgentConfig::default(),
4836 );
4837
4838 let goal = agent.extract_goal("Build a web app").await.unwrap();
4839 assert_eq!(goal.description, "Build web app");
4840 assert_eq!(goal.success_criteria.len(), 2);
4841 assert_eq!(goal.success_criteria[0], "App runs on port 3000");
4842 }
4843
4844 #[tokio::test]
4845 async fn test_extract_goal_fallback_on_non_json() {
4846 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4848 "Some non-JSON response",
4849 )]));
4850 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4851 let agent = AgentLoop::new(
4852 mock_client,
4853 tool_executor,
4854 test_tool_context(),
4855 AgentConfig::default(),
4856 );
4857
4858 let goal = agent.extract_goal("Do something").await.unwrap();
4859 assert_eq!(goal.description, "Do something");
4861 assert_eq!(goal.success_criteria.len(), 2);
4863 }
4864
4865 #[tokio::test]
4866 async fn test_check_goal_achievement_json_yes() {
4867 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4868 r#"{"achieved": true, "progress": 1.0, "remaining_criteria": []}"#,
4869 )]));
4870 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4871 let agent = AgentLoop::new(
4872 mock_client,
4873 tool_executor,
4874 test_tool_context(),
4875 AgentConfig::default(),
4876 );
4877
4878 let goal = crate::planning::AgentGoal::new("Test goal".to_string());
4879 let achieved = agent
4880 .check_goal_achievement(&goal, "All done")
4881 .await
4882 .unwrap();
4883 assert!(achieved);
4884 }
4885
4886 #[tokio::test]
4887 async fn test_check_goal_achievement_fallback_not_done() {
4888 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4890 "invalid json",
4891 )]));
4892 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4893 let agent = AgentLoop::new(
4894 mock_client,
4895 tool_executor,
4896 test_tool_context(),
4897 AgentConfig::default(),
4898 );
4899
4900 let goal = crate::planning::AgentGoal::new("Test goal".to_string());
4901 let achieved = agent
4903 .check_goal_achievement(&goal, "still working")
4904 .await
4905 .unwrap();
4906 assert!(!achieved);
4907 }
4908
4909 #[test]
4914 fn test_build_augmented_system_prompt_empty_context() {
4915 let mock_client = Arc::new(MockLlmClient::new(vec![]));
4916 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4917 let config = AgentConfig {
4918 prompt_slots: SystemPromptSlots {
4919 extra: Some("Base prompt".to_string()),
4920 ..Default::default()
4921 },
4922 ..Default::default()
4923 };
4924 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
4925
4926 let result = agent.build_augmented_system_prompt(&[]);
4927 assert!(result.unwrap().contains("Base prompt"));
4928 }
4929
4930 #[test]
4931 fn test_build_augmented_system_prompt_no_custom_slots() {
4932 let mock_client = Arc::new(MockLlmClient::new(vec![]));
4933 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4934 let agent = AgentLoop::new(
4935 mock_client,
4936 tool_executor,
4937 test_tool_context(),
4938 AgentConfig::default(),
4939 );
4940
4941 let result = agent.build_augmented_system_prompt(&[]);
4942 assert!(result.is_some());
4944 assert!(result.unwrap().contains("Core Behaviour"));
4945 }
4946
4947 #[test]
4948 fn test_build_augmented_system_prompt_with_context_no_base() {
4949 use crate::context::{ContextItem, ContextResult, ContextType};
4950
4951 let mock_client = Arc::new(MockLlmClient::new(vec![]));
4952 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4953 let agent = AgentLoop::new(
4954 mock_client,
4955 tool_executor,
4956 test_tool_context(),
4957 AgentConfig::default(),
4958 );
4959
4960 let context = vec![ContextResult {
4961 provider: "test".to_string(),
4962 items: vec![ContextItem::new("id1", ContextType::Resource, "Content")],
4963 total_tokens: 10,
4964 truncated: false,
4965 }];
4966
4967 let result = agent.build_augmented_system_prompt(&context);
4968 assert!(result.is_some());
4969 let text = result.unwrap();
4970 assert!(text.contains("<context"));
4971 assert!(text.contains("Content"));
4972 }
4973
4974 #[test]
4979 fn test_agent_result_clone() {
4980 let result = AgentResult {
4981 text: "output".to_string(),
4982 messages: vec![Message::user("hello")],
4983 usage: TokenUsage::default(),
4984 tool_calls_count: 3,
4985 };
4986 let cloned = result.clone();
4987 assert_eq!(cloned.text, result.text);
4988 assert_eq!(cloned.tool_calls_count, result.tool_calls_count);
4989 }
4990
4991 #[test]
4992 fn test_agent_result_debug() {
4993 let result = AgentResult {
4994 text: "output".to_string(),
4995 messages: vec![Message::user("hello")],
4996 usage: TokenUsage::default(),
4997 tool_calls_count: 3,
4998 };
4999 let debug = format!("{:?}", result);
5000 assert!(debug.contains("AgentResult"));
5001 assert!(debug.contains("output"));
5002 }
5003
5004 #[tokio::test]
5013 async fn test_tool_command_command_type() {
5014 let executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5015 let cmd = ToolCommand {
5016 tool_executor: executor,
5017 tool_name: "read".to_string(),
5018 tool_args: serde_json::json!({"file": "test.rs"}),
5019 skill_registry: None,
5020 tool_context: test_tool_context(),
5021 };
5022 assert_eq!(cmd.command_type(), "read");
5023 }
5024
5025 #[tokio::test]
5026 async fn test_tool_command_payload() {
5027 let executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5028 let args = serde_json::json!({"file": "test.rs", "offset": 10});
5029 let cmd = ToolCommand {
5030 tool_executor: executor,
5031 tool_name: "read".to_string(),
5032 tool_args: args.clone(),
5033 skill_registry: None,
5034 tool_context: test_tool_context(),
5035 };
5036 assert_eq!(cmd.payload(), args);
5037 }
5038
5039 #[tokio::test(flavor = "multi_thread")]
5044 async fn test_agent_loop_with_queue() {
5045 use tokio::sync::broadcast;
5046
5047 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
5048 "Hello",
5049 )]));
5050 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5051 let config = AgentConfig::default();
5052
5053 let (event_tx, _) = broadcast::channel(100);
5054 let queue = SessionLaneQueue::new("test-session", SessionQueueConfig::default(), event_tx)
5055 .await
5056 .unwrap();
5057
5058 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config)
5059 .with_queue(Arc::new(queue));
5060
5061 assert!(agent.command_queue.is_some());
5062 }
5063
5064 #[tokio::test]
5065 async fn test_agent_loop_without_queue() {
5066 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
5067 "Hello",
5068 )]));
5069 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5070 let config = AgentConfig::default();
5071
5072 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
5073
5074 assert!(agent.command_queue.is_none());
5075 }
5076
5077 #[tokio::test]
5082 async fn test_execute_plan_parallel_independent() {
5083 use crate::planning::{Complexity, ExecutionPlan, Task};
5084
5085 let mock_client = Arc::new(MockLlmClient::new(vec![
5088 MockLlmClient::text_response("Step 1 done"),
5089 MockLlmClient::text_response("Step 2 done"),
5090 MockLlmClient::text_response("Step 3 done"),
5091 ]));
5092
5093 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5094 let config = AgentConfig::default();
5095 let agent = AgentLoop::new(
5096 mock_client.clone(),
5097 tool_executor,
5098 test_tool_context(),
5099 config,
5100 );
5101
5102 let mut plan = ExecutionPlan::new("Test parallel", Complexity::Simple);
5103 plan.add_step(Task::new("s1", "First step"));
5104 plan.add_step(Task::new("s2", "Second step"));
5105 plan.add_step(Task::new("s3", "Third step"));
5106
5107 let (tx, mut rx) = mpsc::channel(100);
5108 let result = agent.execute_plan(&[], &plan, Some(tx)).await.unwrap();
5109
5110 assert_eq!(result.usage.total_tokens, 45);
5112
5113 let mut step_starts = Vec::new();
5115 let mut step_ends = Vec::new();
5116 rx.close();
5117 while let Some(event) = rx.recv().await {
5118 match event {
5119 AgentEvent::StepStart { step_id, .. } => step_starts.push(step_id),
5120 AgentEvent::StepEnd {
5121 step_id, status, ..
5122 } => {
5123 assert_eq!(status, TaskStatus::Completed);
5124 step_ends.push(step_id);
5125 }
5126 _ => {}
5127 }
5128 }
5129 assert_eq!(step_starts.len(), 3);
5130 assert_eq!(step_ends.len(), 3);
5131 }
5132
5133 #[tokio::test]
5134 async fn test_execute_plan_respects_dependencies() {
5135 use crate::planning::{Complexity, ExecutionPlan, Task};
5136
5137 let mock_client = Arc::new(MockLlmClient::new(vec![
5140 MockLlmClient::text_response("Step 1 done"),
5141 MockLlmClient::text_response("Step 2 done"),
5142 MockLlmClient::text_response("Step 3 done"),
5143 ]));
5144
5145 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5146 let config = AgentConfig::default();
5147 let agent = AgentLoop::new(
5148 mock_client.clone(),
5149 tool_executor,
5150 test_tool_context(),
5151 config,
5152 );
5153
5154 let mut plan = ExecutionPlan::new("Test deps", Complexity::Medium);
5155 plan.add_step(Task::new("s1", "Independent A"));
5156 plan.add_step(Task::new("s2", "Independent B"));
5157 plan.add_step(
5158 Task::new("s3", "Depends on A+B")
5159 .with_dependencies(vec!["s1".to_string(), "s2".to_string()]),
5160 );
5161
5162 let (tx, mut rx) = mpsc::channel(100);
5163 let result = agent.execute_plan(&[], &plan, Some(tx)).await.unwrap();
5164
5165 assert_eq!(result.usage.total_tokens, 45);
5167
5168 let mut events = Vec::new();
5170 rx.close();
5171 while let Some(event) = rx.recv().await {
5172 match &event {
5173 AgentEvent::StepStart { step_id, .. } => {
5174 events.push(format!("start:{}", step_id));
5175 }
5176 AgentEvent::StepEnd { step_id, .. } => {
5177 events.push(format!("end:{}", step_id));
5178 }
5179 _ => {}
5180 }
5181 }
5182
5183 let s1_end = events.iter().position(|e| e == "end:s1").unwrap();
5185 let s2_end = events.iter().position(|e| e == "end:s2").unwrap();
5186 let s3_start = events.iter().position(|e| e == "start:s3").unwrap();
5187 assert!(
5188 s3_start > s1_end,
5189 "s3 started before s1 ended: {:?}",
5190 events
5191 );
5192 assert!(
5193 s3_start > s2_end,
5194 "s3 started before s2 ended: {:?}",
5195 events
5196 );
5197
5198 assert!(result.text.contains("Step 3 done") || !result.text.is_empty());
5200 }
5201
5202 #[tokio::test]
5203 async fn test_execute_plan_handles_step_failure() {
5204 use crate::planning::{Complexity, ExecutionPlan, Task};
5205
5206 let mock_client = Arc::new(MockLlmClient::new(vec![
5216 MockLlmClient::text_response("s1 done"),
5218 MockLlmClient::text_response("s3 done"),
5219 ]));
5222
5223 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5224 let config = AgentConfig::default();
5225 let agent = AgentLoop::new(
5226 mock_client.clone(),
5227 tool_executor,
5228 test_tool_context(),
5229 config,
5230 );
5231
5232 let mut plan = ExecutionPlan::new("Test failure", Complexity::Medium);
5233 plan.add_step(Task::new("s1", "Independent step"));
5234 plan.add_step(Task::new("s2", "Depends on s1").with_dependencies(vec!["s1".to_string()]));
5235 plan.add_step(Task::new("s3", "Another independent"));
5236 plan.add_step(Task::new("s4", "Depends on s2").with_dependencies(vec!["s2".to_string()]));
5237
5238 let (tx, mut rx) = mpsc::channel(100);
5239 let _result = agent.execute_plan(&[], &plan, Some(tx)).await.unwrap();
5240
5241 let mut completed_steps = Vec::new();
5244 let mut failed_steps = Vec::new();
5245 rx.close();
5246 while let Some(event) = rx.recv().await {
5247 if let AgentEvent::StepEnd {
5248 step_id, status, ..
5249 } = event
5250 {
5251 match status {
5252 TaskStatus::Completed => completed_steps.push(step_id),
5253 TaskStatus::Failed => failed_steps.push(step_id),
5254 _ => {}
5255 }
5256 }
5257 }
5258
5259 assert!(
5260 completed_steps.contains(&"s1".to_string()),
5261 "s1 should complete"
5262 );
5263 assert!(
5264 completed_steps.contains(&"s3".to_string()),
5265 "s3 should complete"
5266 );
5267 assert!(failed_steps.contains(&"s2".to_string()), "s2 should fail");
5268 assert!(
5270 !completed_steps.contains(&"s4".to_string()),
5271 "s4 should not complete"
5272 );
5273 assert!(
5274 !failed_steps.contains(&"s4".to_string()),
5275 "s4 should not fail (never started)"
5276 );
5277 }
5278
5279 #[test]
5284 fn test_agent_config_resilience_defaults() {
5285 let config = AgentConfig::default();
5286 assert_eq!(config.max_parse_retries, 2);
5287 assert_eq!(config.tool_timeout_ms, None);
5288 assert_eq!(config.circuit_breaker_threshold, 3);
5289 }
5290
5291 #[tokio::test]
5293 async fn test_parse_error_recovery_bails_after_threshold() {
5294 let mock_client = Arc::new(MockLlmClient::new(vec![
5296 MockLlmClient::tool_call_response(
5297 "c1",
5298 "bash",
5299 serde_json::json!({"__parse_error": "unexpected token at position 5"}),
5300 ),
5301 MockLlmClient::tool_call_response(
5302 "c2",
5303 "bash",
5304 serde_json::json!({"__parse_error": "missing closing brace"}),
5305 ),
5306 MockLlmClient::tool_call_response(
5307 "c3",
5308 "bash",
5309 serde_json::json!({"__parse_error": "still broken"}),
5310 ),
5311 MockLlmClient::text_response("Done"), ]));
5313
5314 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5315 let config = AgentConfig {
5316 max_parse_retries: 2,
5317 ..AgentConfig::default()
5318 };
5319 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
5320 let result = agent.execute(&[], "Do something", None).await;
5321 assert!(result.is_err(), "should bail after parse error threshold");
5322 let err = result.unwrap_err().to_string();
5323 assert!(
5324 err.contains("malformed tool arguments"),
5325 "error should mention malformed tool arguments, got: {}",
5326 err
5327 );
5328 }
5329
5330 #[tokio::test]
5332 async fn test_parse_error_counter_resets_on_success() {
5333 let mock_client = Arc::new(MockLlmClient::new(vec![
5337 MockLlmClient::tool_call_response(
5338 "c1",
5339 "bash",
5340 serde_json::json!({"__parse_error": "bad args"}),
5341 ),
5342 MockLlmClient::tool_call_response(
5343 "c2",
5344 "bash",
5345 serde_json::json!({"__parse_error": "bad args again"}),
5346 ),
5347 MockLlmClient::tool_call_response(
5349 "c3",
5350 "bash",
5351 serde_json::json!({"command": "echo ok"}),
5352 ),
5353 MockLlmClient::text_response("All done"),
5354 ]));
5355
5356 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5357 let config = AgentConfig {
5358 max_parse_retries: 2,
5359 ..AgentConfig::default()
5360 };
5361 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
5362 let result = agent.execute(&[], "Do something", None).await;
5363 assert!(
5364 result.is_ok(),
5365 "should not bail — counter reset after successful tool, got: {:?}",
5366 result.err()
5367 );
5368 assert_eq!(result.unwrap().text, "All done");
5369 }
5370
5371 #[tokio::test]
5373 async fn test_tool_timeout_produces_error_result() {
5374 let mock_client = Arc::new(MockLlmClient::new(vec![
5375 MockLlmClient::tool_call_response(
5376 "t1",
5377 "bash",
5378 serde_json::json!({"command": "sleep 10"}),
5379 ),
5380 MockLlmClient::text_response("The command timed out."),
5381 ]));
5382
5383 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5384 let config = AgentConfig {
5385 tool_timeout_ms: Some(50),
5387 ..AgentConfig::default()
5388 };
5389 let agent = AgentLoop::new(
5390 mock_client.clone(),
5391 tool_executor,
5392 test_tool_context(),
5393 config,
5394 );
5395 let result = agent.execute(&[], "Run sleep", None).await;
5396 assert!(
5397 result.is_ok(),
5398 "session should continue after tool timeout: {:?}",
5399 result.err()
5400 );
5401 assert_eq!(result.unwrap().text, "The command timed out.");
5402 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 2);
5404 }
5405
5406 #[tokio::test]
5408 async fn test_tool_within_timeout_succeeds() {
5409 let mock_client = Arc::new(MockLlmClient::new(vec![
5410 MockLlmClient::tool_call_response(
5411 "t1",
5412 "bash",
5413 serde_json::json!({"command": "echo fast"}),
5414 ),
5415 MockLlmClient::text_response("Command succeeded."),
5416 ]));
5417
5418 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5419 let config = AgentConfig {
5420 tool_timeout_ms: Some(5_000), ..AgentConfig::default()
5422 };
5423 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
5424 let result = agent.execute(&[], "Run something fast", None).await;
5425 assert!(
5426 result.is_ok(),
5427 "fast tool should succeed: {:?}",
5428 result.err()
5429 );
5430 assert_eq!(result.unwrap().text, "Command succeeded.");
5431 }
5432
5433 #[tokio::test]
5435 async fn test_circuit_breaker_retries_non_streaming() {
5436 let mock_client = Arc::new(MockLlmClient::new(vec![]));
5439
5440 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5441 let config = AgentConfig {
5442 circuit_breaker_threshold: 2,
5443 ..AgentConfig::default()
5444 };
5445 let agent = AgentLoop::new(
5446 mock_client.clone(),
5447 tool_executor,
5448 test_tool_context(),
5449 config,
5450 );
5451 let result = agent.execute(&[], "Hello", None).await;
5452 assert!(result.is_err(), "should fail when LLM always errors");
5453 let err = result.unwrap_err().to_string();
5454 assert!(
5455 err.contains("circuit breaker"),
5456 "error should mention circuit breaker, got: {}",
5457 err
5458 );
5459 assert_eq!(
5460 mock_client.call_count.load(Ordering::SeqCst),
5461 2,
5462 "should make exactly threshold=2 LLM calls"
5463 );
5464 }
5465
5466 #[tokio::test]
5468 async fn test_circuit_breaker_threshold_one_no_retry() {
5469 let mock_client = Arc::new(MockLlmClient::new(vec![]));
5470
5471 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5472 let config = AgentConfig {
5473 circuit_breaker_threshold: 1,
5474 ..AgentConfig::default()
5475 };
5476 let agent = AgentLoop::new(
5477 mock_client.clone(),
5478 tool_executor,
5479 test_tool_context(),
5480 config,
5481 );
5482 let result = agent.execute(&[], "Hello", None).await;
5483 assert!(result.is_err());
5484 assert_eq!(
5485 mock_client.call_count.load(Ordering::SeqCst),
5486 1,
5487 "with threshold=1 exactly one attempt should be made"
5488 );
5489 }
5490
5491 #[tokio::test]
5493 async fn test_circuit_breaker_succeeds_if_llm_recovers() {
5494 struct FailOnceThenSucceed {
5496 inner: MockLlmClient,
5497 failed_once: std::sync::atomic::AtomicBool,
5498 call_count: AtomicUsize,
5499 }
5500
5501 #[async_trait::async_trait]
5502 impl LlmClient for FailOnceThenSucceed {
5503 async fn complete(
5504 &self,
5505 messages: &[Message],
5506 system: Option<&str>,
5507 tools: &[ToolDefinition],
5508 ) -> Result<LlmResponse> {
5509 self.call_count.fetch_add(1, Ordering::SeqCst);
5510 let already_failed = self
5511 .failed_once
5512 .swap(true, std::sync::atomic::Ordering::SeqCst);
5513 if !already_failed {
5514 anyhow::bail!("transient network error");
5515 }
5516 self.inner.complete(messages, system, tools).await
5517 }
5518
5519 async fn complete_streaming(
5520 &self,
5521 messages: &[Message],
5522 system: Option<&str>,
5523 tools: &[ToolDefinition],
5524 ) -> Result<tokio::sync::mpsc::Receiver<crate::llm::StreamEvent>> {
5525 self.inner.complete_streaming(messages, system, tools).await
5526 }
5527 }
5528
5529 let mock = Arc::new(FailOnceThenSucceed {
5530 inner: MockLlmClient::new(vec![MockLlmClient::text_response("Recovered!")]),
5531 failed_once: std::sync::atomic::AtomicBool::new(false),
5532 call_count: AtomicUsize::new(0),
5533 });
5534
5535 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5536 let config = AgentConfig {
5537 circuit_breaker_threshold: 3,
5538 ..AgentConfig::default()
5539 };
5540 let agent = AgentLoop::new(mock.clone(), tool_executor, test_tool_context(), config);
5541 let result = agent.execute(&[], "Hello", None).await;
5542 assert!(
5543 result.is_ok(),
5544 "should succeed when LLM recovers within threshold: {:?}",
5545 result.err()
5546 );
5547 assert_eq!(result.unwrap().text, "Recovered!");
5548 assert_eq!(
5549 mock.call_count.load(Ordering::SeqCst),
5550 2,
5551 "should have made exactly 2 calls (1 fail + 1 success)"
5552 );
5553 }
5554
5555 #[test]
5558 fn test_looks_incomplete_empty() {
5559 assert!(AgentLoop::looks_incomplete(""));
5560 assert!(AgentLoop::looks_incomplete(" "));
5561 }
5562
5563 #[test]
5564 fn test_looks_incomplete_trailing_colon() {
5565 assert!(AgentLoop::looks_incomplete("Let me check the file:"));
5566 assert!(AgentLoop::looks_incomplete("Next steps:"));
5567 }
5568
5569 #[test]
5570 fn test_looks_incomplete_ellipsis() {
5571 assert!(AgentLoop::looks_incomplete("Working on it..."));
5572 assert!(AgentLoop::looks_incomplete("Processing…"));
5573 }
5574
5575 #[test]
5576 fn test_looks_incomplete_intent_phrases() {
5577 assert!(AgentLoop::looks_incomplete(
5578 "I'll start by reading the file."
5579 ));
5580 assert!(AgentLoop::looks_incomplete(
5581 "Let me check the configuration."
5582 ));
5583 assert!(AgentLoop::looks_incomplete("I will now run the tests."));
5584 assert!(AgentLoop::looks_incomplete(
5585 "I need to update the Cargo.toml."
5586 ));
5587 }
5588
5589 #[test]
5590 fn test_looks_complete_final_answer() {
5591 assert!(!AgentLoop::looks_incomplete(
5593 "The tests pass. All changes have been applied successfully."
5594 ));
5595 assert!(!AgentLoop::looks_incomplete(
5596 "Done. I've updated the three files and verified the build succeeds."
5597 ));
5598 assert!(!AgentLoop::looks_incomplete("42"));
5599 assert!(!AgentLoop::looks_incomplete("Yes."));
5600 }
5601
5602 #[test]
5603 fn test_looks_incomplete_multiline_complete() {
5604 let text = "Here is the summary:\n\n- Fixed the bug in agent.rs\n- All tests pass\n- Build succeeds";
5605 assert!(!AgentLoop::looks_incomplete(text));
5606 }
5607}