1use crate::context::{ContextProvider, ContextQuery, ContextResult};
13use crate::hitl::ConfirmationProvider;
14use crate::hooks::{
15 ErrorType, GenerateEndEvent, GenerateStartEvent, HookEvent, HookExecutor, HookResult,
16 OnErrorEvent, PostResponseEvent, PostToolUseEvent, PrePromptEvent, PreToolUseEvent,
17 TokenUsageInfo, ToolCallInfo, ToolResultData,
18};
19use crate::llm::{LlmClient, LlmResponse, Message, TokenUsage, ToolDefinition};
20use crate::permissions::{PermissionChecker, PermissionDecision};
21use crate::planning::{AgentGoal, ExecutionPlan, TaskStatus};
22use crate::prompts::SystemPromptSlots;
23use crate::queue::SessionCommand;
24use crate::session_lane_queue::SessionLaneQueue;
25use crate::tool_search::ToolIndex;
26use crate::tools::{ToolContext, ToolExecutor, ToolStreamEvent};
27use anyhow::{Context, Result};
28use async_trait::async_trait;
29use futures::future::join_all;
30use serde::{Deserialize, Serialize};
31use serde_json::Value;
32use std::sync::Arc;
33use std::time::Duration;
34use tokio::sync::{mpsc, RwLock};
35
36const MAX_TOOL_ROUNDS: usize = 50;
38
39#[derive(Clone)]
41pub struct AgentConfig {
42 pub prompt_slots: SystemPromptSlots,
48 pub tools: Vec<ToolDefinition>,
49 pub max_tool_rounds: usize,
50 pub security_provider: Option<Arc<dyn crate::security::SecurityProvider>>,
52 pub permission_checker: Option<Arc<dyn PermissionChecker>>,
54 pub confirmation_manager: Option<Arc<dyn ConfirmationProvider>>,
56 pub context_providers: Vec<Arc<dyn ContextProvider>>,
58 pub planning_enabled: bool,
60 pub goal_tracking: bool,
62 pub hook_engine: Option<Arc<dyn HookExecutor>>,
64 pub skill_registry: Option<Arc<crate::skills::SkillRegistry>>,
66 pub max_parse_retries: u32,
72 pub tool_timeout_ms: Option<u64>,
78 pub circuit_breaker_threshold: u32,
85 pub duplicate_tool_call_threshold: u32,
91 pub auto_compact: bool,
93 pub auto_compact_threshold: f32,
96 pub max_context_tokens: usize,
99 pub llm_client: Option<Arc<dyn LlmClient>>,
101 pub memory: Option<Arc<crate::memory::AgentMemory>>,
103 pub continuation_enabled: bool,
111 pub max_continuation_turns: u32,
115 pub tool_index: Option<ToolIndex>,
120}
121
122impl std::fmt::Debug for AgentConfig {
123 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
124 f.debug_struct("AgentConfig")
125 .field("prompt_slots", &self.prompt_slots)
126 .field("tools", &self.tools)
127 .field("max_tool_rounds", &self.max_tool_rounds)
128 .field("security_provider", &self.security_provider.is_some())
129 .field("permission_checker", &self.permission_checker.is_some())
130 .field("confirmation_manager", &self.confirmation_manager.is_some())
131 .field("context_providers", &self.context_providers.len())
132 .field("planning_enabled", &self.planning_enabled)
133 .field("goal_tracking", &self.goal_tracking)
134 .field("hook_engine", &self.hook_engine.is_some())
135 .field(
136 "skill_registry",
137 &self.skill_registry.as_ref().map(|r| r.len()),
138 )
139 .field("max_parse_retries", &self.max_parse_retries)
140 .field("tool_timeout_ms", &self.tool_timeout_ms)
141 .field("circuit_breaker_threshold", &self.circuit_breaker_threshold)
142 .field(
143 "duplicate_tool_call_threshold",
144 &self.duplicate_tool_call_threshold,
145 )
146 .field("auto_compact", &self.auto_compact)
147 .field("auto_compact_threshold", &self.auto_compact_threshold)
148 .field("max_context_tokens", &self.max_context_tokens)
149 .field("continuation_enabled", &self.continuation_enabled)
150 .field("max_continuation_turns", &self.max_continuation_turns)
151 .field("memory", &self.memory.is_some())
152 .field("tool_index", &self.tool_index.as_ref().map(|i| i.len()))
153 .finish()
154 }
155}
156
157impl Default for AgentConfig {
158 fn default() -> Self {
159 Self {
160 prompt_slots: SystemPromptSlots::default(),
161 tools: Vec::new(), max_tool_rounds: MAX_TOOL_ROUNDS,
163 security_provider: None,
164 permission_checker: None,
165 confirmation_manager: None,
166 context_providers: Vec::new(),
167 planning_enabled: false,
168 goal_tracking: false,
169 hook_engine: None,
170 skill_registry: Some(Arc::new(crate::skills::SkillRegistry::with_builtins())),
171 max_parse_retries: 2,
172 tool_timeout_ms: None,
173 circuit_breaker_threshold: 3,
174 duplicate_tool_call_threshold: 3,
175 auto_compact: false,
176 auto_compact_threshold: 0.80,
177 max_context_tokens: 200_000,
178 llm_client: None,
179 memory: None,
180 continuation_enabled: true,
181 max_continuation_turns: 3,
182 tool_index: None,
183 }
184 }
185}
186
187#[derive(Debug, Clone, Serialize, Deserialize)]
193#[serde(tag = "type")]
194#[non_exhaustive]
195pub enum AgentEvent {
196 #[serde(rename = "agent_start")]
198 Start { prompt: String },
199
200 #[serde(rename = "turn_start")]
202 TurnStart { turn: usize },
203
204 #[serde(rename = "text_delta")]
206 TextDelta { text: String },
207
208 #[serde(rename = "tool_start")]
210 ToolStart { id: String, name: String },
211
212 #[serde(rename = "tool_input_delta")]
214 ToolInputDelta { delta: String },
215
216 #[serde(rename = "tool_end")]
218 ToolEnd {
219 id: String,
220 name: String,
221 output: String,
222 exit_code: i32,
223 #[serde(skip_serializing_if = "Option::is_none")]
224 metadata: Option<serde_json::Value>,
225 },
226
227 #[serde(rename = "tool_output_delta")]
229 ToolOutputDelta {
230 id: String,
231 name: String,
232 delta: String,
233 },
234
235 #[serde(rename = "turn_end")]
237 TurnEnd { turn: usize, usage: TokenUsage },
238
239 #[serde(rename = "agent_end")]
241 End {
242 text: String,
243 usage: TokenUsage,
244 #[serde(skip_serializing_if = "Option::is_none")]
245 meta: Option<crate::llm::LlmResponseMeta>,
246 },
247
248 #[serde(rename = "error")]
250 Error { message: String },
251
252 #[serde(rename = "confirmation_required")]
254 ConfirmationRequired {
255 tool_id: String,
256 tool_name: String,
257 args: serde_json::Value,
258 timeout_ms: u64,
259 },
260
261 #[serde(rename = "confirmation_received")]
263 ConfirmationReceived {
264 tool_id: String,
265 approved: bool,
266 reason: Option<String>,
267 },
268
269 #[serde(rename = "confirmation_timeout")]
271 ConfirmationTimeout {
272 tool_id: String,
273 action_taken: String, },
275
276 #[serde(rename = "external_task_pending")]
278 ExternalTaskPending {
279 task_id: String,
280 session_id: String,
281 lane: crate::hitl::SessionLane,
282 command_type: String,
283 payload: serde_json::Value,
284 timeout_ms: u64,
285 },
286
287 #[serde(rename = "external_task_completed")]
289 ExternalTaskCompleted {
290 task_id: String,
291 session_id: String,
292 success: bool,
293 },
294
295 #[serde(rename = "permission_denied")]
297 PermissionDenied {
298 tool_id: String,
299 tool_name: String,
300 args: serde_json::Value,
301 reason: String,
302 },
303
304 #[serde(rename = "context_resolving")]
306 ContextResolving { providers: Vec<String> },
307
308 #[serde(rename = "context_resolved")]
310 ContextResolved {
311 total_items: usize,
312 total_tokens: usize,
313 },
314
315 #[serde(rename = "command_dead_lettered")]
320 CommandDeadLettered {
321 command_id: String,
322 command_type: String,
323 lane: String,
324 error: String,
325 attempts: u32,
326 },
327
328 #[serde(rename = "command_retry")]
330 CommandRetry {
331 command_id: String,
332 command_type: String,
333 lane: String,
334 attempt: u32,
335 delay_ms: u64,
336 },
337
338 #[serde(rename = "queue_alert")]
340 QueueAlert {
341 level: String,
342 alert_type: String,
343 message: String,
344 },
345
346 #[serde(rename = "task_updated")]
351 TaskUpdated {
352 session_id: String,
353 tasks: Vec<crate::planning::Task>,
354 },
355
356 #[serde(rename = "memory_stored")]
361 MemoryStored {
362 memory_id: String,
363 memory_type: String,
364 importance: f32,
365 tags: Vec<String>,
366 },
367
368 #[serde(rename = "memory_recalled")]
370 MemoryRecalled {
371 memory_id: String,
372 content: String,
373 relevance: f32,
374 },
375
376 #[serde(rename = "memories_searched")]
378 MemoriesSearched {
379 query: Option<String>,
380 tags: Vec<String>,
381 result_count: usize,
382 },
383
384 #[serde(rename = "memory_cleared")]
386 MemoryCleared {
387 tier: String, count: u64,
389 },
390
391 #[serde(rename = "subagent_start")]
396 SubagentStart {
397 task_id: String,
399 session_id: String,
401 parent_session_id: String,
403 agent: String,
405 description: String,
407 },
408
409 #[serde(rename = "subagent_progress")]
411 SubagentProgress {
412 task_id: String,
414 session_id: String,
416 status: String,
418 metadata: serde_json::Value,
420 },
421
422 #[serde(rename = "subagent_end")]
424 SubagentEnd {
425 task_id: String,
427 session_id: String,
429 agent: String,
431 output: String,
433 success: bool,
435 },
436
437 #[serde(rename = "planning_start")]
442 PlanningStart { prompt: String },
443
444 #[serde(rename = "planning_end")]
446 PlanningEnd {
447 plan: ExecutionPlan,
448 estimated_steps: usize,
449 },
450
451 #[serde(rename = "step_start")]
453 StepStart {
454 step_id: String,
455 description: String,
456 step_number: usize,
457 total_steps: usize,
458 },
459
460 #[serde(rename = "step_end")]
462 StepEnd {
463 step_id: String,
464 status: TaskStatus,
465 step_number: usize,
466 total_steps: usize,
467 },
468
469 #[serde(rename = "goal_extracted")]
471 GoalExtracted { goal: AgentGoal },
472
473 #[serde(rename = "goal_progress")]
475 GoalProgress {
476 goal: String,
477 progress: f32,
478 completed_steps: usize,
479 total_steps: usize,
480 },
481
482 #[serde(rename = "goal_achieved")]
484 GoalAchieved {
485 goal: String,
486 total_steps: usize,
487 duration_ms: i64,
488 },
489
490 #[serde(rename = "context_compacted")]
495 ContextCompacted {
496 session_id: String,
497 before_messages: usize,
498 after_messages: usize,
499 percent_before: f32,
500 },
501
502 #[serde(rename = "persistence_failed")]
507 PersistenceFailed {
508 session_id: String,
509 operation: String,
510 error: String,
511 },
512
513 #[serde(rename = "btw_answer")]
521 BtwAnswer {
522 question: String,
523 answer: String,
524 usage: TokenUsage,
525 },
526}
527
528#[derive(Debug, Clone)]
530pub struct AgentResult {
531 pub text: String,
532 pub messages: Vec<Message>,
533 pub usage: TokenUsage,
534 pub tool_calls_count: usize,
535}
536
537pub struct ToolCommand {
545 tool_executor: Arc<ToolExecutor>,
546 tool_name: String,
547 tool_args: Value,
548 tool_context: ToolContext,
549 skill_registry: Option<Arc<crate::skills::SkillRegistry>>,
550}
551
552impl ToolCommand {
553 pub fn new(
555 tool_executor: Arc<ToolExecutor>,
556 tool_name: String,
557 tool_args: Value,
558 tool_context: ToolContext,
559 skill_registry: Option<Arc<crate::skills::SkillRegistry>>,
560 ) -> Self {
561 Self {
562 tool_executor,
563 tool_name,
564 tool_args,
565 tool_context,
566 skill_registry,
567 }
568 }
569}
570
571#[async_trait]
572impl SessionCommand for ToolCommand {
573 async fn execute(&self) -> Result<Value> {
574 if let Some(registry) = &self.skill_registry {
576 let instruction_skills = registry.by_kind(crate::skills::SkillKind::Instruction);
577
578 let has_restrictions = instruction_skills.iter().any(|s| s.allowed_tools.is_some());
580
581 if has_restrictions {
582 let mut allowed = false;
583
584 for skill in &instruction_skills {
585 if skill.is_tool_allowed(&self.tool_name) {
586 allowed = true;
587 break;
588 }
589 }
590
591 if !allowed {
592 return Err(anyhow::anyhow!(
593 "Tool '{}' is not allowed by any active skill. Active skills restrict tools to their allowed-tools lists.",
594 self.tool_name
595 ));
596 }
597 }
598 }
599
600 let result = self
602 .tool_executor
603 .execute_with_context(&self.tool_name, &self.tool_args, &self.tool_context)
604 .await?;
605 Ok(serde_json::json!({
606 "output": result.output,
607 "exit_code": result.exit_code,
608 "metadata": result.metadata,
609 }))
610 }
611
612 fn command_type(&self) -> &str {
613 &self.tool_name
614 }
615
616 fn payload(&self) -> Value {
617 self.tool_args.clone()
618 }
619}
620
621#[derive(Clone)]
627pub struct AgentLoop {
628 llm_client: Arc<dyn LlmClient>,
629 tool_executor: Arc<ToolExecutor>,
630 tool_context: ToolContext,
631 config: AgentConfig,
632 tool_metrics: Option<Arc<RwLock<crate::telemetry::ToolMetrics>>>,
634 command_queue: Option<Arc<SessionLaneQueue>>,
636}
637
638impl AgentLoop {
639 pub fn new(
640 llm_client: Arc<dyn LlmClient>,
641 tool_executor: Arc<ToolExecutor>,
642 tool_context: ToolContext,
643 config: AgentConfig,
644 ) -> Self {
645 Self {
646 llm_client,
647 tool_executor,
648 tool_context,
649 config,
650 tool_metrics: None,
651 command_queue: None,
652 }
653 }
654
655 pub fn with_tool_metrics(
657 mut self,
658 metrics: Arc<RwLock<crate::telemetry::ToolMetrics>>,
659 ) -> Self {
660 self.tool_metrics = Some(metrics);
661 self
662 }
663
664 pub fn with_queue(mut self, queue: Arc<SessionLaneQueue>) -> Self {
669 self.command_queue = Some(queue);
670 self
671 }
672
673 async fn execute_tool_timed(
679 &self,
680 name: &str,
681 args: &serde_json::Value,
682 ctx: &ToolContext,
683 ) -> anyhow::Result<crate::tools::ToolResult> {
684 let fut = self.tool_executor.execute_with_context(name, args, ctx);
685 if let Some(timeout_ms) = self.config.tool_timeout_ms {
686 match tokio::time::timeout(Duration::from_millis(timeout_ms), fut).await {
687 Ok(result) => result,
688 Err(_) => Err(anyhow::anyhow!(
689 "Tool '{}' timed out after {}ms",
690 name,
691 timeout_ms
692 )),
693 }
694 } else {
695 fut.await
696 }
697 }
698
699 fn tool_result_to_tuple(
701 result: anyhow::Result<crate::tools::ToolResult>,
702 ) -> (
703 String,
704 i32,
705 bool,
706 Option<serde_json::Value>,
707 Vec<crate::llm::Attachment>,
708 ) {
709 match result {
710 Ok(r) => (
711 r.output,
712 r.exit_code,
713 r.exit_code != 0,
714 r.metadata,
715 r.images,
716 ),
717 Err(e) => {
718 let msg = e.to_string();
719 let hint = if Self::is_transient_error(&msg) {
721 " [transient — you may retry this tool call]"
722 } else {
723 " [permanent — do not retry without changing the arguments]"
724 };
725 (
726 format!("Tool execution error: {}{}", msg, hint),
727 1,
728 true,
729 None,
730 Vec::new(),
731 )
732 }
733 }
734 }
735
736 fn detect_project_hint(workspace: &std::path::Path) -> String {
740 struct Marker {
741 file: &'static str,
742 lang: &'static str,
743 tip: &'static str,
744 }
745
746 let markers = [
747 Marker {
748 file: "Cargo.toml",
749 lang: "Rust",
750 tip: "Use `cargo build`, `cargo test`, `cargo clippy`, and `cargo fmt`. \
751 Prefer `anyhow` / `thiserror` for error handling. \
752 Follow the Microsoft Rust Guidelines (no panics in library code, \
753 async-first with Tokio).",
754 },
755 Marker {
756 file: "package.json",
757 lang: "Node.js / TypeScript",
758 tip: "Check `package.json` for the package manager (npm/yarn/pnpm/bun) \
759 and available scripts. Prefer TypeScript with strict mode. \
760 Use ESM imports unless the project is CommonJS.",
761 },
762 Marker {
763 file: "pyproject.toml",
764 lang: "Python",
765 tip: "Use the package manager declared in `pyproject.toml` \
766 (uv, poetry, hatch, etc.). Prefer type hints and async/await for I/O.",
767 },
768 Marker {
769 file: "setup.py",
770 lang: "Python",
771 tip: "Legacy Python project. Prefer type hints and async/await for I/O.",
772 },
773 Marker {
774 file: "requirements.txt",
775 lang: "Python",
776 tip: "Python project with pip-style dependencies. \
777 Prefer type hints and async/await for I/O.",
778 },
779 Marker {
780 file: "go.mod",
781 lang: "Go",
782 tip: "Use `go build ./...` and `go test ./...`. \
783 Follow standard Go project layout. Use `gofmt` for formatting.",
784 },
785 Marker {
786 file: "pom.xml",
787 lang: "Java / Maven",
788 tip: "Use `mvn compile`, `mvn test`, `mvn package`. \
789 Follow standard Maven project structure.",
790 },
791 Marker {
792 file: "build.gradle",
793 lang: "Java / Gradle",
794 tip: "Use `./gradlew build` and `./gradlew test`. \
795 Follow standard Gradle project structure.",
796 },
797 Marker {
798 file: "build.gradle.kts",
799 lang: "Kotlin / Gradle",
800 tip: "Use `./gradlew build` and `./gradlew test`. \
801 Prefer Kotlin coroutines for async work.",
802 },
803 Marker {
804 file: "CMakeLists.txt",
805 lang: "C / C++",
806 tip: "Use `cmake -B build && cmake --build build`. \
807 Check for `compile_commands.json` for IDE tooling.",
808 },
809 Marker {
810 file: "Makefile",
811 lang: "C / C++ (or generic)",
812 tip: "Use `make` or `make <target>`. \
813 Check available targets with `make help` or by reading the Makefile.",
814 },
815 ];
816
817 let is_dotnet = workspace.join("*.csproj").exists() || {
819 std::fs::read_dir(workspace)
821 .map(|entries| {
822 entries.flatten().any(|e| {
823 let name = e.file_name();
824 let s = name.to_string_lossy();
825 s.ends_with(".csproj") || s.ends_with(".sln")
826 })
827 })
828 .unwrap_or(false)
829 };
830
831 if is_dotnet {
832 return "## Project Context\n\nThis is a **C# / .NET** project. \
833 Use `dotnet build`, `dotnet test`, and `dotnet run`. \
834 Follow C# coding conventions and async/await patterns."
835 .to_string();
836 }
837
838 for marker in &markers {
839 if workspace.join(marker.file).exists() {
840 return format!(
841 "## Project Context\n\nThis is a **{}** project. {}",
842 marker.lang, marker.tip
843 );
844 }
845 }
846
847 String::new()
848 }
849
850 fn is_transient_error(msg: &str) -> bool {
853 let lower = msg.to_lowercase();
854 lower.contains("timeout")
855 || lower.contains("timed out")
856 || lower.contains("connection refused")
857 || lower.contains("connection reset")
858 || lower.contains("broken pipe")
859 || lower.contains("temporarily unavailable")
860 || lower.contains("resource temporarily unavailable")
861 || lower.contains("os error 11") || lower.contains("os error 35") || lower.contains("rate limit")
864 || lower.contains("too many requests")
865 || lower.contains("service unavailable")
866 || lower.contains("network unreachable")
867 }
868
869 fn is_parallel_safe_write(name: &str, _args: &serde_json::Value) -> bool {
872 matches!(
873 name,
874 "write_file" | "edit_file" | "create_file" | "append_to_file" | "replace_in_file"
875 )
876 }
877
878 fn extract_write_path(args: &serde_json::Value) -> Option<String> {
880 args.get("path")
883 .and_then(|v| v.as_str())
884 .map(|s| s.to_string())
885 }
886
887 async fn execute_tool_queued_or_direct(
889 &self,
890 name: &str,
891 args: &serde_json::Value,
892 ctx: &ToolContext,
893 ) -> anyhow::Result<crate::tools::ToolResult> {
894 if let Some(ref queue) = self.command_queue {
895 let command = ToolCommand::new(
896 Arc::clone(&self.tool_executor),
897 name.to_string(),
898 args.clone(),
899 ctx.clone(),
900 self.config.skill_registry.clone(),
901 );
902 let rx = queue.submit_by_tool(name, Box::new(command)).await;
903 match rx.await {
904 Ok(Ok(value)) => {
905 let output = value["output"]
906 .as_str()
907 .ok_or_else(|| {
908 anyhow::anyhow!(
909 "Queue result missing 'output' field for tool '{}'",
910 name
911 )
912 })?
913 .to_string();
914 let exit_code = value["exit_code"].as_i64().unwrap_or(0) as i32;
915 return Ok(crate::tools::ToolResult {
916 name: name.to_string(),
917 output,
918 exit_code,
919 metadata: None,
920 images: Vec::new(),
921 });
922 }
923 Ok(Err(e)) => {
924 tracing::warn!(
925 "Queue execution failed for tool '{}', falling back to direct: {}",
926 name,
927 e
928 );
929 }
930 Err(_) => {
931 tracing::warn!(
932 "Queue channel closed for tool '{}', falling back to direct",
933 name
934 );
935 }
936 }
937 }
938 self.execute_tool_timed(name, args, ctx).await
939 }
940
941 async fn call_llm(
952 &self,
953 messages: &[Message],
954 system: Option<&str>,
955 event_tx: &Option<mpsc::Sender<AgentEvent>>,
956 cancel_token: &tokio_util::sync::CancellationToken,
957 ) -> anyhow::Result<LlmResponse> {
958 let tools = if let Some(ref index) = self.config.tool_index {
960 let query = messages
961 .iter()
962 .rev()
963 .find(|m| m.role == "user")
964 .and_then(|m| {
965 m.content.iter().find_map(|b| match b {
966 crate::llm::ContentBlock::Text { text } => Some(text.as_str()),
967 _ => None,
968 })
969 })
970 .unwrap_or("");
971 let matches = index.search(query, index.len());
972 let matched_names: std::collections::HashSet<&str> =
973 matches.iter().map(|m| m.name.as_str()).collect();
974 self.config
975 .tools
976 .iter()
977 .filter(|t| matched_names.contains(t.name.as_str()))
978 .cloned()
979 .collect::<Vec<_>>()
980 } else {
981 self.config.tools.clone()
982 };
983
984 if event_tx.is_some() {
985 let mut stream_rx = match self
986 .llm_client
987 .complete_streaming(messages, system, &tools)
988 .await
989 {
990 Ok(rx) => rx,
991 Err(stream_error) => {
992 tracing::warn!(
993 error = %stream_error,
994 "LLM streaming setup failed; falling back to non-streaming completion"
995 );
996 return self
997 .llm_client
998 .complete(messages, system, &tools)
999 .await
1000 .with_context(|| {
1001 format!(
1002 "LLM streaming call failed ({stream_error}); non-streaming fallback also failed"
1003 )
1004 });
1005 }
1006 };
1007
1008 let mut final_response: Option<LlmResponse> = None;
1009 loop {
1010 tokio::select! {
1011 _ = cancel_token.cancelled() => {
1012 tracing::info!("🛑 LLM streaming cancelled by CancellationToken");
1013 anyhow::bail!("Operation cancelled by user");
1014 }
1015 event = stream_rx.recv() => {
1016 match event {
1017 Some(crate::llm::StreamEvent::TextDelta(text)) => {
1018 if let Some(tx) = event_tx {
1019 tx.send(AgentEvent::TextDelta { text }).await.ok();
1020 }
1021 }
1022 Some(crate::llm::StreamEvent::ToolUseStart { id, name }) => {
1023 if let Some(tx) = event_tx {
1024 tx.send(AgentEvent::ToolStart { id, name }).await.ok();
1025 }
1026 }
1027 Some(crate::llm::StreamEvent::ToolUseInputDelta(delta)) => {
1028 if let Some(tx) = event_tx {
1029 tx.send(AgentEvent::ToolInputDelta { delta }).await.ok();
1030 }
1031 }
1032 Some(crate::llm::StreamEvent::Done(resp)) => {
1033 final_response = Some(resp);
1034 break;
1035 }
1036 None => break,
1037 }
1038 }
1039 }
1040 }
1041 final_response.context("Stream ended without final response")
1042 } else {
1043 self.llm_client
1044 .complete(messages, system, &tools)
1045 .await
1046 .context("LLM call failed")
1047 }
1048 }
1049
1050 fn streaming_tool_context(
1059 &self,
1060 event_tx: &Option<mpsc::Sender<AgentEvent>>,
1061 tool_id: &str,
1062 tool_name: &str,
1063 ) -> ToolContext {
1064 let mut ctx = self.tool_context.clone();
1065 if let Some(agent_tx) = event_tx {
1066 let (tool_tx, mut tool_rx) = mpsc::channel::<ToolStreamEvent>(64);
1067 ctx.event_tx = Some(tool_tx);
1068
1069 let agent_tx = agent_tx.clone();
1070 let tool_id = tool_id.to_string();
1071 let tool_name = tool_name.to_string();
1072 tokio::spawn(async move {
1073 while let Some(event) = tool_rx.recv().await {
1074 match event {
1075 ToolStreamEvent::OutputDelta(delta) => {
1076 agent_tx
1077 .send(AgentEvent::ToolOutputDelta {
1078 id: tool_id.clone(),
1079 name: tool_name.clone(),
1080 delta,
1081 })
1082 .await
1083 .ok();
1084 }
1085 }
1086 }
1087 });
1088 }
1089 ctx
1090 }
1091
1092 async fn resolve_context(&self, prompt: &str, session_id: Option<&str>) -> Vec<ContextResult> {
1096 if self.config.context_providers.is_empty() {
1097 return Vec::new();
1098 }
1099
1100 let query = ContextQuery::new(prompt).with_session_id(session_id.unwrap_or(""));
1101
1102 let futures = self
1103 .config
1104 .context_providers
1105 .iter()
1106 .map(|p| p.query(&query));
1107 let outcomes = join_all(futures).await;
1108
1109 outcomes
1110 .into_iter()
1111 .enumerate()
1112 .filter_map(|(i, r)| match r {
1113 Ok(result) if !result.is_empty() => Some(result),
1114 Ok(_) => None,
1115 Err(e) => {
1116 tracing::warn!(
1117 "Context provider '{}' failed: {}",
1118 self.config.context_providers[i].name(),
1119 e
1120 );
1121 None
1122 }
1123 })
1124 .collect()
1125 }
1126
1127 fn looks_incomplete(text: &str) -> bool {
1135 let t = text.trim();
1136 if t.is_empty() {
1137 return true;
1138 }
1139 if t.len() < 80 && !t.contains('\n') {
1141 let ends_continuation =
1144 t.ends_with(':') || t.ends_with("...") || t.ends_with('…') || t.ends_with(',');
1145 if ends_continuation {
1146 return true;
1147 }
1148 }
1149 let incomplete_phrases = [
1151 "i'll ",
1152 "i will ",
1153 "let me ",
1154 "i need to ",
1155 "i should ",
1156 "next, i",
1157 "first, i",
1158 "now i",
1159 "i'll start",
1160 "i'll begin",
1161 "i'll now",
1162 "let's start",
1163 "let's begin",
1164 "to do this",
1165 "i'm going to",
1166 ];
1167 let lower = t.to_lowercase();
1168 for phrase in &incomplete_phrases {
1169 if lower.contains(phrase) {
1170 return true;
1171 }
1172 }
1173 false
1174 }
1175
1176 fn system_prompt(&self) -> String {
1178 self.config.prompt_slots.build()
1179 }
1180
1181 fn build_augmented_system_prompt(&self, context_results: &[ContextResult]) -> Option<String> {
1183 let base = self.system_prompt();
1184
1185 let live_tools = self.tool_executor.definitions();
1187 let mcp_tools: Vec<&ToolDefinition> = live_tools
1188 .iter()
1189 .filter(|t| t.name.starts_with("mcp__"))
1190 .collect();
1191
1192 let mcp_section = if mcp_tools.is_empty() {
1193 String::new()
1194 } else {
1195 let mut lines = vec![
1196 "## MCP Tools".to_string(),
1197 String::new(),
1198 "The following MCP (Model Context Protocol) tools are available. Use them when the task requires external capabilities beyond the built-in tools:".to_string(),
1199 String::new(),
1200 ];
1201 for tool in &mcp_tools {
1202 let display = format!("- `{}` — {}", tool.name, tool.description);
1203 lines.push(display);
1204 }
1205 lines.join("\n")
1206 };
1207
1208 let parts: Vec<&str> = [base.as_str(), mcp_section.as_str()]
1209 .iter()
1210 .filter(|s| !s.is_empty())
1211 .copied()
1212 .collect();
1213
1214 let project_hint = if self.config.prompt_slots.guidelines.is_none() {
1217 Self::detect_project_hint(&self.tool_context.workspace)
1218 } else {
1219 String::new()
1220 };
1221
1222 if context_results.is_empty() {
1223 if project_hint.is_empty() {
1224 return Some(parts.join("\n\n"));
1225 }
1226 return Some(format!("{}\n\n{}", parts.join("\n\n"), project_hint));
1227 }
1228
1229 let context_xml: String = context_results
1231 .iter()
1232 .map(|r| r.to_xml())
1233 .collect::<Vec<_>>()
1234 .join("\n\n");
1235
1236 if project_hint.is_empty() {
1237 Some(format!("{}\n\n{}", parts.join("\n\n"), context_xml))
1238 } else {
1239 Some(format!(
1240 "{}\n\n{}\n\n{}",
1241 parts.join("\n\n"),
1242 project_hint,
1243 context_xml
1244 ))
1245 }
1246 }
1247
1248 async fn notify_turn_complete(&self, session_id: &str, prompt: &str, response: &str) {
1250 let futures = self
1251 .config
1252 .context_providers
1253 .iter()
1254 .map(|p| p.on_turn_complete(session_id, prompt, response));
1255 let outcomes = join_all(futures).await;
1256
1257 for (i, result) in outcomes.into_iter().enumerate() {
1258 if let Err(e) = result {
1259 tracing::warn!(
1260 "Context provider '{}' on_turn_complete failed: {}",
1261 self.config.context_providers[i].name(),
1262 e
1263 );
1264 }
1265 }
1266 }
1267
1268 async fn fire_pre_tool_use(
1271 &self,
1272 session_id: &str,
1273 tool_name: &str,
1274 args: &serde_json::Value,
1275 recent_tools: Vec<String>,
1276 ) -> Option<HookResult> {
1277 if let Some(he) = &self.config.hook_engine {
1278 let event = HookEvent::PreToolUse(PreToolUseEvent {
1279 session_id: session_id.to_string(),
1280 tool: tool_name.to_string(),
1281 args: args.clone(),
1282 working_directory: self.tool_context.workspace.to_string_lossy().to_string(),
1283 recent_tools,
1284 });
1285 let result = he.fire(&event).await;
1286 if result.is_block() {
1287 return Some(result);
1288 }
1289 }
1290 None
1291 }
1292
1293 fn tool_call_signature(tool_name: &str, args: &serde_json::Value) -> String {
1294 let raw = match args {
1295 serde_json::Value::Null => String::new(),
1296 serde_json::Value::String(s) => s.clone(),
1297 _ => serde_json::to_string(args).unwrap_or_default(),
1298 };
1299 let compact = raw.split_whitespace().collect::<Vec<_>>().join(" ");
1300 let compact = if compact.len() > 180 {
1301 format!("{}...", &compact[..180])
1302 } else {
1303 compact
1304 };
1305 format!("{tool_name}:{compact}")
1306 }
1307
1308 async fn fire_post_tool_use(
1310 &self,
1311 session_id: &str,
1312 tool_name: &str,
1313 args: &serde_json::Value,
1314 output: &str,
1315 success: bool,
1316 duration_ms: u64,
1317 ) {
1318 if let Some(he) = &self.config.hook_engine {
1319 let event = HookEvent::PostToolUse(PostToolUseEvent {
1320 session_id: session_id.to_string(),
1321 tool: tool_name.to_string(),
1322 args: args.clone(),
1323 result: ToolResultData {
1324 success,
1325 output: output.to_string(),
1326 exit_code: if success { Some(0) } else { Some(1) },
1327 duration_ms,
1328 },
1329 });
1330 let he = Arc::clone(he);
1331 tokio::spawn(async move {
1332 let _ = he.fire(&event).await;
1333 });
1334 }
1335 }
1336
1337 async fn fire_generate_start(
1339 &self,
1340 session_id: &str,
1341 prompt: &str,
1342 system_prompt: &Option<String>,
1343 ) {
1344 if let Some(he) = &self.config.hook_engine {
1345 let event = HookEvent::GenerateStart(GenerateStartEvent {
1346 session_id: session_id.to_string(),
1347 prompt: prompt.to_string(),
1348 system_prompt: system_prompt.clone(),
1349 model_provider: String::new(),
1350 model_name: String::new(),
1351 available_tools: self.config.tools.iter().map(|t| t.name.clone()).collect(),
1352 });
1353 let _ = he.fire(&event).await;
1354 }
1355 }
1356
1357 async fn fire_generate_end(
1359 &self,
1360 session_id: &str,
1361 prompt: &str,
1362 response: &LlmResponse,
1363 duration_ms: u64,
1364 ) {
1365 if let Some(he) = &self.config.hook_engine {
1366 let tool_calls: Vec<ToolCallInfo> = response
1367 .tool_calls()
1368 .iter()
1369 .map(|tc| ToolCallInfo {
1370 name: tc.name.clone(),
1371 args: tc.args.clone(),
1372 })
1373 .collect();
1374
1375 let event = HookEvent::GenerateEnd(GenerateEndEvent {
1376 session_id: session_id.to_string(),
1377 prompt: prompt.to_string(),
1378 response_text: response.text().to_string(),
1379 tool_calls,
1380 usage: TokenUsageInfo {
1381 prompt_tokens: response.usage.prompt_tokens as i32,
1382 completion_tokens: response.usage.completion_tokens as i32,
1383 total_tokens: response.usage.total_tokens as i32,
1384 },
1385 duration_ms,
1386 });
1387 let _ = he.fire(&event).await;
1388 }
1389 }
1390
1391 async fn fire_pre_prompt(
1394 &self,
1395 session_id: &str,
1396 prompt: &str,
1397 system_prompt: &Option<String>,
1398 message_count: usize,
1399 ) -> Option<String> {
1400 if let Some(he) = &self.config.hook_engine {
1401 let event = HookEvent::PrePrompt(PrePromptEvent {
1402 session_id: session_id.to_string(),
1403 prompt: prompt.to_string(),
1404 system_prompt: system_prompt.clone(),
1405 message_count,
1406 });
1407 let result = he.fire(&event).await;
1408 if let HookResult::Continue(Some(modified)) = result {
1409 if let Some(new_prompt) = modified.get("prompt").and_then(|v| v.as_str()) {
1411 return Some(new_prompt.to_string());
1412 }
1413 }
1414 }
1415 None
1416 }
1417
1418 async fn fire_post_response(
1420 &self,
1421 session_id: &str,
1422 response_text: &str,
1423 tool_calls_count: usize,
1424 usage: &TokenUsage,
1425 duration_ms: u64,
1426 ) {
1427 if let Some(he) = &self.config.hook_engine {
1428 let event = HookEvent::PostResponse(PostResponseEvent {
1429 session_id: session_id.to_string(),
1430 response_text: response_text.to_string(),
1431 tool_calls_count,
1432 usage: TokenUsageInfo {
1433 prompt_tokens: usage.prompt_tokens as i32,
1434 completion_tokens: usage.completion_tokens as i32,
1435 total_tokens: usage.total_tokens as i32,
1436 },
1437 duration_ms,
1438 });
1439 let he = Arc::clone(he);
1440 tokio::spawn(async move {
1441 let _ = he.fire(&event).await;
1442 });
1443 }
1444 }
1445
1446 async fn fire_on_error(
1448 &self,
1449 session_id: &str,
1450 error_type: ErrorType,
1451 error_message: &str,
1452 context: serde_json::Value,
1453 ) {
1454 if let Some(he) = &self.config.hook_engine {
1455 let event = HookEvent::OnError(OnErrorEvent {
1456 session_id: session_id.to_string(),
1457 error_type,
1458 error_message: error_message.to_string(),
1459 context,
1460 });
1461 let he = Arc::clone(he);
1462 tokio::spawn(async move {
1463 let _ = he.fire(&event).await;
1464 });
1465 }
1466 }
1467
1468 pub async fn execute(
1474 &self,
1475 history: &[Message],
1476 prompt: &str,
1477 event_tx: Option<mpsc::Sender<AgentEvent>>,
1478 ) -> Result<AgentResult> {
1479 self.execute_with_session(history, prompt, None, event_tx, None)
1480 .await
1481 }
1482
1483 pub async fn execute_from_messages(
1489 &self,
1490 messages: Vec<Message>,
1491 session_id: Option<&str>,
1492 event_tx: Option<mpsc::Sender<AgentEvent>>,
1493 cancel_token: Option<&tokio_util::sync::CancellationToken>,
1494 ) -> Result<AgentResult> {
1495 let default_token = tokio_util::sync::CancellationToken::new();
1496 let token = cancel_token.unwrap_or(&default_token);
1497 tracing::info!(
1498 a3s.session.id = session_id.unwrap_or("none"),
1499 a3s.agent.max_turns = self.config.max_tool_rounds,
1500 "a3s.agent.execute_from_messages started"
1501 );
1502
1503 let effective_prompt = messages
1507 .iter()
1508 .rev()
1509 .find(|m| m.role == "user")
1510 .map(|m| m.text())
1511 .unwrap_or_default();
1512
1513 let result = self
1514 .execute_loop_inner(
1515 &messages,
1516 "",
1517 &effective_prompt,
1518 session_id,
1519 event_tx,
1520 token,
1521 )
1522 .await;
1523
1524 match &result {
1525 Ok(r) => tracing::info!(
1526 a3s.agent.tool_calls_count = r.tool_calls_count,
1527 a3s.llm.total_tokens = r.usage.total_tokens,
1528 "a3s.agent.execute_from_messages completed"
1529 ),
1530 Err(e) => tracing::warn!(
1531 error = %e,
1532 "a3s.agent.execute_from_messages failed"
1533 ),
1534 }
1535
1536 result
1537 }
1538
1539 pub async fn execute_with_session(
1544 &self,
1545 history: &[Message],
1546 prompt: &str,
1547 session_id: Option<&str>,
1548 event_tx: Option<mpsc::Sender<AgentEvent>>,
1549 cancel_token: Option<&tokio_util::sync::CancellationToken>,
1550 ) -> Result<AgentResult> {
1551 let default_token = tokio_util::sync::CancellationToken::new();
1552 let token = cancel_token.unwrap_or(&default_token);
1553 tracing::info!(
1554 a3s.session.id = session_id.unwrap_or("none"),
1555 a3s.agent.max_turns = self.config.max_tool_rounds,
1556 "a3s.agent.execute started"
1557 );
1558
1559 let result = if self.config.planning_enabled {
1561 self.execute_with_planning(history, prompt, event_tx).await
1562 } else {
1563 self.execute_loop(history, prompt, session_id, event_tx, token)
1564 .await
1565 };
1566
1567 match &result {
1568 Ok(r) => {
1569 tracing::info!(
1570 a3s.agent.tool_calls_count = r.tool_calls_count,
1571 a3s.llm.total_tokens = r.usage.total_tokens,
1572 "a3s.agent.execute completed"
1573 );
1574 self.fire_post_response(
1576 session_id.unwrap_or(""),
1577 &r.text,
1578 r.tool_calls_count,
1579 &r.usage,
1580 0, )
1582 .await;
1583 }
1584 Err(e) => {
1585 tracing::warn!(
1586 error = %e,
1587 "a3s.agent.execute failed"
1588 );
1589 self.fire_on_error(
1591 session_id.unwrap_or(""),
1592 ErrorType::Other,
1593 &e.to_string(),
1594 serde_json::json!({"phase": "execute"}),
1595 )
1596 .await;
1597 }
1598 }
1599
1600 result
1601 }
1602
1603 async fn execute_loop(
1609 &self,
1610 history: &[Message],
1611 prompt: &str,
1612 session_id: Option<&str>,
1613 event_tx: Option<mpsc::Sender<AgentEvent>>,
1614 cancel_token: &tokio_util::sync::CancellationToken,
1615 ) -> Result<AgentResult> {
1616 self.execute_loop_inner(history, prompt, prompt, session_id, event_tx, cancel_token)
1619 .await
1620 }
1621
1622 async fn execute_loop_inner(
1627 &self,
1628 history: &[Message],
1629 msg_prompt: &str,
1630 effective_prompt: &str,
1631 session_id: Option<&str>,
1632 event_tx: Option<mpsc::Sender<AgentEvent>>,
1633 cancel_token: &tokio_util::sync::CancellationToken,
1634 ) -> Result<AgentResult> {
1635 let mut messages = history.to_vec();
1636 let mut total_usage = TokenUsage::default();
1637 let mut tool_calls_count = 0;
1638 let mut turn = 0;
1639 let mut parse_error_count: u32 = 0;
1641 let mut continuation_count: u32 = 0;
1643 let mut recent_tool_signatures: Vec<String> = Vec::new();
1645 let mut last_tool_signature: Option<String> = None;
1646 let mut duplicate_tool_call_count: u32 = 0;
1647
1648 if let Some(tx) = &event_tx {
1650 tx.send(AgentEvent::Start {
1651 prompt: effective_prompt.to_string(),
1652 })
1653 .await
1654 .ok();
1655 }
1656
1657 let _queue_forward_handle =
1659 if let (Some(ref queue), Some(ref tx)) = (&self.command_queue, &event_tx) {
1660 let mut rx = queue.subscribe();
1661 let tx = tx.clone();
1662 Some(tokio::spawn(async move {
1663 while let Ok(event) = rx.recv().await {
1664 if tx.send(event).await.is_err() {
1665 break;
1666 }
1667 }
1668 }))
1669 } else {
1670 None
1671 };
1672
1673 let built_system_prompt = Some(self.system_prompt());
1675 let hooked_prompt = if let Some(modified) = self
1676 .fire_pre_prompt(
1677 session_id.unwrap_or(""),
1678 effective_prompt,
1679 &built_system_prompt,
1680 messages.len(),
1681 )
1682 .await
1683 {
1684 modified
1685 } else {
1686 effective_prompt.to_string()
1687 };
1688 let effective_prompt = hooked_prompt.as_str();
1689
1690 if let Some(ref sp) = self.config.security_provider {
1692 sp.taint_input(effective_prompt);
1693 }
1694
1695 let system_with_memory = if let Some(ref memory) = self.config.memory {
1697 match memory.recall_similar(effective_prompt, 5).await {
1698 Ok(items) if !items.is_empty() => {
1699 if let Some(tx) = &event_tx {
1700 for item in &items {
1701 tx.send(AgentEvent::MemoryRecalled {
1702 memory_id: item.id.clone(),
1703 content: item.content.clone(),
1704 relevance: item.relevance_score(),
1705 })
1706 .await
1707 .ok();
1708 }
1709 tx.send(AgentEvent::MemoriesSearched {
1710 query: Some(effective_prompt.to_string()),
1711 tags: Vec::new(),
1712 result_count: items.len(),
1713 })
1714 .await
1715 .ok();
1716 }
1717 let memory_context = items
1718 .iter()
1719 .map(|i| format!("- {}", i.content))
1720 .collect::<Vec<_>>()
1721 .join(
1722 "
1723",
1724 );
1725 let base = self.system_prompt();
1726 Some(format!(
1727 "{}
1728
1729## Relevant past experience
1730{}",
1731 base, memory_context
1732 ))
1733 }
1734 _ => Some(self.system_prompt()),
1735 }
1736 } else {
1737 Some(self.system_prompt())
1738 };
1739
1740 let augmented_system = if !self.config.context_providers.is_empty() {
1742 if let Some(tx) = &event_tx {
1744 let provider_names: Vec<String> = self
1745 .config
1746 .context_providers
1747 .iter()
1748 .map(|p| p.name().to_string())
1749 .collect();
1750 tx.send(AgentEvent::ContextResolving {
1751 providers: provider_names,
1752 })
1753 .await
1754 .ok();
1755 }
1756
1757 tracing::info!(
1758 a3s.context.providers = self.config.context_providers.len() as i64,
1759 "Context resolution started"
1760 );
1761 let context_results = self.resolve_context(effective_prompt, session_id).await;
1762
1763 if let Some(tx) = &event_tx {
1765 let total_items: usize = context_results.iter().map(|r| r.items.len()).sum();
1766 let total_tokens: usize = context_results.iter().map(|r| r.total_tokens).sum();
1767
1768 tracing::info!(
1769 context_items = total_items,
1770 context_tokens = total_tokens,
1771 "Context resolution completed"
1772 );
1773
1774 tx.send(AgentEvent::ContextResolved {
1775 total_items,
1776 total_tokens,
1777 })
1778 .await
1779 .ok();
1780 }
1781
1782 self.build_augmented_system_prompt(&context_results)
1783 } else {
1784 Some(self.system_prompt())
1785 };
1786
1787 let base_prompt = self.system_prompt();
1789 let augmented_system = match (augmented_system, system_with_memory) {
1790 (Some(ctx), Some(mem)) if ctx != mem => Some(ctx.replacen(&base_prompt, &mem, 1)),
1791 (Some(ctx), _) => Some(ctx),
1792 (None, mem) => mem,
1793 };
1794
1795 if !msg_prompt.is_empty() {
1797 messages.push(Message::user(msg_prompt));
1798 }
1799
1800 loop {
1801 turn += 1;
1802
1803 if turn > self.config.max_tool_rounds {
1804 let error = format!("Max tool rounds ({}) exceeded", self.config.max_tool_rounds);
1805 if let Some(tx) = &event_tx {
1806 tx.send(AgentEvent::Error {
1807 message: error.clone(),
1808 })
1809 .await
1810 .ok();
1811 }
1812 anyhow::bail!(error);
1813 }
1814
1815 if let Some(tx) = &event_tx {
1817 tx.send(AgentEvent::TurnStart { turn }).await.ok();
1818 }
1819
1820 tracing::info!(
1821 turn = turn,
1822 max_turns = self.config.max_tool_rounds,
1823 "Agent turn started"
1824 );
1825
1826 tracing::info!(
1828 a3s.llm.streaming = event_tx.is_some(),
1829 "LLM completion started"
1830 );
1831
1832 self.fire_generate_start(
1834 session_id.unwrap_or(""),
1835 effective_prompt,
1836 &augmented_system,
1837 )
1838 .await;
1839
1840 let llm_start = std::time::Instant::now();
1841 let response = {
1845 let threshold = self.config.circuit_breaker_threshold.max(1);
1846 let mut attempt = 0u32;
1847 loop {
1848 attempt += 1;
1849 let result = self
1850 .call_llm(
1851 &messages,
1852 augmented_system.as_deref(),
1853 &event_tx,
1854 cancel_token,
1855 )
1856 .await;
1857 match result {
1858 Ok(r) => {
1859 break r;
1860 }
1861 Err(e) if cancel_token.is_cancelled() => {
1863 anyhow::bail!(e);
1864 }
1865 Err(e) if attempt < threshold && (event_tx.is_none() || attempt == 1) => {
1867 tracing::warn!(
1868 turn = turn,
1869 attempt = attempt,
1870 threshold = threshold,
1871 error = %e,
1872 "LLM call failed, will retry"
1873 );
1874 tokio::time::sleep(Duration::from_millis(100 * attempt as u64)).await;
1875 }
1876 Err(e) => {
1878 let msg = if attempt > 1 {
1879 format!(
1880 "LLM circuit breaker triggered: failed after {} attempt(s): {}",
1881 attempt, e
1882 )
1883 } else {
1884 format!("LLM call failed: {}", e)
1885 };
1886 tracing::error!(turn = turn, attempt = attempt, "{}", msg);
1887 self.fire_on_error(
1889 session_id.unwrap_or(""),
1890 ErrorType::LlmFailure,
1891 &msg,
1892 serde_json::json!({"turn": turn, "attempt": attempt}),
1893 )
1894 .await;
1895 if let Some(tx) = &event_tx {
1896 tx.send(AgentEvent::Error {
1897 message: msg.clone(),
1898 })
1899 .await
1900 .ok();
1901 }
1902 anyhow::bail!(msg);
1903 }
1904 }
1905 }
1906 };
1907
1908 total_usage.prompt_tokens += response.usage.prompt_tokens;
1910 total_usage.completion_tokens += response.usage.completion_tokens;
1911 total_usage.total_tokens += response.usage.total_tokens;
1912 let llm_duration = llm_start.elapsed();
1914 tracing::info!(
1915 turn = turn,
1916 streaming = event_tx.is_some(),
1917 prompt_tokens = response.usage.prompt_tokens,
1918 completion_tokens = response.usage.completion_tokens,
1919 total_tokens = response.usage.total_tokens,
1920 stop_reason = response.stop_reason.as_deref().unwrap_or("unknown"),
1921 duration_ms = llm_duration.as_millis() as u64,
1922 "LLM completion finished"
1923 );
1924
1925 self.fire_generate_end(
1927 session_id.unwrap_or(""),
1928 effective_prompt,
1929 &response,
1930 llm_duration.as_millis() as u64,
1931 )
1932 .await;
1933
1934 crate::telemetry::record_llm_usage(
1936 response.usage.prompt_tokens,
1937 response.usage.completion_tokens,
1938 response.usage.total_tokens,
1939 response.stop_reason.as_deref(),
1940 );
1941 tracing::info!(
1943 turn = turn,
1944 a3s.llm.total_tokens = response.usage.total_tokens,
1945 "Turn token usage"
1946 );
1947
1948 messages.push(response.message.clone());
1950
1951 let tool_calls = response.tool_calls();
1953
1954 if let Some(tx) = &event_tx {
1956 tx.send(AgentEvent::TurnEnd {
1957 turn,
1958 usage: response.usage.clone(),
1959 })
1960 .await
1961 .ok();
1962 }
1963
1964 if self.config.auto_compact {
1966 let used = response.usage.prompt_tokens;
1967 let max = self.config.max_context_tokens;
1968 let threshold = self.config.auto_compact_threshold;
1969
1970 if crate::session::compaction::should_auto_compact(used, max, threshold) {
1971 let before_len = messages.len();
1972 let percent_before = used as f32 / max as f32;
1973
1974 tracing::info!(
1975 used_tokens = used,
1976 max_tokens = max,
1977 percent = percent_before,
1978 threshold = threshold,
1979 "Auto-compact triggered"
1980 );
1981
1982 if let Some(pruned) = crate::session::compaction::prune_tool_outputs(&messages)
1984 {
1985 messages = pruned;
1986 tracing::info!("Tool output pruning applied");
1987 }
1988
1989 if let Ok(Some(compacted)) = crate::session::compaction::compact_messages(
1991 session_id.unwrap_or(""),
1992 &messages,
1993 &self.llm_client,
1994 )
1995 .await
1996 {
1997 messages = compacted;
1998 }
1999
2000 if let Some(tx) = &event_tx {
2002 tx.send(AgentEvent::ContextCompacted {
2003 session_id: session_id.unwrap_or("").to_string(),
2004 before_messages: before_len,
2005 after_messages: messages.len(),
2006 percent_before,
2007 })
2008 .await
2009 .ok();
2010 }
2011 }
2012 }
2013
2014 if tool_calls.is_empty() {
2015 let final_text = response.text();
2018
2019 if self.config.continuation_enabled
2020 && continuation_count < self.config.max_continuation_turns
2021 && turn < self.config.max_tool_rounds && Self::looks_incomplete(&final_text)
2023 {
2024 continuation_count += 1;
2025 tracing::info!(
2026 turn = turn,
2027 continuation = continuation_count,
2028 max_continuation = self.config.max_continuation_turns,
2029 "Injecting continuation message — response looks incomplete"
2030 );
2031 messages.push(Message::user(crate::prompts::CONTINUATION));
2033 continue;
2034 }
2035
2036 let final_text = if let Some(ref sp) = self.config.security_provider {
2038 sp.sanitize_output(&final_text)
2039 } else {
2040 final_text
2041 };
2042
2043 tracing::info!(
2045 tool_calls_count = tool_calls_count,
2046 total_prompt_tokens = total_usage.prompt_tokens,
2047 total_completion_tokens = total_usage.completion_tokens,
2048 total_tokens = total_usage.total_tokens,
2049 turns = turn,
2050 "Agent execution completed"
2051 );
2052
2053 if let Some(tx) = &event_tx {
2054 tx.send(AgentEvent::End {
2055 text: final_text.clone(),
2056 usage: total_usage.clone(),
2057 meta: response.meta.clone(),
2058 })
2059 .await
2060 .ok();
2061 }
2062
2063 if let Some(sid) = session_id {
2065 self.notify_turn_complete(sid, effective_prompt, &final_text)
2066 .await;
2067 }
2068
2069 return Ok(AgentResult {
2070 text: final_text,
2071 messages,
2072 usage: total_usage,
2073 tool_calls_count,
2074 });
2075 }
2076
2077 let tool_calls = if self.config.hook_engine.is_none()
2081 && self.config.confirmation_manager.is_none()
2082 && tool_calls.len() > 1
2083 && tool_calls
2084 .iter()
2085 .all(|tc| Self::is_parallel_safe_write(&tc.name, &tc.args))
2086 && {
2087 let paths: Vec<_> = tool_calls
2089 .iter()
2090 .filter_map(|tc| Self::extract_write_path(&tc.args))
2091 .collect();
2092 paths.len() == tool_calls.len()
2093 && paths.iter().collect::<std::collections::HashSet<_>>().len()
2094 == paths.len()
2095 } {
2096 tracing::info!(
2097 count = tool_calls.len(),
2098 "Parallel write batch: executing {} independent file writes concurrently",
2099 tool_calls.len()
2100 );
2101
2102 let futures: Vec<_> = tool_calls
2103 .iter()
2104 .map(|tc| {
2105 let ctx = self.tool_context.clone();
2106 let executor = Arc::clone(&self.tool_executor);
2107 let name = tc.name.clone();
2108 let args = tc.args.clone();
2109 async move { executor.execute_with_context(&name, &args, &ctx).await }
2110 })
2111 .collect();
2112
2113 let results = join_all(futures).await;
2114
2115 for (tc, result) in tool_calls.iter().zip(results) {
2117 tool_calls_count += 1;
2118 let (output, exit_code, is_error, metadata, images) =
2119 Self::tool_result_to_tuple(result);
2120
2121 let output = if let Some(ref sp) = self.config.security_provider {
2122 sp.sanitize_output(&output)
2123 } else {
2124 output
2125 };
2126
2127 if let Some(tx) = &event_tx {
2128 tx.send(AgentEvent::ToolEnd {
2129 id: tc.id.clone(),
2130 name: tc.name.clone(),
2131 output: output.clone(),
2132 exit_code,
2133 metadata,
2134 })
2135 .await
2136 .ok();
2137 }
2138
2139 if images.is_empty() {
2140 messages.push(Message::tool_result(&tc.id, &output, is_error));
2141 } else {
2142 messages.push(Message::tool_result_with_images(
2143 &tc.id, &output, &images, is_error,
2144 ));
2145 }
2146 }
2147
2148 continue;
2150 } else {
2151 tool_calls
2152 };
2153
2154 for tool_call in tool_calls {
2155 tool_calls_count += 1;
2156
2157 let tool_start = std::time::Instant::now();
2158
2159 tracing::info!(
2160 tool_name = tool_call.name.as_str(),
2161 tool_id = tool_call.id.as_str(),
2162 "Tool execution started"
2163 );
2164
2165 if let Some(parse_error) =
2171 tool_call.args.get("__parse_error").and_then(|v| v.as_str())
2172 {
2173 parse_error_count += 1;
2174 let error_msg = format!("Error: {}", parse_error);
2175 tracing::warn!(
2176 tool = tool_call.name.as_str(),
2177 parse_error_count = parse_error_count,
2178 max_parse_retries = self.config.max_parse_retries,
2179 "Malformed tool arguments from LLM"
2180 );
2181
2182 if let Some(tx) = &event_tx {
2183 tx.send(AgentEvent::ToolEnd {
2184 id: tool_call.id.clone(),
2185 name: tool_call.name.clone(),
2186 output: error_msg.clone(),
2187 exit_code: 1,
2188 metadata: None,
2189 })
2190 .await
2191 .ok();
2192 }
2193
2194 messages.push(Message::tool_result(&tool_call.id, &error_msg, true));
2195
2196 if parse_error_count > self.config.max_parse_retries {
2197 let msg = format!(
2198 "LLM produced malformed tool arguments {} time(s) in a row \
2199 (max_parse_retries={}); giving up",
2200 parse_error_count, self.config.max_parse_retries
2201 );
2202 tracing::error!("{}", msg);
2203 if let Some(tx) = &event_tx {
2204 tx.send(AgentEvent::Error {
2205 message: msg.clone(),
2206 })
2207 .await
2208 .ok();
2209 }
2210 anyhow::bail!(msg);
2211 }
2212 continue;
2213 }
2214
2215 parse_error_count = 0;
2217
2218 if let Some(ref registry) = self.config.skill_registry {
2220 let instruction_skills =
2221 registry.by_kind(crate::skills::SkillKind::Instruction);
2222 let has_restrictions =
2223 instruction_skills.iter().any(|s| s.allowed_tools.is_some());
2224 if has_restrictions {
2225 let allowed = instruction_skills
2226 .iter()
2227 .any(|s| s.is_tool_allowed(&tool_call.name));
2228 if !allowed {
2229 let msg = format!(
2230 "Tool '{}' is not allowed by any active skill.",
2231 tool_call.name
2232 );
2233 tracing::info!(
2234 tool_name = tool_call.name.as_str(),
2235 "Tool blocked by skill registry"
2236 );
2237 if let Some(tx) = &event_tx {
2238 tx.send(AgentEvent::PermissionDenied {
2239 tool_id: tool_call.id.clone(),
2240 tool_name: tool_call.name.clone(),
2241 args: tool_call.args.clone(),
2242 reason: msg.clone(),
2243 })
2244 .await
2245 .ok();
2246 }
2247 messages.push(Message::tool_result(&tool_call.id, &msg, true));
2248 continue;
2249 }
2250 }
2251 }
2252
2253 let tool_signature = Self::tool_call_signature(&tool_call.name, &tool_call.args);
2254 if last_tool_signature.as_deref() == Some(tool_signature.as_str()) {
2255 duplicate_tool_call_count += 1;
2256 } else {
2257 last_tool_signature = Some(tool_signature.clone());
2258 duplicate_tool_call_count = 1;
2259 }
2260
2261 if duplicate_tool_call_count > self.config.duplicate_tool_call_threshold {
2262 let msg = format!(
2263 "Detected repeated identical tool call loop: '{}' was requested {} time(s) in a row. Stop retrying the same tool call and change strategy.",
2264 tool_call.name, duplicate_tool_call_count
2265 );
2266 tracing::error!(
2267 tool_name = tool_call.name.as_str(),
2268 duplicate_tool_call_count,
2269 threshold = self.config.duplicate_tool_call_threshold,
2270 signature = tool_signature,
2271 "{}",
2272 msg
2273 );
2274 if let Some(tx) = &event_tx {
2275 tx.send(AgentEvent::Error {
2276 message: msg.clone(),
2277 })
2278 .await
2279 .ok();
2280 }
2281 anyhow::bail!(msg);
2282 }
2283
2284 if let Some(HookResult::Block(reason)) = self
2286 .fire_pre_tool_use(
2287 session_id.unwrap_or(""),
2288 &tool_call.name,
2289 &tool_call.args,
2290 recent_tool_signatures.clone(),
2291 )
2292 .await
2293 {
2294 let msg = format!("Tool '{}' blocked by hook: {}", tool_call.name, reason);
2295 tracing::info!(
2296 tool_name = tool_call.name.as_str(),
2297 "Tool blocked by PreToolUse hook"
2298 );
2299
2300 if let Some(tx) = &event_tx {
2301 tx.send(AgentEvent::PermissionDenied {
2302 tool_id: tool_call.id.clone(),
2303 tool_name: tool_call.name.clone(),
2304 args: tool_call.args.clone(),
2305 reason: reason.clone(),
2306 })
2307 .await
2308 .ok();
2309 }
2310
2311 messages.push(Message::tool_result(&tool_call.id, &msg, true));
2312 continue;
2313 }
2314
2315 let permission_decision = if let Some(checker) = &self.config.permission_checker {
2317 checker.check(&tool_call.name, &tool_call.args)
2318 } else {
2319 PermissionDecision::Ask
2321 };
2322
2323 let (output, exit_code, is_error, metadata, images) = match permission_decision {
2324 PermissionDecision::Deny => {
2325 tracing::info!(
2326 tool_name = tool_call.name.as_str(),
2327 permission = "deny",
2328 "Tool permission denied"
2329 );
2330 let denial_msg = format!(
2332 "Permission denied: Tool '{}' is blocked by permission policy.",
2333 tool_call.name
2334 );
2335
2336 if let Some(tx) = &event_tx {
2338 tx.send(AgentEvent::PermissionDenied {
2339 tool_id: tool_call.id.clone(),
2340 tool_name: tool_call.name.clone(),
2341 args: tool_call.args.clone(),
2342 reason: "Blocked by deny rule in permission policy".to_string(),
2343 })
2344 .await
2345 .ok();
2346 }
2347
2348 (denial_msg, 1, true, None, Vec::new())
2349 }
2350 PermissionDecision::Allow => {
2351 tracing::info!(
2352 tool_name = tool_call.name.as_str(),
2353 permission = "allow",
2354 "Tool permission: allow"
2355 );
2356 let stream_ctx =
2358 self.streaming_tool_context(&event_tx, &tool_call.id, &tool_call.name);
2359 let result = self
2360 .execute_tool_queued_or_direct(
2361 &tool_call.name,
2362 &tool_call.args,
2363 &stream_ctx,
2364 )
2365 .await;
2366
2367 Self::tool_result_to_tuple(result)
2368 }
2369 PermissionDecision::Ask => {
2370 tracing::info!(
2371 tool_name = tool_call.name.as_str(),
2372 permission = "ask",
2373 "Tool permission: ask"
2374 );
2375 if let Some(cm) = &self.config.confirmation_manager {
2377 if !cm.requires_confirmation(&tool_call.name).await {
2379 let stream_ctx = self.streaming_tool_context(
2380 &event_tx,
2381 &tool_call.id,
2382 &tool_call.name,
2383 );
2384 let result = self
2385 .execute_tool_queued_or_direct(
2386 &tool_call.name,
2387 &tool_call.args,
2388 &stream_ctx,
2389 )
2390 .await;
2391
2392 let (output, exit_code, is_error, metadata, images) =
2393 Self::tool_result_to_tuple(result);
2394
2395 if images.is_empty() {
2397 messages.push(Message::tool_result(
2398 &tool_call.id,
2399 &output,
2400 is_error,
2401 ));
2402 } else {
2403 messages.push(Message::tool_result_with_images(
2404 &tool_call.id,
2405 &output,
2406 &images,
2407 is_error,
2408 ));
2409 }
2410
2411 let tool_duration = tool_start.elapsed();
2413 crate::telemetry::record_tool_result(exit_code, tool_duration);
2414
2415 if let Some(tx) = &event_tx {
2417 tx.send(AgentEvent::ToolEnd {
2418 id: tool_call.id.clone(),
2419 name: tool_call.name.clone(),
2420 output: output.clone(),
2421 exit_code,
2422 metadata,
2423 })
2424 .await
2425 .ok();
2426 }
2427
2428 self.fire_post_tool_use(
2430 session_id.unwrap_or(""),
2431 &tool_call.name,
2432 &tool_call.args,
2433 &output,
2434 exit_code == 0,
2435 tool_duration.as_millis() as u64,
2436 )
2437 .await;
2438
2439 continue; }
2441
2442 let policy = cm.policy().await;
2444 let timeout_ms = policy.default_timeout_ms;
2445 let timeout_action = policy.timeout_action;
2446
2447 let rx = cm
2449 .request_confirmation(
2450 &tool_call.id,
2451 &tool_call.name,
2452 &tool_call.args,
2453 )
2454 .await;
2455
2456 if let Some(tx) = &event_tx {
2460 tx.send(AgentEvent::ConfirmationRequired {
2461 tool_id: tool_call.id.clone(),
2462 tool_name: tool_call.name.clone(),
2463 args: tool_call.args.clone(),
2464 timeout_ms,
2465 })
2466 .await
2467 .ok();
2468 }
2469
2470 let confirmation_result =
2472 tokio::time::timeout(Duration::from_millis(timeout_ms), rx).await;
2473
2474 match confirmation_result {
2475 Ok(Ok(response)) => {
2476 if let Some(tx) = &event_tx {
2478 tx.send(AgentEvent::ConfirmationReceived {
2479 tool_id: tool_call.id.clone(),
2480 approved: response.approved,
2481 reason: response.reason.clone(),
2482 })
2483 .await
2484 .ok();
2485 }
2486 if response.approved {
2487 let stream_ctx = self.streaming_tool_context(
2488 &event_tx,
2489 &tool_call.id,
2490 &tool_call.name,
2491 );
2492 let result = self
2493 .execute_tool_queued_or_direct(
2494 &tool_call.name,
2495 &tool_call.args,
2496 &stream_ctx,
2497 )
2498 .await;
2499
2500 Self::tool_result_to_tuple(result)
2501 } else {
2502 let rejection_msg = format!(
2503 "Tool '{}' execution was REJECTED by the user. Reason: {}. \
2504 DO NOT retry this tool call unless the user explicitly asks you to.",
2505 tool_call.name,
2506 response.reason.unwrap_or_else(|| "No reason provided".to_string())
2507 );
2508 (rejection_msg, 1, true, None, Vec::new())
2509 }
2510 }
2511 Ok(Err(_)) => {
2512 if let Some(tx) = &event_tx {
2514 tx.send(AgentEvent::ConfirmationTimeout {
2515 tool_id: tool_call.id.clone(),
2516 action_taken: "rejected".to_string(),
2517 })
2518 .await
2519 .ok();
2520 }
2521 let msg = format!(
2522 "Tool '{}' confirmation failed: confirmation channel closed",
2523 tool_call.name
2524 );
2525 (msg, 1, true, None, Vec::new())
2526 }
2527 Err(_) => {
2528 cm.check_timeouts().await;
2529
2530 if let Some(tx) = &event_tx {
2532 tx.send(AgentEvent::ConfirmationTimeout {
2533 tool_id: tool_call.id.clone(),
2534 action_taken: match timeout_action {
2535 crate::hitl::TimeoutAction::Reject => {
2536 "rejected".to_string()
2537 }
2538 crate::hitl::TimeoutAction::AutoApprove => {
2539 "auto_approved".to_string()
2540 }
2541 },
2542 })
2543 .await
2544 .ok();
2545 }
2546
2547 match timeout_action {
2548 crate::hitl::TimeoutAction::Reject => {
2549 let msg = format!(
2550 "Tool '{}' execution was REJECTED: user confirmation timed out after {}ms. \
2551 DO NOT retry this tool call — the user did not approve it. \
2552 Inform the user that the operation requires their approval and ask them to try again.",
2553 tool_call.name, timeout_ms
2554 );
2555 (msg, 1, true, None, Vec::new())
2556 }
2557 crate::hitl::TimeoutAction::AutoApprove => {
2558 let stream_ctx = self.streaming_tool_context(
2559 &event_tx,
2560 &tool_call.id,
2561 &tool_call.name,
2562 );
2563 let result = self
2564 .execute_tool_queued_or_direct(
2565 &tool_call.name,
2566 &tool_call.args,
2567 &stream_ctx,
2568 )
2569 .await;
2570
2571 Self::tool_result_to_tuple(result)
2572 }
2573 }
2574 }
2575 }
2576 } else {
2577 let msg = format!(
2579 "Tool '{}' requires confirmation but no HITL confirmation manager is configured. \
2580 Configure a confirmation policy to enable tool execution.",
2581 tool_call.name
2582 );
2583 tracing::warn!(
2584 tool_name = tool_call.name.as_str(),
2585 "Tool requires confirmation but no HITL manager configured"
2586 );
2587 (msg, 1, true, None, Vec::new())
2588 }
2589 }
2590 };
2591
2592 let tool_duration = tool_start.elapsed();
2593 crate::telemetry::record_tool_result(exit_code, tool_duration);
2594
2595 let output = if let Some(ref sp) = self.config.security_provider {
2597 sp.sanitize_output(&output)
2598 } else {
2599 output
2600 };
2601
2602 recent_tool_signatures.push(format!(
2603 "{} => {}",
2604 tool_signature,
2605 if is_error { "error" } else { "ok" }
2606 ));
2607 if recent_tool_signatures.len() > 8 {
2608 let overflow = recent_tool_signatures.len() - 8;
2609 recent_tool_signatures.drain(0..overflow);
2610 }
2611
2612 self.fire_post_tool_use(
2614 session_id.unwrap_or(""),
2615 &tool_call.name,
2616 &tool_call.args,
2617 &output,
2618 exit_code == 0,
2619 tool_duration.as_millis() as u64,
2620 )
2621 .await;
2622
2623 if let Some(ref memory) = self.config.memory {
2625 let tools_used = [tool_call.name.clone()];
2626 let remember_result = if exit_code == 0 {
2627 memory
2628 .remember_success(effective_prompt, &tools_used, &output)
2629 .await
2630 } else {
2631 memory
2632 .remember_failure(effective_prompt, &output, &tools_used)
2633 .await
2634 };
2635 match remember_result {
2636 Ok(()) => {
2637 if let Some(tx) = &event_tx {
2638 let item_type = if exit_code == 0 { "success" } else { "failure" };
2639 tx.send(AgentEvent::MemoryStored {
2640 memory_id: uuid::Uuid::new_v4().to_string(),
2641 memory_type: item_type.to_string(),
2642 importance: if exit_code == 0 { 0.8 } else { 0.9 },
2643 tags: vec![item_type.to_string(), tool_call.name.clone()],
2644 })
2645 .await
2646 .ok();
2647 }
2648 }
2649 Err(e) => {
2650 tracing::warn!("Failed to store memory after tool execution: {}", e);
2651 }
2652 }
2653 }
2654
2655 if let Some(tx) = &event_tx {
2657 tx.send(AgentEvent::ToolEnd {
2658 id: tool_call.id.clone(),
2659 name: tool_call.name.clone(),
2660 output: output.clone(),
2661 exit_code,
2662 metadata,
2663 })
2664 .await
2665 .ok();
2666 }
2667
2668 if images.is_empty() {
2670 messages.push(Message::tool_result(&tool_call.id, &output, is_error));
2671 } else {
2672 messages.push(Message::tool_result_with_images(
2673 &tool_call.id,
2674 &output,
2675 &images,
2676 is_error,
2677 ));
2678 }
2679 }
2680 }
2681 }
2682
2683 pub async fn execute_streaming(
2685 &self,
2686 history: &[Message],
2687 prompt: &str,
2688 ) -> Result<(
2689 mpsc::Receiver<AgentEvent>,
2690 tokio::task::JoinHandle<Result<AgentResult>>,
2691 tokio_util::sync::CancellationToken,
2692 )> {
2693 let (tx, rx) = mpsc::channel(100);
2694 let cancel_token = tokio_util::sync::CancellationToken::new();
2695
2696 let llm_client = self.llm_client.clone();
2697 let tool_executor = self.tool_executor.clone();
2698 let tool_context = self.tool_context.clone();
2699 let config = self.config.clone();
2700 let tool_metrics = self.tool_metrics.clone();
2701 let command_queue = self.command_queue.clone();
2702 let history = history.to_vec();
2703 let prompt = prompt.to_string();
2704 let token_clone = cancel_token.clone();
2705
2706 let handle = tokio::spawn(async move {
2707 let mut agent = AgentLoop::new(llm_client, tool_executor, tool_context, config);
2708 if let Some(metrics) = tool_metrics {
2709 agent = agent.with_tool_metrics(metrics);
2710 }
2711 if let Some(queue) = command_queue {
2712 agent = agent.with_queue(queue);
2713 }
2714 agent
2715 .execute_with_session(&history, &prompt, None, Some(tx), Some(&token_clone))
2716 .await
2717 });
2718
2719 Ok((rx, handle, cancel_token))
2720 }
2721
2722 pub async fn plan(&self, prompt: &str, _context: Option<&str>) -> Result<ExecutionPlan> {
2727 use crate::planning::LlmPlanner;
2728
2729 match LlmPlanner::create_plan(&self.llm_client, prompt).await {
2730 Ok(plan) => Ok(plan),
2731 Err(e) => {
2732 tracing::warn!("LLM plan creation failed, using fallback: {}", e);
2733 Ok(LlmPlanner::fallback_plan(prompt))
2734 }
2735 }
2736 }
2737
2738 pub async fn execute_with_planning(
2740 &self,
2741 history: &[Message],
2742 prompt: &str,
2743 event_tx: Option<mpsc::Sender<AgentEvent>>,
2744 ) -> Result<AgentResult> {
2745 if let Some(tx) = &event_tx {
2747 tx.send(AgentEvent::PlanningStart {
2748 prompt: prompt.to_string(),
2749 })
2750 .await
2751 .ok();
2752 }
2753
2754 let goal = if self.config.goal_tracking {
2756 let g = self.extract_goal(prompt).await?;
2757 if let Some(tx) = &event_tx {
2758 tx.send(AgentEvent::GoalExtracted { goal: g.clone() })
2759 .await
2760 .ok();
2761 }
2762 Some(g)
2763 } else {
2764 None
2765 };
2766
2767 let plan = self.plan(prompt, None).await?;
2769
2770 if let Some(tx) = &event_tx {
2772 tx.send(AgentEvent::PlanningEnd {
2773 estimated_steps: plan.steps.len(),
2774 plan: plan.clone(),
2775 })
2776 .await
2777 .ok();
2778 }
2779
2780 let plan_start = std::time::Instant::now();
2781
2782 let result = self.execute_plan(history, &plan, event_tx.clone()).await?;
2784
2785 if self.config.goal_tracking {
2787 if let Some(ref g) = goal {
2788 let achieved = self.check_goal_achievement(g, &result.text).await?;
2789 if achieved {
2790 if let Some(tx) = &event_tx {
2791 tx.send(AgentEvent::GoalAchieved {
2792 goal: g.description.clone(),
2793 total_steps: result.messages.len(),
2794 duration_ms: plan_start.elapsed().as_millis() as i64,
2795 })
2796 .await
2797 .ok();
2798 }
2799 }
2800 }
2801 }
2802
2803 Ok(result)
2804 }
2805
2806 async fn execute_plan(
2813 &self,
2814 history: &[Message],
2815 plan: &ExecutionPlan,
2816 event_tx: Option<mpsc::Sender<AgentEvent>>,
2817 ) -> Result<AgentResult> {
2818 let mut plan = plan.clone();
2819 let mut current_history = history.to_vec();
2820 let mut total_usage = TokenUsage::default();
2821 let mut tool_calls_count = 0;
2822 let total_steps = plan.steps.len();
2823
2824 let steps_text = plan
2826 .steps
2827 .iter()
2828 .enumerate()
2829 .map(|(i, step)| format!("{}. {}", i + 1, step.content))
2830 .collect::<Vec<_>>()
2831 .join("\n");
2832 current_history.push(Message::user(&crate::prompts::render(
2833 crate::prompts::PLAN_EXECUTE_GOAL,
2834 &[("goal", &plan.goal), ("steps", &steps_text)],
2835 )));
2836
2837 loop {
2838 let ready: Vec<String> = plan
2839 .get_ready_steps()
2840 .iter()
2841 .map(|s| s.id.clone())
2842 .collect();
2843
2844 if ready.is_empty() {
2845 if plan.has_deadlock() {
2847 tracing::warn!(
2848 "Plan deadlock detected: {} pending steps with unresolvable dependencies",
2849 plan.pending_count()
2850 );
2851 }
2852 break;
2853 }
2854
2855 if ready.len() == 1 {
2856 let step_id = &ready[0];
2858 let step = plan
2859 .steps
2860 .iter()
2861 .find(|s| s.id == *step_id)
2862 .ok_or_else(|| anyhow::anyhow!("step '{}' not found in plan", step_id))?
2863 .clone();
2864 let step_number = plan
2865 .steps
2866 .iter()
2867 .position(|s| s.id == *step_id)
2868 .unwrap_or(0)
2869 + 1;
2870
2871 if let Some(tx) = &event_tx {
2873 tx.send(AgentEvent::StepStart {
2874 step_id: step.id.clone(),
2875 description: step.content.clone(),
2876 step_number,
2877 total_steps,
2878 })
2879 .await
2880 .ok();
2881 }
2882
2883 plan.mark_status(&step.id, TaskStatus::InProgress);
2884
2885 let step_prompt = crate::prompts::render(
2886 crate::prompts::PLAN_EXECUTE_STEP,
2887 &[
2888 ("step_num", &step_number.to_string()),
2889 ("description", &step.content),
2890 ],
2891 );
2892
2893 match self
2894 .execute_loop(
2895 ¤t_history,
2896 &step_prompt,
2897 None,
2898 event_tx.clone(),
2899 &tokio_util::sync::CancellationToken::new(),
2900 )
2901 .await
2902 {
2903 Ok(result) => {
2904 current_history = result.messages.clone();
2905 total_usage.prompt_tokens += result.usage.prompt_tokens;
2906 total_usage.completion_tokens += result.usage.completion_tokens;
2907 total_usage.total_tokens += result.usage.total_tokens;
2908 tool_calls_count += result.tool_calls_count;
2909 plan.mark_status(&step.id, TaskStatus::Completed);
2910
2911 if let Some(tx) = &event_tx {
2912 tx.send(AgentEvent::StepEnd {
2913 step_id: step.id.clone(),
2914 status: TaskStatus::Completed,
2915 step_number,
2916 total_steps,
2917 })
2918 .await
2919 .ok();
2920 }
2921 }
2922 Err(e) => {
2923 tracing::error!("Plan step '{}' failed: {}", step.id, e);
2924 plan.mark_status(&step.id, TaskStatus::Failed);
2925
2926 if let Some(tx) = &event_tx {
2927 tx.send(AgentEvent::StepEnd {
2928 step_id: step.id.clone(),
2929 status: TaskStatus::Failed,
2930 step_number,
2931 total_steps,
2932 })
2933 .await
2934 .ok();
2935 }
2936 }
2937 }
2938 } else {
2939 let ready_steps: Vec<_> = ready
2946 .iter()
2947 .filter_map(|id| {
2948 let step = plan.steps.iter().find(|s| s.id == *id)?.clone();
2949 let step_number =
2950 plan.steps.iter().position(|s| s.id == *id).unwrap_or(0) + 1;
2951 Some((step, step_number))
2952 })
2953 .collect();
2954
2955 for (step, step_number) in &ready_steps {
2957 plan.mark_status(&step.id, TaskStatus::InProgress);
2958 if let Some(tx) = &event_tx {
2959 tx.send(AgentEvent::StepStart {
2960 step_id: step.id.clone(),
2961 description: step.content.clone(),
2962 step_number: *step_number,
2963 total_steps,
2964 })
2965 .await
2966 .ok();
2967 }
2968 }
2969
2970 let mut join_set = tokio::task::JoinSet::new();
2972 for (step, step_number) in &ready_steps {
2973 let base_history = current_history.clone();
2974 let agent_clone = self.clone();
2975 let tx = event_tx.clone();
2976 let step_clone = step.clone();
2977 let sn = *step_number;
2978
2979 join_set.spawn(async move {
2980 let prompt = crate::prompts::render(
2981 crate::prompts::PLAN_EXECUTE_STEP,
2982 &[
2983 ("step_num", &sn.to_string()),
2984 ("description", &step_clone.content),
2985 ],
2986 );
2987 let result = agent_clone
2988 .execute_loop(
2989 &base_history,
2990 &prompt,
2991 None,
2992 tx,
2993 &tokio_util::sync::CancellationToken::new(),
2994 )
2995 .await;
2996 (step_clone.id, sn, result)
2997 });
2998 }
2999
3000 let mut parallel_summaries = Vec::new();
3002 while let Some(join_result) = join_set.join_next().await {
3003 match join_result {
3004 Ok((step_id, step_number, step_result)) => match step_result {
3005 Ok(result) => {
3006 total_usage.prompt_tokens += result.usage.prompt_tokens;
3007 total_usage.completion_tokens += result.usage.completion_tokens;
3008 total_usage.total_tokens += result.usage.total_tokens;
3009 tool_calls_count += result.tool_calls_count;
3010 plan.mark_status(&step_id, TaskStatus::Completed);
3011
3012 parallel_summaries.push(format!(
3014 "- Step {} ({}): {}",
3015 step_number, step_id, result.text
3016 ));
3017
3018 if let Some(tx) = &event_tx {
3019 tx.send(AgentEvent::StepEnd {
3020 step_id,
3021 status: TaskStatus::Completed,
3022 step_number,
3023 total_steps,
3024 })
3025 .await
3026 .ok();
3027 }
3028 }
3029 Err(e) => {
3030 tracing::error!("Plan step '{}' failed: {}", step_id, e);
3031 plan.mark_status(&step_id, TaskStatus::Failed);
3032
3033 if let Some(tx) = &event_tx {
3034 tx.send(AgentEvent::StepEnd {
3035 step_id,
3036 status: TaskStatus::Failed,
3037 step_number,
3038 total_steps,
3039 })
3040 .await
3041 .ok();
3042 }
3043 }
3044 },
3045 Err(e) => {
3046 tracing::error!("JoinSet task panicked: {}", e);
3047 }
3048 }
3049 }
3050
3051 if !parallel_summaries.is_empty() {
3053 parallel_summaries.sort(); let results_text = parallel_summaries.join("\n");
3055 current_history.push(Message::user(&crate::prompts::render(
3056 crate::prompts::PLAN_PARALLEL_RESULTS,
3057 &[("results", &results_text)],
3058 )));
3059 }
3060 }
3061
3062 if self.config.goal_tracking {
3064 let completed = plan
3065 .steps
3066 .iter()
3067 .filter(|s| s.status == TaskStatus::Completed)
3068 .count();
3069 if let Some(tx) = &event_tx {
3070 tx.send(AgentEvent::GoalProgress {
3071 goal: plan.goal.clone(),
3072 progress: plan.progress(),
3073 completed_steps: completed,
3074 total_steps,
3075 })
3076 .await
3077 .ok();
3078 }
3079 }
3080 }
3081
3082 let final_text = current_history
3084 .last()
3085 .map(|m| {
3086 m.content
3087 .iter()
3088 .filter_map(|block| {
3089 if let crate::llm::ContentBlock::Text { text } = block {
3090 Some(text.as_str())
3091 } else {
3092 None
3093 }
3094 })
3095 .collect::<Vec<_>>()
3096 .join("\n")
3097 })
3098 .unwrap_or_default();
3099
3100 Ok(AgentResult {
3101 text: final_text,
3102 messages: current_history,
3103 usage: total_usage,
3104 tool_calls_count,
3105 })
3106 }
3107
3108 pub async fn extract_goal(&self, prompt: &str) -> Result<AgentGoal> {
3113 use crate::planning::LlmPlanner;
3114
3115 match LlmPlanner::extract_goal(&self.llm_client, prompt).await {
3116 Ok(goal) => Ok(goal),
3117 Err(e) => {
3118 tracing::warn!("LLM goal extraction failed, using fallback: {}", e);
3119 Ok(LlmPlanner::fallback_goal(prompt))
3120 }
3121 }
3122 }
3123
3124 pub async fn check_goal_achievement(
3129 &self,
3130 goal: &AgentGoal,
3131 current_state: &str,
3132 ) -> Result<bool> {
3133 use crate::planning::LlmPlanner;
3134
3135 match LlmPlanner::check_achievement(&self.llm_client, goal, current_state).await {
3136 Ok(result) => Ok(result.achieved),
3137 Err(e) => {
3138 tracing::warn!("LLM achievement check failed, using fallback: {}", e);
3139 let result = LlmPlanner::fallback_check_achievement(goal, current_state);
3140 Ok(result.achieved)
3141 }
3142 }
3143 }
3144}
3145
3146#[cfg(test)]
3147mod tests {
3148 use super::*;
3149 use crate::llm::{ContentBlock, StreamEvent};
3150 use crate::permissions::PermissionPolicy;
3151 use crate::tools::ToolExecutor;
3152 use std::path::PathBuf;
3153 use std::sync::atomic::{AtomicUsize, Ordering};
3154
3155 fn test_tool_context() -> ToolContext {
3157 ToolContext::new(PathBuf::from("/tmp"))
3158 }
3159
3160 #[test]
3161 fn test_agent_config_default() {
3162 let config = AgentConfig::default();
3163 assert!(config.prompt_slots.is_empty());
3164 assert!(config.tools.is_empty()); assert_eq!(config.max_tool_rounds, MAX_TOOL_ROUNDS);
3166 assert!(config.permission_checker.is_none());
3167 assert!(config.context_providers.is_empty());
3168 let registry = config
3170 .skill_registry
3171 .expect("skill_registry must be Some by default");
3172 assert!(registry.len() >= 7, "expected at least 7 built-in skills");
3173 assert!(registry.get("code-search").is_some());
3174 assert!(registry.get("find-bugs").is_some());
3175 }
3176
3177 pub(crate) struct MockLlmClient {
3183 responses: std::sync::Mutex<Vec<LlmResponse>>,
3185 pub(crate) call_count: AtomicUsize,
3187 }
3188
3189 impl MockLlmClient {
3190 pub(crate) fn new(responses: Vec<LlmResponse>) -> Self {
3191 Self {
3192 responses: std::sync::Mutex::new(responses),
3193 call_count: AtomicUsize::new(0),
3194 }
3195 }
3196
3197 pub(crate) fn text_response(text: &str) -> LlmResponse {
3199 LlmResponse {
3200 message: Message {
3201 role: "assistant".to_string(),
3202 content: vec![ContentBlock::Text {
3203 text: text.to_string(),
3204 }],
3205 reasoning_content: None,
3206 },
3207 usage: TokenUsage {
3208 prompt_tokens: 10,
3209 completion_tokens: 5,
3210 total_tokens: 15,
3211 cache_read_tokens: None,
3212 cache_write_tokens: None,
3213 },
3214 stop_reason: Some("end_turn".to_string()),
3215 meta: None,
3216 }
3217 }
3218
3219 pub(crate) fn tool_call_response(
3221 tool_id: &str,
3222 tool_name: &str,
3223 args: serde_json::Value,
3224 ) -> LlmResponse {
3225 LlmResponse {
3226 message: Message {
3227 role: "assistant".to_string(),
3228 content: vec![ContentBlock::ToolUse {
3229 id: tool_id.to_string(),
3230 name: tool_name.to_string(),
3231 input: args,
3232 }],
3233 reasoning_content: None,
3234 },
3235 usage: TokenUsage {
3236 prompt_tokens: 10,
3237 completion_tokens: 5,
3238 total_tokens: 15,
3239 cache_read_tokens: None,
3240 cache_write_tokens: None,
3241 },
3242 stop_reason: Some("tool_use".to_string()),
3243 meta: None,
3244 }
3245 }
3246 }
3247
3248 #[async_trait::async_trait]
3249 impl LlmClient for MockLlmClient {
3250 async fn complete(
3251 &self,
3252 _messages: &[Message],
3253 _system: Option<&str>,
3254 _tools: &[ToolDefinition],
3255 ) -> Result<LlmResponse> {
3256 self.call_count.fetch_add(1, Ordering::SeqCst);
3257 let mut responses = self.responses.lock().unwrap();
3258 if responses.is_empty() {
3259 anyhow::bail!("No more mock responses available");
3260 }
3261 Ok(responses.remove(0))
3262 }
3263
3264 async fn complete_streaming(
3265 &self,
3266 _messages: &[Message],
3267 _system: Option<&str>,
3268 _tools: &[ToolDefinition],
3269 ) -> Result<mpsc::Receiver<StreamEvent>> {
3270 self.call_count.fetch_add(1, Ordering::SeqCst);
3271 let mut responses = self.responses.lock().unwrap();
3272 if responses.is_empty() {
3273 anyhow::bail!("No more mock responses available");
3274 }
3275 let response = responses.remove(0);
3276
3277 let (tx, rx) = mpsc::channel(10);
3278 tokio::spawn(async move {
3279 for block in &response.message.content {
3281 if let ContentBlock::Text { text } = block {
3282 tx.send(StreamEvent::TextDelta(text.clone())).await.ok();
3283 }
3284 }
3285 tx.send(StreamEvent::Done(response)).await.ok();
3286 });
3287
3288 Ok(rx)
3289 }
3290 }
3291
3292 #[tokio::test]
3297 async fn test_agent_simple_response() {
3298 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
3299 "Hello, I'm an AI assistant.",
3300 )]));
3301
3302 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3303 let config = AgentConfig::default();
3304
3305 let agent = AgentLoop::new(
3306 mock_client.clone(),
3307 tool_executor,
3308 test_tool_context(),
3309 config,
3310 );
3311 let result = agent.execute(&[], "Hello", None).await.unwrap();
3312
3313 assert_eq!(result.text, "Hello, I'm an AI assistant.");
3314 assert_eq!(result.tool_calls_count, 0);
3315 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 1);
3316 }
3317
3318 #[tokio::test]
3319 async fn test_agent_with_tool_call() {
3320 let mock_client = Arc::new(MockLlmClient::new(vec![
3321 MockLlmClient::tool_call_response(
3323 "tool-1",
3324 "bash",
3325 serde_json::json!({"command": "echo hello"}),
3326 ),
3327 MockLlmClient::text_response("The command output was: hello"),
3329 ]));
3330
3331 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3332 let config = AgentConfig::default();
3333
3334 let agent = AgentLoop::new(
3335 mock_client.clone(),
3336 tool_executor,
3337 test_tool_context(),
3338 config,
3339 );
3340 let result = agent.execute(&[], "Run echo hello", None).await.unwrap();
3341
3342 assert_eq!(result.text, "The command output was: hello");
3343 assert_eq!(result.tool_calls_count, 1);
3344 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 2);
3345 }
3346
3347 #[tokio::test]
3348 async fn test_agent_permission_deny() {
3349 let mock_client = Arc::new(MockLlmClient::new(vec![
3350 MockLlmClient::tool_call_response(
3352 "tool-1",
3353 "bash",
3354 serde_json::json!({"command": "rm -rf /tmp/test"}),
3355 ),
3356 MockLlmClient::text_response(
3358 "I cannot execute that command due to permission restrictions.",
3359 ),
3360 ]));
3361
3362 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3363
3364 let permission_policy = PermissionPolicy::new().deny("bash(rm:*)");
3366
3367 let config = AgentConfig {
3368 permission_checker: Some(Arc::new(permission_policy)),
3369 ..Default::default()
3370 };
3371
3372 let (tx, mut rx) = mpsc::channel(100);
3373 let agent = AgentLoop::new(
3374 mock_client.clone(),
3375 tool_executor,
3376 test_tool_context(),
3377 config,
3378 );
3379 let result = agent.execute(&[], "Delete files", Some(tx)).await.unwrap();
3380
3381 let mut found_permission_denied = false;
3383 while let Ok(event) = rx.try_recv() {
3384 if let AgentEvent::PermissionDenied { tool_name, .. } = event {
3385 assert_eq!(tool_name, "bash");
3386 found_permission_denied = true;
3387 }
3388 }
3389 assert!(
3390 found_permission_denied,
3391 "Should have received PermissionDenied event"
3392 );
3393
3394 assert_eq!(result.tool_calls_count, 1);
3395 }
3396
3397 #[tokio::test]
3398 async fn test_agent_permission_allow() {
3399 let mock_client = Arc::new(MockLlmClient::new(vec![
3400 MockLlmClient::tool_call_response(
3402 "tool-1",
3403 "bash",
3404 serde_json::json!({"command": "echo hello"}),
3405 ),
3406 MockLlmClient::text_response("Done!"),
3408 ]));
3409
3410 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3411
3412 let permission_policy = PermissionPolicy::new()
3414 .allow("bash(echo:*)")
3415 .deny("bash(rm:*)");
3416
3417 let config = AgentConfig {
3418 permission_checker: Some(Arc::new(permission_policy)),
3419 ..Default::default()
3420 };
3421
3422 let agent = AgentLoop::new(
3423 mock_client.clone(),
3424 tool_executor,
3425 test_tool_context(),
3426 config,
3427 );
3428 let result = agent.execute(&[], "Echo hello", None).await.unwrap();
3429
3430 assert_eq!(result.text, "Done!");
3431 assert_eq!(result.tool_calls_count, 1);
3432 }
3433
3434 #[tokio::test]
3435 async fn test_agent_streaming_events() {
3436 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
3437 "Hello!",
3438 )]));
3439
3440 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3441 let config = AgentConfig::default();
3442
3443 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3444 let (mut rx, handle, _cancel_token) = agent.execute_streaming(&[], "Hi").await.unwrap();
3445
3446 let mut events = Vec::new();
3448 while let Some(event) = rx.recv().await {
3449 events.push(event);
3450 }
3451
3452 let result = handle.await.unwrap().unwrap();
3453 assert_eq!(result.text, "Hello!");
3454
3455 assert!(events.iter().any(|e| matches!(e, AgentEvent::Start { .. })));
3457 assert!(events.iter().any(|e| matches!(e, AgentEvent::End { .. })));
3458 }
3459
3460 #[tokio::test]
3461 async fn test_agent_max_tool_rounds() {
3462 let responses: Vec<LlmResponse> = (0..100)
3464 .map(|i| {
3465 MockLlmClient::tool_call_response(
3466 &format!("tool-{}", i),
3467 "bash",
3468 serde_json::json!({"command": "echo loop"}),
3469 )
3470 })
3471 .collect();
3472
3473 let mock_client = Arc::new(MockLlmClient::new(responses));
3474 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3475
3476 let config = AgentConfig {
3477 max_tool_rounds: 3,
3478 ..Default::default()
3479 };
3480
3481 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3482 let result = agent.execute(&[], "Loop forever", None).await;
3483
3484 assert!(result.is_err());
3486 assert!(result.unwrap_err().to_string().contains("Max tool rounds"));
3487 }
3488
3489 #[tokio::test]
3490 async fn test_agent_duplicate_tool_call_loop_circuit_breaker() {
3491 let responses: Vec<LlmResponse> = (0..10)
3492 .map(|i| {
3493 MockLlmClient::tool_call_response(
3494 &format!("tool-{}", i),
3495 "bash",
3496 serde_json::json!({"command": "echo repeated-loop"}),
3497 )
3498 })
3499 .collect();
3500
3501 let mock_client = Arc::new(MockLlmClient::new(responses));
3502 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3503 let config = AgentConfig {
3504 max_tool_rounds: 10,
3505 duplicate_tool_call_threshold: 2,
3506 ..Default::default()
3507 };
3508
3509 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3510 let result = agent
3511 .execute(&[], "Trigger repeated identical tool call loop", None)
3512 .await;
3513
3514 assert!(result.is_err());
3515 assert!(result
3516 .unwrap_err()
3517 .to_string()
3518 .contains("repeated identical tool call loop"));
3519 }
3520
3521 #[tokio::test]
3522 async fn test_agent_no_permission_policy_defaults_to_ask() {
3523 let mock_client = Arc::new(MockLlmClient::new(vec![
3526 MockLlmClient::tool_call_response(
3527 "tool-1",
3528 "bash",
3529 serde_json::json!({"command": "rm -rf /tmp/test"}),
3530 ),
3531 MockLlmClient::text_response("Denied!"),
3532 ]));
3533
3534 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3535 let config = AgentConfig {
3536 permission_checker: None, ..Default::default()
3539 };
3540
3541 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3542 let result = agent.execute(&[], "Delete", None).await.unwrap();
3543
3544 assert_eq!(result.text, "Denied!");
3546 assert_eq!(result.tool_calls_count, 1);
3547 }
3548
3549 #[tokio::test]
3550 async fn test_agent_permission_ask_without_cm_denies() {
3551 let mock_client = Arc::new(MockLlmClient::new(vec![
3554 MockLlmClient::tool_call_response(
3555 "tool-1",
3556 "bash",
3557 serde_json::json!({"command": "echo test"}),
3558 ),
3559 MockLlmClient::text_response("Denied!"),
3560 ]));
3561
3562 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3563
3564 let permission_policy = PermissionPolicy::new(); let config = AgentConfig {
3568 permission_checker: Some(Arc::new(permission_policy)),
3569 ..Default::default()
3571 };
3572
3573 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3574 let result = agent.execute(&[], "Echo", None).await.unwrap();
3575
3576 assert_eq!(result.text, "Denied!");
3578 assert!(result.tool_calls_count >= 1);
3580 }
3581
3582 #[tokio::test]
3587 async fn test_agent_hitl_approved() {
3588 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
3589 use tokio::sync::broadcast;
3590
3591 let mock_client = Arc::new(MockLlmClient::new(vec![
3592 MockLlmClient::tool_call_response(
3593 "tool-1",
3594 "bash",
3595 serde_json::json!({"command": "echo hello"}),
3596 ),
3597 MockLlmClient::text_response("Command executed!"),
3598 ]));
3599
3600 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3601
3602 let (event_tx, _event_rx) = broadcast::channel(100);
3604 let hitl_policy = ConfirmationPolicy {
3605 enabled: true,
3606 ..Default::default()
3607 };
3608 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3609
3610 let permission_policy = PermissionPolicy::new(); let config = AgentConfig {
3614 permission_checker: Some(Arc::new(permission_policy)),
3615 confirmation_manager: Some(confirmation_manager.clone()),
3616 ..Default::default()
3617 };
3618
3619 let cm_clone = confirmation_manager.clone();
3621 tokio::spawn(async move {
3622 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
3624 cm_clone.confirm("tool-1", true, None).await.ok();
3626 });
3627
3628 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3629 let result = agent.execute(&[], "Run echo", None).await.unwrap();
3630
3631 assert_eq!(result.text, "Command executed!");
3632 assert_eq!(result.tool_calls_count, 1);
3633 }
3634
3635 #[tokio::test]
3636 async fn test_agent_hitl_rejected() {
3637 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
3638 use tokio::sync::broadcast;
3639
3640 let mock_client = Arc::new(MockLlmClient::new(vec![
3641 MockLlmClient::tool_call_response(
3642 "tool-1",
3643 "bash",
3644 serde_json::json!({"command": "rm -rf /"}),
3645 ),
3646 MockLlmClient::text_response("Understood, I won't do that."),
3647 ]));
3648
3649 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3650
3651 let (event_tx, _event_rx) = broadcast::channel(100);
3653 let hitl_policy = ConfirmationPolicy {
3654 enabled: true,
3655 ..Default::default()
3656 };
3657 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3658
3659 let permission_policy = PermissionPolicy::new();
3661
3662 let config = AgentConfig {
3663 permission_checker: Some(Arc::new(permission_policy)),
3664 confirmation_manager: Some(confirmation_manager.clone()),
3665 ..Default::default()
3666 };
3667
3668 let cm_clone = confirmation_manager.clone();
3670 tokio::spawn(async move {
3671 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
3672 cm_clone
3673 .confirm("tool-1", false, Some("Too dangerous".to_string()))
3674 .await
3675 .ok();
3676 });
3677
3678 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3679 let result = agent.execute(&[], "Delete everything", None).await.unwrap();
3680
3681 assert_eq!(result.text, "Understood, I won't do that.");
3683 }
3684
3685 #[tokio::test]
3686 async fn test_agent_hitl_timeout_reject() {
3687 use crate::hitl::{ConfirmationManager, ConfirmationPolicy, TimeoutAction};
3688 use tokio::sync::broadcast;
3689
3690 let mock_client = Arc::new(MockLlmClient::new(vec![
3691 MockLlmClient::tool_call_response(
3692 "tool-1",
3693 "bash",
3694 serde_json::json!({"command": "echo test"}),
3695 ),
3696 MockLlmClient::text_response("Timed out, I understand."),
3697 ]));
3698
3699 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3700
3701 let (event_tx, _event_rx) = broadcast::channel(100);
3703 let hitl_policy = ConfirmationPolicy {
3704 enabled: true,
3705 default_timeout_ms: 50, timeout_action: TimeoutAction::Reject,
3707 ..Default::default()
3708 };
3709 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3710
3711 let permission_policy = PermissionPolicy::new();
3712
3713 let config = AgentConfig {
3714 permission_checker: Some(Arc::new(permission_policy)),
3715 confirmation_manager: Some(confirmation_manager),
3716 ..Default::default()
3717 };
3718
3719 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3721 let result = agent.execute(&[], "Echo", None).await.unwrap();
3722
3723 assert_eq!(result.text, "Timed out, I understand.");
3725 }
3726
3727 #[tokio::test]
3728 async fn test_agent_hitl_timeout_auto_approve() {
3729 use crate::hitl::{ConfirmationManager, ConfirmationPolicy, TimeoutAction};
3730 use tokio::sync::broadcast;
3731
3732 let mock_client = Arc::new(MockLlmClient::new(vec![
3733 MockLlmClient::tool_call_response(
3734 "tool-1",
3735 "bash",
3736 serde_json::json!({"command": "echo hello"}),
3737 ),
3738 MockLlmClient::text_response("Auto-approved and executed!"),
3739 ]));
3740
3741 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3742
3743 let (event_tx, _event_rx) = broadcast::channel(100);
3745 let hitl_policy = ConfirmationPolicy {
3746 enabled: true,
3747 default_timeout_ms: 50, timeout_action: TimeoutAction::AutoApprove,
3749 ..Default::default()
3750 };
3751 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3752
3753 let permission_policy = PermissionPolicy::new();
3754
3755 let config = AgentConfig {
3756 permission_checker: Some(Arc::new(permission_policy)),
3757 confirmation_manager: Some(confirmation_manager),
3758 ..Default::default()
3759 };
3760
3761 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3763 let result = agent.execute(&[], "Echo", None).await.unwrap();
3764
3765 assert_eq!(result.text, "Auto-approved and executed!");
3767 assert_eq!(result.tool_calls_count, 1);
3768 }
3769
3770 #[tokio::test]
3771 async fn test_agent_hitl_confirmation_events() {
3772 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
3773 use tokio::sync::broadcast;
3774
3775 let mock_client = Arc::new(MockLlmClient::new(vec![
3776 MockLlmClient::tool_call_response(
3777 "tool-1",
3778 "bash",
3779 serde_json::json!({"command": "echo test"}),
3780 ),
3781 MockLlmClient::text_response("Done!"),
3782 ]));
3783
3784 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3785
3786 let (event_tx, mut event_rx) = broadcast::channel(100);
3788 let hitl_policy = ConfirmationPolicy {
3789 enabled: true,
3790 default_timeout_ms: 5000, ..Default::default()
3792 };
3793 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3794
3795 let permission_policy = PermissionPolicy::new();
3796
3797 let config = AgentConfig {
3798 permission_checker: Some(Arc::new(permission_policy)),
3799 confirmation_manager: Some(confirmation_manager.clone()),
3800 ..Default::default()
3801 };
3802
3803 let cm_clone = confirmation_manager.clone();
3805 let event_handle = tokio::spawn(async move {
3806 let mut events = Vec::new();
3807 while let Ok(event) = event_rx.recv().await {
3809 events.push(event.clone());
3810 if let AgentEvent::ConfirmationRequired { tool_id, .. } = event {
3811 cm_clone.confirm(&tool_id, true, None).await.ok();
3813 if let Ok(recv_event) = event_rx.recv().await {
3815 events.push(recv_event);
3816 }
3817 break;
3818 }
3819 }
3820 events
3821 });
3822
3823 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3824 let _result = agent.execute(&[], "Echo", None).await.unwrap();
3825
3826 let events = event_handle.await.unwrap();
3828 assert!(
3829 events
3830 .iter()
3831 .any(|e| matches!(e, AgentEvent::ConfirmationRequired { .. })),
3832 "Should have ConfirmationRequired event"
3833 );
3834 assert!(
3835 events
3836 .iter()
3837 .any(|e| matches!(e, AgentEvent::ConfirmationReceived { approved: true, .. })),
3838 "Should have ConfirmationReceived event with approved=true"
3839 );
3840 }
3841
3842 #[tokio::test]
3843 async fn test_agent_hitl_disabled_auto_executes() {
3844 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
3846 use tokio::sync::broadcast;
3847
3848 let mock_client = Arc::new(MockLlmClient::new(vec![
3849 MockLlmClient::tool_call_response(
3850 "tool-1",
3851 "bash",
3852 serde_json::json!({"command": "echo auto"}),
3853 ),
3854 MockLlmClient::text_response("Auto executed!"),
3855 ]));
3856
3857 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3858
3859 let (event_tx, _event_rx) = broadcast::channel(100);
3861 let hitl_policy = ConfirmationPolicy {
3862 enabled: false, ..Default::default()
3864 };
3865 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3866
3867 let permission_policy = PermissionPolicy::new(); let config = AgentConfig {
3870 permission_checker: Some(Arc::new(permission_policy)),
3871 confirmation_manager: Some(confirmation_manager),
3872 ..Default::default()
3873 };
3874
3875 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3876 let result = agent.execute(&[], "Echo", None).await.unwrap();
3877
3878 assert_eq!(result.text, "Auto executed!");
3880 assert_eq!(result.tool_calls_count, 1);
3881 }
3882
3883 #[tokio::test]
3884 async fn test_agent_hitl_with_permission_deny_skips_hitl() {
3885 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
3887 use tokio::sync::broadcast;
3888
3889 let mock_client = Arc::new(MockLlmClient::new(vec![
3890 MockLlmClient::tool_call_response(
3891 "tool-1",
3892 "bash",
3893 serde_json::json!({"command": "rm -rf /"}),
3894 ),
3895 MockLlmClient::text_response("Blocked by permission."),
3896 ]));
3897
3898 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3899
3900 let (event_tx, mut event_rx) = broadcast::channel(100);
3902 let hitl_policy = ConfirmationPolicy {
3903 enabled: true,
3904 ..Default::default()
3905 };
3906 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3907
3908 let permission_policy = PermissionPolicy::new().deny("bash(rm:*)");
3910
3911 let config = AgentConfig {
3912 permission_checker: Some(Arc::new(permission_policy)),
3913 confirmation_manager: Some(confirmation_manager),
3914 ..Default::default()
3915 };
3916
3917 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3918 let result = agent.execute(&[], "Delete", None).await.unwrap();
3919
3920 assert_eq!(result.text, "Blocked by permission.");
3922
3923 let mut found_confirmation = false;
3925 while let Ok(event) = event_rx.try_recv() {
3926 if matches!(event, AgentEvent::ConfirmationRequired { .. }) {
3927 found_confirmation = true;
3928 }
3929 }
3930 assert!(
3931 !found_confirmation,
3932 "HITL should not be triggered when permission is Deny"
3933 );
3934 }
3935
3936 #[tokio::test]
3937 async fn test_agent_hitl_with_permission_allow_skips_hitl() {
3938 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
3941 use tokio::sync::broadcast;
3942
3943 let mock_client = Arc::new(MockLlmClient::new(vec![
3944 MockLlmClient::tool_call_response(
3945 "tool-1",
3946 "bash",
3947 serde_json::json!({"command": "echo hello"}),
3948 ),
3949 MockLlmClient::text_response("Allowed!"),
3950 ]));
3951
3952 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3953
3954 let (event_tx, mut event_rx) = broadcast::channel(100);
3956 let hitl_policy = ConfirmationPolicy {
3957 enabled: true,
3958 ..Default::default()
3959 };
3960 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3961
3962 let permission_policy = PermissionPolicy::new().allow("bash(echo:*)");
3964
3965 let config = AgentConfig {
3966 permission_checker: Some(Arc::new(permission_policy)),
3967 confirmation_manager: Some(confirmation_manager.clone()),
3968 ..Default::default()
3969 };
3970
3971 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3972 let result = agent.execute(&[], "Echo", None).await.unwrap();
3973
3974 assert_eq!(result.text, "Allowed!");
3976
3977 let mut found_confirmation = false;
3979 while let Ok(event) = event_rx.try_recv() {
3980 if matches!(event, AgentEvent::ConfirmationRequired { .. }) {
3981 found_confirmation = true;
3982 }
3983 }
3984 assert!(
3985 !found_confirmation,
3986 "Permission Allow should skip HITL confirmation"
3987 );
3988 }
3989
3990 #[tokio::test]
3991 async fn test_agent_hitl_multiple_tool_calls() {
3992 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
3994 use tokio::sync::broadcast;
3995
3996 let mock_client = Arc::new(MockLlmClient::new(vec![
3997 LlmResponse {
3999 message: Message {
4000 role: "assistant".to_string(),
4001 content: vec![
4002 ContentBlock::ToolUse {
4003 id: "tool-1".to_string(),
4004 name: "bash".to_string(),
4005 input: serde_json::json!({"command": "echo first"}),
4006 },
4007 ContentBlock::ToolUse {
4008 id: "tool-2".to_string(),
4009 name: "bash".to_string(),
4010 input: serde_json::json!({"command": "echo second"}),
4011 },
4012 ],
4013 reasoning_content: None,
4014 },
4015 usage: TokenUsage {
4016 prompt_tokens: 10,
4017 completion_tokens: 5,
4018 total_tokens: 15,
4019 cache_read_tokens: None,
4020 cache_write_tokens: None,
4021 },
4022 stop_reason: Some("tool_use".to_string()),
4023 meta: None,
4024 },
4025 MockLlmClient::text_response("Both executed!"),
4026 ]));
4027
4028 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4029
4030 let (event_tx, _event_rx) = broadcast::channel(100);
4032 let hitl_policy = ConfirmationPolicy {
4033 enabled: true,
4034 default_timeout_ms: 5000,
4035 ..Default::default()
4036 };
4037 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
4038
4039 let permission_policy = PermissionPolicy::new(); let config = AgentConfig {
4042 permission_checker: Some(Arc::new(permission_policy)),
4043 confirmation_manager: Some(confirmation_manager.clone()),
4044 ..Default::default()
4045 };
4046
4047 let cm_clone = confirmation_manager.clone();
4049 tokio::spawn(async move {
4050 tokio::time::sleep(std::time::Duration::from_millis(30)).await;
4051 cm_clone.confirm("tool-1", true, None).await.ok();
4052 tokio::time::sleep(std::time::Duration::from_millis(30)).await;
4053 cm_clone.confirm("tool-2", true, None).await.ok();
4054 });
4055
4056 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
4057 let result = agent.execute(&[], "Run both", None).await.unwrap();
4058
4059 assert_eq!(result.text, "Both executed!");
4060 assert_eq!(result.tool_calls_count, 2);
4061 }
4062
4063 #[tokio::test]
4064 async fn test_agent_hitl_partial_approval() {
4065 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
4067 use tokio::sync::broadcast;
4068
4069 let mock_client = Arc::new(MockLlmClient::new(vec![
4070 LlmResponse {
4072 message: Message {
4073 role: "assistant".to_string(),
4074 content: vec![
4075 ContentBlock::ToolUse {
4076 id: "tool-1".to_string(),
4077 name: "bash".to_string(),
4078 input: serde_json::json!({"command": "echo safe"}),
4079 },
4080 ContentBlock::ToolUse {
4081 id: "tool-2".to_string(),
4082 name: "bash".to_string(),
4083 input: serde_json::json!({"command": "rm -rf /"}),
4084 },
4085 ],
4086 reasoning_content: None,
4087 },
4088 usage: TokenUsage {
4089 prompt_tokens: 10,
4090 completion_tokens: 5,
4091 total_tokens: 15,
4092 cache_read_tokens: None,
4093 cache_write_tokens: None,
4094 },
4095 stop_reason: Some("tool_use".to_string()),
4096 meta: None,
4097 },
4098 MockLlmClient::text_response("First worked, second rejected."),
4099 ]));
4100
4101 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4102
4103 let (event_tx, _event_rx) = broadcast::channel(100);
4104 let hitl_policy = ConfirmationPolicy {
4105 enabled: true,
4106 default_timeout_ms: 5000,
4107 ..Default::default()
4108 };
4109 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
4110
4111 let permission_policy = PermissionPolicy::new();
4112
4113 let config = AgentConfig {
4114 permission_checker: Some(Arc::new(permission_policy)),
4115 confirmation_manager: Some(confirmation_manager.clone()),
4116 ..Default::default()
4117 };
4118
4119 let cm_clone = confirmation_manager.clone();
4121 tokio::spawn(async move {
4122 tokio::time::sleep(std::time::Duration::from_millis(30)).await;
4123 cm_clone.confirm("tool-1", true, None).await.ok();
4124 tokio::time::sleep(std::time::Duration::from_millis(30)).await;
4125 cm_clone
4126 .confirm("tool-2", false, Some("Dangerous".to_string()))
4127 .await
4128 .ok();
4129 });
4130
4131 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
4132 let result = agent.execute(&[], "Run both", None).await.unwrap();
4133
4134 assert_eq!(result.text, "First worked, second rejected.");
4135 assert_eq!(result.tool_calls_count, 2);
4136 }
4137
4138 #[tokio::test]
4139 async fn test_agent_hitl_yolo_mode_auto_approves() {
4140 use crate::hitl::{ConfirmationManager, ConfirmationPolicy, SessionLane};
4142 use tokio::sync::broadcast;
4143
4144 let mock_client = Arc::new(MockLlmClient::new(vec![
4145 MockLlmClient::tool_call_response(
4146 "tool-1",
4147 "read", serde_json::json!({"path": "/tmp/test.txt"}),
4149 ),
4150 MockLlmClient::text_response("File read!"),
4151 ]));
4152
4153 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4154
4155 let (event_tx, mut event_rx) = broadcast::channel(100);
4157 let mut yolo_lanes = std::collections::HashSet::new();
4158 yolo_lanes.insert(SessionLane::Query);
4159 let hitl_policy = ConfirmationPolicy {
4160 enabled: true,
4161 yolo_lanes, ..Default::default()
4163 };
4164 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
4165
4166 let permission_policy = PermissionPolicy::new();
4167
4168 let config = AgentConfig {
4169 permission_checker: Some(Arc::new(permission_policy)),
4170 confirmation_manager: Some(confirmation_manager),
4171 ..Default::default()
4172 };
4173
4174 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
4175 let result = agent.execute(&[], "Read file", None).await.unwrap();
4176
4177 assert_eq!(result.text, "File read!");
4179
4180 let mut found_confirmation = false;
4182 while let Ok(event) = event_rx.try_recv() {
4183 if matches!(event, AgentEvent::ConfirmationRequired { .. }) {
4184 found_confirmation = true;
4185 }
4186 }
4187 assert!(
4188 !found_confirmation,
4189 "YOLO mode should not trigger confirmation"
4190 );
4191 }
4192
4193 #[tokio::test]
4194 async fn test_agent_config_with_all_options() {
4195 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
4196 use tokio::sync::broadcast;
4197
4198 let (event_tx, _) = broadcast::channel(100);
4199 let hitl_policy = ConfirmationPolicy::default();
4200 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
4201
4202 let permission_policy = PermissionPolicy::new().allow("bash(*)");
4203
4204 let config = AgentConfig {
4205 prompt_slots: SystemPromptSlots {
4206 extra: Some("Test system prompt".to_string()),
4207 ..Default::default()
4208 },
4209 tools: vec![],
4210 max_tool_rounds: 10,
4211 permission_checker: Some(Arc::new(permission_policy)),
4212 confirmation_manager: Some(confirmation_manager),
4213 context_providers: vec![],
4214 planning_enabled: false,
4215 goal_tracking: false,
4216 hook_engine: None,
4217 skill_registry: None,
4218 ..AgentConfig::default()
4219 };
4220
4221 assert!(config.prompt_slots.build().contains("Test system prompt"));
4222 assert_eq!(config.max_tool_rounds, 10);
4223 assert!(config.permission_checker.is_some());
4224 assert!(config.confirmation_manager.is_some());
4225 assert!(config.context_providers.is_empty());
4226
4227 let debug_str = format!("{:?}", config);
4229 assert!(debug_str.contains("AgentConfig"));
4230 assert!(debug_str.contains("permission_checker: true"));
4231 assert!(debug_str.contains("confirmation_manager: true"));
4232 assert!(debug_str.contains("context_providers: 0"));
4233 }
4234
4235 use crate::context::{ContextItem, ContextType};
4240
4241 struct MockContextProvider {
4243 name: String,
4244 items: Vec<ContextItem>,
4245 on_turn_calls: std::sync::Arc<tokio::sync::RwLock<Vec<(String, String, String)>>>,
4246 }
4247
4248 impl MockContextProvider {
4249 fn new(name: &str) -> Self {
4250 Self {
4251 name: name.to_string(),
4252 items: Vec::new(),
4253 on_turn_calls: std::sync::Arc::new(tokio::sync::RwLock::new(Vec::new())),
4254 }
4255 }
4256
4257 fn with_items(mut self, items: Vec<ContextItem>) -> Self {
4258 self.items = items;
4259 self
4260 }
4261 }
4262
4263 #[async_trait::async_trait]
4264 impl ContextProvider for MockContextProvider {
4265 fn name(&self) -> &str {
4266 &self.name
4267 }
4268
4269 async fn query(&self, _query: &ContextQuery) -> anyhow::Result<ContextResult> {
4270 let mut result = ContextResult::new(&self.name);
4271 for item in &self.items {
4272 result.add_item(item.clone());
4273 }
4274 Ok(result)
4275 }
4276
4277 async fn on_turn_complete(
4278 &self,
4279 session_id: &str,
4280 prompt: &str,
4281 response: &str,
4282 ) -> anyhow::Result<()> {
4283 let mut calls = self.on_turn_calls.write().await;
4284 calls.push((
4285 session_id.to_string(),
4286 prompt.to_string(),
4287 response.to_string(),
4288 ));
4289 Ok(())
4290 }
4291 }
4292
4293 #[tokio::test]
4294 async fn test_agent_with_context_provider() {
4295 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4296 "Response using context",
4297 )]));
4298
4299 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4300
4301 let provider =
4302 MockContextProvider::new("test-provider").with_items(vec![ContextItem::new(
4303 "ctx-1",
4304 ContextType::Resource,
4305 "Relevant context here",
4306 )
4307 .with_source("test://docs/example")]);
4308
4309 let config = AgentConfig {
4310 prompt_slots: SystemPromptSlots {
4311 extra: Some("You are helpful.".to_string()),
4312 ..Default::default()
4313 },
4314 context_providers: vec![Arc::new(provider)],
4315 ..Default::default()
4316 };
4317
4318 let agent = AgentLoop::new(
4319 mock_client.clone(),
4320 tool_executor,
4321 test_tool_context(),
4322 config,
4323 );
4324 let result = agent.execute(&[], "What is X?", None).await.unwrap();
4325
4326 assert_eq!(result.text, "Response using context");
4327 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 1);
4328 }
4329
4330 #[tokio::test]
4331 async fn test_agent_context_provider_events() {
4332 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4333 "Answer",
4334 )]));
4335
4336 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4337
4338 let provider =
4339 MockContextProvider::new("event-provider").with_items(vec![ContextItem::new(
4340 "item-1",
4341 ContextType::Memory,
4342 "Memory content",
4343 )
4344 .with_token_count(50)]);
4345
4346 let config = AgentConfig {
4347 context_providers: vec![Arc::new(provider)],
4348 ..Default::default()
4349 };
4350
4351 let (tx, mut rx) = mpsc::channel(100);
4352 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
4353 let _result = agent.execute(&[], "Test prompt", Some(tx)).await.unwrap();
4354
4355 let mut events = Vec::new();
4357 while let Ok(event) = rx.try_recv() {
4358 events.push(event);
4359 }
4360
4361 assert!(
4363 events
4364 .iter()
4365 .any(|e| matches!(e, AgentEvent::ContextResolving { .. })),
4366 "Should have ContextResolving event"
4367 );
4368 assert!(
4369 events
4370 .iter()
4371 .any(|e| matches!(e, AgentEvent::ContextResolved { .. })),
4372 "Should have ContextResolved event"
4373 );
4374
4375 for event in &events {
4377 if let AgentEvent::ContextResolved {
4378 total_items,
4379 total_tokens,
4380 } = event
4381 {
4382 assert_eq!(*total_items, 1);
4383 assert_eq!(*total_tokens, 50);
4384 }
4385 }
4386 }
4387
4388 #[tokio::test]
4389 async fn test_agent_multiple_context_providers() {
4390 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4391 "Combined response",
4392 )]));
4393
4394 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4395
4396 let provider1 = MockContextProvider::new("provider-1").with_items(vec![ContextItem::new(
4397 "p1-1",
4398 ContextType::Resource,
4399 "Resource from P1",
4400 )
4401 .with_token_count(100)]);
4402
4403 let provider2 = MockContextProvider::new("provider-2").with_items(vec![
4404 ContextItem::new("p2-1", ContextType::Memory, "Memory from P2").with_token_count(50),
4405 ContextItem::new("p2-2", ContextType::Skill, "Skill from P2").with_token_count(75),
4406 ]);
4407
4408 let config = AgentConfig {
4409 prompt_slots: SystemPromptSlots {
4410 extra: Some("Base system prompt.".to_string()),
4411 ..Default::default()
4412 },
4413 context_providers: vec![Arc::new(provider1), Arc::new(provider2)],
4414 ..Default::default()
4415 };
4416
4417 let (tx, mut rx) = mpsc::channel(100);
4418 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
4419 let result = agent.execute(&[], "Query", Some(tx)).await.unwrap();
4420
4421 assert_eq!(result.text, "Combined response");
4422
4423 while let Ok(event) = rx.try_recv() {
4425 if let AgentEvent::ContextResolved {
4426 total_items,
4427 total_tokens,
4428 } = event
4429 {
4430 assert_eq!(total_items, 3); assert_eq!(total_tokens, 225); }
4433 }
4434 }
4435
4436 #[tokio::test]
4437 async fn test_agent_no_context_providers() {
4438 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4439 "No context",
4440 )]));
4441
4442 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4443
4444 let config = AgentConfig::default();
4446
4447 let (tx, mut rx) = mpsc::channel(100);
4448 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
4449 let result = agent.execute(&[], "Simple prompt", Some(tx)).await.unwrap();
4450
4451 assert_eq!(result.text, "No context");
4452
4453 let mut events = Vec::new();
4455 while let Ok(event) = rx.try_recv() {
4456 events.push(event);
4457 }
4458
4459 assert!(
4460 !events
4461 .iter()
4462 .any(|e| matches!(e, AgentEvent::ContextResolving { .. })),
4463 "Should NOT have ContextResolving event"
4464 );
4465 }
4466
4467 #[tokio::test]
4468 async fn test_agent_context_on_turn_complete() {
4469 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4470 "Final response",
4471 )]));
4472
4473 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4474
4475 let provider = Arc::new(MockContextProvider::new("memory-provider"));
4476 let on_turn_calls = provider.on_turn_calls.clone();
4477
4478 let config = AgentConfig {
4479 context_providers: vec![provider],
4480 ..Default::default()
4481 };
4482
4483 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
4484
4485 let result = agent
4487 .execute_with_session(&[], "User prompt", Some("sess-123"), None, None)
4488 .await
4489 .unwrap();
4490
4491 assert_eq!(result.text, "Final response");
4492
4493 let calls = on_turn_calls.read().await;
4495 assert_eq!(calls.len(), 1);
4496 assert_eq!(calls[0].0, "sess-123");
4497 assert_eq!(calls[0].1, "User prompt");
4498 assert_eq!(calls[0].2, "Final response");
4499 }
4500
4501 #[tokio::test]
4502 async fn test_agent_context_on_turn_complete_no_session() {
4503 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4504 "Response",
4505 )]));
4506
4507 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4508
4509 let provider = Arc::new(MockContextProvider::new("memory-provider"));
4510 let on_turn_calls = provider.on_turn_calls.clone();
4511
4512 let config = AgentConfig {
4513 context_providers: vec![provider],
4514 ..Default::default()
4515 };
4516
4517 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
4518
4519 let _result = agent.execute(&[], "Prompt", None).await.unwrap();
4521
4522 let calls = on_turn_calls.read().await;
4524 assert!(calls.is_empty());
4525 }
4526
4527 #[tokio::test]
4528 async fn test_agent_build_augmented_system_prompt() {
4529 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response("OK")]));
4530
4531 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4532
4533 let provider = MockContextProvider::new("test").with_items(vec![ContextItem::new(
4534 "doc-1",
4535 ContextType::Resource,
4536 "Auth uses JWT tokens.",
4537 )
4538 .with_source("viking://docs/auth")]);
4539
4540 let config = AgentConfig {
4541 prompt_slots: SystemPromptSlots {
4542 extra: Some("You are helpful.".to_string()),
4543 ..Default::default()
4544 },
4545 context_providers: vec![Arc::new(provider)],
4546 ..Default::default()
4547 };
4548
4549 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
4550
4551 let context_results = agent.resolve_context("test", None).await;
4553 let augmented = agent.build_augmented_system_prompt(&context_results);
4554
4555 let augmented_str = augmented.unwrap();
4556 assert!(augmented_str.contains("You are helpful."));
4557 assert!(augmented_str.contains("<context source=\"viking://docs/auth\" type=\"Resource\">"));
4558 assert!(augmented_str.contains("Auth uses JWT tokens."));
4559 }
4560
4561 async fn collect_events(mut rx: mpsc::Receiver<AgentEvent>) -> Vec<AgentEvent> {
4567 let mut events = Vec::new();
4568 while let Ok(event) = rx.try_recv() {
4569 events.push(event);
4570 }
4571 while let Some(event) = rx.recv().await {
4573 events.push(event);
4574 }
4575 events
4576 }
4577
4578 #[tokio::test]
4579 async fn test_agent_multi_turn_tool_chain() {
4580 let mock_client = Arc::new(MockLlmClient::new(vec![
4582 MockLlmClient::tool_call_response(
4584 "t1",
4585 "bash",
4586 serde_json::json!({"command": "echo step1"}),
4587 ),
4588 MockLlmClient::tool_call_response(
4590 "t2",
4591 "bash",
4592 serde_json::json!({"command": "echo step2"}),
4593 ),
4594 MockLlmClient::text_response("Completed both steps: step1 then step2"),
4596 ]));
4597
4598 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4599 let config = AgentConfig::default();
4600
4601 let agent = AgentLoop::new(
4602 mock_client.clone(),
4603 tool_executor,
4604 test_tool_context(),
4605 config,
4606 );
4607 let result = agent.execute(&[], "Run two steps", None).await.unwrap();
4608
4609 assert_eq!(result.text, "Completed both steps: step1 then step2");
4610 assert_eq!(result.tool_calls_count, 2);
4611 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 3);
4612
4613 assert_eq!(result.messages[0].role, "user");
4615 assert_eq!(result.messages[1].role, "assistant"); assert_eq!(result.messages[2].role, "user"); assert_eq!(result.messages[3].role, "assistant"); assert_eq!(result.messages[4].role, "user"); assert_eq!(result.messages[5].role, "assistant"); assert_eq!(result.messages.len(), 6);
4621 }
4622
4623 #[tokio::test]
4624 async fn test_agent_conversation_history_preserved() {
4625 let existing_history = vec![
4627 Message::user("What is Rust?"),
4628 Message {
4629 role: "assistant".to_string(),
4630 content: vec![ContentBlock::Text {
4631 text: "Rust is a systems programming language.".to_string(),
4632 }],
4633 reasoning_content: None,
4634 },
4635 ];
4636
4637 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4638 "Rust was created by Graydon Hoare at Mozilla.",
4639 )]));
4640
4641 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4642 let agent = AgentLoop::new(
4643 mock_client.clone(),
4644 tool_executor,
4645 test_tool_context(),
4646 AgentConfig::default(),
4647 );
4648
4649 let result = agent
4650 .execute(&existing_history, "Who created it?", None)
4651 .await
4652 .unwrap();
4653
4654 assert_eq!(result.messages.len(), 4);
4656 assert_eq!(result.messages[0].text(), "What is Rust?");
4657 assert_eq!(
4658 result.messages[1].text(),
4659 "Rust is a systems programming language."
4660 );
4661 assert_eq!(result.messages[2].text(), "Who created it?");
4662 assert_eq!(
4663 result.messages[3].text(),
4664 "Rust was created by Graydon Hoare at Mozilla."
4665 );
4666 }
4667
4668 #[tokio::test]
4669 async fn test_agent_event_stream_completeness() {
4670 let mock_client = Arc::new(MockLlmClient::new(vec![
4672 MockLlmClient::tool_call_response(
4673 "t1",
4674 "bash",
4675 serde_json::json!({"command": "echo hi"}),
4676 ),
4677 MockLlmClient::text_response("Done"),
4678 ]));
4679
4680 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4681 let agent = AgentLoop::new(
4682 mock_client,
4683 tool_executor,
4684 test_tool_context(),
4685 AgentConfig::default(),
4686 );
4687
4688 let (tx, rx) = mpsc::channel(100);
4689 let result = agent.execute(&[], "Say hi", Some(tx)).await.unwrap();
4690 assert_eq!(result.text, "Done");
4691
4692 let events = collect_events(rx).await;
4693
4694 let event_types: Vec<&str> = events
4696 .iter()
4697 .map(|e| match e {
4698 AgentEvent::Start { .. } => "Start",
4699 AgentEvent::TurnStart { .. } => "TurnStart",
4700 AgentEvent::TurnEnd { .. } => "TurnEnd",
4701 AgentEvent::ToolEnd { .. } => "ToolEnd",
4702 AgentEvent::End { .. } => "End",
4703 _ => "Other",
4704 })
4705 .collect();
4706
4707 assert_eq!(event_types.first(), Some(&"Start"));
4709 assert_eq!(event_types.last(), Some(&"End"));
4710
4711 let turn_starts = event_types.iter().filter(|&&t| t == "TurnStart").count();
4713 assert_eq!(turn_starts, 2);
4714
4715 let tool_ends = event_types.iter().filter(|&&t| t == "ToolEnd").count();
4717 assert_eq!(tool_ends, 1);
4718 }
4719
4720 #[tokio::test]
4721 async fn test_agent_multiple_tools_single_turn() {
4722 let mock_client = Arc::new(MockLlmClient::new(vec![
4724 LlmResponse {
4725 message: Message {
4726 role: "assistant".to_string(),
4727 content: vec![
4728 ContentBlock::ToolUse {
4729 id: "t1".to_string(),
4730 name: "bash".to_string(),
4731 input: serde_json::json!({"command": "echo first"}),
4732 },
4733 ContentBlock::ToolUse {
4734 id: "t2".to_string(),
4735 name: "bash".to_string(),
4736 input: serde_json::json!({"command": "echo second"}),
4737 },
4738 ],
4739 reasoning_content: None,
4740 },
4741 usage: TokenUsage {
4742 prompt_tokens: 10,
4743 completion_tokens: 5,
4744 total_tokens: 15,
4745 cache_read_tokens: None,
4746 cache_write_tokens: None,
4747 },
4748 stop_reason: Some("tool_use".to_string()),
4749 meta: None,
4750 },
4751 MockLlmClient::text_response("Both commands ran"),
4752 ]));
4753
4754 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4755 let agent = AgentLoop::new(
4756 mock_client.clone(),
4757 tool_executor,
4758 test_tool_context(),
4759 AgentConfig::default(),
4760 );
4761
4762 let result = agent.execute(&[], "Run both", None).await.unwrap();
4763
4764 assert_eq!(result.text, "Both commands ran");
4765 assert_eq!(result.tool_calls_count, 2);
4766 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 2); assert_eq!(result.messages[0].role, "user");
4770 assert_eq!(result.messages[1].role, "assistant");
4771 assert_eq!(result.messages[2].role, "user"); assert_eq!(result.messages[3].role, "user"); assert_eq!(result.messages[4].role, "assistant");
4774 }
4775
4776 #[tokio::test]
4777 async fn test_agent_token_usage_accumulation() {
4778 let mock_client = Arc::new(MockLlmClient::new(vec![
4780 MockLlmClient::tool_call_response(
4781 "t1",
4782 "bash",
4783 serde_json::json!({"command": "echo x"}),
4784 ),
4785 MockLlmClient::text_response("Done"),
4786 ]));
4787
4788 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4789 let agent = AgentLoop::new(
4790 mock_client,
4791 tool_executor,
4792 test_tool_context(),
4793 AgentConfig::default(),
4794 );
4795
4796 let result = agent.execute(&[], "test", None).await.unwrap();
4797
4798 assert_eq!(result.usage.prompt_tokens, 20);
4801 assert_eq!(result.usage.completion_tokens, 10);
4802 assert_eq!(result.usage.total_tokens, 30);
4803 }
4804
4805 #[tokio::test]
4806 async fn test_agent_system_prompt_passed() {
4807 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4809 "I am a coding assistant.",
4810 )]));
4811
4812 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4813 let config = AgentConfig {
4814 prompt_slots: SystemPromptSlots {
4815 extra: Some("You are a coding assistant.".to_string()),
4816 ..Default::default()
4817 },
4818 ..Default::default()
4819 };
4820
4821 let agent = AgentLoop::new(
4822 mock_client.clone(),
4823 tool_executor,
4824 test_tool_context(),
4825 config,
4826 );
4827 let result = agent.execute(&[], "What are you?", None).await.unwrap();
4828
4829 assert_eq!(result.text, "I am a coding assistant.");
4830 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 1);
4831 }
4832
4833 #[tokio::test]
4834 async fn test_agent_max_rounds_with_persistent_tool_calls() {
4835 let mut responses = Vec::new();
4837 for i in 0..15 {
4838 responses.push(MockLlmClient::tool_call_response(
4839 &format!("t{}", i),
4840 "bash",
4841 serde_json::json!({"command": format!("echo round{}", i)}),
4842 ));
4843 }
4844
4845 let mock_client = Arc::new(MockLlmClient::new(responses));
4846 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4847 let config = AgentConfig {
4848 max_tool_rounds: 5,
4849 ..Default::default()
4850 };
4851
4852 let agent = AgentLoop::new(
4853 mock_client.clone(),
4854 tool_executor,
4855 test_tool_context(),
4856 config,
4857 );
4858 let result = agent.execute(&[], "Loop forever", None).await;
4859
4860 assert!(result.is_err());
4861 let err = result.unwrap_err().to_string();
4862 assert!(err.contains("Max tool rounds (5) exceeded"));
4863 }
4864
4865 #[tokio::test]
4866 async fn test_agent_end_event_contains_final_text() {
4867 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4868 "Final answer here",
4869 )]));
4870
4871 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4872 let agent = AgentLoop::new(
4873 mock_client,
4874 tool_executor,
4875 test_tool_context(),
4876 AgentConfig::default(),
4877 );
4878
4879 let (tx, rx) = mpsc::channel(100);
4880 agent.execute(&[], "test", Some(tx)).await.unwrap();
4881
4882 let events = collect_events(rx).await;
4883 let end_event = events.iter().find(|e| matches!(e, AgentEvent::End { .. }));
4884 assert!(end_event.is_some());
4885
4886 if let AgentEvent::End { text, usage, .. } = end_event.unwrap() {
4887 assert_eq!(text, "Final answer here");
4888 assert_eq!(usage.total_tokens, 15);
4889 }
4890 }
4891}
4892
4893#[cfg(test)]
4894mod extra_agent_tests {
4895 use super::*;
4896 use crate::agent::tests::MockLlmClient;
4897 use crate::queue::SessionQueueConfig;
4898 use crate::tools::ToolExecutor;
4899 use std::path::PathBuf;
4900 use std::sync::atomic::{AtomicUsize, Ordering};
4901
4902 fn test_tool_context() -> ToolContext {
4903 ToolContext::new(PathBuf::from("/tmp"))
4904 }
4905
4906 #[test]
4911 fn test_agent_config_debug() {
4912 let config = AgentConfig {
4913 prompt_slots: SystemPromptSlots {
4914 extra: Some("You are helpful".to_string()),
4915 ..Default::default()
4916 },
4917 tools: vec![],
4918 max_tool_rounds: 10,
4919 permission_checker: None,
4920 confirmation_manager: None,
4921 context_providers: vec![],
4922 planning_enabled: true,
4923 goal_tracking: false,
4924 hook_engine: None,
4925 skill_registry: None,
4926 ..AgentConfig::default()
4927 };
4928 let debug = format!("{:?}", config);
4929 assert!(debug.contains("AgentConfig"));
4930 assert!(debug.contains("planning_enabled"));
4931 }
4932
4933 #[test]
4934 fn test_agent_config_default_values() {
4935 let config = AgentConfig::default();
4936 assert_eq!(config.max_tool_rounds, MAX_TOOL_ROUNDS);
4937 assert!(!config.planning_enabled);
4938 assert!(!config.goal_tracking);
4939 assert!(config.context_providers.is_empty());
4940 }
4941
4942 #[test]
4947 fn test_agent_event_serialize_start() {
4948 let event = AgentEvent::Start {
4949 prompt: "Hello".to_string(),
4950 };
4951 let json = serde_json::to_string(&event).unwrap();
4952 assert!(json.contains("agent_start"));
4953 assert!(json.contains("Hello"));
4954 }
4955
4956 #[test]
4957 fn test_agent_event_serialize_text_delta() {
4958 let event = AgentEvent::TextDelta {
4959 text: "chunk".to_string(),
4960 };
4961 let json = serde_json::to_string(&event).unwrap();
4962 assert!(json.contains("text_delta"));
4963 }
4964
4965 #[test]
4966 fn test_agent_event_serialize_tool_start() {
4967 let event = AgentEvent::ToolStart {
4968 id: "t1".to_string(),
4969 name: "bash".to_string(),
4970 };
4971 let json = serde_json::to_string(&event).unwrap();
4972 assert!(json.contains("tool_start"));
4973 assert!(json.contains("bash"));
4974 }
4975
4976 #[test]
4977 fn test_agent_event_serialize_tool_end() {
4978 let event = AgentEvent::ToolEnd {
4979 id: "t1".to_string(),
4980 name: "bash".to_string(),
4981 output: "hello".to_string(),
4982 exit_code: 0,
4983 metadata: None,
4984 };
4985 let json = serde_json::to_string(&event).unwrap();
4986 assert!(json.contains("tool_end"));
4987 }
4988
4989 #[test]
4990 fn test_agent_event_tool_end_has_metadata_field() {
4991 let event = AgentEvent::ToolEnd {
4992 id: "t1".to_string(),
4993 name: "write".to_string(),
4994 output: "Wrote 5 bytes".to_string(),
4995 exit_code: 0,
4996 metadata: Some(
4997 serde_json::json!({ "before": "old", "after": "new", "file_path": "f.txt" }),
4998 ),
4999 };
5000 let json = serde_json::to_string(&event).unwrap();
5001 assert!(json.contains("\"before\""));
5002 }
5003
5004 #[test]
5005 fn test_agent_event_serialize_error() {
5006 let event = AgentEvent::Error {
5007 message: "oops".to_string(),
5008 };
5009 let json = serde_json::to_string(&event).unwrap();
5010 assert!(json.contains("error"));
5011 assert!(json.contains("oops"));
5012 }
5013
5014 #[test]
5015 fn test_agent_event_serialize_confirmation_required() {
5016 let event = AgentEvent::ConfirmationRequired {
5017 tool_id: "t1".to_string(),
5018 tool_name: "bash".to_string(),
5019 args: serde_json::json!({"cmd": "rm"}),
5020 timeout_ms: 30000,
5021 };
5022 let json = serde_json::to_string(&event).unwrap();
5023 assert!(json.contains("confirmation_required"));
5024 }
5025
5026 #[test]
5027 fn test_agent_event_serialize_confirmation_received() {
5028 let event = AgentEvent::ConfirmationReceived {
5029 tool_id: "t1".to_string(),
5030 approved: true,
5031 reason: Some("safe".to_string()),
5032 };
5033 let json = serde_json::to_string(&event).unwrap();
5034 assert!(json.contains("confirmation_received"));
5035 }
5036
5037 #[test]
5038 fn test_agent_event_serialize_confirmation_timeout() {
5039 let event = AgentEvent::ConfirmationTimeout {
5040 tool_id: "t1".to_string(),
5041 action_taken: "rejected".to_string(),
5042 };
5043 let json = serde_json::to_string(&event).unwrap();
5044 assert!(json.contains("confirmation_timeout"));
5045 }
5046
5047 #[test]
5048 fn test_agent_event_serialize_external_task_pending() {
5049 let event = AgentEvent::ExternalTaskPending {
5050 task_id: "task-1".to_string(),
5051 session_id: "sess-1".to_string(),
5052 lane: crate::hitl::SessionLane::Execute,
5053 command_type: "bash".to_string(),
5054 payload: serde_json::json!({}),
5055 timeout_ms: 60000,
5056 };
5057 let json = serde_json::to_string(&event).unwrap();
5058 assert!(json.contains("external_task_pending"));
5059 }
5060
5061 #[test]
5062 fn test_agent_event_serialize_external_task_completed() {
5063 let event = AgentEvent::ExternalTaskCompleted {
5064 task_id: "task-1".to_string(),
5065 session_id: "sess-1".to_string(),
5066 success: false,
5067 };
5068 let json = serde_json::to_string(&event).unwrap();
5069 assert!(json.contains("external_task_completed"));
5070 }
5071
5072 #[test]
5073 fn test_agent_event_serialize_permission_denied() {
5074 let event = AgentEvent::PermissionDenied {
5075 tool_id: "t1".to_string(),
5076 tool_name: "bash".to_string(),
5077 args: serde_json::json!({}),
5078 reason: "denied".to_string(),
5079 };
5080 let json = serde_json::to_string(&event).unwrap();
5081 assert!(json.contains("permission_denied"));
5082 }
5083
5084 #[test]
5085 fn test_agent_event_serialize_context_compacted() {
5086 let event = AgentEvent::ContextCompacted {
5087 session_id: "sess-1".to_string(),
5088 before_messages: 100,
5089 after_messages: 20,
5090 percent_before: 0.85,
5091 };
5092 let json = serde_json::to_string(&event).unwrap();
5093 assert!(json.contains("context_compacted"));
5094 }
5095
5096 #[test]
5097 fn test_agent_event_serialize_turn_start() {
5098 let event = AgentEvent::TurnStart { turn: 3 };
5099 let json = serde_json::to_string(&event).unwrap();
5100 assert!(json.contains("turn_start"));
5101 }
5102
5103 #[test]
5104 fn test_agent_event_serialize_turn_end() {
5105 let event = AgentEvent::TurnEnd {
5106 turn: 3,
5107 usage: TokenUsage::default(),
5108 };
5109 let json = serde_json::to_string(&event).unwrap();
5110 assert!(json.contains("turn_end"));
5111 }
5112
5113 #[test]
5114 fn test_agent_event_serialize_end() {
5115 let event = AgentEvent::End {
5116 text: "Done".to_string(),
5117 usage: TokenUsage {
5118 prompt_tokens: 100,
5119 completion_tokens: 50,
5120 total_tokens: 150,
5121 cache_read_tokens: None,
5122 cache_write_tokens: None,
5123 },
5124 meta: None,
5125 };
5126 let json = serde_json::to_string(&event).unwrap();
5127 assert!(json.contains("agent_end"));
5128 }
5129
5130 #[test]
5135 fn test_agent_result_fields() {
5136 let result = AgentResult {
5137 text: "output".to_string(),
5138 messages: vec![Message::user("hello")],
5139 usage: TokenUsage::default(),
5140 tool_calls_count: 3,
5141 };
5142 assert_eq!(result.text, "output");
5143 assert_eq!(result.messages.len(), 1);
5144 assert_eq!(result.tool_calls_count, 3);
5145 }
5146
5147 #[test]
5152 fn test_agent_event_serialize_context_resolving() {
5153 let event = AgentEvent::ContextResolving {
5154 providers: vec!["provider1".to_string(), "provider2".to_string()],
5155 };
5156 let json = serde_json::to_string(&event).unwrap();
5157 assert!(json.contains("context_resolving"));
5158 assert!(json.contains("provider1"));
5159 }
5160
5161 #[test]
5162 fn test_agent_event_serialize_context_resolved() {
5163 let event = AgentEvent::ContextResolved {
5164 total_items: 5,
5165 total_tokens: 1000,
5166 };
5167 let json = serde_json::to_string(&event).unwrap();
5168 assert!(json.contains("context_resolved"));
5169 assert!(json.contains("1000"));
5170 }
5171
5172 #[test]
5173 fn test_agent_event_serialize_command_dead_lettered() {
5174 let event = AgentEvent::CommandDeadLettered {
5175 command_id: "cmd-1".to_string(),
5176 command_type: "bash".to_string(),
5177 lane: "execute".to_string(),
5178 error: "timeout".to_string(),
5179 attempts: 3,
5180 };
5181 let json = serde_json::to_string(&event).unwrap();
5182 assert!(json.contains("command_dead_lettered"));
5183 assert!(json.contains("cmd-1"));
5184 }
5185
5186 #[test]
5187 fn test_agent_event_serialize_command_retry() {
5188 let event = AgentEvent::CommandRetry {
5189 command_id: "cmd-2".to_string(),
5190 command_type: "read".to_string(),
5191 lane: "query".to_string(),
5192 attempt: 2,
5193 delay_ms: 1000,
5194 };
5195 let json = serde_json::to_string(&event).unwrap();
5196 assert!(json.contains("command_retry"));
5197 assert!(json.contains("cmd-2"));
5198 }
5199
5200 #[test]
5201 fn test_agent_event_serialize_queue_alert() {
5202 let event = AgentEvent::QueueAlert {
5203 level: "warning".to_string(),
5204 alert_type: "depth".to_string(),
5205 message: "Queue depth exceeded".to_string(),
5206 };
5207 let json = serde_json::to_string(&event).unwrap();
5208 assert!(json.contains("queue_alert"));
5209 assert!(json.contains("warning"));
5210 }
5211
5212 #[test]
5213 fn test_agent_event_serialize_task_updated() {
5214 let event = AgentEvent::TaskUpdated {
5215 session_id: "sess-1".to_string(),
5216 tasks: vec![],
5217 };
5218 let json = serde_json::to_string(&event).unwrap();
5219 assert!(json.contains("task_updated"));
5220 assert!(json.contains("sess-1"));
5221 }
5222
5223 #[test]
5224 fn test_agent_event_serialize_memory_stored() {
5225 let event = AgentEvent::MemoryStored {
5226 memory_id: "mem-1".to_string(),
5227 memory_type: "conversation".to_string(),
5228 importance: 0.8,
5229 tags: vec!["important".to_string()],
5230 };
5231 let json = serde_json::to_string(&event).unwrap();
5232 assert!(json.contains("memory_stored"));
5233 assert!(json.contains("mem-1"));
5234 }
5235
5236 #[test]
5237 fn test_agent_event_serialize_memory_recalled() {
5238 let event = AgentEvent::MemoryRecalled {
5239 memory_id: "mem-2".to_string(),
5240 content: "Previous conversation".to_string(),
5241 relevance: 0.9,
5242 };
5243 let json = serde_json::to_string(&event).unwrap();
5244 assert!(json.contains("memory_recalled"));
5245 assert!(json.contains("mem-2"));
5246 }
5247
5248 #[test]
5249 fn test_agent_event_serialize_memories_searched() {
5250 let event = AgentEvent::MemoriesSearched {
5251 query: Some("search term".to_string()),
5252 tags: vec!["tag1".to_string()],
5253 result_count: 5,
5254 };
5255 let json = serde_json::to_string(&event).unwrap();
5256 assert!(json.contains("memories_searched"));
5257 assert!(json.contains("search term"));
5258 }
5259
5260 #[test]
5261 fn test_agent_event_serialize_memory_cleared() {
5262 let event = AgentEvent::MemoryCleared {
5263 tier: "short_term".to_string(),
5264 count: 10,
5265 };
5266 let json = serde_json::to_string(&event).unwrap();
5267 assert!(json.contains("memory_cleared"));
5268 assert!(json.contains("short_term"));
5269 }
5270
5271 #[test]
5272 fn test_agent_event_serialize_subagent_start() {
5273 let event = AgentEvent::SubagentStart {
5274 task_id: "task-1".to_string(),
5275 session_id: "child-sess".to_string(),
5276 parent_session_id: "parent-sess".to_string(),
5277 agent: "explore".to_string(),
5278 description: "Explore codebase".to_string(),
5279 };
5280 let json = serde_json::to_string(&event).unwrap();
5281 assert!(json.contains("subagent_start"));
5282 assert!(json.contains("explore"));
5283 }
5284
5285 #[test]
5286 fn test_agent_event_serialize_subagent_progress() {
5287 let event = AgentEvent::SubagentProgress {
5288 task_id: "task-1".to_string(),
5289 session_id: "child-sess".to_string(),
5290 status: "processing".to_string(),
5291 metadata: serde_json::json!({"progress": 50}),
5292 };
5293 let json = serde_json::to_string(&event).unwrap();
5294 assert!(json.contains("subagent_progress"));
5295 assert!(json.contains("processing"));
5296 }
5297
5298 #[test]
5299 fn test_agent_event_serialize_subagent_end() {
5300 let event = AgentEvent::SubagentEnd {
5301 task_id: "task-1".to_string(),
5302 session_id: "child-sess".to_string(),
5303 agent: "explore".to_string(),
5304 output: "Found 10 files".to_string(),
5305 success: true,
5306 };
5307 let json = serde_json::to_string(&event).unwrap();
5308 assert!(json.contains("subagent_end"));
5309 assert!(json.contains("Found 10 files"));
5310 }
5311
5312 #[test]
5313 fn test_agent_event_serialize_planning_start() {
5314 let event = AgentEvent::PlanningStart {
5315 prompt: "Build a web app".to_string(),
5316 };
5317 let json = serde_json::to_string(&event).unwrap();
5318 assert!(json.contains("planning_start"));
5319 assert!(json.contains("Build a web app"));
5320 }
5321
5322 #[test]
5323 fn test_agent_event_serialize_planning_end() {
5324 use crate::planning::{Complexity, ExecutionPlan};
5325 let plan = ExecutionPlan::new("Test goal".to_string(), Complexity::Simple);
5326 let event = AgentEvent::PlanningEnd {
5327 plan,
5328 estimated_steps: 3,
5329 };
5330 let json = serde_json::to_string(&event).unwrap();
5331 assert!(json.contains("planning_end"));
5332 assert!(json.contains("estimated_steps"));
5333 }
5334
5335 #[test]
5336 fn test_agent_event_serialize_step_start() {
5337 let event = AgentEvent::StepStart {
5338 step_id: "step-1".to_string(),
5339 description: "Initialize project".to_string(),
5340 step_number: 1,
5341 total_steps: 5,
5342 };
5343 let json = serde_json::to_string(&event).unwrap();
5344 assert!(json.contains("step_start"));
5345 assert!(json.contains("Initialize project"));
5346 }
5347
5348 #[test]
5349 fn test_agent_event_serialize_step_end() {
5350 let event = AgentEvent::StepEnd {
5351 step_id: "step-1".to_string(),
5352 status: TaskStatus::Completed,
5353 step_number: 1,
5354 total_steps: 5,
5355 };
5356 let json = serde_json::to_string(&event).unwrap();
5357 assert!(json.contains("step_end"));
5358 assert!(json.contains("step-1"));
5359 }
5360
5361 #[test]
5362 fn test_agent_event_serialize_goal_extracted() {
5363 use crate::planning::AgentGoal;
5364 let goal = AgentGoal::new("Complete the task".to_string());
5365 let event = AgentEvent::GoalExtracted { goal };
5366 let json = serde_json::to_string(&event).unwrap();
5367 assert!(json.contains("goal_extracted"));
5368 }
5369
5370 #[test]
5371 fn test_agent_event_serialize_goal_progress() {
5372 let event = AgentEvent::GoalProgress {
5373 goal: "Build app".to_string(),
5374 progress: 0.5,
5375 completed_steps: 2,
5376 total_steps: 4,
5377 };
5378 let json = serde_json::to_string(&event).unwrap();
5379 assert!(json.contains("goal_progress"));
5380 assert!(json.contains("0.5"));
5381 }
5382
5383 #[test]
5384 fn test_agent_event_serialize_goal_achieved() {
5385 let event = AgentEvent::GoalAchieved {
5386 goal: "Build app".to_string(),
5387 total_steps: 4,
5388 duration_ms: 5000,
5389 };
5390 let json = serde_json::to_string(&event).unwrap();
5391 assert!(json.contains("goal_achieved"));
5392 assert!(json.contains("5000"));
5393 }
5394
5395 #[tokio::test]
5396 async fn test_extract_goal_with_json_response() {
5397 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
5399 r#"{"description": "Build web app", "success_criteria": ["App runs on port 3000", "Has login page"]}"#,
5400 )]));
5401 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5402 let agent = AgentLoop::new(
5403 mock_client,
5404 tool_executor,
5405 test_tool_context(),
5406 AgentConfig::default(),
5407 );
5408
5409 let goal = agent.extract_goal("Build a web app").await.unwrap();
5410 assert_eq!(goal.description, "Build web app");
5411 assert_eq!(goal.success_criteria.len(), 2);
5412 assert_eq!(goal.success_criteria[0], "App runs on port 3000");
5413 }
5414
5415 #[tokio::test]
5416 async fn test_extract_goal_fallback_on_non_json() {
5417 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
5419 "Some non-JSON response",
5420 )]));
5421 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5422 let agent = AgentLoop::new(
5423 mock_client,
5424 tool_executor,
5425 test_tool_context(),
5426 AgentConfig::default(),
5427 );
5428
5429 let goal = agent.extract_goal("Do something").await.unwrap();
5430 assert_eq!(goal.description, "Do something");
5432 assert_eq!(goal.success_criteria.len(), 2);
5434 }
5435
5436 #[tokio::test]
5437 async fn test_check_goal_achievement_json_yes() {
5438 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
5439 r#"{"achieved": true, "progress": 1.0, "remaining_criteria": []}"#,
5440 )]));
5441 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5442 let agent = AgentLoop::new(
5443 mock_client,
5444 tool_executor,
5445 test_tool_context(),
5446 AgentConfig::default(),
5447 );
5448
5449 let goal = crate::planning::AgentGoal::new("Test goal".to_string());
5450 let achieved = agent
5451 .check_goal_achievement(&goal, "All done")
5452 .await
5453 .unwrap();
5454 assert!(achieved);
5455 }
5456
5457 #[tokio::test]
5458 async fn test_check_goal_achievement_fallback_not_done() {
5459 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
5461 "invalid json",
5462 )]));
5463 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5464 let agent = AgentLoop::new(
5465 mock_client,
5466 tool_executor,
5467 test_tool_context(),
5468 AgentConfig::default(),
5469 );
5470
5471 let goal = crate::planning::AgentGoal::new("Test goal".to_string());
5472 let achieved = agent
5474 .check_goal_achievement(&goal, "still working")
5475 .await
5476 .unwrap();
5477 assert!(!achieved);
5478 }
5479
5480 #[test]
5485 fn test_build_augmented_system_prompt_empty_context() {
5486 let mock_client = Arc::new(MockLlmClient::new(vec![]));
5487 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5488 let config = AgentConfig {
5489 prompt_slots: SystemPromptSlots {
5490 extra: Some("Base prompt".to_string()),
5491 ..Default::default()
5492 },
5493 ..Default::default()
5494 };
5495 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
5496
5497 let result = agent.build_augmented_system_prompt(&[]);
5498 assert!(result.unwrap().contains("Base prompt"));
5499 }
5500
5501 #[test]
5502 fn test_build_augmented_system_prompt_no_custom_slots() {
5503 let mock_client = Arc::new(MockLlmClient::new(vec![]));
5504 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5505 let agent = AgentLoop::new(
5506 mock_client,
5507 tool_executor,
5508 test_tool_context(),
5509 AgentConfig::default(),
5510 );
5511
5512 let result = agent.build_augmented_system_prompt(&[]);
5513 assert!(result.is_some());
5515 assert!(result.unwrap().contains("Core Behaviour"));
5516 }
5517
5518 #[test]
5519 fn test_build_augmented_system_prompt_with_context_no_base() {
5520 use crate::context::{ContextItem, ContextResult, ContextType};
5521
5522 let mock_client = Arc::new(MockLlmClient::new(vec![]));
5523 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5524 let agent = AgentLoop::new(
5525 mock_client,
5526 tool_executor,
5527 test_tool_context(),
5528 AgentConfig::default(),
5529 );
5530
5531 let context = vec![ContextResult {
5532 provider: "test".to_string(),
5533 items: vec![ContextItem::new("id1", ContextType::Resource, "Content")],
5534 total_tokens: 10,
5535 truncated: false,
5536 }];
5537
5538 let result = agent.build_augmented_system_prompt(&context);
5539 assert!(result.is_some());
5540 let text = result.unwrap();
5541 assert!(text.contains("<context"));
5542 assert!(text.contains("Content"));
5543 }
5544
5545 #[test]
5550 fn test_agent_result_clone() {
5551 let result = AgentResult {
5552 text: "output".to_string(),
5553 messages: vec![Message::user("hello")],
5554 usage: TokenUsage::default(),
5555 tool_calls_count: 3,
5556 };
5557 let cloned = result.clone();
5558 assert_eq!(cloned.text, result.text);
5559 assert_eq!(cloned.tool_calls_count, result.tool_calls_count);
5560 }
5561
5562 #[test]
5563 fn test_agent_result_debug() {
5564 let result = AgentResult {
5565 text: "output".to_string(),
5566 messages: vec![Message::user("hello")],
5567 usage: TokenUsage::default(),
5568 tool_calls_count: 3,
5569 };
5570 let debug = format!("{:?}", result);
5571 assert!(debug.contains("AgentResult"));
5572 assert!(debug.contains("output"));
5573 }
5574
5575 #[tokio::test]
5584 async fn test_tool_command_command_type() {
5585 let executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5586 let cmd = ToolCommand {
5587 tool_executor: executor,
5588 tool_name: "read".to_string(),
5589 tool_args: serde_json::json!({"file": "test.rs"}),
5590 skill_registry: None,
5591 tool_context: test_tool_context(),
5592 };
5593 assert_eq!(cmd.command_type(), "read");
5594 }
5595
5596 #[tokio::test]
5597 async fn test_tool_command_payload() {
5598 let executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5599 let args = serde_json::json!({"file": "test.rs", "offset": 10});
5600 let cmd = ToolCommand {
5601 tool_executor: executor,
5602 tool_name: "read".to_string(),
5603 tool_args: args.clone(),
5604 skill_registry: None,
5605 tool_context: test_tool_context(),
5606 };
5607 assert_eq!(cmd.payload(), args);
5608 }
5609
5610 #[tokio::test(flavor = "multi_thread")]
5615 async fn test_agent_loop_with_queue() {
5616 use tokio::sync::broadcast;
5617
5618 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
5619 "Hello",
5620 )]));
5621 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5622 let config = AgentConfig::default();
5623
5624 let (event_tx, _) = broadcast::channel(100);
5625 let queue = SessionLaneQueue::new("test-session", SessionQueueConfig::default(), event_tx)
5626 .await
5627 .unwrap();
5628
5629 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config)
5630 .with_queue(Arc::new(queue));
5631
5632 assert!(agent.command_queue.is_some());
5633 }
5634
5635 #[tokio::test]
5636 async fn test_agent_loop_without_queue() {
5637 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
5638 "Hello",
5639 )]));
5640 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5641 let config = AgentConfig::default();
5642
5643 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
5644
5645 assert!(agent.command_queue.is_none());
5646 }
5647
5648 #[tokio::test]
5653 async fn test_execute_plan_parallel_independent() {
5654 use crate::planning::{Complexity, ExecutionPlan, Task};
5655
5656 let mock_client = Arc::new(MockLlmClient::new(vec![
5659 MockLlmClient::text_response("Step 1 done"),
5660 MockLlmClient::text_response("Step 2 done"),
5661 MockLlmClient::text_response("Step 3 done"),
5662 ]));
5663
5664 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5665 let config = AgentConfig::default();
5666 let agent = AgentLoop::new(
5667 mock_client.clone(),
5668 tool_executor,
5669 test_tool_context(),
5670 config,
5671 );
5672
5673 let mut plan = ExecutionPlan::new("Test parallel", Complexity::Simple);
5674 plan.add_step(Task::new("s1", "First step"));
5675 plan.add_step(Task::new("s2", "Second step"));
5676 plan.add_step(Task::new("s3", "Third step"));
5677
5678 let (tx, mut rx) = mpsc::channel(100);
5679 let result = agent.execute_plan(&[], &plan, Some(tx)).await.unwrap();
5680
5681 assert_eq!(result.usage.total_tokens, 45);
5683
5684 let mut step_starts = Vec::new();
5686 let mut step_ends = Vec::new();
5687 rx.close();
5688 while let Some(event) = rx.recv().await {
5689 match event {
5690 AgentEvent::StepStart { step_id, .. } => step_starts.push(step_id),
5691 AgentEvent::StepEnd {
5692 step_id, status, ..
5693 } => {
5694 assert_eq!(status, TaskStatus::Completed);
5695 step_ends.push(step_id);
5696 }
5697 _ => {}
5698 }
5699 }
5700 assert_eq!(step_starts.len(), 3);
5701 assert_eq!(step_ends.len(), 3);
5702 }
5703
5704 #[tokio::test]
5705 async fn test_execute_plan_respects_dependencies() {
5706 use crate::planning::{Complexity, ExecutionPlan, Task};
5707
5708 let mock_client = Arc::new(MockLlmClient::new(vec![
5711 MockLlmClient::text_response("Step 1 done"),
5712 MockLlmClient::text_response("Step 2 done"),
5713 MockLlmClient::text_response("Step 3 done"),
5714 ]));
5715
5716 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5717 let config = AgentConfig::default();
5718 let agent = AgentLoop::new(
5719 mock_client.clone(),
5720 tool_executor,
5721 test_tool_context(),
5722 config,
5723 );
5724
5725 let mut plan = ExecutionPlan::new("Test deps", Complexity::Medium);
5726 plan.add_step(Task::new("s1", "Independent A"));
5727 plan.add_step(Task::new("s2", "Independent B"));
5728 plan.add_step(
5729 Task::new("s3", "Depends on A+B")
5730 .with_dependencies(vec!["s1".to_string(), "s2".to_string()]),
5731 );
5732
5733 let (tx, mut rx) = mpsc::channel(100);
5734 let result = agent.execute_plan(&[], &plan, Some(tx)).await.unwrap();
5735
5736 assert_eq!(result.usage.total_tokens, 45);
5738
5739 let mut events = Vec::new();
5741 rx.close();
5742 while let Some(event) = rx.recv().await {
5743 match &event {
5744 AgentEvent::StepStart { step_id, .. } => {
5745 events.push(format!("start:{}", step_id));
5746 }
5747 AgentEvent::StepEnd { step_id, .. } => {
5748 events.push(format!("end:{}", step_id));
5749 }
5750 _ => {}
5751 }
5752 }
5753
5754 let s1_end = events.iter().position(|e| e == "end:s1").unwrap();
5756 let s2_end = events.iter().position(|e| e == "end:s2").unwrap();
5757 let s3_start = events.iter().position(|e| e == "start:s3").unwrap();
5758 assert!(
5759 s3_start > s1_end,
5760 "s3 started before s1 ended: {:?}",
5761 events
5762 );
5763 assert!(
5764 s3_start > s2_end,
5765 "s3 started before s2 ended: {:?}",
5766 events
5767 );
5768
5769 assert!(result.text.contains("Step 3 done") || !result.text.is_empty());
5771 }
5772
5773 #[tokio::test]
5774 async fn test_execute_plan_handles_step_failure() {
5775 use crate::planning::{Complexity, ExecutionPlan, Task};
5776
5777 let mock_client = Arc::new(MockLlmClient::new(vec![
5787 MockLlmClient::text_response("s1 done"),
5789 MockLlmClient::text_response("s3 done"),
5790 ]));
5793
5794 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5795 let config = AgentConfig::default();
5796 let agent = AgentLoop::new(
5797 mock_client.clone(),
5798 tool_executor,
5799 test_tool_context(),
5800 config,
5801 );
5802
5803 let mut plan = ExecutionPlan::new("Test failure", Complexity::Medium);
5804 plan.add_step(Task::new("s1", "Independent step"));
5805 plan.add_step(Task::new("s2", "Depends on s1").with_dependencies(vec!["s1".to_string()]));
5806 plan.add_step(Task::new("s3", "Another independent"));
5807 plan.add_step(Task::new("s4", "Depends on s2").with_dependencies(vec!["s2".to_string()]));
5808
5809 let (tx, mut rx) = mpsc::channel(100);
5810 let _result = agent.execute_plan(&[], &plan, Some(tx)).await.unwrap();
5811
5812 let mut completed_steps = Vec::new();
5815 let mut failed_steps = Vec::new();
5816 rx.close();
5817 while let Some(event) = rx.recv().await {
5818 if let AgentEvent::StepEnd {
5819 step_id, status, ..
5820 } = event
5821 {
5822 match status {
5823 TaskStatus::Completed => completed_steps.push(step_id),
5824 TaskStatus::Failed => failed_steps.push(step_id),
5825 _ => {}
5826 }
5827 }
5828 }
5829
5830 assert!(
5831 completed_steps.contains(&"s1".to_string()),
5832 "s1 should complete"
5833 );
5834 assert!(
5835 completed_steps.contains(&"s3".to_string()),
5836 "s3 should complete"
5837 );
5838 assert!(failed_steps.contains(&"s2".to_string()), "s2 should fail");
5839 assert!(
5841 !completed_steps.contains(&"s4".to_string()),
5842 "s4 should not complete"
5843 );
5844 assert!(
5845 !failed_steps.contains(&"s4".to_string()),
5846 "s4 should not fail (never started)"
5847 );
5848 }
5849
5850 #[test]
5855 fn test_agent_config_resilience_defaults() {
5856 let config = AgentConfig::default();
5857 assert_eq!(config.max_parse_retries, 2);
5858 assert_eq!(config.tool_timeout_ms, None);
5859 assert_eq!(config.circuit_breaker_threshold, 3);
5860 }
5861
5862 #[tokio::test]
5864 async fn test_parse_error_recovery_bails_after_threshold() {
5865 let mock_client = Arc::new(MockLlmClient::new(vec![
5867 MockLlmClient::tool_call_response(
5868 "c1",
5869 "bash",
5870 serde_json::json!({"__parse_error": "unexpected token at position 5"}),
5871 ),
5872 MockLlmClient::tool_call_response(
5873 "c2",
5874 "bash",
5875 serde_json::json!({"__parse_error": "missing closing brace"}),
5876 ),
5877 MockLlmClient::tool_call_response(
5878 "c3",
5879 "bash",
5880 serde_json::json!({"__parse_error": "still broken"}),
5881 ),
5882 MockLlmClient::text_response("Done"), ]));
5884
5885 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5886 let config = AgentConfig {
5887 max_parse_retries: 2,
5888 ..AgentConfig::default()
5889 };
5890 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
5891 let result = agent.execute(&[], "Do something", None).await;
5892 assert!(result.is_err(), "should bail after parse error threshold");
5893 let err = result.unwrap_err().to_string();
5894 assert!(
5895 err.contains("malformed tool arguments"),
5896 "error should mention malformed tool arguments, got: {}",
5897 err
5898 );
5899 }
5900
5901 #[tokio::test]
5903 async fn test_parse_error_counter_resets_on_success() {
5904 let mock_client = Arc::new(MockLlmClient::new(vec![
5908 MockLlmClient::tool_call_response(
5909 "c1",
5910 "bash",
5911 serde_json::json!({"__parse_error": "bad args"}),
5912 ),
5913 MockLlmClient::tool_call_response(
5914 "c2",
5915 "bash",
5916 serde_json::json!({"__parse_error": "bad args again"}),
5917 ),
5918 MockLlmClient::tool_call_response(
5920 "c3",
5921 "bash",
5922 serde_json::json!({"command": "echo ok"}),
5923 ),
5924 MockLlmClient::text_response("All done"),
5925 ]));
5926
5927 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5928 let config = AgentConfig {
5929 max_parse_retries: 2,
5930 ..AgentConfig::default()
5931 };
5932 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
5933 let result = agent.execute(&[], "Do something", None).await;
5934 assert!(
5935 result.is_ok(),
5936 "should not bail — counter reset after successful tool, got: {:?}",
5937 result.err()
5938 );
5939 assert_eq!(result.unwrap().text, "All done");
5940 }
5941
5942 #[tokio::test]
5944 async fn test_tool_timeout_produces_error_result() {
5945 let mock_client = Arc::new(MockLlmClient::new(vec![
5946 MockLlmClient::tool_call_response(
5947 "t1",
5948 "bash",
5949 serde_json::json!({"command": "sleep 10"}),
5950 ),
5951 MockLlmClient::text_response("The command timed out."),
5952 ]));
5953
5954 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5955 let config = AgentConfig {
5956 tool_timeout_ms: Some(50),
5958 ..AgentConfig::default()
5959 };
5960 let agent = AgentLoop::new(
5961 mock_client.clone(),
5962 tool_executor,
5963 test_tool_context(),
5964 config,
5965 );
5966 let result = agent.execute(&[], "Run sleep", None).await;
5967 assert!(
5968 result.is_ok(),
5969 "session should continue after tool timeout: {:?}",
5970 result.err()
5971 );
5972 assert_eq!(result.unwrap().text, "The command timed out.");
5973 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 2);
5975 }
5976
5977 #[tokio::test]
5979 async fn test_tool_within_timeout_succeeds() {
5980 let mock_client = Arc::new(MockLlmClient::new(vec![
5981 MockLlmClient::tool_call_response(
5982 "t1",
5983 "bash",
5984 serde_json::json!({"command": "echo fast"}),
5985 ),
5986 MockLlmClient::text_response("Command succeeded."),
5987 ]));
5988
5989 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5990 let config = AgentConfig {
5991 tool_timeout_ms: Some(5_000), ..AgentConfig::default()
5993 };
5994 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
5995 let result = agent.execute(&[], "Run something fast", None).await;
5996 assert!(
5997 result.is_ok(),
5998 "fast tool should succeed: {:?}",
5999 result.err()
6000 );
6001 assert_eq!(result.unwrap().text, "Command succeeded.");
6002 }
6003
6004 #[tokio::test]
6006 async fn test_circuit_breaker_retries_non_streaming() {
6007 let mock_client = Arc::new(MockLlmClient::new(vec![]));
6010
6011 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
6012 let config = AgentConfig {
6013 circuit_breaker_threshold: 2,
6014 ..AgentConfig::default()
6015 };
6016 let agent = AgentLoop::new(
6017 mock_client.clone(),
6018 tool_executor,
6019 test_tool_context(),
6020 config,
6021 );
6022 let result = agent.execute(&[], "Hello", None).await;
6023 assert!(result.is_err(), "should fail when LLM always errors");
6024 let err = result.unwrap_err().to_string();
6025 assert!(
6026 err.contains("circuit breaker"),
6027 "error should mention circuit breaker, got: {}",
6028 err
6029 );
6030 assert_eq!(
6031 mock_client.call_count.load(Ordering::SeqCst),
6032 2,
6033 "should make exactly threshold=2 LLM calls"
6034 );
6035 }
6036
6037 #[tokio::test]
6039 async fn test_circuit_breaker_threshold_one_no_retry() {
6040 let mock_client = Arc::new(MockLlmClient::new(vec![]));
6041
6042 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
6043 let config = AgentConfig {
6044 circuit_breaker_threshold: 1,
6045 ..AgentConfig::default()
6046 };
6047 let agent = AgentLoop::new(
6048 mock_client.clone(),
6049 tool_executor,
6050 test_tool_context(),
6051 config,
6052 );
6053 let result = agent.execute(&[], "Hello", None).await;
6054 assert!(result.is_err());
6055 assert_eq!(
6056 mock_client.call_count.load(Ordering::SeqCst),
6057 1,
6058 "with threshold=1 exactly one attempt should be made"
6059 );
6060 }
6061
6062 #[tokio::test]
6064 async fn test_circuit_breaker_succeeds_if_llm_recovers() {
6065 struct FailOnceThenSucceed {
6067 inner: MockLlmClient,
6068 failed_once: std::sync::atomic::AtomicBool,
6069 call_count: AtomicUsize,
6070 }
6071
6072 #[async_trait::async_trait]
6073 impl LlmClient for FailOnceThenSucceed {
6074 async fn complete(
6075 &self,
6076 messages: &[Message],
6077 system: Option<&str>,
6078 tools: &[ToolDefinition],
6079 ) -> Result<LlmResponse> {
6080 self.call_count.fetch_add(1, Ordering::SeqCst);
6081 let already_failed = self
6082 .failed_once
6083 .swap(true, std::sync::atomic::Ordering::SeqCst);
6084 if !already_failed {
6085 anyhow::bail!("transient network error");
6086 }
6087 self.inner.complete(messages, system, tools).await
6088 }
6089
6090 async fn complete_streaming(
6091 &self,
6092 messages: &[Message],
6093 system: Option<&str>,
6094 tools: &[ToolDefinition],
6095 ) -> Result<tokio::sync::mpsc::Receiver<crate::llm::StreamEvent>> {
6096 self.inner.complete_streaming(messages, system, tools).await
6097 }
6098 }
6099
6100 let mock = Arc::new(FailOnceThenSucceed {
6101 inner: MockLlmClient::new(vec![MockLlmClient::text_response("Recovered!")]),
6102 failed_once: std::sync::atomic::AtomicBool::new(false),
6103 call_count: AtomicUsize::new(0),
6104 });
6105
6106 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
6107 let config = AgentConfig {
6108 circuit_breaker_threshold: 3,
6109 ..AgentConfig::default()
6110 };
6111 let agent = AgentLoop::new(mock.clone(), tool_executor, test_tool_context(), config);
6112 let result = agent.execute(&[], "Hello", None).await;
6113 assert!(
6114 result.is_ok(),
6115 "should succeed when LLM recovers within threshold: {:?}",
6116 result.err()
6117 );
6118 assert_eq!(result.unwrap().text, "Recovered!");
6119 assert_eq!(
6120 mock.call_count.load(Ordering::SeqCst),
6121 2,
6122 "should have made exactly 2 calls (1 fail + 1 success)"
6123 );
6124 }
6125
6126 #[test]
6129 fn test_looks_incomplete_empty() {
6130 assert!(AgentLoop::looks_incomplete(""));
6131 assert!(AgentLoop::looks_incomplete(" "));
6132 }
6133
6134 #[test]
6135 fn test_looks_incomplete_trailing_colon() {
6136 assert!(AgentLoop::looks_incomplete("Let me check the file:"));
6137 assert!(AgentLoop::looks_incomplete("Next steps:"));
6138 }
6139
6140 #[test]
6141 fn test_looks_incomplete_ellipsis() {
6142 assert!(AgentLoop::looks_incomplete("Working on it..."));
6143 assert!(AgentLoop::looks_incomplete("Processing…"));
6144 }
6145
6146 #[test]
6147 fn test_looks_incomplete_intent_phrases() {
6148 assert!(AgentLoop::looks_incomplete(
6149 "I'll start by reading the file."
6150 ));
6151 assert!(AgentLoop::looks_incomplete(
6152 "Let me check the configuration."
6153 ));
6154 assert!(AgentLoop::looks_incomplete("I will now run the tests."));
6155 assert!(AgentLoop::looks_incomplete(
6156 "I need to update the Cargo.toml."
6157 ));
6158 }
6159
6160 #[test]
6161 fn test_looks_complete_final_answer() {
6162 assert!(!AgentLoop::looks_incomplete(
6164 "The tests pass. All changes have been applied successfully."
6165 ));
6166 assert!(!AgentLoop::looks_incomplete(
6167 "Done. I've updated the three files and verified the build succeeds."
6168 ));
6169 assert!(!AgentLoop::looks_incomplete("42"));
6170 assert!(!AgentLoop::looks_incomplete("Yes."));
6171 }
6172
6173 #[test]
6174 fn test_looks_incomplete_multiline_complete() {
6175 let text = "Here is the summary:\n\n- Fixed the bug in agent.rs\n- All tests pass\n- Build succeeds";
6176 assert!(!AgentLoop::looks_incomplete(text));
6177 }
6178}