1use crate::context::{ContextProvider, ContextQuery, ContextResult};
13use crate::hitl::ConfirmationProvider;
14use crate::hooks::{
15 ErrorType, GenerateEndEvent, GenerateStartEvent, HookEvent, HookExecutor, HookResult,
16 OnErrorEvent, PostResponseEvent, PostToolUseEvent, PrePromptEvent, PreToolUseEvent,
17 TokenUsageInfo, ToolCallInfo, ToolResultData,
18};
19use crate::llm::{LlmClient, LlmResponse, Message, TokenUsage, ToolDefinition};
20use crate::permissions::{PermissionChecker, PermissionDecision};
21use crate::planning::{AgentGoal, ExecutionPlan, TaskStatus};
22use crate::prompts::SystemPromptSlots;
23use crate::queue::SessionCommand;
24use crate::session_lane_queue::SessionLaneQueue;
25use crate::tool_search::ToolIndex;
26use crate::tools::{ToolContext, ToolExecutor, ToolStreamEvent};
27use anyhow::{Context, Result};
28use async_trait::async_trait;
29use futures::future::join_all;
30use serde::{Deserialize, Serialize};
31use serde_json::Value;
32use std::sync::Arc;
33use std::time::Duration;
34use tokio::sync::{mpsc, RwLock};
35
36const MAX_TOOL_ROUNDS: usize = 50;
38
39#[derive(Clone)]
41pub struct AgentConfig {
42 pub prompt_slots: SystemPromptSlots,
48 pub tools: Vec<ToolDefinition>,
49 pub max_tool_rounds: usize,
50 pub security_provider: Option<Arc<dyn crate::security::SecurityProvider>>,
52 pub permission_checker: Option<Arc<dyn PermissionChecker>>,
54 pub confirmation_manager: Option<Arc<dyn ConfirmationProvider>>,
56 pub context_providers: Vec<Arc<dyn ContextProvider>>,
58 pub planning_enabled: bool,
60 pub goal_tracking: bool,
62 pub hook_engine: Option<Arc<dyn HookExecutor>>,
64 pub skill_registry: Option<Arc<crate::skills::SkillRegistry>>,
66 pub max_parse_retries: u32,
72 pub tool_timeout_ms: Option<u64>,
78 pub circuit_breaker_threshold: u32,
85 pub auto_compact: bool,
87 pub auto_compact_threshold: f32,
90 pub max_context_tokens: usize,
93 pub llm_client: Option<Arc<dyn LlmClient>>,
95 pub memory: Option<Arc<crate::memory::AgentMemory>>,
97 pub continuation_enabled: bool,
105 pub max_continuation_turns: u32,
109 pub tool_index: Option<ToolIndex>,
114}
115
116impl std::fmt::Debug for AgentConfig {
117 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118 f.debug_struct("AgentConfig")
119 .field("prompt_slots", &self.prompt_slots)
120 .field("tools", &self.tools)
121 .field("max_tool_rounds", &self.max_tool_rounds)
122 .field("security_provider", &self.security_provider.is_some())
123 .field("permission_checker", &self.permission_checker.is_some())
124 .field("confirmation_manager", &self.confirmation_manager.is_some())
125 .field("context_providers", &self.context_providers.len())
126 .field("planning_enabled", &self.planning_enabled)
127 .field("goal_tracking", &self.goal_tracking)
128 .field("hook_engine", &self.hook_engine.is_some())
129 .field(
130 "skill_registry",
131 &self.skill_registry.as_ref().map(|r| r.len()),
132 )
133 .field("max_parse_retries", &self.max_parse_retries)
134 .field("tool_timeout_ms", &self.tool_timeout_ms)
135 .field("circuit_breaker_threshold", &self.circuit_breaker_threshold)
136 .field("auto_compact", &self.auto_compact)
137 .field("auto_compact_threshold", &self.auto_compact_threshold)
138 .field("max_context_tokens", &self.max_context_tokens)
139 .field("continuation_enabled", &self.continuation_enabled)
140 .field("max_continuation_turns", &self.max_continuation_turns)
141 .field("memory", &self.memory.is_some())
142 .field("tool_index", &self.tool_index.as_ref().map(|i| i.len()))
143 .finish()
144 }
145}
146
147impl Default for AgentConfig {
148 fn default() -> Self {
149 Self {
150 prompt_slots: SystemPromptSlots::default(),
151 tools: Vec::new(), max_tool_rounds: MAX_TOOL_ROUNDS,
153 security_provider: None,
154 permission_checker: None,
155 confirmation_manager: None,
156 context_providers: Vec::new(),
157 planning_enabled: false,
158 goal_tracking: false,
159 hook_engine: None,
160 skill_registry: Some(Arc::new(crate::skills::SkillRegistry::with_builtins())),
161 max_parse_retries: 2,
162 tool_timeout_ms: None,
163 circuit_breaker_threshold: 3,
164 auto_compact: false,
165 auto_compact_threshold: 0.80,
166 max_context_tokens: 200_000,
167 llm_client: None,
168 memory: None,
169 continuation_enabled: true,
170 max_continuation_turns: 3,
171 tool_index: None,
172 }
173 }
174}
175
176#[derive(Debug, Clone, Serialize, Deserialize)]
182#[serde(tag = "type")]
183#[non_exhaustive]
184pub enum AgentEvent {
185 #[serde(rename = "agent_start")]
187 Start { prompt: String },
188
189 #[serde(rename = "turn_start")]
191 TurnStart { turn: usize },
192
193 #[serde(rename = "text_delta")]
195 TextDelta { text: String },
196
197 #[serde(rename = "tool_start")]
199 ToolStart { id: String, name: String },
200
201 #[serde(rename = "tool_input_delta")]
203 ToolInputDelta { delta: String },
204
205 #[serde(rename = "tool_end")]
207 ToolEnd {
208 id: String,
209 name: String,
210 output: String,
211 exit_code: i32,
212 #[serde(skip_serializing_if = "Option::is_none")]
213 metadata: Option<serde_json::Value>,
214 },
215
216 #[serde(rename = "tool_output_delta")]
218 ToolOutputDelta {
219 id: String,
220 name: String,
221 delta: String,
222 },
223
224 #[serde(rename = "turn_end")]
226 TurnEnd { turn: usize, usage: TokenUsage },
227
228 #[serde(rename = "agent_end")]
230 End { text: String, usage: TokenUsage },
231
232 #[serde(rename = "error")]
234 Error { message: String },
235
236 #[serde(rename = "confirmation_required")]
238 ConfirmationRequired {
239 tool_id: String,
240 tool_name: String,
241 args: serde_json::Value,
242 timeout_ms: u64,
243 },
244
245 #[serde(rename = "confirmation_received")]
247 ConfirmationReceived {
248 tool_id: String,
249 approved: bool,
250 reason: Option<String>,
251 },
252
253 #[serde(rename = "confirmation_timeout")]
255 ConfirmationTimeout {
256 tool_id: String,
257 action_taken: String, },
259
260 #[serde(rename = "external_task_pending")]
262 ExternalTaskPending {
263 task_id: String,
264 session_id: String,
265 lane: crate::hitl::SessionLane,
266 command_type: String,
267 payload: serde_json::Value,
268 timeout_ms: u64,
269 },
270
271 #[serde(rename = "external_task_completed")]
273 ExternalTaskCompleted {
274 task_id: String,
275 session_id: String,
276 success: bool,
277 },
278
279 #[serde(rename = "permission_denied")]
281 PermissionDenied {
282 tool_id: String,
283 tool_name: String,
284 args: serde_json::Value,
285 reason: String,
286 },
287
288 #[serde(rename = "context_resolving")]
290 ContextResolving { providers: Vec<String> },
291
292 #[serde(rename = "context_resolved")]
294 ContextResolved {
295 total_items: usize,
296 total_tokens: usize,
297 },
298
299 #[serde(rename = "command_dead_lettered")]
304 CommandDeadLettered {
305 command_id: String,
306 command_type: String,
307 lane: String,
308 error: String,
309 attempts: u32,
310 },
311
312 #[serde(rename = "command_retry")]
314 CommandRetry {
315 command_id: String,
316 command_type: String,
317 lane: String,
318 attempt: u32,
319 delay_ms: u64,
320 },
321
322 #[serde(rename = "queue_alert")]
324 QueueAlert {
325 level: String,
326 alert_type: String,
327 message: String,
328 },
329
330 #[serde(rename = "task_updated")]
335 TaskUpdated {
336 session_id: String,
337 tasks: Vec<crate::planning::Task>,
338 },
339
340 #[serde(rename = "memory_stored")]
345 MemoryStored {
346 memory_id: String,
347 memory_type: String,
348 importance: f32,
349 tags: Vec<String>,
350 },
351
352 #[serde(rename = "memory_recalled")]
354 MemoryRecalled {
355 memory_id: String,
356 content: String,
357 relevance: f32,
358 },
359
360 #[serde(rename = "memories_searched")]
362 MemoriesSearched {
363 query: Option<String>,
364 tags: Vec<String>,
365 result_count: usize,
366 },
367
368 #[serde(rename = "memory_cleared")]
370 MemoryCleared {
371 tier: String, count: u64,
373 },
374
375 #[serde(rename = "subagent_start")]
380 SubagentStart {
381 task_id: String,
383 session_id: String,
385 parent_session_id: String,
387 agent: String,
389 description: String,
391 },
392
393 #[serde(rename = "subagent_progress")]
395 SubagentProgress {
396 task_id: String,
398 session_id: String,
400 status: String,
402 metadata: serde_json::Value,
404 },
405
406 #[serde(rename = "subagent_end")]
408 SubagentEnd {
409 task_id: String,
411 session_id: String,
413 agent: String,
415 output: String,
417 success: bool,
419 },
420
421 #[serde(rename = "planning_start")]
426 PlanningStart { prompt: String },
427
428 #[serde(rename = "planning_end")]
430 PlanningEnd {
431 plan: ExecutionPlan,
432 estimated_steps: usize,
433 },
434
435 #[serde(rename = "step_start")]
437 StepStart {
438 step_id: String,
439 description: String,
440 step_number: usize,
441 total_steps: usize,
442 },
443
444 #[serde(rename = "step_end")]
446 StepEnd {
447 step_id: String,
448 status: TaskStatus,
449 step_number: usize,
450 total_steps: usize,
451 },
452
453 #[serde(rename = "goal_extracted")]
455 GoalExtracted { goal: AgentGoal },
456
457 #[serde(rename = "goal_progress")]
459 GoalProgress {
460 goal: String,
461 progress: f32,
462 completed_steps: usize,
463 total_steps: usize,
464 },
465
466 #[serde(rename = "goal_achieved")]
468 GoalAchieved {
469 goal: String,
470 total_steps: usize,
471 duration_ms: i64,
472 },
473
474 #[serde(rename = "context_compacted")]
479 ContextCompacted {
480 session_id: String,
481 before_messages: usize,
482 after_messages: usize,
483 percent_before: f32,
484 },
485
486 #[serde(rename = "persistence_failed")]
491 PersistenceFailed {
492 session_id: String,
493 operation: String,
494 error: String,
495 },
496}
497
498#[derive(Debug, Clone)]
500pub struct AgentResult {
501 pub text: String,
502 pub messages: Vec<Message>,
503 pub usage: TokenUsage,
504 pub tool_calls_count: usize,
505}
506
507pub struct ToolCommand {
515 tool_executor: Arc<ToolExecutor>,
516 tool_name: String,
517 tool_args: Value,
518 tool_context: ToolContext,
519 skill_registry: Option<Arc<crate::skills::SkillRegistry>>,
520}
521
522impl ToolCommand {
523 pub fn new(
525 tool_executor: Arc<ToolExecutor>,
526 tool_name: String,
527 tool_args: Value,
528 tool_context: ToolContext,
529 skill_registry: Option<Arc<crate::skills::SkillRegistry>>,
530 ) -> Self {
531 Self {
532 tool_executor,
533 tool_name,
534 tool_args,
535 tool_context,
536 skill_registry,
537 }
538 }
539}
540
541#[async_trait]
542impl SessionCommand for ToolCommand {
543 async fn execute(&self) -> Result<Value> {
544 if let Some(registry) = &self.skill_registry {
546 let instruction_skills = registry.by_kind(crate::skills::SkillKind::Instruction);
547
548 let has_restrictions = instruction_skills.iter().any(|s| s.allowed_tools.is_some());
550
551 if has_restrictions {
552 let mut allowed = false;
553
554 for skill in &instruction_skills {
555 if skill.is_tool_allowed(&self.tool_name) {
556 allowed = true;
557 break;
558 }
559 }
560
561 if !allowed {
562 return Err(anyhow::anyhow!(
563 "Tool '{}' is not allowed by any active skill. Active skills restrict tools to their allowed-tools lists.",
564 self.tool_name
565 ));
566 }
567 }
568 }
569
570 let result = self
572 .tool_executor
573 .execute_with_context(&self.tool_name, &self.tool_args, &self.tool_context)
574 .await?;
575 Ok(serde_json::json!({
576 "output": result.output,
577 "exit_code": result.exit_code,
578 "metadata": result.metadata,
579 }))
580 }
581
582 fn command_type(&self) -> &str {
583 &self.tool_name
584 }
585
586 fn payload(&self) -> Value {
587 self.tool_args.clone()
588 }
589}
590
591#[derive(Clone)]
597pub struct AgentLoop {
598 llm_client: Arc<dyn LlmClient>,
599 tool_executor: Arc<ToolExecutor>,
600 tool_context: ToolContext,
601 config: AgentConfig,
602 tool_metrics: Option<Arc<RwLock<crate::telemetry::ToolMetrics>>>,
604 command_queue: Option<Arc<SessionLaneQueue>>,
606}
607
608impl AgentLoop {
609 pub fn new(
610 llm_client: Arc<dyn LlmClient>,
611 tool_executor: Arc<ToolExecutor>,
612 tool_context: ToolContext,
613 config: AgentConfig,
614 ) -> Self {
615 Self {
616 llm_client,
617 tool_executor,
618 tool_context,
619 config,
620 tool_metrics: None,
621 command_queue: None,
622 }
623 }
624
625 pub fn with_tool_metrics(
627 mut self,
628 metrics: Arc<RwLock<crate::telemetry::ToolMetrics>>,
629 ) -> Self {
630 self.tool_metrics = Some(metrics);
631 self
632 }
633
634 pub fn with_queue(mut self, queue: Arc<SessionLaneQueue>) -> Self {
639 self.command_queue = Some(queue);
640 self
641 }
642
643 async fn execute_tool_timed(
649 &self,
650 name: &str,
651 args: &serde_json::Value,
652 ctx: &ToolContext,
653 ) -> anyhow::Result<crate::tools::ToolResult> {
654 let fut = self.tool_executor.execute_with_context(name, args, ctx);
655 if let Some(timeout_ms) = self.config.tool_timeout_ms {
656 match tokio::time::timeout(Duration::from_millis(timeout_ms), fut).await {
657 Ok(result) => result,
658 Err(_) => Err(anyhow::anyhow!(
659 "Tool '{}' timed out after {}ms",
660 name,
661 timeout_ms
662 )),
663 }
664 } else {
665 fut.await
666 }
667 }
668
669 fn tool_result_to_tuple(
671 result: anyhow::Result<crate::tools::ToolResult>,
672 ) -> (
673 String,
674 i32,
675 bool,
676 Option<serde_json::Value>,
677 Vec<crate::llm::Attachment>,
678 ) {
679 match result {
680 Ok(r) => (
681 r.output,
682 r.exit_code,
683 r.exit_code != 0,
684 r.metadata,
685 r.images,
686 ),
687 Err(e) => {
688 let msg = e.to_string();
689 let hint = if Self::is_transient_error(&msg) {
691 " [transient — you may retry this tool call]"
692 } else {
693 " [permanent — do not retry without changing the arguments]"
694 };
695 (
696 format!("Tool execution error: {}{}", msg, hint),
697 1,
698 true,
699 None,
700 Vec::new(),
701 )
702 }
703 }
704 }
705
706 fn detect_project_hint(workspace: &std::path::Path) -> String {
710 struct Marker {
711 file: &'static str,
712 lang: &'static str,
713 tip: &'static str,
714 }
715
716 let markers = [
717 Marker {
718 file: "Cargo.toml",
719 lang: "Rust",
720 tip: "Use `cargo build`, `cargo test`, `cargo clippy`, and `cargo fmt`. \
721 Prefer `anyhow` / `thiserror` for error handling. \
722 Follow the Microsoft Rust Guidelines (no panics in library code, \
723 async-first with Tokio).",
724 },
725 Marker {
726 file: "package.json",
727 lang: "Node.js / TypeScript",
728 tip: "Check `package.json` for the package manager (npm/yarn/pnpm/bun) \
729 and available scripts. Prefer TypeScript with strict mode. \
730 Use ESM imports unless the project is CommonJS.",
731 },
732 Marker {
733 file: "pyproject.toml",
734 lang: "Python",
735 tip: "Use the package manager declared in `pyproject.toml` \
736 (uv, poetry, hatch, etc.). Prefer type hints and async/await for I/O.",
737 },
738 Marker {
739 file: "setup.py",
740 lang: "Python",
741 tip: "Legacy Python project. Prefer type hints and async/await for I/O.",
742 },
743 Marker {
744 file: "requirements.txt",
745 lang: "Python",
746 tip: "Python project with pip-style dependencies. \
747 Prefer type hints and async/await for I/O.",
748 },
749 Marker {
750 file: "go.mod",
751 lang: "Go",
752 tip: "Use `go build ./...` and `go test ./...`. \
753 Follow standard Go project layout. Use `gofmt` for formatting.",
754 },
755 Marker {
756 file: "pom.xml",
757 lang: "Java / Maven",
758 tip: "Use `mvn compile`, `mvn test`, `mvn package`. \
759 Follow standard Maven project structure.",
760 },
761 Marker {
762 file: "build.gradle",
763 lang: "Java / Gradle",
764 tip: "Use `./gradlew build` and `./gradlew test`. \
765 Follow standard Gradle project structure.",
766 },
767 Marker {
768 file: "build.gradle.kts",
769 lang: "Kotlin / Gradle",
770 tip: "Use `./gradlew build` and `./gradlew test`. \
771 Prefer Kotlin coroutines for async work.",
772 },
773 Marker {
774 file: "CMakeLists.txt",
775 lang: "C / C++",
776 tip: "Use `cmake -B build && cmake --build build`. \
777 Check for `compile_commands.json` for IDE tooling.",
778 },
779 Marker {
780 file: "Makefile",
781 lang: "C / C++ (or generic)",
782 tip: "Use `make` or `make <target>`. \
783 Check available targets with `make help` or by reading the Makefile.",
784 },
785 ];
786
787 let is_dotnet = workspace.join("*.csproj").exists() || {
789 std::fs::read_dir(workspace)
791 .map(|entries| {
792 entries.flatten().any(|e| {
793 let name = e.file_name();
794 let s = name.to_string_lossy();
795 s.ends_with(".csproj") || s.ends_with(".sln")
796 })
797 })
798 .unwrap_or(false)
799 };
800
801 if is_dotnet {
802 return "## Project Context\n\nThis is a **C# / .NET** project. \
803 Use `dotnet build`, `dotnet test`, and `dotnet run`. \
804 Follow C# coding conventions and async/await patterns."
805 .to_string();
806 }
807
808 for marker in &markers {
809 if workspace.join(marker.file).exists() {
810 return format!(
811 "## Project Context\n\nThis is a **{}** project. {}",
812 marker.lang, marker.tip
813 );
814 }
815 }
816
817 String::new()
818 }
819
820 fn is_transient_error(msg: &str) -> bool {
823 let lower = msg.to_lowercase();
824 lower.contains("timeout")
825 || lower.contains("timed out")
826 || lower.contains("connection refused")
827 || lower.contains("connection reset")
828 || lower.contains("broken pipe")
829 || lower.contains("temporarily unavailable")
830 || lower.contains("resource temporarily unavailable")
831 || lower.contains("os error 11") || lower.contains("os error 35") || lower.contains("rate limit")
834 || lower.contains("too many requests")
835 || lower.contains("service unavailable")
836 || lower.contains("network unreachable")
837 }
838
839 fn is_parallel_safe_write(name: &str, _args: &serde_json::Value) -> bool {
842 matches!(
843 name,
844 "write_file" | "edit_file" | "create_file" | "append_to_file" | "replace_in_file"
845 )
846 }
847
848 fn extract_write_path(args: &serde_json::Value) -> Option<String> {
850 args.get("path")
853 .and_then(|v| v.as_str())
854 .map(|s| s.to_string())
855 }
856
857 async fn execute_tool_queued_or_direct(
859 &self,
860 name: &str,
861 args: &serde_json::Value,
862 ctx: &ToolContext,
863 ) -> anyhow::Result<crate::tools::ToolResult> {
864 if let Some(ref queue) = self.command_queue {
865 let command = ToolCommand::new(
866 Arc::clone(&self.tool_executor),
867 name.to_string(),
868 args.clone(),
869 ctx.clone(),
870 self.config.skill_registry.clone(),
871 );
872 let rx = queue.submit_by_tool(name, Box::new(command)).await;
873 match rx.await {
874 Ok(Ok(value)) => {
875 let output = value["output"]
876 .as_str()
877 .ok_or_else(|| {
878 anyhow::anyhow!(
879 "Queue result missing 'output' field for tool '{}'",
880 name
881 )
882 })?
883 .to_string();
884 let exit_code = value["exit_code"].as_i64().unwrap_or(0) as i32;
885 return Ok(crate::tools::ToolResult {
886 name: name.to_string(),
887 output,
888 exit_code,
889 metadata: None,
890 images: Vec::new(),
891 });
892 }
893 Ok(Err(e)) => {
894 tracing::warn!(
895 "Queue execution failed for tool '{}', falling back to direct: {}",
896 name,
897 e
898 );
899 }
900 Err(_) => {
901 tracing::warn!(
902 "Queue channel closed for tool '{}', falling back to direct",
903 name
904 );
905 }
906 }
907 }
908 self.execute_tool_timed(name, args, ctx).await
909 }
910
911 async fn call_llm(
922 &self,
923 messages: &[Message],
924 system: Option<&str>,
925 event_tx: &Option<mpsc::Sender<AgentEvent>>,
926 ) -> anyhow::Result<LlmResponse> {
927 let tools = if let Some(ref index) = self.config.tool_index {
929 let query = messages
930 .iter()
931 .rev()
932 .find(|m| m.role == "user")
933 .and_then(|m| {
934 m.content.iter().find_map(|b| match b {
935 crate::llm::ContentBlock::Text { text } => Some(text.as_str()),
936 _ => None,
937 })
938 })
939 .unwrap_or("");
940 let matches = index.search(query, index.len());
941 let matched_names: std::collections::HashSet<&str> =
942 matches.iter().map(|m| m.name.as_str()).collect();
943 self.config
944 .tools
945 .iter()
946 .filter(|t| matched_names.contains(t.name.as_str()))
947 .cloned()
948 .collect::<Vec<_>>()
949 } else {
950 self.config.tools.clone()
951 };
952
953 if event_tx.is_some() {
954 let mut stream_rx = self
955 .llm_client
956 .complete_streaming(messages, system, &tools)
957 .await
958 .context("LLM streaming call failed")?;
959
960 let mut final_response: Option<LlmResponse> = None;
961 while let Some(event) = stream_rx.recv().await {
962 match event {
963 crate::llm::StreamEvent::TextDelta(text) => {
964 if let Some(tx) = event_tx {
965 tx.send(AgentEvent::TextDelta { text }).await.ok();
966 }
967 }
968 crate::llm::StreamEvent::ToolUseStart { id, name } => {
969 if let Some(tx) = event_tx {
970 tx.send(AgentEvent::ToolStart { id, name }).await.ok();
971 }
972 }
973 crate::llm::StreamEvent::ToolUseInputDelta(delta) => {
974 if let Some(tx) = event_tx {
975 tx.send(AgentEvent::ToolInputDelta { delta }).await.ok();
976 }
977 }
978 crate::llm::StreamEvent::Done(resp) => {
979 final_response = Some(resp);
980 }
981 }
982 }
983 final_response.context("Stream ended without final response")
984 } else {
985 self.llm_client
986 .complete(messages, system, &tools)
987 .await
988 .context("LLM call failed")
989 }
990 }
991
992 fn streaming_tool_context(
1001 &self,
1002 event_tx: &Option<mpsc::Sender<AgentEvent>>,
1003 tool_id: &str,
1004 tool_name: &str,
1005 ) -> ToolContext {
1006 let mut ctx = self.tool_context.clone();
1007 if let Some(agent_tx) = event_tx {
1008 let (tool_tx, mut tool_rx) = mpsc::channel::<ToolStreamEvent>(64);
1009 ctx.event_tx = Some(tool_tx);
1010
1011 let agent_tx = agent_tx.clone();
1012 let tool_id = tool_id.to_string();
1013 let tool_name = tool_name.to_string();
1014 tokio::spawn(async move {
1015 while let Some(event) = tool_rx.recv().await {
1016 match event {
1017 ToolStreamEvent::OutputDelta(delta) => {
1018 agent_tx
1019 .send(AgentEvent::ToolOutputDelta {
1020 id: tool_id.clone(),
1021 name: tool_name.clone(),
1022 delta,
1023 })
1024 .await
1025 .ok();
1026 }
1027 }
1028 }
1029 });
1030 }
1031 ctx
1032 }
1033
1034 async fn resolve_context(&self, prompt: &str, session_id: Option<&str>) -> Vec<ContextResult> {
1038 if self.config.context_providers.is_empty() {
1039 return Vec::new();
1040 }
1041
1042 let query = ContextQuery::new(prompt).with_session_id(session_id.unwrap_or(""));
1043
1044 let futures = self
1045 .config
1046 .context_providers
1047 .iter()
1048 .map(|p| p.query(&query));
1049 let outcomes = join_all(futures).await;
1050
1051 outcomes
1052 .into_iter()
1053 .enumerate()
1054 .filter_map(|(i, r)| match r {
1055 Ok(result) if !result.is_empty() => Some(result),
1056 Ok(_) => None,
1057 Err(e) => {
1058 tracing::warn!(
1059 "Context provider '{}' failed: {}",
1060 self.config.context_providers[i].name(),
1061 e
1062 );
1063 None
1064 }
1065 })
1066 .collect()
1067 }
1068
1069 fn looks_incomplete(text: &str) -> bool {
1077 let t = text.trim();
1078 if t.is_empty() {
1079 return true;
1080 }
1081 if t.len() < 80 && !t.contains('\n') {
1083 let ends_continuation =
1086 t.ends_with(':') || t.ends_with("...") || t.ends_with('…') || t.ends_with(',');
1087 if ends_continuation {
1088 return true;
1089 }
1090 }
1091 let incomplete_phrases = [
1093 "i'll ",
1094 "i will ",
1095 "let me ",
1096 "i need to ",
1097 "i should ",
1098 "next, i",
1099 "first, i",
1100 "now i",
1101 "i'll start",
1102 "i'll begin",
1103 "i'll now",
1104 "let's start",
1105 "let's begin",
1106 "to do this",
1107 "i'm going to",
1108 ];
1109 let lower = t.to_lowercase();
1110 for phrase in &incomplete_phrases {
1111 if lower.contains(phrase) {
1112 return true;
1113 }
1114 }
1115 false
1116 }
1117
1118 fn system_prompt(&self) -> String {
1120 self.config.prompt_slots.build()
1121 }
1122
1123 fn build_augmented_system_prompt(&self, context_results: &[ContextResult]) -> Option<String> {
1125 let base = self.system_prompt();
1126
1127 let live_tools = self.tool_executor.definitions();
1129 let mcp_tools: Vec<&ToolDefinition> = live_tools
1130 .iter()
1131 .filter(|t| t.name.starts_with("mcp__"))
1132 .collect();
1133
1134 let mcp_section = if mcp_tools.is_empty() {
1135 String::new()
1136 } else {
1137 let mut lines = vec![
1138 "## MCP Tools".to_string(),
1139 String::new(),
1140 "The following MCP (Model Context Protocol) tools are available. Use them when the task requires external capabilities beyond the built-in tools:".to_string(),
1141 String::new(),
1142 ];
1143 for tool in &mcp_tools {
1144 let display = format!("- `{}` — {}", tool.name, tool.description);
1145 lines.push(display);
1146 }
1147 lines.join("\n")
1148 };
1149
1150 let parts: Vec<&str> = [base.as_str(), mcp_section.as_str()]
1151 .iter()
1152 .filter(|s| !s.is_empty())
1153 .copied()
1154 .collect();
1155
1156 let project_hint = if self.config.prompt_slots.guidelines.is_none() {
1159 Self::detect_project_hint(&self.tool_context.workspace)
1160 } else {
1161 String::new()
1162 };
1163
1164 if context_results.is_empty() {
1165 if project_hint.is_empty() {
1166 return Some(parts.join("\n\n"));
1167 }
1168 return Some(format!("{}\n\n{}", parts.join("\n\n"), project_hint));
1169 }
1170
1171 let context_xml: String = context_results
1173 .iter()
1174 .map(|r| r.to_xml())
1175 .collect::<Vec<_>>()
1176 .join("\n\n");
1177
1178 if project_hint.is_empty() {
1179 Some(format!("{}\n\n{}", parts.join("\n\n"), context_xml))
1180 } else {
1181 Some(format!(
1182 "{}\n\n{}\n\n{}",
1183 parts.join("\n\n"),
1184 project_hint,
1185 context_xml
1186 ))
1187 }
1188 }
1189
1190 async fn notify_turn_complete(&self, session_id: &str, prompt: &str, response: &str) {
1192 let futures = self
1193 .config
1194 .context_providers
1195 .iter()
1196 .map(|p| p.on_turn_complete(session_id, prompt, response));
1197 let outcomes = join_all(futures).await;
1198
1199 for (i, result) in outcomes.into_iter().enumerate() {
1200 if let Err(e) = result {
1201 tracing::warn!(
1202 "Context provider '{}' on_turn_complete failed: {}",
1203 self.config.context_providers[i].name(),
1204 e
1205 );
1206 }
1207 }
1208 }
1209
1210 async fn fire_pre_tool_use(
1213 &self,
1214 session_id: &str,
1215 tool_name: &str,
1216 args: &serde_json::Value,
1217 ) -> Option<HookResult> {
1218 if let Some(he) = &self.config.hook_engine {
1219 let event = HookEvent::PreToolUse(PreToolUseEvent {
1220 session_id: session_id.to_string(),
1221 tool: tool_name.to_string(),
1222 args: args.clone(),
1223 working_directory: self.tool_context.workspace.to_string_lossy().to_string(),
1224 recent_tools: Vec::new(),
1225 });
1226 let result = he.fire(&event).await;
1227 if result.is_block() {
1228 return Some(result);
1229 }
1230 }
1231 None
1232 }
1233
1234 async fn fire_post_tool_use(
1236 &self,
1237 session_id: &str,
1238 tool_name: &str,
1239 args: &serde_json::Value,
1240 output: &str,
1241 success: bool,
1242 duration_ms: u64,
1243 ) {
1244 if let Some(he) = &self.config.hook_engine {
1245 let event = HookEvent::PostToolUse(PostToolUseEvent {
1246 session_id: session_id.to_string(),
1247 tool: tool_name.to_string(),
1248 args: args.clone(),
1249 result: ToolResultData {
1250 success,
1251 output: output.to_string(),
1252 exit_code: if success { Some(0) } else { Some(1) },
1253 duration_ms,
1254 },
1255 });
1256 let he = Arc::clone(he);
1257 tokio::spawn(async move {
1258 let _ = he.fire(&event).await;
1259 });
1260 }
1261 }
1262
1263 async fn fire_generate_start(
1265 &self,
1266 session_id: &str,
1267 prompt: &str,
1268 system_prompt: &Option<String>,
1269 ) {
1270 if let Some(he) = &self.config.hook_engine {
1271 let event = HookEvent::GenerateStart(GenerateStartEvent {
1272 session_id: session_id.to_string(),
1273 prompt: prompt.to_string(),
1274 system_prompt: system_prompt.clone(),
1275 model_provider: String::new(),
1276 model_name: String::new(),
1277 available_tools: self.config.tools.iter().map(|t| t.name.clone()).collect(),
1278 });
1279 let _ = he.fire(&event).await;
1280 }
1281 }
1282
1283 async fn fire_generate_end(
1285 &self,
1286 session_id: &str,
1287 prompt: &str,
1288 response: &LlmResponse,
1289 duration_ms: u64,
1290 ) {
1291 if let Some(he) = &self.config.hook_engine {
1292 let tool_calls: Vec<ToolCallInfo> = response
1293 .tool_calls()
1294 .iter()
1295 .map(|tc| ToolCallInfo {
1296 name: tc.name.clone(),
1297 args: tc.args.clone(),
1298 })
1299 .collect();
1300
1301 let event = HookEvent::GenerateEnd(GenerateEndEvent {
1302 session_id: session_id.to_string(),
1303 prompt: prompt.to_string(),
1304 response_text: response.text().to_string(),
1305 tool_calls,
1306 usage: TokenUsageInfo {
1307 prompt_tokens: response.usage.prompt_tokens as i32,
1308 completion_tokens: response.usage.completion_tokens as i32,
1309 total_tokens: response.usage.total_tokens as i32,
1310 },
1311 duration_ms,
1312 });
1313 let _ = he.fire(&event).await;
1314 }
1315 }
1316
1317 async fn fire_pre_prompt(
1320 &self,
1321 session_id: &str,
1322 prompt: &str,
1323 system_prompt: &Option<String>,
1324 message_count: usize,
1325 ) -> Option<String> {
1326 if let Some(he) = &self.config.hook_engine {
1327 let event = HookEvent::PrePrompt(PrePromptEvent {
1328 session_id: session_id.to_string(),
1329 prompt: prompt.to_string(),
1330 system_prompt: system_prompt.clone(),
1331 message_count,
1332 });
1333 let result = he.fire(&event).await;
1334 if let HookResult::Continue(Some(modified)) = result {
1335 if let Some(new_prompt) = modified.get("prompt").and_then(|v| v.as_str()) {
1337 return Some(new_prompt.to_string());
1338 }
1339 }
1340 }
1341 None
1342 }
1343
1344 async fn fire_post_response(
1346 &self,
1347 session_id: &str,
1348 response_text: &str,
1349 tool_calls_count: usize,
1350 usage: &TokenUsage,
1351 duration_ms: u64,
1352 ) {
1353 if let Some(he) = &self.config.hook_engine {
1354 let event = HookEvent::PostResponse(PostResponseEvent {
1355 session_id: session_id.to_string(),
1356 response_text: response_text.to_string(),
1357 tool_calls_count,
1358 usage: TokenUsageInfo {
1359 prompt_tokens: usage.prompt_tokens as i32,
1360 completion_tokens: usage.completion_tokens as i32,
1361 total_tokens: usage.total_tokens as i32,
1362 },
1363 duration_ms,
1364 });
1365 let he = Arc::clone(he);
1366 tokio::spawn(async move {
1367 let _ = he.fire(&event).await;
1368 });
1369 }
1370 }
1371
1372 async fn fire_on_error(
1374 &self,
1375 session_id: &str,
1376 error_type: ErrorType,
1377 error_message: &str,
1378 context: serde_json::Value,
1379 ) {
1380 if let Some(he) = &self.config.hook_engine {
1381 let event = HookEvent::OnError(OnErrorEvent {
1382 session_id: session_id.to_string(),
1383 error_type,
1384 error_message: error_message.to_string(),
1385 context,
1386 });
1387 let he = Arc::clone(he);
1388 tokio::spawn(async move {
1389 let _ = he.fire(&event).await;
1390 });
1391 }
1392 }
1393
1394 pub async fn execute(
1400 &self,
1401 history: &[Message],
1402 prompt: &str,
1403 event_tx: Option<mpsc::Sender<AgentEvent>>,
1404 ) -> Result<AgentResult> {
1405 self.execute_with_session(history, prompt, None, event_tx)
1406 .await
1407 }
1408
1409 pub async fn execute_from_messages(
1415 &self,
1416 messages: Vec<Message>,
1417 session_id: Option<&str>,
1418 event_tx: Option<mpsc::Sender<AgentEvent>>,
1419 ) -> Result<AgentResult> {
1420 tracing::info!(
1421 a3s.session.id = session_id.unwrap_or("none"),
1422 a3s.agent.max_turns = self.config.max_tool_rounds,
1423 "a3s.agent.execute_from_messages started"
1424 );
1425
1426 let effective_prompt = messages
1430 .iter()
1431 .rev()
1432 .find(|m| m.role == "user")
1433 .map(|m| m.text())
1434 .unwrap_or_default();
1435
1436 let result = self
1437 .execute_loop_inner(&messages, "", &effective_prompt, session_id, event_tx)
1438 .await;
1439
1440 match &result {
1441 Ok(r) => tracing::info!(
1442 a3s.agent.tool_calls_count = r.tool_calls_count,
1443 a3s.llm.total_tokens = r.usage.total_tokens,
1444 "a3s.agent.execute_from_messages completed"
1445 ),
1446 Err(e) => tracing::warn!(
1447 error = %e,
1448 "a3s.agent.execute_from_messages failed"
1449 ),
1450 }
1451
1452 result
1453 }
1454
1455 pub async fn execute_with_session(
1460 &self,
1461 history: &[Message],
1462 prompt: &str,
1463 session_id: Option<&str>,
1464 event_tx: Option<mpsc::Sender<AgentEvent>>,
1465 ) -> Result<AgentResult> {
1466 tracing::info!(
1467 a3s.session.id = session_id.unwrap_or("none"),
1468 a3s.agent.max_turns = self.config.max_tool_rounds,
1469 "a3s.agent.execute started"
1470 );
1471
1472 let result = if self.config.planning_enabled {
1474 self.execute_with_planning(history, prompt, event_tx).await
1475 } else {
1476 self.execute_loop(history, prompt, session_id, event_tx)
1477 .await
1478 };
1479
1480 match &result {
1481 Ok(r) => {
1482 tracing::info!(
1483 a3s.agent.tool_calls_count = r.tool_calls_count,
1484 a3s.llm.total_tokens = r.usage.total_tokens,
1485 "a3s.agent.execute completed"
1486 );
1487 self.fire_post_response(
1489 session_id.unwrap_or(""),
1490 &r.text,
1491 r.tool_calls_count,
1492 &r.usage,
1493 0, )
1495 .await;
1496 }
1497 Err(e) => {
1498 tracing::warn!(
1499 error = %e,
1500 "a3s.agent.execute failed"
1501 );
1502 self.fire_on_error(
1504 session_id.unwrap_or(""),
1505 ErrorType::Other,
1506 &e.to_string(),
1507 serde_json::json!({"phase": "execute"}),
1508 )
1509 .await;
1510 }
1511 }
1512
1513 result
1514 }
1515
1516 async fn execute_loop(
1522 &self,
1523 history: &[Message],
1524 prompt: &str,
1525 session_id: Option<&str>,
1526 event_tx: Option<mpsc::Sender<AgentEvent>>,
1527 ) -> Result<AgentResult> {
1528 self.execute_loop_inner(history, prompt, prompt, session_id, event_tx)
1531 .await
1532 }
1533
1534 async fn execute_loop_inner(
1539 &self,
1540 history: &[Message],
1541 msg_prompt: &str,
1542 effective_prompt: &str,
1543 session_id: Option<&str>,
1544 event_tx: Option<mpsc::Sender<AgentEvent>>,
1545 ) -> Result<AgentResult> {
1546 let mut messages = history.to_vec();
1547 let mut total_usage = TokenUsage::default();
1548 let mut tool_calls_count = 0;
1549 let mut turn = 0;
1550 let mut parse_error_count: u32 = 0;
1552 let mut continuation_count: u32 = 0;
1554
1555 if let Some(tx) = &event_tx {
1557 tx.send(AgentEvent::Start {
1558 prompt: effective_prompt.to_string(),
1559 })
1560 .await
1561 .ok();
1562 }
1563
1564 let _queue_forward_handle =
1566 if let (Some(ref queue), Some(ref tx)) = (&self.command_queue, &event_tx) {
1567 let mut rx = queue.subscribe();
1568 let tx = tx.clone();
1569 Some(tokio::spawn(async move {
1570 while let Ok(event) = rx.recv().await {
1571 if tx.send(event).await.is_err() {
1572 break;
1573 }
1574 }
1575 }))
1576 } else {
1577 None
1578 };
1579
1580 let built_system_prompt = Some(self.system_prompt());
1582 let hooked_prompt = if let Some(modified) = self
1583 .fire_pre_prompt(
1584 session_id.unwrap_or(""),
1585 effective_prompt,
1586 &built_system_prompt,
1587 messages.len(),
1588 )
1589 .await
1590 {
1591 modified
1592 } else {
1593 effective_prompt.to_string()
1594 };
1595 let effective_prompt = hooked_prompt.as_str();
1596
1597 if let Some(ref sp) = self.config.security_provider {
1599 sp.taint_input(effective_prompt);
1600 }
1601
1602 let system_with_memory = if let Some(ref memory) = self.config.memory {
1604 match memory.recall_similar(effective_prompt, 5).await {
1605 Ok(items) if !items.is_empty() => {
1606 if let Some(tx) = &event_tx {
1607 for item in &items {
1608 tx.send(AgentEvent::MemoryRecalled {
1609 memory_id: item.id.clone(),
1610 content: item.content.clone(),
1611 relevance: item.relevance_score(),
1612 })
1613 .await
1614 .ok();
1615 }
1616 tx.send(AgentEvent::MemoriesSearched {
1617 query: Some(effective_prompt.to_string()),
1618 tags: Vec::new(),
1619 result_count: items.len(),
1620 })
1621 .await
1622 .ok();
1623 }
1624 let memory_context = items
1625 .iter()
1626 .map(|i| format!("- {}", i.content))
1627 .collect::<Vec<_>>()
1628 .join(
1629 "
1630",
1631 );
1632 let base = self.system_prompt();
1633 Some(format!(
1634 "{}
1635
1636## Relevant past experience
1637{}",
1638 base, memory_context
1639 ))
1640 }
1641 _ => Some(self.system_prompt()),
1642 }
1643 } else {
1644 Some(self.system_prompt())
1645 };
1646
1647 let augmented_system = if !self.config.context_providers.is_empty() {
1649 if let Some(tx) = &event_tx {
1651 let provider_names: Vec<String> = self
1652 .config
1653 .context_providers
1654 .iter()
1655 .map(|p| p.name().to_string())
1656 .collect();
1657 tx.send(AgentEvent::ContextResolving {
1658 providers: provider_names,
1659 })
1660 .await
1661 .ok();
1662 }
1663
1664 tracing::info!(
1665 a3s.context.providers = self.config.context_providers.len() as i64,
1666 "Context resolution started"
1667 );
1668 let context_results = self.resolve_context(effective_prompt, session_id).await;
1669
1670 if let Some(tx) = &event_tx {
1672 let total_items: usize = context_results.iter().map(|r| r.items.len()).sum();
1673 let total_tokens: usize = context_results.iter().map(|r| r.total_tokens).sum();
1674
1675 tracing::info!(
1676 context_items = total_items,
1677 context_tokens = total_tokens,
1678 "Context resolution completed"
1679 );
1680
1681 tx.send(AgentEvent::ContextResolved {
1682 total_items,
1683 total_tokens,
1684 })
1685 .await
1686 .ok();
1687 }
1688
1689 self.build_augmented_system_prompt(&context_results)
1690 } else {
1691 Some(self.system_prompt())
1692 };
1693
1694 let base_prompt = self.system_prompt();
1696 let augmented_system = match (augmented_system, system_with_memory) {
1697 (Some(ctx), Some(mem)) if ctx != mem => Some(ctx.replacen(&base_prompt, &mem, 1)),
1698 (Some(ctx), _) => Some(ctx),
1699 (None, mem) => mem,
1700 };
1701
1702 if !msg_prompt.is_empty() {
1704 messages.push(Message::user(msg_prompt));
1705 }
1706
1707 loop {
1708 turn += 1;
1709
1710 if turn > self.config.max_tool_rounds {
1711 let error = format!("Max tool rounds ({}) exceeded", self.config.max_tool_rounds);
1712 if let Some(tx) = &event_tx {
1713 tx.send(AgentEvent::Error {
1714 message: error.clone(),
1715 })
1716 .await
1717 .ok();
1718 }
1719 anyhow::bail!(error);
1720 }
1721
1722 if let Some(tx) = &event_tx {
1724 tx.send(AgentEvent::TurnStart { turn }).await.ok();
1725 }
1726
1727 tracing::info!(
1728 turn = turn,
1729 max_turns = self.config.max_tool_rounds,
1730 "Agent turn started"
1731 );
1732
1733 tracing::info!(
1735 a3s.llm.streaming = event_tx.is_some(),
1736 "LLM completion started"
1737 );
1738
1739 self.fire_generate_start(
1741 session_id.unwrap_or(""),
1742 effective_prompt,
1743 &augmented_system,
1744 )
1745 .await;
1746
1747 let llm_start = std::time::Instant::now();
1748 let response = {
1752 let threshold = self.config.circuit_breaker_threshold.max(1);
1753 let mut attempt = 0u32;
1754 loop {
1755 attempt += 1;
1756 let result = self
1757 .call_llm(&messages, augmented_system.as_deref(), &event_tx)
1758 .await;
1759 match result {
1760 Ok(r) => {
1761 break r;
1762 }
1763 Err(e) if attempt < threshold && (event_tx.is_none() || attempt == 1) => {
1765 tracing::warn!(
1766 turn = turn,
1767 attempt = attempt,
1768 threshold = threshold,
1769 error = %e,
1770 "LLM call failed, will retry"
1771 );
1772 tokio::time::sleep(Duration::from_millis(100 * attempt as u64)).await;
1773 }
1774 Err(e) => {
1776 let msg = if attempt > 1 {
1777 format!(
1778 "LLM circuit breaker triggered: failed after {} attempt(s): {}",
1779 attempt, e
1780 )
1781 } else {
1782 format!("LLM call failed: {}", e)
1783 };
1784 tracing::error!(turn = turn, attempt = attempt, "{}", msg);
1785 self.fire_on_error(
1787 session_id.unwrap_or(""),
1788 ErrorType::LlmFailure,
1789 &msg,
1790 serde_json::json!({"turn": turn, "attempt": attempt}),
1791 )
1792 .await;
1793 if let Some(tx) = &event_tx {
1794 tx.send(AgentEvent::Error {
1795 message: msg.clone(),
1796 })
1797 .await
1798 .ok();
1799 }
1800 anyhow::bail!(msg);
1801 }
1802 }
1803 }
1804 };
1805
1806 total_usage.prompt_tokens += response.usage.prompt_tokens;
1808 total_usage.completion_tokens += response.usage.completion_tokens;
1809 total_usage.total_tokens += response.usage.total_tokens;
1810
1811 let llm_duration = llm_start.elapsed();
1813 tracing::info!(
1814 turn = turn,
1815 streaming = event_tx.is_some(),
1816 prompt_tokens = response.usage.prompt_tokens,
1817 completion_tokens = response.usage.completion_tokens,
1818 total_tokens = response.usage.total_tokens,
1819 stop_reason = response.stop_reason.as_deref().unwrap_or("unknown"),
1820 duration_ms = llm_duration.as_millis() as u64,
1821 "LLM completion finished"
1822 );
1823
1824 self.fire_generate_end(
1826 session_id.unwrap_or(""),
1827 effective_prompt,
1828 &response,
1829 llm_duration.as_millis() as u64,
1830 )
1831 .await;
1832
1833 crate::telemetry::record_llm_usage(
1835 response.usage.prompt_tokens,
1836 response.usage.completion_tokens,
1837 response.usage.total_tokens,
1838 response.stop_reason.as_deref(),
1839 );
1840 tracing::info!(
1842 turn = turn,
1843 a3s.llm.total_tokens = response.usage.total_tokens,
1844 "Turn token usage"
1845 );
1846
1847 messages.push(response.message.clone());
1849
1850 let tool_calls = response.tool_calls();
1852
1853 if let Some(tx) = &event_tx {
1855 tx.send(AgentEvent::TurnEnd {
1856 turn,
1857 usage: response.usage.clone(),
1858 })
1859 .await
1860 .ok();
1861 }
1862
1863 if self.config.auto_compact {
1865 let used = response.usage.prompt_tokens;
1866 let max = self.config.max_context_tokens;
1867 let threshold = self.config.auto_compact_threshold;
1868
1869 if crate::session::compaction::should_auto_compact(used, max, threshold) {
1870 let before_len = messages.len();
1871 let percent_before = used as f32 / max as f32;
1872
1873 tracing::info!(
1874 used_tokens = used,
1875 max_tokens = max,
1876 percent = percent_before,
1877 threshold = threshold,
1878 "Auto-compact triggered"
1879 );
1880
1881 if let Some(pruned) = crate::session::compaction::prune_tool_outputs(&messages)
1883 {
1884 messages = pruned;
1885 tracing::info!("Tool output pruning applied");
1886 }
1887
1888 if let Ok(Some(compacted)) = crate::session::compaction::compact_messages(
1890 session_id.unwrap_or(""),
1891 &messages,
1892 &self.llm_client,
1893 )
1894 .await
1895 {
1896 messages = compacted;
1897 }
1898
1899 if let Some(tx) = &event_tx {
1901 tx.send(AgentEvent::ContextCompacted {
1902 session_id: session_id.unwrap_or("").to_string(),
1903 before_messages: before_len,
1904 after_messages: messages.len(),
1905 percent_before,
1906 })
1907 .await
1908 .ok();
1909 }
1910 }
1911 }
1912
1913 if tool_calls.is_empty() {
1914 let final_text = response.text();
1917
1918 if self.config.continuation_enabled
1919 && continuation_count < self.config.max_continuation_turns
1920 && turn < self.config.max_tool_rounds && Self::looks_incomplete(&final_text)
1922 {
1923 continuation_count += 1;
1924 tracing::info!(
1925 turn = turn,
1926 continuation = continuation_count,
1927 max_continuation = self.config.max_continuation_turns,
1928 "Injecting continuation message — response looks incomplete"
1929 );
1930 messages.push(Message::user(crate::prompts::CONTINUATION));
1932 continue;
1933 }
1934
1935 let final_text = if let Some(ref sp) = self.config.security_provider {
1937 sp.sanitize_output(&final_text)
1938 } else {
1939 final_text
1940 };
1941
1942 tracing::info!(
1944 tool_calls_count = tool_calls_count,
1945 total_prompt_tokens = total_usage.prompt_tokens,
1946 total_completion_tokens = total_usage.completion_tokens,
1947 total_tokens = total_usage.total_tokens,
1948 turns = turn,
1949 "Agent execution completed"
1950 );
1951
1952 if let Some(tx) = &event_tx {
1953 tx.send(AgentEvent::End {
1954 text: final_text.clone(),
1955 usage: total_usage.clone(),
1956 })
1957 .await
1958 .ok();
1959 }
1960
1961 if let Some(sid) = session_id {
1963 self.notify_turn_complete(sid, effective_prompt, &final_text)
1964 .await;
1965 }
1966
1967 return Ok(AgentResult {
1968 text: final_text,
1969 messages,
1970 usage: total_usage,
1971 tool_calls_count,
1972 });
1973 }
1974
1975 let tool_calls = if self.config.hook_engine.is_none()
1979 && self.config.confirmation_manager.is_none()
1980 && tool_calls.len() > 1
1981 && tool_calls
1982 .iter()
1983 .all(|tc| Self::is_parallel_safe_write(&tc.name, &tc.args))
1984 && {
1985 let paths: Vec<_> = tool_calls
1987 .iter()
1988 .filter_map(|tc| Self::extract_write_path(&tc.args))
1989 .collect();
1990 paths.len() == tool_calls.len()
1991 && paths.iter().collect::<std::collections::HashSet<_>>().len()
1992 == paths.len()
1993 } {
1994 tracing::info!(
1995 count = tool_calls.len(),
1996 "Parallel write batch: executing {} independent file writes concurrently",
1997 tool_calls.len()
1998 );
1999
2000 let futures: Vec<_> = tool_calls
2001 .iter()
2002 .map(|tc| {
2003 let ctx = self.tool_context.clone();
2004 let executor = Arc::clone(&self.tool_executor);
2005 let name = tc.name.clone();
2006 let args = tc.args.clone();
2007 async move { executor.execute_with_context(&name, &args, &ctx).await }
2008 })
2009 .collect();
2010
2011 let results = join_all(futures).await;
2012
2013 for (tc, result) in tool_calls.iter().zip(results) {
2015 tool_calls_count += 1;
2016 let (output, exit_code, is_error, metadata, images) =
2017 Self::tool_result_to_tuple(result);
2018
2019 let output = if let Some(ref sp) = self.config.security_provider {
2020 sp.sanitize_output(&output)
2021 } else {
2022 output
2023 };
2024
2025 if let Some(tx) = &event_tx {
2026 tx.send(AgentEvent::ToolEnd {
2027 id: tc.id.clone(),
2028 name: tc.name.clone(),
2029 output: output.clone(),
2030 exit_code,
2031 metadata,
2032 })
2033 .await
2034 .ok();
2035 }
2036
2037 if images.is_empty() {
2038 messages.push(Message::tool_result(&tc.id, &output, is_error));
2039 } else {
2040 messages.push(Message::tool_result_with_images(
2041 &tc.id, &output, &images, is_error,
2042 ));
2043 }
2044 }
2045
2046 continue;
2048 } else {
2049 tool_calls
2050 };
2051
2052 for tool_call in tool_calls {
2053 tool_calls_count += 1;
2054
2055 let tool_start = std::time::Instant::now();
2056
2057 tracing::info!(
2058 tool_name = tool_call.name.as_str(),
2059 tool_id = tool_call.id.as_str(),
2060 "Tool execution started"
2061 );
2062
2063 if let Some(parse_error) =
2069 tool_call.args.get("__parse_error").and_then(|v| v.as_str())
2070 {
2071 parse_error_count += 1;
2072 let error_msg = format!("Error: {}", parse_error);
2073 tracing::warn!(
2074 tool = tool_call.name.as_str(),
2075 parse_error_count = parse_error_count,
2076 max_parse_retries = self.config.max_parse_retries,
2077 "Malformed tool arguments from LLM"
2078 );
2079
2080 if let Some(tx) = &event_tx {
2081 tx.send(AgentEvent::ToolEnd {
2082 id: tool_call.id.clone(),
2083 name: tool_call.name.clone(),
2084 output: error_msg.clone(),
2085 exit_code: 1,
2086 metadata: None,
2087 })
2088 .await
2089 .ok();
2090 }
2091
2092 messages.push(Message::tool_result(&tool_call.id, &error_msg, true));
2093
2094 if parse_error_count > self.config.max_parse_retries {
2095 let msg = format!(
2096 "LLM produced malformed tool arguments {} time(s) in a row \
2097 (max_parse_retries={}); giving up",
2098 parse_error_count, self.config.max_parse_retries
2099 );
2100 tracing::error!("{}", msg);
2101 if let Some(tx) = &event_tx {
2102 tx.send(AgentEvent::Error {
2103 message: msg.clone(),
2104 })
2105 .await
2106 .ok();
2107 }
2108 anyhow::bail!(msg);
2109 }
2110 continue;
2111 }
2112
2113 parse_error_count = 0;
2115
2116 if let Some(ref registry) = self.config.skill_registry {
2118 let instruction_skills =
2119 registry.by_kind(crate::skills::SkillKind::Instruction);
2120 let has_restrictions =
2121 instruction_skills.iter().any(|s| s.allowed_tools.is_some());
2122 if has_restrictions {
2123 let allowed = instruction_skills
2124 .iter()
2125 .any(|s| s.is_tool_allowed(&tool_call.name));
2126 if !allowed {
2127 let msg = format!(
2128 "Tool '{}' is not allowed by any active skill.",
2129 tool_call.name
2130 );
2131 tracing::info!(
2132 tool_name = tool_call.name.as_str(),
2133 "Tool blocked by skill registry"
2134 );
2135 if let Some(tx) = &event_tx {
2136 tx.send(AgentEvent::PermissionDenied {
2137 tool_id: tool_call.id.clone(),
2138 tool_name: tool_call.name.clone(),
2139 args: tool_call.args.clone(),
2140 reason: msg.clone(),
2141 })
2142 .await
2143 .ok();
2144 }
2145 messages.push(Message::tool_result(&tool_call.id, &msg, true));
2146 continue;
2147 }
2148 }
2149 }
2150
2151 if let Some(HookResult::Block(reason)) = self
2153 .fire_pre_tool_use(session_id.unwrap_or(""), &tool_call.name, &tool_call.args)
2154 .await
2155 {
2156 let msg = format!("Tool '{}' blocked by hook: {}", tool_call.name, reason);
2157 tracing::info!(
2158 tool_name = tool_call.name.as_str(),
2159 "Tool blocked by PreToolUse hook"
2160 );
2161
2162 if let Some(tx) = &event_tx {
2163 tx.send(AgentEvent::PermissionDenied {
2164 tool_id: tool_call.id.clone(),
2165 tool_name: tool_call.name.clone(),
2166 args: tool_call.args.clone(),
2167 reason: reason.clone(),
2168 })
2169 .await
2170 .ok();
2171 }
2172
2173 messages.push(Message::tool_result(&tool_call.id, &msg, true));
2174 continue;
2175 }
2176
2177 let permission_decision = if let Some(checker) = &self.config.permission_checker {
2179 checker.check(&tool_call.name, &tool_call.args)
2180 } else {
2181 PermissionDecision::Ask
2183 };
2184
2185 let (output, exit_code, is_error, metadata, images) = match permission_decision {
2186 PermissionDecision::Deny => {
2187 tracing::info!(
2188 tool_name = tool_call.name.as_str(),
2189 permission = "deny",
2190 "Tool permission denied"
2191 );
2192 let denial_msg = format!(
2194 "Permission denied: Tool '{}' is blocked by permission policy.",
2195 tool_call.name
2196 );
2197
2198 if let Some(tx) = &event_tx {
2200 tx.send(AgentEvent::PermissionDenied {
2201 tool_id: tool_call.id.clone(),
2202 tool_name: tool_call.name.clone(),
2203 args: tool_call.args.clone(),
2204 reason: "Blocked by deny rule in permission policy".to_string(),
2205 })
2206 .await
2207 .ok();
2208 }
2209
2210 (denial_msg, 1, true, None, Vec::new())
2211 }
2212 PermissionDecision::Allow => {
2213 tracing::info!(
2214 tool_name = tool_call.name.as_str(),
2215 permission = "allow",
2216 "Tool permission: allow"
2217 );
2218 let stream_ctx =
2220 self.streaming_tool_context(&event_tx, &tool_call.id, &tool_call.name);
2221 let result = self
2222 .execute_tool_queued_or_direct(
2223 &tool_call.name,
2224 &tool_call.args,
2225 &stream_ctx,
2226 )
2227 .await;
2228
2229 Self::tool_result_to_tuple(result)
2230 }
2231 PermissionDecision::Ask => {
2232 tracing::info!(
2233 tool_name = tool_call.name.as_str(),
2234 permission = "ask",
2235 "Tool permission: ask"
2236 );
2237 if let Some(cm) = &self.config.confirmation_manager {
2239 if !cm.requires_confirmation(&tool_call.name).await {
2241 let stream_ctx = self.streaming_tool_context(
2242 &event_tx,
2243 &tool_call.id,
2244 &tool_call.name,
2245 );
2246 let result = self
2247 .execute_tool_queued_or_direct(
2248 &tool_call.name,
2249 &tool_call.args,
2250 &stream_ctx,
2251 )
2252 .await;
2253
2254 let (output, exit_code, is_error, metadata, images) =
2255 Self::tool_result_to_tuple(result);
2256
2257 if images.is_empty() {
2259 messages.push(Message::tool_result(
2260 &tool_call.id,
2261 &output,
2262 is_error,
2263 ));
2264 } else {
2265 messages.push(Message::tool_result_with_images(
2266 &tool_call.id,
2267 &output,
2268 &images,
2269 is_error,
2270 ));
2271 }
2272
2273 let tool_duration = tool_start.elapsed();
2275 crate::telemetry::record_tool_result(exit_code, tool_duration);
2276
2277 if let Some(tx) = &event_tx {
2279 tx.send(AgentEvent::ToolEnd {
2280 id: tool_call.id.clone(),
2281 name: tool_call.name.clone(),
2282 output: output.clone(),
2283 exit_code,
2284 metadata,
2285 })
2286 .await
2287 .ok();
2288 }
2289
2290 self.fire_post_tool_use(
2292 session_id.unwrap_or(""),
2293 &tool_call.name,
2294 &tool_call.args,
2295 &output,
2296 exit_code == 0,
2297 tool_duration.as_millis() as u64,
2298 )
2299 .await;
2300
2301 continue; }
2303
2304 let policy = cm.policy().await;
2306 let timeout_ms = policy.default_timeout_ms;
2307 let timeout_action = policy.timeout_action;
2308
2309 let rx = cm
2311 .request_confirmation(
2312 &tool_call.id,
2313 &tool_call.name,
2314 &tool_call.args,
2315 )
2316 .await;
2317
2318 if let Some(tx) = &event_tx {
2322 tx.send(AgentEvent::ConfirmationRequired {
2323 tool_id: tool_call.id.clone(),
2324 tool_name: tool_call.name.clone(),
2325 args: tool_call.args.clone(),
2326 timeout_ms,
2327 })
2328 .await
2329 .ok();
2330 }
2331
2332 let confirmation_result =
2334 tokio::time::timeout(Duration::from_millis(timeout_ms), rx).await;
2335
2336 match confirmation_result {
2337 Ok(Ok(response)) => {
2338 if let Some(tx) = &event_tx {
2340 tx.send(AgentEvent::ConfirmationReceived {
2341 tool_id: tool_call.id.clone(),
2342 approved: response.approved,
2343 reason: response.reason.clone(),
2344 })
2345 .await
2346 .ok();
2347 }
2348 if response.approved {
2349 let stream_ctx = self.streaming_tool_context(
2350 &event_tx,
2351 &tool_call.id,
2352 &tool_call.name,
2353 );
2354 let result = self
2355 .execute_tool_queued_or_direct(
2356 &tool_call.name,
2357 &tool_call.args,
2358 &stream_ctx,
2359 )
2360 .await;
2361
2362 Self::tool_result_to_tuple(result)
2363 } else {
2364 let rejection_msg = format!(
2365 "Tool '{}' execution was REJECTED by the user. Reason: {}. \
2366 DO NOT retry this tool call unless the user explicitly asks you to.",
2367 tool_call.name,
2368 response.reason.unwrap_or_else(|| "No reason provided".to_string())
2369 );
2370 (rejection_msg, 1, true, None, Vec::new())
2371 }
2372 }
2373 Ok(Err(_)) => {
2374 if let Some(tx) = &event_tx {
2376 tx.send(AgentEvent::ConfirmationTimeout {
2377 tool_id: tool_call.id.clone(),
2378 action_taken: "rejected".to_string(),
2379 })
2380 .await
2381 .ok();
2382 }
2383 let msg = format!(
2384 "Tool '{}' confirmation failed: confirmation channel closed",
2385 tool_call.name
2386 );
2387 (msg, 1, true, None, Vec::new())
2388 }
2389 Err(_) => {
2390 cm.check_timeouts().await;
2391
2392 if let Some(tx) = &event_tx {
2394 tx.send(AgentEvent::ConfirmationTimeout {
2395 tool_id: tool_call.id.clone(),
2396 action_taken: match timeout_action {
2397 crate::hitl::TimeoutAction::Reject => {
2398 "rejected".to_string()
2399 }
2400 crate::hitl::TimeoutAction::AutoApprove => {
2401 "auto_approved".to_string()
2402 }
2403 },
2404 })
2405 .await
2406 .ok();
2407 }
2408
2409 match timeout_action {
2410 crate::hitl::TimeoutAction::Reject => {
2411 let msg = format!(
2412 "Tool '{}' execution was REJECTED: user confirmation timed out after {}ms. \
2413 DO NOT retry this tool call — the user did not approve it. \
2414 Inform the user that the operation requires their approval and ask them to try again.",
2415 tool_call.name, timeout_ms
2416 );
2417 (msg, 1, true, None, Vec::new())
2418 }
2419 crate::hitl::TimeoutAction::AutoApprove => {
2420 let stream_ctx = self.streaming_tool_context(
2421 &event_tx,
2422 &tool_call.id,
2423 &tool_call.name,
2424 );
2425 let result = self
2426 .execute_tool_queued_or_direct(
2427 &tool_call.name,
2428 &tool_call.args,
2429 &stream_ctx,
2430 )
2431 .await;
2432
2433 Self::tool_result_to_tuple(result)
2434 }
2435 }
2436 }
2437 }
2438 } else {
2439 let msg = format!(
2441 "Tool '{}' requires confirmation but no HITL confirmation manager is configured. \
2442 Configure a confirmation policy to enable tool execution.",
2443 tool_call.name
2444 );
2445 tracing::warn!(
2446 tool_name = tool_call.name.as_str(),
2447 "Tool requires confirmation but no HITL manager configured"
2448 );
2449 (msg, 1, true, None, Vec::new())
2450 }
2451 }
2452 };
2453
2454 let tool_duration = tool_start.elapsed();
2455 crate::telemetry::record_tool_result(exit_code, tool_duration);
2456
2457 let output = if let Some(ref sp) = self.config.security_provider {
2459 sp.sanitize_output(&output)
2460 } else {
2461 output
2462 };
2463
2464 self.fire_post_tool_use(
2466 session_id.unwrap_or(""),
2467 &tool_call.name,
2468 &tool_call.args,
2469 &output,
2470 exit_code == 0,
2471 tool_duration.as_millis() as u64,
2472 )
2473 .await;
2474
2475 if let Some(ref memory) = self.config.memory {
2477 let tools_used = [tool_call.name.clone()];
2478 let remember_result = if exit_code == 0 {
2479 memory
2480 .remember_success(effective_prompt, &tools_used, &output)
2481 .await
2482 } else {
2483 memory
2484 .remember_failure(effective_prompt, &output, &tools_used)
2485 .await
2486 };
2487 match remember_result {
2488 Ok(()) => {
2489 if let Some(tx) = &event_tx {
2490 let item_type = if exit_code == 0 { "success" } else { "failure" };
2491 tx.send(AgentEvent::MemoryStored {
2492 memory_id: uuid::Uuid::new_v4().to_string(),
2493 memory_type: item_type.to_string(),
2494 importance: if exit_code == 0 { 0.8 } else { 0.9 },
2495 tags: vec![item_type.to_string(), tool_call.name.clone()],
2496 })
2497 .await
2498 .ok();
2499 }
2500 }
2501 Err(e) => {
2502 tracing::warn!("Failed to store memory after tool execution: {}", e);
2503 }
2504 }
2505 }
2506
2507 if let Some(tx) = &event_tx {
2509 tx.send(AgentEvent::ToolEnd {
2510 id: tool_call.id.clone(),
2511 name: tool_call.name.clone(),
2512 output: output.clone(),
2513 exit_code,
2514 metadata,
2515 })
2516 .await
2517 .ok();
2518 }
2519
2520 if images.is_empty() {
2522 messages.push(Message::tool_result(&tool_call.id, &output, is_error));
2523 } else {
2524 messages.push(Message::tool_result_with_images(
2525 &tool_call.id,
2526 &output,
2527 &images,
2528 is_error,
2529 ));
2530 }
2531 }
2532 }
2533 }
2534
2535 pub async fn execute_streaming(
2537 &self,
2538 history: &[Message],
2539 prompt: &str,
2540 ) -> Result<(
2541 mpsc::Receiver<AgentEvent>,
2542 tokio::task::JoinHandle<Result<AgentResult>>,
2543 )> {
2544 let (tx, rx) = mpsc::channel(100);
2545
2546 let llm_client = self.llm_client.clone();
2547 let tool_executor = self.tool_executor.clone();
2548 let tool_context = self.tool_context.clone();
2549 let config = self.config.clone();
2550 let tool_metrics = self.tool_metrics.clone();
2551 let command_queue = self.command_queue.clone();
2552 let history = history.to_vec();
2553 let prompt = prompt.to_string();
2554
2555 let handle = tokio::spawn(async move {
2556 let mut agent = AgentLoop::new(llm_client, tool_executor, tool_context, config);
2557 if let Some(metrics) = tool_metrics {
2558 agent = agent.with_tool_metrics(metrics);
2559 }
2560 if let Some(queue) = command_queue {
2561 agent = agent.with_queue(queue);
2562 }
2563 agent.execute(&history, &prompt, Some(tx)).await
2564 });
2565
2566 Ok((rx, handle))
2567 }
2568
2569 pub async fn plan(&self, prompt: &str, _context: Option<&str>) -> Result<ExecutionPlan> {
2574 use crate::planning::LlmPlanner;
2575
2576 match LlmPlanner::create_plan(&self.llm_client, prompt).await {
2577 Ok(plan) => Ok(plan),
2578 Err(e) => {
2579 tracing::warn!("LLM plan creation failed, using fallback: {}", e);
2580 Ok(LlmPlanner::fallback_plan(prompt))
2581 }
2582 }
2583 }
2584
2585 pub async fn execute_with_planning(
2587 &self,
2588 history: &[Message],
2589 prompt: &str,
2590 event_tx: Option<mpsc::Sender<AgentEvent>>,
2591 ) -> Result<AgentResult> {
2592 if let Some(tx) = &event_tx {
2594 tx.send(AgentEvent::PlanningStart {
2595 prompt: prompt.to_string(),
2596 })
2597 .await
2598 .ok();
2599 }
2600
2601 let goal = if self.config.goal_tracking {
2603 let g = self.extract_goal(prompt).await?;
2604 if let Some(tx) = &event_tx {
2605 tx.send(AgentEvent::GoalExtracted { goal: g.clone() })
2606 .await
2607 .ok();
2608 }
2609 Some(g)
2610 } else {
2611 None
2612 };
2613
2614 let plan = self.plan(prompt, None).await?;
2616
2617 if let Some(tx) = &event_tx {
2619 tx.send(AgentEvent::PlanningEnd {
2620 estimated_steps: plan.steps.len(),
2621 plan: plan.clone(),
2622 })
2623 .await
2624 .ok();
2625 }
2626
2627 let plan_start = std::time::Instant::now();
2628
2629 let result = self.execute_plan(history, &plan, event_tx.clone()).await?;
2631
2632 if self.config.goal_tracking {
2634 if let Some(ref g) = goal {
2635 let achieved = self.check_goal_achievement(g, &result.text).await?;
2636 if achieved {
2637 if let Some(tx) = &event_tx {
2638 tx.send(AgentEvent::GoalAchieved {
2639 goal: g.description.clone(),
2640 total_steps: result.messages.len(),
2641 duration_ms: plan_start.elapsed().as_millis() as i64,
2642 })
2643 .await
2644 .ok();
2645 }
2646 }
2647 }
2648 }
2649
2650 Ok(result)
2651 }
2652
2653 async fn execute_plan(
2660 &self,
2661 history: &[Message],
2662 plan: &ExecutionPlan,
2663 event_tx: Option<mpsc::Sender<AgentEvent>>,
2664 ) -> Result<AgentResult> {
2665 let mut plan = plan.clone();
2666 let mut current_history = history.to_vec();
2667 let mut total_usage = TokenUsage::default();
2668 let mut tool_calls_count = 0;
2669 let total_steps = plan.steps.len();
2670
2671 let steps_text = plan
2673 .steps
2674 .iter()
2675 .enumerate()
2676 .map(|(i, step)| format!("{}. {}", i + 1, step.content))
2677 .collect::<Vec<_>>()
2678 .join("\n");
2679 current_history.push(Message::user(&crate::prompts::render(
2680 crate::prompts::PLAN_EXECUTE_GOAL,
2681 &[("goal", &plan.goal), ("steps", &steps_text)],
2682 )));
2683
2684 loop {
2685 let ready: Vec<String> = plan
2686 .get_ready_steps()
2687 .iter()
2688 .map(|s| s.id.clone())
2689 .collect();
2690
2691 if ready.is_empty() {
2692 if plan.has_deadlock() {
2694 tracing::warn!(
2695 "Plan deadlock detected: {} pending steps with unresolvable dependencies",
2696 plan.pending_count()
2697 );
2698 }
2699 break;
2700 }
2701
2702 if ready.len() == 1 {
2703 let step_id = &ready[0];
2705 let step = plan
2706 .steps
2707 .iter()
2708 .find(|s| s.id == *step_id)
2709 .ok_or_else(|| anyhow::anyhow!("step '{}' not found in plan", step_id))?
2710 .clone();
2711 let step_number = plan
2712 .steps
2713 .iter()
2714 .position(|s| s.id == *step_id)
2715 .unwrap_or(0)
2716 + 1;
2717
2718 if let Some(tx) = &event_tx {
2720 tx.send(AgentEvent::StepStart {
2721 step_id: step.id.clone(),
2722 description: step.content.clone(),
2723 step_number,
2724 total_steps,
2725 })
2726 .await
2727 .ok();
2728 }
2729
2730 plan.mark_status(&step.id, TaskStatus::InProgress);
2731
2732 let step_prompt = crate::prompts::render(
2733 crate::prompts::PLAN_EXECUTE_STEP,
2734 &[
2735 ("step_num", &step_number.to_string()),
2736 ("description", &step.content),
2737 ],
2738 );
2739
2740 match self
2741 .execute_loop(¤t_history, &step_prompt, None, event_tx.clone())
2742 .await
2743 {
2744 Ok(result) => {
2745 current_history = result.messages.clone();
2746 total_usage.prompt_tokens += result.usage.prompt_tokens;
2747 total_usage.completion_tokens += result.usage.completion_tokens;
2748 total_usage.total_tokens += result.usage.total_tokens;
2749 tool_calls_count += result.tool_calls_count;
2750 plan.mark_status(&step.id, TaskStatus::Completed);
2751
2752 if let Some(tx) = &event_tx {
2753 tx.send(AgentEvent::StepEnd {
2754 step_id: step.id.clone(),
2755 status: TaskStatus::Completed,
2756 step_number,
2757 total_steps,
2758 })
2759 .await
2760 .ok();
2761 }
2762 }
2763 Err(e) => {
2764 tracing::error!("Plan step '{}' failed: {}", step.id, e);
2765 plan.mark_status(&step.id, TaskStatus::Failed);
2766
2767 if let Some(tx) = &event_tx {
2768 tx.send(AgentEvent::StepEnd {
2769 step_id: step.id.clone(),
2770 status: TaskStatus::Failed,
2771 step_number,
2772 total_steps,
2773 })
2774 .await
2775 .ok();
2776 }
2777 }
2778 }
2779 } else {
2780 let ready_steps: Vec<_> = ready
2787 .iter()
2788 .filter_map(|id| {
2789 let step = plan.steps.iter().find(|s| s.id == *id)?.clone();
2790 let step_number =
2791 plan.steps.iter().position(|s| s.id == *id).unwrap_or(0) + 1;
2792 Some((step, step_number))
2793 })
2794 .collect();
2795
2796 for (step, step_number) in &ready_steps {
2798 plan.mark_status(&step.id, TaskStatus::InProgress);
2799 if let Some(tx) = &event_tx {
2800 tx.send(AgentEvent::StepStart {
2801 step_id: step.id.clone(),
2802 description: step.content.clone(),
2803 step_number: *step_number,
2804 total_steps,
2805 })
2806 .await
2807 .ok();
2808 }
2809 }
2810
2811 let mut join_set = tokio::task::JoinSet::new();
2813 for (step, step_number) in &ready_steps {
2814 let base_history = current_history.clone();
2815 let agent_clone = self.clone();
2816 let tx = event_tx.clone();
2817 let step_clone = step.clone();
2818 let sn = *step_number;
2819
2820 join_set.spawn(async move {
2821 let prompt = crate::prompts::render(
2822 crate::prompts::PLAN_EXECUTE_STEP,
2823 &[
2824 ("step_num", &sn.to_string()),
2825 ("description", &step_clone.content),
2826 ],
2827 );
2828 let result = agent_clone
2829 .execute_loop(&base_history, &prompt, None, tx)
2830 .await;
2831 (step_clone.id, sn, result)
2832 });
2833 }
2834
2835 let mut parallel_summaries = Vec::new();
2837 while let Some(join_result) = join_set.join_next().await {
2838 match join_result {
2839 Ok((step_id, step_number, step_result)) => match step_result {
2840 Ok(result) => {
2841 total_usage.prompt_tokens += result.usage.prompt_tokens;
2842 total_usage.completion_tokens += result.usage.completion_tokens;
2843 total_usage.total_tokens += result.usage.total_tokens;
2844 tool_calls_count += result.tool_calls_count;
2845 plan.mark_status(&step_id, TaskStatus::Completed);
2846
2847 parallel_summaries.push(format!(
2849 "- Step {} ({}): {}",
2850 step_number, step_id, result.text
2851 ));
2852
2853 if let Some(tx) = &event_tx {
2854 tx.send(AgentEvent::StepEnd {
2855 step_id,
2856 status: TaskStatus::Completed,
2857 step_number,
2858 total_steps,
2859 })
2860 .await
2861 .ok();
2862 }
2863 }
2864 Err(e) => {
2865 tracing::error!("Plan step '{}' failed: {}", step_id, e);
2866 plan.mark_status(&step_id, TaskStatus::Failed);
2867
2868 if let Some(tx) = &event_tx {
2869 tx.send(AgentEvent::StepEnd {
2870 step_id,
2871 status: TaskStatus::Failed,
2872 step_number,
2873 total_steps,
2874 })
2875 .await
2876 .ok();
2877 }
2878 }
2879 },
2880 Err(e) => {
2881 tracing::error!("JoinSet task panicked: {}", e);
2882 }
2883 }
2884 }
2885
2886 if !parallel_summaries.is_empty() {
2888 parallel_summaries.sort(); let results_text = parallel_summaries.join("\n");
2890 current_history.push(Message::user(&crate::prompts::render(
2891 crate::prompts::PLAN_PARALLEL_RESULTS,
2892 &[("results", &results_text)],
2893 )));
2894 }
2895 }
2896
2897 if self.config.goal_tracking {
2899 let completed = plan
2900 .steps
2901 .iter()
2902 .filter(|s| s.status == TaskStatus::Completed)
2903 .count();
2904 if let Some(tx) = &event_tx {
2905 tx.send(AgentEvent::GoalProgress {
2906 goal: plan.goal.clone(),
2907 progress: plan.progress(),
2908 completed_steps: completed,
2909 total_steps,
2910 })
2911 .await
2912 .ok();
2913 }
2914 }
2915 }
2916
2917 let final_text = current_history
2919 .last()
2920 .map(|m| {
2921 m.content
2922 .iter()
2923 .filter_map(|block| {
2924 if let crate::llm::ContentBlock::Text { text } = block {
2925 Some(text.as_str())
2926 } else {
2927 None
2928 }
2929 })
2930 .collect::<Vec<_>>()
2931 .join("\n")
2932 })
2933 .unwrap_or_default();
2934
2935 Ok(AgentResult {
2936 text: final_text,
2937 messages: current_history,
2938 usage: total_usage,
2939 tool_calls_count,
2940 })
2941 }
2942
2943 pub async fn extract_goal(&self, prompt: &str) -> Result<AgentGoal> {
2948 use crate::planning::LlmPlanner;
2949
2950 match LlmPlanner::extract_goal(&self.llm_client, prompt).await {
2951 Ok(goal) => Ok(goal),
2952 Err(e) => {
2953 tracing::warn!("LLM goal extraction failed, using fallback: {}", e);
2954 Ok(LlmPlanner::fallback_goal(prompt))
2955 }
2956 }
2957 }
2958
2959 pub async fn check_goal_achievement(
2964 &self,
2965 goal: &AgentGoal,
2966 current_state: &str,
2967 ) -> Result<bool> {
2968 use crate::planning::LlmPlanner;
2969
2970 match LlmPlanner::check_achievement(&self.llm_client, goal, current_state).await {
2971 Ok(result) => Ok(result.achieved),
2972 Err(e) => {
2973 tracing::warn!("LLM achievement check failed, using fallback: {}", e);
2974 let result = LlmPlanner::fallback_check_achievement(goal, current_state);
2975 Ok(result.achieved)
2976 }
2977 }
2978 }
2979}
2980
2981#[cfg(test)]
2982mod tests {
2983 use super::*;
2984 use crate::llm::{ContentBlock, StreamEvent};
2985 use crate::permissions::PermissionPolicy;
2986 use crate::tools::ToolExecutor;
2987 use std::path::PathBuf;
2988 use std::sync::atomic::{AtomicUsize, Ordering};
2989
2990 fn test_tool_context() -> ToolContext {
2992 ToolContext::new(PathBuf::from("/tmp"))
2993 }
2994
2995 #[test]
2996 fn test_agent_config_default() {
2997 let config = AgentConfig::default();
2998 assert!(config.prompt_slots.is_empty());
2999 assert!(config.tools.is_empty()); assert_eq!(config.max_tool_rounds, MAX_TOOL_ROUNDS);
3001 assert!(config.permission_checker.is_none());
3002 assert!(config.context_providers.is_empty());
3003 let registry = config
3005 .skill_registry
3006 .expect("skill_registry must be Some by default");
3007 assert!(registry.len() >= 7, "expected at least 7 built-in skills");
3008 assert!(registry.get("code-search").is_some());
3009 assert!(registry.get("find-bugs").is_some());
3010 }
3011
3012 pub(crate) struct MockLlmClient {
3018 responses: std::sync::Mutex<Vec<LlmResponse>>,
3020 pub(crate) call_count: AtomicUsize,
3022 }
3023
3024 impl MockLlmClient {
3025 pub(crate) fn new(responses: Vec<LlmResponse>) -> Self {
3026 Self {
3027 responses: std::sync::Mutex::new(responses),
3028 call_count: AtomicUsize::new(0),
3029 }
3030 }
3031
3032 pub(crate) fn text_response(text: &str) -> LlmResponse {
3034 LlmResponse {
3035 message: Message {
3036 role: "assistant".to_string(),
3037 content: vec![ContentBlock::Text {
3038 text: text.to_string(),
3039 }],
3040 reasoning_content: None,
3041 },
3042 usage: TokenUsage {
3043 prompt_tokens: 10,
3044 completion_tokens: 5,
3045 total_tokens: 15,
3046 cache_read_tokens: None,
3047 cache_write_tokens: None,
3048 },
3049 stop_reason: Some("end_turn".to_string()),
3050 }
3051 }
3052
3053 pub(crate) fn tool_call_response(
3055 tool_id: &str,
3056 tool_name: &str,
3057 args: serde_json::Value,
3058 ) -> LlmResponse {
3059 LlmResponse {
3060 message: Message {
3061 role: "assistant".to_string(),
3062 content: vec![ContentBlock::ToolUse {
3063 id: tool_id.to_string(),
3064 name: tool_name.to_string(),
3065 input: args,
3066 }],
3067 reasoning_content: None,
3068 },
3069 usage: TokenUsage {
3070 prompt_tokens: 10,
3071 completion_tokens: 5,
3072 total_tokens: 15,
3073 cache_read_tokens: None,
3074 cache_write_tokens: None,
3075 },
3076 stop_reason: Some("tool_use".to_string()),
3077 }
3078 }
3079 }
3080
3081 #[async_trait::async_trait]
3082 impl LlmClient for MockLlmClient {
3083 async fn complete(
3084 &self,
3085 _messages: &[Message],
3086 _system: Option<&str>,
3087 _tools: &[ToolDefinition],
3088 ) -> Result<LlmResponse> {
3089 self.call_count.fetch_add(1, Ordering::SeqCst);
3090 let mut responses = self.responses.lock().unwrap();
3091 if responses.is_empty() {
3092 anyhow::bail!("No more mock responses available");
3093 }
3094 Ok(responses.remove(0))
3095 }
3096
3097 async fn complete_streaming(
3098 &self,
3099 _messages: &[Message],
3100 _system: Option<&str>,
3101 _tools: &[ToolDefinition],
3102 ) -> Result<mpsc::Receiver<StreamEvent>> {
3103 self.call_count.fetch_add(1, Ordering::SeqCst);
3104 let mut responses = self.responses.lock().unwrap();
3105 if responses.is_empty() {
3106 anyhow::bail!("No more mock responses available");
3107 }
3108 let response = responses.remove(0);
3109
3110 let (tx, rx) = mpsc::channel(10);
3111 tokio::spawn(async move {
3112 for block in &response.message.content {
3114 if let ContentBlock::Text { text } = block {
3115 tx.send(StreamEvent::TextDelta(text.clone())).await.ok();
3116 }
3117 }
3118 tx.send(StreamEvent::Done(response)).await.ok();
3119 });
3120
3121 Ok(rx)
3122 }
3123 }
3124
3125 #[tokio::test]
3130 async fn test_agent_simple_response() {
3131 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
3132 "Hello, I'm an AI assistant.",
3133 )]));
3134
3135 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3136 let config = AgentConfig::default();
3137
3138 let agent = AgentLoop::new(
3139 mock_client.clone(),
3140 tool_executor,
3141 test_tool_context(),
3142 config,
3143 );
3144 let result = agent.execute(&[], "Hello", None).await.unwrap();
3145
3146 assert_eq!(result.text, "Hello, I'm an AI assistant.");
3147 assert_eq!(result.tool_calls_count, 0);
3148 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 1);
3149 }
3150
3151 #[tokio::test]
3152 async fn test_agent_with_tool_call() {
3153 let mock_client = Arc::new(MockLlmClient::new(vec![
3154 MockLlmClient::tool_call_response(
3156 "tool-1",
3157 "bash",
3158 serde_json::json!({"command": "echo hello"}),
3159 ),
3160 MockLlmClient::text_response("The command output was: hello"),
3162 ]));
3163
3164 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3165 let config = AgentConfig::default();
3166
3167 let agent = AgentLoop::new(
3168 mock_client.clone(),
3169 tool_executor,
3170 test_tool_context(),
3171 config,
3172 );
3173 let result = agent.execute(&[], "Run echo hello", None).await.unwrap();
3174
3175 assert_eq!(result.text, "The command output was: hello");
3176 assert_eq!(result.tool_calls_count, 1);
3177 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 2);
3178 }
3179
3180 #[tokio::test]
3181 async fn test_agent_permission_deny() {
3182 let mock_client = Arc::new(MockLlmClient::new(vec![
3183 MockLlmClient::tool_call_response(
3185 "tool-1",
3186 "bash",
3187 serde_json::json!({"command": "rm -rf /tmp/test"}),
3188 ),
3189 MockLlmClient::text_response(
3191 "I cannot execute that command due to permission restrictions.",
3192 ),
3193 ]));
3194
3195 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3196
3197 let permission_policy = PermissionPolicy::new().deny("bash(rm:*)");
3199
3200 let config = AgentConfig {
3201 permission_checker: Some(Arc::new(permission_policy)),
3202 ..Default::default()
3203 };
3204
3205 let (tx, mut rx) = mpsc::channel(100);
3206 let agent = AgentLoop::new(
3207 mock_client.clone(),
3208 tool_executor,
3209 test_tool_context(),
3210 config,
3211 );
3212 let result = agent.execute(&[], "Delete files", Some(tx)).await.unwrap();
3213
3214 let mut found_permission_denied = false;
3216 while let Ok(event) = rx.try_recv() {
3217 if let AgentEvent::PermissionDenied { tool_name, .. } = event {
3218 assert_eq!(tool_name, "bash");
3219 found_permission_denied = true;
3220 }
3221 }
3222 assert!(
3223 found_permission_denied,
3224 "Should have received PermissionDenied event"
3225 );
3226
3227 assert_eq!(result.tool_calls_count, 1);
3228 }
3229
3230 #[tokio::test]
3231 async fn test_agent_permission_allow() {
3232 let mock_client = Arc::new(MockLlmClient::new(vec![
3233 MockLlmClient::tool_call_response(
3235 "tool-1",
3236 "bash",
3237 serde_json::json!({"command": "echo hello"}),
3238 ),
3239 MockLlmClient::text_response("Done!"),
3241 ]));
3242
3243 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3244
3245 let permission_policy = PermissionPolicy::new()
3247 .allow("bash(echo:*)")
3248 .deny("bash(rm:*)");
3249
3250 let config = AgentConfig {
3251 permission_checker: Some(Arc::new(permission_policy)),
3252 ..Default::default()
3253 };
3254
3255 let agent = AgentLoop::new(
3256 mock_client.clone(),
3257 tool_executor,
3258 test_tool_context(),
3259 config,
3260 );
3261 let result = agent.execute(&[], "Echo hello", None).await.unwrap();
3262
3263 assert_eq!(result.text, "Done!");
3264 assert_eq!(result.tool_calls_count, 1);
3265 }
3266
3267 #[tokio::test]
3268 async fn test_agent_streaming_events() {
3269 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
3270 "Hello!",
3271 )]));
3272
3273 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3274 let config = AgentConfig::default();
3275
3276 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3277 let (mut rx, handle) = agent.execute_streaming(&[], "Hi").await.unwrap();
3278
3279 let mut events = Vec::new();
3281 while let Some(event) = rx.recv().await {
3282 events.push(event);
3283 }
3284
3285 let result = handle.await.unwrap().unwrap();
3286 assert_eq!(result.text, "Hello!");
3287
3288 assert!(events.iter().any(|e| matches!(e, AgentEvent::Start { .. })));
3290 assert!(events.iter().any(|e| matches!(e, AgentEvent::End { .. })));
3291 }
3292
3293 #[tokio::test]
3294 async fn test_agent_max_tool_rounds() {
3295 let responses: Vec<LlmResponse> = (0..100)
3297 .map(|i| {
3298 MockLlmClient::tool_call_response(
3299 &format!("tool-{}", i),
3300 "bash",
3301 serde_json::json!({"command": "echo loop"}),
3302 )
3303 })
3304 .collect();
3305
3306 let mock_client = Arc::new(MockLlmClient::new(responses));
3307 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3308
3309 let config = AgentConfig {
3310 max_tool_rounds: 3,
3311 ..Default::default()
3312 };
3313
3314 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3315 let result = agent.execute(&[], "Loop forever", None).await;
3316
3317 assert!(result.is_err());
3319 assert!(result.unwrap_err().to_string().contains("Max tool rounds"));
3320 }
3321
3322 #[tokio::test]
3323 async fn test_agent_no_permission_policy_defaults_to_ask() {
3324 let mock_client = Arc::new(MockLlmClient::new(vec![
3327 MockLlmClient::tool_call_response(
3328 "tool-1",
3329 "bash",
3330 serde_json::json!({"command": "rm -rf /tmp/test"}),
3331 ),
3332 MockLlmClient::text_response("Denied!"),
3333 ]));
3334
3335 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3336 let config = AgentConfig {
3337 permission_checker: None, ..Default::default()
3340 };
3341
3342 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3343 let result = agent.execute(&[], "Delete", None).await.unwrap();
3344
3345 assert_eq!(result.text, "Denied!");
3347 assert_eq!(result.tool_calls_count, 1);
3348 }
3349
3350 #[tokio::test]
3351 async fn test_agent_permission_ask_without_cm_denies() {
3352 let mock_client = Arc::new(MockLlmClient::new(vec![
3355 MockLlmClient::tool_call_response(
3356 "tool-1",
3357 "bash",
3358 serde_json::json!({"command": "echo test"}),
3359 ),
3360 MockLlmClient::text_response("Denied!"),
3361 ]));
3362
3363 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3364
3365 let permission_policy = PermissionPolicy::new(); let config = AgentConfig {
3369 permission_checker: Some(Arc::new(permission_policy)),
3370 ..Default::default()
3372 };
3373
3374 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3375 let result = agent.execute(&[], "Echo", None).await.unwrap();
3376
3377 assert_eq!(result.text, "Denied!");
3379 assert!(result.tool_calls_count >= 1);
3381 }
3382
3383 #[tokio::test]
3388 async fn test_agent_hitl_approved() {
3389 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
3390 use tokio::sync::broadcast;
3391
3392 let mock_client = Arc::new(MockLlmClient::new(vec![
3393 MockLlmClient::tool_call_response(
3394 "tool-1",
3395 "bash",
3396 serde_json::json!({"command": "echo hello"}),
3397 ),
3398 MockLlmClient::text_response("Command executed!"),
3399 ]));
3400
3401 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3402
3403 let (event_tx, _event_rx) = broadcast::channel(100);
3405 let hitl_policy = ConfirmationPolicy {
3406 enabled: true,
3407 ..Default::default()
3408 };
3409 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3410
3411 let permission_policy = PermissionPolicy::new(); let config = AgentConfig {
3415 permission_checker: Some(Arc::new(permission_policy)),
3416 confirmation_manager: Some(confirmation_manager.clone()),
3417 ..Default::default()
3418 };
3419
3420 let cm_clone = confirmation_manager.clone();
3422 tokio::spawn(async move {
3423 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
3425 cm_clone.confirm("tool-1", true, None).await.ok();
3427 });
3428
3429 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3430 let result = agent.execute(&[], "Run echo", None).await.unwrap();
3431
3432 assert_eq!(result.text, "Command executed!");
3433 assert_eq!(result.tool_calls_count, 1);
3434 }
3435
3436 #[tokio::test]
3437 async fn test_agent_hitl_rejected() {
3438 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
3439 use tokio::sync::broadcast;
3440
3441 let mock_client = Arc::new(MockLlmClient::new(vec![
3442 MockLlmClient::tool_call_response(
3443 "tool-1",
3444 "bash",
3445 serde_json::json!({"command": "rm -rf /"}),
3446 ),
3447 MockLlmClient::text_response("Understood, I won't do that."),
3448 ]));
3449
3450 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3451
3452 let (event_tx, _event_rx) = broadcast::channel(100);
3454 let hitl_policy = ConfirmationPolicy {
3455 enabled: true,
3456 ..Default::default()
3457 };
3458 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3459
3460 let permission_policy = PermissionPolicy::new();
3462
3463 let config = AgentConfig {
3464 permission_checker: Some(Arc::new(permission_policy)),
3465 confirmation_manager: Some(confirmation_manager.clone()),
3466 ..Default::default()
3467 };
3468
3469 let cm_clone = confirmation_manager.clone();
3471 tokio::spawn(async move {
3472 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
3473 cm_clone
3474 .confirm("tool-1", false, Some("Too dangerous".to_string()))
3475 .await
3476 .ok();
3477 });
3478
3479 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3480 let result = agent.execute(&[], "Delete everything", None).await.unwrap();
3481
3482 assert_eq!(result.text, "Understood, I won't do that.");
3484 }
3485
3486 #[tokio::test]
3487 async fn test_agent_hitl_timeout_reject() {
3488 use crate::hitl::{ConfirmationManager, ConfirmationPolicy, TimeoutAction};
3489 use tokio::sync::broadcast;
3490
3491 let mock_client = Arc::new(MockLlmClient::new(vec![
3492 MockLlmClient::tool_call_response(
3493 "tool-1",
3494 "bash",
3495 serde_json::json!({"command": "echo test"}),
3496 ),
3497 MockLlmClient::text_response("Timed out, I understand."),
3498 ]));
3499
3500 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3501
3502 let (event_tx, _event_rx) = broadcast::channel(100);
3504 let hitl_policy = ConfirmationPolicy {
3505 enabled: true,
3506 default_timeout_ms: 50, timeout_action: TimeoutAction::Reject,
3508 ..Default::default()
3509 };
3510 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3511
3512 let permission_policy = PermissionPolicy::new();
3513
3514 let config = AgentConfig {
3515 permission_checker: Some(Arc::new(permission_policy)),
3516 confirmation_manager: Some(confirmation_manager),
3517 ..Default::default()
3518 };
3519
3520 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3522 let result = agent.execute(&[], "Echo", None).await.unwrap();
3523
3524 assert_eq!(result.text, "Timed out, I understand.");
3526 }
3527
3528 #[tokio::test]
3529 async fn test_agent_hitl_timeout_auto_approve() {
3530 use crate::hitl::{ConfirmationManager, ConfirmationPolicy, TimeoutAction};
3531 use tokio::sync::broadcast;
3532
3533 let mock_client = Arc::new(MockLlmClient::new(vec![
3534 MockLlmClient::tool_call_response(
3535 "tool-1",
3536 "bash",
3537 serde_json::json!({"command": "echo hello"}),
3538 ),
3539 MockLlmClient::text_response("Auto-approved and executed!"),
3540 ]));
3541
3542 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3543
3544 let (event_tx, _event_rx) = broadcast::channel(100);
3546 let hitl_policy = ConfirmationPolicy {
3547 enabled: true,
3548 default_timeout_ms: 50, timeout_action: TimeoutAction::AutoApprove,
3550 ..Default::default()
3551 };
3552 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3553
3554 let permission_policy = PermissionPolicy::new();
3555
3556 let config = AgentConfig {
3557 permission_checker: Some(Arc::new(permission_policy)),
3558 confirmation_manager: Some(confirmation_manager),
3559 ..Default::default()
3560 };
3561
3562 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3564 let result = agent.execute(&[], "Echo", None).await.unwrap();
3565
3566 assert_eq!(result.text, "Auto-approved and executed!");
3568 assert_eq!(result.tool_calls_count, 1);
3569 }
3570
3571 #[tokio::test]
3572 async fn test_agent_hitl_confirmation_events() {
3573 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
3574 use tokio::sync::broadcast;
3575
3576 let mock_client = Arc::new(MockLlmClient::new(vec![
3577 MockLlmClient::tool_call_response(
3578 "tool-1",
3579 "bash",
3580 serde_json::json!({"command": "echo test"}),
3581 ),
3582 MockLlmClient::text_response("Done!"),
3583 ]));
3584
3585 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3586
3587 let (event_tx, mut event_rx) = broadcast::channel(100);
3589 let hitl_policy = ConfirmationPolicy {
3590 enabled: true,
3591 default_timeout_ms: 5000, ..Default::default()
3593 };
3594 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3595
3596 let permission_policy = PermissionPolicy::new();
3597
3598 let config = AgentConfig {
3599 permission_checker: Some(Arc::new(permission_policy)),
3600 confirmation_manager: Some(confirmation_manager.clone()),
3601 ..Default::default()
3602 };
3603
3604 let cm_clone = confirmation_manager.clone();
3606 let event_handle = tokio::spawn(async move {
3607 let mut events = Vec::new();
3608 while let Ok(event) = event_rx.recv().await {
3610 events.push(event.clone());
3611 if let AgentEvent::ConfirmationRequired { tool_id, .. } = event {
3612 cm_clone.confirm(&tool_id, true, None).await.ok();
3614 if let Ok(recv_event) = event_rx.recv().await {
3616 events.push(recv_event);
3617 }
3618 break;
3619 }
3620 }
3621 events
3622 });
3623
3624 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3625 let _result = agent.execute(&[], "Echo", None).await.unwrap();
3626
3627 let events = event_handle.await.unwrap();
3629 assert!(
3630 events
3631 .iter()
3632 .any(|e| matches!(e, AgentEvent::ConfirmationRequired { .. })),
3633 "Should have ConfirmationRequired event"
3634 );
3635 assert!(
3636 events
3637 .iter()
3638 .any(|e| matches!(e, AgentEvent::ConfirmationReceived { approved: true, .. })),
3639 "Should have ConfirmationReceived event with approved=true"
3640 );
3641 }
3642
3643 #[tokio::test]
3644 async fn test_agent_hitl_disabled_auto_executes() {
3645 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
3647 use tokio::sync::broadcast;
3648
3649 let mock_client = Arc::new(MockLlmClient::new(vec![
3650 MockLlmClient::tool_call_response(
3651 "tool-1",
3652 "bash",
3653 serde_json::json!({"command": "echo auto"}),
3654 ),
3655 MockLlmClient::text_response("Auto executed!"),
3656 ]));
3657
3658 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3659
3660 let (event_tx, _event_rx) = broadcast::channel(100);
3662 let hitl_policy = ConfirmationPolicy {
3663 enabled: false, ..Default::default()
3665 };
3666 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3667
3668 let permission_policy = PermissionPolicy::new(); let config = AgentConfig {
3671 permission_checker: Some(Arc::new(permission_policy)),
3672 confirmation_manager: Some(confirmation_manager),
3673 ..Default::default()
3674 };
3675
3676 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3677 let result = agent.execute(&[], "Echo", None).await.unwrap();
3678
3679 assert_eq!(result.text, "Auto executed!");
3681 assert_eq!(result.tool_calls_count, 1);
3682 }
3683
3684 #[tokio::test]
3685 async fn test_agent_hitl_with_permission_deny_skips_hitl() {
3686 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
3688 use tokio::sync::broadcast;
3689
3690 let mock_client = Arc::new(MockLlmClient::new(vec![
3691 MockLlmClient::tool_call_response(
3692 "tool-1",
3693 "bash",
3694 serde_json::json!({"command": "rm -rf /"}),
3695 ),
3696 MockLlmClient::text_response("Blocked by permission."),
3697 ]));
3698
3699 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3700
3701 let (event_tx, mut event_rx) = broadcast::channel(100);
3703 let hitl_policy = ConfirmationPolicy {
3704 enabled: true,
3705 ..Default::default()
3706 };
3707 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3708
3709 let permission_policy = PermissionPolicy::new().deny("bash(rm:*)");
3711
3712 let config = AgentConfig {
3713 permission_checker: Some(Arc::new(permission_policy)),
3714 confirmation_manager: Some(confirmation_manager),
3715 ..Default::default()
3716 };
3717
3718 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3719 let result = agent.execute(&[], "Delete", None).await.unwrap();
3720
3721 assert_eq!(result.text, "Blocked by permission.");
3723
3724 let mut found_confirmation = false;
3726 while let Ok(event) = event_rx.try_recv() {
3727 if matches!(event, AgentEvent::ConfirmationRequired { .. }) {
3728 found_confirmation = true;
3729 }
3730 }
3731 assert!(
3732 !found_confirmation,
3733 "HITL should not be triggered when permission is Deny"
3734 );
3735 }
3736
3737 #[tokio::test]
3738 async fn test_agent_hitl_with_permission_allow_skips_hitl() {
3739 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
3742 use tokio::sync::broadcast;
3743
3744 let mock_client = Arc::new(MockLlmClient::new(vec![
3745 MockLlmClient::tool_call_response(
3746 "tool-1",
3747 "bash",
3748 serde_json::json!({"command": "echo hello"}),
3749 ),
3750 MockLlmClient::text_response("Allowed!"),
3751 ]));
3752
3753 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3754
3755 let (event_tx, mut event_rx) = broadcast::channel(100);
3757 let hitl_policy = ConfirmationPolicy {
3758 enabled: true,
3759 ..Default::default()
3760 };
3761 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3762
3763 let permission_policy = PermissionPolicy::new().allow("bash(echo:*)");
3765
3766 let config = AgentConfig {
3767 permission_checker: Some(Arc::new(permission_policy)),
3768 confirmation_manager: Some(confirmation_manager.clone()),
3769 ..Default::default()
3770 };
3771
3772 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3773 let result = agent.execute(&[], "Echo", None).await.unwrap();
3774
3775 assert_eq!(result.text, "Allowed!");
3777
3778 let mut found_confirmation = false;
3780 while let Ok(event) = event_rx.try_recv() {
3781 if matches!(event, AgentEvent::ConfirmationRequired { .. }) {
3782 found_confirmation = true;
3783 }
3784 }
3785 assert!(
3786 !found_confirmation,
3787 "Permission Allow should skip HITL confirmation"
3788 );
3789 }
3790
3791 #[tokio::test]
3792 async fn test_agent_hitl_multiple_tool_calls() {
3793 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
3795 use tokio::sync::broadcast;
3796
3797 let mock_client = Arc::new(MockLlmClient::new(vec![
3798 LlmResponse {
3800 message: Message {
3801 role: "assistant".to_string(),
3802 content: vec![
3803 ContentBlock::ToolUse {
3804 id: "tool-1".to_string(),
3805 name: "bash".to_string(),
3806 input: serde_json::json!({"command": "echo first"}),
3807 },
3808 ContentBlock::ToolUse {
3809 id: "tool-2".to_string(),
3810 name: "bash".to_string(),
3811 input: serde_json::json!({"command": "echo second"}),
3812 },
3813 ],
3814 reasoning_content: None,
3815 },
3816 usage: TokenUsage {
3817 prompt_tokens: 10,
3818 completion_tokens: 5,
3819 total_tokens: 15,
3820 cache_read_tokens: None,
3821 cache_write_tokens: None,
3822 },
3823 stop_reason: Some("tool_use".to_string()),
3824 },
3825 MockLlmClient::text_response("Both executed!"),
3826 ]));
3827
3828 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3829
3830 let (event_tx, _event_rx) = broadcast::channel(100);
3832 let hitl_policy = ConfirmationPolicy {
3833 enabled: true,
3834 default_timeout_ms: 5000,
3835 ..Default::default()
3836 };
3837 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3838
3839 let permission_policy = PermissionPolicy::new(); let config = AgentConfig {
3842 permission_checker: Some(Arc::new(permission_policy)),
3843 confirmation_manager: Some(confirmation_manager.clone()),
3844 ..Default::default()
3845 };
3846
3847 let cm_clone = confirmation_manager.clone();
3849 tokio::spawn(async move {
3850 tokio::time::sleep(std::time::Duration::from_millis(30)).await;
3851 cm_clone.confirm("tool-1", true, None).await.ok();
3852 tokio::time::sleep(std::time::Duration::from_millis(30)).await;
3853 cm_clone.confirm("tool-2", true, None).await.ok();
3854 });
3855
3856 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3857 let result = agent.execute(&[], "Run both", None).await.unwrap();
3858
3859 assert_eq!(result.text, "Both executed!");
3860 assert_eq!(result.tool_calls_count, 2);
3861 }
3862
3863 #[tokio::test]
3864 async fn test_agent_hitl_partial_approval() {
3865 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
3867 use tokio::sync::broadcast;
3868
3869 let mock_client = Arc::new(MockLlmClient::new(vec![
3870 LlmResponse {
3872 message: Message {
3873 role: "assistant".to_string(),
3874 content: vec![
3875 ContentBlock::ToolUse {
3876 id: "tool-1".to_string(),
3877 name: "bash".to_string(),
3878 input: serde_json::json!({"command": "echo safe"}),
3879 },
3880 ContentBlock::ToolUse {
3881 id: "tool-2".to_string(),
3882 name: "bash".to_string(),
3883 input: serde_json::json!({"command": "rm -rf /"}),
3884 },
3885 ],
3886 reasoning_content: None,
3887 },
3888 usage: TokenUsage {
3889 prompt_tokens: 10,
3890 completion_tokens: 5,
3891 total_tokens: 15,
3892 cache_read_tokens: None,
3893 cache_write_tokens: None,
3894 },
3895 stop_reason: Some("tool_use".to_string()),
3896 },
3897 MockLlmClient::text_response("First worked, second rejected."),
3898 ]));
3899
3900 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3901
3902 let (event_tx, _event_rx) = broadcast::channel(100);
3903 let hitl_policy = ConfirmationPolicy {
3904 enabled: true,
3905 default_timeout_ms: 5000,
3906 ..Default::default()
3907 };
3908 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3909
3910 let permission_policy = PermissionPolicy::new();
3911
3912 let config = AgentConfig {
3913 permission_checker: Some(Arc::new(permission_policy)),
3914 confirmation_manager: Some(confirmation_manager.clone()),
3915 ..Default::default()
3916 };
3917
3918 let cm_clone = confirmation_manager.clone();
3920 tokio::spawn(async move {
3921 tokio::time::sleep(std::time::Duration::from_millis(30)).await;
3922 cm_clone.confirm("tool-1", true, None).await.ok();
3923 tokio::time::sleep(std::time::Duration::from_millis(30)).await;
3924 cm_clone
3925 .confirm("tool-2", false, Some("Dangerous".to_string()))
3926 .await
3927 .ok();
3928 });
3929
3930 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3931 let result = agent.execute(&[], "Run both", None).await.unwrap();
3932
3933 assert_eq!(result.text, "First worked, second rejected.");
3934 assert_eq!(result.tool_calls_count, 2);
3935 }
3936
3937 #[tokio::test]
3938 async fn test_agent_hitl_yolo_mode_auto_approves() {
3939 use crate::hitl::{ConfirmationManager, ConfirmationPolicy, SessionLane};
3941 use tokio::sync::broadcast;
3942
3943 let mock_client = Arc::new(MockLlmClient::new(vec![
3944 MockLlmClient::tool_call_response(
3945 "tool-1",
3946 "read", serde_json::json!({"path": "/tmp/test.txt"}),
3948 ),
3949 MockLlmClient::text_response("File read!"),
3950 ]));
3951
3952 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3953
3954 let (event_tx, mut event_rx) = broadcast::channel(100);
3956 let mut yolo_lanes = std::collections::HashSet::new();
3957 yolo_lanes.insert(SessionLane::Query);
3958 let hitl_policy = ConfirmationPolicy {
3959 enabled: true,
3960 yolo_lanes, ..Default::default()
3962 };
3963 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3964
3965 let permission_policy = PermissionPolicy::new();
3966
3967 let config = AgentConfig {
3968 permission_checker: Some(Arc::new(permission_policy)),
3969 confirmation_manager: Some(confirmation_manager),
3970 ..Default::default()
3971 };
3972
3973 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3974 let result = agent.execute(&[], "Read file", None).await.unwrap();
3975
3976 assert_eq!(result.text, "File read!");
3978
3979 let mut found_confirmation = false;
3981 while let Ok(event) = event_rx.try_recv() {
3982 if matches!(event, AgentEvent::ConfirmationRequired { .. }) {
3983 found_confirmation = true;
3984 }
3985 }
3986 assert!(
3987 !found_confirmation,
3988 "YOLO mode should not trigger confirmation"
3989 );
3990 }
3991
3992 #[tokio::test]
3993 async fn test_agent_config_with_all_options() {
3994 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
3995 use tokio::sync::broadcast;
3996
3997 let (event_tx, _) = broadcast::channel(100);
3998 let hitl_policy = ConfirmationPolicy::default();
3999 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
4000
4001 let permission_policy = PermissionPolicy::new().allow("bash(*)");
4002
4003 let config = AgentConfig {
4004 prompt_slots: SystemPromptSlots {
4005 extra: Some("Test system prompt".to_string()),
4006 ..Default::default()
4007 },
4008 tools: vec![],
4009 max_tool_rounds: 10,
4010 permission_checker: Some(Arc::new(permission_policy)),
4011 confirmation_manager: Some(confirmation_manager),
4012 context_providers: vec![],
4013 planning_enabled: false,
4014 goal_tracking: false,
4015 hook_engine: None,
4016 skill_registry: None,
4017 ..AgentConfig::default()
4018 };
4019
4020 assert!(config.prompt_slots.build().contains("Test system prompt"));
4021 assert_eq!(config.max_tool_rounds, 10);
4022 assert!(config.permission_checker.is_some());
4023 assert!(config.confirmation_manager.is_some());
4024 assert!(config.context_providers.is_empty());
4025
4026 let debug_str = format!("{:?}", config);
4028 assert!(debug_str.contains("AgentConfig"));
4029 assert!(debug_str.contains("permission_checker: true"));
4030 assert!(debug_str.contains("confirmation_manager: true"));
4031 assert!(debug_str.contains("context_providers: 0"));
4032 }
4033
4034 use crate::context::{ContextItem, ContextType};
4039
4040 struct MockContextProvider {
4042 name: String,
4043 items: Vec<ContextItem>,
4044 on_turn_calls: std::sync::Arc<tokio::sync::RwLock<Vec<(String, String, String)>>>,
4045 }
4046
4047 impl MockContextProvider {
4048 fn new(name: &str) -> Self {
4049 Self {
4050 name: name.to_string(),
4051 items: Vec::new(),
4052 on_turn_calls: std::sync::Arc::new(tokio::sync::RwLock::new(Vec::new())),
4053 }
4054 }
4055
4056 fn with_items(mut self, items: Vec<ContextItem>) -> Self {
4057 self.items = items;
4058 self
4059 }
4060 }
4061
4062 #[async_trait::async_trait]
4063 impl ContextProvider for MockContextProvider {
4064 fn name(&self) -> &str {
4065 &self.name
4066 }
4067
4068 async fn query(&self, _query: &ContextQuery) -> anyhow::Result<ContextResult> {
4069 let mut result = ContextResult::new(&self.name);
4070 for item in &self.items {
4071 result.add_item(item.clone());
4072 }
4073 Ok(result)
4074 }
4075
4076 async fn on_turn_complete(
4077 &self,
4078 session_id: &str,
4079 prompt: &str,
4080 response: &str,
4081 ) -> anyhow::Result<()> {
4082 let mut calls = self.on_turn_calls.write().await;
4083 calls.push((
4084 session_id.to_string(),
4085 prompt.to_string(),
4086 response.to_string(),
4087 ));
4088 Ok(())
4089 }
4090 }
4091
4092 #[tokio::test]
4093 async fn test_agent_with_context_provider() {
4094 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4095 "Response using context",
4096 )]));
4097
4098 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4099
4100 let provider =
4101 MockContextProvider::new("test-provider").with_items(vec![ContextItem::new(
4102 "ctx-1",
4103 ContextType::Resource,
4104 "Relevant context here",
4105 )
4106 .with_source("test://docs/example")]);
4107
4108 let config = AgentConfig {
4109 prompt_slots: SystemPromptSlots {
4110 extra: Some("You are helpful.".to_string()),
4111 ..Default::default()
4112 },
4113 context_providers: vec![Arc::new(provider)],
4114 ..Default::default()
4115 };
4116
4117 let agent = AgentLoop::new(
4118 mock_client.clone(),
4119 tool_executor,
4120 test_tool_context(),
4121 config,
4122 );
4123 let result = agent.execute(&[], "What is X?", None).await.unwrap();
4124
4125 assert_eq!(result.text, "Response using context");
4126 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 1);
4127 }
4128
4129 #[tokio::test]
4130 async fn test_agent_context_provider_events() {
4131 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4132 "Answer",
4133 )]));
4134
4135 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4136
4137 let provider =
4138 MockContextProvider::new("event-provider").with_items(vec![ContextItem::new(
4139 "item-1",
4140 ContextType::Memory,
4141 "Memory content",
4142 )
4143 .with_token_count(50)]);
4144
4145 let config = AgentConfig {
4146 context_providers: vec![Arc::new(provider)],
4147 ..Default::default()
4148 };
4149
4150 let (tx, mut rx) = mpsc::channel(100);
4151 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
4152 let _result = agent.execute(&[], "Test prompt", Some(tx)).await.unwrap();
4153
4154 let mut events = Vec::new();
4156 while let Ok(event) = rx.try_recv() {
4157 events.push(event);
4158 }
4159
4160 assert!(
4162 events
4163 .iter()
4164 .any(|e| matches!(e, AgentEvent::ContextResolving { .. })),
4165 "Should have ContextResolving event"
4166 );
4167 assert!(
4168 events
4169 .iter()
4170 .any(|e| matches!(e, AgentEvent::ContextResolved { .. })),
4171 "Should have ContextResolved event"
4172 );
4173
4174 for event in &events {
4176 if let AgentEvent::ContextResolved {
4177 total_items,
4178 total_tokens,
4179 } = event
4180 {
4181 assert_eq!(*total_items, 1);
4182 assert_eq!(*total_tokens, 50);
4183 }
4184 }
4185 }
4186
4187 #[tokio::test]
4188 async fn test_agent_multiple_context_providers() {
4189 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4190 "Combined response",
4191 )]));
4192
4193 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4194
4195 let provider1 = MockContextProvider::new("provider-1").with_items(vec![ContextItem::new(
4196 "p1-1",
4197 ContextType::Resource,
4198 "Resource from P1",
4199 )
4200 .with_token_count(100)]);
4201
4202 let provider2 = MockContextProvider::new("provider-2").with_items(vec![
4203 ContextItem::new("p2-1", ContextType::Memory, "Memory from P2").with_token_count(50),
4204 ContextItem::new("p2-2", ContextType::Skill, "Skill from P2").with_token_count(75),
4205 ]);
4206
4207 let config = AgentConfig {
4208 prompt_slots: SystemPromptSlots {
4209 extra: Some("Base system prompt.".to_string()),
4210 ..Default::default()
4211 },
4212 context_providers: vec![Arc::new(provider1), Arc::new(provider2)],
4213 ..Default::default()
4214 };
4215
4216 let (tx, mut rx) = mpsc::channel(100);
4217 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
4218 let result = agent.execute(&[], "Query", Some(tx)).await.unwrap();
4219
4220 assert_eq!(result.text, "Combined response");
4221
4222 while let Ok(event) = rx.try_recv() {
4224 if let AgentEvent::ContextResolved {
4225 total_items,
4226 total_tokens,
4227 } = event
4228 {
4229 assert_eq!(total_items, 3); assert_eq!(total_tokens, 225); }
4232 }
4233 }
4234
4235 #[tokio::test]
4236 async fn test_agent_no_context_providers() {
4237 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4238 "No context",
4239 )]));
4240
4241 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4242
4243 let config = AgentConfig::default();
4245
4246 let (tx, mut rx) = mpsc::channel(100);
4247 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
4248 let result = agent.execute(&[], "Simple prompt", Some(tx)).await.unwrap();
4249
4250 assert_eq!(result.text, "No context");
4251
4252 let mut events = Vec::new();
4254 while let Ok(event) = rx.try_recv() {
4255 events.push(event);
4256 }
4257
4258 assert!(
4259 !events
4260 .iter()
4261 .any(|e| matches!(e, AgentEvent::ContextResolving { .. })),
4262 "Should NOT have ContextResolving event"
4263 );
4264 }
4265
4266 #[tokio::test]
4267 async fn test_agent_context_on_turn_complete() {
4268 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4269 "Final response",
4270 )]));
4271
4272 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4273
4274 let provider = Arc::new(MockContextProvider::new("memory-provider"));
4275 let on_turn_calls = provider.on_turn_calls.clone();
4276
4277 let config = AgentConfig {
4278 context_providers: vec![provider],
4279 ..Default::default()
4280 };
4281
4282 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
4283
4284 let result = agent
4286 .execute_with_session(&[], "User prompt", Some("sess-123"), None)
4287 .await
4288 .unwrap();
4289
4290 assert_eq!(result.text, "Final response");
4291
4292 let calls = on_turn_calls.read().await;
4294 assert_eq!(calls.len(), 1);
4295 assert_eq!(calls[0].0, "sess-123");
4296 assert_eq!(calls[0].1, "User prompt");
4297 assert_eq!(calls[0].2, "Final response");
4298 }
4299
4300 #[tokio::test]
4301 async fn test_agent_context_on_turn_complete_no_session() {
4302 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4303 "Response",
4304 )]));
4305
4306 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4307
4308 let provider = Arc::new(MockContextProvider::new("memory-provider"));
4309 let on_turn_calls = provider.on_turn_calls.clone();
4310
4311 let config = AgentConfig {
4312 context_providers: vec![provider],
4313 ..Default::default()
4314 };
4315
4316 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
4317
4318 let _result = agent.execute(&[], "Prompt", None).await.unwrap();
4320
4321 let calls = on_turn_calls.read().await;
4323 assert!(calls.is_empty());
4324 }
4325
4326 #[tokio::test]
4327 async fn test_agent_build_augmented_system_prompt() {
4328 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response("OK")]));
4329
4330 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4331
4332 let provider = MockContextProvider::new("test").with_items(vec![ContextItem::new(
4333 "doc-1",
4334 ContextType::Resource,
4335 "Auth uses JWT tokens.",
4336 )
4337 .with_source("viking://docs/auth")]);
4338
4339 let config = AgentConfig {
4340 prompt_slots: SystemPromptSlots {
4341 extra: Some("You are helpful.".to_string()),
4342 ..Default::default()
4343 },
4344 context_providers: vec![Arc::new(provider)],
4345 ..Default::default()
4346 };
4347
4348 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
4349
4350 let context_results = agent.resolve_context("test", None).await;
4352 let augmented = agent.build_augmented_system_prompt(&context_results);
4353
4354 let augmented_str = augmented.unwrap();
4355 assert!(augmented_str.contains("You are helpful."));
4356 assert!(augmented_str.contains("<context source=\"viking://docs/auth\" type=\"Resource\">"));
4357 assert!(augmented_str.contains("Auth uses JWT tokens."));
4358 }
4359
4360 async fn collect_events(mut rx: mpsc::Receiver<AgentEvent>) -> Vec<AgentEvent> {
4366 let mut events = Vec::new();
4367 while let Ok(event) = rx.try_recv() {
4368 events.push(event);
4369 }
4370 while let Some(event) = rx.recv().await {
4372 events.push(event);
4373 }
4374 events
4375 }
4376
4377 #[tokio::test]
4378 async fn test_agent_multi_turn_tool_chain() {
4379 let mock_client = Arc::new(MockLlmClient::new(vec![
4381 MockLlmClient::tool_call_response(
4383 "t1",
4384 "bash",
4385 serde_json::json!({"command": "echo step1"}),
4386 ),
4387 MockLlmClient::tool_call_response(
4389 "t2",
4390 "bash",
4391 serde_json::json!({"command": "echo step2"}),
4392 ),
4393 MockLlmClient::text_response("Completed both steps: step1 then step2"),
4395 ]));
4396
4397 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4398 let config = AgentConfig::default();
4399
4400 let agent = AgentLoop::new(
4401 mock_client.clone(),
4402 tool_executor,
4403 test_tool_context(),
4404 config,
4405 );
4406 let result = agent.execute(&[], "Run two steps", None).await.unwrap();
4407
4408 assert_eq!(result.text, "Completed both steps: step1 then step2");
4409 assert_eq!(result.tool_calls_count, 2);
4410 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 3);
4411
4412 assert_eq!(result.messages[0].role, "user");
4414 assert_eq!(result.messages[1].role, "assistant"); assert_eq!(result.messages[2].role, "user"); assert_eq!(result.messages[3].role, "assistant"); assert_eq!(result.messages[4].role, "user"); assert_eq!(result.messages[5].role, "assistant"); assert_eq!(result.messages.len(), 6);
4420 }
4421
4422 #[tokio::test]
4423 async fn test_agent_conversation_history_preserved() {
4424 let existing_history = vec![
4426 Message::user("What is Rust?"),
4427 Message {
4428 role: "assistant".to_string(),
4429 content: vec![ContentBlock::Text {
4430 text: "Rust is a systems programming language.".to_string(),
4431 }],
4432 reasoning_content: None,
4433 },
4434 ];
4435
4436 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4437 "Rust was created by Graydon Hoare at Mozilla.",
4438 )]));
4439
4440 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4441 let agent = AgentLoop::new(
4442 mock_client.clone(),
4443 tool_executor,
4444 test_tool_context(),
4445 AgentConfig::default(),
4446 );
4447
4448 let result = agent
4449 .execute(&existing_history, "Who created it?", None)
4450 .await
4451 .unwrap();
4452
4453 assert_eq!(result.messages.len(), 4);
4455 assert_eq!(result.messages[0].text(), "What is Rust?");
4456 assert_eq!(
4457 result.messages[1].text(),
4458 "Rust is a systems programming language."
4459 );
4460 assert_eq!(result.messages[2].text(), "Who created it?");
4461 assert_eq!(
4462 result.messages[3].text(),
4463 "Rust was created by Graydon Hoare at Mozilla."
4464 );
4465 }
4466
4467 #[tokio::test]
4468 async fn test_agent_event_stream_completeness() {
4469 let mock_client = Arc::new(MockLlmClient::new(vec![
4471 MockLlmClient::tool_call_response(
4472 "t1",
4473 "bash",
4474 serde_json::json!({"command": "echo hi"}),
4475 ),
4476 MockLlmClient::text_response("Done"),
4477 ]));
4478
4479 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4480 let agent = AgentLoop::new(
4481 mock_client,
4482 tool_executor,
4483 test_tool_context(),
4484 AgentConfig::default(),
4485 );
4486
4487 let (tx, rx) = mpsc::channel(100);
4488 let result = agent.execute(&[], "Say hi", Some(tx)).await.unwrap();
4489 assert_eq!(result.text, "Done");
4490
4491 let events = collect_events(rx).await;
4492
4493 let event_types: Vec<&str> = events
4495 .iter()
4496 .map(|e| match e {
4497 AgentEvent::Start { .. } => "Start",
4498 AgentEvent::TurnStart { .. } => "TurnStart",
4499 AgentEvent::TurnEnd { .. } => "TurnEnd",
4500 AgentEvent::ToolEnd { .. } => "ToolEnd",
4501 AgentEvent::End { .. } => "End",
4502 _ => "Other",
4503 })
4504 .collect();
4505
4506 assert_eq!(event_types.first(), Some(&"Start"));
4508 assert_eq!(event_types.last(), Some(&"End"));
4509
4510 let turn_starts = event_types.iter().filter(|&&t| t == "TurnStart").count();
4512 assert_eq!(turn_starts, 2);
4513
4514 let tool_ends = event_types.iter().filter(|&&t| t == "ToolEnd").count();
4516 assert_eq!(tool_ends, 1);
4517 }
4518
4519 #[tokio::test]
4520 async fn test_agent_multiple_tools_single_turn() {
4521 let mock_client = Arc::new(MockLlmClient::new(vec![
4523 LlmResponse {
4524 message: Message {
4525 role: "assistant".to_string(),
4526 content: vec![
4527 ContentBlock::ToolUse {
4528 id: "t1".to_string(),
4529 name: "bash".to_string(),
4530 input: serde_json::json!({"command": "echo first"}),
4531 },
4532 ContentBlock::ToolUse {
4533 id: "t2".to_string(),
4534 name: "bash".to_string(),
4535 input: serde_json::json!({"command": "echo second"}),
4536 },
4537 ],
4538 reasoning_content: None,
4539 },
4540 usage: TokenUsage {
4541 prompt_tokens: 10,
4542 completion_tokens: 5,
4543 total_tokens: 15,
4544 cache_read_tokens: None,
4545 cache_write_tokens: None,
4546 },
4547 stop_reason: Some("tool_use".to_string()),
4548 },
4549 MockLlmClient::text_response("Both commands ran"),
4550 ]));
4551
4552 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4553 let agent = AgentLoop::new(
4554 mock_client.clone(),
4555 tool_executor,
4556 test_tool_context(),
4557 AgentConfig::default(),
4558 );
4559
4560 let result = agent.execute(&[], "Run both", None).await.unwrap();
4561
4562 assert_eq!(result.text, "Both commands ran");
4563 assert_eq!(result.tool_calls_count, 2);
4564 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 2); assert_eq!(result.messages[0].role, "user");
4568 assert_eq!(result.messages[1].role, "assistant");
4569 assert_eq!(result.messages[2].role, "user"); assert_eq!(result.messages[3].role, "user"); assert_eq!(result.messages[4].role, "assistant");
4572 }
4573
4574 #[tokio::test]
4575 async fn test_agent_token_usage_accumulation() {
4576 let mock_client = Arc::new(MockLlmClient::new(vec![
4578 MockLlmClient::tool_call_response(
4579 "t1",
4580 "bash",
4581 serde_json::json!({"command": "echo x"}),
4582 ),
4583 MockLlmClient::text_response("Done"),
4584 ]));
4585
4586 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4587 let agent = AgentLoop::new(
4588 mock_client,
4589 tool_executor,
4590 test_tool_context(),
4591 AgentConfig::default(),
4592 );
4593
4594 let result = agent.execute(&[], "test", None).await.unwrap();
4595
4596 assert_eq!(result.usage.prompt_tokens, 20);
4599 assert_eq!(result.usage.completion_tokens, 10);
4600 assert_eq!(result.usage.total_tokens, 30);
4601 }
4602
4603 #[tokio::test]
4604 async fn test_agent_system_prompt_passed() {
4605 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4607 "I am a coding assistant.",
4608 )]));
4609
4610 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4611 let config = AgentConfig {
4612 prompt_slots: SystemPromptSlots {
4613 extra: Some("You are a coding assistant.".to_string()),
4614 ..Default::default()
4615 },
4616 ..Default::default()
4617 };
4618
4619 let agent = AgentLoop::new(
4620 mock_client.clone(),
4621 tool_executor,
4622 test_tool_context(),
4623 config,
4624 );
4625 let result = agent.execute(&[], "What are you?", None).await.unwrap();
4626
4627 assert_eq!(result.text, "I am a coding assistant.");
4628 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 1);
4629 }
4630
4631 #[tokio::test]
4632 async fn test_agent_max_rounds_with_persistent_tool_calls() {
4633 let mut responses = Vec::new();
4635 for i in 0..15 {
4636 responses.push(MockLlmClient::tool_call_response(
4637 &format!("t{}", i),
4638 "bash",
4639 serde_json::json!({"command": format!("echo round{}", i)}),
4640 ));
4641 }
4642
4643 let mock_client = Arc::new(MockLlmClient::new(responses));
4644 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4645 let config = AgentConfig {
4646 max_tool_rounds: 5,
4647 ..Default::default()
4648 };
4649
4650 let agent = AgentLoop::new(
4651 mock_client.clone(),
4652 tool_executor,
4653 test_tool_context(),
4654 config,
4655 );
4656 let result = agent.execute(&[], "Loop forever", None).await;
4657
4658 assert!(result.is_err());
4659 let err = result.unwrap_err().to_string();
4660 assert!(err.contains("Max tool rounds (5) exceeded"));
4661 }
4662
4663 #[tokio::test]
4664 async fn test_agent_end_event_contains_final_text() {
4665 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4666 "Final answer here",
4667 )]));
4668
4669 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4670 let agent = AgentLoop::new(
4671 mock_client,
4672 tool_executor,
4673 test_tool_context(),
4674 AgentConfig::default(),
4675 );
4676
4677 let (tx, rx) = mpsc::channel(100);
4678 agent.execute(&[], "test", Some(tx)).await.unwrap();
4679
4680 let events = collect_events(rx).await;
4681 let end_event = events.iter().find(|e| matches!(e, AgentEvent::End { .. }));
4682 assert!(end_event.is_some());
4683
4684 if let AgentEvent::End { text, usage } = end_event.unwrap() {
4685 assert_eq!(text, "Final answer here");
4686 assert_eq!(usage.total_tokens, 15);
4687 }
4688 }
4689}
4690
4691#[cfg(test)]
4692mod extra_agent_tests {
4693 use super::*;
4694 use crate::agent::tests::MockLlmClient;
4695 use crate::queue::SessionQueueConfig;
4696 use crate::tools::ToolExecutor;
4697 use std::path::PathBuf;
4698 use std::sync::atomic::{AtomicUsize, Ordering};
4699
4700 fn test_tool_context() -> ToolContext {
4701 ToolContext::new(PathBuf::from("/tmp"))
4702 }
4703
4704 #[test]
4709 fn test_agent_config_debug() {
4710 let config = AgentConfig {
4711 prompt_slots: SystemPromptSlots {
4712 extra: Some("You are helpful".to_string()),
4713 ..Default::default()
4714 },
4715 tools: vec![],
4716 max_tool_rounds: 10,
4717 permission_checker: None,
4718 confirmation_manager: None,
4719 context_providers: vec![],
4720 planning_enabled: true,
4721 goal_tracking: false,
4722 hook_engine: None,
4723 skill_registry: None,
4724 ..AgentConfig::default()
4725 };
4726 let debug = format!("{:?}", config);
4727 assert!(debug.contains("AgentConfig"));
4728 assert!(debug.contains("planning_enabled"));
4729 }
4730
4731 #[test]
4732 fn test_agent_config_default_values() {
4733 let config = AgentConfig::default();
4734 assert_eq!(config.max_tool_rounds, MAX_TOOL_ROUNDS);
4735 assert!(!config.planning_enabled);
4736 assert!(!config.goal_tracking);
4737 assert!(config.context_providers.is_empty());
4738 }
4739
4740 #[test]
4745 fn test_agent_event_serialize_start() {
4746 let event = AgentEvent::Start {
4747 prompt: "Hello".to_string(),
4748 };
4749 let json = serde_json::to_string(&event).unwrap();
4750 assert!(json.contains("agent_start"));
4751 assert!(json.contains("Hello"));
4752 }
4753
4754 #[test]
4755 fn test_agent_event_serialize_text_delta() {
4756 let event = AgentEvent::TextDelta {
4757 text: "chunk".to_string(),
4758 };
4759 let json = serde_json::to_string(&event).unwrap();
4760 assert!(json.contains("text_delta"));
4761 }
4762
4763 #[test]
4764 fn test_agent_event_serialize_tool_start() {
4765 let event = AgentEvent::ToolStart {
4766 id: "t1".to_string(),
4767 name: "bash".to_string(),
4768 };
4769 let json = serde_json::to_string(&event).unwrap();
4770 assert!(json.contains("tool_start"));
4771 assert!(json.contains("bash"));
4772 }
4773
4774 #[test]
4775 fn test_agent_event_serialize_tool_end() {
4776 let event = AgentEvent::ToolEnd {
4777 id: "t1".to_string(),
4778 name: "bash".to_string(),
4779 output: "hello".to_string(),
4780 exit_code: 0,
4781 metadata: None,
4782 };
4783 let json = serde_json::to_string(&event).unwrap();
4784 assert!(json.contains("tool_end"));
4785 }
4786
4787 #[test]
4788 fn test_agent_event_tool_end_has_metadata_field() {
4789 let event = AgentEvent::ToolEnd {
4790 id: "t1".to_string(),
4791 name: "write".to_string(),
4792 output: "Wrote 5 bytes".to_string(),
4793 exit_code: 0,
4794 metadata: Some(
4795 serde_json::json!({ "before": "old", "after": "new", "file_path": "f.txt" }),
4796 ),
4797 };
4798 let json = serde_json::to_string(&event).unwrap();
4799 assert!(json.contains("\"before\""));
4800 }
4801
4802 #[test]
4803 fn test_agent_event_serialize_error() {
4804 let event = AgentEvent::Error {
4805 message: "oops".to_string(),
4806 };
4807 let json = serde_json::to_string(&event).unwrap();
4808 assert!(json.contains("error"));
4809 assert!(json.contains("oops"));
4810 }
4811
4812 #[test]
4813 fn test_agent_event_serialize_confirmation_required() {
4814 let event = AgentEvent::ConfirmationRequired {
4815 tool_id: "t1".to_string(),
4816 tool_name: "bash".to_string(),
4817 args: serde_json::json!({"cmd": "rm"}),
4818 timeout_ms: 30000,
4819 };
4820 let json = serde_json::to_string(&event).unwrap();
4821 assert!(json.contains("confirmation_required"));
4822 }
4823
4824 #[test]
4825 fn test_agent_event_serialize_confirmation_received() {
4826 let event = AgentEvent::ConfirmationReceived {
4827 tool_id: "t1".to_string(),
4828 approved: true,
4829 reason: Some("safe".to_string()),
4830 };
4831 let json = serde_json::to_string(&event).unwrap();
4832 assert!(json.contains("confirmation_received"));
4833 }
4834
4835 #[test]
4836 fn test_agent_event_serialize_confirmation_timeout() {
4837 let event = AgentEvent::ConfirmationTimeout {
4838 tool_id: "t1".to_string(),
4839 action_taken: "rejected".to_string(),
4840 };
4841 let json = serde_json::to_string(&event).unwrap();
4842 assert!(json.contains("confirmation_timeout"));
4843 }
4844
4845 #[test]
4846 fn test_agent_event_serialize_external_task_pending() {
4847 let event = AgentEvent::ExternalTaskPending {
4848 task_id: "task-1".to_string(),
4849 session_id: "sess-1".to_string(),
4850 lane: crate::hitl::SessionLane::Execute,
4851 command_type: "bash".to_string(),
4852 payload: serde_json::json!({}),
4853 timeout_ms: 60000,
4854 };
4855 let json = serde_json::to_string(&event).unwrap();
4856 assert!(json.contains("external_task_pending"));
4857 }
4858
4859 #[test]
4860 fn test_agent_event_serialize_external_task_completed() {
4861 let event = AgentEvent::ExternalTaskCompleted {
4862 task_id: "task-1".to_string(),
4863 session_id: "sess-1".to_string(),
4864 success: false,
4865 };
4866 let json = serde_json::to_string(&event).unwrap();
4867 assert!(json.contains("external_task_completed"));
4868 }
4869
4870 #[test]
4871 fn test_agent_event_serialize_permission_denied() {
4872 let event = AgentEvent::PermissionDenied {
4873 tool_id: "t1".to_string(),
4874 tool_name: "bash".to_string(),
4875 args: serde_json::json!({}),
4876 reason: "denied".to_string(),
4877 };
4878 let json = serde_json::to_string(&event).unwrap();
4879 assert!(json.contains("permission_denied"));
4880 }
4881
4882 #[test]
4883 fn test_agent_event_serialize_context_compacted() {
4884 let event = AgentEvent::ContextCompacted {
4885 session_id: "sess-1".to_string(),
4886 before_messages: 100,
4887 after_messages: 20,
4888 percent_before: 0.85,
4889 };
4890 let json = serde_json::to_string(&event).unwrap();
4891 assert!(json.contains("context_compacted"));
4892 }
4893
4894 #[test]
4895 fn test_agent_event_serialize_turn_start() {
4896 let event = AgentEvent::TurnStart { turn: 3 };
4897 let json = serde_json::to_string(&event).unwrap();
4898 assert!(json.contains("turn_start"));
4899 }
4900
4901 #[test]
4902 fn test_agent_event_serialize_turn_end() {
4903 let event = AgentEvent::TurnEnd {
4904 turn: 3,
4905 usage: TokenUsage::default(),
4906 };
4907 let json = serde_json::to_string(&event).unwrap();
4908 assert!(json.contains("turn_end"));
4909 }
4910
4911 #[test]
4912 fn test_agent_event_serialize_end() {
4913 let event = AgentEvent::End {
4914 text: "Done".to_string(),
4915 usage: TokenUsage {
4916 prompt_tokens: 100,
4917 completion_tokens: 50,
4918 total_tokens: 150,
4919 cache_read_tokens: None,
4920 cache_write_tokens: None,
4921 },
4922 };
4923 let json = serde_json::to_string(&event).unwrap();
4924 assert!(json.contains("agent_end"));
4925 }
4926
4927 #[test]
4932 fn test_agent_result_fields() {
4933 let result = AgentResult {
4934 text: "output".to_string(),
4935 messages: vec![Message::user("hello")],
4936 usage: TokenUsage::default(),
4937 tool_calls_count: 3,
4938 };
4939 assert_eq!(result.text, "output");
4940 assert_eq!(result.messages.len(), 1);
4941 assert_eq!(result.tool_calls_count, 3);
4942 }
4943
4944 #[test]
4949 fn test_agent_event_serialize_context_resolving() {
4950 let event = AgentEvent::ContextResolving {
4951 providers: vec!["provider1".to_string(), "provider2".to_string()],
4952 };
4953 let json = serde_json::to_string(&event).unwrap();
4954 assert!(json.contains("context_resolving"));
4955 assert!(json.contains("provider1"));
4956 }
4957
4958 #[test]
4959 fn test_agent_event_serialize_context_resolved() {
4960 let event = AgentEvent::ContextResolved {
4961 total_items: 5,
4962 total_tokens: 1000,
4963 };
4964 let json = serde_json::to_string(&event).unwrap();
4965 assert!(json.contains("context_resolved"));
4966 assert!(json.contains("1000"));
4967 }
4968
4969 #[test]
4970 fn test_agent_event_serialize_command_dead_lettered() {
4971 let event = AgentEvent::CommandDeadLettered {
4972 command_id: "cmd-1".to_string(),
4973 command_type: "bash".to_string(),
4974 lane: "execute".to_string(),
4975 error: "timeout".to_string(),
4976 attempts: 3,
4977 };
4978 let json = serde_json::to_string(&event).unwrap();
4979 assert!(json.contains("command_dead_lettered"));
4980 assert!(json.contains("cmd-1"));
4981 }
4982
4983 #[test]
4984 fn test_agent_event_serialize_command_retry() {
4985 let event = AgentEvent::CommandRetry {
4986 command_id: "cmd-2".to_string(),
4987 command_type: "read".to_string(),
4988 lane: "query".to_string(),
4989 attempt: 2,
4990 delay_ms: 1000,
4991 };
4992 let json = serde_json::to_string(&event).unwrap();
4993 assert!(json.contains("command_retry"));
4994 assert!(json.contains("cmd-2"));
4995 }
4996
4997 #[test]
4998 fn test_agent_event_serialize_queue_alert() {
4999 let event = AgentEvent::QueueAlert {
5000 level: "warning".to_string(),
5001 alert_type: "depth".to_string(),
5002 message: "Queue depth exceeded".to_string(),
5003 };
5004 let json = serde_json::to_string(&event).unwrap();
5005 assert!(json.contains("queue_alert"));
5006 assert!(json.contains("warning"));
5007 }
5008
5009 #[test]
5010 fn test_agent_event_serialize_task_updated() {
5011 let event = AgentEvent::TaskUpdated {
5012 session_id: "sess-1".to_string(),
5013 tasks: vec![],
5014 };
5015 let json = serde_json::to_string(&event).unwrap();
5016 assert!(json.contains("task_updated"));
5017 assert!(json.contains("sess-1"));
5018 }
5019
5020 #[test]
5021 fn test_agent_event_serialize_memory_stored() {
5022 let event = AgentEvent::MemoryStored {
5023 memory_id: "mem-1".to_string(),
5024 memory_type: "conversation".to_string(),
5025 importance: 0.8,
5026 tags: vec!["important".to_string()],
5027 };
5028 let json = serde_json::to_string(&event).unwrap();
5029 assert!(json.contains("memory_stored"));
5030 assert!(json.contains("mem-1"));
5031 }
5032
5033 #[test]
5034 fn test_agent_event_serialize_memory_recalled() {
5035 let event = AgentEvent::MemoryRecalled {
5036 memory_id: "mem-2".to_string(),
5037 content: "Previous conversation".to_string(),
5038 relevance: 0.9,
5039 };
5040 let json = serde_json::to_string(&event).unwrap();
5041 assert!(json.contains("memory_recalled"));
5042 assert!(json.contains("mem-2"));
5043 }
5044
5045 #[test]
5046 fn test_agent_event_serialize_memories_searched() {
5047 let event = AgentEvent::MemoriesSearched {
5048 query: Some("search term".to_string()),
5049 tags: vec!["tag1".to_string()],
5050 result_count: 5,
5051 };
5052 let json = serde_json::to_string(&event).unwrap();
5053 assert!(json.contains("memories_searched"));
5054 assert!(json.contains("search term"));
5055 }
5056
5057 #[test]
5058 fn test_agent_event_serialize_memory_cleared() {
5059 let event = AgentEvent::MemoryCleared {
5060 tier: "short_term".to_string(),
5061 count: 10,
5062 };
5063 let json = serde_json::to_string(&event).unwrap();
5064 assert!(json.contains("memory_cleared"));
5065 assert!(json.contains("short_term"));
5066 }
5067
5068 #[test]
5069 fn test_agent_event_serialize_subagent_start() {
5070 let event = AgentEvent::SubagentStart {
5071 task_id: "task-1".to_string(),
5072 session_id: "child-sess".to_string(),
5073 parent_session_id: "parent-sess".to_string(),
5074 agent: "explore".to_string(),
5075 description: "Explore codebase".to_string(),
5076 };
5077 let json = serde_json::to_string(&event).unwrap();
5078 assert!(json.contains("subagent_start"));
5079 assert!(json.contains("explore"));
5080 }
5081
5082 #[test]
5083 fn test_agent_event_serialize_subagent_progress() {
5084 let event = AgentEvent::SubagentProgress {
5085 task_id: "task-1".to_string(),
5086 session_id: "child-sess".to_string(),
5087 status: "processing".to_string(),
5088 metadata: serde_json::json!({"progress": 50}),
5089 };
5090 let json = serde_json::to_string(&event).unwrap();
5091 assert!(json.contains("subagent_progress"));
5092 assert!(json.contains("processing"));
5093 }
5094
5095 #[test]
5096 fn test_agent_event_serialize_subagent_end() {
5097 let event = AgentEvent::SubagentEnd {
5098 task_id: "task-1".to_string(),
5099 session_id: "child-sess".to_string(),
5100 agent: "explore".to_string(),
5101 output: "Found 10 files".to_string(),
5102 success: true,
5103 };
5104 let json = serde_json::to_string(&event).unwrap();
5105 assert!(json.contains("subagent_end"));
5106 assert!(json.contains("Found 10 files"));
5107 }
5108
5109 #[test]
5110 fn test_agent_event_serialize_planning_start() {
5111 let event = AgentEvent::PlanningStart {
5112 prompt: "Build a web app".to_string(),
5113 };
5114 let json = serde_json::to_string(&event).unwrap();
5115 assert!(json.contains("planning_start"));
5116 assert!(json.contains("Build a web app"));
5117 }
5118
5119 #[test]
5120 fn test_agent_event_serialize_planning_end() {
5121 use crate::planning::{Complexity, ExecutionPlan};
5122 let plan = ExecutionPlan::new("Test goal".to_string(), Complexity::Simple);
5123 let event = AgentEvent::PlanningEnd {
5124 plan,
5125 estimated_steps: 3,
5126 };
5127 let json = serde_json::to_string(&event).unwrap();
5128 assert!(json.contains("planning_end"));
5129 assert!(json.contains("estimated_steps"));
5130 }
5131
5132 #[test]
5133 fn test_agent_event_serialize_step_start() {
5134 let event = AgentEvent::StepStart {
5135 step_id: "step-1".to_string(),
5136 description: "Initialize project".to_string(),
5137 step_number: 1,
5138 total_steps: 5,
5139 };
5140 let json = serde_json::to_string(&event).unwrap();
5141 assert!(json.contains("step_start"));
5142 assert!(json.contains("Initialize project"));
5143 }
5144
5145 #[test]
5146 fn test_agent_event_serialize_step_end() {
5147 let event = AgentEvent::StepEnd {
5148 step_id: "step-1".to_string(),
5149 status: TaskStatus::Completed,
5150 step_number: 1,
5151 total_steps: 5,
5152 };
5153 let json = serde_json::to_string(&event).unwrap();
5154 assert!(json.contains("step_end"));
5155 assert!(json.contains("step-1"));
5156 }
5157
5158 #[test]
5159 fn test_agent_event_serialize_goal_extracted() {
5160 use crate::planning::AgentGoal;
5161 let goal = AgentGoal::new("Complete the task".to_string());
5162 let event = AgentEvent::GoalExtracted { goal };
5163 let json = serde_json::to_string(&event).unwrap();
5164 assert!(json.contains("goal_extracted"));
5165 }
5166
5167 #[test]
5168 fn test_agent_event_serialize_goal_progress() {
5169 let event = AgentEvent::GoalProgress {
5170 goal: "Build app".to_string(),
5171 progress: 0.5,
5172 completed_steps: 2,
5173 total_steps: 4,
5174 };
5175 let json = serde_json::to_string(&event).unwrap();
5176 assert!(json.contains("goal_progress"));
5177 assert!(json.contains("0.5"));
5178 }
5179
5180 #[test]
5181 fn test_agent_event_serialize_goal_achieved() {
5182 let event = AgentEvent::GoalAchieved {
5183 goal: "Build app".to_string(),
5184 total_steps: 4,
5185 duration_ms: 5000,
5186 };
5187 let json = serde_json::to_string(&event).unwrap();
5188 assert!(json.contains("goal_achieved"));
5189 assert!(json.contains("5000"));
5190 }
5191
5192 #[tokio::test]
5193 async fn test_extract_goal_with_json_response() {
5194 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
5196 r#"{"description": "Build web app", "success_criteria": ["App runs on port 3000", "Has login page"]}"#,
5197 )]));
5198 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5199 let agent = AgentLoop::new(
5200 mock_client,
5201 tool_executor,
5202 test_tool_context(),
5203 AgentConfig::default(),
5204 );
5205
5206 let goal = agent.extract_goal("Build a web app").await.unwrap();
5207 assert_eq!(goal.description, "Build web app");
5208 assert_eq!(goal.success_criteria.len(), 2);
5209 assert_eq!(goal.success_criteria[0], "App runs on port 3000");
5210 }
5211
5212 #[tokio::test]
5213 async fn test_extract_goal_fallback_on_non_json() {
5214 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
5216 "Some non-JSON response",
5217 )]));
5218 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5219 let agent = AgentLoop::new(
5220 mock_client,
5221 tool_executor,
5222 test_tool_context(),
5223 AgentConfig::default(),
5224 );
5225
5226 let goal = agent.extract_goal("Do something").await.unwrap();
5227 assert_eq!(goal.description, "Do something");
5229 assert_eq!(goal.success_criteria.len(), 2);
5231 }
5232
5233 #[tokio::test]
5234 async fn test_check_goal_achievement_json_yes() {
5235 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
5236 r#"{"achieved": true, "progress": 1.0, "remaining_criteria": []}"#,
5237 )]));
5238 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5239 let agent = AgentLoop::new(
5240 mock_client,
5241 tool_executor,
5242 test_tool_context(),
5243 AgentConfig::default(),
5244 );
5245
5246 let goal = crate::planning::AgentGoal::new("Test goal".to_string());
5247 let achieved = agent
5248 .check_goal_achievement(&goal, "All done")
5249 .await
5250 .unwrap();
5251 assert!(achieved);
5252 }
5253
5254 #[tokio::test]
5255 async fn test_check_goal_achievement_fallback_not_done() {
5256 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
5258 "invalid json",
5259 )]));
5260 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5261 let agent = AgentLoop::new(
5262 mock_client,
5263 tool_executor,
5264 test_tool_context(),
5265 AgentConfig::default(),
5266 );
5267
5268 let goal = crate::planning::AgentGoal::new("Test goal".to_string());
5269 let achieved = agent
5271 .check_goal_achievement(&goal, "still working")
5272 .await
5273 .unwrap();
5274 assert!(!achieved);
5275 }
5276
5277 #[test]
5282 fn test_build_augmented_system_prompt_empty_context() {
5283 let mock_client = Arc::new(MockLlmClient::new(vec![]));
5284 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5285 let config = AgentConfig {
5286 prompt_slots: SystemPromptSlots {
5287 extra: Some("Base prompt".to_string()),
5288 ..Default::default()
5289 },
5290 ..Default::default()
5291 };
5292 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
5293
5294 let result = agent.build_augmented_system_prompt(&[]);
5295 assert!(result.unwrap().contains("Base prompt"));
5296 }
5297
5298 #[test]
5299 fn test_build_augmented_system_prompt_no_custom_slots() {
5300 let mock_client = Arc::new(MockLlmClient::new(vec![]));
5301 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5302 let agent = AgentLoop::new(
5303 mock_client,
5304 tool_executor,
5305 test_tool_context(),
5306 AgentConfig::default(),
5307 );
5308
5309 let result = agent.build_augmented_system_prompt(&[]);
5310 assert!(result.is_some());
5312 assert!(result.unwrap().contains("Core Behaviour"));
5313 }
5314
5315 #[test]
5316 fn test_build_augmented_system_prompt_with_context_no_base() {
5317 use crate::context::{ContextItem, ContextResult, ContextType};
5318
5319 let mock_client = Arc::new(MockLlmClient::new(vec![]));
5320 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5321 let agent = AgentLoop::new(
5322 mock_client,
5323 tool_executor,
5324 test_tool_context(),
5325 AgentConfig::default(),
5326 );
5327
5328 let context = vec![ContextResult {
5329 provider: "test".to_string(),
5330 items: vec![ContextItem::new("id1", ContextType::Resource, "Content")],
5331 total_tokens: 10,
5332 truncated: false,
5333 }];
5334
5335 let result = agent.build_augmented_system_prompt(&context);
5336 assert!(result.is_some());
5337 let text = result.unwrap();
5338 assert!(text.contains("<context"));
5339 assert!(text.contains("Content"));
5340 }
5341
5342 #[test]
5347 fn test_agent_result_clone() {
5348 let result = AgentResult {
5349 text: "output".to_string(),
5350 messages: vec![Message::user("hello")],
5351 usage: TokenUsage::default(),
5352 tool_calls_count: 3,
5353 };
5354 let cloned = result.clone();
5355 assert_eq!(cloned.text, result.text);
5356 assert_eq!(cloned.tool_calls_count, result.tool_calls_count);
5357 }
5358
5359 #[test]
5360 fn test_agent_result_debug() {
5361 let result = AgentResult {
5362 text: "output".to_string(),
5363 messages: vec![Message::user("hello")],
5364 usage: TokenUsage::default(),
5365 tool_calls_count: 3,
5366 };
5367 let debug = format!("{:?}", result);
5368 assert!(debug.contains("AgentResult"));
5369 assert!(debug.contains("output"));
5370 }
5371
5372 #[tokio::test]
5381 async fn test_tool_command_command_type() {
5382 let executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5383 let cmd = ToolCommand {
5384 tool_executor: executor,
5385 tool_name: "read".to_string(),
5386 tool_args: serde_json::json!({"file": "test.rs"}),
5387 skill_registry: None,
5388 tool_context: test_tool_context(),
5389 };
5390 assert_eq!(cmd.command_type(), "read");
5391 }
5392
5393 #[tokio::test]
5394 async fn test_tool_command_payload() {
5395 let executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5396 let args = serde_json::json!({"file": "test.rs", "offset": 10});
5397 let cmd = ToolCommand {
5398 tool_executor: executor,
5399 tool_name: "read".to_string(),
5400 tool_args: args.clone(),
5401 skill_registry: None,
5402 tool_context: test_tool_context(),
5403 };
5404 assert_eq!(cmd.payload(), args);
5405 }
5406
5407 #[tokio::test(flavor = "multi_thread")]
5412 async fn test_agent_loop_with_queue() {
5413 use tokio::sync::broadcast;
5414
5415 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
5416 "Hello",
5417 )]));
5418 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5419 let config = AgentConfig::default();
5420
5421 let (event_tx, _) = broadcast::channel(100);
5422 let queue = SessionLaneQueue::new("test-session", SessionQueueConfig::default(), event_tx)
5423 .await
5424 .unwrap();
5425
5426 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config)
5427 .with_queue(Arc::new(queue));
5428
5429 assert!(agent.command_queue.is_some());
5430 }
5431
5432 #[tokio::test]
5433 async fn test_agent_loop_without_queue() {
5434 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
5435 "Hello",
5436 )]));
5437 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5438 let config = AgentConfig::default();
5439
5440 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
5441
5442 assert!(agent.command_queue.is_none());
5443 }
5444
5445 #[tokio::test]
5450 async fn test_execute_plan_parallel_independent() {
5451 use crate::planning::{Complexity, ExecutionPlan, Task};
5452
5453 let mock_client = Arc::new(MockLlmClient::new(vec![
5456 MockLlmClient::text_response("Step 1 done"),
5457 MockLlmClient::text_response("Step 2 done"),
5458 MockLlmClient::text_response("Step 3 done"),
5459 ]));
5460
5461 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5462 let config = AgentConfig::default();
5463 let agent = AgentLoop::new(
5464 mock_client.clone(),
5465 tool_executor,
5466 test_tool_context(),
5467 config,
5468 );
5469
5470 let mut plan = ExecutionPlan::new("Test parallel", Complexity::Simple);
5471 plan.add_step(Task::new("s1", "First step"));
5472 plan.add_step(Task::new("s2", "Second step"));
5473 plan.add_step(Task::new("s3", "Third step"));
5474
5475 let (tx, mut rx) = mpsc::channel(100);
5476 let result = agent.execute_plan(&[], &plan, Some(tx)).await.unwrap();
5477
5478 assert_eq!(result.usage.total_tokens, 45);
5480
5481 let mut step_starts = Vec::new();
5483 let mut step_ends = Vec::new();
5484 rx.close();
5485 while let Some(event) = rx.recv().await {
5486 match event {
5487 AgentEvent::StepStart { step_id, .. } => step_starts.push(step_id),
5488 AgentEvent::StepEnd {
5489 step_id, status, ..
5490 } => {
5491 assert_eq!(status, TaskStatus::Completed);
5492 step_ends.push(step_id);
5493 }
5494 _ => {}
5495 }
5496 }
5497 assert_eq!(step_starts.len(), 3);
5498 assert_eq!(step_ends.len(), 3);
5499 }
5500
5501 #[tokio::test]
5502 async fn test_execute_plan_respects_dependencies() {
5503 use crate::planning::{Complexity, ExecutionPlan, Task};
5504
5505 let mock_client = Arc::new(MockLlmClient::new(vec![
5508 MockLlmClient::text_response("Step 1 done"),
5509 MockLlmClient::text_response("Step 2 done"),
5510 MockLlmClient::text_response("Step 3 done"),
5511 ]));
5512
5513 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5514 let config = AgentConfig::default();
5515 let agent = AgentLoop::new(
5516 mock_client.clone(),
5517 tool_executor,
5518 test_tool_context(),
5519 config,
5520 );
5521
5522 let mut plan = ExecutionPlan::new("Test deps", Complexity::Medium);
5523 plan.add_step(Task::new("s1", "Independent A"));
5524 plan.add_step(Task::new("s2", "Independent B"));
5525 plan.add_step(
5526 Task::new("s3", "Depends on A+B")
5527 .with_dependencies(vec!["s1".to_string(), "s2".to_string()]),
5528 );
5529
5530 let (tx, mut rx) = mpsc::channel(100);
5531 let result = agent.execute_plan(&[], &plan, Some(tx)).await.unwrap();
5532
5533 assert_eq!(result.usage.total_tokens, 45);
5535
5536 let mut events = Vec::new();
5538 rx.close();
5539 while let Some(event) = rx.recv().await {
5540 match &event {
5541 AgentEvent::StepStart { step_id, .. } => {
5542 events.push(format!("start:{}", step_id));
5543 }
5544 AgentEvent::StepEnd { step_id, .. } => {
5545 events.push(format!("end:{}", step_id));
5546 }
5547 _ => {}
5548 }
5549 }
5550
5551 let s1_end = events.iter().position(|e| e == "end:s1").unwrap();
5553 let s2_end = events.iter().position(|e| e == "end:s2").unwrap();
5554 let s3_start = events.iter().position(|e| e == "start:s3").unwrap();
5555 assert!(
5556 s3_start > s1_end,
5557 "s3 started before s1 ended: {:?}",
5558 events
5559 );
5560 assert!(
5561 s3_start > s2_end,
5562 "s3 started before s2 ended: {:?}",
5563 events
5564 );
5565
5566 assert!(result.text.contains("Step 3 done") || !result.text.is_empty());
5568 }
5569
5570 #[tokio::test]
5571 async fn test_execute_plan_handles_step_failure() {
5572 use crate::planning::{Complexity, ExecutionPlan, Task};
5573
5574 let mock_client = Arc::new(MockLlmClient::new(vec![
5584 MockLlmClient::text_response("s1 done"),
5586 MockLlmClient::text_response("s3 done"),
5587 ]));
5590
5591 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5592 let config = AgentConfig::default();
5593 let agent = AgentLoop::new(
5594 mock_client.clone(),
5595 tool_executor,
5596 test_tool_context(),
5597 config,
5598 );
5599
5600 let mut plan = ExecutionPlan::new("Test failure", Complexity::Medium);
5601 plan.add_step(Task::new("s1", "Independent step"));
5602 plan.add_step(Task::new("s2", "Depends on s1").with_dependencies(vec!["s1".to_string()]));
5603 plan.add_step(Task::new("s3", "Another independent"));
5604 plan.add_step(Task::new("s4", "Depends on s2").with_dependencies(vec!["s2".to_string()]));
5605
5606 let (tx, mut rx) = mpsc::channel(100);
5607 let _result = agent.execute_plan(&[], &plan, Some(tx)).await.unwrap();
5608
5609 let mut completed_steps = Vec::new();
5612 let mut failed_steps = Vec::new();
5613 rx.close();
5614 while let Some(event) = rx.recv().await {
5615 if let AgentEvent::StepEnd {
5616 step_id, status, ..
5617 } = event
5618 {
5619 match status {
5620 TaskStatus::Completed => completed_steps.push(step_id),
5621 TaskStatus::Failed => failed_steps.push(step_id),
5622 _ => {}
5623 }
5624 }
5625 }
5626
5627 assert!(
5628 completed_steps.contains(&"s1".to_string()),
5629 "s1 should complete"
5630 );
5631 assert!(
5632 completed_steps.contains(&"s3".to_string()),
5633 "s3 should complete"
5634 );
5635 assert!(failed_steps.contains(&"s2".to_string()), "s2 should fail");
5636 assert!(
5638 !completed_steps.contains(&"s4".to_string()),
5639 "s4 should not complete"
5640 );
5641 assert!(
5642 !failed_steps.contains(&"s4".to_string()),
5643 "s4 should not fail (never started)"
5644 );
5645 }
5646
5647 #[test]
5652 fn test_agent_config_resilience_defaults() {
5653 let config = AgentConfig::default();
5654 assert_eq!(config.max_parse_retries, 2);
5655 assert_eq!(config.tool_timeout_ms, None);
5656 assert_eq!(config.circuit_breaker_threshold, 3);
5657 }
5658
5659 #[tokio::test]
5661 async fn test_parse_error_recovery_bails_after_threshold() {
5662 let mock_client = Arc::new(MockLlmClient::new(vec![
5664 MockLlmClient::tool_call_response(
5665 "c1",
5666 "bash",
5667 serde_json::json!({"__parse_error": "unexpected token at position 5"}),
5668 ),
5669 MockLlmClient::tool_call_response(
5670 "c2",
5671 "bash",
5672 serde_json::json!({"__parse_error": "missing closing brace"}),
5673 ),
5674 MockLlmClient::tool_call_response(
5675 "c3",
5676 "bash",
5677 serde_json::json!({"__parse_error": "still broken"}),
5678 ),
5679 MockLlmClient::text_response("Done"), ]));
5681
5682 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5683 let config = AgentConfig {
5684 max_parse_retries: 2,
5685 ..AgentConfig::default()
5686 };
5687 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
5688 let result = agent.execute(&[], "Do something", None).await;
5689 assert!(result.is_err(), "should bail after parse error threshold");
5690 let err = result.unwrap_err().to_string();
5691 assert!(
5692 err.contains("malformed tool arguments"),
5693 "error should mention malformed tool arguments, got: {}",
5694 err
5695 );
5696 }
5697
5698 #[tokio::test]
5700 async fn test_parse_error_counter_resets_on_success() {
5701 let mock_client = Arc::new(MockLlmClient::new(vec![
5705 MockLlmClient::tool_call_response(
5706 "c1",
5707 "bash",
5708 serde_json::json!({"__parse_error": "bad args"}),
5709 ),
5710 MockLlmClient::tool_call_response(
5711 "c2",
5712 "bash",
5713 serde_json::json!({"__parse_error": "bad args again"}),
5714 ),
5715 MockLlmClient::tool_call_response(
5717 "c3",
5718 "bash",
5719 serde_json::json!({"command": "echo ok"}),
5720 ),
5721 MockLlmClient::text_response("All done"),
5722 ]));
5723
5724 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5725 let config = AgentConfig {
5726 max_parse_retries: 2,
5727 ..AgentConfig::default()
5728 };
5729 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
5730 let result = agent.execute(&[], "Do something", None).await;
5731 assert!(
5732 result.is_ok(),
5733 "should not bail — counter reset after successful tool, got: {:?}",
5734 result.err()
5735 );
5736 assert_eq!(result.unwrap().text, "All done");
5737 }
5738
5739 #[tokio::test]
5741 async fn test_tool_timeout_produces_error_result() {
5742 let mock_client = Arc::new(MockLlmClient::new(vec![
5743 MockLlmClient::tool_call_response(
5744 "t1",
5745 "bash",
5746 serde_json::json!({"command": "sleep 10"}),
5747 ),
5748 MockLlmClient::text_response("The command timed out."),
5749 ]));
5750
5751 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5752 let config = AgentConfig {
5753 tool_timeout_ms: Some(50),
5755 ..AgentConfig::default()
5756 };
5757 let agent = AgentLoop::new(
5758 mock_client.clone(),
5759 tool_executor,
5760 test_tool_context(),
5761 config,
5762 );
5763 let result = agent.execute(&[], "Run sleep", None).await;
5764 assert!(
5765 result.is_ok(),
5766 "session should continue after tool timeout: {:?}",
5767 result.err()
5768 );
5769 assert_eq!(result.unwrap().text, "The command timed out.");
5770 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 2);
5772 }
5773
5774 #[tokio::test]
5776 async fn test_tool_within_timeout_succeeds() {
5777 let mock_client = Arc::new(MockLlmClient::new(vec![
5778 MockLlmClient::tool_call_response(
5779 "t1",
5780 "bash",
5781 serde_json::json!({"command": "echo fast"}),
5782 ),
5783 MockLlmClient::text_response("Command succeeded."),
5784 ]));
5785
5786 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5787 let config = AgentConfig {
5788 tool_timeout_ms: Some(5_000), ..AgentConfig::default()
5790 };
5791 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
5792 let result = agent.execute(&[], "Run something fast", None).await;
5793 assert!(
5794 result.is_ok(),
5795 "fast tool should succeed: {:?}",
5796 result.err()
5797 );
5798 assert_eq!(result.unwrap().text, "Command succeeded.");
5799 }
5800
5801 #[tokio::test]
5803 async fn test_circuit_breaker_retries_non_streaming() {
5804 let mock_client = Arc::new(MockLlmClient::new(vec![]));
5807
5808 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5809 let config = AgentConfig {
5810 circuit_breaker_threshold: 2,
5811 ..AgentConfig::default()
5812 };
5813 let agent = AgentLoop::new(
5814 mock_client.clone(),
5815 tool_executor,
5816 test_tool_context(),
5817 config,
5818 );
5819 let result = agent.execute(&[], "Hello", None).await;
5820 assert!(result.is_err(), "should fail when LLM always errors");
5821 let err = result.unwrap_err().to_string();
5822 assert!(
5823 err.contains("circuit breaker"),
5824 "error should mention circuit breaker, got: {}",
5825 err
5826 );
5827 assert_eq!(
5828 mock_client.call_count.load(Ordering::SeqCst),
5829 2,
5830 "should make exactly threshold=2 LLM calls"
5831 );
5832 }
5833
5834 #[tokio::test]
5836 async fn test_circuit_breaker_threshold_one_no_retry() {
5837 let mock_client = Arc::new(MockLlmClient::new(vec![]));
5838
5839 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5840 let config = AgentConfig {
5841 circuit_breaker_threshold: 1,
5842 ..AgentConfig::default()
5843 };
5844 let agent = AgentLoop::new(
5845 mock_client.clone(),
5846 tool_executor,
5847 test_tool_context(),
5848 config,
5849 );
5850 let result = agent.execute(&[], "Hello", None).await;
5851 assert!(result.is_err());
5852 assert_eq!(
5853 mock_client.call_count.load(Ordering::SeqCst),
5854 1,
5855 "with threshold=1 exactly one attempt should be made"
5856 );
5857 }
5858
5859 #[tokio::test]
5861 async fn test_circuit_breaker_succeeds_if_llm_recovers() {
5862 struct FailOnceThenSucceed {
5864 inner: MockLlmClient,
5865 failed_once: std::sync::atomic::AtomicBool,
5866 call_count: AtomicUsize,
5867 }
5868
5869 #[async_trait::async_trait]
5870 impl LlmClient for FailOnceThenSucceed {
5871 async fn complete(
5872 &self,
5873 messages: &[Message],
5874 system: Option<&str>,
5875 tools: &[ToolDefinition],
5876 ) -> Result<LlmResponse> {
5877 self.call_count.fetch_add(1, Ordering::SeqCst);
5878 let already_failed = self
5879 .failed_once
5880 .swap(true, std::sync::atomic::Ordering::SeqCst);
5881 if !already_failed {
5882 anyhow::bail!("transient network error");
5883 }
5884 self.inner.complete(messages, system, tools).await
5885 }
5886
5887 async fn complete_streaming(
5888 &self,
5889 messages: &[Message],
5890 system: Option<&str>,
5891 tools: &[ToolDefinition],
5892 ) -> Result<tokio::sync::mpsc::Receiver<crate::llm::StreamEvent>> {
5893 self.inner.complete_streaming(messages, system, tools).await
5894 }
5895 }
5896
5897 let mock = Arc::new(FailOnceThenSucceed {
5898 inner: MockLlmClient::new(vec![MockLlmClient::text_response("Recovered!")]),
5899 failed_once: std::sync::atomic::AtomicBool::new(false),
5900 call_count: AtomicUsize::new(0),
5901 });
5902
5903 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5904 let config = AgentConfig {
5905 circuit_breaker_threshold: 3,
5906 ..AgentConfig::default()
5907 };
5908 let agent = AgentLoop::new(mock.clone(), tool_executor, test_tool_context(), config);
5909 let result = agent.execute(&[], "Hello", None).await;
5910 assert!(
5911 result.is_ok(),
5912 "should succeed when LLM recovers within threshold: {:?}",
5913 result.err()
5914 );
5915 assert_eq!(result.unwrap().text, "Recovered!");
5916 assert_eq!(
5917 mock.call_count.load(Ordering::SeqCst),
5918 2,
5919 "should have made exactly 2 calls (1 fail + 1 success)"
5920 );
5921 }
5922
5923 #[test]
5926 fn test_looks_incomplete_empty() {
5927 assert!(AgentLoop::looks_incomplete(""));
5928 assert!(AgentLoop::looks_incomplete(" "));
5929 }
5930
5931 #[test]
5932 fn test_looks_incomplete_trailing_colon() {
5933 assert!(AgentLoop::looks_incomplete("Let me check the file:"));
5934 assert!(AgentLoop::looks_incomplete("Next steps:"));
5935 }
5936
5937 #[test]
5938 fn test_looks_incomplete_ellipsis() {
5939 assert!(AgentLoop::looks_incomplete("Working on it..."));
5940 assert!(AgentLoop::looks_incomplete("Processing…"));
5941 }
5942
5943 #[test]
5944 fn test_looks_incomplete_intent_phrases() {
5945 assert!(AgentLoop::looks_incomplete(
5946 "I'll start by reading the file."
5947 ));
5948 assert!(AgentLoop::looks_incomplete(
5949 "Let me check the configuration."
5950 ));
5951 assert!(AgentLoop::looks_incomplete("I will now run the tests."));
5952 assert!(AgentLoop::looks_incomplete(
5953 "I need to update the Cargo.toml."
5954 ));
5955 }
5956
5957 #[test]
5958 fn test_looks_complete_final_answer() {
5959 assert!(!AgentLoop::looks_incomplete(
5961 "The tests pass. All changes have been applied successfully."
5962 ));
5963 assert!(!AgentLoop::looks_incomplete(
5964 "Done. I've updated the three files and verified the build succeeds."
5965 ));
5966 assert!(!AgentLoop::looks_incomplete("42"));
5967 assert!(!AgentLoop::looks_incomplete("Yes."));
5968 }
5969
5970 #[test]
5971 fn test_looks_incomplete_multiline_complete() {
5972 let text = "Here is the summary:\n\n- Fixed the bug in agent.rs\n- All tests pass\n- Build succeeds";
5973 assert!(!AgentLoop::looks_incomplete(text));
5974 }
5975}