1#[cfg(feature = "ahp")]
13use crate::ahp::InjectedContext;
14#[cfg(feature = "ahp")]
15use crate::context::{ContextItem, ContextType};
16use crate::context::{ContextProvider, ContextQuery, ContextResult};
17use crate::hitl::ConfirmationProvider;
18use crate::hooks::{
19 ErrorType, GenerateEndEvent, GenerateStartEvent, HookEvent, HookExecutor, HookResult,
20 IntentDetectionEvent, OnErrorEvent, PostResponseEvent, PostToolUseEvent,
21 PreContextPerceptionEvent, PrePromptEvent, PreToolUseEvent, TokenUsageInfo, ToolCallInfo,
22 ToolResultData,
23};
24use crate::llm::{LlmClient, LlmResponse, Message, TokenUsage, ToolDefinition};
25use crate::permissions::{PermissionChecker, PermissionDecision};
26use crate::planning::{AgentGoal, ExecutionPlan, TaskStatus};
27use crate::prompts::{AgentStyle, DetectionConfidence, PlanningMode, SystemPromptSlots};
28use crate::queue::SessionCommand;
29use crate::session_lane_queue::SessionLaneQueue;
30use crate::text::truncate_utf8;
31use crate::tool_search::ToolIndex;
32use crate::tools::{ToolContext, ToolExecutor, ToolStreamEvent};
33use anyhow::{Context, Result};
34use async_trait::async_trait;
35use futures::future::join_all;
36use serde::{Deserialize, Serialize};
37use serde_json::Value;
38use std::sync::Arc;
39use std::time::Duration;
40use tokio::sync::{mpsc, RwLock};
41
42const MAX_TOOL_ROUNDS: usize = 50;
44
45#[derive(Clone)]
47pub struct AgentConfig {
48 pub prompt_slots: SystemPromptSlots,
54 pub tools: Vec<ToolDefinition>,
55 pub max_tool_rounds: usize,
56 pub security_provider: Option<Arc<dyn crate::security::SecurityProvider>>,
58 pub permission_checker: Option<Arc<dyn PermissionChecker>>,
60 pub confirmation_manager: Option<Arc<dyn ConfirmationProvider>>,
62 pub context_providers: Vec<Arc<dyn ContextProvider>>,
64 pub planning_mode: PlanningMode,
66 pub goal_tracking: bool,
68 pub hook_engine: Option<Arc<dyn HookExecutor>>,
70 pub skill_registry: Option<Arc<crate::skills::SkillRegistry>>,
72 pub max_parse_retries: u32,
78 pub tool_timeout_ms: Option<u64>,
84 pub circuit_breaker_threshold: u32,
91 pub duplicate_tool_call_threshold: u32,
97 pub auto_compact: bool,
99 pub auto_compact_threshold: f32,
102 pub max_context_tokens: usize,
105 pub llm_client: Option<Arc<dyn LlmClient>>,
107 pub memory: Option<Arc<crate::memory::AgentMemory>>,
109 pub continuation_enabled: bool,
117 pub max_continuation_turns: u32,
121 pub tool_index: Option<ToolIndex>,
126 pub subagent_registry: Option<Arc<crate::subagent::AgentRegistry>>,
131 #[allow(clippy::type_complexity)]
138 pub on_subagent_launch: Option<
139 Arc<
140 dyn Fn(&crate::subagent::AgentDefinition, &str) -> Option<Result<AgentResult>>
141 + Send
142 + Sync,
143 >,
144 >,
145}
146
147impl std::fmt::Debug for AgentConfig {
148 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
149 f.debug_struct("AgentConfig")
150 .field("prompt_slots", &self.prompt_slots)
151 .field("tools", &self.tools)
152 .field("max_tool_rounds", &self.max_tool_rounds)
153 .field("security_provider", &self.security_provider.is_some())
154 .field("permission_checker", &self.permission_checker.is_some())
155 .field("confirmation_manager", &self.confirmation_manager.is_some())
156 .field("context_providers", &self.context_providers.len())
157 .field("planning_mode", &self.planning_mode)
158 .field("goal_tracking", &self.goal_tracking)
159 .field("hook_engine", &self.hook_engine.is_some())
160 .field(
161 "skill_registry",
162 &self.skill_registry.as_ref().map(|r| r.len()),
163 )
164 .field("max_parse_retries", &self.max_parse_retries)
165 .field("tool_timeout_ms", &self.tool_timeout_ms)
166 .field("circuit_breaker_threshold", &self.circuit_breaker_threshold)
167 .field(
168 "duplicate_tool_call_threshold",
169 &self.duplicate_tool_call_threshold,
170 )
171 .field("auto_compact", &self.auto_compact)
172 .field("auto_compact_threshold", &self.auto_compact_threshold)
173 .field("max_context_tokens", &self.max_context_tokens)
174 .field("continuation_enabled", &self.continuation_enabled)
175 .field("max_continuation_turns", &self.max_continuation_turns)
176 .field("memory", &self.memory.is_some())
177 .field("tool_index", &self.tool_index.as_ref().map(|i| i.len()))
178 .field(
179 "subagent_registry",
180 &self.subagent_registry.as_ref().map(|r| r.len()),
181 )
182 .field("on_subagent_launch", &self.on_subagent_launch.is_some())
183 .finish()
184 }
185}
186
187impl Default for AgentConfig {
188 fn default() -> Self {
189 Self {
190 prompt_slots: SystemPromptSlots::default(),
191 tools: Vec::new(), max_tool_rounds: MAX_TOOL_ROUNDS,
193 security_provider: None,
194 permission_checker: None,
195 confirmation_manager: None,
196 context_providers: Vec::new(),
197 planning_mode: PlanningMode::default(),
198 goal_tracking: false,
199 hook_engine: None,
200 skill_registry: Some(Arc::new(crate::skills::SkillRegistry::with_builtins())),
201 max_parse_retries: 2,
202 tool_timeout_ms: None,
203 circuit_breaker_threshold: 3,
204 duplicate_tool_call_threshold: 3,
205 auto_compact: false,
206 auto_compact_threshold: 0.80,
207 max_context_tokens: 200_000,
208 llm_client: None,
209 memory: None,
210 continuation_enabled: true,
211 max_continuation_turns: 3,
212 tool_index: None,
213 subagent_registry: None,
214 on_subagent_launch: None,
215 }
216 }
217}
218
219#[derive(Debug, Clone, Serialize, Deserialize)]
225#[serde(tag = "type")]
226#[non_exhaustive]
227pub enum AgentEvent {
228 #[serde(rename = "agent_start")]
230 Start { prompt: String },
231
232 #[serde(rename = "agent_mode_changed")]
234 AgentModeChanged {
235 mode: String,
237 agent: String,
239 description: String,
241 },
242
243 #[serde(rename = "turn_start")]
245 TurnStart { turn: usize },
246
247 #[serde(rename = "text_delta")]
249 TextDelta { text: String },
250
251 #[serde(rename = "tool_start")]
253 ToolStart { id: String, name: String },
254
255 #[serde(rename = "tool_input_delta")]
257 ToolInputDelta { delta: String },
258
259 #[serde(rename = "tool_end")]
261 ToolEnd {
262 id: String,
263 name: String,
264 output: String,
265 exit_code: i32,
266 #[serde(skip_serializing_if = "Option::is_none")]
267 metadata: Option<serde_json::Value>,
268 },
269
270 #[serde(rename = "tool_output_delta")]
272 ToolOutputDelta {
273 id: String,
274 name: String,
275 delta: String,
276 },
277
278 #[serde(rename = "turn_end")]
280 TurnEnd { turn: usize, usage: TokenUsage },
281
282 #[serde(rename = "agent_end")]
284 End {
285 text: String,
286 usage: TokenUsage,
287 #[serde(skip_serializing_if = "Option::is_none")]
288 meta: Option<crate::llm::LlmResponseMeta>,
289 },
290
291 #[serde(rename = "error")]
293 Error { message: String },
294
295 #[serde(rename = "confirmation_required")]
297 ConfirmationRequired {
298 tool_id: String,
299 tool_name: String,
300 args: serde_json::Value,
301 timeout_ms: u64,
302 },
303
304 #[serde(rename = "confirmation_received")]
306 ConfirmationReceived {
307 tool_id: String,
308 approved: bool,
309 reason: Option<String>,
310 },
311
312 #[serde(rename = "confirmation_timeout")]
314 ConfirmationTimeout {
315 tool_id: String,
316 action_taken: String, },
318
319 #[serde(rename = "external_task_pending")]
321 ExternalTaskPending {
322 task_id: String,
323 session_id: String,
324 lane: crate::hitl::SessionLane,
325 command_type: String,
326 payload: serde_json::Value,
327 timeout_ms: u64,
328 },
329
330 #[serde(rename = "external_task_completed")]
332 ExternalTaskCompleted {
333 task_id: String,
334 session_id: String,
335 success: bool,
336 },
337
338 #[serde(rename = "permission_denied")]
340 PermissionDenied {
341 tool_id: String,
342 tool_name: String,
343 args: serde_json::Value,
344 reason: String,
345 },
346
347 #[serde(rename = "context_resolving")]
349 ContextResolving { providers: Vec<String> },
350
351 #[serde(rename = "context_resolved")]
353 ContextResolved {
354 total_items: usize,
355 total_tokens: usize,
356 },
357
358 #[serde(rename = "command_dead_lettered")]
363 CommandDeadLettered {
364 command_id: String,
365 command_type: String,
366 lane: String,
367 error: String,
368 attempts: u32,
369 },
370
371 #[serde(rename = "command_retry")]
373 CommandRetry {
374 command_id: String,
375 command_type: String,
376 lane: String,
377 attempt: u32,
378 delay_ms: u64,
379 },
380
381 #[serde(rename = "queue_alert")]
383 QueueAlert {
384 level: String,
385 alert_type: String,
386 message: String,
387 },
388
389 #[serde(rename = "task_updated")]
394 TaskUpdated {
395 session_id: String,
396 tasks: Vec<crate::planning::Task>,
397 },
398
399 #[serde(rename = "memory_stored")]
404 MemoryStored {
405 memory_id: String,
406 memory_type: String,
407 importance: f32,
408 tags: Vec<String>,
409 },
410
411 #[serde(rename = "memory_recalled")]
413 MemoryRecalled {
414 memory_id: String,
415 content: String,
416 relevance: f32,
417 },
418
419 #[serde(rename = "memories_searched")]
421 MemoriesSearched {
422 query: Option<String>,
423 tags: Vec<String>,
424 result_count: usize,
425 },
426
427 #[serde(rename = "memory_cleared")]
429 MemoryCleared {
430 tier: String, count: u64,
432 },
433
434 #[serde(rename = "subagent_start")]
439 SubagentStart {
440 task_id: String,
442 session_id: String,
444 parent_session_id: String,
446 agent: String,
448 description: String,
450 },
451
452 #[serde(rename = "subagent_progress")]
454 SubagentProgress {
455 task_id: String,
457 session_id: String,
459 status: String,
461 metadata: serde_json::Value,
463 },
464
465 #[serde(rename = "subagent_end")]
467 SubagentEnd {
468 task_id: String,
470 session_id: String,
472 agent: String,
474 output: String,
476 success: bool,
478 },
479
480 #[serde(rename = "planning_start")]
485 PlanningStart { prompt: String },
486
487 #[serde(rename = "planning_end")]
489 PlanningEnd {
490 plan: ExecutionPlan,
491 estimated_steps: usize,
492 },
493
494 #[serde(rename = "step_start")]
496 StepStart {
497 step_id: String,
498 description: String,
499 step_number: usize,
500 total_steps: usize,
501 },
502
503 #[serde(rename = "step_end")]
505 StepEnd {
506 step_id: String,
507 status: TaskStatus,
508 step_number: usize,
509 total_steps: usize,
510 },
511
512 #[serde(rename = "goal_extracted")]
514 GoalExtracted { goal: AgentGoal },
515
516 #[serde(rename = "goal_progress")]
518 GoalProgress {
519 goal: String,
520 progress: f32,
521 completed_steps: usize,
522 total_steps: usize,
523 },
524
525 #[serde(rename = "goal_achieved")]
527 GoalAchieved {
528 goal: String,
529 total_steps: usize,
530 duration_ms: i64,
531 },
532
533 #[serde(rename = "context_compacted")]
538 ContextCompacted {
539 session_id: String,
540 before_messages: usize,
541 after_messages: usize,
542 percent_before: f32,
543 },
544
545 #[serde(rename = "persistence_failed")]
550 PersistenceFailed {
551 session_id: String,
552 operation: String,
553 error: String,
554 },
555
556 #[serde(rename = "btw_answer")]
564 BtwAnswer {
565 question: String,
566 answer: String,
567 usage: TokenUsage,
568 },
569}
570
571#[derive(Debug, Clone)]
573pub struct AgentResult {
574 pub text: String,
575 pub messages: Vec<Message>,
576 pub usage: TokenUsage,
577 pub tool_calls_count: usize,
578}
579
580pub struct ToolCommand {
588 tool_executor: Arc<ToolExecutor>,
589 tool_name: String,
590 tool_args: Value,
591 tool_context: ToolContext,
592 skill_registry: Option<Arc<crate::skills::SkillRegistry>>,
593}
594
595impl ToolCommand {
596 pub fn new(
598 tool_executor: Arc<ToolExecutor>,
599 tool_name: String,
600 tool_args: Value,
601 tool_context: ToolContext,
602 skill_registry: Option<Arc<crate::skills::SkillRegistry>>,
603 ) -> Self {
604 Self {
605 tool_executor,
606 tool_name,
607 tool_args,
608 tool_context,
609 skill_registry,
610 }
611 }
612}
613
614#[async_trait]
615impl SessionCommand for ToolCommand {
616 async fn execute(&self) -> Result<Value> {
617 if let Some(registry) = &self.skill_registry {
619 let instruction_skills = registry.by_kind(crate::skills::SkillKind::Instruction);
620
621 let has_restrictions = instruction_skills.iter().any(|s| s.allowed_tools.is_some());
623
624 if has_restrictions {
625 let mut allowed = false;
626
627 for skill in &instruction_skills {
628 if skill.is_tool_allowed(&self.tool_name) {
629 allowed = true;
630 break;
631 }
632 }
633
634 if !allowed {
635 return Err(anyhow::anyhow!(
636 "Tool '{}' is not allowed by any active skill. Active skills restrict tools to their allowed-tools lists.",
637 self.tool_name
638 ));
639 }
640 }
641 }
642
643 let result = self
645 .tool_executor
646 .execute_with_context(&self.tool_name, &self.tool_args, &self.tool_context)
647 .await?;
648 Ok(serde_json::json!({
649 "output": result.output,
650 "exit_code": result.exit_code,
651 "metadata": result.metadata,
652 }))
653 }
654
655 fn command_type(&self) -> &str {
656 &self.tool_name
657 }
658
659 fn payload(&self) -> Value {
660 self.tool_args.clone()
661 }
662}
663
664#[derive(Clone)]
670pub struct AgentLoop {
671 llm_client: Arc<dyn LlmClient>,
672 tool_executor: Arc<ToolExecutor>,
673 tool_context: ToolContext,
674 config: AgentConfig,
675 tool_metrics: Option<Arc<RwLock<crate::telemetry::ToolMetrics>>>,
677 command_queue: Option<Arc<SessionLaneQueue>>,
679 progress_tracker: Option<Arc<tokio::sync::RwLock<crate::task::ProgressTracker>>>,
681 task_manager: Option<Arc<crate::task::TaskManager>>,
683}
684
685#[allow(clippy::extra_unused_lifetimes)]
691fn extract_target_name_from_prompt<'a>(prompt: &str, _patterns: &[&str]) -> String {
692 if let Some(start) = prompt.find('"') {
694 if let Some(end) = prompt[start + 1..].find('"') {
695 return prompt[start + 1..start + 1 + end].to_string();
696 }
697 }
698
699 if let Some(start) = prompt.find('\'') {
701 if let Some(end) = prompt[start + 1..].find('\'') {
702 return prompt[start + 1..start + 1 + end].to_string();
703 }
704 }
705
706 if let Some(start) = prompt.find('`') {
708 if let Some(end) = prompt[start + 1..].find('`') {
709 return prompt[start + 1..start + 1 + end].to_string();
710 }
711 }
712
713 let words: Vec<&str> = prompt.split_whitespace().collect();
715 if words.len() > 2 {
716 for word in words.iter() {
718 if word.len() > 3
719 && !["where", "what", "find", "the", "how", "is", "are"].contains(word)
720 {
721 return word.to_string();
722 }
723 }
724 }
725
726 String::new()
727}
728
729fn detect_domain_from_prompt(prompt: &str) -> String {
731 let lower = prompt.to_lowercase();
732
733 if lower.contains("rust") || lower.contains("cargo") || lower.contains(".rs") {
734 "rust".to_string()
735 } else if lower.contains("javascript")
736 || lower.contains("typescript")
737 || lower.contains("node")
738 || lower.contains(".js")
739 || lower.contains(".ts")
740 {
741 "javascript".to_string()
742 } else if lower.contains("python") || lower.contains(".py") {
743 "python".to_string()
744 } else if lower.contains("go") || lower.contains(".go") {
745 "go".to_string()
746 } else if lower.contains("java") || lower.contains(".java") {
747 "java".to_string()
748 } else if lower.contains("docker") || lower.contains("container") {
749 "docker".to_string()
750 } else if lower.contains("kubernetes") || lower.contains("k8s") {
751 "kubernetes".to_string()
752 } else if lower.contains("sql")
753 || lower.contains("database")
754 || lower.contains("postgres")
755 || lower.contains("mysql")
756 {
757 "database".to_string()
758 } else if lower.contains("api") || lower.contains("rest") || lower.contains("grpc") {
759 "api".to_string()
760 } else if lower.contains("auth")
761 || lower.contains("login")
762 || lower.contains("password")
763 || lower.contains("token")
764 {
765 "security".to_string()
766 } else if lower.contains("test") || lower.contains("spec") || lower.contains("mock") {
767 "testing".to_string()
768 } else {
769 "general".to_string()
770 }
771}
772
773#[derive(Debug, Clone, Serialize, Deserialize)]
775pub struct IntentDetectionResult {
776 pub detected_intent: String,
778 pub confidence: f32,
780 #[serde(skip_serializing_if = "Option::is_none")]
782 pub target_hints: Option<TargetHints>,
783}
784
785#[derive(Debug, Clone, Serialize, Deserialize)]
787pub struct TargetHints {
788 #[serde(skip_serializing_if = "Option::is_none")]
789 pub target_type: Option<String>,
790 #[serde(skip_serializing_if = "Option::is_none")]
791 pub target_name: Option<String>,
792 #[serde(skip_serializing_if = "Option::is_none")]
793 pub domain: Option<String>,
794}
795
796fn detect_language_hint(prompt: &str) -> Option<String> {
798 if prompt
800 .chars()
801 .any(|c| ('\u{4e00}'..='\u{9fff}').contains(&c))
802 {
803 return Some("zh".to_string());
804 }
805 if prompt
807 .chars()
808 .any(|c| ('\u{3040}'..='\u{309f}').contains(&c) || ('\u{30a0}'..='\u{30ff}').contains(&c))
809 {
810 return Some("ja".to_string());
811 }
812 if prompt
814 .chars()
815 .any(|c| ('\u{ac00}'..='\u{d7af}').contains(&c))
816 {
817 return Some("ko".to_string());
818 }
819 if prompt
821 .chars()
822 .any(|c| ('\u{0600}'..='\u{06ff}').contains(&c))
823 {
824 return Some("ar".to_string());
825 }
826 if prompt
828 .chars()
829 .any(|c| ('\u{0400}'..='\u{04ff}').contains(&c))
830 {
831 return Some("ru".to_string());
832 }
833 None
834}
835
836fn build_pre_context_perception_from_intent(
838 result: IntentDetectionResult,
839 prompt: &str,
840 session_id: &str,
841 workspace: &str,
842) -> PreContextPerceptionEvent {
843 let target_hints = result.target_hints;
844 PreContextPerceptionEvent {
845 session_id: session_id.to_string(),
846 intent: result.detected_intent,
847 target_type: target_hints
848 .as_ref()
849 .and_then(|h| h.target_type.clone())
850 .unwrap_or_else(|| "unknown".to_string()),
851 target_name: target_hints
852 .as_ref()
853 .and_then(|h| h.target_name.clone())
854 .unwrap_or_else(|| extract_target_name_from_prompt(prompt, &[])),
855 domain: target_hints
856 .as_ref()
857 .and_then(|h| h.domain.clone())
858 .unwrap_or_else(|| detect_domain_from_prompt(prompt)),
859 query: Some(prompt.to_string()),
860 working_directory: workspace.to_string(),
861 urgency: "normal".to_string(),
862 }
863}
864
865#[cfg(feature = "ahp")]
867fn estimate_tokens(text: &str) -> usize {
868 text.len() / 4
869}
870
871impl AgentLoop {
872 pub fn new(
873 llm_client: Arc<dyn LlmClient>,
874 tool_executor: Arc<ToolExecutor>,
875 tool_context: ToolContext,
876 config: AgentConfig,
877 ) -> Self {
878 Self {
879 llm_client,
880 tool_executor,
881 tool_context,
882 config,
883 tool_metrics: None,
884 command_queue: None,
885 progress_tracker: None,
886 task_manager: None,
887 }
888 }
889
890 pub fn with_progress_tracker(
892 mut self,
893 tracker: Arc<tokio::sync::RwLock<crate::task::ProgressTracker>>,
894 ) -> Self {
895 self.progress_tracker = Some(tracker);
896 self
897 }
898
899 pub fn with_task_manager(mut self, manager: Arc<crate::task::TaskManager>) -> Self {
901 self.task_manager = Some(manager);
902 self
903 }
904
905 pub fn with_tool_metrics(
907 mut self,
908 metrics: Arc<RwLock<crate::telemetry::ToolMetrics>>,
909 ) -> Self {
910 self.tool_metrics = Some(metrics);
911 self
912 }
913
914 pub fn with_queue(mut self, queue: Arc<SessionLaneQueue>) -> Self {
919 self.command_queue = Some(queue);
920 self
921 }
922
923 fn track_tool_result(&self, tool_name: &str, args: &serde_json::Value, exit_code: i32) {
925 if let Some(ref tracker) = self.progress_tracker {
926 let args_summary = Self::compact_json_args(args);
927 let success = exit_code == 0;
928 if let Ok(mut guard) = tracker.try_write() {
929 guard.track_tool_call(tool_name, args_summary, success);
930 }
931 }
932 }
933
934 fn compact_json_args(args: &serde_json::Value) -> String {
936 let raw = match args {
937 serde_json::Value::Null => String::new(),
938 serde_json::Value::String(s) => s.clone(),
939 _ => serde_json::to_string(args).unwrap_or_default(),
940 };
941 let compact = raw.split_whitespace().collect::<Vec<_>>().join(" ");
942 if compact.len() > 180 {
943 format!("{}...", truncate_utf8(&compact, 180))
944 } else {
945 compact
946 }
947 }
948
949 async fn execute_tool_timed(
955 &self,
956 name: &str,
957 args: &serde_json::Value,
958 ctx: &ToolContext,
959 ) -> anyhow::Result<crate::tools::ToolResult> {
960 let fut = self.tool_executor.execute_with_context(name, args, ctx);
961 if let Some(timeout_ms) = self.config.tool_timeout_ms {
962 match tokio::time::timeout(Duration::from_millis(timeout_ms), fut).await {
963 Ok(result) => result,
964 Err(_) => Err(anyhow::anyhow!(
965 "Tool '{}' timed out after {}ms",
966 name,
967 timeout_ms
968 )),
969 }
970 } else {
971 fut.await
972 }
973 }
974
975 fn tool_result_to_tuple(
977 result: anyhow::Result<crate::tools::ToolResult>,
978 ) -> (
979 String,
980 i32,
981 bool,
982 Option<serde_json::Value>,
983 Vec<crate::llm::Attachment>,
984 ) {
985 match result {
986 Ok(r) => (
987 r.output,
988 r.exit_code,
989 r.exit_code != 0,
990 r.metadata,
991 r.images,
992 ),
993 Err(e) => {
994 let msg = e.to_string();
995 let hint = if Self::is_transient_error(&msg) {
997 " [transient — you may retry this tool call]"
998 } else {
999 " [permanent — do not retry without changing the arguments]"
1000 };
1001 (
1002 format!("Tool execution error: {}{}", msg, hint),
1003 1,
1004 true,
1005 None,
1006 Vec::new(),
1007 )
1008 }
1009 }
1010 }
1011
1012 fn detect_project_hint(workspace: &std::path::Path) -> String {
1016 struct Marker {
1017 file: &'static str,
1018 lang: &'static str,
1019 tip: &'static str,
1020 }
1021
1022 let markers = [
1023 Marker {
1024 file: "Cargo.toml",
1025 lang: "Rust",
1026 tip: "Use `cargo build`, `cargo test`, `cargo clippy`, and `cargo fmt`. \
1027 Prefer `anyhow` / `thiserror` for error handling. \
1028 Follow the Microsoft Rust Guidelines (no panics in library code, \
1029 async-first with Tokio).",
1030 },
1031 Marker {
1032 file: "package.json",
1033 lang: "Node.js / TypeScript",
1034 tip: "Check `package.json` for the package manager (npm/yarn/pnpm/bun) \
1035 and available scripts. Prefer TypeScript with strict mode. \
1036 Use ESM imports unless the project is CommonJS.",
1037 },
1038 Marker {
1039 file: "pyproject.toml",
1040 lang: "Python",
1041 tip: "Use the package manager declared in `pyproject.toml` \
1042 (uv, poetry, hatch, etc.). Prefer type hints and async/await for I/O.",
1043 },
1044 Marker {
1045 file: "setup.py",
1046 lang: "Python",
1047 tip: "Legacy Python project. Prefer type hints and async/await for I/O.",
1048 },
1049 Marker {
1050 file: "requirements.txt",
1051 lang: "Python",
1052 tip: "Python project with pip-style dependencies. \
1053 Prefer type hints and async/await for I/O.",
1054 },
1055 Marker {
1056 file: "go.mod",
1057 lang: "Go",
1058 tip: "Use `go build ./...` and `go test ./...`. \
1059 Follow standard Go project layout. Use `gofmt` for formatting.",
1060 },
1061 Marker {
1062 file: "pom.xml",
1063 lang: "Java / Maven",
1064 tip: "Use `mvn compile`, `mvn test`, `mvn package`. \
1065 Follow standard Maven project structure.",
1066 },
1067 Marker {
1068 file: "build.gradle",
1069 lang: "Java / Gradle",
1070 tip: "Use `./gradlew build` and `./gradlew test`. \
1071 Follow standard Gradle project structure.",
1072 },
1073 Marker {
1074 file: "build.gradle.kts",
1075 lang: "Kotlin / Gradle",
1076 tip: "Use `./gradlew build` and `./gradlew test`. \
1077 Prefer Kotlin coroutines for async work.",
1078 },
1079 Marker {
1080 file: "CMakeLists.txt",
1081 lang: "C / C++",
1082 tip: "Use `cmake -B build && cmake --build build`. \
1083 Check for `compile_commands.json` for IDE tooling.",
1084 },
1085 Marker {
1086 file: "Makefile",
1087 lang: "C / C++ (or generic)",
1088 tip: "Use `make` or `make <target>`. \
1089 Check available targets with `make help` or by reading the Makefile.",
1090 },
1091 ];
1092
1093 let is_dotnet = workspace.join("*.csproj").exists() || {
1095 std::fs::read_dir(workspace)
1097 .map(|entries| {
1098 entries.flatten().any(|e| {
1099 let name = e.file_name();
1100 let s = name.to_string_lossy();
1101 s.ends_with(".csproj") || s.ends_with(".sln")
1102 })
1103 })
1104 .unwrap_or(false)
1105 };
1106
1107 if is_dotnet {
1108 return "## Project Context\n\nThis is a **C# / .NET** project. \
1109 Use `dotnet build`, `dotnet test`, and `dotnet run`. \
1110 Follow C# coding conventions and async/await patterns."
1111 .to_string();
1112 }
1113
1114 for marker in &markers {
1115 if workspace.join(marker.file).exists() {
1116 return format!(
1117 "## Project Context\n\nThis is a **{}** project. {}",
1118 marker.lang, marker.tip
1119 );
1120 }
1121 }
1122
1123 String::new()
1124 }
1125
1126 fn is_transient_error(msg: &str) -> bool {
1129 let lower = msg.to_lowercase();
1130 lower.contains("timeout")
1131 || lower.contains("timed out")
1132 || lower.contains("connection refused")
1133 || lower.contains("connection reset")
1134 || lower.contains("broken pipe")
1135 || lower.contains("temporarily unavailable")
1136 || lower.contains("resource temporarily unavailable")
1137 || lower.contains("os error 11") || lower.contains("os error 35") || lower.contains("rate limit")
1140 || lower.contains("too many requests")
1141 || lower.contains("service unavailable")
1142 || lower.contains("network unreachable")
1143 }
1144
1145 fn is_parallel_safe_write(name: &str, _args: &serde_json::Value) -> bool {
1148 matches!(
1149 name,
1150 "write_file" | "edit_file" | "create_file" | "append_to_file" | "replace_in_file"
1151 )
1152 }
1153
1154 fn extract_write_path(args: &serde_json::Value) -> Option<String> {
1156 args.get("path")
1159 .and_then(|v| v.as_str())
1160 .map(|s| s.to_string())
1161 }
1162
1163 async fn execute_tool_queued_or_direct(
1166 &self,
1167 name: &str,
1168 args: &serde_json::Value,
1169 ctx: &ToolContext,
1170 ) -> anyhow::Result<crate::tools::ToolResult> {
1171 let task_id = if let Some(ref tm) = self.task_manager {
1173 let task = crate::task::Task::tool(name, args.clone());
1174 let id = task.id;
1175 tm.spawn(task);
1176 let _ = tm.start(id);
1178 Some(id)
1179 } else {
1180 None
1181 };
1182
1183 let result = self
1184 .execute_tool_queued_or_direct_inner(name, args, ctx)
1185 .await;
1186
1187 if let Some(ref tm) = self.task_manager {
1189 if let Some(tid) = task_id {
1190 match &result {
1191 Ok(r) => {
1192 let output = serde_json::json!({
1193 "output": r.output.clone(),
1194 "exit_code": r.exit_code,
1195 });
1196 let _ = tm.complete(tid, Some(output));
1197 }
1198 Err(e) => {
1199 let _ = tm.fail(tid, e.to_string());
1200 }
1201 }
1202 }
1203 }
1204
1205 result
1206 }
1207
1208 async fn execute_tool_queued_or_direct_inner(
1210 &self,
1211 name: &str,
1212 args: &serde_json::Value,
1213 ctx: &ToolContext,
1214 ) -> anyhow::Result<crate::tools::ToolResult> {
1215 if let Some(ref queue) = self.command_queue {
1216 let command = ToolCommand::new(
1217 Arc::clone(&self.tool_executor),
1218 name.to_string(),
1219 args.clone(),
1220 ctx.clone(),
1221 self.config.skill_registry.clone(),
1222 );
1223 let rx = queue.submit_by_tool(name, Box::new(command)).await;
1224 match rx.await {
1225 Ok(Ok(value)) => {
1226 let output = value["output"]
1227 .as_str()
1228 .ok_or_else(|| {
1229 anyhow::anyhow!(
1230 "Queue result missing 'output' field for tool '{}'",
1231 name
1232 )
1233 })?
1234 .to_string();
1235 let exit_code = value["exit_code"].as_i64().unwrap_or(0) as i32;
1236 return Ok(crate::tools::ToolResult {
1237 name: name.to_string(),
1238 output,
1239 exit_code,
1240 metadata: None,
1241 images: Vec::new(),
1242 });
1243 }
1244 Ok(Err(e)) => {
1245 tracing::warn!(
1246 "Queue execution failed for tool '{}', falling back to direct: {}",
1247 name,
1248 e
1249 );
1250 }
1251 Err(_) => {
1252 tracing::warn!(
1253 "Queue channel closed for tool '{}', falling back to direct",
1254 name
1255 );
1256 }
1257 }
1258 }
1259 self.execute_tool_timed(name, args, ctx).await
1260 }
1261
1262 async fn call_llm(
1273 &self,
1274 messages: &[Message],
1275 system: Option<&str>,
1276 event_tx: &Option<mpsc::Sender<AgentEvent>>,
1277 cancel_token: &tokio_util::sync::CancellationToken,
1278 ) -> anyhow::Result<LlmResponse> {
1279 let tools = if let Some(ref index) = self.config.tool_index {
1281 let query = messages
1282 .iter()
1283 .rev()
1284 .find(|m| m.role == "user")
1285 .and_then(|m| {
1286 m.content.iter().find_map(|b| match b {
1287 crate::llm::ContentBlock::Text { text } => Some(text.as_str()),
1288 _ => None,
1289 })
1290 })
1291 .unwrap_or("");
1292 let matches = index.search(query, index.len());
1293 let matched_names: std::collections::HashSet<&str> =
1294 matches.iter().map(|m| m.name.as_str()).collect();
1295 self.config
1296 .tools
1297 .iter()
1298 .filter(|t| matched_names.contains(t.name.as_str()))
1299 .cloned()
1300 .collect::<Vec<_>>()
1301 } else {
1302 self.config.tools.clone()
1303 };
1304
1305 if event_tx.is_some() {
1306 let mut stream_rx = match self
1307 .llm_client
1308 .complete_streaming(messages, system, &tools)
1309 .await
1310 {
1311 Ok(rx) => rx,
1312 Err(stream_error) => {
1313 tracing::warn!(
1314 error = %stream_error,
1315 "LLM streaming setup failed; falling back to non-streaming completion"
1316 );
1317 return self
1318 .llm_client
1319 .complete(messages, system, &tools)
1320 .await
1321 .with_context(|| {
1322 format!(
1323 "LLM streaming call failed ({stream_error}); non-streaming fallback also failed"
1324 )
1325 });
1326 }
1327 };
1328
1329 let mut final_response: Option<LlmResponse> = None;
1330 loop {
1331 tokio::select! {
1332 _ = cancel_token.cancelled() => {
1333 tracing::info!("🛑 LLM streaming cancelled by CancellationToken");
1334 anyhow::bail!("Operation cancelled by user");
1335 }
1336 event = stream_rx.recv() => {
1337 match event {
1338 Some(crate::llm::StreamEvent::TextDelta(text)) => {
1339 if let Some(tx) = event_tx {
1340 tx.send(AgentEvent::TextDelta { text }).await.ok();
1341 }
1342 }
1343 Some(crate::llm::StreamEvent::ToolUseStart { id, name }) => {
1344 if let Some(tx) = event_tx {
1345 tx.send(AgentEvent::ToolStart { id, name }).await.ok();
1346 }
1347 }
1348 Some(crate::llm::StreamEvent::ToolUseInputDelta(delta)) => {
1349 if let Some(tx) = event_tx {
1350 tx.send(AgentEvent::ToolInputDelta { delta }).await.ok();
1351 }
1352 }
1353 Some(crate::llm::StreamEvent::Done(resp)) => {
1354 final_response = Some(resp);
1355 break;
1356 }
1357 None => break,
1358 }
1359 }
1360 }
1361 }
1362 final_response.context("Stream ended without final response")
1363 } else {
1364 self.llm_client
1365 .complete(messages, system, &tools)
1366 .await
1367 .context("LLM call failed")
1368 }
1369 }
1370
1371 fn streaming_tool_context(
1380 &self,
1381 event_tx: &Option<mpsc::Sender<AgentEvent>>,
1382 tool_id: &str,
1383 tool_name: &str,
1384 ) -> ToolContext {
1385 let mut ctx = self.tool_context.clone();
1386 if let Some(agent_tx) = event_tx {
1387 let (tool_tx, mut tool_rx) = mpsc::channel::<ToolStreamEvent>(64);
1388 ctx.event_tx = Some(tool_tx);
1389
1390 let agent_tx = agent_tx.clone();
1391 let tool_id = tool_id.to_string();
1392 let tool_name = tool_name.to_string();
1393 tokio::spawn(async move {
1394 while let Some(event) = tool_rx.recv().await {
1395 match event {
1396 ToolStreamEvent::OutputDelta(delta) => {
1397 agent_tx
1398 .send(AgentEvent::ToolOutputDelta {
1399 id: tool_id.clone(),
1400 name: tool_name.clone(),
1401 delta,
1402 })
1403 .await
1404 .ok();
1405 }
1406 }
1407 }
1408 });
1409 }
1410 ctx
1411 }
1412
1413 async fn resolve_context(&self, prompt: &str, session_id: Option<&str>) -> Vec<ContextResult> {
1417 if self.config.context_providers.is_empty() {
1418 return Vec::new();
1419 }
1420
1421 let query = ContextQuery::new(prompt).with_session_id(session_id.unwrap_or(""));
1422
1423 let futures = self
1424 .config
1425 .context_providers
1426 .iter()
1427 .map(|p| p.query(&query));
1428 let outcomes = join_all(futures).await;
1429
1430 outcomes
1431 .into_iter()
1432 .enumerate()
1433 .filter_map(|(i, r)| match r {
1434 Ok(result) if !result.is_empty() => Some(result),
1435 Ok(_) => None,
1436 Err(e) => {
1437 tracing::warn!(
1438 "Context provider '{}' failed: {}",
1439 self.config.context_providers[i].name(),
1440 e
1441 );
1442 None
1443 }
1444 })
1445 .collect()
1446 }
1447
1448 fn looks_incomplete(text: &str) -> bool {
1456 let t = text.trim();
1457 if t.is_empty() {
1458 return true;
1459 }
1460 if t.len() < 80 && !t.contains('\n') {
1462 let ends_continuation =
1465 t.ends_with(':') || t.ends_with("...") || t.ends_with('…') || t.ends_with(',');
1466 if ends_continuation {
1467 return true;
1468 }
1469 }
1470 let incomplete_phrases = [
1472 "i'll ",
1473 "i will ",
1474 "let me ",
1475 "i need to ",
1476 "i should ",
1477 "next, i",
1478 "first, i",
1479 "now i",
1480 "i'll start",
1481 "i'll begin",
1482 "i'll now",
1483 "let's start",
1484 "let's begin",
1485 "to do this",
1486 "i'm going to",
1487 ];
1488 let lower = t.to_lowercase();
1489 for phrase in &incomplete_phrases {
1490 if lower.contains(phrase) {
1491 return true;
1492 }
1493 }
1494 false
1495 }
1496
1497 #[allow(dead_code)]
1499 fn system_prompt(&self) -> String {
1500 self.config.prompt_slots.build()
1501 }
1502
1503 fn system_prompt_for_style(&self, style: AgentStyle) -> String {
1505 let mut slots = self.config.prompt_slots.clone();
1506 slots.style = Some(style);
1507 slots.build()
1508 }
1509
1510 async fn resolve_effective_style(&self, prompt: &str) -> AgentStyle {
1511 if let Some(style) = self.config.prompt_slots.style {
1512 return style;
1513 }
1514
1515 let (style, confidence) = AgentStyle::detect_with_confidence(prompt);
1516 if confidence != DetectionConfidence::Low {
1517 return style;
1518 }
1519
1520 match AgentStyle::detect_with_llm(self.llm_client.as_ref(), prompt).await {
1521 Ok(classified_style) => {
1522 tracing::debug!(
1523 intent.classification = ?classified_style,
1524 intent.source = "llm",
1525 "Intent classified via LLM"
1526 );
1527 classified_style
1528 }
1529 Err(e) => {
1530 tracing::warn!(error = %e, "LLM intent classification failed, using keyword detection");
1531 style
1532 }
1533 }
1534 }
1535
1536 fn should_launch_subagent(
1544 &self,
1545 prompt: &str,
1546 style: AgentStyle,
1547 ) -> Option<(Arc<crate::subagent::AgentDefinition>, String)> {
1548 let registry = self.config.subagent_registry.as_ref()?;
1549
1550 if let Some(caps) = prompt.find("[subagent:") {
1552 if let Some(end_offset) = prompt[caps..].find(']') {
1554 let end_abs = caps + end_offset;
1556 let name = &prompt[caps + 10..end_abs];
1558 if let Some(agent_def) = registry.get(name) {
1559 let after_tag = prompt[end_abs + 1..].trim();
1561 let cleaned = if caps > 0 {
1562 format!("{} {}", &prompt[..caps].trim(), after_tag)
1563 } else {
1564 after_tag.to_string()
1565 };
1566 tracing::info!(subagent = %name, "Explicit subagent request detected");
1567 return Some((Arc::new(agent_def), cleaned.trim().to_string()));
1568 }
1569 }
1570 }
1571
1572 let agent_name = match style {
1574 AgentStyle::Explore => "explore",
1575 AgentStyle::Plan => "plan",
1576 AgentStyle::Verification => "verification",
1577 AgentStyle::CodeReview => "review",
1578 AgentStyle::GeneralPurpose => return None,
1579 };
1580
1581 if let Some(agent_def) = registry.get(agent_name) {
1582 tracing::info!(
1583 subagent = %agent_name,
1584 style = ?style,
1585 "Auto-detected subagent launch based on style"
1586 );
1587 return Some((Arc::new(agent_def), prompt.to_string()));
1588 }
1589
1590 None
1591 }
1592
1593 pub fn detect_context_perception_intent(
1598 &self,
1599 prompt: &str,
1600 session_id: &str,
1601 workspace: &str,
1602 ) -> Option<PreContextPerceptionEvent> {
1603 let lower = prompt.to_lowercase();
1604
1605 let intents: &[(&[&str], &str)] = &[
1607 (
1609 &[
1610 "where is",
1611 "where are",
1612 "find the file",
1613 "find all",
1614 "find files",
1615 "who wrote",
1616 "locate",
1617 "search for",
1618 "look for",
1619 "search",
1620 ],
1621 "locate",
1622 ),
1623 (
1625 &[
1626 "how does",
1627 "what does",
1628 "explain",
1629 "understand",
1630 "what is this",
1631 "how does this work",
1632 ],
1633 "understand",
1634 ),
1635 (
1637 &[
1638 "remember",
1639 "earlier",
1640 "before",
1641 "previously",
1642 "last time",
1643 "past",
1644 "previous",
1645 ],
1646 "retrieve",
1647 ),
1648 (
1650 &[
1651 "how is organized",
1652 "project structure",
1653 "what files",
1654 "show me the structure",
1655 "explore",
1656 ],
1657 "explore",
1658 ),
1659 (
1661 &[
1662 "why did",
1663 "why is",
1664 "cause",
1665 "reason",
1666 "what happened",
1667 "why does",
1668 ],
1669 "reason",
1670 ),
1671 (
1673 &["is this correct", "verify", "validate", "check if", "debug"],
1674 "validate",
1675 ),
1676 (
1678 &[
1679 "difference between",
1680 "compare",
1681 "versus",
1682 " vs ",
1683 "different from",
1684 ],
1685 "compare",
1686 ),
1687 (
1689 &[
1690 "status",
1691 "progress",
1692 "how far",
1693 "history",
1694 "what's the current",
1695 ],
1696 "track",
1697 ),
1698 ];
1699
1700 let target_type = if lower.contains("function") || lower.contains("method") {
1702 "function"
1703 } else if lower.contains("file") || lower.contains("config") {
1704 "file"
1705 } else if lower.contains("class") {
1706 "entity"
1707 } else if lower.contains("module") || lower.contains("package") {
1708 "module"
1709 } else if lower.contains("test") {
1710 "test"
1711 } else {
1712 "unknown"
1713 };
1714
1715 let matched_intent = intents
1717 .iter()
1718 .find(|(patterns, _)| patterns.iter().any(|p| lower.contains(p)));
1719
1720 matched_intent.map(|(patterns, intent)| {
1721 let target_name = extract_target_name_from_prompt(prompt, patterns);
1723
1724 PreContextPerceptionEvent {
1725 session_id: session_id.to_string(),
1726 intent: intent.to_string(),
1727 target_type: target_type.to_string(),
1728 target_name,
1729 domain: detect_domain_from_prompt(prompt),
1730 query: Some(prompt.to_string()),
1731 working_directory: workspace.to_string(),
1732 urgency: "normal".to_string(),
1733 }
1734 })
1735 }
1736
1737 async fn fire_pre_context_perception(&self, event: &PreContextPerceptionEvent) -> HookResult {
1739 if let Some(he) = &self.config.hook_engine {
1740 let hook_event = HookEvent::PreContextPerception(event.clone());
1741 he.fire(&hook_event).await
1742 } else {
1743 HookResult::continue_()
1744 }
1745 }
1746
1747 async fn fire_intent_detection(
1752 &self,
1753 prompt: &str,
1754 session_id: &str,
1755 workspace: &str,
1756 ) -> Option<IntentDetectionResult> {
1757 let event = IntentDetectionEvent {
1758 session_id: session_id.to_string(),
1759 prompt: prompt.to_string(),
1760 workspace: workspace.to_string(),
1761 language_hint: detect_language_hint(prompt),
1762 };
1763
1764 let hook_result = if let Some(he) = &self.config.hook_engine {
1765 let hook_event = HookEvent::IntentDetection(event);
1766 he.fire(&hook_event).await
1767 } else {
1768 return None;
1769 };
1770
1771 match hook_result {
1772 HookResult::Continue(Some(modified)) => {
1773 serde_json::from_value::<IntentDetectionResult>(modified).ok()
1775 }
1776 HookResult::Block(_) => {
1777 tracing::info!("AHP harness blocked intent detection");
1779 None
1780 }
1781 _ => None,
1782 }
1783 }
1784
1785 #[cfg(feature = "ahp")]
1787 fn apply_injected_context(&self, injected: InjectedContext) -> Vec<ContextResult> {
1788 let mut results = Vec::new();
1789
1790 if !injected.facts.is_empty() {
1792 let items: Vec<ContextItem> = injected
1793 .facts
1794 .into_iter()
1795 .map(|f| {
1796 let token_count = estimate_tokens(&f.content);
1797 ContextItem {
1798 id: uuid::Uuid::new_v4().to_string(),
1799 context_type: ContextType::Resource,
1800 content: f.content,
1801 token_count,
1802 relevance: f.confidence,
1803 source: Some(f.source),
1804 metadata: std::collections::HashMap::new(),
1805 }
1806 })
1807 .collect();
1808
1809 let total_tokens: usize = items.iter().map(|i| i.token_count).sum();
1810
1811 results.push(ContextResult {
1812 items,
1813 total_tokens,
1814 provider: "ahp_harness".to_string(),
1815 truncated: false,
1816 });
1817 }
1818
1819 if let Some(file_contents) = injected.file_contents {
1821 let items: Vec<ContextItem> = file_contents
1822 .into_iter()
1823 .map(|f| {
1824 let token_count = estimate_tokens(&f.snippet);
1825 ContextItem {
1826 id: uuid::Uuid::new_v4().to_string(),
1827 context_type: ContextType::Resource,
1828 content: f.snippet,
1829 token_count,
1830 relevance: f.relevance_score,
1831 source: Some(f.path),
1832 metadata: std::collections::HashMap::new(),
1833 }
1834 })
1835 .collect();
1836
1837 let total_tokens: usize = items.iter().map(|i| i.token_count).sum();
1838
1839 results.push(ContextResult {
1840 items,
1841 total_tokens,
1842 provider: "ahp_harness".to_string(),
1843 truncated: false,
1844 });
1845 }
1846
1847 if let Some(summary) = injected.project_summary {
1849 let content = format!(
1850 "Project: {}\n{}",
1851 summary.project_name, summary.structure_description
1852 );
1853 let token_count = estimate_tokens(&content);
1854
1855 results.push(ContextResult {
1856 items: vec![ContextItem {
1857 id: uuid::Uuid::new_v4().to_string(),
1858 context_type: ContextType::Resource,
1859 content,
1860 token_count,
1861 relevance: 0.9,
1862 source: Some("ahp://project-summary".to_string()),
1863 metadata: std::collections::HashMap::new(),
1864 }],
1865 total_tokens: token_count,
1866 provider: "ahp_harness".to_string(),
1867 truncated: false,
1868 });
1869 }
1870
1871 if let Some(knowledge) = injected.knowledge {
1873 let items: Vec<ContextItem> = knowledge
1874 .into_iter()
1875 .map(|k| {
1876 let token_count = estimate_tokens(&k);
1877 ContextItem {
1878 id: uuid::Uuid::new_v4().to_string(),
1879 context_type: ContextType::Resource,
1880 content: k,
1881 token_count,
1882 relevance: 0.8,
1883 source: Some("ahp://knowledge".to_string()),
1884 metadata: std::collections::HashMap::new(),
1885 }
1886 })
1887 .collect();
1888
1889 let total_tokens: usize = items.iter().map(|i| i.token_count).sum();
1890
1891 results.push(ContextResult {
1892 items,
1893 total_tokens,
1894 provider: "ahp_harness".to_string(),
1895 truncated: false,
1896 });
1897 }
1898
1899 results
1900 }
1901
1902 #[allow(dead_code)]
1904 fn build_augmented_system_prompt(&self, context_results: &[ContextResult]) -> Option<String> {
1905 let base = self.system_prompt();
1906 self.build_augmented_system_prompt_with_base(&base, context_results)
1907 }
1908
1909 fn build_augmented_system_prompt_with_base(
1910 &self,
1911 base: &str,
1912 context_results: &[ContextResult],
1913 ) -> Option<String> {
1914 let base = base.to_string();
1915
1916 let live_tools = self.tool_executor.definitions();
1918 let mcp_tools: Vec<&ToolDefinition> = live_tools
1919 .iter()
1920 .filter(|t| t.name.starts_with("mcp__"))
1921 .collect();
1922
1923 let mcp_section = if mcp_tools.is_empty() {
1924 String::new()
1925 } else {
1926 let mut lines = vec![
1927 "## MCP Tools".to_string(),
1928 String::new(),
1929 "The following MCP (Model Context Protocol) tools are available. Use them when the task requires external capabilities beyond the built-in tools:".to_string(),
1930 String::new(),
1931 ];
1932 for tool in &mcp_tools {
1933 let display = format!("- `{}` — {}", tool.name, tool.description);
1934 lines.push(display);
1935 }
1936 lines.join("\n")
1937 };
1938
1939 let parts: Vec<&str> = [base.as_str(), mcp_section.as_str()]
1940 .iter()
1941 .filter(|s| !s.is_empty())
1942 .copied()
1943 .collect();
1944
1945 let project_hint = if self.config.prompt_slots.guidelines.is_none() {
1948 Self::detect_project_hint(&self.tool_context.workspace)
1949 } else {
1950 String::new()
1951 };
1952
1953 if context_results.is_empty() {
1954 if project_hint.is_empty() {
1955 return Some(parts.join("\n\n"));
1956 }
1957 return Some(format!("{}\n\n{}", parts.join("\n\n"), project_hint));
1958 }
1959
1960 let context_xml: String = context_results
1962 .iter()
1963 .map(|r| r.to_xml())
1964 .collect::<Vec<_>>()
1965 .join("\n\n");
1966
1967 if project_hint.is_empty() {
1968 Some(format!("{}\n\n{}", parts.join("\n\n"), context_xml))
1969 } else {
1970 Some(format!(
1971 "{}\n\n{}\n\n{}",
1972 parts.join("\n\n"),
1973 project_hint,
1974 context_xml
1975 ))
1976 }
1977 }
1978
1979 async fn notify_turn_complete(&self, session_id: &str, prompt: &str, response: &str) {
1981 let futures = self
1982 .config
1983 .context_providers
1984 .iter()
1985 .map(|p| p.on_turn_complete(session_id, prompt, response));
1986 let outcomes = join_all(futures).await;
1987
1988 for (i, result) in outcomes.into_iter().enumerate() {
1989 if let Err(e) = result {
1990 tracing::warn!(
1991 "Context provider '{}' on_turn_complete failed: {}",
1992 self.config.context_providers[i].name(),
1993 e
1994 );
1995 }
1996 }
1997 }
1998
1999 async fn fire_pre_tool_use(
2002 &self,
2003 session_id: &str,
2004 tool_name: &str,
2005 args: &serde_json::Value,
2006 recent_tools: Vec<String>,
2007 ) -> Option<HookResult> {
2008 if let Some(he) = &self.config.hook_engine {
2009 let safe_args = if args.is_null() {
2011 serde_json::Value::Object(Default::default())
2012 } else {
2013 args.clone()
2014 };
2015 let event = HookEvent::PreToolUse(PreToolUseEvent {
2016 session_id: session_id.to_string(),
2017 tool: tool_name.to_string(),
2018 args: safe_args,
2019 working_directory: self.tool_context.workspace.to_string_lossy().to_string(),
2020 recent_tools,
2021 });
2022 let result = he.fire(&event).await;
2023 if result.is_block() {
2024 return Some(result);
2025 }
2026 }
2027 None
2028 }
2029
2030 async fn fire_post_tool_use(
2032 &self,
2033 session_id: &str,
2034 tool_name: &str,
2035 args: &serde_json::Value,
2036 output: &str,
2037 success: bool,
2038 duration_ms: u64,
2039 ) {
2040 if let Some(he) = &self.config.hook_engine {
2041 let safe_args = if args.is_null() {
2043 serde_json::Value::Object(Default::default())
2044 } else {
2045 args.clone()
2046 };
2047 let event = HookEvent::PostToolUse(PostToolUseEvent {
2048 session_id: session_id.to_string(),
2049 tool: tool_name.to_string(),
2050 args: safe_args,
2051 result: ToolResultData {
2052 success,
2053 output: output.to_string(),
2054 exit_code: if success { Some(0) } else { Some(1) },
2055 duration_ms,
2056 },
2057 });
2058 let he = Arc::clone(he);
2059 tokio::spawn(async move {
2060 let _ = he.fire(&event).await;
2061 });
2062 }
2063 }
2064
2065 async fn fire_generate_start(
2067 &self,
2068 session_id: &str,
2069 prompt: &str,
2070 system_prompt: &Option<String>,
2071 ) {
2072 if let Some(he) = &self.config.hook_engine {
2073 let event = HookEvent::GenerateStart(GenerateStartEvent {
2074 session_id: session_id.to_string(),
2075 prompt: prompt.to_string(),
2076 system_prompt: system_prompt.clone(),
2077 model_provider: String::new(),
2078 model_name: String::new(),
2079 available_tools: self.config.tools.iter().map(|t| t.name.clone()).collect(),
2080 });
2081 let _ = he.fire(&event).await;
2082 }
2083 }
2084
2085 async fn fire_generate_end(
2087 &self,
2088 session_id: &str,
2089 prompt: &str,
2090 response: &LlmResponse,
2091 duration_ms: u64,
2092 ) {
2093 if let Some(he) = &self.config.hook_engine {
2094 let tool_calls: Vec<ToolCallInfo> = response
2095 .tool_calls()
2096 .iter()
2097 .map(|tc| {
2098 let args = if tc.args.is_null() {
2099 serde_json::Value::Object(Default::default())
2100 } else {
2101 tc.args.clone()
2102 };
2103 ToolCallInfo {
2104 name: tc.name.clone(),
2105 args,
2106 }
2107 })
2108 .collect();
2109
2110 let event = HookEvent::GenerateEnd(GenerateEndEvent {
2111 session_id: session_id.to_string(),
2112 prompt: prompt.to_string(),
2113 response_text: response.text().to_string(),
2114 tool_calls,
2115 usage: TokenUsageInfo {
2116 prompt_tokens: response.usage.prompt_tokens as i32,
2117 completion_tokens: response.usage.completion_tokens as i32,
2118 total_tokens: response.usage.total_tokens as i32,
2119 },
2120 duration_ms,
2121 });
2122 let _ = he.fire(&event).await;
2123 }
2124 }
2125
2126 async fn fire_pre_prompt(
2129 &self,
2130 session_id: &str,
2131 prompt: &str,
2132 system_prompt: &Option<String>,
2133 message_count: usize,
2134 ) -> Option<String> {
2135 if let Some(he) = &self.config.hook_engine {
2136 let event = HookEvent::PrePrompt(PrePromptEvent {
2137 session_id: session_id.to_string(),
2138 prompt: prompt.to_string(),
2139 system_prompt: system_prompt.clone(),
2140 message_count,
2141 });
2142 let result = he.fire(&event).await;
2143 if let HookResult::Continue(Some(modified)) = result {
2144 if let Some(new_prompt) = modified.get("prompt").and_then(|v| v.as_str()) {
2146 return Some(new_prompt.to_string());
2147 }
2148 }
2149 }
2150 None
2151 }
2152
2153 async fn fire_post_response(
2155 &self,
2156 session_id: &str,
2157 response_text: &str,
2158 tool_calls_count: usize,
2159 usage: &TokenUsage,
2160 duration_ms: u64,
2161 ) {
2162 if let Some(he) = &self.config.hook_engine {
2163 let event = HookEvent::PostResponse(PostResponseEvent {
2164 session_id: session_id.to_string(),
2165 response_text: response_text.to_string(),
2166 tool_calls_count,
2167 usage: TokenUsageInfo {
2168 prompt_tokens: usage.prompt_tokens as i32,
2169 completion_tokens: usage.completion_tokens as i32,
2170 total_tokens: usage.total_tokens as i32,
2171 },
2172 duration_ms,
2173 });
2174 let he = Arc::clone(he);
2175 tokio::spawn(async move {
2176 let _ = he.fire(&event).await;
2177 });
2178 }
2179 }
2180
2181 async fn fire_on_error(
2183 &self,
2184 session_id: &str,
2185 error_type: ErrorType,
2186 error_message: &str,
2187 context: serde_json::Value,
2188 ) {
2189 if let Some(he) = &self.config.hook_engine {
2190 let event = HookEvent::OnError(OnErrorEvent {
2191 session_id: session_id.to_string(),
2192 error_type,
2193 error_message: error_message.to_string(),
2194 context,
2195 });
2196 let he = Arc::clone(he);
2197 tokio::spawn(async move {
2198 let _ = he.fire(&event).await;
2199 });
2200 }
2201 }
2202
2203 pub async fn execute(
2209 &self,
2210 history: &[Message],
2211 prompt: &str,
2212 event_tx: Option<mpsc::Sender<AgentEvent>>,
2213 ) -> Result<AgentResult> {
2214 self.execute_with_session(history, prompt, None, event_tx, None)
2215 .await
2216 }
2217
2218 pub async fn execute_from_messages(
2224 &self,
2225 messages: Vec<Message>,
2226 session_id: Option<&str>,
2227 event_tx: Option<mpsc::Sender<AgentEvent>>,
2228 cancel_token: Option<&tokio_util::sync::CancellationToken>,
2229 ) -> Result<AgentResult> {
2230 let default_token = tokio_util::sync::CancellationToken::new();
2231 let token = cancel_token.unwrap_or(&default_token);
2232 tracing::info!(
2233 a3s.session.id = session_id.unwrap_or("none"),
2234 a3s.agent.max_turns = self.config.max_tool_rounds,
2235 "a3s.agent.execute_from_messages started"
2236 );
2237
2238 let effective_prompt = messages
2242 .iter()
2243 .rev()
2244 .find(|m| m.role == "user")
2245 .map(|m| m.text())
2246 .unwrap_or_default();
2247
2248 let result = self
2249 .execute_loop_inner(
2250 &messages,
2251 "",
2252 &effective_prompt,
2253 session_id,
2254 event_tx,
2255 token,
2256 true, )
2258 .await;
2259
2260 match &result {
2261 Ok(r) => tracing::info!(
2262 a3s.agent.tool_calls_count = r.tool_calls_count,
2263 a3s.llm.total_tokens = r.usage.total_tokens,
2264 "a3s.agent.execute_from_messages completed"
2265 ),
2266 Err(e) => tracing::warn!(
2267 error = %e,
2268 "a3s.agent.execute_from_messages failed"
2269 ),
2270 }
2271
2272 result
2273 }
2274
2275 pub async fn execute_with_session(
2280 &self,
2281 history: &[Message],
2282 prompt: &str,
2283 session_id: Option<&str>,
2284 event_tx: Option<mpsc::Sender<AgentEvent>>,
2285 cancel_token: Option<&tokio_util::sync::CancellationToken>,
2286 ) -> Result<AgentResult> {
2287 let default_token = tokio_util::sync::CancellationToken::new();
2288 let token = cancel_token.unwrap_or(&default_token);
2289 tracing::info!(
2290 a3s.session.id = session_id.unwrap_or("none"),
2291 a3s.agent.max_turns = self.config.max_tool_rounds,
2292 "a3s.agent.execute started"
2293 );
2294
2295 let effective_style = self.resolve_effective_style(prompt).await;
2296
2297 if let Some((subagent_def, cleaned_prompt)) =
2299 self.should_launch_subagent(prompt, effective_style)
2300 {
2301 tracing::info!(subagent = %subagent_def.name, "Subagent launch requested");
2302
2303 if let Some(ref callback) = self.config.on_subagent_launch {
2305 if let Some(result) = callback(&subagent_def, &cleaned_prompt) {
2306 tracing::info!(subagent = %subagent_def.name, "Subagent executed successfully");
2307 return result;
2308 }
2309 }
2310 tracing::debug!(subagent = %subagent_def.name, "No callback or callback returned None, continuing with normal execution");
2312 }
2313
2314 let use_planning = if self.config.planning_mode == PlanningMode::Auto {
2316 effective_style.requires_planning()
2317 } else {
2318 self.config.planning_mode.should_plan(prompt)
2320 };
2321
2322 let task_id = if let Some(ref tm) = self.task_manager {
2324 let workspace = self.tool_context.workspace.display().to_string();
2325 let task = crate::task::Task::agent("agent", &workspace, prompt);
2326 let id = task.id;
2327 tm.spawn(task);
2328 let _ = tm.start(id);
2329 Some(id)
2330 } else {
2331 None
2332 };
2333
2334 let result = if use_planning {
2335 self.execute_with_planning(history, prompt, event_tx).await
2336 } else {
2337 self.execute_loop(history, prompt, session_id, event_tx, token, true)
2338 .await
2339 };
2340
2341 if let Some(ref tm) = self.task_manager {
2343 if let Some(tid) = task_id {
2344 match &result {
2345 Ok(r) => {
2346 let output = serde_json::json!({
2347 "text": r.text,
2348 "tool_calls_count": r.tool_calls_count,
2349 "usage": r.usage,
2350 });
2351 let _ = tm.complete(tid, Some(output));
2352 }
2353 Err(e) => {
2354 let _ = tm.fail(tid, e.to_string());
2355 }
2356 }
2357 }
2358 }
2359
2360 match &result {
2361 Ok(r) => {
2362 tracing::info!(
2363 a3s.agent.tool_calls_count = r.tool_calls_count,
2364 a3s.llm.total_tokens = r.usage.total_tokens,
2365 "a3s.agent.execute completed"
2366 );
2367 self.fire_post_response(
2369 session_id.unwrap_or(""),
2370 &r.text,
2371 r.tool_calls_count,
2372 &r.usage,
2373 0, )
2375 .await;
2376 }
2377 Err(e) => {
2378 tracing::warn!(
2379 error = %e,
2380 "a3s.agent.execute failed"
2381 );
2382 self.fire_on_error(
2384 session_id.unwrap_or(""),
2385 ErrorType::Other,
2386 &e.to_string(),
2387 serde_json::json!({"phase": "execute"}),
2388 )
2389 .await;
2390 }
2391 }
2392
2393 result
2394 }
2395
2396 async fn execute_loop(
2402 &self,
2403 history: &[Message],
2404 prompt: &str,
2405 session_id: Option<&str>,
2406 event_tx: Option<mpsc::Sender<AgentEvent>>,
2407 cancel_token: &tokio_util::sync::CancellationToken,
2408 emit_end: bool,
2409 ) -> Result<AgentResult> {
2410 self.execute_loop_inner(
2413 history,
2414 prompt,
2415 prompt,
2416 session_id,
2417 event_tx,
2418 cancel_token,
2419 emit_end,
2420 )
2421 .await
2422 }
2423
2424 #[allow(clippy::too_many_arguments)]
2431 async fn execute_loop_inner(
2432 &self,
2433 history: &[Message],
2434 msg_prompt: &str,
2435 effective_prompt: &str,
2436 session_id: Option<&str>,
2437 event_tx: Option<mpsc::Sender<AgentEvent>>,
2438 cancel_token: &tokio_util::sync::CancellationToken,
2439 emit_end: bool,
2440 ) -> Result<AgentResult> {
2441 let mut messages = history.to_vec();
2442 let mut total_usage = TokenUsage::default();
2443 let mut tool_calls_count = 0;
2444 let mut turn = 0;
2445 let mut parse_error_count: u32 = 0;
2447 let mut continuation_count: u32 = 0;
2449 let mut recent_tool_signatures: Vec<String> = Vec::new();
2450 let style_prompt = if effective_prompt.is_empty() {
2451 msg_prompt
2452 } else {
2453 effective_prompt
2454 };
2455 let effective_style = self.resolve_effective_style(style_prompt).await;
2456 let effective_system_prompt = self.system_prompt_for_style(effective_style);
2457 if let Some(tx) = &event_tx {
2458 tx.send(AgentEvent::AgentModeChanged {
2459 mode: effective_style.runtime_mode().to_string(),
2460 agent: effective_style.builtin_agent_name().to_string(),
2461 description: effective_style.description().to_string(),
2462 })
2463 .await
2464 .ok();
2465 }
2466
2467 if let Some(tx) = &event_tx {
2469 tx.send(AgentEvent::Start {
2470 prompt: effective_prompt.to_string(),
2471 })
2472 .await
2473 .ok();
2474 }
2475
2476 let _queue_forward_handle =
2478 if let (Some(ref queue), Some(ref tx)) = (&self.command_queue, &event_tx) {
2479 let mut rx = queue.subscribe();
2480 let tx = tx.clone();
2481 Some(tokio::spawn(async move {
2482 while let Ok(event) = rx.recv().await {
2483 if tx.send(event).await.is_err() {
2484 break;
2485 }
2486 }
2487 }))
2488 } else {
2489 None
2490 };
2491
2492 let built_system_prompt = Some(effective_system_prompt.clone());
2494 let hooked_prompt = if let Some(modified) = self
2495 .fire_pre_prompt(
2496 session_id.unwrap_or(""),
2497 effective_prompt,
2498 &built_system_prompt,
2499 messages.len(),
2500 )
2501 .await
2502 {
2503 modified
2504 } else {
2505 effective_prompt.to_string()
2506 };
2507 let effective_prompt = hooked_prompt.as_str();
2508
2509 if let Some(ref sp) = self.config.security_provider {
2511 sp.taint_input(effective_prompt);
2512 }
2513
2514 let system_with_memory = if let Some(ref memory) = self.config.memory {
2516 match memory.recall_similar(effective_prompt, 5).await {
2517 Ok(items) if !items.is_empty() => {
2518 if let Some(tx) = &event_tx {
2519 for item in &items {
2520 tx.send(AgentEvent::MemoryRecalled {
2521 memory_id: item.id.clone(),
2522 content: item.content.clone(),
2523 relevance: item.relevance_score(),
2524 })
2525 .await
2526 .ok();
2527 }
2528 tx.send(AgentEvent::MemoriesSearched {
2529 query: Some(effective_prompt.to_string()),
2530 tags: Vec::new(),
2531 result_count: items.len(),
2532 })
2533 .await
2534 .ok();
2535 }
2536 let memory_context = items
2537 .iter()
2538 .map(|i| format!("- {}", i.content))
2539 .collect::<Vec<_>>()
2540 .join(
2541 "
2542",
2543 );
2544 let base = effective_system_prompt.clone();
2545 Some(format!(
2546 "{}
2547
2548## Relevant past experience
2549{}",
2550 base, memory_context
2551 ))
2552 }
2553 _ => Some(effective_system_prompt.clone()),
2554 }
2555 } else {
2556 Some(effective_system_prompt.clone())
2557 };
2558
2559 let workspace = self.tool_context.workspace.display().to_string();
2562 let session_id_str = session_id.unwrap_or("");
2563 let context_results = if !self.config.context_providers.is_empty() {
2564 #[allow(clippy::needless_borrow)]
2566 let harness_intent = self
2567 .fire_intent_detection(effective_prompt, &session_id_str, &workspace)
2568 .await;
2569
2570 #[allow(clippy::needless_borrow)]
2572 let perception_event = if let Some(detected) = harness_intent {
2573 tracing::info!(
2574 intent = %detected.detected_intent,
2575 confidence = %detected.confidence,
2576 "Intent detected from AHP harness"
2577 );
2578 Some(build_pre_context_perception_from_intent(
2579 detected,
2580 effective_prompt,
2581 &session_id_str,
2582 &workspace,
2583 ))
2584 } else {
2585 tracing::debug!("No intent from harness, using local keyword detection");
2587 self.detect_context_perception_intent(effective_prompt, &session_id_str, &workspace)
2588 };
2589
2590 if let Some(perception_event) = perception_event {
2591 tracing::info!(
2593 intent = %perception_event.intent,
2594 target_type = %perception_event.target_type,
2595 "Context perception intent detected, firing AHP hook"
2596 );
2597
2598 let hook_result = self.fire_pre_context_perception(&perception_event).await;
2599
2600 match hook_result {
2601 HookResult::Continue(Some(modified_context)) => {
2602 #[cfg(feature = "ahp")]
2604 {
2605 if let Ok(injected) =
2606 serde_json::from_value::<InjectedContext>(modified_context)
2607 {
2608 tracing::info!(
2609 facts = injected.facts.len(),
2610 "Using injected context from AHP harness"
2611 );
2612 self.apply_injected_context(injected)
2613 } else {
2614 tracing::warn!(
2616 "Failed to parse injected context, falling back to providers"
2617 );
2618 self.resolve_context(effective_prompt, session_id).await
2619 }
2620 }
2621 #[cfg(not(feature = "ahp"))]
2622 {
2623 let _ = modified_context; self.resolve_context(effective_prompt, session_id).await
2626 }
2627 }
2628 HookResult::Block(_) => {
2629 tracing::info!("AHP harness blocked context injection");
2631 Vec::new()
2632 }
2633 _ => {
2634 self.resolve_context(effective_prompt, session_id).await
2636 }
2637 }
2638 } else {
2639 self.resolve_context(effective_prompt, session_id).await
2641 }
2642 } else {
2643 Vec::new()
2644 };
2645
2646 if let Some(tx) = &event_tx {
2648 let total_items: usize = context_results.iter().map(|r| r.items.len()).sum();
2649 let total_tokens: usize = context_results.iter().map(|r| r.total_tokens).sum();
2650
2651 tracing::info!(
2652 context_items = total_items,
2653 context_tokens = total_tokens,
2654 "Context resolution completed"
2655 );
2656
2657 tx.send(AgentEvent::ContextResolved {
2658 total_items,
2659 total_tokens,
2660 })
2661 .await
2662 .ok();
2663 }
2664
2665 let augmented_system = self
2666 .build_augmented_system_prompt_with_base(&effective_system_prompt, &context_results);
2667
2668 let base_prompt = effective_system_prompt.clone();
2670 let augmented_system = match (augmented_system, system_with_memory) {
2671 (Some(ctx), Some(mem)) if ctx != mem => Some(ctx.replacen(&base_prompt, &mem, 1)),
2672 (Some(ctx), _) => Some(ctx),
2673 (None, mem) => mem,
2674 };
2675
2676 if !msg_prompt.is_empty() {
2678 messages.push(Message::user(msg_prompt));
2679 }
2680
2681 loop {
2682 turn += 1;
2683
2684 if turn > self.config.max_tool_rounds {
2685 let error = format!("Max tool rounds ({}) exceeded", self.config.max_tool_rounds);
2686 if let Some(tx) = &event_tx {
2687 tx.send(AgentEvent::Error {
2688 message: error.clone(),
2689 })
2690 .await
2691 .ok();
2692 }
2693 anyhow::bail!(error);
2694 }
2695
2696 if let Some(tx) = &event_tx {
2698 tx.send(AgentEvent::TurnStart { turn }).await.ok();
2699 }
2700
2701 tracing::info!(
2702 turn = turn,
2703 max_turns = self.config.max_tool_rounds,
2704 "Agent turn started"
2705 );
2706
2707 tracing::info!(
2709 a3s.llm.streaming = event_tx.is_some(),
2710 "LLM completion started"
2711 );
2712
2713 self.fire_generate_start(
2715 session_id.unwrap_or(""),
2716 effective_prompt,
2717 &augmented_system,
2718 )
2719 .await;
2720
2721 let llm_start = std::time::Instant::now();
2722 let response = {
2726 let threshold = self.config.circuit_breaker_threshold.max(1);
2727 let mut attempt = 0u32;
2728 loop {
2729 attempt += 1;
2730 let result = self
2731 .call_llm(
2732 &messages,
2733 augmented_system.as_deref(),
2734 &event_tx,
2735 cancel_token,
2736 )
2737 .await;
2738 match result {
2739 Ok(r) => {
2740 break r;
2741 }
2742 Err(e) if cancel_token.is_cancelled() => {
2744 anyhow::bail!(e);
2745 }
2746 Err(e) if attempt < threshold && (event_tx.is_none() || attempt == 1) => {
2748 tracing::warn!(
2749 turn = turn,
2750 attempt = attempt,
2751 threshold = threshold,
2752 error = %e,
2753 "LLM call failed, will retry"
2754 );
2755 tokio::time::sleep(Duration::from_millis(100 * attempt as u64)).await;
2756 }
2757 Err(e) => {
2759 let msg = if attempt > 1 {
2760 format!(
2761 "LLM circuit breaker triggered: failed after {} attempt(s): {}",
2762 attempt, e
2763 )
2764 } else {
2765 format!("LLM call failed: {}", e)
2766 };
2767 tracing::error!(turn = turn, attempt = attempt, "{}", msg);
2768 self.fire_on_error(
2770 session_id.unwrap_or(""),
2771 ErrorType::LlmFailure,
2772 &msg,
2773 serde_json::json!({"turn": turn, "attempt": attempt}),
2774 )
2775 .await;
2776 if let Some(tx) = &event_tx {
2777 tx.send(AgentEvent::Error {
2778 message: msg.clone(),
2779 })
2780 .await
2781 .ok();
2782 }
2783 anyhow::bail!(msg);
2784 }
2785 }
2786 }
2787 };
2788
2789 total_usage.prompt_tokens += response.usage.prompt_tokens;
2791 total_usage.completion_tokens += response.usage.completion_tokens;
2792 total_usage.total_tokens += response.usage.total_tokens;
2793
2794 if let Some(ref tracker) = self.progress_tracker {
2796 let token_usage = crate::task::TaskTokenUsage {
2797 input_tokens: response.usage.prompt_tokens as u64,
2798 output_tokens: response.usage.completion_tokens as u64,
2799 cache_read_tokens: response.usage.cache_read_tokens.unwrap_or(0) as u64,
2800 cache_write_tokens: response.usage.cache_write_tokens.unwrap_or(0) as u64,
2801 };
2802 if let Ok(mut guard) = tracker.try_write() {
2803 guard.track_tokens(token_usage);
2804 }
2805 }
2806
2807 let llm_duration = llm_start.elapsed();
2809 tracing::info!(
2810 turn = turn,
2811 streaming = event_tx.is_some(),
2812 prompt_tokens = response.usage.prompt_tokens,
2813 completion_tokens = response.usage.completion_tokens,
2814 total_tokens = response.usage.total_tokens,
2815 stop_reason = response.stop_reason.as_deref().unwrap_or("unknown"),
2816 duration_ms = llm_duration.as_millis() as u64,
2817 "LLM completion finished"
2818 );
2819
2820 self.fire_generate_end(
2822 session_id.unwrap_or(""),
2823 effective_prompt,
2824 &response,
2825 llm_duration.as_millis() as u64,
2826 )
2827 .await;
2828
2829 crate::telemetry::record_llm_usage(
2831 response.usage.prompt_tokens,
2832 response.usage.completion_tokens,
2833 response.usage.total_tokens,
2834 response.stop_reason.as_deref(),
2835 );
2836 tracing::info!(
2838 turn = turn,
2839 a3s.llm.total_tokens = response.usage.total_tokens,
2840 "Turn token usage"
2841 );
2842
2843 messages.push(response.message.clone());
2845
2846 let tool_calls = response.tool_calls();
2848
2849 if let Some(tx) = &event_tx {
2851 tx.send(AgentEvent::TurnEnd {
2852 turn,
2853 usage: response.usage.clone(),
2854 })
2855 .await
2856 .ok();
2857 }
2858
2859 if self.config.auto_compact {
2861 let used = response.usage.prompt_tokens;
2862 let max = self.config.max_context_tokens;
2863 let threshold = self.config.auto_compact_threshold;
2864
2865 if crate::session::compaction::should_auto_compact(used, max, threshold) {
2866 let before_len = messages.len();
2867 let percent_before = used as f32 / max as f32;
2868
2869 tracing::info!(
2870 used_tokens = used,
2871 max_tokens = max,
2872 percent = percent_before,
2873 threshold = threshold,
2874 "Auto-compact triggered"
2875 );
2876
2877 if let Some(pruned) = crate::session::compaction::prune_tool_outputs(&messages)
2879 {
2880 messages = pruned;
2881 tracing::info!("Tool output pruning applied");
2882 }
2883
2884 if let Ok(Some(compacted)) = crate::session::compaction::compact_messages(
2886 session_id.unwrap_or(""),
2887 &messages,
2888 &self.llm_client,
2889 )
2890 .await
2891 {
2892 messages = compacted;
2893 }
2894
2895 if let Some(tx) = &event_tx {
2897 tx.send(AgentEvent::ContextCompacted {
2898 session_id: session_id.unwrap_or("").to_string(),
2899 before_messages: before_len,
2900 after_messages: messages.len(),
2901 percent_before,
2902 })
2903 .await
2904 .ok();
2905 }
2906 }
2907 }
2908
2909 if tool_calls.is_empty() {
2910 let final_text = response.text();
2913
2914 if self.config.continuation_enabled
2915 && continuation_count < self.config.max_continuation_turns
2916 && turn < self.config.max_tool_rounds && Self::looks_incomplete(&final_text)
2918 {
2919 continuation_count += 1;
2920 tracing::info!(
2921 turn = turn,
2922 continuation = continuation_count,
2923 max_continuation = self.config.max_continuation_turns,
2924 "Injecting continuation message — response looks incomplete"
2925 );
2926 messages.push(Message::user(crate::prompts::CONTINUATION));
2928 continue;
2929 }
2930
2931 let final_text = if let Some(ref sp) = self.config.security_provider {
2933 sp.sanitize_output(&final_text)
2934 } else {
2935 final_text
2936 };
2937
2938 tracing::info!(
2940 tool_calls_count = tool_calls_count,
2941 total_prompt_tokens = total_usage.prompt_tokens,
2942 total_completion_tokens = total_usage.completion_tokens,
2943 total_tokens = total_usage.total_tokens,
2944 turns = turn,
2945 "Agent execution completed"
2946 );
2947
2948 if emit_end {
2949 if let Some(tx) = &event_tx {
2950 tx.send(AgentEvent::End {
2951 text: final_text.clone(),
2952 usage: total_usage.clone(),
2953 meta: response.meta.clone(),
2954 })
2955 .await
2956 .ok();
2957 }
2958 }
2959
2960 if let Some(sid) = session_id {
2962 self.notify_turn_complete(sid, effective_prompt, &final_text)
2963 .await;
2964 }
2965
2966 return Ok(AgentResult {
2967 text: final_text,
2968 messages,
2969 usage: total_usage,
2970 tool_calls_count,
2971 });
2972 }
2973
2974 let tool_calls = if self.config.hook_engine.is_none()
2978 && self.config.confirmation_manager.is_none()
2979 && tool_calls.len() > 1
2980 && tool_calls
2981 .iter()
2982 .all(|tc| Self::is_parallel_safe_write(&tc.name, &tc.args))
2983 && {
2984 let paths: Vec<_> = tool_calls
2986 .iter()
2987 .filter_map(|tc| Self::extract_write_path(&tc.args))
2988 .collect();
2989 paths.len() == tool_calls.len()
2990 && paths.iter().collect::<std::collections::HashSet<_>>().len()
2991 == paths.len()
2992 } {
2993 tracing::info!(
2994 count = tool_calls.len(),
2995 "Parallel write batch: executing {} independent file writes concurrently",
2996 tool_calls.len()
2997 );
2998
2999 let futures: Vec<_> = tool_calls
3000 .iter()
3001 .map(|tc| {
3002 let ctx = self.tool_context.clone();
3003 let executor = Arc::clone(&self.tool_executor);
3004 let name = tc.name.clone();
3005 let args = tc.args.clone();
3006 async move { executor.execute_with_context(&name, &args, &ctx).await }
3007 })
3008 .collect();
3009
3010 let results = join_all(futures).await;
3011
3012 for (tc, result) in tool_calls.iter().zip(results) {
3014 tool_calls_count += 1;
3015 let (output, exit_code, is_error, metadata, images) =
3016 Self::tool_result_to_tuple(result);
3017
3018 self.track_tool_result(&tc.name, &tc.args, exit_code);
3020
3021 let output = if let Some(ref sp) = self.config.security_provider {
3022 sp.sanitize_output(&output)
3023 } else {
3024 output
3025 };
3026
3027 if let Some(tx) = &event_tx {
3028 tx.send(AgentEvent::ToolEnd {
3029 id: tc.id.clone(),
3030 name: tc.name.clone(),
3031 output: output.clone(),
3032 exit_code,
3033 metadata,
3034 })
3035 .await
3036 .ok();
3037 }
3038
3039 if images.is_empty() {
3040 messages.push(Message::tool_result(&tc.id, &output, is_error));
3041 } else {
3042 messages.push(Message::tool_result_with_images(
3043 &tc.id, &output, &images, is_error,
3044 ));
3045 }
3046 }
3047
3048 continue;
3050 } else {
3051 tool_calls
3052 };
3053
3054 for tool_call in tool_calls {
3055 tool_calls_count += 1;
3056
3057 let tool_start = std::time::Instant::now();
3058
3059 tracing::info!(
3060 tool_name = tool_call.name.as_str(),
3061 tool_id = tool_call.id.as_str(),
3062 "Tool execution started"
3063 );
3064
3065 if let Some(parse_error) =
3071 tool_call.args.get("__parse_error").and_then(|v| v.as_str())
3072 {
3073 parse_error_count += 1;
3074 let error_msg = format!("Error: {}", parse_error);
3075 tracing::warn!(
3076 tool = tool_call.name.as_str(),
3077 parse_error_count = parse_error_count,
3078 max_parse_retries = self.config.max_parse_retries,
3079 "Malformed tool arguments from LLM"
3080 );
3081
3082 if let Some(tx) = &event_tx {
3083 tx.send(AgentEvent::ToolEnd {
3084 id: tool_call.id.clone(),
3085 name: tool_call.name.clone(),
3086 output: error_msg.clone(),
3087 exit_code: 1,
3088 metadata: None,
3089 })
3090 .await
3091 .ok();
3092 }
3093
3094 messages.push(Message::tool_result(&tool_call.id, &error_msg, true));
3095
3096 if parse_error_count > self.config.max_parse_retries {
3097 let msg = format!(
3098 "LLM produced malformed tool arguments {} time(s) in a row \
3099 (max_parse_retries={}); giving up",
3100 parse_error_count, self.config.max_parse_retries
3101 );
3102 tracing::error!("{}", msg);
3103 if let Some(tx) = &event_tx {
3104 tx.send(AgentEvent::Error {
3105 message: msg.clone(),
3106 })
3107 .await
3108 .ok();
3109 }
3110 anyhow::bail!(msg);
3111 }
3112 continue;
3113 }
3114
3115 parse_error_count = 0;
3117
3118 if let Some(ref registry) = self.config.skill_registry {
3120 let instruction_skills =
3121 registry.by_kind(crate::skills::SkillKind::Instruction);
3122 let has_restrictions =
3123 instruction_skills.iter().any(|s| s.allowed_tools.is_some());
3124 if has_restrictions {
3125 let allowed = instruction_skills
3126 .iter()
3127 .any(|s| s.is_tool_allowed(&tool_call.name));
3128 if !allowed {
3129 let msg = format!(
3130 "Tool '{}' is not allowed by any active skill.",
3131 tool_call.name
3132 );
3133 tracing::info!(
3134 tool_name = tool_call.name.as_str(),
3135 "Tool blocked by skill registry"
3136 );
3137 if let Some(tx) = &event_tx {
3138 tx.send(AgentEvent::PermissionDenied {
3139 tool_id: tool_call.id.clone(),
3140 tool_name: tool_call.name.clone(),
3141 args: tool_call.args.clone(),
3142 reason: msg.clone(),
3143 })
3144 .await
3145 .ok();
3146 }
3147 messages.push(Message::tool_result(&tool_call.id, &msg, true));
3148 continue;
3149 }
3150 }
3151 }
3152
3153 if let Some(HookResult::Block(reason)) = self
3155 .fire_pre_tool_use(
3156 session_id.unwrap_or(""),
3157 &tool_call.name,
3158 &tool_call.args,
3159 recent_tool_signatures.clone(),
3160 )
3161 .await
3162 {
3163 let msg = format!("Tool '{}' blocked by hook: {}", tool_call.name, reason);
3164 tracing::info!(
3165 tool_name = tool_call.name.as_str(),
3166 "Tool blocked by PreToolUse hook"
3167 );
3168
3169 if let Some(tx) = &event_tx {
3170 tx.send(AgentEvent::PermissionDenied {
3171 tool_id: tool_call.id.clone(),
3172 tool_name: tool_call.name.clone(),
3173 args: tool_call.args.clone(),
3174 reason: reason.clone(),
3175 })
3176 .await
3177 .ok();
3178 }
3179
3180 messages.push(Message::tool_result(&tool_call.id, &msg, true));
3181 continue;
3182 }
3183
3184 let permission_decision = if let Some(checker) = &self.config.permission_checker {
3186 checker.check(&tool_call.name, &tool_call.args)
3187 } else {
3188 PermissionDecision::Ask
3190 };
3191
3192 let (output, exit_code, is_error, metadata, images) = match permission_decision {
3193 PermissionDecision::Deny => {
3194 tracing::info!(
3195 tool_name = tool_call.name.as_str(),
3196 permission = "deny",
3197 "Tool permission denied"
3198 );
3199 let denial_msg = format!(
3201 "Permission denied: Tool '{}' is blocked by permission policy.",
3202 tool_call.name
3203 );
3204
3205 if let Some(tx) = &event_tx {
3207 tx.send(AgentEvent::PermissionDenied {
3208 tool_id: tool_call.id.clone(),
3209 tool_name: tool_call.name.clone(),
3210 args: tool_call.args.clone(),
3211 reason: "Blocked by deny rule in permission policy".to_string(),
3212 })
3213 .await
3214 .ok();
3215 }
3216
3217 (denial_msg, 1, true, None, Vec::new())
3218 }
3219 PermissionDecision::Allow => {
3220 tracing::info!(
3221 tool_name = tool_call.name.as_str(),
3222 permission = "allow",
3223 "Tool permission: allow"
3224 );
3225 let stream_ctx =
3227 self.streaming_tool_context(&event_tx, &tool_call.id, &tool_call.name);
3228 let result = self
3229 .execute_tool_queued_or_direct(
3230 &tool_call.name,
3231 &tool_call.args,
3232 &stream_ctx,
3233 )
3234 .await;
3235
3236 let tuple = Self::tool_result_to_tuple(result);
3237 let (_, exit_code, _, _, _) = tuple;
3239 self.track_tool_result(&tool_call.name, &tool_call.args, exit_code);
3240 tuple
3241 }
3242 PermissionDecision::Ask => {
3243 tracing::info!(
3244 tool_name = tool_call.name.as_str(),
3245 permission = "ask",
3246 "Tool permission: ask"
3247 );
3248 if let Some(cm) = &self.config.confirmation_manager {
3250 if !cm.requires_confirmation(&tool_call.name).await {
3252 let stream_ctx = self.streaming_tool_context(
3253 &event_tx,
3254 &tool_call.id,
3255 &tool_call.name,
3256 );
3257 let result = self
3258 .execute_tool_queued_or_direct(
3259 &tool_call.name,
3260 &tool_call.args,
3261 &stream_ctx,
3262 )
3263 .await;
3264
3265 let (output, exit_code, is_error, metadata, images) =
3266 Self::tool_result_to_tuple(result);
3267
3268 self.track_tool_result(&tool_call.name, &tool_call.args, exit_code);
3270
3271 if images.is_empty() {
3273 messages.push(Message::tool_result(
3274 &tool_call.id,
3275 &output,
3276 is_error,
3277 ));
3278 } else {
3279 messages.push(Message::tool_result_with_images(
3280 &tool_call.id,
3281 &output,
3282 &images,
3283 is_error,
3284 ));
3285 }
3286
3287 let tool_duration = tool_start.elapsed();
3289 crate::telemetry::record_tool_result(exit_code, tool_duration);
3290
3291 if let Some(tx) = &event_tx {
3293 tx.send(AgentEvent::ToolEnd {
3294 id: tool_call.id.clone(),
3295 name: tool_call.name.clone(),
3296 output: output.clone(),
3297 exit_code,
3298 metadata,
3299 })
3300 .await
3301 .ok();
3302 }
3303
3304 self.fire_post_tool_use(
3306 session_id.unwrap_or(""),
3307 &tool_call.name,
3308 &tool_call.args,
3309 &output,
3310 exit_code == 0,
3311 tool_duration.as_millis() as u64,
3312 )
3313 .await;
3314
3315 continue; }
3317
3318 let policy = cm.policy().await;
3320 let timeout_ms = policy.default_timeout_ms;
3321 let timeout_action = policy.timeout_action;
3322
3323 let rx = cm
3325 .request_confirmation(
3326 &tool_call.id,
3327 &tool_call.name,
3328 &tool_call.args,
3329 )
3330 .await;
3331
3332 if let Some(tx) = &event_tx {
3336 tx.send(AgentEvent::ConfirmationRequired {
3337 tool_id: tool_call.id.clone(),
3338 tool_name: tool_call.name.clone(),
3339 args: tool_call.args.clone(),
3340 timeout_ms,
3341 })
3342 .await
3343 .ok();
3344 }
3345
3346 let confirmation_result =
3348 tokio::time::timeout(Duration::from_millis(timeout_ms), rx).await;
3349
3350 match confirmation_result {
3351 Ok(Ok(response)) => {
3352 if let Some(tx) = &event_tx {
3354 tx.send(AgentEvent::ConfirmationReceived {
3355 tool_id: tool_call.id.clone(),
3356 approved: response.approved,
3357 reason: response.reason.clone(),
3358 })
3359 .await
3360 .ok();
3361 }
3362 if response.approved {
3363 let stream_ctx = self.streaming_tool_context(
3364 &event_tx,
3365 &tool_call.id,
3366 &tool_call.name,
3367 );
3368 let result = self
3369 .execute_tool_queued_or_direct(
3370 &tool_call.name,
3371 &tool_call.args,
3372 &stream_ctx,
3373 )
3374 .await;
3375
3376 let tuple = Self::tool_result_to_tuple(result);
3377 let (_, exit_code, _, _, _) = tuple;
3379 self.track_tool_result(
3380 &tool_call.name,
3381 &tool_call.args,
3382 exit_code,
3383 );
3384 tuple
3385 } else {
3386 let rejection_msg = format!(
3387 "Tool '{}' execution was REJECTED by the user. Reason: {}. \
3388 DO NOT retry this tool call unless the user explicitly asks you to.",
3389 tool_call.name,
3390 response.reason.unwrap_or_else(|| "No reason provided".to_string())
3391 );
3392 (rejection_msg, 1, true, None, Vec::new())
3393 }
3394 }
3395 Ok(Err(_)) => {
3396 if let Some(tx) = &event_tx {
3398 tx.send(AgentEvent::ConfirmationTimeout {
3399 tool_id: tool_call.id.clone(),
3400 action_taken: "rejected".to_string(),
3401 })
3402 .await
3403 .ok();
3404 }
3405 let msg = format!(
3406 "Tool '{}' confirmation failed: confirmation channel closed",
3407 tool_call.name
3408 );
3409 (msg, 1, true, None, Vec::new())
3410 }
3411 Err(_) => {
3412 cm.check_timeouts().await;
3413
3414 if let Some(tx) = &event_tx {
3416 tx.send(AgentEvent::ConfirmationTimeout {
3417 tool_id: tool_call.id.clone(),
3418 action_taken: match timeout_action {
3419 crate::hitl::TimeoutAction::Reject => {
3420 "rejected".to_string()
3421 }
3422 crate::hitl::TimeoutAction::AutoApprove => {
3423 "auto_approved".to_string()
3424 }
3425 },
3426 })
3427 .await
3428 .ok();
3429 }
3430
3431 match timeout_action {
3432 crate::hitl::TimeoutAction::Reject => {
3433 let msg = format!(
3434 "Tool '{}' execution was REJECTED: user confirmation timed out after {}ms. \
3435 DO NOT retry this tool call — the user did not approve it. \
3436 Inform the user that the operation requires their approval and ask them to try again.",
3437 tool_call.name, timeout_ms
3438 );
3439 (msg, 1, true, None, Vec::new())
3440 }
3441 crate::hitl::TimeoutAction::AutoApprove => {
3442 let stream_ctx = self.streaming_tool_context(
3443 &event_tx,
3444 &tool_call.id,
3445 &tool_call.name,
3446 );
3447 let result = self
3448 .execute_tool_queued_or_direct(
3449 &tool_call.name,
3450 &tool_call.args,
3451 &stream_ctx,
3452 )
3453 .await;
3454
3455 let tuple = Self::tool_result_to_tuple(result);
3456 let (_, exit_code, _, _, _) = tuple;
3458 self.track_tool_result(
3459 &tool_call.name,
3460 &tool_call.args,
3461 exit_code,
3462 );
3463 tuple
3464 }
3465 }
3466 }
3467 }
3468 } else {
3469 let msg = format!(
3471 "Tool '{}' requires confirmation but no HITL confirmation manager is configured. \
3472 Configure a confirmation policy to enable tool execution.",
3473 tool_call.name
3474 );
3475 tracing::warn!(
3476 tool_name = tool_call.name.as_str(),
3477 "Tool requires confirmation but no HITL manager configured"
3478 );
3479 (msg, 1, true, None, Vec::new())
3480 }
3481 }
3482 };
3483
3484 let tool_duration = tool_start.elapsed();
3485 crate::telemetry::record_tool_result(exit_code, tool_duration);
3486
3487 let output = if let Some(ref sp) = self.config.security_provider {
3489 sp.sanitize_output(&output)
3490 } else {
3491 output
3492 };
3493
3494 recent_tool_signatures.push(format!(
3495 "{}:{} => {}",
3496 tool_call.name,
3497 serde_json::to_string(&tool_call.args).unwrap_or_default(),
3498 if is_error { "error" } else { "ok" }
3499 ));
3500 if recent_tool_signatures.len() > 8 {
3501 let overflow = recent_tool_signatures.len() - 8;
3502 recent_tool_signatures.drain(0..overflow);
3503 }
3504
3505 self.fire_post_tool_use(
3507 session_id.unwrap_or(""),
3508 &tool_call.name,
3509 &tool_call.args,
3510 &output,
3511 exit_code == 0,
3512 tool_duration.as_millis() as u64,
3513 )
3514 .await;
3515
3516 if let Some(ref memory) = self.config.memory {
3518 let tools_used = [tool_call.name.clone()];
3519 let remember_result = if exit_code == 0 {
3520 memory
3521 .remember_success(effective_prompt, &tools_used, &output)
3522 .await
3523 } else {
3524 memory
3525 .remember_failure(effective_prompt, &output, &tools_used)
3526 .await
3527 };
3528 match remember_result {
3529 Ok(()) => {
3530 if let Some(tx) = &event_tx {
3531 let item_type = if exit_code == 0 { "success" } else { "failure" };
3532 tx.send(AgentEvent::MemoryStored {
3533 memory_id: uuid::Uuid::new_v4().to_string(),
3534 memory_type: item_type.to_string(),
3535 importance: if exit_code == 0 { 0.8 } else { 0.9 },
3536 tags: vec![item_type.to_string(), tool_call.name.clone()],
3537 })
3538 .await
3539 .ok();
3540 }
3541 }
3542 Err(e) => {
3543 tracing::warn!("Failed to store memory after tool execution: {}", e);
3544 }
3545 }
3546 }
3547
3548 if let Some(tx) = &event_tx {
3550 tx.send(AgentEvent::ToolEnd {
3551 id: tool_call.id.clone(),
3552 name: tool_call.name.clone(),
3553 output: output.clone(),
3554 exit_code,
3555 metadata,
3556 })
3557 .await
3558 .ok();
3559 }
3560
3561 if images.is_empty() {
3563 messages.push(Message::tool_result(&tool_call.id, &output, is_error));
3564 } else {
3565 messages.push(Message::tool_result_with_images(
3566 &tool_call.id,
3567 &output,
3568 &images,
3569 is_error,
3570 ));
3571 }
3572 }
3573 }
3574 }
3575
3576 pub async fn execute_streaming(
3578 &self,
3579 history: &[Message],
3580 prompt: &str,
3581 ) -> Result<(
3582 mpsc::Receiver<AgentEvent>,
3583 tokio::task::JoinHandle<Result<AgentResult>>,
3584 tokio_util::sync::CancellationToken,
3585 )> {
3586 let (tx, rx) = mpsc::channel(100);
3587 let cancel_token = tokio_util::sync::CancellationToken::new();
3588
3589 let llm_client = self.llm_client.clone();
3590 let tool_executor = self.tool_executor.clone();
3591 let tool_context = self.tool_context.clone();
3592 let config = self.config.clone();
3593 let tool_metrics = self.tool_metrics.clone();
3594 let command_queue = self.command_queue.clone();
3595 let history = history.to_vec();
3596 let prompt = prompt.to_string();
3597 let token_clone = cancel_token.clone();
3598
3599 let handle = tokio::spawn(async move {
3600 let mut agent = AgentLoop::new(llm_client, tool_executor, tool_context, config);
3601 if let Some(metrics) = tool_metrics {
3602 agent = agent.with_tool_metrics(metrics);
3603 }
3604 if let Some(queue) = command_queue {
3605 agent = agent.with_queue(queue);
3606 }
3607 agent
3608 .execute_with_session(&history, &prompt, None, Some(tx), Some(&token_clone))
3609 .await
3610 });
3611
3612 Ok((rx, handle, cancel_token))
3613 }
3614
3615 pub async fn plan(&self, prompt: &str, _context: Option<&str>) -> Result<ExecutionPlan> {
3620 use crate::planning::LlmPlanner;
3621
3622 match LlmPlanner::create_plan(&self.llm_client, prompt).await {
3623 Ok(plan) => Ok(plan),
3624 Err(e) => {
3625 tracing::warn!("LLM plan creation failed, using fallback: {}", e);
3626 Ok(LlmPlanner::fallback_plan(prompt))
3627 }
3628 }
3629 }
3630
3631 pub async fn execute_with_planning(
3633 &self,
3634 history: &[Message],
3635 prompt: &str,
3636 event_tx: Option<mpsc::Sender<AgentEvent>>,
3637 ) -> Result<AgentResult> {
3638 if let Some(tx) = &event_tx {
3640 tx.send(AgentEvent::PlanningStart {
3641 prompt: prompt.to_string(),
3642 })
3643 .await
3644 .ok();
3645 }
3646
3647 let goal = if self.config.goal_tracking {
3649 let g = self.extract_goal(prompt).await?;
3650 if let Some(tx) = &event_tx {
3651 tx.send(AgentEvent::GoalExtracted { goal: g.clone() })
3652 .await
3653 .ok();
3654 }
3655 Some(g)
3656 } else {
3657 None
3658 };
3659
3660 let plan = self.plan(prompt, None).await?;
3662
3663 if let Some(tx) = &event_tx {
3665 tx.send(AgentEvent::PlanningEnd {
3666 estimated_steps: plan.steps.len(),
3667 plan: plan.clone(),
3668 })
3669 .await
3670 .ok();
3671 }
3672
3673 let plan_start = std::time::Instant::now();
3674
3675 let result = self.execute_plan(history, &plan, event_tx.clone()).await?;
3677
3678 if let Some(tx) = &event_tx {
3680 tx.send(AgentEvent::End {
3681 text: result.text.clone(),
3682 usage: result.usage.clone(),
3683 meta: None,
3684 })
3685 .await
3686 .ok();
3687 }
3688
3689 if self.config.goal_tracking {
3691 if let Some(ref g) = goal {
3692 let achieved = self.check_goal_achievement(g, &result.text).await?;
3693 if achieved {
3694 if let Some(tx) = &event_tx {
3695 tx.send(AgentEvent::GoalAchieved {
3696 goal: g.description.clone(),
3697 total_steps: result.messages.len(),
3698 duration_ms: plan_start.elapsed().as_millis() as i64,
3699 })
3700 .await
3701 .ok();
3702 }
3703 }
3704 }
3705 }
3706
3707 Ok(result)
3708 }
3709
3710 async fn execute_plan(
3717 &self,
3718 history: &[Message],
3719 plan: &ExecutionPlan,
3720 event_tx: Option<mpsc::Sender<AgentEvent>>,
3721 ) -> Result<AgentResult> {
3722 let mut plan = plan.clone();
3723 let mut current_history = history.to_vec();
3724 let mut total_usage = TokenUsage::default();
3725 let mut tool_calls_count = 0;
3726 let total_steps = plan.steps.len();
3727
3728 let steps_text = plan
3730 .steps
3731 .iter()
3732 .enumerate()
3733 .map(|(i, step)| format!("{}. {}", i + 1, step.content))
3734 .collect::<Vec<_>>()
3735 .join("\n");
3736 current_history.push(Message::user(&crate::prompts::render(
3737 crate::prompts::PLAN_EXECUTE_GOAL,
3738 &[("goal", &plan.goal), ("steps", &steps_text)],
3739 )));
3740
3741 loop {
3742 let ready: Vec<String> = plan
3743 .get_ready_steps()
3744 .iter()
3745 .map(|s| s.id.clone())
3746 .collect();
3747
3748 if ready.is_empty() {
3749 if plan.has_deadlock() {
3751 tracing::warn!(
3752 "Plan deadlock detected: {} pending steps with unresolvable dependencies",
3753 plan.pending_count()
3754 );
3755 }
3756 break;
3757 }
3758
3759 if ready.len() == 1 {
3760 let step_id = &ready[0];
3762 let step = plan
3763 .steps
3764 .iter()
3765 .find(|s| s.id == *step_id)
3766 .ok_or_else(|| anyhow::anyhow!("step '{}' not found in plan", step_id))?
3767 .clone();
3768 let step_number = plan
3769 .steps
3770 .iter()
3771 .position(|s| s.id == *step_id)
3772 .unwrap_or(0)
3773 + 1;
3774
3775 if let Some(tx) = &event_tx {
3777 tx.send(AgentEvent::StepStart {
3778 step_id: step.id.clone(),
3779 description: step.content.clone(),
3780 step_number,
3781 total_steps,
3782 })
3783 .await
3784 .ok();
3785 }
3786
3787 plan.mark_status(&step.id, TaskStatus::InProgress);
3788
3789 let step_prompt = crate::prompts::render(
3790 crate::prompts::PLAN_EXECUTE_STEP,
3791 &[
3792 ("step_num", &step_number.to_string()),
3793 ("description", &step.content),
3794 ],
3795 );
3796
3797 match self
3798 .execute_loop(
3799 ¤t_history,
3800 &step_prompt,
3801 None,
3802 event_tx.clone(),
3803 &tokio_util::sync::CancellationToken::new(),
3804 false, )
3806 .await
3807 {
3808 Ok(result) => {
3809 current_history = result.messages.clone();
3810 total_usage.prompt_tokens += result.usage.prompt_tokens;
3811 total_usage.completion_tokens += result.usage.completion_tokens;
3812 total_usage.total_tokens += result.usage.total_tokens;
3813 tool_calls_count += result.tool_calls_count;
3814 plan.mark_status(&step.id, TaskStatus::Completed);
3815
3816 if let Some(tx) = &event_tx {
3817 tx.send(AgentEvent::StepEnd {
3818 step_id: step.id.clone(),
3819 status: TaskStatus::Completed,
3820 step_number,
3821 total_steps,
3822 })
3823 .await
3824 .ok();
3825 }
3826 }
3827 Err(e) => {
3828 tracing::error!("Plan step '{}' failed: {}", step.id, e);
3829 plan.mark_status(&step.id, TaskStatus::Failed);
3830
3831 if let Some(tx) = &event_tx {
3832 tx.send(AgentEvent::StepEnd {
3833 step_id: step.id.clone(),
3834 status: TaskStatus::Failed,
3835 step_number,
3836 total_steps,
3837 })
3838 .await
3839 .ok();
3840 }
3841 }
3842 }
3843 } else {
3844 let ready_steps: Vec<_> = ready
3851 .iter()
3852 .filter_map(|id| {
3853 let step = plan.steps.iter().find(|s| s.id == *id)?.clone();
3854 let step_number =
3855 plan.steps.iter().position(|s| s.id == *id).unwrap_or(0) + 1;
3856 Some((step, step_number))
3857 })
3858 .collect();
3859
3860 for (step, step_number) in &ready_steps {
3862 plan.mark_status(&step.id, TaskStatus::InProgress);
3863 if let Some(tx) = &event_tx {
3864 tx.send(AgentEvent::StepStart {
3865 step_id: step.id.clone(),
3866 description: step.content.clone(),
3867 step_number: *step_number,
3868 total_steps,
3869 })
3870 .await
3871 .ok();
3872 }
3873 }
3874
3875 let mut join_set = tokio::task::JoinSet::new();
3877 for (step, step_number) in &ready_steps {
3878 let base_history = current_history.clone();
3879 let agent_clone = self.clone();
3880 let tx = event_tx.clone();
3881 let step_clone = step.clone();
3882 let sn = *step_number;
3883
3884 join_set.spawn(async move {
3885 let prompt = crate::prompts::render(
3886 crate::prompts::PLAN_EXECUTE_STEP,
3887 &[
3888 ("step_num", &sn.to_string()),
3889 ("description", &step_clone.content),
3890 ],
3891 );
3892 let result = agent_clone
3893 .execute_loop(
3894 &base_history,
3895 &prompt,
3896 None,
3897 tx,
3898 &tokio_util::sync::CancellationToken::new(),
3899 false, )
3901 .await;
3902 (step_clone.id, sn, result)
3903 });
3904 }
3905
3906 let mut parallel_summaries = Vec::new();
3908 while let Some(join_result) = join_set.join_next().await {
3909 match join_result {
3910 Ok((step_id, step_number, step_result)) => match step_result {
3911 Ok(result) => {
3912 total_usage.prompt_tokens += result.usage.prompt_tokens;
3913 total_usage.completion_tokens += result.usage.completion_tokens;
3914 total_usage.total_tokens += result.usage.total_tokens;
3915 tool_calls_count += result.tool_calls_count;
3916 plan.mark_status(&step_id, TaskStatus::Completed);
3917
3918 parallel_summaries.push(format!(
3920 "- Step {} ({}): {}",
3921 step_number, step_id, result.text
3922 ));
3923
3924 if let Some(tx) = &event_tx {
3925 tx.send(AgentEvent::StepEnd {
3926 step_id,
3927 status: TaskStatus::Completed,
3928 step_number,
3929 total_steps,
3930 })
3931 .await
3932 .ok();
3933 }
3934 }
3935 Err(e) => {
3936 tracing::error!("Plan step '{}' failed: {}", step_id, e);
3937 plan.mark_status(&step_id, TaskStatus::Failed);
3938
3939 if let Some(tx) = &event_tx {
3940 tx.send(AgentEvent::StepEnd {
3941 step_id,
3942 status: TaskStatus::Failed,
3943 step_number,
3944 total_steps,
3945 })
3946 .await
3947 .ok();
3948 }
3949 }
3950 },
3951 Err(e) => {
3952 tracing::error!("JoinSet task panicked: {}", e);
3953 }
3954 }
3955 }
3956
3957 if !parallel_summaries.is_empty() {
3959 parallel_summaries.sort(); let results_text = parallel_summaries.join("\n");
3961 current_history.push(Message::user(&crate::prompts::render(
3962 crate::prompts::PLAN_PARALLEL_RESULTS,
3963 &[("results", &results_text)],
3964 )));
3965 }
3966 }
3967
3968 if self.config.goal_tracking {
3970 let completed = plan
3971 .steps
3972 .iter()
3973 .filter(|s| s.status == TaskStatus::Completed)
3974 .count();
3975 if let Some(tx) = &event_tx {
3976 tx.send(AgentEvent::GoalProgress {
3977 goal: plan.goal.clone(),
3978 progress: plan.progress(),
3979 completed_steps: completed,
3980 total_steps,
3981 })
3982 .await
3983 .ok();
3984 }
3985 }
3986 }
3987
3988 let final_text = current_history
3990 .last()
3991 .map(|m| {
3992 m.content
3993 .iter()
3994 .filter_map(|block| {
3995 if let crate::llm::ContentBlock::Text { text } = block {
3996 Some(text.as_str())
3997 } else {
3998 None
3999 }
4000 })
4001 .collect::<Vec<_>>()
4002 .join("\n")
4003 })
4004 .unwrap_or_default();
4005
4006 Ok(AgentResult {
4007 text: final_text,
4008 messages: current_history,
4009 usage: total_usage,
4010 tool_calls_count,
4011 })
4012 }
4013
4014 pub async fn extract_goal(&self, prompt: &str) -> Result<AgentGoal> {
4019 use crate::planning::LlmPlanner;
4020
4021 match LlmPlanner::extract_goal(&self.llm_client, prompt).await {
4022 Ok(goal) => Ok(goal),
4023 Err(e) => {
4024 tracing::warn!("LLM goal extraction failed, using fallback: {}", e);
4025 Ok(LlmPlanner::fallback_goal(prompt))
4026 }
4027 }
4028 }
4029
4030 pub async fn check_goal_achievement(
4035 &self,
4036 goal: &AgentGoal,
4037 current_state: &str,
4038 ) -> Result<bool> {
4039 use crate::planning::LlmPlanner;
4040
4041 match LlmPlanner::check_achievement(&self.llm_client, goal, current_state).await {
4042 Ok(result) => Ok(result.achieved),
4043 Err(e) => {
4044 tracing::warn!("LLM achievement check failed, using fallback: {}", e);
4045 let result = LlmPlanner::fallback_check_achievement(goal, current_state);
4046 Ok(result.achieved)
4047 }
4048 }
4049 }
4050}
4051
4052#[cfg(test)]
4053mod tests {
4054 use super::*;
4055 use crate::llm::{ContentBlock, StreamEvent};
4056 use crate::permissions::PermissionPolicy;
4057 use crate::tools::ToolExecutor;
4058 use std::path::PathBuf;
4059 use std::sync::atomic::{AtomicUsize, Ordering};
4060
4061 fn test_tool_context() -> ToolContext {
4063 ToolContext::new(PathBuf::from("/tmp"))
4064 }
4065
4066 #[test]
4067 fn test_agent_config_default() {
4068 let config = AgentConfig::default();
4069 assert!(config.prompt_slots.is_empty());
4070 assert!(config.tools.is_empty()); assert_eq!(config.max_tool_rounds, MAX_TOOL_ROUNDS);
4072 assert!(config.permission_checker.is_none());
4073 assert!(config.context_providers.is_empty());
4074 let registry = config
4076 .skill_registry
4077 .expect("skill_registry must be Some by default");
4078 assert!(registry.len() >= 7, "expected at least 7 built-in skills");
4079 assert!(registry.get("code-search").is_some());
4080 assert!(registry.get("find-bugs").is_some());
4081 }
4082
4083 pub(crate) struct MockLlmClient {
4089 responses: std::sync::Mutex<Vec<LlmResponse>>,
4091 pub(crate) call_count: AtomicUsize,
4093 }
4094
4095 impl MockLlmClient {
4096 pub(crate) fn new(responses: Vec<LlmResponse>) -> Self {
4097 Self {
4098 responses: std::sync::Mutex::new(responses),
4099 call_count: AtomicUsize::new(0),
4100 }
4101 }
4102
4103 pub(crate) fn text_response(text: &str) -> LlmResponse {
4105 LlmResponse {
4106 message: Message {
4107 role: "assistant".to_string(),
4108 content: vec![ContentBlock::Text {
4109 text: text.to_string(),
4110 }],
4111 reasoning_content: None,
4112 },
4113 usage: TokenUsage {
4114 prompt_tokens: 10,
4115 completion_tokens: 5,
4116 total_tokens: 15,
4117 cache_read_tokens: None,
4118 cache_write_tokens: None,
4119 },
4120 stop_reason: Some("end_turn".to_string()),
4121 meta: None,
4122 }
4123 }
4124
4125 pub(crate) fn tool_call_response(
4127 tool_id: &str,
4128 tool_name: &str,
4129 args: serde_json::Value,
4130 ) -> LlmResponse {
4131 LlmResponse {
4132 message: Message {
4133 role: "assistant".to_string(),
4134 content: vec![ContentBlock::ToolUse {
4135 id: tool_id.to_string(),
4136 name: tool_name.to_string(),
4137 input: args,
4138 }],
4139 reasoning_content: None,
4140 },
4141 usage: TokenUsage {
4142 prompt_tokens: 10,
4143 completion_tokens: 5,
4144 total_tokens: 15,
4145 cache_read_tokens: None,
4146 cache_write_tokens: None,
4147 },
4148 stop_reason: Some("tool_use".to_string()),
4149 meta: None,
4150 }
4151 }
4152 }
4153
4154 #[async_trait::async_trait]
4155 impl LlmClient for MockLlmClient {
4156 async fn complete(
4157 &self,
4158 _messages: &[Message],
4159 _system: Option<&str>,
4160 _tools: &[ToolDefinition],
4161 ) -> Result<LlmResponse> {
4162 self.call_count.fetch_add(1, Ordering::SeqCst);
4163 let mut responses = self.responses.lock().unwrap();
4164 if responses.is_empty() {
4165 anyhow::bail!("No more mock responses available");
4166 }
4167 Ok(responses.remove(0))
4168 }
4169
4170 async fn complete_streaming(
4171 &self,
4172 _messages: &[Message],
4173 _system: Option<&str>,
4174 _tools: &[ToolDefinition],
4175 ) -> Result<mpsc::Receiver<StreamEvent>> {
4176 self.call_count.fetch_add(1, Ordering::SeqCst);
4177 let mut responses = self.responses.lock().unwrap();
4178 if responses.is_empty() {
4179 anyhow::bail!("No more mock responses available");
4180 }
4181 let response = responses.remove(0);
4182
4183 let (tx, rx) = mpsc::channel(10);
4184 tokio::spawn(async move {
4185 for block in &response.message.content {
4187 if let ContentBlock::Text { text } = block {
4188 tx.send(StreamEvent::TextDelta(text.clone())).await.ok();
4189 }
4190 }
4191 tx.send(StreamEvent::Done(response)).await.ok();
4192 });
4193
4194 Ok(rx)
4195 }
4196 }
4197
4198 #[tokio::test]
4203 async fn test_agent_simple_response() {
4204 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4205 "Hello, I'm an AI assistant.",
4206 )]));
4207
4208 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4209 let config = AgentConfig::default();
4210
4211 let agent = AgentLoop::new(
4212 mock_client.clone(),
4213 tool_executor,
4214 test_tool_context(),
4215 config,
4216 );
4217 let result = agent.execute(&[], "Hello", None).await.unwrap();
4218
4219 assert_eq!(result.text, "Hello, I'm an AI assistant.");
4220 assert_eq!(result.tool_calls_count, 0);
4221 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 1);
4222 }
4223
4224 #[tokio::test]
4225 async fn test_agent_with_tool_call() {
4226 let mock_client = Arc::new(MockLlmClient::new(vec![
4227 MockLlmClient::tool_call_response(
4229 "tool-1",
4230 "bash",
4231 serde_json::json!({"command": "echo hello"}),
4232 ),
4233 MockLlmClient::text_response("The command output was: hello"),
4235 ]));
4236
4237 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4238 let config = AgentConfig::default();
4239
4240 let agent = AgentLoop::new(
4241 mock_client.clone(),
4242 tool_executor,
4243 test_tool_context(),
4244 config,
4245 );
4246 let result = agent.execute(&[], "Run echo hello", None).await.unwrap();
4247
4248 assert_eq!(result.text, "The command output was: hello");
4249 assert_eq!(result.tool_calls_count, 1);
4250 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 2);
4251 }
4252
4253 #[tokio::test]
4254 async fn test_agent_permission_deny() {
4255 let mock_client = Arc::new(MockLlmClient::new(vec![
4256 MockLlmClient::tool_call_response(
4258 "tool-1",
4259 "bash",
4260 serde_json::json!({"command": "rm -rf /tmp/test"}),
4261 ),
4262 MockLlmClient::text_response(
4264 "I cannot execute that command due to permission restrictions.",
4265 ),
4266 ]));
4267
4268 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4269
4270 let permission_policy = PermissionPolicy::new().deny("bash(rm:*)");
4272
4273 let config = AgentConfig {
4274 permission_checker: Some(Arc::new(permission_policy)),
4275 ..Default::default()
4276 };
4277
4278 let (tx, mut rx) = mpsc::channel(100);
4279 let agent = AgentLoop::new(
4280 mock_client.clone(),
4281 tool_executor,
4282 test_tool_context(),
4283 config,
4284 );
4285 let result = agent.execute(&[], "Delete files", Some(tx)).await.unwrap();
4286
4287 let mut found_permission_denied = false;
4289 while let Ok(event) = rx.try_recv() {
4290 if let AgentEvent::PermissionDenied { tool_name, .. } = event {
4291 assert_eq!(tool_name, "bash");
4292 found_permission_denied = true;
4293 }
4294 }
4295 assert!(
4296 found_permission_denied,
4297 "Should have received PermissionDenied event"
4298 );
4299
4300 assert_eq!(result.tool_calls_count, 1);
4301 }
4302
4303 #[tokio::test]
4304 async fn test_agent_permission_allow() {
4305 let mock_client = Arc::new(MockLlmClient::new(vec![
4306 MockLlmClient::tool_call_response(
4308 "tool-1",
4309 "bash",
4310 serde_json::json!({"command": "echo hello"}),
4311 ),
4312 MockLlmClient::text_response("Done!"),
4314 ]));
4315
4316 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4317
4318 let permission_policy = PermissionPolicy::new()
4320 .allow("bash(echo:*)")
4321 .deny("bash(rm:*)");
4322
4323 let config = AgentConfig {
4324 permission_checker: Some(Arc::new(permission_policy)),
4325 ..Default::default()
4326 };
4327
4328 let agent = AgentLoop::new(
4329 mock_client.clone(),
4330 tool_executor,
4331 test_tool_context(),
4332 config,
4333 );
4334 let result = agent.execute(&[], "Echo hello", None).await.unwrap();
4335
4336 assert_eq!(result.text, "Done!");
4337 assert_eq!(result.tool_calls_count, 1);
4338 }
4339
4340 #[tokio::test]
4341 async fn test_agent_streaming_events() {
4342 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4343 "Hello!",
4344 )]));
4345
4346 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4347 let config = AgentConfig::default();
4348
4349 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
4350 let (mut rx, handle, _cancel_token) = agent.execute_streaming(&[], "Hi").await.unwrap();
4351
4352 let mut events = Vec::new();
4354 while let Some(event) = rx.recv().await {
4355 events.push(event);
4356 }
4357
4358 let result = handle.await.unwrap().unwrap();
4359 assert_eq!(result.text, "Hello!");
4360
4361 assert!(events.iter().any(|e| matches!(e, AgentEvent::Start { .. })));
4363 assert!(events.iter().any(|e| matches!(e, AgentEvent::End { .. })));
4364 }
4365
4366 #[tokio::test]
4367 async fn test_agent_max_tool_rounds() {
4368 let responses: Vec<LlmResponse> = (0..100)
4370 .map(|i| {
4371 MockLlmClient::tool_call_response(
4372 &format!("tool-{}", i),
4373 "bash",
4374 serde_json::json!({"command": "echo loop"}),
4375 )
4376 })
4377 .collect();
4378
4379 let mock_client = Arc::new(MockLlmClient::new(responses));
4380 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4381
4382 let config = AgentConfig {
4383 max_tool_rounds: 3,
4384 ..Default::default()
4385 };
4386
4387 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
4388 let result = agent.execute(&[], "Loop forever", None).await;
4389
4390 assert!(result.is_err());
4392 assert!(result.unwrap_err().to_string().contains("Max tool rounds"));
4393 }
4394
4395 #[tokio::test]
4396 async fn test_agent_no_permission_policy_defaults_to_ask() {
4397 let mock_client = Arc::new(MockLlmClient::new(vec![
4400 MockLlmClient::tool_call_response(
4401 "tool-1",
4402 "bash",
4403 serde_json::json!({"command": "rm -rf /tmp/test"}),
4404 ),
4405 MockLlmClient::text_response("Denied!"),
4406 ]));
4407
4408 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4409 let config = AgentConfig {
4410 permission_checker: None, ..Default::default()
4413 };
4414
4415 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
4416 let result = agent.execute(&[], "Delete", None).await.unwrap();
4417
4418 assert_eq!(result.text, "Denied!");
4420 assert_eq!(result.tool_calls_count, 1);
4421 }
4422
4423 #[tokio::test]
4424 async fn test_agent_permission_ask_without_cm_denies() {
4425 let mock_client = Arc::new(MockLlmClient::new(vec![
4428 MockLlmClient::tool_call_response(
4429 "tool-1",
4430 "bash",
4431 serde_json::json!({"command": "echo test"}),
4432 ),
4433 MockLlmClient::text_response("Denied!"),
4434 ]));
4435
4436 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4437
4438 let permission_policy = PermissionPolicy::new(); let config = AgentConfig {
4442 permission_checker: Some(Arc::new(permission_policy)),
4443 ..Default::default()
4445 };
4446
4447 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
4448 let result = agent.execute(&[], "Echo", None).await.unwrap();
4449
4450 assert_eq!(result.text, "Denied!");
4452 assert!(result.tool_calls_count >= 1);
4454 }
4455
4456 #[tokio::test]
4461 async fn test_agent_hitl_approved() {
4462 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
4463 use tokio::sync::broadcast;
4464
4465 let mock_client = Arc::new(MockLlmClient::new(vec![
4466 MockLlmClient::tool_call_response(
4467 "tool-1",
4468 "bash",
4469 serde_json::json!({"command": "echo hello"}),
4470 ),
4471 MockLlmClient::text_response("Command executed!"),
4472 ]));
4473
4474 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4475
4476 let (event_tx, _event_rx) = broadcast::channel(100);
4478 let hitl_policy = ConfirmationPolicy {
4479 enabled: true,
4480 ..Default::default()
4481 };
4482 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
4483
4484 let permission_policy = PermissionPolicy::new(); let config = AgentConfig {
4488 permission_checker: Some(Arc::new(permission_policy)),
4489 confirmation_manager: Some(confirmation_manager.clone()),
4490 ..Default::default()
4491 };
4492
4493 let cm_clone = confirmation_manager.clone();
4495 tokio::spawn(async move {
4496 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
4498 cm_clone.confirm("tool-1", true, None).await.ok();
4500 });
4501
4502 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
4503 let result = agent.execute(&[], "Run echo", None).await.unwrap();
4504
4505 assert_eq!(result.text, "Command executed!");
4506 assert_eq!(result.tool_calls_count, 1);
4507 }
4508
4509 #[tokio::test]
4510 async fn test_agent_hitl_rejected() {
4511 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
4512 use tokio::sync::broadcast;
4513
4514 let mock_client = Arc::new(MockLlmClient::new(vec![
4515 MockLlmClient::tool_call_response(
4516 "tool-1",
4517 "bash",
4518 serde_json::json!({"command": "rm -rf /"}),
4519 ),
4520 MockLlmClient::text_response("Understood, I won't do that."),
4521 ]));
4522
4523 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4524
4525 let (event_tx, _event_rx) = broadcast::channel(100);
4527 let hitl_policy = ConfirmationPolicy {
4528 enabled: true,
4529 ..Default::default()
4530 };
4531 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
4532
4533 let permission_policy = PermissionPolicy::new();
4535
4536 let config = AgentConfig {
4537 permission_checker: Some(Arc::new(permission_policy)),
4538 confirmation_manager: Some(confirmation_manager.clone()),
4539 ..Default::default()
4540 };
4541
4542 let cm_clone = confirmation_manager.clone();
4544 tokio::spawn(async move {
4545 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
4546 cm_clone
4547 .confirm("tool-1", false, Some("Too dangerous".to_string()))
4548 .await
4549 .ok();
4550 });
4551
4552 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
4553 let result = agent.execute(&[], "Delete everything", None).await.unwrap();
4554
4555 assert_eq!(result.text, "Understood, I won't do that.");
4557 }
4558
4559 #[tokio::test]
4560 async fn test_agent_hitl_timeout_reject() {
4561 use crate::hitl::{ConfirmationManager, ConfirmationPolicy, TimeoutAction};
4562 use tokio::sync::broadcast;
4563
4564 let mock_client = Arc::new(MockLlmClient::new(vec![
4565 MockLlmClient::tool_call_response(
4566 "tool-1",
4567 "bash",
4568 serde_json::json!({"command": "echo test"}),
4569 ),
4570 MockLlmClient::text_response("Timed out, I understand."),
4571 ]));
4572
4573 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4574
4575 let (event_tx, _event_rx) = broadcast::channel(100);
4577 let hitl_policy = ConfirmationPolicy {
4578 enabled: true,
4579 default_timeout_ms: 50, timeout_action: TimeoutAction::Reject,
4581 ..Default::default()
4582 };
4583 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
4584
4585 let permission_policy = PermissionPolicy::new();
4586
4587 let config = AgentConfig {
4588 permission_checker: Some(Arc::new(permission_policy)),
4589 confirmation_manager: Some(confirmation_manager),
4590 ..Default::default()
4591 };
4592
4593 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
4595 let result = agent.execute(&[], "Echo", None).await.unwrap();
4596
4597 assert_eq!(result.text, "Timed out, I understand.");
4599 }
4600
4601 #[tokio::test]
4602 async fn test_agent_hitl_timeout_auto_approve() {
4603 use crate::hitl::{ConfirmationManager, ConfirmationPolicy, TimeoutAction};
4604 use tokio::sync::broadcast;
4605
4606 let mock_client = Arc::new(MockLlmClient::new(vec![
4607 MockLlmClient::tool_call_response(
4608 "tool-1",
4609 "bash",
4610 serde_json::json!({"command": "echo hello"}),
4611 ),
4612 MockLlmClient::text_response("Auto-approved and executed!"),
4613 ]));
4614
4615 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4616
4617 let (event_tx, _event_rx) = broadcast::channel(100);
4619 let hitl_policy = ConfirmationPolicy {
4620 enabled: true,
4621 default_timeout_ms: 50, timeout_action: TimeoutAction::AutoApprove,
4623 ..Default::default()
4624 };
4625 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
4626
4627 let permission_policy = PermissionPolicy::new();
4628
4629 let config = AgentConfig {
4630 permission_checker: Some(Arc::new(permission_policy)),
4631 confirmation_manager: Some(confirmation_manager),
4632 ..Default::default()
4633 };
4634
4635 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
4637 let result = agent.execute(&[], "Echo", None).await.unwrap();
4638
4639 assert_eq!(result.text, "Auto-approved and executed!");
4641 assert_eq!(result.tool_calls_count, 1);
4642 }
4643
4644 #[tokio::test]
4645 async fn test_agent_hitl_confirmation_events() {
4646 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
4647 use tokio::sync::broadcast;
4648
4649 let mock_client = Arc::new(MockLlmClient::new(vec![
4650 MockLlmClient::tool_call_response(
4651 "tool-1",
4652 "bash",
4653 serde_json::json!({"command": "echo test"}),
4654 ),
4655 MockLlmClient::text_response("Done!"),
4656 ]));
4657
4658 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4659
4660 let (event_tx, mut event_rx) = broadcast::channel(100);
4662 let hitl_policy = ConfirmationPolicy {
4663 enabled: true,
4664 default_timeout_ms: 5000, ..Default::default()
4666 };
4667 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
4668
4669 let permission_policy = PermissionPolicy::new();
4670
4671 let config = AgentConfig {
4672 permission_checker: Some(Arc::new(permission_policy)),
4673 confirmation_manager: Some(confirmation_manager.clone()),
4674 ..Default::default()
4675 };
4676
4677 let cm_clone = confirmation_manager.clone();
4679 let event_handle = tokio::spawn(async move {
4680 let mut events = Vec::new();
4681 while let Ok(event) = event_rx.recv().await {
4683 events.push(event.clone());
4684 if let AgentEvent::ConfirmationRequired { tool_id, .. } = event {
4685 cm_clone.confirm(&tool_id, true, None).await.ok();
4687 if let Ok(recv_event) = event_rx.recv().await {
4689 events.push(recv_event);
4690 }
4691 break;
4692 }
4693 }
4694 events
4695 });
4696
4697 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
4698 let _result = agent.execute(&[], "Echo", None).await.unwrap();
4699
4700 let events = event_handle.await.unwrap();
4702 assert!(
4703 events
4704 .iter()
4705 .any(|e| matches!(e, AgentEvent::ConfirmationRequired { .. })),
4706 "Should have ConfirmationRequired event"
4707 );
4708 assert!(
4709 events
4710 .iter()
4711 .any(|e| matches!(e, AgentEvent::ConfirmationReceived { approved: true, .. })),
4712 "Should have ConfirmationReceived event with approved=true"
4713 );
4714 }
4715
4716 #[tokio::test]
4717 async fn test_agent_hitl_disabled_auto_executes() {
4718 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
4720 use tokio::sync::broadcast;
4721
4722 let mock_client = Arc::new(MockLlmClient::new(vec![
4723 MockLlmClient::tool_call_response(
4724 "tool-1",
4725 "bash",
4726 serde_json::json!({"command": "echo auto"}),
4727 ),
4728 MockLlmClient::text_response("Auto executed!"),
4729 ]));
4730
4731 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4732
4733 let (event_tx, _event_rx) = broadcast::channel(100);
4735 let hitl_policy = ConfirmationPolicy {
4736 enabled: false, ..Default::default()
4738 };
4739 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
4740
4741 let permission_policy = PermissionPolicy::new(); let config = AgentConfig {
4744 permission_checker: Some(Arc::new(permission_policy)),
4745 confirmation_manager: Some(confirmation_manager),
4746 ..Default::default()
4747 };
4748
4749 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
4750 let result = agent.execute(&[], "Echo", None).await.unwrap();
4751
4752 assert_eq!(result.text, "Auto executed!");
4754 assert_eq!(result.tool_calls_count, 1);
4755 }
4756
4757 #[tokio::test]
4758 async fn test_agent_hitl_with_permission_deny_skips_hitl() {
4759 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
4761 use tokio::sync::broadcast;
4762
4763 let mock_client = Arc::new(MockLlmClient::new(vec![
4764 MockLlmClient::tool_call_response(
4765 "tool-1",
4766 "bash",
4767 serde_json::json!({"command": "rm -rf /"}),
4768 ),
4769 MockLlmClient::text_response("Blocked by permission."),
4770 ]));
4771
4772 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4773
4774 let (event_tx, mut event_rx) = broadcast::channel(100);
4776 let hitl_policy = ConfirmationPolicy {
4777 enabled: true,
4778 ..Default::default()
4779 };
4780 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
4781
4782 let permission_policy = PermissionPolicy::new().deny("bash(rm:*)");
4784
4785 let config = AgentConfig {
4786 permission_checker: Some(Arc::new(permission_policy)),
4787 confirmation_manager: Some(confirmation_manager),
4788 ..Default::default()
4789 };
4790
4791 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
4792 let result = agent.execute(&[], "Delete", None).await.unwrap();
4793
4794 assert_eq!(result.text, "Blocked by permission.");
4796
4797 let mut found_confirmation = false;
4799 while let Ok(event) = event_rx.try_recv() {
4800 if matches!(event, AgentEvent::ConfirmationRequired { .. }) {
4801 found_confirmation = true;
4802 }
4803 }
4804 assert!(
4805 !found_confirmation,
4806 "HITL should not be triggered when permission is Deny"
4807 );
4808 }
4809
4810 #[tokio::test]
4811 async fn test_agent_hitl_with_permission_allow_skips_hitl() {
4812 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
4815 use tokio::sync::broadcast;
4816
4817 let mock_client = Arc::new(MockLlmClient::new(vec![
4818 MockLlmClient::tool_call_response(
4819 "tool-1",
4820 "bash",
4821 serde_json::json!({"command": "echo hello"}),
4822 ),
4823 MockLlmClient::text_response("Allowed!"),
4824 ]));
4825
4826 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4827
4828 let (event_tx, mut event_rx) = broadcast::channel(100);
4830 let hitl_policy = ConfirmationPolicy {
4831 enabled: true,
4832 ..Default::default()
4833 };
4834 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
4835
4836 let permission_policy = PermissionPolicy::new().allow("bash(echo:*)");
4838
4839 let config = AgentConfig {
4840 permission_checker: Some(Arc::new(permission_policy)),
4841 confirmation_manager: Some(confirmation_manager.clone()),
4842 ..Default::default()
4843 };
4844
4845 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
4846 let result = agent.execute(&[], "Echo", None).await.unwrap();
4847
4848 assert_eq!(result.text, "Allowed!");
4850
4851 let mut found_confirmation = false;
4853 while let Ok(event) = event_rx.try_recv() {
4854 if matches!(event, AgentEvent::ConfirmationRequired { .. }) {
4855 found_confirmation = true;
4856 }
4857 }
4858 assert!(
4859 !found_confirmation,
4860 "Permission Allow should skip HITL confirmation"
4861 );
4862 }
4863
4864 #[tokio::test]
4865 async fn test_agent_hitl_multiple_tool_calls() {
4866 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
4868 use tokio::sync::broadcast;
4869
4870 let mock_client = Arc::new(MockLlmClient::new(vec![
4871 LlmResponse {
4873 message: Message {
4874 role: "assistant".to_string(),
4875 content: vec![
4876 ContentBlock::ToolUse {
4877 id: "tool-1".to_string(),
4878 name: "bash".to_string(),
4879 input: serde_json::json!({"command": "echo first"}),
4880 },
4881 ContentBlock::ToolUse {
4882 id: "tool-2".to_string(),
4883 name: "bash".to_string(),
4884 input: serde_json::json!({"command": "echo second"}),
4885 },
4886 ],
4887 reasoning_content: None,
4888 },
4889 usage: TokenUsage {
4890 prompt_tokens: 10,
4891 completion_tokens: 5,
4892 total_tokens: 15,
4893 cache_read_tokens: None,
4894 cache_write_tokens: None,
4895 },
4896 stop_reason: Some("tool_use".to_string()),
4897 meta: None,
4898 },
4899 MockLlmClient::text_response("Both executed!"),
4900 ]));
4901
4902 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4903
4904 let (event_tx, _event_rx) = broadcast::channel(100);
4906 let hitl_policy = ConfirmationPolicy {
4907 enabled: true,
4908 default_timeout_ms: 5000,
4909 ..Default::default()
4910 };
4911 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
4912
4913 let permission_policy = PermissionPolicy::new(); let config = AgentConfig {
4916 permission_checker: Some(Arc::new(permission_policy)),
4917 confirmation_manager: Some(confirmation_manager.clone()),
4918 ..Default::default()
4919 };
4920
4921 let cm_clone = confirmation_manager.clone();
4923 tokio::spawn(async move {
4924 tokio::time::sleep(std::time::Duration::from_millis(30)).await;
4925 cm_clone.confirm("tool-1", true, None).await.ok();
4926 tokio::time::sleep(std::time::Duration::from_millis(30)).await;
4927 cm_clone.confirm("tool-2", true, None).await.ok();
4928 });
4929
4930 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
4931 let result = agent.execute(&[], "Run both", None).await.unwrap();
4932
4933 assert_eq!(result.text, "Both executed!");
4934 assert_eq!(result.tool_calls_count, 2);
4935 }
4936
4937 #[tokio::test]
4938 async fn test_agent_hitl_partial_approval() {
4939 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
4941 use tokio::sync::broadcast;
4942
4943 let mock_client = Arc::new(MockLlmClient::new(vec![
4944 LlmResponse {
4946 message: Message {
4947 role: "assistant".to_string(),
4948 content: vec![
4949 ContentBlock::ToolUse {
4950 id: "tool-1".to_string(),
4951 name: "bash".to_string(),
4952 input: serde_json::json!({"command": "echo safe"}),
4953 },
4954 ContentBlock::ToolUse {
4955 id: "tool-2".to_string(),
4956 name: "bash".to_string(),
4957 input: serde_json::json!({"command": "rm -rf /"}),
4958 },
4959 ],
4960 reasoning_content: None,
4961 },
4962 usage: TokenUsage {
4963 prompt_tokens: 10,
4964 completion_tokens: 5,
4965 total_tokens: 15,
4966 cache_read_tokens: None,
4967 cache_write_tokens: None,
4968 },
4969 stop_reason: Some("tool_use".to_string()),
4970 meta: None,
4971 },
4972 MockLlmClient::text_response("First worked, second rejected."),
4973 ]));
4974
4975 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4976
4977 let (event_tx, _event_rx) = broadcast::channel(100);
4978 let hitl_policy = ConfirmationPolicy {
4979 enabled: true,
4980 default_timeout_ms: 5000,
4981 ..Default::default()
4982 };
4983 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
4984
4985 let permission_policy = PermissionPolicy::new();
4986
4987 let config = AgentConfig {
4988 permission_checker: Some(Arc::new(permission_policy)),
4989 confirmation_manager: Some(confirmation_manager.clone()),
4990 ..Default::default()
4991 };
4992
4993 let cm_clone = confirmation_manager.clone();
4995 tokio::spawn(async move {
4996 tokio::time::sleep(std::time::Duration::from_millis(30)).await;
4997 cm_clone.confirm("tool-1", true, None).await.ok();
4998 tokio::time::sleep(std::time::Duration::from_millis(30)).await;
4999 cm_clone
5000 .confirm("tool-2", false, Some("Dangerous".to_string()))
5001 .await
5002 .ok();
5003 });
5004
5005 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
5006 let result = agent.execute(&[], "Run both", None).await.unwrap();
5007
5008 assert_eq!(result.text, "First worked, second rejected.");
5009 assert_eq!(result.tool_calls_count, 2);
5010 }
5011
5012 #[tokio::test]
5013 async fn test_agent_hitl_yolo_mode_auto_approves() {
5014 use crate::hitl::{ConfirmationManager, ConfirmationPolicy, SessionLane};
5016 use tokio::sync::broadcast;
5017
5018 let mock_client = Arc::new(MockLlmClient::new(vec![
5019 MockLlmClient::tool_call_response(
5020 "tool-1",
5021 "read", serde_json::json!({"path": "/tmp/test.txt"}),
5023 ),
5024 MockLlmClient::text_response("File read!"),
5025 ]));
5026
5027 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5028
5029 let (event_tx, mut event_rx) = broadcast::channel(100);
5031 let mut yolo_lanes = std::collections::HashSet::new();
5032 yolo_lanes.insert(SessionLane::Query);
5033 let hitl_policy = ConfirmationPolicy {
5034 enabled: true,
5035 yolo_lanes, ..Default::default()
5037 };
5038 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
5039
5040 let permission_policy = PermissionPolicy::new();
5041
5042 let config = AgentConfig {
5043 permission_checker: Some(Arc::new(permission_policy)),
5044 confirmation_manager: Some(confirmation_manager),
5045 ..Default::default()
5046 };
5047
5048 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
5049 let result = agent.execute(&[], "Read file", None).await.unwrap();
5050
5051 assert_eq!(result.text, "File read!");
5053
5054 let mut found_confirmation = false;
5056 while let Ok(event) = event_rx.try_recv() {
5057 if matches!(event, AgentEvent::ConfirmationRequired { .. }) {
5058 found_confirmation = true;
5059 }
5060 }
5061 assert!(
5062 !found_confirmation,
5063 "YOLO mode should not trigger confirmation"
5064 );
5065 }
5066
5067 #[tokio::test]
5068 async fn test_agent_config_with_all_options() {
5069 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
5070 use tokio::sync::broadcast;
5071
5072 let (event_tx, _) = broadcast::channel(100);
5073 let hitl_policy = ConfirmationPolicy::default();
5074 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
5075
5076 let permission_policy = PermissionPolicy::new().allow("bash(*)");
5077
5078 let config = AgentConfig {
5079 prompt_slots: SystemPromptSlots {
5080 extra: Some("Test system prompt".to_string()),
5081 ..Default::default()
5082 },
5083 tools: vec![],
5084 max_tool_rounds: 10,
5085 permission_checker: Some(Arc::new(permission_policy)),
5086 confirmation_manager: Some(confirmation_manager),
5087 context_providers: vec![],
5088 planning_mode: PlanningMode::default(),
5089 goal_tracking: false,
5090 hook_engine: None,
5091 skill_registry: None,
5092 ..AgentConfig::default()
5093 };
5094
5095 assert!(config.prompt_slots.build().contains("Test system prompt"));
5096 assert_eq!(config.max_tool_rounds, 10);
5097 assert!(config.permission_checker.is_some());
5098 assert!(config.confirmation_manager.is_some());
5099 assert!(config.context_providers.is_empty());
5100
5101 let debug_str = format!("{:?}", config);
5103 assert!(debug_str.contains("AgentConfig"));
5104 assert!(debug_str.contains("permission_checker: true"));
5105 assert!(debug_str.contains("confirmation_manager: true"));
5106 assert!(debug_str.contains("context_providers: 0"));
5107 }
5108
5109 use crate::context::{ContextItem, ContextType};
5114
5115 struct MockContextProvider {
5117 name: String,
5118 items: Vec<ContextItem>,
5119 on_turn_calls: std::sync::Arc<tokio::sync::RwLock<Vec<(String, String, String)>>>,
5120 }
5121
5122 impl MockContextProvider {
5123 fn new(name: &str) -> Self {
5124 Self {
5125 name: name.to_string(),
5126 items: Vec::new(),
5127 on_turn_calls: std::sync::Arc::new(tokio::sync::RwLock::new(Vec::new())),
5128 }
5129 }
5130
5131 fn with_items(mut self, items: Vec<ContextItem>) -> Self {
5132 self.items = items;
5133 self
5134 }
5135 }
5136
5137 #[async_trait::async_trait]
5138 impl ContextProvider for MockContextProvider {
5139 fn name(&self) -> &str {
5140 &self.name
5141 }
5142
5143 async fn query(&self, _query: &ContextQuery) -> anyhow::Result<ContextResult> {
5144 let mut result = ContextResult::new(&self.name);
5145 for item in &self.items {
5146 result.add_item(item.clone());
5147 }
5148 Ok(result)
5149 }
5150
5151 async fn on_turn_complete(
5152 &self,
5153 session_id: &str,
5154 prompt: &str,
5155 response: &str,
5156 ) -> anyhow::Result<()> {
5157 let mut calls = self.on_turn_calls.write().await;
5158 calls.push((
5159 session_id.to_string(),
5160 prompt.to_string(),
5161 response.to_string(),
5162 ));
5163 Ok(())
5164 }
5165 }
5166
5167 #[tokio::test]
5168 async fn test_agent_with_context_provider() {
5169 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
5170 "Response using context",
5171 )]));
5172
5173 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5174
5175 let provider =
5176 MockContextProvider::new("test-provider").with_items(vec![ContextItem::new(
5177 "ctx-1",
5178 ContextType::Resource,
5179 "Relevant context here",
5180 )
5181 .with_source("test://docs/example")]);
5182
5183 let config = AgentConfig {
5184 prompt_slots: SystemPromptSlots {
5185 extra: Some("You are helpful.".to_string()),
5186 ..Default::default()
5187 },
5188 context_providers: vec![Arc::new(provider)],
5189 ..Default::default()
5190 };
5191
5192 let agent = AgentLoop::new(
5193 mock_client.clone(),
5194 tool_executor,
5195 test_tool_context(),
5196 config,
5197 );
5198 let result = agent.execute(&[], "What is X?", None).await.unwrap();
5199
5200 assert_eq!(result.text, "Response using context");
5201 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 1);
5202 }
5203
5204 #[tokio::test]
5205 async fn test_agent_context_provider_events() {
5206 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
5207 "Answer",
5208 )]));
5209
5210 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5211
5212 let provider =
5213 MockContextProvider::new("event-provider").with_items(vec![ContextItem::new(
5214 "item-1",
5215 ContextType::Memory,
5216 "Memory content",
5217 )
5218 .with_token_count(50)]);
5219
5220 let config = AgentConfig {
5221 context_providers: vec![Arc::new(provider)],
5222 ..Default::default()
5223 };
5224
5225 let (tx, mut rx) = mpsc::channel(100);
5226 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
5227 let _result = agent.execute(&[], "Test prompt", Some(tx)).await.unwrap();
5228
5229 let mut events = Vec::new();
5231 while let Ok(event) = rx.try_recv() {
5232 events.push(event);
5233 }
5234
5235 assert!(
5237 events
5238 .iter()
5239 .any(|e| matches!(e, AgentEvent::ContextResolving { .. })),
5240 "Should have ContextResolving event"
5241 );
5242 assert!(
5243 events
5244 .iter()
5245 .any(|e| matches!(e, AgentEvent::ContextResolved { .. })),
5246 "Should have ContextResolved event"
5247 );
5248
5249 for event in &events {
5251 if let AgentEvent::ContextResolved {
5252 total_items,
5253 total_tokens,
5254 } = event
5255 {
5256 assert_eq!(*total_items, 1);
5257 assert_eq!(*total_tokens, 50);
5258 }
5259 }
5260 }
5261
5262 #[tokio::test]
5263 async fn test_agent_multiple_context_providers() {
5264 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
5265 "Combined response",
5266 )]));
5267
5268 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5269
5270 let provider1 = MockContextProvider::new("provider-1").with_items(vec![ContextItem::new(
5271 "p1-1",
5272 ContextType::Resource,
5273 "Resource from P1",
5274 )
5275 .with_token_count(100)]);
5276
5277 let provider2 = MockContextProvider::new("provider-2").with_items(vec![
5278 ContextItem::new("p2-1", ContextType::Memory, "Memory from P2").with_token_count(50),
5279 ContextItem::new("p2-2", ContextType::Skill, "Skill from P2").with_token_count(75),
5280 ]);
5281
5282 let config = AgentConfig {
5283 prompt_slots: SystemPromptSlots {
5284 extra: Some("Base system prompt.".to_string()),
5285 ..Default::default()
5286 },
5287 context_providers: vec![Arc::new(provider1), Arc::new(provider2)],
5288 ..Default::default()
5289 };
5290
5291 let (tx, mut rx) = mpsc::channel(100);
5292 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
5293 let result = agent.execute(&[], "Query", Some(tx)).await.unwrap();
5294
5295 assert_eq!(result.text, "Combined response");
5296
5297 while let Ok(event) = rx.try_recv() {
5299 if let AgentEvent::ContextResolved {
5300 total_items,
5301 total_tokens,
5302 } = event
5303 {
5304 assert_eq!(total_items, 3); assert_eq!(total_tokens, 225); }
5307 }
5308 }
5309
5310 #[tokio::test]
5311 async fn test_agent_no_context_providers() {
5312 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
5313 "No context",
5314 )]));
5315
5316 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5317
5318 let config = AgentConfig::default();
5320
5321 let (tx, mut rx) = mpsc::channel(100);
5322 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
5323 let result = agent.execute(&[], "Simple prompt", Some(tx)).await.unwrap();
5324
5325 assert_eq!(result.text, "No context");
5326
5327 let mut events = Vec::new();
5329 while let Ok(event) = rx.try_recv() {
5330 events.push(event);
5331 }
5332
5333 assert!(
5334 !events
5335 .iter()
5336 .any(|e| matches!(e, AgentEvent::ContextResolving { .. })),
5337 "Should NOT have ContextResolving event"
5338 );
5339 }
5340
5341 #[tokio::test]
5342 async fn test_agent_context_on_turn_complete() {
5343 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
5344 "Final response",
5345 )]));
5346
5347 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5348
5349 let provider = Arc::new(MockContextProvider::new("memory-provider"));
5350 let on_turn_calls = provider.on_turn_calls.clone();
5351
5352 let config = AgentConfig {
5353 context_providers: vec![provider],
5354 ..Default::default()
5355 };
5356
5357 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
5358
5359 let result = agent
5361 .execute_with_session(&[], "User prompt", Some("sess-123"), None, None)
5362 .await
5363 .unwrap();
5364
5365 assert_eq!(result.text, "Final response");
5366
5367 let calls = on_turn_calls.read().await;
5369 assert_eq!(calls.len(), 1);
5370 assert_eq!(calls[0].0, "sess-123");
5371 assert_eq!(calls[0].1, "User prompt");
5372 assert_eq!(calls[0].2, "Final response");
5373 }
5374
5375 #[tokio::test]
5376 async fn test_agent_context_on_turn_complete_no_session() {
5377 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
5378 "Response",
5379 )]));
5380
5381 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5382
5383 let provider = Arc::new(MockContextProvider::new("memory-provider"));
5384 let on_turn_calls = provider.on_turn_calls.clone();
5385
5386 let config = AgentConfig {
5387 context_providers: vec![provider],
5388 ..Default::default()
5389 };
5390
5391 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
5392
5393 let _result = agent.execute(&[], "Prompt", None).await.unwrap();
5395
5396 let calls = on_turn_calls.read().await;
5398 assert!(calls.is_empty());
5399 }
5400
5401 #[tokio::test]
5402 async fn test_agent_build_augmented_system_prompt() {
5403 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response("OK")]));
5404
5405 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5406
5407 let provider = MockContextProvider::new("test").with_items(vec![ContextItem::new(
5408 "doc-1",
5409 ContextType::Resource,
5410 "Auth uses JWT tokens.",
5411 )
5412 .with_source("viking://docs/auth")]);
5413
5414 let config = AgentConfig {
5415 prompt_slots: SystemPromptSlots {
5416 extra: Some("You are helpful.".to_string()),
5417 ..Default::default()
5418 },
5419 context_providers: vec![Arc::new(provider)],
5420 ..Default::default()
5421 };
5422
5423 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
5424
5425 let context_results = agent.resolve_context("test", None).await;
5427 let augmented = agent.build_augmented_system_prompt(&context_results);
5428
5429 let augmented_str = augmented.unwrap();
5430 assert!(augmented_str.contains("You are helpful."));
5431 assert!(augmented_str.contains("<context source=\"viking://docs/auth\" type=\"Resource\">"));
5432 assert!(augmented_str.contains("Auth uses JWT tokens."));
5433 }
5434
5435 async fn collect_events(mut rx: mpsc::Receiver<AgentEvent>) -> Vec<AgentEvent> {
5441 let mut events = Vec::new();
5442 while let Ok(event) = rx.try_recv() {
5443 events.push(event);
5444 }
5445 while let Some(event) = rx.recv().await {
5447 events.push(event);
5448 }
5449 events
5450 }
5451
5452 #[tokio::test]
5453 async fn test_agent_multi_turn_tool_chain() {
5454 let mock_client = Arc::new(MockLlmClient::new(vec![
5456 MockLlmClient::tool_call_response(
5458 "t1",
5459 "bash",
5460 serde_json::json!({"command": "echo step1"}),
5461 ),
5462 MockLlmClient::tool_call_response(
5464 "t2",
5465 "bash",
5466 serde_json::json!({"command": "echo step2"}),
5467 ),
5468 MockLlmClient::text_response("Completed both steps: step1 then step2"),
5470 ]));
5471
5472 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5473 let config = AgentConfig::default();
5474
5475 let agent = AgentLoop::new(
5476 mock_client.clone(),
5477 tool_executor,
5478 test_tool_context(),
5479 config,
5480 );
5481 let result = agent.execute(&[], "Run two steps", None).await.unwrap();
5482
5483 assert_eq!(result.text, "Completed both steps: step1 then step2");
5484 assert_eq!(result.tool_calls_count, 2);
5485 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 3);
5486
5487 assert_eq!(result.messages[0].role, "user");
5489 assert_eq!(result.messages[1].role, "assistant"); assert_eq!(result.messages[2].role, "user"); assert_eq!(result.messages[3].role, "assistant"); assert_eq!(result.messages[4].role, "user"); assert_eq!(result.messages[5].role, "assistant"); assert_eq!(result.messages.len(), 6);
5495 }
5496
5497 #[tokio::test]
5498 async fn test_agent_conversation_history_preserved() {
5499 let existing_history = vec![
5501 Message::user("What is Rust?"),
5502 Message {
5503 role: "assistant".to_string(),
5504 content: vec![ContentBlock::Text {
5505 text: "Rust is a systems programming language.".to_string(),
5506 }],
5507 reasoning_content: None,
5508 },
5509 ];
5510
5511 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
5512 "Rust was created by Graydon Hoare at Mozilla.",
5513 )]));
5514
5515 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5516 let agent = AgentLoop::new(
5517 mock_client.clone(),
5518 tool_executor,
5519 test_tool_context(),
5520 AgentConfig::default(),
5521 );
5522
5523 let result = agent
5524 .execute(&existing_history, "Who created it?", None)
5525 .await
5526 .unwrap();
5527
5528 assert_eq!(result.messages.len(), 4);
5530 assert_eq!(result.messages[0].text(), "What is Rust?");
5531 assert_eq!(
5532 result.messages[1].text(),
5533 "Rust is a systems programming language."
5534 );
5535 assert_eq!(result.messages[2].text(), "Who created it?");
5536 assert_eq!(
5537 result.messages[3].text(),
5538 "Rust was created by Graydon Hoare at Mozilla."
5539 );
5540 }
5541
5542 #[tokio::test]
5543 async fn test_agent_event_stream_completeness() {
5544 let mock_client = Arc::new(MockLlmClient::new(vec![
5546 MockLlmClient::tool_call_response(
5547 "t1",
5548 "bash",
5549 serde_json::json!({"command": "echo hi"}),
5550 ),
5551 MockLlmClient::text_response("Done"),
5552 ]));
5553
5554 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5555 let agent = AgentLoop::new(
5556 mock_client,
5557 tool_executor,
5558 test_tool_context(),
5559 AgentConfig::default(),
5560 );
5561
5562 let (tx, rx) = mpsc::channel(100);
5563 let result = agent.execute(&[], "Say hi", Some(tx)).await.unwrap();
5564 assert_eq!(result.text, "Done");
5565
5566 let events = collect_events(rx).await;
5567
5568 let event_types: Vec<&str> = events
5570 .iter()
5571 .map(|e| match e {
5572 AgentEvent::Start { .. } => "Start",
5573 AgentEvent::TurnStart { .. } => "TurnStart",
5574 AgentEvent::TurnEnd { .. } => "TurnEnd",
5575 AgentEvent::ToolEnd { .. } => "ToolEnd",
5576 AgentEvent::End { .. } => "End",
5577 _ => "Other",
5578 })
5579 .collect();
5580
5581 assert_eq!(event_types.first(), Some(&"Start"));
5583 assert_eq!(event_types.last(), Some(&"End"));
5584
5585 let turn_starts = event_types.iter().filter(|&&t| t == "TurnStart").count();
5587 assert_eq!(turn_starts, 2);
5588
5589 let tool_ends = event_types.iter().filter(|&&t| t == "ToolEnd").count();
5591 assert_eq!(tool_ends, 1);
5592 }
5593
5594 #[tokio::test]
5595 async fn test_agent_multiple_tools_single_turn() {
5596 let mock_client = Arc::new(MockLlmClient::new(vec![
5598 LlmResponse {
5599 message: Message {
5600 role: "assistant".to_string(),
5601 content: vec![
5602 ContentBlock::ToolUse {
5603 id: "t1".to_string(),
5604 name: "bash".to_string(),
5605 input: serde_json::json!({"command": "echo first"}),
5606 },
5607 ContentBlock::ToolUse {
5608 id: "t2".to_string(),
5609 name: "bash".to_string(),
5610 input: serde_json::json!({"command": "echo second"}),
5611 },
5612 ],
5613 reasoning_content: None,
5614 },
5615 usage: TokenUsage {
5616 prompt_tokens: 10,
5617 completion_tokens: 5,
5618 total_tokens: 15,
5619 cache_read_tokens: None,
5620 cache_write_tokens: None,
5621 },
5622 stop_reason: Some("tool_use".to_string()),
5623 meta: None,
5624 },
5625 MockLlmClient::text_response("Both commands ran"),
5626 ]));
5627
5628 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5629 let agent = AgentLoop::new(
5630 mock_client.clone(),
5631 tool_executor,
5632 test_tool_context(),
5633 AgentConfig::default(),
5634 );
5635
5636 let result = agent.execute(&[], "Run both", None).await.unwrap();
5637
5638 assert_eq!(result.text, "Both commands ran");
5639 assert_eq!(result.tool_calls_count, 2);
5640 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 2); assert_eq!(result.messages[0].role, "user");
5644 assert_eq!(result.messages[1].role, "assistant");
5645 assert_eq!(result.messages[2].role, "user"); assert_eq!(result.messages[3].role, "user"); assert_eq!(result.messages[4].role, "assistant");
5648 }
5649
5650 #[tokio::test]
5651 async fn test_agent_token_usage_accumulation() {
5652 let mock_client = Arc::new(MockLlmClient::new(vec![
5654 MockLlmClient::tool_call_response(
5655 "t1",
5656 "bash",
5657 serde_json::json!({"command": "echo x"}),
5658 ),
5659 MockLlmClient::text_response("Done"),
5660 ]));
5661
5662 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5663 let agent = AgentLoop::new(
5664 mock_client,
5665 tool_executor,
5666 test_tool_context(),
5667 AgentConfig::default(),
5668 );
5669
5670 let result = agent.execute(&[], "test", None).await.unwrap();
5671
5672 assert_eq!(result.usage.prompt_tokens, 20);
5675 assert_eq!(result.usage.completion_tokens, 10);
5676 assert_eq!(result.usage.total_tokens, 30);
5677 }
5678
5679 #[tokio::test]
5680 async fn test_agent_system_prompt_passed() {
5681 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
5683 "I am a coding assistant.",
5684 )]));
5685
5686 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5687 let config = AgentConfig {
5688 prompt_slots: SystemPromptSlots {
5689 extra: Some("You are a coding assistant.".to_string()),
5690 ..Default::default()
5691 },
5692 ..Default::default()
5693 };
5694
5695 let agent = AgentLoop::new(
5696 mock_client.clone(),
5697 tool_executor,
5698 test_tool_context(),
5699 config,
5700 );
5701 let result = agent.execute(&[], "What are you?", None).await.unwrap();
5702
5703 assert_eq!(result.text, "I am a coding assistant.");
5704 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 1);
5705 }
5706
5707 #[tokio::test]
5708 async fn test_agent_max_rounds_with_persistent_tool_calls() {
5709 let mut responses = Vec::new();
5711 for i in 0..15 {
5712 responses.push(MockLlmClient::tool_call_response(
5713 &format!("t{}", i),
5714 "bash",
5715 serde_json::json!({"command": format!("echo round{}", i)}),
5716 ));
5717 }
5718
5719 let mock_client = Arc::new(MockLlmClient::new(responses));
5720 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5721 let config = AgentConfig {
5722 max_tool_rounds: 5,
5723 ..Default::default()
5724 };
5725
5726 let agent = AgentLoop::new(
5727 mock_client.clone(),
5728 tool_executor,
5729 test_tool_context(),
5730 config,
5731 );
5732 let result = agent.execute(&[], "Loop forever", None).await;
5733
5734 assert!(result.is_err());
5735 let err = result.unwrap_err().to_string();
5736 assert!(err.contains("Max tool rounds (5) exceeded"));
5737 }
5738
5739 #[tokio::test]
5740 async fn test_agent_end_event_contains_final_text() {
5741 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
5742 "Final answer here",
5743 )]));
5744
5745 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5746 let agent = AgentLoop::new(
5747 mock_client,
5748 tool_executor,
5749 test_tool_context(),
5750 AgentConfig::default(),
5751 );
5752
5753 let (tx, rx) = mpsc::channel(100);
5754 agent.execute(&[], "test", Some(tx)).await.unwrap();
5755
5756 let events = collect_events(rx).await;
5757 let end_event = events.iter().find(|e| matches!(e, AgentEvent::End { .. }));
5758 assert!(end_event.is_some());
5759
5760 if let AgentEvent::End { text, usage, .. } = end_event.unwrap() {
5761 assert_eq!(text, "Final answer here");
5762 assert_eq!(usage.total_tokens, 15);
5763 }
5764 }
5765}
5766
5767#[cfg(test)]
5768mod extra_agent_tests {
5769 use super::*;
5770 use crate::agent::tests::MockLlmClient;
5771 use crate::queue::SessionQueueConfig;
5772 use crate::tools::ToolExecutor;
5773 use std::path::PathBuf;
5774 use std::sync::atomic::{AtomicUsize, Ordering};
5775
5776 fn test_tool_context() -> ToolContext {
5777 ToolContext::new(PathBuf::from("/tmp"))
5778 }
5779
5780 #[test]
5785 fn test_agent_config_debug() {
5786 let config = AgentConfig {
5787 prompt_slots: SystemPromptSlots {
5788 extra: Some("You are helpful".to_string()),
5789 ..Default::default()
5790 },
5791 tools: vec![],
5792 max_tool_rounds: 10,
5793 permission_checker: None,
5794 confirmation_manager: None,
5795 context_providers: vec![],
5796 planning_mode: PlanningMode::Enabled,
5797 goal_tracking: false,
5798 hook_engine: None,
5799 skill_registry: None,
5800 ..AgentConfig::default()
5801 };
5802 let debug = format!("{:?}", config);
5803 assert!(debug.contains("AgentConfig"));
5804 assert!(debug.contains("planning_mode"));
5805 }
5806
5807 #[test]
5808 fn test_agent_config_default_values() {
5809 let config = AgentConfig::default();
5810 assert_eq!(config.max_tool_rounds, MAX_TOOL_ROUNDS);
5811 assert_eq!(config.planning_mode, PlanningMode::Auto);
5812 assert!(!config.goal_tracking);
5813 assert!(config.context_providers.is_empty());
5814 }
5815
5816 #[test]
5821 fn test_agent_event_serialize_start() {
5822 let event = AgentEvent::Start {
5823 prompt: "Hello".to_string(),
5824 };
5825 let json = serde_json::to_string(&event).unwrap();
5826 assert!(json.contains("agent_start"));
5827 assert!(json.contains("Hello"));
5828 }
5829
5830 #[test]
5831 fn test_agent_event_serialize_text_delta() {
5832 let event = AgentEvent::TextDelta {
5833 text: "chunk".to_string(),
5834 };
5835 let json = serde_json::to_string(&event).unwrap();
5836 assert!(json.contains("text_delta"));
5837 }
5838
5839 #[test]
5840 fn test_agent_event_serialize_tool_start() {
5841 let event = AgentEvent::ToolStart {
5842 id: "t1".to_string(),
5843 name: "bash".to_string(),
5844 };
5845 let json = serde_json::to_string(&event).unwrap();
5846 assert!(json.contains("tool_start"));
5847 assert!(json.contains("bash"));
5848 }
5849
5850 #[test]
5851 fn test_agent_event_serialize_tool_end() {
5852 let event = AgentEvent::ToolEnd {
5853 id: "t1".to_string(),
5854 name: "bash".to_string(),
5855 output: "hello".to_string(),
5856 exit_code: 0,
5857 metadata: None,
5858 };
5859 let json = serde_json::to_string(&event).unwrap();
5860 assert!(json.contains("tool_end"));
5861 }
5862
5863 #[test]
5864 fn test_agent_event_tool_end_has_metadata_field() {
5865 let event = AgentEvent::ToolEnd {
5866 id: "t1".to_string(),
5867 name: "write".to_string(),
5868 output: "Wrote 5 bytes".to_string(),
5869 exit_code: 0,
5870 metadata: Some(
5871 serde_json::json!({ "before": "old", "after": "new", "file_path": "f.txt" }),
5872 ),
5873 };
5874 let json = serde_json::to_string(&event).unwrap();
5875 assert!(json.contains("\"before\""));
5876 }
5877
5878 #[test]
5879 fn test_agent_event_serialize_error() {
5880 let event = AgentEvent::Error {
5881 message: "oops".to_string(),
5882 };
5883 let json = serde_json::to_string(&event).unwrap();
5884 assert!(json.contains("error"));
5885 assert!(json.contains("oops"));
5886 }
5887
5888 #[test]
5889 fn test_agent_event_serialize_confirmation_required() {
5890 let event = AgentEvent::ConfirmationRequired {
5891 tool_id: "t1".to_string(),
5892 tool_name: "bash".to_string(),
5893 args: serde_json::json!({"cmd": "rm"}),
5894 timeout_ms: 30000,
5895 };
5896 let json = serde_json::to_string(&event).unwrap();
5897 assert!(json.contains("confirmation_required"));
5898 }
5899
5900 #[test]
5901 fn test_agent_event_serialize_confirmation_received() {
5902 let event = AgentEvent::ConfirmationReceived {
5903 tool_id: "t1".to_string(),
5904 approved: true,
5905 reason: Some("safe".to_string()),
5906 };
5907 let json = serde_json::to_string(&event).unwrap();
5908 assert!(json.contains("confirmation_received"));
5909 }
5910
5911 #[test]
5912 fn test_agent_event_serialize_confirmation_timeout() {
5913 let event = AgentEvent::ConfirmationTimeout {
5914 tool_id: "t1".to_string(),
5915 action_taken: "rejected".to_string(),
5916 };
5917 let json = serde_json::to_string(&event).unwrap();
5918 assert!(json.contains("confirmation_timeout"));
5919 }
5920
5921 #[test]
5922 fn test_agent_event_serialize_external_task_pending() {
5923 let event = AgentEvent::ExternalTaskPending {
5924 task_id: "task-1".to_string(),
5925 session_id: "sess-1".to_string(),
5926 lane: crate::hitl::SessionLane::Execute,
5927 command_type: "bash".to_string(),
5928 payload: serde_json::json!({}),
5929 timeout_ms: 60000,
5930 };
5931 let json = serde_json::to_string(&event).unwrap();
5932 assert!(json.contains("external_task_pending"));
5933 }
5934
5935 #[test]
5936 fn test_agent_event_serialize_external_task_completed() {
5937 let event = AgentEvent::ExternalTaskCompleted {
5938 task_id: "task-1".to_string(),
5939 session_id: "sess-1".to_string(),
5940 success: false,
5941 };
5942 let json = serde_json::to_string(&event).unwrap();
5943 assert!(json.contains("external_task_completed"));
5944 }
5945
5946 #[test]
5947 fn test_agent_event_serialize_permission_denied() {
5948 let event = AgentEvent::PermissionDenied {
5949 tool_id: "t1".to_string(),
5950 tool_name: "bash".to_string(),
5951 args: serde_json::json!({}),
5952 reason: "denied".to_string(),
5953 };
5954 let json = serde_json::to_string(&event).unwrap();
5955 assert!(json.contains("permission_denied"));
5956 }
5957
5958 #[test]
5959 fn test_agent_event_serialize_context_compacted() {
5960 let event = AgentEvent::ContextCompacted {
5961 session_id: "sess-1".to_string(),
5962 before_messages: 100,
5963 after_messages: 20,
5964 percent_before: 0.85,
5965 };
5966 let json = serde_json::to_string(&event).unwrap();
5967 assert!(json.contains("context_compacted"));
5968 }
5969
5970 #[test]
5971 fn test_agent_event_serialize_turn_start() {
5972 let event = AgentEvent::TurnStart { turn: 3 };
5973 let json = serde_json::to_string(&event).unwrap();
5974 assert!(json.contains("turn_start"));
5975 }
5976
5977 #[test]
5978 fn test_agent_event_serialize_turn_end() {
5979 let event = AgentEvent::TurnEnd {
5980 turn: 3,
5981 usage: TokenUsage::default(),
5982 };
5983 let json = serde_json::to_string(&event).unwrap();
5984 assert!(json.contains("turn_end"));
5985 }
5986
5987 #[test]
5988 fn test_agent_event_serialize_end() {
5989 let event = AgentEvent::End {
5990 text: "Done".to_string(),
5991 usage: TokenUsage {
5992 prompt_tokens: 100,
5993 completion_tokens: 50,
5994 total_tokens: 150,
5995 cache_read_tokens: None,
5996 cache_write_tokens: None,
5997 },
5998 meta: None,
5999 };
6000 let json = serde_json::to_string(&event).unwrap();
6001 assert!(json.contains("agent_end"));
6002 }
6003
6004 #[test]
6009 fn test_agent_result_fields() {
6010 let result = AgentResult {
6011 text: "output".to_string(),
6012 messages: vec![Message::user("hello")],
6013 usage: TokenUsage::default(),
6014 tool_calls_count: 3,
6015 };
6016 assert_eq!(result.text, "output");
6017 assert_eq!(result.messages.len(), 1);
6018 assert_eq!(result.tool_calls_count, 3);
6019 }
6020
6021 #[test]
6026 fn test_agent_event_serialize_context_resolving() {
6027 let event = AgentEvent::ContextResolving {
6028 providers: vec!["provider1".to_string(), "provider2".to_string()],
6029 };
6030 let json = serde_json::to_string(&event).unwrap();
6031 assert!(json.contains("context_resolving"));
6032 assert!(json.contains("provider1"));
6033 }
6034
6035 #[test]
6036 fn test_agent_event_serialize_context_resolved() {
6037 let event = AgentEvent::ContextResolved {
6038 total_items: 5,
6039 total_tokens: 1000,
6040 };
6041 let json = serde_json::to_string(&event).unwrap();
6042 assert!(json.contains("context_resolved"));
6043 assert!(json.contains("1000"));
6044 }
6045
6046 #[test]
6047 fn test_agent_event_serialize_command_dead_lettered() {
6048 let event = AgentEvent::CommandDeadLettered {
6049 command_id: "cmd-1".to_string(),
6050 command_type: "bash".to_string(),
6051 lane: "execute".to_string(),
6052 error: "timeout".to_string(),
6053 attempts: 3,
6054 };
6055 let json = serde_json::to_string(&event).unwrap();
6056 assert!(json.contains("command_dead_lettered"));
6057 assert!(json.contains("cmd-1"));
6058 }
6059
6060 #[test]
6061 fn test_agent_event_serialize_command_retry() {
6062 let event = AgentEvent::CommandRetry {
6063 command_id: "cmd-2".to_string(),
6064 command_type: "read".to_string(),
6065 lane: "query".to_string(),
6066 attempt: 2,
6067 delay_ms: 1000,
6068 };
6069 let json = serde_json::to_string(&event).unwrap();
6070 assert!(json.contains("command_retry"));
6071 assert!(json.contains("cmd-2"));
6072 }
6073
6074 #[test]
6075 fn test_agent_event_serialize_queue_alert() {
6076 let event = AgentEvent::QueueAlert {
6077 level: "warning".to_string(),
6078 alert_type: "depth".to_string(),
6079 message: "Queue depth exceeded".to_string(),
6080 };
6081 let json = serde_json::to_string(&event).unwrap();
6082 assert!(json.contains("queue_alert"));
6083 assert!(json.contains("warning"));
6084 }
6085
6086 #[test]
6087 fn test_agent_event_serialize_task_updated() {
6088 let event = AgentEvent::TaskUpdated {
6089 session_id: "sess-1".to_string(),
6090 tasks: vec![],
6091 };
6092 let json = serde_json::to_string(&event).unwrap();
6093 assert!(json.contains("task_updated"));
6094 assert!(json.contains("sess-1"));
6095 }
6096
6097 #[test]
6098 fn test_agent_event_serialize_memory_stored() {
6099 let event = AgentEvent::MemoryStored {
6100 memory_id: "mem-1".to_string(),
6101 memory_type: "conversation".to_string(),
6102 importance: 0.8,
6103 tags: vec!["important".to_string()],
6104 };
6105 let json = serde_json::to_string(&event).unwrap();
6106 assert!(json.contains("memory_stored"));
6107 assert!(json.contains("mem-1"));
6108 }
6109
6110 #[test]
6111 fn test_agent_event_serialize_memory_recalled() {
6112 let event = AgentEvent::MemoryRecalled {
6113 memory_id: "mem-2".to_string(),
6114 content: "Previous conversation".to_string(),
6115 relevance: 0.9,
6116 };
6117 let json = serde_json::to_string(&event).unwrap();
6118 assert!(json.contains("memory_recalled"));
6119 assert!(json.contains("mem-2"));
6120 }
6121
6122 #[test]
6123 fn test_agent_event_serialize_memories_searched() {
6124 let event = AgentEvent::MemoriesSearched {
6125 query: Some("search term".to_string()),
6126 tags: vec!["tag1".to_string()],
6127 result_count: 5,
6128 };
6129 let json = serde_json::to_string(&event).unwrap();
6130 assert!(json.contains("memories_searched"));
6131 assert!(json.contains("search term"));
6132 }
6133
6134 #[test]
6135 fn test_agent_event_serialize_memory_cleared() {
6136 let event = AgentEvent::MemoryCleared {
6137 tier: "short_term".to_string(),
6138 count: 10,
6139 };
6140 let json = serde_json::to_string(&event).unwrap();
6141 assert!(json.contains("memory_cleared"));
6142 assert!(json.contains("short_term"));
6143 }
6144
6145 #[test]
6146 fn test_agent_event_serialize_subagent_start() {
6147 let event = AgentEvent::SubagentStart {
6148 task_id: "task-1".to_string(),
6149 session_id: "child-sess".to_string(),
6150 parent_session_id: "parent-sess".to_string(),
6151 agent: "explore".to_string(),
6152 description: "Explore codebase".to_string(),
6153 };
6154 let json = serde_json::to_string(&event).unwrap();
6155 assert!(json.contains("subagent_start"));
6156 assert!(json.contains("explore"));
6157 }
6158
6159 #[test]
6160 fn test_agent_event_serialize_subagent_progress() {
6161 let event = AgentEvent::SubagentProgress {
6162 task_id: "task-1".to_string(),
6163 session_id: "child-sess".to_string(),
6164 status: "processing".to_string(),
6165 metadata: serde_json::json!({"progress": 50}),
6166 };
6167 let json = serde_json::to_string(&event).unwrap();
6168 assert!(json.contains("subagent_progress"));
6169 assert!(json.contains("processing"));
6170 }
6171
6172 #[test]
6173 fn test_agent_event_serialize_subagent_end() {
6174 let event = AgentEvent::SubagentEnd {
6175 task_id: "task-1".to_string(),
6176 session_id: "child-sess".to_string(),
6177 agent: "explore".to_string(),
6178 output: "Found 10 files".to_string(),
6179 success: true,
6180 };
6181 let json = serde_json::to_string(&event).unwrap();
6182 assert!(json.contains("subagent_end"));
6183 assert!(json.contains("Found 10 files"));
6184 }
6185
6186 #[test]
6187 fn test_agent_event_serialize_planning_start() {
6188 let event = AgentEvent::PlanningStart {
6189 prompt: "Build a web app".to_string(),
6190 };
6191 let json = serde_json::to_string(&event).unwrap();
6192 assert!(json.contains("planning_start"));
6193 assert!(json.contains("Build a web app"));
6194 }
6195
6196 #[test]
6197 fn test_agent_event_serialize_planning_end() {
6198 use crate::planning::{Complexity, ExecutionPlan};
6199 let plan = ExecutionPlan::new("Test goal".to_string(), Complexity::Simple);
6200 let event = AgentEvent::PlanningEnd {
6201 plan,
6202 estimated_steps: 3,
6203 };
6204 let json = serde_json::to_string(&event).unwrap();
6205 assert!(json.contains("planning_end"));
6206 assert!(json.contains("estimated_steps"));
6207 }
6208
6209 #[test]
6210 fn test_agent_event_serialize_step_start() {
6211 let event = AgentEvent::StepStart {
6212 step_id: "step-1".to_string(),
6213 description: "Initialize project".to_string(),
6214 step_number: 1,
6215 total_steps: 5,
6216 };
6217 let json = serde_json::to_string(&event).unwrap();
6218 assert!(json.contains("step_start"));
6219 assert!(json.contains("Initialize project"));
6220 }
6221
6222 #[test]
6223 fn test_agent_event_serialize_step_end() {
6224 let event = AgentEvent::StepEnd {
6225 step_id: "step-1".to_string(),
6226 status: TaskStatus::Completed,
6227 step_number: 1,
6228 total_steps: 5,
6229 };
6230 let json = serde_json::to_string(&event).unwrap();
6231 assert!(json.contains("step_end"));
6232 assert!(json.contains("step-1"));
6233 }
6234
6235 #[test]
6236 fn test_agent_event_serialize_goal_extracted() {
6237 use crate::planning::AgentGoal;
6238 let goal = AgentGoal::new("Complete the task".to_string());
6239 let event = AgentEvent::GoalExtracted { goal };
6240 let json = serde_json::to_string(&event).unwrap();
6241 assert!(json.contains("goal_extracted"));
6242 }
6243
6244 #[test]
6245 fn test_agent_event_serialize_goal_progress() {
6246 let event = AgentEvent::GoalProgress {
6247 goal: "Build app".to_string(),
6248 progress: 0.5,
6249 completed_steps: 2,
6250 total_steps: 4,
6251 };
6252 let json = serde_json::to_string(&event).unwrap();
6253 assert!(json.contains("goal_progress"));
6254 assert!(json.contains("0.5"));
6255 }
6256
6257 #[test]
6258 fn test_agent_event_serialize_goal_achieved() {
6259 let event = AgentEvent::GoalAchieved {
6260 goal: "Build app".to_string(),
6261 total_steps: 4,
6262 duration_ms: 5000,
6263 };
6264 let json = serde_json::to_string(&event).unwrap();
6265 assert!(json.contains("goal_achieved"));
6266 assert!(json.contains("5000"));
6267 }
6268
6269 #[tokio::test]
6270 async fn test_extract_goal_with_json_response() {
6271 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
6273 r#"{"description": "Build web app", "success_criteria": ["App runs on port 3000", "Has login page"]}"#,
6274 )]));
6275 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
6276 let agent = AgentLoop::new(
6277 mock_client,
6278 tool_executor,
6279 test_tool_context(),
6280 AgentConfig::default(),
6281 );
6282
6283 let goal = agent.extract_goal("Build a web app").await.unwrap();
6284 assert_eq!(goal.description, "Build web app");
6285 assert_eq!(goal.success_criteria.len(), 2);
6286 assert_eq!(goal.success_criteria[0], "App runs on port 3000");
6287 }
6288
6289 #[tokio::test]
6290 async fn test_extract_goal_fallback_on_non_json() {
6291 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
6293 "Some non-JSON response",
6294 )]));
6295 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
6296 let agent = AgentLoop::new(
6297 mock_client,
6298 tool_executor,
6299 test_tool_context(),
6300 AgentConfig::default(),
6301 );
6302
6303 let goal = agent.extract_goal("Do something").await.unwrap();
6304 assert_eq!(goal.description, "Do something");
6306 assert_eq!(goal.success_criteria.len(), 2);
6308 }
6309
6310 #[tokio::test]
6311 async fn test_check_goal_achievement_json_yes() {
6312 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
6313 r#"{"achieved": true, "progress": 1.0, "remaining_criteria": []}"#,
6314 )]));
6315 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
6316 let agent = AgentLoop::new(
6317 mock_client,
6318 tool_executor,
6319 test_tool_context(),
6320 AgentConfig::default(),
6321 );
6322
6323 let goal = crate::planning::AgentGoal::new("Test goal".to_string());
6324 let achieved = agent
6325 .check_goal_achievement(&goal, "All done")
6326 .await
6327 .unwrap();
6328 assert!(achieved);
6329 }
6330
6331 #[tokio::test]
6332 async fn test_check_goal_achievement_fallback_not_done() {
6333 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
6335 "invalid json",
6336 )]));
6337 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
6338 let agent = AgentLoop::new(
6339 mock_client,
6340 tool_executor,
6341 test_tool_context(),
6342 AgentConfig::default(),
6343 );
6344
6345 let goal = crate::planning::AgentGoal::new("Test goal".to_string());
6346 let achieved = agent
6348 .check_goal_achievement(&goal, "still working")
6349 .await
6350 .unwrap();
6351 assert!(!achieved);
6352 }
6353
6354 #[test]
6359 fn test_build_augmented_system_prompt_empty_context() {
6360 let mock_client = Arc::new(MockLlmClient::new(vec![]));
6361 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
6362 let config = AgentConfig {
6363 prompt_slots: SystemPromptSlots {
6364 extra: Some("Base prompt".to_string()),
6365 ..Default::default()
6366 },
6367 ..Default::default()
6368 };
6369 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
6370
6371 let result = agent.build_augmented_system_prompt(&[]);
6372 assert!(result.unwrap().contains("Base prompt"));
6373 }
6374
6375 #[test]
6376 fn test_build_augmented_system_prompt_no_custom_slots() {
6377 let mock_client = Arc::new(MockLlmClient::new(vec![]));
6378 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
6379 let agent = AgentLoop::new(
6380 mock_client,
6381 tool_executor,
6382 test_tool_context(),
6383 AgentConfig::default(),
6384 );
6385
6386 let result = agent.build_augmented_system_prompt(&[]);
6387 assert!(result.is_some());
6389 assert!(result.unwrap().contains("Core Behaviour"));
6390 }
6391
6392 #[test]
6393 fn test_build_augmented_system_prompt_with_context_no_base() {
6394 use crate::context::{ContextItem, ContextResult, ContextType};
6395
6396 let mock_client = Arc::new(MockLlmClient::new(vec![]));
6397 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
6398 let agent = AgentLoop::new(
6399 mock_client,
6400 tool_executor,
6401 test_tool_context(),
6402 AgentConfig::default(),
6403 );
6404
6405 let context = vec![ContextResult {
6406 provider: "test".to_string(),
6407 items: vec![ContextItem::new("id1", ContextType::Resource, "Content")],
6408 total_tokens: 10,
6409 truncated: false,
6410 }];
6411
6412 let result = agent.build_augmented_system_prompt(&context);
6413 assert!(result.is_some());
6414 let text = result.unwrap();
6415 assert!(text.contains("<context"));
6416 assert!(text.contains("Content"));
6417 }
6418
6419 #[test]
6424 fn test_agent_result_clone() {
6425 let result = AgentResult {
6426 text: "output".to_string(),
6427 messages: vec![Message::user("hello")],
6428 usage: TokenUsage::default(),
6429 tool_calls_count: 3,
6430 };
6431 let cloned = result.clone();
6432 assert_eq!(cloned.text, result.text);
6433 assert_eq!(cloned.tool_calls_count, result.tool_calls_count);
6434 }
6435
6436 #[test]
6437 fn test_agent_result_debug() {
6438 let result = AgentResult {
6439 text: "output".to_string(),
6440 messages: vec![Message::user("hello")],
6441 usage: TokenUsage::default(),
6442 tool_calls_count: 3,
6443 };
6444 let debug = format!("{:?}", result);
6445 assert!(debug.contains("AgentResult"));
6446 assert!(debug.contains("output"));
6447 }
6448
6449 #[tokio::test]
6458 async fn test_tool_command_command_type() {
6459 let executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
6460 let cmd = ToolCommand {
6461 tool_executor: executor,
6462 tool_name: "read".to_string(),
6463 tool_args: serde_json::json!({"file": "test.rs"}),
6464 skill_registry: None,
6465 tool_context: test_tool_context(),
6466 };
6467 assert_eq!(cmd.command_type(), "read");
6468 }
6469
6470 #[tokio::test]
6471 async fn test_tool_command_payload() {
6472 let executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
6473 let args = serde_json::json!({"file": "test.rs", "offset": 10});
6474 let cmd = ToolCommand {
6475 tool_executor: executor,
6476 tool_name: "read".to_string(),
6477 tool_args: args.clone(),
6478 skill_registry: None,
6479 tool_context: test_tool_context(),
6480 };
6481 assert_eq!(cmd.payload(), args);
6482 }
6483
6484 #[tokio::test(flavor = "multi_thread")]
6489 async fn test_agent_loop_with_queue() {
6490 use tokio::sync::broadcast;
6491
6492 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
6493 "Hello",
6494 )]));
6495 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
6496 let config = AgentConfig::default();
6497
6498 let (event_tx, _) = broadcast::channel(100);
6499 let queue = SessionLaneQueue::new("test-session", SessionQueueConfig::default(), event_tx)
6500 .await
6501 .unwrap();
6502
6503 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config)
6504 .with_queue(Arc::new(queue));
6505
6506 assert!(agent.command_queue.is_some());
6507 }
6508
6509 #[tokio::test]
6510 async fn test_agent_loop_without_queue() {
6511 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
6512 "Hello",
6513 )]));
6514 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
6515 let config = AgentConfig::default();
6516
6517 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
6518
6519 assert!(agent.command_queue.is_none());
6520 }
6521
6522 #[tokio::test]
6527 async fn test_execute_plan_parallel_independent() {
6528 use crate::planning::{Complexity, ExecutionPlan, Task};
6529
6530 let mock_client = Arc::new(MockLlmClient::new(vec![
6533 MockLlmClient::text_response("Step 1 done"),
6534 MockLlmClient::text_response("Step 2 done"),
6535 MockLlmClient::text_response("Step 3 done"),
6536 ]));
6537
6538 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
6539 let config = AgentConfig::default();
6540 let agent = AgentLoop::new(
6541 mock_client.clone(),
6542 tool_executor,
6543 test_tool_context(),
6544 config,
6545 );
6546
6547 let mut plan = ExecutionPlan::new("Test parallel", Complexity::Simple);
6548 plan.add_step(Task::new("s1", "First step"));
6549 plan.add_step(Task::new("s2", "Second step"));
6550 plan.add_step(Task::new("s3", "Third step"));
6551
6552 let (tx, mut rx) = mpsc::channel(100);
6553 let result = agent.execute_plan(&[], &plan, Some(tx)).await.unwrap();
6554
6555 assert_eq!(result.usage.total_tokens, 45);
6557
6558 let mut step_starts = Vec::new();
6560 let mut step_ends = Vec::new();
6561 rx.close();
6562 while let Some(event) = rx.recv().await {
6563 match event {
6564 AgentEvent::StepStart { step_id, .. } => step_starts.push(step_id),
6565 AgentEvent::StepEnd {
6566 step_id, status, ..
6567 } => {
6568 assert_eq!(status, TaskStatus::Completed);
6569 step_ends.push(step_id);
6570 }
6571 _ => {}
6572 }
6573 }
6574 assert_eq!(step_starts.len(), 3);
6575 assert_eq!(step_ends.len(), 3);
6576 }
6577
6578 #[tokio::test]
6579 async fn test_execute_plan_respects_dependencies() {
6580 use crate::planning::{Complexity, ExecutionPlan, Task};
6581
6582 let mock_client = Arc::new(MockLlmClient::new(vec![
6585 MockLlmClient::text_response("Step 1 done"),
6586 MockLlmClient::text_response("Step 2 done"),
6587 MockLlmClient::text_response("Step 3 done"),
6588 ]));
6589
6590 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
6591 let config = AgentConfig::default();
6592 let agent = AgentLoop::new(
6593 mock_client.clone(),
6594 tool_executor,
6595 test_tool_context(),
6596 config,
6597 );
6598
6599 let mut plan = ExecutionPlan::new("Test deps", Complexity::Medium);
6600 plan.add_step(Task::new("s1", "Independent A"));
6601 plan.add_step(Task::new("s2", "Independent B"));
6602 plan.add_step(
6603 Task::new("s3", "Depends on A+B")
6604 .with_dependencies(vec!["s1".to_string(), "s2".to_string()]),
6605 );
6606
6607 let (tx, mut rx) = mpsc::channel(100);
6608 let result = agent.execute_plan(&[], &plan, Some(tx)).await.unwrap();
6609
6610 assert_eq!(result.usage.total_tokens, 45);
6612
6613 let mut events = Vec::new();
6615 rx.close();
6616 while let Some(event) = rx.recv().await {
6617 match &event {
6618 AgentEvent::StepStart { step_id, .. } => {
6619 events.push(format!("start:{}", step_id));
6620 }
6621 AgentEvent::StepEnd { step_id, .. } => {
6622 events.push(format!("end:{}", step_id));
6623 }
6624 _ => {}
6625 }
6626 }
6627
6628 let s1_end = events.iter().position(|e| e == "end:s1").unwrap();
6630 let s2_end = events.iter().position(|e| e == "end:s2").unwrap();
6631 let s3_start = events.iter().position(|e| e == "start:s3").unwrap();
6632 assert!(
6633 s3_start > s1_end,
6634 "s3 started before s1 ended: {:?}",
6635 events
6636 );
6637 assert!(
6638 s3_start > s2_end,
6639 "s3 started before s2 ended: {:?}",
6640 events
6641 );
6642
6643 assert!(result.text.contains("Step 3 done") || !result.text.is_empty());
6645 }
6646
6647 #[tokio::test]
6648 async fn test_execute_plan_handles_step_failure() {
6649 use crate::planning::{Complexity, ExecutionPlan, Task};
6650
6651 let mock_client = Arc::new(MockLlmClient::new(vec![
6661 MockLlmClient::text_response("s1 done"),
6663 MockLlmClient::text_response("s3 done"),
6664 ]));
6667
6668 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
6669 let config = AgentConfig::default();
6670 let agent = AgentLoop::new(
6671 mock_client.clone(),
6672 tool_executor,
6673 test_tool_context(),
6674 config,
6675 );
6676
6677 let mut plan = ExecutionPlan::new("Test failure", Complexity::Medium);
6678 plan.add_step(Task::new("s1", "Independent step"));
6679 plan.add_step(Task::new("s2", "Depends on s1").with_dependencies(vec!["s1".to_string()]));
6680 plan.add_step(Task::new("s3", "Another independent"));
6681 plan.add_step(Task::new("s4", "Depends on s2").with_dependencies(vec!["s2".to_string()]));
6682
6683 let (tx, mut rx) = mpsc::channel(100);
6684 let _result = agent.execute_plan(&[], &plan, Some(tx)).await.unwrap();
6685
6686 let mut completed_steps = Vec::new();
6689 let mut failed_steps = Vec::new();
6690 rx.close();
6691 while let Some(event) = rx.recv().await {
6692 if let AgentEvent::StepEnd {
6693 step_id, status, ..
6694 } = event
6695 {
6696 match status {
6697 TaskStatus::Completed => completed_steps.push(step_id),
6698 TaskStatus::Failed => failed_steps.push(step_id),
6699 _ => {}
6700 }
6701 }
6702 }
6703
6704 assert!(
6705 completed_steps.contains(&"s1".to_string()),
6706 "s1 should complete"
6707 );
6708 assert!(
6709 completed_steps.contains(&"s3".to_string()),
6710 "s3 should complete"
6711 );
6712 assert!(failed_steps.contains(&"s2".to_string()), "s2 should fail");
6713 assert!(
6715 !completed_steps.contains(&"s4".to_string()),
6716 "s4 should not complete"
6717 );
6718 assert!(
6719 !failed_steps.contains(&"s4".to_string()),
6720 "s4 should not fail (never started)"
6721 );
6722 }
6723
6724 #[test]
6729 fn test_agent_config_resilience_defaults() {
6730 let config = AgentConfig::default();
6731 assert_eq!(config.max_parse_retries, 2);
6732 assert_eq!(config.tool_timeout_ms, None);
6733 assert_eq!(config.circuit_breaker_threshold, 3);
6734 }
6735
6736 #[tokio::test]
6738 async fn test_parse_error_recovery_bails_after_threshold() {
6739 let mock_client = Arc::new(MockLlmClient::new(vec![
6741 MockLlmClient::tool_call_response(
6742 "c1",
6743 "bash",
6744 serde_json::json!({"__parse_error": "unexpected token at position 5"}),
6745 ),
6746 MockLlmClient::tool_call_response(
6747 "c2",
6748 "bash",
6749 serde_json::json!({"__parse_error": "missing closing brace"}),
6750 ),
6751 MockLlmClient::tool_call_response(
6752 "c3",
6753 "bash",
6754 serde_json::json!({"__parse_error": "still broken"}),
6755 ),
6756 MockLlmClient::text_response("Done"), ]));
6758
6759 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
6760 let config = AgentConfig {
6761 max_parse_retries: 2,
6762 ..AgentConfig::default()
6763 };
6764 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
6765 let result = agent.execute(&[], "Do something", None).await;
6766 assert!(result.is_err(), "should bail after parse error threshold");
6767 let err = result.unwrap_err().to_string();
6768 assert!(
6769 err.contains("malformed tool arguments"),
6770 "error should mention malformed tool arguments, got: {}",
6771 err
6772 );
6773 }
6774
6775 #[tokio::test]
6777 async fn test_parse_error_counter_resets_on_success() {
6778 let mock_client = Arc::new(MockLlmClient::new(vec![
6782 MockLlmClient::tool_call_response(
6783 "c1",
6784 "bash",
6785 serde_json::json!({"__parse_error": "bad args"}),
6786 ),
6787 MockLlmClient::tool_call_response(
6788 "c2",
6789 "bash",
6790 serde_json::json!({"__parse_error": "bad args again"}),
6791 ),
6792 MockLlmClient::tool_call_response(
6794 "c3",
6795 "bash",
6796 serde_json::json!({"command": "echo ok"}),
6797 ),
6798 MockLlmClient::text_response("All done"),
6799 ]));
6800
6801 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
6802 let config = AgentConfig {
6803 max_parse_retries: 2,
6804 ..AgentConfig::default()
6805 };
6806 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
6807 let result = agent.execute(&[], "Do something", None).await;
6808 assert!(
6809 result.is_ok(),
6810 "should not bail — counter reset after successful tool, got: {:?}",
6811 result.err()
6812 );
6813 assert_eq!(result.unwrap().text, "All done");
6814 }
6815
6816 #[tokio::test]
6818 async fn test_tool_timeout_produces_error_result() {
6819 let mock_client = Arc::new(MockLlmClient::new(vec![
6820 MockLlmClient::tool_call_response(
6821 "t1",
6822 "bash",
6823 serde_json::json!({"command": "sleep 10"}),
6824 ),
6825 MockLlmClient::text_response("The command timed out."),
6826 ]));
6827
6828 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
6829 let config = AgentConfig {
6830 tool_timeout_ms: Some(50),
6832 ..AgentConfig::default()
6833 };
6834 let agent = AgentLoop::new(
6835 mock_client.clone(),
6836 tool_executor,
6837 test_tool_context(),
6838 config,
6839 );
6840 let result = agent.execute(&[], "Run sleep", None).await;
6841 assert!(
6842 result.is_ok(),
6843 "session should continue after tool timeout: {:?}",
6844 result.err()
6845 );
6846 assert_eq!(result.unwrap().text, "The command timed out.");
6847 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 2);
6849 }
6850
6851 #[tokio::test]
6853 async fn test_tool_within_timeout_succeeds() {
6854 let mock_client = Arc::new(MockLlmClient::new(vec![
6855 MockLlmClient::tool_call_response(
6856 "t1",
6857 "bash",
6858 serde_json::json!({"command": "echo fast"}),
6859 ),
6860 MockLlmClient::text_response("Command succeeded."),
6861 ]));
6862
6863 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
6864 let config = AgentConfig {
6865 tool_timeout_ms: Some(5_000), ..AgentConfig::default()
6867 };
6868 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
6869 let result = agent.execute(&[], "Run something fast", None).await;
6870 assert!(
6871 result.is_ok(),
6872 "fast tool should succeed: {:?}",
6873 result.err()
6874 );
6875 assert_eq!(result.unwrap().text, "Command succeeded.");
6876 }
6877
6878 #[tokio::test]
6880 async fn test_circuit_breaker_retries_non_streaming() {
6881 let mock_client = Arc::new(MockLlmClient::new(vec![]));
6884
6885 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
6886 let config = AgentConfig {
6887 circuit_breaker_threshold: 2,
6888 ..AgentConfig::default()
6889 };
6890 let agent = AgentLoop::new(
6891 mock_client.clone(),
6892 tool_executor,
6893 test_tool_context(),
6894 config,
6895 );
6896 let result = agent.execute(&[], "Hello", None).await;
6897 assert!(result.is_err(), "should fail when LLM always errors");
6898 let err = result.unwrap_err().to_string();
6899 assert!(
6900 err.contains("circuit breaker"),
6901 "error should mention circuit breaker, got: {}",
6902 err
6903 );
6904 assert_eq!(
6905 mock_client.call_count.load(Ordering::SeqCst),
6906 2,
6907 "should make exactly threshold=2 LLM calls"
6908 );
6909 }
6910
6911 #[tokio::test]
6913 async fn test_circuit_breaker_threshold_one_no_retry() {
6914 let mock_client = Arc::new(MockLlmClient::new(vec![]));
6915
6916 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
6917 let config = AgentConfig {
6918 circuit_breaker_threshold: 1,
6919 ..AgentConfig::default()
6920 };
6921 let agent = AgentLoop::new(
6922 mock_client.clone(),
6923 tool_executor,
6924 test_tool_context(),
6925 config,
6926 );
6927 let result = agent.execute(&[], "Hello", None).await;
6928 assert!(result.is_err());
6929 assert_eq!(
6930 mock_client.call_count.load(Ordering::SeqCst),
6931 1,
6932 "with threshold=1 exactly one attempt should be made"
6933 );
6934 }
6935
6936 #[tokio::test]
6938 async fn test_circuit_breaker_succeeds_if_llm_recovers() {
6939 struct FailOnceThenSucceed {
6941 inner: MockLlmClient,
6942 failed_once: std::sync::atomic::AtomicBool,
6943 call_count: AtomicUsize,
6944 }
6945
6946 #[async_trait::async_trait]
6947 impl LlmClient for FailOnceThenSucceed {
6948 async fn complete(
6949 &self,
6950 messages: &[Message],
6951 system: Option<&str>,
6952 tools: &[ToolDefinition],
6953 ) -> Result<LlmResponse> {
6954 self.call_count.fetch_add(1, Ordering::SeqCst);
6955 let already_failed = self
6956 .failed_once
6957 .swap(true, std::sync::atomic::Ordering::SeqCst);
6958 if !already_failed {
6959 anyhow::bail!("transient network error");
6960 }
6961 self.inner.complete(messages, system, tools).await
6962 }
6963
6964 async fn complete_streaming(
6965 &self,
6966 messages: &[Message],
6967 system: Option<&str>,
6968 tools: &[ToolDefinition],
6969 ) -> Result<tokio::sync::mpsc::Receiver<crate::llm::StreamEvent>> {
6970 self.inner.complete_streaming(messages, system, tools).await
6971 }
6972 }
6973
6974 let mock = Arc::new(FailOnceThenSucceed {
6975 inner: MockLlmClient::new(vec![MockLlmClient::text_response("Recovered!")]),
6976 failed_once: std::sync::atomic::AtomicBool::new(false),
6977 call_count: AtomicUsize::new(0),
6978 });
6979
6980 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
6981 let config = AgentConfig {
6982 circuit_breaker_threshold: 3,
6983 ..AgentConfig::default()
6984 };
6985 let agent = AgentLoop::new(mock.clone(), tool_executor, test_tool_context(), config);
6986 let result = agent.execute(&[], "Hello", None).await;
6987 assert!(
6988 result.is_ok(),
6989 "should succeed when LLM recovers within threshold: {:?}",
6990 result.err()
6991 );
6992 assert_eq!(result.unwrap().text, "Recovered!");
6993 assert_eq!(
6994 mock.call_count.load(Ordering::SeqCst),
6995 2,
6996 "should have made exactly 2 calls (1 fail + 1 success)"
6997 );
6998 }
6999
7000 #[test]
7003 fn test_looks_incomplete_empty() {
7004 assert!(AgentLoop::looks_incomplete(""));
7005 assert!(AgentLoop::looks_incomplete(" "));
7006 }
7007
7008 #[test]
7009 fn test_looks_incomplete_trailing_colon() {
7010 assert!(AgentLoop::looks_incomplete("Let me check the file:"));
7011 assert!(AgentLoop::looks_incomplete("Next steps:"));
7012 }
7013
7014 #[test]
7015 fn test_looks_incomplete_ellipsis() {
7016 assert!(AgentLoop::looks_incomplete("Working on it..."));
7017 assert!(AgentLoop::looks_incomplete("Processing…"));
7018 }
7019
7020 #[test]
7021 fn test_looks_incomplete_intent_phrases() {
7022 assert!(AgentLoop::looks_incomplete(
7023 "I'll start by reading the file."
7024 ));
7025 assert!(AgentLoop::looks_incomplete(
7026 "Let me check the configuration."
7027 ));
7028 assert!(AgentLoop::looks_incomplete("I will now run the tests."));
7029 assert!(AgentLoop::looks_incomplete(
7030 "I need to update the Cargo.toml."
7031 ));
7032 }
7033
7034 #[test]
7035 fn test_looks_complete_final_answer() {
7036 assert!(!AgentLoop::looks_incomplete(
7038 "The tests pass. All changes have been applied successfully."
7039 ));
7040 assert!(!AgentLoop::looks_incomplete(
7041 "Done. I've updated the three files and verified the build succeeds."
7042 ));
7043 assert!(!AgentLoop::looks_incomplete("42"));
7044 assert!(!AgentLoop::looks_incomplete("Yes."));
7045 }
7046
7047 #[test]
7048 fn test_looks_incomplete_multiline_complete() {
7049 let text = "Here is the summary:\n\n- Fixed the bug in agent.rs\n- All tests pass\n- Build succeeds";
7050 assert!(!AgentLoop::looks_incomplete(text));
7051 }
7052}