1use crate::context::{ContextProvider, ContextQuery, ContextResult};
13use crate::hitl::ConfirmationProvider;
14use crate::hooks::{
15 ErrorType, GenerateEndEvent, GenerateStartEvent, HookEvent, HookExecutor, HookResult,
16 OnErrorEvent, PostResponseEvent, PostToolUseEvent, PrePromptEvent, PreToolUseEvent,
17 TokenUsageInfo, ToolCallInfo, ToolResultData,
18};
19use crate::llm::{LlmClient, LlmResponse, Message, TokenUsage, ToolDefinition};
20use crate::permissions::{PermissionChecker, PermissionDecision};
21use crate::planning::{AgentGoal, ExecutionPlan, TaskStatus};
22use crate::prompts::SystemPromptSlots;
23use crate::queue::SessionCommand;
24use crate::session_lane_queue::SessionLaneQueue;
25use crate::tool_search::ToolIndex;
26use crate::tools::{ToolContext, ToolExecutor, ToolStreamEvent};
27use anyhow::{Context, Result};
28use async_trait::async_trait;
29use futures::future::join_all;
30use serde::{Deserialize, Serialize};
31use serde_json::Value;
32use std::sync::Arc;
33use std::time::Duration;
34use tokio::sync::{mpsc, RwLock};
35
36const MAX_TOOL_ROUNDS: usize = 50;
38
39#[derive(Clone)]
41pub struct AgentConfig {
42 pub prompt_slots: SystemPromptSlots,
48 pub tools: Vec<ToolDefinition>,
49 pub max_tool_rounds: usize,
50 pub security_provider: Option<Arc<dyn crate::security::SecurityProvider>>,
52 pub permission_checker: Option<Arc<dyn PermissionChecker>>,
54 pub confirmation_manager: Option<Arc<dyn ConfirmationProvider>>,
56 pub context_providers: Vec<Arc<dyn ContextProvider>>,
58 pub planning_enabled: bool,
60 pub goal_tracking: bool,
62 pub hook_engine: Option<Arc<dyn HookExecutor>>,
64 pub skill_registry: Option<Arc<crate::skills::SkillRegistry>>,
66 pub max_parse_retries: u32,
72 pub tool_timeout_ms: Option<u64>,
78 pub circuit_breaker_threshold: u32,
85 pub duplicate_tool_call_threshold: u32,
91 pub auto_compact: bool,
93 pub auto_compact_threshold: f32,
96 pub max_context_tokens: usize,
99 pub llm_client: Option<Arc<dyn LlmClient>>,
101 pub memory: Option<Arc<crate::memory::AgentMemory>>,
103 pub continuation_enabled: bool,
111 pub max_continuation_turns: u32,
115 pub tool_index: Option<ToolIndex>,
120}
121
122impl std::fmt::Debug for AgentConfig {
123 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
124 f.debug_struct("AgentConfig")
125 .field("prompt_slots", &self.prompt_slots)
126 .field("tools", &self.tools)
127 .field("max_tool_rounds", &self.max_tool_rounds)
128 .field("security_provider", &self.security_provider.is_some())
129 .field("permission_checker", &self.permission_checker.is_some())
130 .field("confirmation_manager", &self.confirmation_manager.is_some())
131 .field("context_providers", &self.context_providers.len())
132 .field("planning_enabled", &self.planning_enabled)
133 .field("goal_tracking", &self.goal_tracking)
134 .field("hook_engine", &self.hook_engine.is_some())
135 .field(
136 "skill_registry",
137 &self.skill_registry.as_ref().map(|r| r.len()),
138 )
139 .field("max_parse_retries", &self.max_parse_retries)
140 .field("tool_timeout_ms", &self.tool_timeout_ms)
141 .field("circuit_breaker_threshold", &self.circuit_breaker_threshold)
142 .field(
143 "duplicate_tool_call_threshold",
144 &self.duplicate_tool_call_threshold,
145 )
146 .field("auto_compact", &self.auto_compact)
147 .field("auto_compact_threshold", &self.auto_compact_threshold)
148 .field("max_context_tokens", &self.max_context_tokens)
149 .field("continuation_enabled", &self.continuation_enabled)
150 .field("max_continuation_turns", &self.max_continuation_turns)
151 .field("memory", &self.memory.is_some())
152 .field("tool_index", &self.tool_index.as_ref().map(|i| i.len()))
153 .finish()
154 }
155}
156
157impl Default for AgentConfig {
158 fn default() -> Self {
159 Self {
160 prompt_slots: SystemPromptSlots::default(),
161 tools: Vec::new(), max_tool_rounds: MAX_TOOL_ROUNDS,
163 security_provider: None,
164 permission_checker: None,
165 confirmation_manager: None,
166 context_providers: Vec::new(),
167 planning_enabled: false,
168 goal_tracking: false,
169 hook_engine: None,
170 skill_registry: Some(Arc::new(crate::skills::SkillRegistry::with_builtins())),
171 max_parse_retries: 2,
172 tool_timeout_ms: None,
173 circuit_breaker_threshold: 3,
174 duplicate_tool_call_threshold: 3,
175 auto_compact: false,
176 auto_compact_threshold: 0.80,
177 max_context_tokens: 200_000,
178 llm_client: None,
179 memory: None,
180 continuation_enabled: true,
181 max_continuation_turns: 3,
182 tool_index: None,
183 }
184 }
185}
186
187#[derive(Debug, Clone, Serialize, Deserialize)]
193#[serde(tag = "type")]
194#[non_exhaustive]
195pub enum AgentEvent {
196 #[serde(rename = "agent_start")]
198 Start { prompt: String },
199
200 #[serde(rename = "turn_start")]
202 TurnStart { turn: usize },
203
204 #[serde(rename = "text_delta")]
206 TextDelta { text: String },
207
208 #[serde(rename = "tool_start")]
210 ToolStart { id: String, name: String },
211
212 #[serde(rename = "tool_input_delta")]
214 ToolInputDelta { delta: String },
215
216 #[serde(rename = "tool_end")]
218 ToolEnd {
219 id: String,
220 name: String,
221 output: String,
222 exit_code: i32,
223 #[serde(skip_serializing_if = "Option::is_none")]
224 metadata: Option<serde_json::Value>,
225 },
226
227 #[serde(rename = "tool_output_delta")]
229 ToolOutputDelta {
230 id: String,
231 name: String,
232 delta: String,
233 },
234
235 #[serde(rename = "turn_end")]
237 TurnEnd { turn: usize, usage: TokenUsage },
238
239 #[serde(rename = "agent_end")]
241 End {
242 text: String,
243 usage: TokenUsage,
244 #[serde(skip_serializing_if = "Option::is_none")]
245 meta: Option<crate::llm::LlmResponseMeta>,
246 },
247
248 #[serde(rename = "error")]
250 Error { message: String },
251
252 #[serde(rename = "confirmation_required")]
254 ConfirmationRequired {
255 tool_id: String,
256 tool_name: String,
257 args: serde_json::Value,
258 timeout_ms: u64,
259 },
260
261 #[serde(rename = "confirmation_received")]
263 ConfirmationReceived {
264 tool_id: String,
265 approved: bool,
266 reason: Option<String>,
267 },
268
269 #[serde(rename = "confirmation_timeout")]
271 ConfirmationTimeout {
272 tool_id: String,
273 action_taken: String, },
275
276 #[serde(rename = "external_task_pending")]
278 ExternalTaskPending {
279 task_id: String,
280 session_id: String,
281 lane: crate::hitl::SessionLane,
282 command_type: String,
283 payload: serde_json::Value,
284 timeout_ms: u64,
285 },
286
287 #[serde(rename = "external_task_completed")]
289 ExternalTaskCompleted {
290 task_id: String,
291 session_id: String,
292 success: bool,
293 },
294
295 #[serde(rename = "permission_denied")]
297 PermissionDenied {
298 tool_id: String,
299 tool_name: String,
300 args: serde_json::Value,
301 reason: String,
302 },
303
304 #[serde(rename = "context_resolving")]
306 ContextResolving { providers: Vec<String> },
307
308 #[serde(rename = "context_resolved")]
310 ContextResolved {
311 total_items: usize,
312 total_tokens: usize,
313 },
314
315 #[serde(rename = "command_dead_lettered")]
320 CommandDeadLettered {
321 command_id: String,
322 command_type: String,
323 lane: String,
324 error: String,
325 attempts: u32,
326 },
327
328 #[serde(rename = "command_retry")]
330 CommandRetry {
331 command_id: String,
332 command_type: String,
333 lane: String,
334 attempt: u32,
335 delay_ms: u64,
336 },
337
338 #[serde(rename = "queue_alert")]
340 QueueAlert {
341 level: String,
342 alert_type: String,
343 message: String,
344 },
345
346 #[serde(rename = "task_updated")]
351 TaskUpdated {
352 session_id: String,
353 tasks: Vec<crate::planning::Task>,
354 },
355
356 #[serde(rename = "memory_stored")]
361 MemoryStored {
362 memory_id: String,
363 memory_type: String,
364 importance: f32,
365 tags: Vec<String>,
366 },
367
368 #[serde(rename = "memory_recalled")]
370 MemoryRecalled {
371 memory_id: String,
372 content: String,
373 relevance: f32,
374 },
375
376 #[serde(rename = "memories_searched")]
378 MemoriesSearched {
379 query: Option<String>,
380 tags: Vec<String>,
381 result_count: usize,
382 },
383
384 #[serde(rename = "memory_cleared")]
386 MemoryCleared {
387 tier: String, count: u64,
389 },
390
391 #[serde(rename = "subagent_start")]
396 SubagentStart {
397 task_id: String,
399 session_id: String,
401 parent_session_id: String,
403 agent: String,
405 description: String,
407 },
408
409 #[serde(rename = "subagent_progress")]
411 SubagentProgress {
412 task_id: String,
414 session_id: String,
416 status: String,
418 metadata: serde_json::Value,
420 },
421
422 #[serde(rename = "subagent_end")]
424 SubagentEnd {
425 task_id: String,
427 session_id: String,
429 agent: String,
431 output: String,
433 success: bool,
435 },
436
437 #[serde(rename = "planning_start")]
442 PlanningStart { prompt: String },
443
444 #[serde(rename = "planning_end")]
446 PlanningEnd {
447 plan: ExecutionPlan,
448 estimated_steps: usize,
449 },
450
451 #[serde(rename = "step_start")]
453 StepStart {
454 step_id: String,
455 description: String,
456 step_number: usize,
457 total_steps: usize,
458 },
459
460 #[serde(rename = "step_end")]
462 StepEnd {
463 step_id: String,
464 status: TaskStatus,
465 step_number: usize,
466 total_steps: usize,
467 },
468
469 #[serde(rename = "goal_extracted")]
471 GoalExtracted { goal: AgentGoal },
472
473 #[serde(rename = "goal_progress")]
475 GoalProgress {
476 goal: String,
477 progress: f32,
478 completed_steps: usize,
479 total_steps: usize,
480 },
481
482 #[serde(rename = "goal_achieved")]
484 GoalAchieved {
485 goal: String,
486 total_steps: usize,
487 duration_ms: i64,
488 },
489
490 #[serde(rename = "context_compacted")]
495 ContextCompacted {
496 session_id: String,
497 before_messages: usize,
498 after_messages: usize,
499 percent_before: f32,
500 },
501
502 #[serde(rename = "persistence_failed")]
507 PersistenceFailed {
508 session_id: String,
509 operation: String,
510 error: String,
511 },
512
513 #[serde(rename = "btw_answer")]
521 BtwAnswer {
522 question: String,
523 answer: String,
524 usage: TokenUsage,
525 },
526}
527
528#[derive(Debug, Clone)]
530pub struct AgentResult {
531 pub text: String,
532 pub messages: Vec<Message>,
533 pub usage: TokenUsage,
534 pub tool_calls_count: usize,
535}
536
537pub struct ToolCommand {
545 tool_executor: Arc<ToolExecutor>,
546 tool_name: String,
547 tool_args: Value,
548 tool_context: ToolContext,
549 skill_registry: Option<Arc<crate::skills::SkillRegistry>>,
550}
551
552impl ToolCommand {
553 pub fn new(
555 tool_executor: Arc<ToolExecutor>,
556 tool_name: String,
557 tool_args: Value,
558 tool_context: ToolContext,
559 skill_registry: Option<Arc<crate::skills::SkillRegistry>>,
560 ) -> Self {
561 Self {
562 tool_executor,
563 tool_name,
564 tool_args,
565 tool_context,
566 skill_registry,
567 }
568 }
569}
570
571#[async_trait]
572impl SessionCommand for ToolCommand {
573 async fn execute(&self) -> Result<Value> {
574 if let Some(registry) = &self.skill_registry {
576 let instruction_skills = registry.by_kind(crate::skills::SkillKind::Instruction);
577
578 let has_restrictions = instruction_skills.iter().any(|s| s.allowed_tools.is_some());
580
581 if has_restrictions {
582 let mut allowed = false;
583
584 for skill in &instruction_skills {
585 if skill.is_tool_allowed(&self.tool_name) {
586 allowed = true;
587 break;
588 }
589 }
590
591 if !allowed {
592 return Err(anyhow::anyhow!(
593 "Tool '{}' is not allowed by any active skill. Active skills restrict tools to their allowed-tools lists.",
594 self.tool_name
595 ));
596 }
597 }
598 }
599
600 let result = self
602 .tool_executor
603 .execute_with_context(&self.tool_name, &self.tool_args, &self.tool_context)
604 .await?;
605 Ok(serde_json::json!({
606 "output": result.output,
607 "exit_code": result.exit_code,
608 "metadata": result.metadata,
609 }))
610 }
611
612 fn command_type(&self) -> &str {
613 &self.tool_name
614 }
615
616 fn payload(&self) -> Value {
617 self.tool_args.clone()
618 }
619}
620
621#[derive(Clone)]
627pub struct AgentLoop {
628 llm_client: Arc<dyn LlmClient>,
629 tool_executor: Arc<ToolExecutor>,
630 tool_context: ToolContext,
631 config: AgentConfig,
632 tool_metrics: Option<Arc<RwLock<crate::telemetry::ToolMetrics>>>,
634 command_queue: Option<Arc<SessionLaneQueue>>,
636}
637
638impl AgentLoop {
639 pub fn new(
640 llm_client: Arc<dyn LlmClient>,
641 tool_executor: Arc<ToolExecutor>,
642 tool_context: ToolContext,
643 config: AgentConfig,
644 ) -> Self {
645 Self {
646 llm_client,
647 tool_executor,
648 tool_context,
649 config,
650 tool_metrics: None,
651 command_queue: None,
652 }
653 }
654
655 pub fn with_tool_metrics(
657 mut self,
658 metrics: Arc<RwLock<crate::telemetry::ToolMetrics>>,
659 ) -> Self {
660 self.tool_metrics = Some(metrics);
661 self
662 }
663
664 pub fn with_queue(mut self, queue: Arc<SessionLaneQueue>) -> Self {
669 self.command_queue = Some(queue);
670 self
671 }
672
673 async fn execute_tool_timed(
679 &self,
680 name: &str,
681 args: &serde_json::Value,
682 ctx: &ToolContext,
683 ) -> anyhow::Result<crate::tools::ToolResult> {
684 let fut = self.tool_executor.execute_with_context(name, args, ctx);
685 if let Some(timeout_ms) = self.config.tool_timeout_ms {
686 match tokio::time::timeout(Duration::from_millis(timeout_ms), fut).await {
687 Ok(result) => result,
688 Err(_) => Err(anyhow::anyhow!(
689 "Tool '{}' timed out after {}ms",
690 name,
691 timeout_ms
692 )),
693 }
694 } else {
695 fut.await
696 }
697 }
698
699 fn tool_result_to_tuple(
701 result: anyhow::Result<crate::tools::ToolResult>,
702 ) -> (
703 String,
704 i32,
705 bool,
706 Option<serde_json::Value>,
707 Vec<crate::llm::Attachment>,
708 ) {
709 match result {
710 Ok(r) => (
711 r.output,
712 r.exit_code,
713 r.exit_code != 0,
714 r.metadata,
715 r.images,
716 ),
717 Err(e) => {
718 let msg = e.to_string();
719 let hint = if Self::is_transient_error(&msg) {
721 " [transient — you may retry this tool call]"
722 } else {
723 " [permanent — do not retry without changing the arguments]"
724 };
725 (
726 format!("Tool execution error: {}{}", msg, hint),
727 1,
728 true,
729 None,
730 Vec::new(),
731 )
732 }
733 }
734 }
735
736 fn detect_project_hint(workspace: &std::path::Path) -> String {
740 struct Marker {
741 file: &'static str,
742 lang: &'static str,
743 tip: &'static str,
744 }
745
746 let markers = [
747 Marker {
748 file: "Cargo.toml",
749 lang: "Rust",
750 tip: "Use `cargo build`, `cargo test`, `cargo clippy`, and `cargo fmt`. \
751 Prefer `anyhow` / `thiserror` for error handling. \
752 Follow the Microsoft Rust Guidelines (no panics in library code, \
753 async-first with Tokio).",
754 },
755 Marker {
756 file: "package.json",
757 lang: "Node.js / TypeScript",
758 tip: "Check `package.json` for the package manager (npm/yarn/pnpm/bun) \
759 and available scripts. Prefer TypeScript with strict mode. \
760 Use ESM imports unless the project is CommonJS.",
761 },
762 Marker {
763 file: "pyproject.toml",
764 lang: "Python",
765 tip: "Use the package manager declared in `pyproject.toml` \
766 (uv, poetry, hatch, etc.). Prefer type hints and async/await for I/O.",
767 },
768 Marker {
769 file: "setup.py",
770 lang: "Python",
771 tip: "Legacy Python project. Prefer type hints and async/await for I/O.",
772 },
773 Marker {
774 file: "requirements.txt",
775 lang: "Python",
776 tip: "Python project with pip-style dependencies. \
777 Prefer type hints and async/await for I/O.",
778 },
779 Marker {
780 file: "go.mod",
781 lang: "Go",
782 tip: "Use `go build ./...` and `go test ./...`. \
783 Follow standard Go project layout. Use `gofmt` for formatting.",
784 },
785 Marker {
786 file: "pom.xml",
787 lang: "Java / Maven",
788 tip: "Use `mvn compile`, `mvn test`, `mvn package`. \
789 Follow standard Maven project structure.",
790 },
791 Marker {
792 file: "build.gradle",
793 lang: "Java / Gradle",
794 tip: "Use `./gradlew build` and `./gradlew test`. \
795 Follow standard Gradle project structure.",
796 },
797 Marker {
798 file: "build.gradle.kts",
799 lang: "Kotlin / Gradle",
800 tip: "Use `./gradlew build` and `./gradlew test`. \
801 Prefer Kotlin coroutines for async work.",
802 },
803 Marker {
804 file: "CMakeLists.txt",
805 lang: "C / C++",
806 tip: "Use `cmake -B build && cmake --build build`. \
807 Check for `compile_commands.json` for IDE tooling.",
808 },
809 Marker {
810 file: "Makefile",
811 lang: "C / C++ (or generic)",
812 tip: "Use `make` or `make <target>`. \
813 Check available targets with `make help` or by reading the Makefile.",
814 },
815 ];
816
817 let is_dotnet = workspace.join("*.csproj").exists() || {
819 std::fs::read_dir(workspace)
821 .map(|entries| {
822 entries.flatten().any(|e| {
823 let name = e.file_name();
824 let s = name.to_string_lossy();
825 s.ends_with(".csproj") || s.ends_with(".sln")
826 })
827 })
828 .unwrap_or(false)
829 };
830
831 if is_dotnet {
832 return "## Project Context\n\nThis is a **C# / .NET** project. \
833 Use `dotnet build`, `dotnet test`, and `dotnet run`. \
834 Follow C# coding conventions and async/await patterns."
835 .to_string();
836 }
837
838 for marker in &markers {
839 if workspace.join(marker.file).exists() {
840 return format!(
841 "## Project Context\n\nThis is a **{}** project. {}",
842 marker.lang, marker.tip
843 );
844 }
845 }
846
847 String::new()
848 }
849
850 fn is_transient_error(msg: &str) -> bool {
853 let lower = msg.to_lowercase();
854 lower.contains("timeout")
855 || lower.contains("timed out")
856 || lower.contains("connection refused")
857 || lower.contains("connection reset")
858 || lower.contains("broken pipe")
859 || lower.contains("temporarily unavailable")
860 || lower.contains("resource temporarily unavailable")
861 || lower.contains("os error 11") || lower.contains("os error 35") || lower.contains("rate limit")
864 || lower.contains("too many requests")
865 || lower.contains("service unavailable")
866 || lower.contains("network unreachable")
867 }
868
869 fn is_parallel_safe_write(name: &str, _args: &serde_json::Value) -> bool {
872 matches!(
873 name,
874 "write_file" | "edit_file" | "create_file" | "append_to_file" | "replace_in_file"
875 )
876 }
877
878 fn extract_write_path(args: &serde_json::Value) -> Option<String> {
880 args.get("path")
883 .and_then(|v| v.as_str())
884 .map(|s| s.to_string())
885 }
886
887 async fn execute_tool_queued_or_direct(
889 &self,
890 name: &str,
891 args: &serde_json::Value,
892 ctx: &ToolContext,
893 ) -> anyhow::Result<crate::tools::ToolResult> {
894 if let Some(ref queue) = self.command_queue {
895 let command = ToolCommand::new(
896 Arc::clone(&self.tool_executor),
897 name.to_string(),
898 args.clone(),
899 ctx.clone(),
900 self.config.skill_registry.clone(),
901 );
902 let rx = queue.submit_by_tool(name, Box::new(command)).await;
903 match rx.await {
904 Ok(Ok(value)) => {
905 let output = value["output"]
906 .as_str()
907 .ok_or_else(|| {
908 anyhow::anyhow!(
909 "Queue result missing 'output' field for tool '{}'",
910 name
911 )
912 })?
913 .to_string();
914 let exit_code = value["exit_code"].as_i64().unwrap_or(0) as i32;
915 return Ok(crate::tools::ToolResult {
916 name: name.to_string(),
917 output,
918 exit_code,
919 metadata: None,
920 images: Vec::new(),
921 });
922 }
923 Ok(Err(e)) => {
924 tracing::warn!(
925 "Queue execution failed for tool '{}', falling back to direct: {}",
926 name,
927 e
928 );
929 }
930 Err(_) => {
931 tracing::warn!(
932 "Queue channel closed for tool '{}', falling back to direct",
933 name
934 );
935 }
936 }
937 }
938 self.execute_tool_timed(name, args, ctx).await
939 }
940
941 async fn call_llm(
952 &self,
953 messages: &[Message],
954 system: Option<&str>,
955 event_tx: &Option<mpsc::Sender<AgentEvent>>,
956 cancel_token: &tokio_util::sync::CancellationToken,
957 ) -> anyhow::Result<LlmResponse> {
958 let tools = if let Some(ref index) = self.config.tool_index {
960 let query = messages
961 .iter()
962 .rev()
963 .find(|m| m.role == "user")
964 .and_then(|m| {
965 m.content.iter().find_map(|b| match b {
966 crate::llm::ContentBlock::Text { text } => Some(text.as_str()),
967 _ => None,
968 })
969 })
970 .unwrap_or("");
971 let matches = index.search(query, index.len());
972 let matched_names: std::collections::HashSet<&str> =
973 matches.iter().map(|m| m.name.as_str()).collect();
974 self.config
975 .tools
976 .iter()
977 .filter(|t| matched_names.contains(t.name.as_str()))
978 .cloned()
979 .collect::<Vec<_>>()
980 } else {
981 self.config.tools.clone()
982 };
983
984 if event_tx.is_some() {
985 let mut stream_rx = match self
986 .llm_client
987 .complete_streaming(messages, system, &tools)
988 .await
989 {
990 Ok(rx) => rx,
991 Err(stream_error) => {
992 tracing::warn!(
993 error = %stream_error,
994 "LLM streaming setup failed; falling back to non-streaming completion"
995 );
996 return self
997 .llm_client
998 .complete(messages, system, &tools)
999 .await
1000 .with_context(|| {
1001 format!(
1002 "LLM streaming call failed ({stream_error}); non-streaming fallback also failed"
1003 )
1004 });
1005 }
1006 };
1007
1008 let mut final_response: Option<LlmResponse> = None;
1009 loop {
1010 tokio::select! {
1011 _ = cancel_token.cancelled() => {
1012 tracing::info!("🛑 LLM streaming cancelled by CancellationToken");
1013 anyhow::bail!("Operation cancelled by user");
1014 }
1015 event = stream_rx.recv() => {
1016 match event {
1017 Some(crate::llm::StreamEvent::TextDelta(text)) => {
1018 if let Some(tx) = event_tx {
1019 tx.send(AgentEvent::TextDelta { text }).await.ok();
1020 }
1021 }
1022 Some(crate::llm::StreamEvent::ToolUseStart { id, name }) => {
1023 if let Some(tx) = event_tx {
1024 tx.send(AgentEvent::ToolStart { id, name }).await.ok();
1025 }
1026 }
1027 Some(crate::llm::StreamEvent::ToolUseInputDelta(delta)) => {
1028 if let Some(tx) = event_tx {
1029 tx.send(AgentEvent::ToolInputDelta { delta }).await.ok();
1030 }
1031 }
1032 Some(crate::llm::StreamEvent::Done(resp)) => {
1033 final_response = Some(resp);
1034 break;
1035 }
1036 None => break,
1037 }
1038 }
1039 }
1040 }
1041 final_response.context("Stream ended without final response")
1042 } else {
1043 self.llm_client
1044 .complete(messages, system, &tools)
1045 .await
1046 .context("LLM call failed")
1047 }
1048 }
1049
1050 fn streaming_tool_context(
1059 &self,
1060 event_tx: &Option<mpsc::Sender<AgentEvent>>,
1061 tool_id: &str,
1062 tool_name: &str,
1063 ) -> ToolContext {
1064 let mut ctx = self.tool_context.clone();
1065 if let Some(agent_tx) = event_tx {
1066 let (tool_tx, mut tool_rx) = mpsc::channel::<ToolStreamEvent>(64);
1067 ctx.event_tx = Some(tool_tx);
1068
1069 let agent_tx = agent_tx.clone();
1070 let tool_id = tool_id.to_string();
1071 let tool_name = tool_name.to_string();
1072 tokio::spawn(async move {
1073 while let Some(event) = tool_rx.recv().await {
1074 match event {
1075 ToolStreamEvent::OutputDelta(delta) => {
1076 agent_tx
1077 .send(AgentEvent::ToolOutputDelta {
1078 id: tool_id.clone(),
1079 name: tool_name.clone(),
1080 delta,
1081 })
1082 .await
1083 .ok();
1084 }
1085 }
1086 }
1087 });
1088 }
1089 ctx
1090 }
1091
1092 async fn resolve_context(&self, prompt: &str, session_id: Option<&str>) -> Vec<ContextResult> {
1096 if self.config.context_providers.is_empty() {
1097 return Vec::new();
1098 }
1099
1100 let query = ContextQuery::new(prompt).with_session_id(session_id.unwrap_or(""));
1101
1102 let futures = self
1103 .config
1104 .context_providers
1105 .iter()
1106 .map(|p| p.query(&query));
1107 let outcomes = join_all(futures).await;
1108
1109 outcomes
1110 .into_iter()
1111 .enumerate()
1112 .filter_map(|(i, r)| match r {
1113 Ok(result) if !result.is_empty() => Some(result),
1114 Ok(_) => None,
1115 Err(e) => {
1116 tracing::warn!(
1117 "Context provider '{}' failed: {}",
1118 self.config.context_providers[i].name(),
1119 e
1120 );
1121 None
1122 }
1123 })
1124 .collect()
1125 }
1126
1127 fn looks_incomplete(text: &str) -> bool {
1135 let t = text.trim();
1136 if t.is_empty() {
1137 return true;
1138 }
1139 if t.len() < 80 && !t.contains('\n') {
1141 let ends_continuation =
1144 t.ends_with(':') || t.ends_with("...") || t.ends_with('…') || t.ends_with(',');
1145 if ends_continuation {
1146 return true;
1147 }
1148 }
1149 let incomplete_phrases = [
1151 "i'll ",
1152 "i will ",
1153 "let me ",
1154 "i need to ",
1155 "i should ",
1156 "next, i",
1157 "first, i",
1158 "now i",
1159 "i'll start",
1160 "i'll begin",
1161 "i'll now",
1162 "let's start",
1163 "let's begin",
1164 "to do this",
1165 "i'm going to",
1166 ];
1167 let lower = t.to_lowercase();
1168 for phrase in &incomplete_phrases {
1169 if lower.contains(phrase) {
1170 return true;
1171 }
1172 }
1173 false
1174 }
1175
1176 fn system_prompt(&self) -> String {
1178 self.config.prompt_slots.build()
1179 }
1180
1181 fn build_augmented_system_prompt(&self, context_results: &[ContextResult]) -> Option<String> {
1183 let base = self.system_prompt();
1184
1185 let live_tools = self.tool_executor.definitions();
1187 let mcp_tools: Vec<&ToolDefinition> = live_tools
1188 .iter()
1189 .filter(|t| t.name.starts_with("mcp__"))
1190 .collect();
1191
1192 let mcp_section = if mcp_tools.is_empty() {
1193 String::new()
1194 } else {
1195 let mut lines = vec![
1196 "## MCP Tools".to_string(),
1197 String::new(),
1198 "The following MCP (Model Context Protocol) tools are available. Use them when the task requires external capabilities beyond the built-in tools:".to_string(),
1199 String::new(),
1200 ];
1201 for tool in &mcp_tools {
1202 let display = format!("- `{}` — {}", tool.name, tool.description);
1203 lines.push(display);
1204 }
1205 lines.join("\n")
1206 };
1207
1208 let parts: Vec<&str> = [base.as_str(), mcp_section.as_str()]
1209 .iter()
1210 .filter(|s| !s.is_empty())
1211 .copied()
1212 .collect();
1213
1214 let project_hint = if self.config.prompt_slots.guidelines.is_none() {
1217 Self::detect_project_hint(&self.tool_context.workspace)
1218 } else {
1219 String::new()
1220 };
1221
1222 if context_results.is_empty() {
1223 if project_hint.is_empty() {
1224 return Some(parts.join("\n\n"));
1225 }
1226 return Some(format!("{}\n\n{}", parts.join("\n\n"), project_hint));
1227 }
1228
1229 let context_xml: String = context_results
1231 .iter()
1232 .map(|r| r.to_xml())
1233 .collect::<Vec<_>>()
1234 .join("\n\n");
1235
1236 if project_hint.is_empty() {
1237 Some(format!("{}\n\n{}", parts.join("\n\n"), context_xml))
1238 } else {
1239 Some(format!(
1240 "{}\n\n{}\n\n{}",
1241 parts.join("\n\n"),
1242 project_hint,
1243 context_xml
1244 ))
1245 }
1246 }
1247
1248 async fn notify_turn_complete(&self, session_id: &str, prompt: &str, response: &str) {
1250 let futures = self
1251 .config
1252 .context_providers
1253 .iter()
1254 .map(|p| p.on_turn_complete(session_id, prompt, response));
1255 let outcomes = join_all(futures).await;
1256
1257 for (i, result) in outcomes.into_iter().enumerate() {
1258 if let Err(e) = result {
1259 tracing::warn!(
1260 "Context provider '{}' on_turn_complete failed: {}",
1261 self.config.context_providers[i].name(),
1262 e
1263 );
1264 }
1265 }
1266 }
1267
1268 async fn fire_pre_tool_use(
1271 &self,
1272 session_id: &str,
1273 tool_name: &str,
1274 args: &serde_json::Value,
1275 recent_tools: Vec<String>,
1276 ) -> Option<HookResult> {
1277 if let Some(he) = &self.config.hook_engine {
1278 let event = HookEvent::PreToolUse(PreToolUseEvent {
1279 session_id: session_id.to_string(),
1280 tool: tool_name.to_string(),
1281 args: args.clone(),
1282 working_directory: self.tool_context.workspace.to_string_lossy().to_string(),
1283 recent_tools,
1284 });
1285 let result = he.fire(&event).await;
1286 if result.is_block() {
1287 return Some(result);
1288 }
1289 }
1290 None
1291 }
1292
1293 async fn fire_post_tool_use(
1295 &self,
1296 session_id: &str,
1297 tool_name: &str,
1298 args: &serde_json::Value,
1299 output: &str,
1300 success: bool,
1301 duration_ms: u64,
1302 ) {
1303 if let Some(he) = &self.config.hook_engine {
1304 let event = HookEvent::PostToolUse(PostToolUseEvent {
1305 session_id: session_id.to_string(),
1306 tool: tool_name.to_string(),
1307 args: args.clone(),
1308 result: ToolResultData {
1309 success,
1310 output: output.to_string(),
1311 exit_code: if success { Some(0) } else { Some(1) },
1312 duration_ms,
1313 },
1314 });
1315 let he = Arc::clone(he);
1316 tokio::spawn(async move {
1317 let _ = he.fire(&event).await;
1318 });
1319 }
1320 }
1321
1322 async fn fire_generate_start(
1324 &self,
1325 session_id: &str,
1326 prompt: &str,
1327 system_prompt: &Option<String>,
1328 ) {
1329 if let Some(he) = &self.config.hook_engine {
1330 let event = HookEvent::GenerateStart(GenerateStartEvent {
1331 session_id: session_id.to_string(),
1332 prompt: prompt.to_string(),
1333 system_prompt: system_prompt.clone(),
1334 model_provider: String::new(),
1335 model_name: String::new(),
1336 available_tools: self.config.tools.iter().map(|t| t.name.clone()).collect(),
1337 });
1338 let _ = he.fire(&event).await;
1339 }
1340 }
1341
1342 async fn fire_generate_end(
1344 &self,
1345 session_id: &str,
1346 prompt: &str,
1347 response: &LlmResponse,
1348 duration_ms: u64,
1349 ) {
1350 if let Some(he) = &self.config.hook_engine {
1351 let tool_calls: Vec<ToolCallInfo> = response
1352 .tool_calls()
1353 .iter()
1354 .map(|tc| ToolCallInfo {
1355 name: tc.name.clone(),
1356 args: tc.args.clone(),
1357 })
1358 .collect();
1359
1360 let event = HookEvent::GenerateEnd(GenerateEndEvent {
1361 session_id: session_id.to_string(),
1362 prompt: prompt.to_string(),
1363 response_text: response.text().to_string(),
1364 tool_calls,
1365 usage: TokenUsageInfo {
1366 prompt_tokens: response.usage.prompt_tokens as i32,
1367 completion_tokens: response.usage.completion_tokens as i32,
1368 total_tokens: response.usage.total_tokens as i32,
1369 },
1370 duration_ms,
1371 });
1372 let _ = he.fire(&event).await;
1373 }
1374 }
1375
1376 async fn fire_pre_prompt(
1379 &self,
1380 session_id: &str,
1381 prompt: &str,
1382 system_prompt: &Option<String>,
1383 message_count: usize,
1384 ) -> Option<String> {
1385 if let Some(he) = &self.config.hook_engine {
1386 let event = HookEvent::PrePrompt(PrePromptEvent {
1387 session_id: session_id.to_string(),
1388 prompt: prompt.to_string(),
1389 system_prompt: system_prompt.clone(),
1390 message_count,
1391 });
1392 let result = he.fire(&event).await;
1393 if let HookResult::Continue(Some(modified)) = result {
1394 if let Some(new_prompt) = modified.get("prompt").and_then(|v| v.as_str()) {
1396 return Some(new_prompt.to_string());
1397 }
1398 }
1399 }
1400 None
1401 }
1402
1403 async fn fire_post_response(
1405 &self,
1406 session_id: &str,
1407 response_text: &str,
1408 tool_calls_count: usize,
1409 usage: &TokenUsage,
1410 duration_ms: u64,
1411 ) {
1412 if let Some(he) = &self.config.hook_engine {
1413 let event = HookEvent::PostResponse(PostResponseEvent {
1414 session_id: session_id.to_string(),
1415 response_text: response_text.to_string(),
1416 tool_calls_count,
1417 usage: TokenUsageInfo {
1418 prompt_tokens: usage.prompt_tokens as i32,
1419 completion_tokens: usage.completion_tokens as i32,
1420 total_tokens: usage.total_tokens as i32,
1421 },
1422 duration_ms,
1423 });
1424 let he = Arc::clone(he);
1425 tokio::spawn(async move {
1426 let _ = he.fire(&event).await;
1427 });
1428 }
1429 }
1430
1431 async fn fire_on_error(
1433 &self,
1434 session_id: &str,
1435 error_type: ErrorType,
1436 error_message: &str,
1437 context: serde_json::Value,
1438 ) {
1439 if let Some(he) = &self.config.hook_engine {
1440 let event = HookEvent::OnError(OnErrorEvent {
1441 session_id: session_id.to_string(),
1442 error_type,
1443 error_message: error_message.to_string(),
1444 context,
1445 });
1446 let he = Arc::clone(he);
1447 tokio::spawn(async move {
1448 let _ = he.fire(&event).await;
1449 });
1450 }
1451 }
1452
1453 pub async fn execute(
1459 &self,
1460 history: &[Message],
1461 prompt: &str,
1462 event_tx: Option<mpsc::Sender<AgentEvent>>,
1463 ) -> Result<AgentResult> {
1464 self.execute_with_session(history, prompt, None, event_tx, None)
1465 .await
1466 }
1467
1468 pub async fn execute_from_messages(
1474 &self,
1475 messages: Vec<Message>,
1476 session_id: Option<&str>,
1477 event_tx: Option<mpsc::Sender<AgentEvent>>,
1478 cancel_token: Option<&tokio_util::sync::CancellationToken>,
1479 ) -> Result<AgentResult> {
1480 let default_token = tokio_util::sync::CancellationToken::new();
1481 let token = cancel_token.unwrap_or(&default_token);
1482 tracing::info!(
1483 a3s.session.id = session_id.unwrap_or("none"),
1484 a3s.agent.max_turns = self.config.max_tool_rounds,
1485 "a3s.agent.execute_from_messages started"
1486 );
1487
1488 let effective_prompt = messages
1492 .iter()
1493 .rev()
1494 .find(|m| m.role == "user")
1495 .map(|m| m.text())
1496 .unwrap_or_default();
1497
1498 let result = self
1499 .execute_loop_inner(
1500 &messages,
1501 "",
1502 &effective_prompt,
1503 session_id,
1504 event_tx,
1505 token,
1506 )
1507 .await;
1508
1509 match &result {
1510 Ok(r) => tracing::info!(
1511 a3s.agent.tool_calls_count = r.tool_calls_count,
1512 a3s.llm.total_tokens = r.usage.total_tokens,
1513 "a3s.agent.execute_from_messages completed"
1514 ),
1515 Err(e) => tracing::warn!(
1516 error = %e,
1517 "a3s.agent.execute_from_messages failed"
1518 ),
1519 }
1520
1521 result
1522 }
1523
1524 pub async fn execute_with_session(
1529 &self,
1530 history: &[Message],
1531 prompt: &str,
1532 session_id: Option<&str>,
1533 event_tx: Option<mpsc::Sender<AgentEvent>>,
1534 cancel_token: Option<&tokio_util::sync::CancellationToken>,
1535 ) -> Result<AgentResult> {
1536 let default_token = tokio_util::sync::CancellationToken::new();
1537 let token = cancel_token.unwrap_or(&default_token);
1538 tracing::info!(
1539 a3s.session.id = session_id.unwrap_or("none"),
1540 a3s.agent.max_turns = self.config.max_tool_rounds,
1541 "a3s.agent.execute started"
1542 );
1543
1544 let result = if self.config.planning_enabled {
1546 self.execute_with_planning(history, prompt, event_tx).await
1547 } else {
1548 self.execute_loop(history, prompt, session_id, event_tx, token)
1549 .await
1550 };
1551
1552 match &result {
1553 Ok(r) => {
1554 tracing::info!(
1555 a3s.agent.tool_calls_count = r.tool_calls_count,
1556 a3s.llm.total_tokens = r.usage.total_tokens,
1557 "a3s.agent.execute completed"
1558 );
1559 self.fire_post_response(
1561 session_id.unwrap_or(""),
1562 &r.text,
1563 r.tool_calls_count,
1564 &r.usage,
1565 0, )
1567 .await;
1568 }
1569 Err(e) => {
1570 tracing::warn!(
1571 error = %e,
1572 "a3s.agent.execute failed"
1573 );
1574 self.fire_on_error(
1576 session_id.unwrap_or(""),
1577 ErrorType::Other,
1578 &e.to_string(),
1579 serde_json::json!({"phase": "execute"}),
1580 )
1581 .await;
1582 }
1583 }
1584
1585 result
1586 }
1587
1588 async fn execute_loop(
1594 &self,
1595 history: &[Message],
1596 prompt: &str,
1597 session_id: Option<&str>,
1598 event_tx: Option<mpsc::Sender<AgentEvent>>,
1599 cancel_token: &tokio_util::sync::CancellationToken,
1600 ) -> Result<AgentResult> {
1601 self.execute_loop_inner(history, prompt, prompt, session_id, event_tx, cancel_token)
1604 .await
1605 }
1606
1607 async fn execute_loop_inner(
1612 &self,
1613 history: &[Message],
1614 msg_prompt: &str,
1615 effective_prompt: &str,
1616 session_id: Option<&str>,
1617 event_tx: Option<mpsc::Sender<AgentEvent>>,
1618 cancel_token: &tokio_util::sync::CancellationToken,
1619 ) -> Result<AgentResult> {
1620 let mut messages = history.to_vec();
1621 let mut total_usage = TokenUsage::default();
1622 let mut tool_calls_count = 0;
1623 let mut turn = 0;
1624 let mut parse_error_count: u32 = 0;
1626 let mut continuation_count: u32 = 0;
1628 let mut recent_tool_signatures: Vec<String> = Vec::new();
1629
1630 if let Some(tx) = &event_tx {
1632 tx.send(AgentEvent::Start {
1633 prompt: effective_prompt.to_string(),
1634 })
1635 .await
1636 .ok();
1637 }
1638
1639 let _queue_forward_handle =
1641 if let (Some(ref queue), Some(ref tx)) = (&self.command_queue, &event_tx) {
1642 let mut rx = queue.subscribe();
1643 let tx = tx.clone();
1644 Some(tokio::spawn(async move {
1645 while let Ok(event) = rx.recv().await {
1646 if tx.send(event).await.is_err() {
1647 break;
1648 }
1649 }
1650 }))
1651 } else {
1652 None
1653 };
1654
1655 let built_system_prompt = Some(self.system_prompt());
1657 let hooked_prompt = if let Some(modified) = self
1658 .fire_pre_prompt(
1659 session_id.unwrap_or(""),
1660 effective_prompt,
1661 &built_system_prompt,
1662 messages.len(),
1663 )
1664 .await
1665 {
1666 modified
1667 } else {
1668 effective_prompt.to_string()
1669 };
1670 let effective_prompt = hooked_prompt.as_str();
1671
1672 if let Some(ref sp) = self.config.security_provider {
1674 sp.taint_input(effective_prompt);
1675 }
1676
1677 let system_with_memory = if let Some(ref memory) = self.config.memory {
1679 match memory.recall_similar(effective_prompt, 5).await {
1680 Ok(items) if !items.is_empty() => {
1681 if let Some(tx) = &event_tx {
1682 for item in &items {
1683 tx.send(AgentEvent::MemoryRecalled {
1684 memory_id: item.id.clone(),
1685 content: item.content.clone(),
1686 relevance: item.relevance_score(),
1687 })
1688 .await
1689 .ok();
1690 }
1691 tx.send(AgentEvent::MemoriesSearched {
1692 query: Some(effective_prompt.to_string()),
1693 tags: Vec::new(),
1694 result_count: items.len(),
1695 })
1696 .await
1697 .ok();
1698 }
1699 let memory_context = items
1700 .iter()
1701 .map(|i| format!("- {}", i.content))
1702 .collect::<Vec<_>>()
1703 .join(
1704 "
1705",
1706 );
1707 let base = self.system_prompt();
1708 Some(format!(
1709 "{}
1710
1711## Relevant past experience
1712{}",
1713 base, memory_context
1714 ))
1715 }
1716 _ => Some(self.system_prompt()),
1717 }
1718 } else {
1719 Some(self.system_prompt())
1720 };
1721
1722 let augmented_system = if !self.config.context_providers.is_empty() {
1724 if let Some(tx) = &event_tx {
1726 let provider_names: Vec<String> = self
1727 .config
1728 .context_providers
1729 .iter()
1730 .map(|p| p.name().to_string())
1731 .collect();
1732 tx.send(AgentEvent::ContextResolving {
1733 providers: provider_names,
1734 })
1735 .await
1736 .ok();
1737 }
1738
1739 tracing::info!(
1740 a3s.context.providers = self.config.context_providers.len() as i64,
1741 "Context resolution started"
1742 );
1743 let context_results = self.resolve_context(effective_prompt, session_id).await;
1744
1745 if let Some(tx) = &event_tx {
1747 let total_items: usize = context_results.iter().map(|r| r.items.len()).sum();
1748 let total_tokens: usize = context_results.iter().map(|r| r.total_tokens).sum();
1749
1750 tracing::info!(
1751 context_items = total_items,
1752 context_tokens = total_tokens,
1753 "Context resolution completed"
1754 );
1755
1756 tx.send(AgentEvent::ContextResolved {
1757 total_items,
1758 total_tokens,
1759 })
1760 .await
1761 .ok();
1762 }
1763
1764 self.build_augmented_system_prompt(&context_results)
1765 } else {
1766 Some(self.system_prompt())
1767 };
1768
1769 let base_prompt = self.system_prompt();
1771 let augmented_system = match (augmented_system, system_with_memory) {
1772 (Some(ctx), Some(mem)) if ctx != mem => Some(ctx.replacen(&base_prompt, &mem, 1)),
1773 (Some(ctx), _) => Some(ctx),
1774 (None, mem) => mem,
1775 };
1776
1777 if !msg_prompt.is_empty() {
1779 messages.push(Message::user(msg_prompt));
1780 }
1781
1782 loop {
1783 turn += 1;
1784
1785 if turn > self.config.max_tool_rounds {
1786 let error = format!("Max tool rounds ({}) exceeded", self.config.max_tool_rounds);
1787 if let Some(tx) = &event_tx {
1788 tx.send(AgentEvent::Error {
1789 message: error.clone(),
1790 })
1791 .await
1792 .ok();
1793 }
1794 anyhow::bail!(error);
1795 }
1796
1797 if let Some(tx) = &event_tx {
1799 tx.send(AgentEvent::TurnStart { turn }).await.ok();
1800 }
1801
1802 tracing::info!(
1803 turn = turn,
1804 max_turns = self.config.max_tool_rounds,
1805 "Agent turn started"
1806 );
1807
1808 tracing::info!(
1810 a3s.llm.streaming = event_tx.is_some(),
1811 "LLM completion started"
1812 );
1813
1814 self.fire_generate_start(
1816 session_id.unwrap_or(""),
1817 effective_prompt,
1818 &augmented_system,
1819 )
1820 .await;
1821
1822 let llm_start = std::time::Instant::now();
1823 let response = {
1827 let threshold = self.config.circuit_breaker_threshold.max(1);
1828 let mut attempt = 0u32;
1829 loop {
1830 attempt += 1;
1831 let result = self
1832 .call_llm(
1833 &messages,
1834 augmented_system.as_deref(),
1835 &event_tx,
1836 cancel_token,
1837 )
1838 .await;
1839 match result {
1840 Ok(r) => {
1841 break r;
1842 }
1843 Err(e) if cancel_token.is_cancelled() => {
1845 anyhow::bail!(e);
1846 }
1847 Err(e) if attempt < threshold && (event_tx.is_none() || attempt == 1) => {
1849 tracing::warn!(
1850 turn = turn,
1851 attempt = attempt,
1852 threshold = threshold,
1853 error = %e,
1854 "LLM call failed, will retry"
1855 );
1856 tokio::time::sleep(Duration::from_millis(100 * attempt as u64)).await;
1857 }
1858 Err(e) => {
1860 let msg = if attempt > 1 {
1861 format!(
1862 "LLM circuit breaker triggered: failed after {} attempt(s): {}",
1863 attempt, e
1864 )
1865 } else {
1866 format!("LLM call failed: {}", e)
1867 };
1868 tracing::error!(turn = turn, attempt = attempt, "{}", msg);
1869 self.fire_on_error(
1871 session_id.unwrap_or(""),
1872 ErrorType::LlmFailure,
1873 &msg,
1874 serde_json::json!({"turn": turn, "attempt": attempt}),
1875 )
1876 .await;
1877 if let Some(tx) = &event_tx {
1878 tx.send(AgentEvent::Error {
1879 message: msg.clone(),
1880 })
1881 .await
1882 .ok();
1883 }
1884 anyhow::bail!(msg);
1885 }
1886 }
1887 }
1888 };
1889
1890 total_usage.prompt_tokens += response.usage.prompt_tokens;
1892 total_usage.completion_tokens += response.usage.completion_tokens;
1893 total_usage.total_tokens += response.usage.total_tokens;
1894 let llm_duration = llm_start.elapsed();
1896 tracing::info!(
1897 turn = turn,
1898 streaming = event_tx.is_some(),
1899 prompt_tokens = response.usage.prompt_tokens,
1900 completion_tokens = response.usage.completion_tokens,
1901 total_tokens = response.usage.total_tokens,
1902 stop_reason = response.stop_reason.as_deref().unwrap_or("unknown"),
1903 duration_ms = llm_duration.as_millis() as u64,
1904 "LLM completion finished"
1905 );
1906
1907 self.fire_generate_end(
1909 session_id.unwrap_or(""),
1910 effective_prompt,
1911 &response,
1912 llm_duration.as_millis() as u64,
1913 )
1914 .await;
1915
1916 crate::telemetry::record_llm_usage(
1918 response.usage.prompt_tokens,
1919 response.usage.completion_tokens,
1920 response.usage.total_tokens,
1921 response.stop_reason.as_deref(),
1922 );
1923 tracing::info!(
1925 turn = turn,
1926 a3s.llm.total_tokens = response.usage.total_tokens,
1927 "Turn token usage"
1928 );
1929
1930 messages.push(response.message.clone());
1932
1933 let tool_calls = response.tool_calls();
1935
1936 if let Some(tx) = &event_tx {
1938 tx.send(AgentEvent::TurnEnd {
1939 turn,
1940 usage: response.usage.clone(),
1941 })
1942 .await
1943 .ok();
1944 }
1945
1946 if self.config.auto_compact {
1948 let used = response.usage.prompt_tokens;
1949 let max = self.config.max_context_tokens;
1950 let threshold = self.config.auto_compact_threshold;
1951
1952 if crate::session::compaction::should_auto_compact(used, max, threshold) {
1953 let before_len = messages.len();
1954 let percent_before = used as f32 / max as f32;
1955
1956 tracing::info!(
1957 used_tokens = used,
1958 max_tokens = max,
1959 percent = percent_before,
1960 threshold = threshold,
1961 "Auto-compact triggered"
1962 );
1963
1964 if let Some(pruned) = crate::session::compaction::prune_tool_outputs(&messages)
1966 {
1967 messages = pruned;
1968 tracing::info!("Tool output pruning applied");
1969 }
1970
1971 if let Ok(Some(compacted)) = crate::session::compaction::compact_messages(
1973 session_id.unwrap_or(""),
1974 &messages,
1975 &self.llm_client,
1976 )
1977 .await
1978 {
1979 messages = compacted;
1980 }
1981
1982 if let Some(tx) = &event_tx {
1984 tx.send(AgentEvent::ContextCompacted {
1985 session_id: session_id.unwrap_or("").to_string(),
1986 before_messages: before_len,
1987 after_messages: messages.len(),
1988 percent_before,
1989 })
1990 .await
1991 .ok();
1992 }
1993 }
1994 }
1995
1996 if tool_calls.is_empty() {
1997 let final_text = response.text();
2000
2001 if self.config.continuation_enabled
2002 && continuation_count < self.config.max_continuation_turns
2003 && turn < self.config.max_tool_rounds && Self::looks_incomplete(&final_text)
2005 {
2006 continuation_count += 1;
2007 tracing::info!(
2008 turn = turn,
2009 continuation = continuation_count,
2010 max_continuation = self.config.max_continuation_turns,
2011 "Injecting continuation message — response looks incomplete"
2012 );
2013 messages.push(Message::user(crate::prompts::CONTINUATION));
2015 continue;
2016 }
2017
2018 let final_text = if let Some(ref sp) = self.config.security_provider {
2020 sp.sanitize_output(&final_text)
2021 } else {
2022 final_text
2023 };
2024
2025 tracing::info!(
2027 tool_calls_count = tool_calls_count,
2028 total_prompt_tokens = total_usage.prompt_tokens,
2029 total_completion_tokens = total_usage.completion_tokens,
2030 total_tokens = total_usage.total_tokens,
2031 turns = turn,
2032 "Agent execution completed"
2033 );
2034
2035 if let Some(tx) = &event_tx {
2036 tx.send(AgentEvent::End {
2037 text: final_text.clone(),
2038 usage: total_usage.clone(),
2039 meta: response.meta.clone(),
2040 })
2041 .await
2042 .ok();
2043 }
2044
2045 if let Some(sid) = session_id {
2047 self.notify_turn_complete(sid, effective_prompt, &final_text)
2048 .await;
2049 }
2050
2051 return Ok(AgentResult {
2052 text: final_text,
2053 messages,
2054 usage: total_usage,
2055 tool_calls_count,
2056 });
2057 }
2058
2059 let tool_calls = if self.config.hook_engine.is_none()
2063 && self.config.confirmation_manager.is_none()
2064 && tool_calls.len() > 1
2065 && tool_calls
2066 .iter()
2067 .all(|tc| Self::is_parallel_safe_write(&tc.name, &tc.args))
2068 && {
2069 let paths: Vec<_> = tool_calls
2071 .iter()
2072 .filter_map(|tc| Self::extract_write_path(&tc.args))
2073 .collect();
2074 paths.len() == tool_calls.len()
2075 && paths.iter().collect::<std::collections::HashSet<_>>().len()
2076 == paths.len()
2077 } {
2078 tracing::info!(
2079 count = tool_calls.len(),
2080 "Parallel write batch: executing {} independent file writes concurrently",
2081 tool_calls.len()
2082 );
2083
2084 let futures: Vec<_> = tool_calls
2085 .iter()
2086 .map(|tc| {
2087 let ctx = self.tool_context.clone();
2088 let executor = Arc::clone(&self.tool_executor);
2089 let name = tc.name.clone();
2090 let args = tc.args.clone();
2091 async move { executor.execute_with_context(&name, &args, &ctx).await }
2092 })
2093 .collect();
2094
2095 let results = join_all(futures).await;
2096
2097 for (tc, result) in tool_calls.iter().zip(results) {
2099 tool_calls_count += 1;
2100 let (output, exit_code, is_error, metadata, images) =
2101 Self::tool_result_to_tuple(result);
2102
2103 let output = if let Some(ref sp) = self.config.security_provider {
2104 sp.sanitize_output(&output)
2105 } else {
2106 output
2107 };
2108
2109 if let Some(tx) = &event_tx {
2110 tx.send(AgentEvent::ToolEnd {
2111 id: tc.id.clone(),
2112 name: tc.name.clone(),
2113 output: output.clone(),
2114 exit_code,
2115 metadata,
2116 })
2117 .await
2118 .ok();
2119 }
2120
2121 if images.is_empty() {
2122 messages.push(Message::tool_result(&tc.id, &output, is_error));
2123 } else {
2124 messages.push(Message::tool_result_with_images(
2125 &tc.id, &output, &images, is_error,
2126 ));
2127 }
2128 }
2129
2130 continue;
2132 } else {
2133 tool_calls
2134 };
2135
2136 for tool_call in tool_calls {
2137 tool_calls_count += 1;
2138
2139 let tool_start = std::time::Instant::now();
2140
2141 tracing::info!(
2142 tool_name = tool_call.name.as_str(),
2143 tool_id = tool_call.id.as_str(),
2144 "Tool execution started"
2145 );
2146
2147 if let Some(parse_error) =
2153 tool_call.args.get("__parse_error").and_then(|v| v.as_str())
2154 {
2155 parse_error_count += 1;
2156 let error_msg = format!("Error: {}", parse_error);
2157 tracing::warn!(
2158 tool = tool_call.name.as_str(),
2159 parse_error_count = parse_error_count,
2160 max_parse_retries = self.config.max_parse_retries,
2161 "Malformed tool arguments from LLM"
2162 );
2163
2164 if let Some(tx) = &event_tx {
2165 tx.send(AgentEvent::ToolEnd {
2166 id: tool_call.id.clone(),
2167 name: tool_call.name.clone(),
2168 output: error_msg.clone(),
2169 exit_code: 1,
2170 metadata: None,
2171 })
2172 .await
2173 .ok();
2174 }
2175
2176 messages.push(Message::tool_result(&tool_call.id, &error_msg, true));
2177
2178 if parse_error_count > self.config.max_parse_retries {
2179 let msg = format!(
2180 "LLM produced malformed tool arguments {} time(s) in a row \
2181 (max_parse_retries={}); giving up",
2182 parse_error_count, self.config.max_parse_retries
2183 );
2184 tracing::error!("{}", msg);
2185 if let Some(tx) = &event_tx {
2186 tx.send(AgentEvent::Error {
2187 message: msg.clone(),
2188 })
2189 .await
2190 .ok();
2191 }
2192 anyhow::bail!(msg);
2193 }
2194 continue;
2195 }
2196
2197 parse_error_count = 0;
2199
2200 if let Some(ref registry) = self.config.skill_registry {
2202 let instruction_skills =
2203 registry.by_kind(crate::skills::SkillKind::Instruction);
2204 let has_restrictions =
2205 instruction_skills.iter().any(|s| s.allowed_tools.is_some());
2206 if has_restrictions {
2207 let allowed = instruction_skills
2208 .iter()
2209 .any(|s| s.is_tool_allowed(&tool_call.name));
2210 if !allowed {
2211 let msg = format!(
2212 "Tool '{}' is not allowed by any active skill.",
2213 tool_call.name
2214 );
2215 tracing::info!(
2216 tool_name = tool_call.name.as_str(),
2217 "Tool blocked by skill registry"
2218 );
2219 if let Some(tx) = &event_tx {
2220 tx.send(AgentEvent::PermissionDenied {
2221 tool_id: tool_call.id.clone(),
2222 tool_name: tool_call.name.clone(),
2223 args: tool_call.args.clone(),
2224 reason: msg.clone(),
2225 })
2226 .await
2227 .ok();
2228 }
2229 messages.push(Message::tool_result(&tool_call.id, &msg, true));
2230 continue;
2231 }
2232 }
2233 }
2234
2235 if let Some(HookResult::Block(reason)) = self
2237 .fire_pre_tool_use(
2238 session_id.unwrap_or(""),
2239 &tool_call.name,
2240 &tool_call.args,
2241 recent_tool_signatures.clone(),
2242 )
2243 .await
2244 {
2245 let msg = format!("Tool '{}' blocked by hook: {}", tool_call.name, reason);
2246 tracing::info!(
2247 tool_name = tool_call.name.as_str(),
2248 "Tool blocked by PreToolUse hook"
2249 );
2250
2251 if let Some(tx) = &event_tx {
2252 tx.send(AgentEvent::PermissionDenied {
2253 tool_id: tool_call.id.clone(),
2254 tool_name: tool_call.name.clone(),
2255 args: tool_call.args.clone(),
2256 reason: reason.clone(),
2257 })
2258 .await
2259 .ok();
2260 }
2261
2262 messages.push(Message::tool_result(&tool_call.id, &msg, true));
2263 continue;
2264 }
2265
2266 let permission_decision = if let Some(checker) = &self.config.permission_checker {
2268 checker.check(&tool_call.name, &tool_call.args)
2269 } else {
2270 PermissionDecision::Ask
2272 };
2273
2274 let (output, exit_code, is_error, metadata, images) = match permission_decision {
2275 PermissionDecision::Deny => {
2276 tracing::info!(
2277 tool_name = tool_call.name.as_str(),
2278 permission = "deny",
2279 "Tool permission denied"
2280 );
2281 let denial_msg = format!(
2283 "Permission denied: Tool '{}' is blocked by permission policy.",
2284 tool_call.name
2285 );
2286
2287 if let Some(tx) = &event_tx {
2289 tx.send(AgentEvent::PermissionDenied {
2290 tool_id: tool_call.id.clone(),
2291 tool_name: tool_call.name.clone(),
2292 args: tool_call.args.clone(),
2293 reason: "Blocked by deny rule in permission policy".to_string(),
2294 })
2295 .await
2296 .ok();
2297 }
2298
2299 (denial_msg, 1, true, None, Vec::new())
2300 }
2301 PermissionDecision::Allow => {
2302 tracing::info!(
2303 tool_name = tool_call.name.as_str(),
2304 permission = "allow",
2305 "Tool permission: allow"
2306 );
2307 let stream_ctx =
2309 self.streaming_tool_context(&event_tx, &tool_call.id, &tool_call.name);
2310 let result = self
2311 .execute_tool_queued_or_direct(
2312 &tool_call.name,
2313 &tool_call.args,
2314 &stream_ctx,
2315 )
2316 .await;
2317
2318 Self::tool_result_to_tuple(result)
2319 }
2320 PermissionDecision::Ask => {
2321 tracing::info!(
2322 tool_name = tool_call.name.as_str(),
2323 permission = "ask",
2324 "Tool permission: ask"
2325 );
2326 if let Some(cm) = &self.config.confirmation_manager {
2328 if !cm.requires_confirmation(&tool_call.name).await {
2330 let stream_ctx = self.streaming_tool_context(
2331 &event_tx,
2332 &tool_call.id,
2333 &tool_call.name,
2334 );
2335 let result = self
2336 .execute_tool_queued_or_direct(
2337 &tool_call.name,
2338 &tool_call.args,
2339 &stream_ctx,
2340 )
2341 .await;
2342
2343 let (output, exit_code, is_error, metadata, images) =
2344 Self::tool_result_to_tuple(result);
2345
2346 if images.is_empty() {
2348 messages.push(Message::tool_result(
2349 &tool_call.id,
2350 &output,
2351 is_error,
2352 ));
2353 } else {
2354 messages.push(Message::tool_result_with_images(
2355 &tool_call.id,
2356 &output,
2357 &images,
2358 is_error,
2359 ));
2360 }
2361
2362 let tool_duration = tool_start.elapsed();
2364 crate::telemetry::record_tool_result(exit_code, tool_duration);
2365
2366 if let Some(tx) = &event_tx {
2368 tx.send(AgentEvent::ToolEnd {
2369 id: tool_call.id.clone(),
2370 name: tool_call.name.clone(),
2371 output: output.clone(),
2372 exit_code,
2373 metadata,
2374 })
2375 .await
2376 .ok();
2377 }
2378
2379 self.fire_post_tool_use(
2381 session_id.unwrap_or(""),
2382 &tool_call.name,
2383 &tool_call.args,
2384 &output,
2385 exit_code == 0,
2386 tool_duration.as_millis() as u64,
2387 )
2388 .await;
2389
2390 continue; }
2392
2393 let policy = cm.policy().await;
2395 let timeout_ms = policy.default_timeout_ms;
2396 let timeout_action = policy.timeout_action;
2397
2398 let rx = cm
2400 .request_confirmation(
2401 &tool_call.id,
2402 &tool_call.name,
2403 &tool_call.args,
2404 )
2405 .await;
2406
2407 if let Some(tx) = &event_tx {
2411 tx.send(AgentEvent::ConfirmationRequired {
2412 tool_id: tool_call.id.clone(),
2413 tool_name: tool_call.name.clone(),
2414 args: tool_call.args.clone(),
2415 timeout_ms,
2416 })
2417 .await
2418 .ok();
2419 }
2420
2421 let confirmation_result =
2423 tokio::time::timeout(Duration::from_millis(timeout_ms), rx).await;
2424
2425 match confirmation_result {
2426 Ok(Ok(response)) => {
2427 if let Some(tx) = &event_tx {
2429 tx.send(AgentEvent::ConfirmationReceived {
2430 tool_id: tool_call.id.clone(),
2431 approved: response.approved,
2432 reason: response.reason.clone(),
2433 })
2434 .await
2435 .ok();
2436 }
2437 if response.approved {
2438 let stream_ctx = self.streaming_tool_context(
2439 &event_tx,
2440 &tool_call.id,
2441 &tool_call.name,
2442 );
2443 let result = self
2444 .execute_tool_queued_or_direct(
2445 &tool_call.name,
2446 &tool_call.args,
2447 &stream_ctx,
2448 )
2449 .await;
2450
2451 Self::tool_result_to_tuple(result)
2452 } else {
2453 let rejection_msg = format!(
2454 "Tool '{}' execution was REJECTED by the user. Reason: {}. \
2455 DO NOT retry this tool call unless the user explicitly asks you to.",
2456 tool_call.name,
2457 response.reason.unwrap_or_else(|| "No reason provided".to_string())
2458 );
2459 (rejection_msg, 1, true, None, Vec::new())
2460 }
2461 }
2462 Ok(Err(_)) => {
2463 if let Some(tx) = &event_tx {
2465 tx.send(AgentEvent::ConfirmationTimeout {
2466 tool_id: tool_call.id.clone(),
2467 action_taken: "rejected".to_string(),
2468 })
2469 .await
2470 .ok();
2471 }
2472 let msg = format!(
2473 "Tool '{}' confirmation failed: confirmation channel closed",
2474 tool_call.name
2475 );
2476 (msg, 1, true, None, Vec::new())
2477 }
2478 Err(_) => {
2479 cm.check_timeouts().await;
2480
2481 if let Some(tx) = &event_tx {
2483 tx.send(AgentEvent::ConfirmationTimeout {
2484 tool_id: tool_call.id.clone(),
2485 action_taken: match timeout_action {
2486 crate::hitl::TimeoutAction::Reject => {
2487 "rejected".to_string()
2488 }
2489 crate::hitl::TimeoutAction::AutoApprove => {
2490 "auto_approved".to_string()
2491 }
2492 },
2493 })
2494 .await
2495 .ok();
2496 }
2497
2498 match timeout_action {
2499 crate::hitl::TimeoutAction::Reject => {
2500 let msg = format!(
2501 "Tool '{}' execution was REJECTED: user confirmation timed out after {}ms. \
2502 DO NOT retry this tool call — the user did not approve it. \
2503 Inform the user that the operation requires their approval and ask them to try again.",
2504 tool_call.name, timeout_ms
2505 );
2506 (msg, 1, true, None, Vec::new())
2507 }
2508 crate::hitl::TimeoutAction::AutoApprove => {
2509 let stream_ctx = self.streaming_tool_context(
2510 &event_tx,
2511 &tool_call.id,
2512 &tool_call.name,
2513 );
2514 let result = self
2515 .execute_tool_queued_or_direct(
2516 &tool_call.name,
2517 &tool_call.args,
2518 &stream_ctx,
2519 )
2520 .await;
2521
2522 Self::tool_result_to_tuple(result)
2523 }
2524 }
2525 }
2526 }
2527 } else {
2528 let msg = format!(
2530 "Tool '{}' requires confirmation but no HITL confirmation manager is configured. \
2531 Configure a confirmation policy to enable tool execution.",
2532 tool_call.name
2533 );
2534 tracing::warn!(
2535 tool_name = tool_call.name.as_str(),
2536 "Tool requires confirmation but no HITL manager configured"
2537 );
2538 (msg, 1, true, None, Vec::new())
2539 }
2540 }
2541 };
2542
2543 let tool_duration = tool_start.elapsed();
2544 crate::telemetry::record_tool_result(exit_code, tool_duration);
2545
2546 let output = if let Some(ref sp) = self.config.security_provider {
2548 sp.sanitize_output(&output)
2549 } else {
2550 output
2551 };
2552
2553 recent_tool_signatures.push(format!(
2554 "{}:{} => {}",
2555 tool_call.name,
2556 serde_json::to_string(&tool_call.args).unwrap_or_default(),
2557 if is_error { "error" } else { "ok" }
2558 ));
2559 if recent_tool_signatures.len() > 8 {
2560 let overflow = recent_tool_signatures.len() - 8;
2561 recent_tool_signatures.drain(0..overflow);
2562 }
2563
2564 self.fire_post_tool_use(
2566 session_id.unwrap_or(""),
2567 &tool_call.name,
2568 &tool_call.args,
2569 &output,
2570 exit_code == 0,
2571 tool_duration.as_millis() as u64,
2572 )
2573 .await;
2574
2575 if let Some(ref memory) = self.config.memory {
2577 let tools_used = [tool_call.name.clone()];
2578 let remember_result = if exit_code == 0 {
2579 memory
2580 .remember_success(effective_prompt, &tools_used, &output)
2581 .await
2582 } else {
2583 memory
2584 .remember_failure(effective_prompt, &output, &tools_used)
2585 .await
2586 };
2587 match remember_result {
2588 Ok(()) => {
2589 if let Some(tx) = &event_tx {
2590 let item_type = if exit_code == 0 { "success" } else { "failure" };
2591 tx.send(AgentEvent::MemoryStored {
2592 memory_id: uuid::Uuid::new_v4().to_string(),
2593 memory_type: item_type.to_string(),
2594 importance: if exit_code == 0 { 0.8 } else { 0.9 },
2595 tags: vec![item_type.to_string(), tool_call.name.clone()],
2596 })
2597 .await
2598 .ok();
2599 }
2600 }
2601 Err(e) => {
2602 tracing::warn!("Failed to store memory after tool execution: {}", e);
2603 }
2604 }
2605 }
2606
2607 if let Some(tx) = &event_tx {
2609 tx.send(AgentEvent::ToolEnd {
2610 id: tool_call.id.clone(),
2611 name: tool_call.name.clone(),
2612 output: output.clone(),
2613 exit_code,
2614 metadata,
2615 })
2616 .await
2617 .ok();
2618 }
2619
2620 if images.is_empty() {
2622 messages.push(Message::tool_result(&tool_call.id, &output, is_error));
2623 } else {
2624 messages.push(Message::tool_result_with_images(
2625 &tool_call.id,
2626 &output,
2627 &images,
2628 is_error,
2629 ));
2630 }
2631 }
2632 }
2633 }
2634
2635 pub async fn execute_streaming(
2637 &self,
2638 history: &[Message],
2639 prompt: &str,
2640 ) -> Result<(
2641 mpsc::Receiver<AgentEvent>,
2642 tokio::task::JoinHandle<Result<AgentResult>>,
2643 tokio_util::sync::CancellationToken,
2644 )> {
2645 let (tx, rx) = mpsc::channel(100);
2646 let cancel_token = tokio_util::sync::CancellationToken::new();
2647
2648 let llm_client = self.llm_client.clone();
2649 let tool_executor = self.tool_executor.clone();
2650 let tool_context = self.tool_context.clone();
2651 let config = self.config.clone();
2652 let tool_metrics = self.tool_metrics.clone();
2653 let command_queue = self.command_queue.clone();
2654 let history = history.to_vec();
2655 let prompt = prompt.to_string();
2656 let token_clone = cancel_token.clone();
2657
2658 let handle = tokio::spawn(async move {
2659 let mut agent = AgentLoop::new(llm_client, tool_executor, tool_context, config);
2660 if let Some(metrics) = tool_metrics {
2661 agent = agent.with_tool_metrics(metrics);
2662 }
2663 if let Some(queue) = command_queue {
2664 agent = agent.with_queue(queue);
2665 }
2666 agent
2667 .execute_with_session(&history, &prompt, None, Some(tx), Some(&token_clone))
2668 .await
2669 });
2670
2671 Ok((rx, handle, cancel_token))
2672 }
2673
2674 pub async fn plan(&self, prompt: &str, _context: Option<&str>) -> Result<ExecutionPlan> {
2679 use crate::planning::LlmPlanner;
2680
2681 match LlmPlanner::create_plan(&self.llm_client, prompt).await {
2682 Ok(plan) => Ok(plan),
2683 Err(e) => {
2684 tracing::warn!("LLM plan creation failed, using fallback: {}", e);
2685 Ok(LlmPlanner::fallback_plan(prompt))
2686 }
2687 }
2688 }
2689
2690 pub async fn execute_with_planning(
2692 &self,
2693 history: &[Message],
2694 prompt: &str,
2695 event_tx: Option<mpsc::Sender<AgentEvent>>,
2696 ) -> Result<AgentResult> {
2697 if let Some(tx) = &event_tx {
2699 tx.send(AgentEvent::PlanningStart {
2700 prompt: prompt.to_string(),
2701 })
2702 .await
2703 .ok();
2704 }
2705
2706 let goal = if self.config.goal_tracking {
2708 let g = self.extract_goal(prompt).await?;
2709 if let Some(tx) = &event_tx {
2710 tx.send(AgentEvent::GoalExtracted { goal: g.clone() })
2711 .await
2712 .ok();
2713 }
2714 Some(g)
2715 } else {
2716 None
2717 };
2718
2719 let plan = self.plan(prompt, None).await?;
2721
2722 if let Some(tx) = &event_tx {
2724 tx.send(AgentEvent::PlanningEnd {
2725 estimated_steps: plan.steps.len(),
2726 plan: plan.clone(),
2727 })
2728 .await
2729 .ok();
2730 }
2731
2732 let plan_start = std::time::Instant::now();
2733
2734 let result = self.execute_plan(history, &plan, event_tx.clone()).await?;
2736
2737 if self.config.goal_tracking {
2739 if let Some(ref g) = goal {
2740 let achieved = self.check_goal_achievement(g, &result.text).await?;
2741 if achieved {
2742 if let Some(tx) = &event_tx {
2743 tx.send(AgentEvent::GoalAchieved {
2744 goal: g.description.clone(),
2745 total_steps: result.messages.len(),
2746 duration_ms: plan_start.elapsed().as_millis() as i64,
2747 })
2748 .await
2749 .ok();
2750 }
2751 }
2752 }
2753 }
2754
2755 Ok(result)
2756 }
2757
2758 async fn execute_plan(
2765 &self,
2766 history: &[Message],
2767 plan: &ExecutionPlan,
2768 event_tx: Option<mpsc::Sender<AgentEvent>>,
2769 ) -> Result<AgentResult> {
2770 let mut plan = plan.clone();
2771 let mut current_history = history.to_vec();
2772 let mut total_usage = TokenUsage::default();
2773 let mut tool_calls_count = 0;
2774 let total_steps = plan.steps.len();
2775
2776 let steps_text = plan
2778 .steps
2779 .iter()
2780 .enumerate()
2781 .map(|(i, step)| format!("{}. {}", i + 1, step.content))
2782 .collect::<Vec<_>>()
2783 .join("\n");
2784 current_history.push(Message::user(&crate::prompts::render(
2785 crate::prompts::PLAN_EXECUTE_GOAL,
2786 &[("goal", &plan.goal), ("steps", &steps_text)],
2787 )));
2788
2789 loop {
2790 let ready: Vec<String> = plan
2791 .get_ready_steps()
2792 .iter()
2793 .map(|s| s.id.clone())
2794 .collect();
2795
2796 if ready.is_empty() {
2797 if plan.has_deadlock() {
2799 tracing::warn!(
2800 "Plan deadlock detected: {} pending steps with unresolvable dependencies",
2801 plan.pending_count()
2802 );
2803 }
2804 break;
2805 }
2806
2807 if ready.len() == 1 {
2808 let step_id = &ready[0];
2810 let step = plan
2811 .steps
2812 .iter()
2813 .find(|s| s.id == *step_id)
2814 .ok_or_else(|| anyhow::anyhow!("step '{}' not found in plan", step_id))?
2815 .clone();
2816 let step_number = plan
2817 .steps
2818 .iter()
2819 .position(|s| s.id == *step_id)
2820 .unwrap_or(0)
2821 + 1;
2822
2823 if let Some(tx) = &event_tx {
2825 tx.send(AgentEvent::StepStart {
2826 step_id: step.id.clone(),
2827 description: step.content.clone(),
2828 step_number,
2829 total_steps,
2830 })
2831 .await
2832 .ok();
2833 }
2834
2835 plan.mark_status(&step.id, TaskStatus::InProgress);
2836
2837 let step_prompt = crate::prompts::render(
2838 crate::prompts::PLAN_EXECUTE_STEP,
2839 &[
2840 ("step_num", &step_number.to_string()),
2841 ("description", &step.content),
2842 ],
2843 );
2844
2845 match self
2846 .execute_loop(
2847 ¤t_history,
2848 &step_prompt,
2849 None,
2850 event_tx.clone(),
2851 &tokio_util::sync::CancellationToken::new(),
2852 )
2853 .await
2854 {
2855 Ok(result) => {
2856 current_history = result.messages.clone();
2857 total_usage.prompt_tokens += result.usage.prompt_tokens;
2858 total_usage.completion_tokens += result.usage.completion_tokens;
2859 total_usage.total_tokens += result.usage.total_tokens;
2860 tool_calls_count += result.tool_calls_count;
2861 plan.mark_status(&step.id, TaskStatus::Completed);
2862
2863 if let Some(tx) = &event_tx {
2864 tx.send(AgentEvent::StepEnd {
2865 step_id: step.id.clone(),
2866 status: TaskStatus::Completed,
2867 step_number,
2868 total_steps,
2869 })
2870 .await
2871 .ok();
2872 }
2873 }
2874 Err(e) => {
2875 tracing::error!("Plan step '{}' failed: {}", step.id, e);
2876 plan.mark_status(&step.id, TaskStatus::Failed);
2877
2878 if let Some(tx) = &event_tx {
2879 tx.send(AgentEvent::StepEnd {
2880 step_id: step.id.clone(),
2881 status: TaskStatus::Failed,
2882 step_number,
2883 total_steps,
2884 })
2885 .await
2886 .ok();
2887 }
2888 }
2889 }
2890 } else {
2891 let ready_steps: Vec<_> = ready
2898 .iter()
2899 .filter_map(|id| {
2900 let step = plan.steps.iter().find(|s| s.id == *id)?.clone();
2901 let step_number =
2902 plan.steps.iter().position(|s| s.id == *id).unwrap_or(0) + 1;
2903 Some((step, step_number))
2904 })
2905 .collect();
2906
2907 for (step, step_number) in &ready_steps {
2909 plan.mark_status(&step.id, TaskStatus::InProgress);
2910 if let Some(tx) = &event_tx {
2911 tx.send(AgentEvent::StepStart {
2912 step_id: step.id.clone(),
2913 description: step.content.clone(),
2914 step_number: *step_number,
2915 total_steps,
2916 })
2917 .await
2918 .ok();
2919 }
2920 }
2921
2922 let mut join_set = tokio::task::JoinSet::new();
2924 for (step, step_number) in &ready_steps {
2925 let base_history = current_history.clone();
2926 let agent_clone = self.clone();
2927 let tx = event_tx.clone();
2928 let step_clone = step.clone();
2929 let sn = *step_number;
2930
2931 join_set.spawn(async move {
2932 let prompt = crate::prompts::render(
2933 crate::prompts::PLAN_EXECUTE_STEP,
2934 &[
2935 ("step_num", &sn.to_string()),
2936 ("description", &step_clone.content),
2937 ],
2938 );
2939 let result = agent_clone
2940 .execute_loop(
2941 &base_history,
2942 &prompt,
2943 None,
2944 tx,
2945 &tokio_util::sync::CancellationToken::new(),
2946 )
2947 .await;
2948 (step_clone.id, sn, result)
2949 });
2950 }
2951
2952 let mut parallel_summaries = Vec::new();
2954 while let Some(join_result) = join_set.join_next().await {
2955 match join_result {
2956 Ok((step_id, step_number, step_result)) => match step_result {
2957 Ok(result) => {
2958 total_usage.prompt_tokens += result.usage.prompt_tokens;
2959 total_usage.completion_tokens += result.usage.completion_tokens;
2960 total_usage.total_tokens += result.usage.total_tokens;
2961 tool_calls_count += result.tool_calls_count;
2962 plan.mark_status(&step_id, TaskStatus::Completed);
2963
2964 parallel_summaries.push(format!(
2966 "- Step {} ({}): {}",
2967 step_number, step_id, result.text
2968 ));
2969
2970 if let Some(tx) = &event_tx {
2971 tx.send(AgentEvent::StepEnd {
2972 step_id,
2973 status: TaskStatus::Completed,
2974 step_number,
2975 total_steps,
2976 })
2977 .await
2978 .ok();
2979 }
2980 }
2981 Err(e) => {
2982 tracing::error!("Plan step '{}' failed: {}", step_id, e);
2983 plan.mark_status(&step_id, TaskStatus::Failed);
2984
2985 if let Some(tx) = &event_tx {
2986 tx.send(AgentEvent::StepEnd {
2987 step_id,
2988 status: TaskStatus::Failed,
2989 step_number,
2990 total_steps,
2991 })
2992 .await
2993 .ok();
2994 }
2995 }
2996 },
2997 Err(e) => {
2998 tracing::error!("JoinSet task panicked: {}", e);
2999 }
3000 }
3001 }
3002
3003 if !parallel_summaries.is_empty() {
3005 parallel_summaries.sort(); let results_text = parallel_summaries.join("\n");
3007 current_history.push(Message::user(&crate::prompts::render(
3008 crate::prompts::PLAN_PARALLEL_RESULTS,
3009 &[("results", &results_text)],
3010 )));
3011 }
3012 }
3013
3014 if self.config.goal_tracking {
3016 let completed = plan
3017 .steps
3018 .iter()
3019 .filter(|s| s.status == TaskStatus::Completed)
3020 .count();
3021 if let Some(tx) = &event_tx {
3022 tx.send(AgentEvent::GoalProgress {
3023 goal: plan.goal.clone(),
3024 progress: plan.progress(),
3025 completed_steps: completed,
3026 total_steps,
3027 })
3028 .await
3029 .ok();
3030 }
3031 }
3032 }
3033
3034 let final_text = current_history
3036 .last()
3037 .map(|m| {
3038 m.content
3039 .iter()
3040 .filter_map(|block| {
3041 if let crate::llm::ContentBlock::Text { text } = block {
3042 Some(text.as_str())
3043 } else {
3044 None
3045 }
3046 })
3047 .collect::<Vec<_>>()
3048 .join("\n")
3049 })
3050 .unwrap_or_default();
3051
3052 Ok(AgentResult {
3053 text: final_text,
3054 messages: current_history,
3055 usage: total_usage,
3056 tool_calls_count,
3057 })
3058 }
3059
3060 pub async fn extract_goal(&self, prompt: &str) -> Result<AgentGoal> {
3065 use crate::planning::LlmPlanner;
3066
3067 match LlmPlanner::extract_goal(&self.llm_client, prompt).await {
3068 Ok(goal) => Ok(goal),
3069 Err(e) => {
3070 tracing::warn!("LLM goal extraction failed, using fallback: {}", e);
3071 Ok(LlmPlanner::fallback_goal(prompt))
3072 }
3073 }
3074 }
3075
3076 pub async fn check_goal_achievement(
3081 &self,
3082 goal: &AgentGoal,
3083 current_state: &str,
3084 ) -> Result<bool> {
3085 use crate::planning::LlmPlanner;
3086
3087 match LlmPlanner::check_achievement(&self.llm_client, goal, current_state).await {
3088 Ok(result) => Ok(result.achieved),
3089 Err(e) => {
3090 tracing::warn!("LLM achievement check failed, using fallback: {}", e);
3091 let result = LlmPlanner::fallback_check_achievement(goal, current_state);
3092 Ok(result.achieved)
3093 }
3094 }
3095 }
3096}
3097
3098#[cfg(test)]
3099mod tests {
3100 use super::*;
3101 use crate::llm::{ContentBlock, StreamEvent};
3102 use crate::permissions::PermissionPolicy;
3103 use crate::tools::ToolExecutor;
3104 use std::path::PathBuf;
3105 use std::sync::atomic::{AtomicUsize, Ordering};
3106
3107 fn test_tool_context() -> ToolContext {
3109 ToolContext::new(PathBuf::from("/tmp"))
3110 }
3111
3112 #[test]
3113 fn test_agent_config_default() {
3114 let config = AgentConfig::default();
3115 assert!(config.prompt_slots.is_empty());
3116 assert!(config.tools.is_empty()); assert_eq!(config.max_tool_rounds, MAX_TOOL_ROUNDS);
3118 assert!(config.permission_checker.is_none());
3119 assert!(config.context_providers.is_empty());
3120 let registry = config
3122 .skill_registry
3123 .expect("skill_registry must be Some by default");
3124 assert!(registry.len() >= 7, "expected at least 7 built-in skills");
3125 assert!(registry.get("code-search").is_some());
3126 assert!(registry.get("find-bugs").is_some());
3127 }
3128
3129 pub(crate) struct MockLlmClient {
3135 responses: std::sync::Mutex<Vec<LlmResponse>>,
3137 pub(crate) call_count: AtomicUsize,
3139 }
3140
3141 impl MockLlmClient {
3142 pub(crate) fn new(responses: Vec<LlmResponse>) -> Self {
3143 Self {
3144 responses: std::sync::Mutex::new(responses),
3145 call_count: AtomicUsize::new(0),
3146 }
3147 }
3148
3149 pub(crate) fn text_response(text: &str) -> LlmResponse {
3151 LlmResponse {
3152 message: Message {
3153 role: "assistant".to_string(),
3154 content: vec![ContentBlock::Text {
3155 text: text.to_string(),
3156 }],
3157 reasoning_content: None,
3158 },
3159 usage: TokenUsage {
3160 prompt_tokens: 10,
3161 completion_tokens: 5,
3162 total_tokens: 15,
3163 cache_read_tokens: None,
3164 cache_write_tokens: None,
3165 },
3166 stop_reason: Some("end_turn".to_string()),
3167 meta: None,
3168 }
3169 }
3170
3171 pub(crate) fn tool_call_response(
3173 tool_id: &str,
3174 tool_name: &str,
3175 args: serde_json::Value,
3176 ) -> LlmResponse {
3177 LlmResponse {
3178 message: Message {
3179 role: "assistant".to_string(),
3180 content: vec![ContentBlock::ToolUse {
3181 id: tool_id.to_string(),
3182 name: tool_name.to_string(),
3183 input: args,
3184 }],
3185 reasoning_content: None,
3186 },
3187 usage: TokenUsage {
3188 prompt_tokens: 10,
3189 completion_tokens: 5,
3190 total_tokens: 15,
3191 cache_read_tokens: None,
3192 cache_write_tokens: None,
3193 },
3194 stop_reason: Some("tool_use".to_string()),
3195 meta: None,
3196 }
3197 }
3198 }
3199
3200 #[async_trait::async_trait]
3201 impl LlmClient for MockLlmClient {
3202 async fn complete(
3203 &self,
3204 _messages: &[Message],
3205 _system: Option<&str>,
3206 _tools: &[ToolDefinition],
3207 ) -> Result<LlmResponse> {
3208 self.call_count.fetch_add(1, Ordering::SeqCst);
3209 let mut responses = self.responses.lock().unwrap();
3210 if responses.is_empty() {
3211 anyhow::bail!("No more mock responses available");
3212 }
3213 Ok(responses.remove(0))
3214 }
3215
3216 async fn complete_streaming(
3217 &self,
3218 _messages: &[Message],
3219 _system: Option<&str>,
3220 _tools: &[ToolDefinition],
3221 ) -> Result<mpsc::Receiver<StreamEvent>> {
3222 self.call_count.fetch_add(1, Ordering::SeqCst);
3223 let mut responses = self.responses.lock().unwrap();
3224 if responses.is_empty() {
3225 anyhow::bail!("No more mock responses available");
3226 }
3227 let response = responses.remove(0);
3228
3229 let (tx, rx) = mpsc::channel(10);
3230 tokio::spawn(async move {
3231 for block in &response.message.content {
3233 if let ContentBlock::Text { text } = block {
3234 tx.send(StreamEvent::TextDelta(text.clone())).await.ok();
3235 }
3236 }
3237 tx.send(StreamEvent::Done(response)).await.ok();
3238 });
3239
3240 Ok(rx)
3241 }
3242 }
3243
3244 #[tokio::test]
3249 async fn test_agent_simple_response() {
3250 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
3251 "Hello, I'm an AI assistant.",
3252 )]));
3253
3254 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3255 let config = AgentConfig::default();
3256
3257 let agent = AgentLoop::new(
3258 mock_client.clone(),
3259 tool_executor,
3260 test_tool_context(),
3261 config,
3262 );
3263 let result = agent.execute(&[], "Hello", None).await.unwrap();
3264
3265 assert_eq!(result.text, "Hello, I'm an AI assistant.");
3266 assert_eq!(result.tool_calls_count, 0);
3267 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 1);
3268 }
3269
3270 #[tokio::test]
3271 async fn test_agent_with_tool_call() {
3272 let mock_client = Arc::new(MockLlmClient::new(vec![
3273 MockLlmClient::tool_call_response(
3275 "tool-1",
3276 "bash",
3277 serde_json::json!({"command": "echo hello"}),
3278 ),
3279 MockLlmClient::text_response("The command output was: hello"),
3281 ]));
3282
3283 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3284 let config = AgentConfig::default();
3285
3286 let agent = AgentLoop::new(
3287 mock_client.clone(),
3288 tool_executor,
3289 test_tool_context(),
3290 config,
3291 );
3292 let result = agent.execute(&[], "Run echo hello", None).await.unwrap();
3293
3294 assert_eq!(result.text, "The command output was: hello");
3295 assert_eq!(result.tool_calls_count, 1);
3296 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 2);
3297 }
3298
3299 #[tokio::test]
3300 async fn test_agent_permission_deny() {
3301 let mock_client = Arc::new(MockLlmClient::new(vec![
3302 MockLlmClient::tool_call_response(
3304 "tool-1",
3305 "bash",
3306 serde_json::json!({"command": "rm -rf /tmp/test"}),
3307 ),
3308 MockLlmClient::text_response(
3310 "I cannot execute that command due to permission restrictions.",
3311 ),
3312 ]));
3313
3314 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3315
3316 let permission_policy = PermissionPolicy::new().deny("bash(rm:*)");
3318
3319 let config = AgentConfig {
3320 permission_checker: Some(Arc::new(permission_policy)),
3321 ..Default::default()
3322 };
3323
3324 let (tx, mut rx) = mpsc::channel(100);
3325 let agent = AgentLoop::new(
3326 mock_client.clone(),
3327 tool_executor,
3328 test_tool_context(),
3329 config,
3330 );
3331 let result = agent.execute(&[], "Delete files", Some(tx)).await.unwrap();
3332
3333 let mut found_permission_denied = false;
3335 while let Ok(event) = rx.try_recv() {
3336 if let AgentEvent::PermissionDenied { tool_name, .. } = event {
3337 assert_eq!(tool_name, "bash");
3338 found_permission_denied = true;
3339 }
3340 }
3341 assert!(
3342 found_permission_denied,
3343 "Should have received PermissionDenied event"
3344 );
3345
3346 assert_eq!(result.tool_calls_count, 1);
3347 }
3348
3349 #[tokio::test]
3350 async fn test_agent_permission_allow() {
3351 let mock_client = Arc::new(MockLlmClient::new(vec![
3352 MockLlmClient::tool_call_response(
3354 "tool-1",
3355 "bash",
3356 serde_json::json!({"command": "echo hello"}),
3357 ),
3358 MockLlmClient::text_response("Done!"),
3360 ]));
3361
3362 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3363
3364 let permission_policy = PermissionPolicy::new()
3366 .allow("bash(echo:*)")
3367 .deny("bash(rm:*)");
3368
3369 let config = AgentConfig {
3370 permission_checker: Some(Arc::new(permission_policy)),
3371 ..Default::default()
3372 };
3373
3374 let agent = AgentLoop::new(
3375 mock_client.clone(),
3376 tool_executor,
3377 test_tool_context(),
3378 config,
3379 );
3380 let result = agent.execute(&[], "Echo hello", None).await.unwrap();
3381
3382 assert_eq!(result.text, "Done!");
3383 assert_eq!(result.tool_calls_count, 1);
3384 }
3385
3386 #[tokio::test]
3387 async fn test_agent_streaming_events() {
3388 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
3389 "Hello!",
3390 )]));
3391
3392 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3393 let config = AgentConfig::default();
3394
3395 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3396 let (mut rx, handle, _cancel_token) = agent.execute_streaming(&[], "Hi").await.unwrap();
3397
3398 let mut events = Vec::new();
3400 while let Some(event) = rx.recv().await {
3401 events.push(event);
3402 }
3403
3404 let result = handle.await.unwrap().unwrap();
3405 assert_eq!(result.text, "Hello!");
3406
3407 assert!(events.iter().any(|e| matches!(e, AgentEvent::Start { .. })));
3409 assert!(events.iter().any(|e| matches!(e, AgentEvent::End { .. })));
3410 }
3411
3412 #[tokio::test]
3413 async fn test_agent_max_tool_rounds() {
3414 let responses: Vec<LlmResponse> = (0..100)
3416 .map(|i| {
3417 MockLlmClient::tool_call_response(
3418 &format!("tool-{}", i),
3419 "bash",
3420 serde_json::json!({"command": "echo loop"}),
3421 )
3422 })
3423 .collect();
3424
3425 let mock_client = Arc::new(MockLlmClient::new(responses));
3426 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3427
3428 let config = AgentConfig {
3429 max_tool_rounds: 3,
3430 ..Default::default()
3431 };
3432
3433 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3434 let result = agent.execute(&[], "Loop forever", None).await;
3435
3436 assert!(result.is_err());
3438 assert!(result.unwrap_err().to_string().contains("Max tool rounds"));
3439 }
3440
3441 #[tokio::test]
3442 async fn test_agent_no_permission_policy_defaults_to_ask() {
3443 let mock_client = Arc::new(MockLlmClient::new(vec![
3446 MockLlmClient::tool_call_response(
3447 "tool-1",
3448 "bash",
3449 serde_json::json!({"command": "rm -rf /tmp/test"}),
3450 ),
3451 MockLlmClient::text_response("Denied!"),
3452 ]));
3453
3454 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3455 let config = AgentConfig {
3456 permission_checker: None, ..Default::default()
3459 };
3460
3461 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3462 let result = agent.execute(&[], "Delete", None).await.unwrap();
3463
3464 assert_eq!(result.text, "Denied!");
3466 assert_eq!(result.tool_calls_count, 1);
3467 }
3468
3469 #[tokio::test]
3470 async fn test_agent_permission_ask_without_cm_denies() {
3471 let mock_client = Arc::new(MockLlmClient::new(vec![
3474 MockLlmClient::tool_call_response(
3475 "tool-1",
3476 "bash",
3477 serde_json::json!({"command": "echo test"}),
3478 ),
3479 MockLlmClient::text_response("Denied!"),
3480 ]));
3481
3482 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3483
3484 let permission_policy = PermissionPolicy::new(); let config = AgentConfig {
3488 permission_checker: Some(Arc::new(permission_policy)),
3489 ..Default::default()
3491 };
3492
3493 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3494 let result = agent.execute(&[], "Echo", None).await.unwrap();
3495
3496 assert_eq!(result.text, "Denied!");
3498 assert!(result.tool_calls_count >= 1);
3500 }
3501
3502 #[tokio::test]
3507 async fn test_agent_hitl_approved() {
3508 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
3509 use tokio::sync::broadcast;
3510
3511 let mock_client = Arc::new(MockLlmClient::new(vec![
3512 MockLlmClient::tool_call_response(
3513 "tool-1",
3514 "bash",
3515 serde_json::json!({"command": "echo hello"}),
3516 ),
3517 MockLlmClient::text_response("Command executed!"),
3518 ]));
3519
3520 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3521
3522 let (event_tx, _event_rx) = broadcast::channel(100);
3524 let hitl_policy = ConfirmationPolicy {
3525 enabled: true,
3526 ..Default::default()
3527 };
3528 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3529
3530 let permission_policy = PermissionPolicy::new(); let config = AgentConfig {
3534 permission_checker: Some(Arc::new(permission_policy)),
3535 confirmation_manager: Some(confirmation_manager.clone()),
3536 ..Default::default()
3537 };
3538
3539 let cm_clone = confirmation_manager.clone();
3541 tokio::spawn(async move {
3542 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
3544 cm_clone.confirm("tool-1", true, None).await.ok();
3546 });
3547
3548 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3549 let result = agent.execute(&[], "Run echo", None).await.unwrap();
3550
3551 assert_eq!(result.text, "Command executed!");
3552 assert_eq!(result.tool_calls_count, 1);
3553 }
3554
3555 #[tokio::test]
3556 async fn test_agent_hitl_rejected() {
3557 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
3558 use tokio::sync::broadcast;
3559
3560 let mock_client = Arc::new(MockLlmClient::new(vec![
3561 MockLlmClient::tool_call_response(
3562 "tool-1",
3563 "bash",
3564 serde_json::json!({"command": "rm -rf /"}),
3565 ),
3566 MockLlmClient::text_response("Understood, I won't do that."),
3567 ]));
3568
3569 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3570
3571 let (event_tx, _event_rx) = broadcast::channel(100);
3573 let hitl_policy = ConfirmationPolicy {
3574 enabled: true,
3575 ..Default::default()
3576 };
3577 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3578
3579 let permission_policy = PermissionPolicy::new();
3581
3582 let config = AgentConfig {
3583 permission_checker: Some(Arc::new(permission_policy)),
3584 confirmation_manager: Some(confirmation_manager.clone()),
3585 ..Default::default()
3586 };
3587
3588 let cm_clone = confirmation_manager.clone();
3590 tokio::spawn(async move {
3591 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
3592 cm_clone
3593 .confirm("tool-1", false, Some("Too dangerous".to_string()))
3594 .await
3595 .ok();
3596 });
3597
3598 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3599 let result = agent.execute(&[], "Delete everything", None).await.unwrap();
3600
3601 assert_eq!(result.text, "Understood, I won't do that.");
3603 }
3604
3605 #[tokio::test]
3606 async fn test_agent_hitl_timeout_reject() {
3607 use crate::hitl::{ConfirmationManager, ConfirmationPolicy, TimeoutAction};
3608 use tokio::sync::broadcast;
3609
3610 let mock_client = Arc::new(MockLlmClient::new(vec![
3611 MockLlmClient::tool_call_response(
3612 "tool-1",
3613 "bash",
3614 serde_json::json!({"command": "echo test"}),
3615 ),
3616 MockLlmClient::text_response("Timed out, I understand."),
3617 ]));
3618
3619 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3620
3621 let (event_tx, _event_rx) = broadcast::channel(100);
3623 let hitl_policy = ConfirmationPolicy {
3624 enabled: true,
3625 default_timeout_ms: 50, timeout_action: TimeoutAction::Reject,
3627 ..Default::default()
3628 };
3629 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3630
3631 let permission_policy = PermissionPolicy::new();
3632
3633 let config = AgentConfig {
3634 permission_checker: Some(Arc::new(permission_policy)),
3635 confirmation_manager: Some(confirmation_manager),
3636 ..Default::default()
3637 };
3638
3639 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3641 let result = agent.execute(&[], "Echo", None).await.unwrap();
3642
3643 assert_eq!(result.text, "Timed out, I understand.");
3645 }
3646
3647 #[tokio::test]
3648 async fn test_agent_hitl_timeout_auto_approve() {
3649 use crate::hitl::{ConfirmationManager, ConfirmationPolicy, TimeoutAction};
3650 use tokio::sync::broadcast;
3651
3652 let mock_client = Arc::new(MockLlmClient::new(vec![
3653 MockLlmClient::tool_call_response(
3654 "tool-1",
3655 "bash",
3656 serde_json::json!({"command": "echo hello"}),
3657 ),
3658 MockLlmClient::text_response("Auto-approved and executed!"),
3659 ]));
3660
3661 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3662
3663 let (event_tx, _event_rx) = broadcast::channel(100);
3665 let hitl_policy = ConfirmationPolicy {
3666 enabled: true,
3667 default_timeout_ms: 50, timeout_action: TimeoutAction::AutoApprove,
3669 ..Default::default()
3670 };
3671 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3672
3673 let permission_policy = PermissionPolicy::new();
3674
3675 let config = AgentConfig {
3676 permission_checker: Some(Arc::new(permission_policy)),
3677 confirmation_manager: Some(confirmation_manager),
3678 ..Default::default()
3679 };
3680
3681 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3683 let result = agent.execute(&[], "Echo", None).await.unwrap();
3684
3685 assert_eq!(result.text, "Auto-approved and executed!");
3687 assert_eq!(result.tool_calls_count, 1);
3688 }
3689
3690 #[tokio::test]
3691 async fn test_agent_hitl_confirmation_events() {
3692 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
3693 use tokio::sync::broadcast;
3694
3695 let mock_client = Arc::new(MockLlmClient::new(vec![
3696 MockLlmClient::tool_call_response(
3697 "tool-1",
3698 "bash",
3699 serde_json::json!({"command": "echo test"}),
3700 ),
3701 MockLlmClient::text_response("Done!"),
3702 ]));
3703
3704 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3705
3706 let (event_tx, mut event_rx) = broadcast::channel(100);
3708 let hitl_policy = ConfirmationPolicy {
3709 enabled: true,
3710 default_timeout_ms: 5000, ..Default::default()
3712 };
3713 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3714
3715 let permission_policy = PermissionPolicy::new();
3716
3717 let config = AgentConfig {
3718 permission_checker: Some(Arc::new(permission_policy)),
3719 confirmation_manager: Some(confirmation_manager.clone()),
3720 ..Default::default()
3721 };
3722
3723 let cm_clone = confirmation_manager.clone();
3725 let event_handle = tokio::spawn(async move {
3726 let mut events = Vec::new();
3727 while let Ok(event) = event_rx.recv().await {
3729 events.push(event.clone());
3730 if let AgentEvent::ConfirmationRequired { tool_id, .. } = event {
3731 cm_clone.confirm(&tool_id, true, None).await.ok();
3733 if let Ok(recv_event) = event_rx.recv().await {
3735 events.push(recv_event);
3736 }
3737 break;
3738 }
3739 }
3740 events
3741 });
3742
3743 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3744 let _result = agent.execute(&[], "Echo", None).await.unwrap();
3745
3746 let events = event_handle.await.unwrap();
3748 assert!(
3749 events
3750 .iter()
3751 .any(|e| matches!(e, AgentEvent::ConfirmationRequired { .. })),
3752 "Should have ConfirmationRequired event"
3753 );
3754 assert!(
3755 events
3756 .iter()
3757 .any(|e| matches!(e, AgentEvent::ConfirmationReceived { approved: true, .. })),
3758 "Should have ConfirmationReceived event with approved=true"
3759 );
3760 }
3761
3762 #[tokio::test]
3763 async fn test_agent_hitl_disabled_auto_executes() {
3764 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
3766 use tokio::sync::broadcast;
3767
3768 let mock_client = Arc::new(MockLlmClient::new(vec![
3769 MockLlmClient::tool_call_response(
3770 "tool-1",
3771 "bash",
3772 serde_json::json!({"command": "echo auto"}),
3773 ),
3774 MockLlmClient::text_response("Auto executed!"),
3775 ]));
3776
3777 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3778
3779 let (event_tx, _event_rx) = broadcast::channel(100);
3781 let hitl_policy = ConfirmationPolicy {
3782 enabled: false, ..Default::default()
3784 };
3785 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3786
3787 let permission_policy = PermissionPolicy::new(); let config = AgentConfig {
3790 permission_checker: Some(Arc::new(permission_policy)),
3791 confirmation_manager: Some(confirmation_manager),
3792 ..Default::default()
3793 };
3794
3795 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3796 let result = agent.execute(&[], "Echo", None).await.unwrap();
3797
3798 assert_eq!(result.text, "Auto executed!");
3800 assert_eq!(result.tool_calls_count, 1);
3801 }
3802
3803 #[tokio::test]
3804 async fn test_agent_hitl_with_permission_deny_skips_hitl() {
3805 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
3807 use tokio::sync::broadcast;
3808
3809 let mock_client = Arc::new(MockLlmClient::new(vec![
3810 MockLlmClient::tool_call_response(
3811 "tool-1",
3812 "bash",
3813 serde_json::json!({"command": "rm -rf /"}),
3814 ),
3815 MockLlmClient::text_response("Blocked by permission."),
3816 ]));
3817
3818 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3819
3820 let (event_tx, mut event_rx) = broadcast::channel(100);
3822 let hitl_policy = ConfirmationPolicy {
3823 enabled: true,
3824 ..Default::default()
3825 };
3826 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3827
3828 let permission_policy = PermissionPolicy::new().deny("bash(rm:*)");
3830
3831 let config = AgentConfig {
3832 permission_checker: Some(Arc::new(permission_policy)),
3833 confirmation_manager: Some(confirmation_manager),
3834 ..Default::default()
3835 };
3836
3837 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3838 let result = agent.execute(&[], "Delete", None).await.unwrap();
3839
3840 assert_eq!(result.text, "Blocked by permission.");
3842
3843 let mut found_confirmation = false;
3845 while let Ok(event) = event_rx.try_recv() {
3846 if matches!(event, AgentEvent::ConfirmationRequired { .. }) {
3847 found_confirmation = true;
3848 }
3849 }
3850 assert!(
3851 !found_confirmation,
3852 "HITL should not be triggered when permission is Deny"
3853 );
3854 }
3855
3856 #[tokio::test]
3857 async fn test_agent_hitl_with_permission_allow_skips_hitl() {
3858 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
3861 use tokio::sync::broadcast;
3862
3863 let mock_client = Arc::new(MockLlmClient::new(vec![
3864 MockLlmClient::tool_call_response(
3865 "tool-1",
3866 "bash",
3867 serde_json::json!({"command": "echo hello"}),
3868 ),
3869 MockLlmClient::text_response("Allowed!"),
3870 ]));
3871
3872 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3873
3874 let (event_tx, mut event_rx) = broadcast::channel(100);
3876 let hitl_policy = ConfirmationPolicy {
3877 enabled: true,
3878 ..Default::default()
3879 };
3880 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3881
3882 let permission_policy = PermissionPolicy::new().allow("bash(echo:*)");
3884
3885 let config = AgentConfig {
3886 permission_checker: Some(Arc::new(permission_policy)),
3887 confirmation_manager: Some(confirmation_manager.clone()),
3888 ..Default::default()
3889 };
3890
3891 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3892 let result = agent.execute(&[], "Echo", None).await.unwrap();
3893
3894 assert_eq!(result.text, "Allowed!");
3896
3897 let mut found_confirmation = false;
3899 while let Ok(event) = event_rx.try_recv() {
3900 if matches!(event, AgentEvent::ConfirmationRequired { .. }) {
3901 found_confirmation = true;
3902 }
3903 }
3904 assert!(
3905 !found_confirmation,
3906 "Permission Allow should skip HITL confirmation"
3907 );
3908 }
3909
3910 #[tokio::test]
3911 async fn test_agent_hitl_multiple_tool_calls() {
3912 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
3914 use tokio::sync::broadcast;
3915
3916 let mock_client = Arc::new(MockLlmClient::new(vec![
3917 LlmResponse {
3919 message: Message {
3920 role: "assistant".to_string(),
3921 content: vec![
3922 ContentBlock::ToolUse {
3923 id: "tool-1".to_string(),
3924 name: "bash".to_string(),
3925 input: serde_json::json!({"command": "echo first"}),
3926 },
3927 ContentBlock::ToolUse {
3928 id: "tool-2".to_string(),
3929 name: "bash".to_string(),
3930 input: serde_json::json!({"command": "echo second"}),
3931 },
3932 ],
3933 reasoning_content: None,
3934 },
3935 usage: TokenUsage {
3936 prompt_tokens: 10,
3937 completion_tokens: 5,
3938 total_tokens: 15,
3939 cache_read_tokens: None,
3940 cache_write_tokens: None,
3941 },
3942 stop_reason: Some("tool_use".to_string()),
3943 meta: None,
3944 },
3945 MockLlmClient::text_response("Both executed!"),
3946 ]));
3947
3948 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
3949
3950 let (event_tx, _event_rx) = broadcast::channel(100);
3952 let hitl_policy = ConfirmationPolicy {
3953 enabled: true,
3954 default_timeout_ms: 5000,
3955 ..Default::default()
3956 };
3957 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
3958
3959 let permission_policy = PermissionPolicy::new(); let config = AgentConfig {
3962 permission_checker: Some(Arc::new(permission_policy)),
3963 confirmation_manager: Some(confirmation_manager.clone()),
3964 ..Default::default()
3965 };
3966
3967 let cm_clone = confirmation_manager.clone();
3969 tokio::spawn(async move {
3970 tokio::time::sleep(std::time::Duration::from_millis(30)).await;
3971 cm_clone.confirm("tool-1", true, None).await.ok();
3972 tokio::time::sleep(std::time::Duration::from_millis(30)).await;
3973 cm_clone.confirm("tool-2", true, None).await.ok();
3974 });
3975
3976 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
3977 let result = agent.execute(&[], "Run both", None).await.unwrap();
3978
3979 assert_eq!(result.text, "Both executed!");
3980 assert_eq!(result.tool_calls_count, 2);
3981 }
3982
3983 #[tokio::test]
3984 async fn test_agent_hitl_partial_approval() {
3985 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
3987 use tokio::sync::broadcast;
3988
3989 let mock_client = Arc::new(MockLlmClient::new(vec![
3990 LlmResponse {
3992 message: Message {
3993 role: "assistant".to_string(),
3994 content: vec![
3995 ContentBlock::ToolUse {
3996 id: "tool-1".to_string(),
3997 name: "bash".to_string(),
3998 input: serde_json::json!({"command": "echo safe"}),
3999 },
4000 ContentBlock::ToolUse {
4001 id: "tool-2".to_string(),
4002 name: "bash".to_string(),
4003 input: serde_json::json!({"command": "rm -rf /"}),
4004 },
4005 ],
4006 reasoning_content: None,
4007 },
4008 usage: TokenUsage {
4009 prompt_tokens: 10,
4010 completion_tokens: 5,
4011 total_tokens: 15,
4012 cache_read_tokens: None,
4013 cache_write_tokens: None,
4014 },
4015 stop_reason: Some("tool_use".to_string()),
4016 meta: None,
4017 },
4018 MockLlmClient::text_response("First worked, second rejected."),
4019 ]));
4020
4021 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4022
4023 let (event_tx, _event_rx) = broadcast::channel(100);
4024 let hitl_policy = ConfirmationPolicy {
4025 enabled: true,
4026 default_timeout_ms: 5000,
4027 ..Default::default()
4028 };
4029 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
4030
4031 let permission_policy = PermissionPolicy::new();
4032
4033 let config = AgentConfig {
4034 permission_checker: Some(Arc::new(permission_policy)),
4035 confirmation_manager: Some(confirmation_manager.clone()),
4036 ..Default::default()
4037 };
4038
4039 let cm_clone = confirmation_manager.clone();
4041 tokio::spawn(async move {
4042 tokio::time::sleep(std::time::Duration::from_millis(30)).await;
4043 cm_clone.confirm("tool-1", true, None).await.ok();
4044 tokio::time::sleep(std::time::Duration::from_millis(30)).await;
4045 cm_clone
4046 .confirm("tool-2", false, Some("Dangerous".to_string()))
4047 .await
4048 .ok();
4049 });
4050
4051 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
4052 let result = agent.execute(&[], "Run both", None).await.unwrap();
4053
4054 assert_eq!(result.text, "First worked, second rejected.");
4055 assert_eq!(result.tool_calls_count, 2);
4056 }
4057
4058 #[tokio::test]
4059 async fn test_agent_hitl_yolo_mode_auto_approves() {
4060 use crate::hitl::{ConfirmationManager, ConfirmationPolicy, SessionLane};
4062 use tokio::sync::broadcast;
4063
4064 let mock_client = Arc::new(MockLlmClient::new(vec![
4065 MockLlmClient::tool_call_response(
4066 "tool-1",
4067 "read", serde_json::json!({"path": "/tmp/test.txt"}),
4069 ),
4070 MockLlmClient::text_response("File read!"),
4071 ]));
4072
4073 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4074
4075 let (event_tx, mut event_rx) = broadcast::channel(100);
4077 let mut yolo_lanes = std::collections::HashSet::new();
4078 yolo_lanes.insert(SessionLane::Query);
4079 let hitl_policy = ConfirmationPolicy {
4080 enabled: true,
4081 yolo_lanes, ..Default::default()
4083 };
4084 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
4085
4086 let permission_policy = PermissionPolicy::new();
4087
4088 let config = AgentConfig {
4089 permission_checker: Some(Arc::new(permission_policy)),
4090 confirmation_manager: Some(confirmation_manager),
4091 ..Default::default()
4092 };
4093
4094 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
4095 let result = agent.execute(&[], "Read file", None).await.unwrap();
4096
4097 assert_eq!(result.text, "File read!");
4099
4100 let mut found_confirmation = false;
4102 while let Ok(event) = event_rx.try_recv() {
4103 if matches!(event, AgentEvent::ConfirmationRequired { .. }) {
4104 found_confirmation = true;
4105 }
4106 }
4107 assert!(
4108 !found_confirmation,
4109 "YOLO mode should not trigger confirmation"
4110 );
4111 }
4112
4113 #[tokio::test]
4114 async fn test_agent_config_with_all_options() {
4115 use crate::hitl::{ConfirmationManager, ConfirmationPolicy};
4116 use tokio::sync::broadcast;
4117
4118 let (event_tx, _) = broadcast::channel(100);
4119 let hitl_policy = ConfirmationPolicy::default();
4120 let confirmation_manager = Arc::new(ConfirmationManager::new(hitl_policy, event_tx));
4121
4122 let permission_policy = PermissionPolicy::new().allow("bash(*)");
4123
4124 let config = AgentConfig {
4125 prompt_slots: SystemPromptSlots {
4126 extra: Some("Test system prompt".to_string()),
4127 ..Default::default()
4128 },
4129 tools: vec![],
4130 max_tool_rounds: 10,
4131 permission_checker: Some(Arc::new(permission_policy)),
4132 confirmation_manager: Some(confirmation_manager),
4133 context_providers: vec![],
4134 planning_enabled: false,
4135 goal_tracking: false,
4136 hook_engine: None,
4137 skill_registry: None,
4138 ..AgentConfig::default()
4139 };
4140
4141 assert!(config.prompt_slots.build().contains("Test system prompt"));
4142 assert_eq!(config.max_tool_rounds, 10);
4143 assert!(config.permission_checker.is_some());
4144 assert!(config.confirmation_manager.is_some());
4145 assert!(config.context_providers.is_empty());
4146
4147 let debug_str = format!("{:?}", config);
4149 assert!(debug_str.contains("AgentConfig"));
4150 assert!(debug_str.contains("permission_checker: true"));
4151 assert!(debug_str.contains("confirmation_manager: true"));
4152 assert!(debug_str.contains("context_providers: 0"));
4153 }
4154
4155 use crate::context::{ContextItem, ContextType};
4160
4161 struct MockContextProvider {
4163 name: String,
4164 items: Vec<ContextItem>,
4165 on_turn_calls: std::sync::Arc<tokio::sync::RwLock<Vec<(String, String, String)>>>,
4166 }
4167
4168 impl MockContextProvider {
4169 fn new(name: &str) -> Self {
4170 Self {
4171 name: name.to_string(),
4172 items: Vec::new(),
4173 on_turn_calls: std::sync::Arc::new(tokio::sync::RwLock::new(Vec::new())),
4174 }
4175 }
4176
4177 fn with_items(mut self, items: Vec<ContextItem>) -> Self {
4178 self.items = items;
4179 self
4180 }
4181 }
4182
4183 #[async_trait::async_trait]
4184 impl ContextProvider for MockContextProvider {
4185 fn name(&self) -> &str {
4186 &self.name
4187 }
4188
4189 async fn query(&self, _query: &ContextQuery) -> anyhow::Result<ContextResult> {
4190 let mut result = ContextResult::new(&self.name);
4191 for item in &self.items {
4192 result.add_item(item.clone());
4193 }
4194 Ok(result)
4195 }
4196
4197 async fn on_turn_complete(
4198 &self,
4199 session_id: &str,
4200 prompt: &str,
4201 response: &str,
4202 ) -> anyhow::Result<()> {
4203 let mut calls = self.on_turn_calls.write().await;
4204 calls.push((
4205 session_id.to_string(),
4206 prompt.to_string(),
4207 response.to_string(),
4208 ));
4209 Ok(())
4210 }
4211 }
4212
4213 #[tokio::test]
4214 async fn test_agent_with_context_provider() {
4215 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4216 "Response using context",
4217 )]));
4218
4219 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4220
4221 let provider =
4222 MockContextProvider::new("test-provider").with_items(vec![ContextItem::new(
4223 "ctx-1",
4224 ContextType::Resource,
4225 "Relevant context here",
4226 )
4227 .with_source("test://docs/example")]);
4228
4229 let config = AgentConfig {
4230 prompt_slots: SystemPromptSlots {
4231 extra: Some("You are helpful.".to_string()),
4232 ..Default::default()
4233 },
4234 context_providers: vec![Arc::new(provider)],
4235 ..Default::default()
4236 };
4237
4238 let agent = AgentLoop::new(
4239 mock_client.clone(),
4240 tool_executor,
4241 test_tool_context(),
4242 config,
4243 );
4244 let result = agent.execute(&[], "What is X?", None).await.unwrap();
4245
4246 assert_eq!(result.text, "Response using context");
4247 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 1);
4248 }
4249
4250 #[tokio::test]
4251 async fn test_agent_context_provider_events() {
4252 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4253 "Answer",
4254 )]));
4255
4256 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4257
4258 let provider =
4259 MockContextProvider::new("event-provider").with_items(vec![ContextItem::new(
4260 "item-1",
4261 ContextType::Memory,
4262 "Memory content",
4263 )
4264 .with_token_count(50)]);
4265
4266 let config = AgentConfig {
4267 context_providers: vec![Arc::new(provider)],
4268 ..Default::default()
4269 };
4270
4271 let (tx, mut rx) = mpsc::channel(100);
4272 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
4273 let _result = agent.execute(&[], "Test prompt", Some(tx)).await.unwrap();
4274
4275 let mut events = Vec::new();
4277 while let Ok(event) = rx.try_recv() {
4278 events.push(event);
4279 }
4280
4281 assert!(
4283 events
4284 .iter()
4285 .any(|e| matches!(e, AgentEvent::ContextResolving { .. })),
4286 "Should have ContextResolving event"
4287 );
4288 assert!(
4289 events
4290 .iter()
4291 .any(|e| matches!(e, AgentEvent::ContextResolved { .. })),
4292 "Should have ContextResolved event"
4293 );
4294
4295 for event in &events {
4297 if let AgentEvent::ContextResolved {
4298 total_items,
4299 total_tokens,
4300 } = event
4301 {
4302 assert_eq!(*total_items, 1);
4303 assert_eq!(*total_tokens, 50);
4304 }
4305 }
4306 }
4307
4308 #[tokio::test]
4309 async fn test_agent_multiple_context_providers() {
4310 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4311 "Combined response",
4312 )]));
4313
4314 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4315
4316 let provider1 = MockContextProvider::new("provider-1").with_items(vec![ContextItem::new(
4317 "p1-1",
4318 ContextType::Resource,
4319 "Resource from P1",
4320 )
4321 .with_token_count(100)]);
4322
4323 let provider2 = MockContextProvider::new("provider-2").with_items(vec![
4324 ContextItem::new("p2-1", ContextType::Memory, "Memory from P2").with_token_count(50),
4325 ContextItem::new("p2-2", ContextType::Skill, "Skill from P2").with_token_count(75),
4326 ]);
4327
4328 let config = AgentConfig {
4329 prompt_slots: SystemPromptSlots {
4330 extra: Some("Base system prompt.".to_string()),
4331 ..Default::default()
4332 },
4333 context_providers: vec![Arc::new(provider1), Arc::new(provider2)],
4334 ..Default::default()
4335 };
4336
4337 let (tx, mut rx) = mpsc::channel(100);
4338 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
4339 let result = agent.execute(&[], "Query", Some(tx)).await.unwrap();
4340
4341 assert_eq!(result.text, "Combined response");
4342
4343 while let Ok(event) = rx.try_recv() {
4345 if let AgentEvent::ContextResolved {
4346 total_items,
4347 total_tokens,
4348 } = event
4349 {
4350 assert_eq!(total_items, 3); assert_eq!(total_tokens, 225); }
4353 }
4354 }
4355
4356 #[tokio::test]
4357 async fn test_agent_no_context_providers() {
4358 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4359 "No context",
4360 )]));
4361
4362 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4363
4364 let config = AgentConfig::default();
4366
4367 let (tx, mut rx) = mpsc::channel(100);
4368 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
4369 let result = agent.execute(&[], "Simple prompt", Some(tx)).await.unwrap();
4370
4371 assert_eq!(result.text, "No context");
4372
4373 let mut events = Vec::new();
4375 while let Ok(event) = rx.try_recv() {
4376 events.push(event);
4377 }
4378
4379 assert!(
4380 !events
4381 .iter()
4382 .any(|e| matches!(e, AgentEvent::ContextResolving { .. })),
4383 "Should NOT have ContextResolving event"
4384 );
4385 }
4386
4387 #[tokio::test]
4388 async fn test_agent_context_on_turn_complete() {
4389 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4390 "Final response",
4391 )]));
4392
4393 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4394
4395 let provider = Arc::new(MockContextProvider::new("memory-provider"));
4396 let on_turn_calls = provider.on_turn_calls.clone();
4397
4398 let config = AgentConfig {
4399 context_providers: vec![provider],
4400 ..Default::default()
4401 };
4402
4403 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
4404
4405 let result = agent
4407 .execute_with_session(&[], "User prompt", Some("sess-123"), None, None)
4408 .await
4409 .unwrap();
4410
4411 assert_eq!(result.text, "Final response");
4412
4413 let calls = on_turn_calls.read().await;
4415 assert_eq!(calls.len(), 1);
4416 assert_eq!(calls[0].0, "sess-123");
4417 assert_eq!(calls[0].1, "User prompt");
4418 assert_eq!(calls[0].2, "Final response");
4419 }
4420
4421 #[tokio::test]
4422 async fn test_agent_context_on_turn_complete_no_session() {
4423 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4424 "Response",
4425 )]));
4426
4427 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4428
4429 let provider = Arc::new(MockContextProvider::new("memory-provider"));
4430 let on_turn_calls = provider.on_turn_calls.clone();
4431
4432 let config = AgentConfig {
4433 context_providers: vec![provider],
4434 ..Default::default()
4435 };
4436
4437 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
4438
4439 let _result = agent.execute(&[], "Prompt", None).await.unwrap();
4441
4442 let calls = on_turn_calls.read().await;
4444 assert!(calls.is_empty());
4445 }
4446
4447 #[tokio::test]
4448 async fn test_agent_build_augmented_system_prompt() {
4449 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response("OK")]));
4450
4451 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4452
4453 let provider = MockContextProvider::new("test").with_items(vec![ContextItem::new(
4454 "doc-1",
4455 ContextType::Resource,
4456 "Auth uses JWT tokens.",
4457 )
4458 .with_source("viking://docs/auth")]);
4459
4460 let config = AgentConfig {
4461 prompt_slots: SystemPromptSlots {
4462 extra: Some("You are helpful.".to_string()),
4463 ..Default::default()
4464 },
4465 context_providers: vec![Arc::new(provider)],
4466 ..Default::default()
4467 };
4468
4469 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
4470
4471 let context_results = agent.resolve_context("test", None).await;
4473 let augmented = agent.build_augmented_system_prompt(&context_results);
4474
4475 let augmented_str = augmented.unwrap();
4476 assert!(augmented_str.contains("You are helpful."));
4477 assert!(augmented_str.contains("<context source=\"viking://docs/auth\" type=\"Resource\">"));
4478 assert!(augmented_str.contains("Auth uses JWT tokens."));
4479 }
4480
4481 async fn collect_events(mut rx: mpsc::Receiver<AgentEvent>) -> Vec<AgentEvent> {
4487 let mut events = Vec::new();
4488 while let Ok(event) = rx.try_recv() {
4489 events.push(event);
4490 }
4491 while let Some(event) = rx.recv().await {
4493 events.push(event);
4494 }
4495 events
4496 }
4497
4498 #[tokio::test]
4499 async fn test_agent_multi_turn_tool_chain() {
4500 let mock_client = Arc::new(MockLlmClient::new(vec![
4502 MockLlmClient::tool_call_response(
4504 "t1",
4505 "bash",
4506 serde_json::json!({"command": "echo step1"}),
4507 ),
4508 MockLlmClient::tool_call_response(
4510 "t2",
4511 "bash",
4512 serde_json::json!({"command": "echo step2"}),
4513 ),
4514 MockLlmClient::text_response("Completed both steps: step1 then step2"),
4516 ]));
4517
4518 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4519 let config = AgentConfig::default();
4520
4521 let agent = AgentLoop::new(
4522 mock_client.clone(),
4523 tool_executor,
4524 test_tool_context(),
4525 config,
4526 );
4527 let result = agent.execute(&[], "Run two steps", None).await.unwrap();
4528
4529 assert_eq!(result.text, "Completed both steps: step1 then step2");
4530 assert_eq!(result.tool_calls_count, 2);
4531 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 3);
4532
4533 assert_eq!(result.messages[0].role, "user");
4535 assert_eq!(result.messages[1].role, "assistant"); assert_eq!(result.messages[2].role, "user"); assert_eq!(result.messages[3].role, "assistant"); assert_eq!(result.messages[4].role, "user"); assert_eq!(result.messages[5].role, "assistant"); assert_eq!(result.messages.len(), 6);
4541 }
4542
4543 #[tokio::test]
4544 async fn test_agent_conversation_history_preserved() {
4545 let existing_history = vec![
4547 Message::user("What is Rust?"),
4548 Message {
4549 role: "assistant".to_string(),
4550 content: vec![ContentBlock::Text {
4551 text: "Rust is a systems programming language.".to_string(),
4552 }],
4553 reasoning_content: None,
4554 },
4555 ];
4556
4557 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4558 "Rust was created by Graydon Hoare at Mozilla.",
4559 )]));
4560
4561 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4562 let agent = AgentLoop::new(
4563 mock_client.clone(),
4564 tool_executor,
4565 test_tool_context(),
4566 AgentConfig::default(),
4567 );
4568
4569 let result = agent
4570 .execute(&existing_history, "Who created it?", None)
4571 .await
4572 .unwrap();
4573
4574 assert_eq!(result.messages.len(), 4);
4576 assert_eq!(result.messages[0].text(), "What is Rust?");
4577 assert_eq!(
4578 result.messages[1].text(),
4579 "Rust is a systems programming language."
4580 );
4581 assert_eq!(result.messages[2].text(), "Who created it?");
4582 assert_eq!(
4583 result.messages[3].text(),
4584 "Rust was created by Graydon Hoare at Mozilla."
4585 );
4586 }
4587
4588 #[tokio::test]
4589 async fn test_agent_event_stream_completeness() {
4590 let mock_client = Arc::new(MockLlmClient::new(vec![
4592 MockLlmClient::tool_call_response(
4593 "t1",
4594 "bash",
4595 serde_json::json!({"command": "echo hi"}),
4596 ),
4597 MockLlmClient::text_response("Done"),
4598 ]));
4599
4600 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4601 let agent = AgentLoop::new(
4602 mock_client,
4603 tool_executor,
4604 test_tool_context(),
4605 AgentConfig::default(),
4606 );
4607
4608 let (tx, rx) = mpsc::channel(100);
4609 let result = agent.execute(&[], "Say hi", Some(tx)).await.unwrap();
4610 assert_eq!(result.text, "Done");
4611
4612 let events = collect_events(rx).await;
4613
4614 let event_types: Vec<&str> = events
4616 .iter()
4617 .map(|e| match e {
4618 AgentEvent::Start { .. } => "Start",
4619 AgentEvent::TurnStart { .. } => "TurnStart",
4620 AgentEvent::TurnEnd { .. } => "TurnEnd",
4621 AgentEvent::ToolEnd { .. } => "ToolEnd",
4622 AgentEvent::End { .. } => "End",
4623 _ => "Other",
4624 })
4625 .collect();
4626
4627 assert_eq!(event_types.first(), Some(&"Start"));
4629 assert_eq!(event_types.last(), Some(&"End"));
4630
4631 let turn_starts = event_types.iter().filter(|&&t| t == "TurnStart").count();
4633 assert_eq!(turn_starts, 2);
4634
4635 let tool_ends = event_types.iter().filter(|&&t| t == "ToolEnd").count();
4637 assert_eq!(tool_ends, 1);
4638 }
4639
4640 #[tokio::test]
4641 async fn test_agent_multiple_tools_single_turn() {
4642 let mock_client = Arc::new(MockLlmClient::new(vec![
4644 LlmResponse {
4645 message: Message {
4646 role: "assistant".to_string(),
4647 content: vec![
4648 ContentBlock::ToolUse {
4649 id: "t1".to_string(),
4650 name: "bash".to_string(),
4651 input: serde_json::json!({"command": "echo first"}),
4652 },
4653 ContentBlock::ToolUse {
4654 id: "t2".to_string(),
4655 name: "bash".to_string(),
4656 input: serde_json::json!({"command": "echo second"}),
4657 },
4658 ],
4659 reasoning_content: None,
4660 },
4661 usage: TokenUsage {
4662 prompt_tokens: 10,
4663 completion_tokens: 5,
4664 total_tokens: 15,
4665 cache_read_tokens: None,
4666 cache_write_tokens: None,
4667 },
4668 stop_reason: Some("tool_use".to_string()),
4669 meta: None,
4670 },
4671 MockLlmClient::text_response("Both commands ran"),
4672 ]));
4673
4674 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4675 let agent = AgentLoop::new(
4676 mock_client.clone(),
4677 tool_executor,
4678 test_tool_context(),
4679 AgentConfig::default(),
4680 );
4681
4682 let result = agent.execute(&[], "Run both", None).await.unwrap();
4683
4684 assert_eq!(result.text, "Both commands ran");
4685 assert_eq!(result.tool_calls_count, 2);
4686 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 2); assert_eq!(result.messages[0].role, "user");
4690 assert_eq!(result.messages[1].role, "assistant");
4691 assert_eq!(result.messages[2].role, "user"); assert_eq!(result.messages[3].role, "user"); assert_eq!(result.messages[4].role, "assistant");
4694 }
4695
4696 #[tokio::test]
4697 async fn test_agent_token_usage_accumulation() {
4698 let mock_client = Arc::new(MockLlmClient::new(vec![
4700 MockLlmClient::tool_call_response(
4701 "t1",
4702 "bash",
4703 serde_json::json!({"command": "echo x"}),
4704 ),
4705 MockLlmClient::text_response("Done"),
4706 ]));
4707
4708 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4709 let agent = AgentLoop::new(
4710 mock_client,
4711 tool_executor,
4712 test_tool_context(),
4713 AgentConfig::default(),
4714 );
4715
4716 let result = agent.execute(&[], "test", None).await.unwrap();
4717
4718 assert_eq!(result.usage.prompt_tokens, 20);
4721 assert_eq!(result.usage.completion_tokens, 10);
4722 assert_eq!(result.usage.total_tokens, 30);
4723 }
4724
4725 #[tokio::test]
4726 async fn test_agent_system_prompt_passed() {
4727 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4729 "I am a coding assistant.",
4730 )]));
4731
4732 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4733 let config = AgentConfig {
4734 prompt_slots: SystemPromptSlots {
4735 extra: Some("You are a coding assistant.".to_string()),
4736 ..Default::default()
4737 },
4738 ..Default::default()
4739 };
4740
4741 let agent = AgentLoop::new(
4742 mock_client.clone(),
4743 tool_executor,
4744 test_tool_context(),
4745 config,
4746 );
4747 let result = agent.execute(&[], "What are you?", None).await.unwrap();
4748
4749 assert_eq!(result.text, "I am a coding assistant.");
4750 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 1);
4751 }
4752
4753 #[tokio::test]
4754 async fn test_agent_max_rounds_with_persistent_tool_calls() {
4755 let mut responses = Vec::new();
4757 for i in 0..15 {
4758 responses.push(MockLlmClient::tool_call_response(
4759 &format!("t{}", i),
4760 "bash",
4761 serde_json::json!({"command": format!("echo round{}", i)}),
4762 ));
4763 }
4764
4765 let mock_client = Arc::new(MockLlmClient::new(responses));
4766 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4767 let config = AgentConfig {
4768 max_tool_rounds: 5,
4769 ..Default::default()
4770 };
4771
4772 let agent = AgentLoop::new(
4773 mock_client.clone(),
4774 tool_executor,
4775 test_tool_context(),
4776 config,
4777 );
4778 let result = agent.execute(&[], "Loop forever", None).await;
4779
4780 assert!(result.is_err());
4781 let err = result.unwrap_err().to_string();
4782 assert!(err.contains("Max tool rounds (5) exceeded"));
4783 }
4784
4785 #[tokio::test]
4786 async fn test_agent_end_event_contains_final_text() {
4787 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
4788 "Final answer here",
4789 )]));
4790
4791 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
4792 let agent = AgentLoop::new(
4793 mock_client,
4794 tool_executor,
4795 test_tool_context(),
4796 AgentConfig::default(),
4797 );
4798
4799 let (tx, rx) = mpsc::channel(100);
4800 agent.execute(&[], "test", Some(tx)).await.unwrap();
4801
4802 let events = collect_events(rx).await;
4803 let end_event = events.iter().find(|e| matches!(e, AgentEvent::End { .. }));
4804 assert!(end_event.is_some());
4805
4806 if let AgentEvent::End { text, usage, .. } = end_event.unwrap() {
4807 assert_eq!(text, "Final answer here");
4808 assert_eq!(usage.total_tokens, 15);
4809 }
4810 }
4811}
4812
4813#[cfg(test)]
4814mod extra_agent_tests {
4815 use super::*;
4816 use crate::agent::tests::MockLlmClient;
4817 use crate::queue::SessionQueueConfig;
4818 use crate::tools::ToolExecutor;
4819 use std::path::PathBuf;
4820 use std::sync::atomic::{AtomicUsize, Ordering};
4821
4822 fn test_tool_context() -> ToolContext {
4823 ToolContext::new(PathBuf::from("/tmp"))
4824 }
4825
4826 #[test]
4831 fn test_agent_config_debug() {
4832 let config = AgentConfig {
4833 prompt_slots: SystemPromptSlots {
4834 extra: Some("You are helpful".to_string()),
4835 ..Default::default()
4836 },
4837 tools: vec![],
4838 max_tool_rounds: 10,
4839 permission_checker: None,
4840 confirmation_manager: None,
4841 context_providers: vec![],
4842 planning_enabled: true,
4843 goal_tracking: false,
4844 hook_engine: None,
4845 skill_registry: None,
4846 ..AgentConfig::default()
4847 };
4848 let debug = format!("{:?}", config);
4849 assert!(debug.contains("AgentConfig"));
4850 assert!(debug.contains("planning_enabled"));
4851 }
4852
4853 #[test]
4854 fn test_agent_config_default_values() {
4855 let config = AgentConfig::default();
4856 assert_eq!(config.max_tool_rounds, MAX_TOOL_ROUNDS);
4857 assert!(!config.planning_enabled);
4858 assert!(!config.goal_tracking);
4859 assert!(config.context_providers.is_empty());
4860 }
4861
4862 #[test]
4867 fn test_agent_event_serialize_start() {
4868 let event = AgentEvent::Start {
4869 prompt: "Hello".to_string(),
4870 };
4871 let json = serde_json::to_string(&event).unwrap();
4872 assert!(json.contains("agent_start"));
4873 assert!(json.contains("Hello"));
4874 }
4875
4876 #[test]
4877 fn test_agent_event_serialize_text_delta() {
4878 let event = AgentEvent::TextDelta {
4879 text: "chunk".to_string(),
4880 };
4881 let json = serde_json::to_string(&event).unwrap();
4882 assert!(json.contains("text_delta"));
4883 }
4884
4885 #[test]
4886 fn test_agent_event_serialize_tool_start() {
4887 let event = AgentEvent::ToolStart {
4888 id: "t1".to_string(),
4889 name: "bash".to_string(),
4890 };
4891 let json = serde_json::to_string(&event).unwrap();
4892 assert!(json.contains("tool_start"));
4893 assert!(json.contains("bash"));
4894 }
4895
4896 #[test]
4897 fn test_agent_event_serialize_tool_end() {
4898 let event = AgentEvent::ToolEnd {
4899 id: "t1".to_string(),
4900 name: "bash".to_string(),
4901 output: "hello".to_string(),
4902 exit_code: 0,
4903 metadata: None,
4904 };
4905 let json = serde_json::to_string(&event).unwrap();
4906 assert!(json.contains("tool_end"));
4907 }
4908
4909 #[test]
4910 fn test_agent_event_tool_end_has_metadata_field() {
4911 let event = AgentEvent::ToolEnd {
4912 id: "t1".to_string(),
4913 name: "write".to_string(),
4914 output: "Wrote 5 bytes".to_string(),
4915 exit_code: 0,
4916 metadata: Some(
4917 serde_json::json!({ "before": "old", "after": "new", "file_path": "f.txt" }),
4918 ),
4919 };
4920 let json = serde_json::to_string(&event).unwrap();
4921 assert!(json.contains("\"before\""));
4922 }
4923
4924 #[test]
4925 fn test_agent_event_serialize_error() {
4926 let event = AgentEvent::Error {
4927 message: "oops".to_string(),
4928 };
4929 let json = serde_json::to_string(&event).unwrap();
4930 assert!(json.contains("error"));
4931 assert!(json.contains("oops"));
4932 }
4933
4934 #[test]
4935 fn test_agent_event_serialize_confirmation_required() {
4936 let event = AgentEvent::ConfirmationRequired {
4937 tool_id: "t1".to_string(),
4938 tool_name: "bash".to_string(),
4939 args: serde_json::json!({"cmd": "rm"}),
4940 timeout_ms: 30000,
4941 };
4942 let json = serde_json::to_string(&event).unwrap();
4943 assert!(json.contains("confirmation_required"));
4944 }
4945
4946 #[test]
4947 fn test_agent_event_serialize_confirmation_received() {
4948 let event = AgentEvent::ConfirmationReceived {
4949 tool_id: "t1".to_string(),
4950 approved: true,
4951 reason: Some("safe".to_string()),
4952 };
4953 let json = serde_json::to_string(&event).unwrap();
4954 assert!(json.contains("confirmation_received"));
4955 }
4956
4957 #[test]
4958 fn test_agent_event_serialize_confirmation_timeout() {
4959 let event = AgentEvent::ConfirmationTimeout {
4960 tool_id: "t1".to_string(),
4961 action_taken: "rejected".to_string(),
4962 };
4963 let json = serde_json::to_string(&event).unwrap();
4964 assert!(json.contains("confirmation_timeout"));
4965 }
4966
4967 #[test]
4968 fn test_agent_event_serialize_external_task_pending() {
4969 let event = AgentEvent::ExternalTaskPending {
4970 task_id: "task-1".to_string(),
4971 session_id: "sess-1".to_string(),
4972 lane: crate::hitl::SessionLane::Execute,
4973 command_type: "bash".to_string(),
4974 payload: serde_json::json!({}),
4975 timeout_ms: 60000,
4976 };
4977 let json = serde_json::to_string(&event).unwrap();
4978 assert!(json.contains("external_task_pending"));
4979 }
4980
4981 #[test]
4982 fn test_agent_event_serialize_external_task_completed() {
4983 let event = AgentEvent::ExternalTaskCompleted {
4984 task_id: "task-1".to_string(),
4985 session_id: "sess-1".to_string(),
4986 success: false,
4987 };
4988 let json = serde_json::to_string(&event).unwrap();
4989 assert!(json.contains("external_task_completed"));
4990 }
4991
4992 #[test]
4993 fn test_agent_event_serialize_permission_denied() {
4994 let event = AgentEvent::PermissionDenied {
4995 tool_id: "t1".to_string(),
4996 tool_name: "bash".to_string(),
4997 args: serde_json::json!({}),
4998 reason: "denied".to_string(),
4999 };
5000 let json = serde_json::to_string(&event).unwrap();
5001 assert!(json.contains("permission_denied"));
5002 }
5003
5004 #[test]
5005 fn test_agent_event_serialize_context_compacted() {
5006 let event = AgentEvent::ContextCompacted {
5007 session_id: "sess-1".to_string(),
5008 before_messages: 100,
5009 after_messages: 20,
5010 percent_before: 0.85,
5011 };
5012 let json = serde_json::to_string(&event).unwrap();
5013 assert!(json.contains("context_compacted"));
5014 }
5015
5016 #[test]
5017 fn test_agent_event_serialize_turn_start() {
5018 let event = AgentEvent::TurnStart { turn: 3 };
5019 let json = serde_json::to_string(&event).unwrap();
5020 assert!(json.contains("turn_start"));
5021 }
5022
5023 #[test]
5024 fn test_agent_event_serialize_turn_end() {
5025 let event = AgentEvent::TurnEnd {
5026 turn: 3,
5027 usage: TokenUsage::default(),
5028 };
5029 let json = serde_json::to_string(&event).unwrap();
5030 assert!(json.contains("turn_end"));
5031 }
5032
5033 #[test]
5034 fn test_agent_event_serialize_end() {
5035 let event = AgentEvent::End {
5036 text: "Done".to_string(),
5037 usage: TokenUsage {
5038 prompt_tokens: 100,
5039 completion_tokens: 50,
5040 total_tokens: 150,
5041 cache_read_tokens: None,
5042 cache_write_tokens: None,
5043 },
5044 meta: None,
5045 };
5046 let json = serde_json::to_string(&event).unwrap();
5047 assert!(json.contains("agent_end"));
5048 }
5049
5050 #[test]
5055 fn test_agent_result_fields() {
5056 let result = AgentResult {
5057 text: "output".to_string(),
5058 messages: vec![Message::user("hello")],
5059 usage: TokenUsage::default(),
5060 tool_calls_count: 3,
5061 };
5062 assert_eq!(result.text, "output");
5063 assert_eq!(result.messages.len(), 1);
5064 assert_eq!(result.tool_calls_count, 3);
5065 }
5066
5067 #[test]
5072 fn test_agent_event_serialize_context_resolving() {
5073 let event = AgentEvent::ContextResolving {
5074 providers: vec!["provider1".to_string(), "provider2".to_string()],
5075 };
5076 let json = serde_json::to_string(&event).unwrap();
5077 assert!(json.contains("context_resolving"));
5078 assert!(json.contains("provider1"));
5079 }
5080
5081 #[test]
5082 fn test_agent_event_serialize_context_resolved() {
5083 let event = AgentEvent::ContextResolved {
5084 total_items: 5,
5085 total_tokens: 1000,
5086 };
5087 let json = serde_json::to_string(&event).unwrap();
5088 assert!(json.contains("context_resolved"));
5089 assert!(json.contains("1000"));
5090 }
5091
5092 #[test]
5093 fn test_agent_event_serialize_command_dead_lettered() {
5094 let event = AgentEvent::CommandDeadLettered {
5095 command_id: "cmd-1".to_string(),
5096 command_type: "bash".to_string(),
5097 lane: "execute".to_string(),
5098 error: "timeout".to_string(),
5099 attempts: 3,
5100 };
5101 let json = serde_json::to_string(&event).unwrap();
5102 assert!(json.contains("command_dead_lettered"));
5103 assert!(json.contains("cmd-1"));
5104 }
5105
5106 #[test]
5107 fn test_agent_event_serialize_command_retry() {
5108 let event = AgentEvent::CommandRetry {
5109 command_id: "cmd-2".to_string(),
5110 command_type: "read".to_string(),
5111 lane: "query".to_string(),
5112 attempt: 2,
5113 delay_ms: 1000,
5114 };
5115 let json = serde_json::to_string(&event).unwrap();
5116 assert!(json.contains("command_retry"));
5117 assert!(json.contains("cmd-2"));
5118 }
5119
5120 #[test]
5121 fn test_agent_event_serialize_queue_alert() {
5122 let event = AgentEvent::QueueAlert {
5123 level: "warning".to_string(),
5124 alert_type: "depth".to_string(),
5125 message: "Queue depth exceeded".to_string(),
5126 };
5127 let json = serde_json::to_string(&event).unwrap();
5128 assert!(json.contains("queue_alert"));
5129 assert!(json.contains("warning"));
5130 }
5131
5132 #[test]
5133 fn test_agent_event_serialize_task_updated() {
5134 let event = AgentEvent::TaskUpdated {
5135 session_id: "sess-1".to_string(),
5136 tasks: vec![],
5137 };
5138 let json = serde_json::to_string(&event).unwrap();
5139 assert!(json.contains("task_updated"));
5140 assert!(json.contains("sess-1"));
5141 }
5142
5143 #[test]
5144 fn test_agent_event_serialize_memory_stored() {
5145 let event = AgentEvent::MemoryStored {
5146 memory_id: "mem-1".to_string(),
5147 memory_type: "conversation".to_string(),
5148 importance: 0.8,
5149 tags: vec!["important".to_string()],
5150 };
5151 let json = serde_json::to_string(&event).unwrap();
5152 assert!(json.contains("memory_stored"));
5153 assert!(json.contains("mem-1"));
5154 }
5155
5156 #[test]
5157 fn test_agent_event_serialize_memory_recalled() {
5158 let event = AgentEvent::MemoryRecalled {
5159 memory_id: "mem-2".to_string(),
5160 content: "Previous conversation".to_string(),
5161 relevance: 0.9,
5162 };
5163 let json = serde_json::to_string(&event).unwrap();
5164 assert!(json.contains("memory_recalled"));
5165 assert!(json.contains("mem-2"));
5166 }
5167
5168 #[test]
5169 fn test_agent_event_serialize_memories_searched() {
5170 let event = AgentEvent::MemoriesSearched {
5171 query: Some("search term".to_string()),
5172 tags: vec!["tag1".to_string()],
5173 result_count: 5,
5174 };
5175 let json = serde_json::to_string(&event).unwrap();
5176 assert!(json.contains("memories_searched"));
5177 assert!(json.contains("search term"));
5178 }
5179
5180 #[test]
5181 fn test_agent_event_serialize_memory_cleared() {
5182 let event = AgentEvent::MemoryCleared {
5183 tier: "short_term".to_string(),
5184 count: 10,
5185 };
5186 let json = serde_json::to_string(&event).unwrap();
5187 assert!(json.contains("memory_cleared"));
5188 assert!(json.contains("short_term"));
5189 }
5190
5191 #[test]
5192 fn test_agent_event_serialize_subagent_start() {
5193 let event = AgentEvent::SubagentStart {
5194 task_id: "task-1".to_string(),
5195 session_id: "child-sess".to_string(),
5196 parent_session_id: "parent-sess".to_string(),
5197 agent: "explore".to_string(),
5198 description: "Explore codebase".to_string(),
5199 };
5200 let json = serde_json::to_string(&event).unwrap();
5201 assert!(json.contains("subagent_start"));
5202 assert!(json.contains("explore"));
5203 }
5204
5205 #[test]
5206 fn test_agent_event_serialize_subagent_progress() {
5207 let event = AgentEvent::SubagentProgress {
5208 task_id: "task-1".to_string(),
5209 session_id: "child-sess".to_string(),
5210 status: "processing".to_string(),
5211 metadata: serde_json::json!({"progress": 50}),
5212 };
5213 let json = serde_json::to_string(&event).unwrap();
5214 assert!(json.contains("subagent_progress"));
5215 assert!(json.contains("processing"));
5216 }
5217
5218 #[test]
5219 fn test_agent_event_serialize_subagent_end() {
5220 let event = AgentEvent::SubagentEnd {
5221 task_id: "task-1".to_string(),
5222 session_id: "child-sess".to_string(),
5223 agent: "explore".to_string(),
5224 output: "Found 10 files".to_string(),
5225 success: true,
5226 };
5227 let json = serde_json::to_string(&event).unwrap();
5228 assert!(json.contains("subagent_end"));
5229 assert!(json.contains("Found 10 files"));
5230 }
5231
5232 #[test]
5233 fn test_agent_event_serialize_planning_start() {
5234 let event = AgentEvent::PlanningStart {
5235 prompt: "Build a web app".to_string(),
5236 };
5237 let json = serde_json::to_string(&event).unwrap();
5238 assert!(json.contains("planning_start"));
5239 assert!(json.contains("Build a web app"));
5240 }
5241
5242 #[test]
5243 fn test_agent_event_serialize_planning_end() {
5244 use crate::planning::{Complexity, ExecutionPlan};
5245 let plan = ExecutionPlan::new("Test goal".to_string(), Complexity::Simple);
5246 let event = AgentEvent::PlanningEnd {
5247 plan,
5248 estimated_steps: 3,
5249 };
5250 let json = serde_json::to_string(&event).unwrap();
5251 assert!(json.contains("planning_end"));
5252 assert!(json.contains("estimated_steps"));
5253 }
5254
5255 #[test]
5256 fn test_agent_event_serialize_step_start() {
5257 let event = AgentEvent::StepStart {
5258 step_id: "step-1".to_string(),
5259 description: "Initialize project".to_string(),
5260 step_number: 1,
5261 total_steps: 5,
5262 };
5263 let json = serde_json::to_string(&event).unwrap();
5264 assert!(json.contains("step_start"));
5265 assert!(json.contains("Initialize project"));
5266 }
5267
5268 #[test]
5269 fn test_agent_event_serialize_step_end() {
5270 let event = AgentEvent::StepEnd {
5271 step_id: "step-1".to_string(),
5272 status: TaskStatus::Completed,
5273 step_number: 1,
5274 total_steps: 5,
5275 };
5276 let json = serde_json::to_string(&event).unwrap();
5277 assert!(json.contains("step_end"));
5278 assert!(json.contains("step-1"));
5279 }
5280
5281 #[test]
5282 fn test_agent_event_serialize_goal_extracted() {
5283 use crate::planning::AgentGoal;
5284 let goal = AgentGoal::new("Complete the task".to_string());
5285 let event = AgentEvent::GoalExtracted { goal };
5286 let json = serde_json::to_string(&event).unwrap();
5287 assert!(json.contains("goal_extracted"));
5288 }
5289
5290 #[test]
5291 fn test_agent_event_serialize_goal_progress() {
5292 let event = AgentEvent::GoalProgress {
5293 goal: "Build app".to_string(),
5294 progress: 0.5,
5295 completed_steps: 2,
5296 total_steps: 4,
5297 };
5298 let json = serde_json::to_string(&event).unwrap();
5299 assert!(json.contains("goal_progress"));
5300 assert!(json.contains("0.5"));
5301 }
5302
5303 #[test]
5304 fn test_agent_event_serialize_goal_achieved() {
5305 let event = AgentEvent::GoalAchieved {
5306 goal: "Build app".to_string(),
5307 total_steps: 4,
5308 duration_ms: 5000,
5309 };
5310 let json = serde_json::to_string(&event).unwrap();
5311 assert!(json.contains("goal_achieved"));
5312 assert!(json.contains("5000"));
5313 }
5314
5315 #[tokio::test]
5316 async fn test_extract_goal_with_json_response() {
5317 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
5319 r#"{"description": "Build web app", "success_criteria": ["App runs on port 3000", "Has login page"]}"#,
5320 )]));
5321 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5322 let agent = AgentLoop::new(
5323 mock_client,
5324 tool_executor,
5325 test_tool_context(),
5326 AgentConfig::default(),
5327 );
5328
5329 let goal = agent.extract_goal("Build a web app").await.unwrap();
5330 assert_eq!(goal.description, "Build web app");
5331 assert_eq!(goal.success_criteria.len(), 2);
5332 assert_eq!(goal.success_criteria[0], "App runs on port 3000");
5333 }
5334
5335 #[tokio::test]
5336 async fn test_extract_goal_fallback_on_non_json() {
5337 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
5339 "Some non-JSON response",
5340 )]));
5341 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5342 let agent = AgentLoop::new(
5343 mock_client,
5344 tool_executor,
5345 test_tool_context(),
5346 AgentConfig::default(),
5347 );
5348
5349 let goal = agent.extract_goal("Do something").await.unwrap();
5350 assert_eq!(goal.description, "Do something");
5352 assert_eq!(goal.success_criteria.len(), 2);
5354 }
5355
5356 #[tokio::test]
5357 async fn test_check_goal_achievement_json_yes() {
5358 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
5359 r#"{"achieved": true, "progress": 1.0, "remaining_criteria": []}"#,
5360 )]));
5361 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5362 let agent = AgentLoop::new(
5363 mock_client,
5364 tool_executor,
5365 test_tool_context(),
5366 AgentConfig::default(),
5367 );
5368
5369 let goal = crate::planning::AgentGoal::new("Test goal".to_string());
5370 let achieved = agent
5371 .check_goal_achievement(&goal, "All done")
5372 .await
5373 .unwrap();
5374 assert!(achieved);
5375 }
5376
5377 #[tokio::test]
5378 async fn test_check_goal_achievement_fallback_not_done() {
5379 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
5381 "invalid json",
5382 )]));
5383 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5384 let agent = AgentLoop::new(
5385 mock_client,
5386 tool_executor,
5387 test_tool_context(),
5388 AgentConfig::default(),
5389 );
5390
5391 let goal = crate::planning::AgentGoal::new("Test goal".to_string());
5392 let achieved = agent
5394 .check_goal_achievement(&goal, "still working")
5395 .await
5396 .unwrap();
5397 assert!(!achieved);
5398 }
5399
5400 #[test]
5405 fn test_build_augmented_system_prompt_empty_context() {
5406 let mock_client = Arc::new(MockLlmClient::new(vec![]));
5407 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5408 let config = AgentConfig {
5409 prompt_slots: SystemPromptSlots {
5410 extra: Some("Base prompt".to_string()),
5411 ..Default::default()
5412 },
5413 ..Default::default()
5414 };
5415 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
5416
5417 let result = agent.build_augmented_system_prompt(&[]);
5418 assert!(result.unwrap().contains("Base prompt"));
5419 }
5420
5421 #[test]
5422 fn test_build_augmented_system_prompt_no_custom_slots() {
5423 let mock_client = Arc::new(MockLlmClient::new(vec![]));
5424 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5425 let agent = AgentLoop::new(
5426 mock_client,
5427 tool_executor,
5428 test_tool_context(),
5429 AgentConfig::default(),
5430 );
5431
5432 let result = agent.build_augmented_system_prompt(&[]);
5433 assert!(result.is_some());
5435 assert!(result.unwrap().contains("Core Behaviour"));
5436 }
5437
5438 #[test]
5439 fn test_build_augmented_system_prompt_with_context_no_base() {
5440 use crate::context::{ContextItem, ContextResult, ContextType};
5441
5442 let mock_client = Arc::new(MockLlmClient::new(vec![]));
5443 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5444 let agent = AgentLoop::new(
5445 mock_client,
5446 tool_executor,
5447 test_tool_context(),
5448 AgentConfig::default(),
5449 );
5450
5451 let context = vec![ContextResult {
5452 provider: "test".to_string(),
5453 items: vec![ContextItem::new("id1", ContextType::Resource, "Content")],
5454 total_tokens: 10,
5455 truncated: false,
5456 }];
5457
5458 let result = agent.build_augmented_system_prompt(&context);
5459 assert!(result.is_some());
5460 let text = result.unwrap();
5461 assert!(text.contains("<context"));
5462 assert!(text.contains("Content"));
5463 }
5464
5465 #[test]
5470 fn test_agent_result_clone() {
5471 let result = AgentResult {
5472 text: "output".to_string(),
5473 messages: vec![Message::user("hello")],
5474 usage: TokenUsage::default(),
5475 tool_calls_count: 3,
5476 };
5477 let cloned = result.clone();
5478 assert_eq!(cloned.text, result.text);
5479 assert_eq!(cloned.tool_calls_count, result.tool_calls_count);
5480 }
5481
5482 #[test]
5483 fn test_agent_result_debug() {
5484 let result = AgentResult {
5485 text: "output".to_string(),
5486 messages: vec![Message::user("hello")],
5487 usage: TokenUsage::default(),
5488 tool_calls_count: 3,
5489 };
5490 let debug = format!("{:?}", result);
5491 assert!(debug.contains("AgentResult"));
5492 assert!(debug.contains("output"));
5493 }
5494
5495 #[tokio::test]
5504 async fn test_tool_command_command_type() {
5505 let executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5506 let cmd = ToolCommand {
5507 tool_executor: executor,
5508 tool_name: "read".to_string(),
5509 tool_args: serde_json::json!({"file": "test.rs"}),
5510 skill_registry: None,
5511 tool_context: test_tool_context(),
5512 };
5513 assert_eq!(cmd.command_type(), "read");
5514 }
5515
5516 #[tokio::test]
5517 async fn test_tool_command_payload() {
5518 let executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5519 let args = serde_json::json!({"file": "test.rs", "offset": 10});
5520 let cmd = ToolCommand {
5521 tool_executor: executor,
5522 tool_name: "read".to_string(),
5523 tool_args: args.clone(),
5524 skill_registry: None,
5525 tool_context: test_tool_context(),
5526 };
5527 assert_eq!(cmd.payload(), args);
5528 }
5529
5530 #[tokio::test(flavor = "multi_thread")]
5535 async fn test_agent_loop_with_queue() {
5536 use tokio::sync::broadcast;
5537
5538 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
5539 "Hello",
5540 )]));
5541 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5542 let config = AgentConfig::default();
5543
5544 let (event_tx, _) = broadcast::channel(100);
5545 let queue = SessionLaneQueue::new("test-session", SessionQueueConfig::default(), event_tx)
5546 .await
5547 .unwrap();
5548
5549 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config)
5550 .with_queue(Arc::new(queue));
5551
5552 assert!(agent.command_queue.is_some());
5553 }
5554
5555 #[tokio::test]
5556 async fn test_agent_loop_without_queue() {
5557 let mock_client = Arc::new(MockLlmClient::new(vec![MockLlmClient::text_response(
5558 "Hello",
5559 )]));
5560 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5561 let config = AgentConfig::default();
5562
5563 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
5564
5565 assert!(agent.command_queue.is_none());
5566 }
5567
5568 #[tokio::test]
5573 async fn test_execute_plan_parallel_independent() {
5574 use crate::planning::{Complexity, ExecutionPlan, Task};
5575
5576 let mock_client = Arc::new(MockLlmClient::new(vec![
5579 MockLlmClient::text_response("Step 1 done"),
5580 MockLlmClient::text_response("Step 2 done"),
5581 MockLlmClient::text_response("Step 3 done"),
5582 ]));
5583
5584 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5585 let config = AgentConfig::default();
5586 let agent = AgentLoop::new(
5587 mock_client.clone(),
5588 tool_executor,
5589 test_tool_context(),
5590 config,
5591 );
5592
5593 let mut plan = ExecutionPlan::new("Test parallel", Complexity::Simple);
5594 plan.add_step(Task::new("s1", "First step"));
5595 plan.add_step(Task::new("s2", "Second step"));
5596 plan.add_step(Task::new("s3", "Third step"));
5597
5598 let (tx, mut rx) = mpsc::channel(100);
5599 let result = agent.execute_plan(&[], &plan, Some(tx)).await.unwrap();
5600
5601 assert_eq!(result.usage.total_tokens, 45);
5603
5604 let mut step_starts = Vec::new();
5606 let mut step_ends = Vec::new();
5607 rx.close();
5608 while let Some(event) = rx.recv().await {
5609 match event {
5610 AgentEvent::StepStart { step_id, .. } => step_starts.push(step_id),
5611 AgentEvent::StepEnd {
5612 step_id, status, ..
5613 } => {
5614 assert_eq!(status, TaskStatus::Completed);
5615 step_ends.push(step_id);
5616 }
5617 _ => {}
5618 }
5619 }
5620 assert_eq!(step_starts.len(), 3);
5621 assert_eq!(step_ends.len(), 3);
5622 }
5623
5624 #[tokio::test]
5625 async fn test_execute_plan_respects_dependencies() {
5626 use crate::planning::{Complexity, ExecutionPlan, Task};
5627
5628 let mock_client = Arc::new(MockLlmClient::new(vec![
5631 MockLlmClient::text_response("Step 1 done"),
5632 MockLlmClient::text_response("Step 2 done"),
5633 MockLlmClient::text_response("Step 3 done"),
5634 ]));
5635
5636 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5637 let config = AgentConfig::default();
5638 let agent = AgentLoop::new(
5639 mock_client.clone(),
5640 tool_executor,
5641 test_tool_context(),
5642 config,
5643 );
5644
5645 let mut plan = ExecutionPlan::new("Test deps", Complexity::Medium);
5646 plan.add_step(Task::new("s1", "Independent A"));
5647 plan.add_step(Task::new("s2", "Independent B"));
5648 plan.add_step(
5649 Task::new("s3", "Depends on A+B")
5650 .with_dependencies(vec!["s1".to_string(), "s2".to_string()]),
5651 );
5652
5653 let (tx, mut rx) = mpsc::channel(100);
5654 let result = agent.execute_plan(&[], &plan, Some(tx)).await.unwrap();
5655
5656 assert_eq!(result.usage.total_tokens, 45);
5658
5659 let mut events = Vec::new();
5661 rx.close();
5662 while let Some(event) = rx.recv().await {
5663 match &event {
5664 AgentEvent::StepStart { step_id, .. } => {
5665 events.push(format!("start:{}", step_id));
5666 }
5667 AgentEvent::StepEnd { step_id, .. } => {
5668 events.push(format!("end:{}", step_id));
5669 }
5670 _ => {}
5671 }
5672 }
5673
5674 let s1_end = events.iter().position(|e| e == "end:s1").unwrap();
5676 let s2_end = events.iter().position(|e| e == "end:s2").unwrap();
5677 let s3_start = events.iter().position(|e| e == "start:s3").unwrap();
5678 assert!(
5679 s3_start > s1_end,
5680 "s3 started before s1 ended: {:?}",
5681 events
5682 );
5683 assert!(
5684 s3_start > s2_end,
5685 "s3 started before s2 ended: {:?}",
5686 events
5687 );
5688
5689 assert!(result.text.contains("Step 3 done") || !result.text.is_empty());
5691 }
5692
5693 #[tokio::test]
5694 async fn test_execute_plan_handles_step_failure() {
5695 use crate::planning::{Complexity, ExecutionPlan, Task};
5696
5697 let mock_client = Arc::new(MockLlmClient::new(vec![
5707 MockLlmClient::text_response("s1 done"),
5709 MockLlmClient::text_response("s3 done"),
5710 ]));
5713
5714 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5715 let config = AgentConfig::default();
5716 let agent = AgentLoop::new(
5717 mock_client.clone(),
5718 tool_executor,
5719 test_tool_context(),
5720 config,
5721 );
5722
5723 let mut plan = ExecutionPlan::new("Test failure", Complexity::Medium);
5724 plan.add_step(Task::new("s1", "Independent step"));
5725 plan.add_step(Task::new("s2", "Depends on s1").with_dependencies(vec!["s1".to_string()]));
5726 plan.add_step(Task::new("s3", "Another independent"));
5727 plan.add_step(Task::new("s4", "Depends on s2").with_dependencies(vec!["s2".to_string()]));
5728
5729 let (tx, mut rx) = mpsc::channel(100);
5730 let _result = agent.execute_plan(&[], &plan, Some(tx)).await.unwrap();
5731
5732 let mut completed_steps = Vec::new();
5735 let mut failed_steps = Vec::new();
5736 rx.close();
5737 while let Some(event) = rx.recv().await {
5738 if let AgentEvent::StepEnd {
5739 step_id, status, ..
5740 } = event
5741 {
5742 match status {
5743 TaskStatus::Completed => completed_steps.push(step_id),
5744 TaskStatus::Failed => failed_steps.push(step_id),
5745 _ => {}
5746 }
5747 }
5748 }
5749
5750 assert!(
5751 completed_steps.contains(&"s1".to_string()),
5752 "s1 should complete"
5753 );
5754 assert!(
5755 completed_steps.contains(&"s3".to_string()),
5756 "s3 should complete"
5757 );
5758 assert!(failed_steps.contains(&"s2".to_string()), "s2 should fail");
5759 assert!(
5761 !completed_steps.contains(&"s4".to_string()),
5762 "s4 should not complete"
5763 );
5764 assert!(
5765 !failed_steps.contains(&"s4".to_string()),
5766 "s4 should not fail (never started)"
5767 );
5768 }
5769
5770 #[test]
5775 fn test_agent_config_resilience_defaults() {
5776 let config = AgentConfig::default();
5777 assert_eq!(config.max_parse_retries, 2);
5778 assert_eq!(config.tool_timeout_ms, None);
5779 assert_eq!(config.circuit_breaker_threshold, 3);
5780 }
5781
5782 #[tokio::test]
5784 async fn test_parse_error_recovery_bails_after_threshold() {
5785 let mock_client = Arc::new(MockLlmClient::new(vec![
5787 MockLlmClient::tool_call_response(
5788 "c1",
5789 "bash",
5790 serde_json::json!({"__parse_error": "unexpected token at position 5"}),
5791 ),
5792 MockLlmClient::tool_call_response(
5793 "c2",
5794 "bash",
5795 serde_json::json!({"__parse_error": "missing closing brace"}),
5796 ),
5797 MockLlmClient::tool_call_response(
5798 "c3",
5799 "bash",
5800 serde_json::json!({"__parse_error": "still broken"}),
5801 ),
5802 MockLlmClient::text_response("Done"), ]));
5804
5805 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5806 let config = AgentConfig {
5807 max_parse_retries: 2,
5808 ..AgentConfig::default()
5809 };
5810 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
5811 let result = agent.execute(&[], "Do something", None).await;
5812 assert!(result.is_err(), "should bail after parse error threshold");
5813 let err = result.unwrap_err().to_string();
5814 assert!(
5815 err.contains("malformed tool arguments"),
5816 "error should mention malformed tool arguments, got: {}",
5817 err
5818 );
5819 }
5820
5821 #[tokio::test]
5823 async fn test_parse_error_counter_resets_on_success() {
5824 let mock_client = Arc::new(MockLlmClient::new(vec![
5828 MockLlmClient::tool_call_response(
5829 "c1",
5830 "bash",
5831 serde_json::json!({"__parse_error": "bad args"}),
5832 ),
5833 MockLlmClient::tool_call_response(
5834 "c2",
5835 "bash",
5836 serde_json::json!({"__parse_error": "bad args again"}),
5837 ),
5838 MockLlmClient::tool_call_response(
5840 "c3",
5841 "bash",
5842 serde_json::json!({"command": "echo ok"}),
5843 ),
5844 MockLlmClient::text_response("All done"),
5845 ]));
5846
5847 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5848 let config = AgentConfig {
5849 max_parse_retries: 2,
5850 ..AgentConfig::default()
5851 };
5852 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
5853 let result = agent.execute(&[], "Do something", None).await;
5854 assert!(
5855 result.is_ok(),
5856 "should not bail — counter reset after successful tool, got: {:?}",
5857 result.err()
5858 );
5859 assert_eq!(result.unwrap().text, "All done");
5860 }
5861
5862 #[tokio::test]
5864 async fn test_tool_timeout_produces_error_result() {
5865 let mock_client = Arc::new(MockLlmClient::new(vec![
5866 MockLlmClient::tool_call_response(
5867 "t1",
5868 "bash",
5869 serde_json::json!({"command": "sleep 10"}),
5870 ),
5871 MockLlmClient::text_response("The command timed out."),
5872 ]));
5873
5874 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5875 let config = AgentConfig {
5876 tool_timeout_ms: Some(50),
5878 ..AgentConfig::default()
5879 };
5880 let agent = AgentLoop::new(
5881 mock_client.clone(),
5882 tool_executor,
5883 test_tool_context(),
5884 config,
5885 );
5886 let result = agent.execute(&[], "Run sleep", None).await;
5887 assert!(
5888 result.is_ok(),
5889 "session should continue after tool timeout: {:?}",
5890 result.err()
5891 );
5892 assert_eq!(result.unwrap().text, "The command timed out.");
5893 assert_eq!(mock_client.call_count.load(Ordering::SeqCst), 2);
5895 }
5896
5897 #[tokio::test]
5899 async fn test_tool_within_timeout_succeeds() {
5900 let mock_client = Arc::new(MockLlmClient::new(vec![
5901 MockLlmClient::tool_call_response(
5902 "t1",
5903 "bash",
5904 serde_json::json!({"command": "echo fast"}),
5905 ),
5906 MockLlmClient::text_response("Command succeeded."),
5907 ]));
5908
5909 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5910 let config = AgentConfig {
5911 tool_timeout_ms: Some(5_000), ..AgentConfig::default()
5913 };
5914 let agent = AgentLoop::new(mock_client, tool_executor, test_tool_context(), config);
5915 let result = agent.execute(&[], "Run something fast", None).await;
5916 assert!(
5917 result.is_ok(),
5918 "fast tool should succeed: {:?}",
5919 result.err()
5920 );
5921 assert_eq!(result.unwrap().text, "Command succeeded.");
5922 }
5923
5924 #[tokio::test]
5926 async fn test_circuit_breaker_retries_non_streaming() {
5927 let mock_client = Arc::new(MockLlmClient::new(vec![]));
5930
5931 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5932 let config = AgentConfig {
5933 circuit_breaker_threshold: 2,
5934 ..AgentConfig::default()
5935 };
5936 let agent = AgentLoop::new(
5937 mock_client.clone(),
5938 tool_executor,
5939 test_tool_context(),
5940 config,
5941 );
5942 let result = agent.execute(&[], "Hello", None).await;
5943 assert!(result.is_err(), "should fail when LLM always errors");
5944 let err = result.unwrap_err().to_string();
5945 assert!(
5946 err.contains("circuit breaker"),
5947 "error should mention circuit breaker, got: {}",
5948 err
5949 );
5950 assert_eq!(
5951 mock_client.call_count.load(Ordering::SeqCst),
5952 2,
5953 "should make exactly threshold=2 LLM calls"
5954 );
5955 }
5956
5957 #[tokio::test]
5959 async fn test_circuit_breaker_threshold_one_no_retry() {
5960 let mock_client = Arc::new(MockLlmClient::new(vec![]));
5961
5962 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
5963 let config = AgentConfig {
5964 circuit_breaker_threshold: 1,
5965 ..AgentConfig::default()
5966 };
5967 let agent = AgentLoop::new(
5968 mock_client.clone(),
5969 tool_executor,
5970 test_tool_context(),
5971 config,
5972 );
5973 let result = agent.execute(&[], "Hello", None).await;
5974 assert!(result.is_err());
5975 assert_eq!(
5976 mock_client.call_count.load(Ordering::SeqCst),
5977 1,
5978 "with threshold=1 exactly one attempt should be made"
5979 );
5980 }
5981
5982 #[tokio::test]
5984 async fn test_circuit_breaker_succeeds_if_llm_recovers() {
5985 struct FailOnceThenSucceed {
5987 inner: MockLlmClient,
5988 failed_once: std::sync::atomic::AtomicBool,
5989 call_count: AtomicUsize,
5990 }
5991
5992 #[async_trait::async_trait]
5993 impl LlmClient for FailOnceThenSucceed {
5994 async fn complete(
5995 &self,
5996 messages: &[Message],
5997 system: Option<&str>,
5998 tools: &[ToolDefinition],
5999 ) -> Result<LlmResponse> {
6000 self.call_count.fetch_add(1, Ordering::SeqCst);
6001 let already_failed = self
6002 .failed_once
6003 .swap(true, std::sync::atomic::Ordering::SeqCst);
6004 if !already_failed {
6005 anyhow::bail!("transient network error");
6006 }
6007 self.inner.complete(messages, system, tools).await
6008 }
6009
6010 async fn complete_streaming(
6011 &self,
6012 messages: &[Message],
6013 system: Option<&str>,
6014 tools: &[ToolDefinition],
6015 ) -> Result<tokio::sync::mpsc::Receiver<crate::llm::StreamEvent>> {
6016 self.inner.complete_streaming(messages, system, tools).await
6017 }
6018 }
6019
6020 let mock = Arc::new(FailOnceThenSucceed {
6021 inner: MockLlmClient::new(vec![MockLlmClient::text_response("Recovered!")]),
6022 failed_once: std::sync::atomic::AtomicBool::new(false),
6023 call_count: AtomicUsize::new(0),
6024 });
6025
6026 let tool_executor = Arc::new(ToolExecutor::new("/tmp".to_string()));
6027 let config = AgentConfig {
6028 circuit_breaker_threshold: 3,
6029 ..AgentConfig::default()
6030 };
6031 let agent = AgentLoop::new(mock.clone(), tool_executor, test_tool_context(), config);
6032 let result = agent.execute(&[], "Hello", None).await;
6033 assert!(
6034 result.is_ok(),
6035 "should succeed when LLM recovers within threshold: {:?}",
6036 result.err()
6037 );
6038 assert_eq!(result.unwrap().text, "Recovered!");
6039 assert_eq!(
6040 mock.call_count.load(Ordering::SeqCst),
6041 2,
6042 "should have made exactly 2 calls (1 fail + 1 success)"
6043 );
6044 }
6045
6046 #[test]
6049 fn test_looks_incomplete_empty() {
6050 assert!(AgentLoop::looks_incomplete(""));
6051 assert!(AgentLoop::looks_incomplete(" "));
6052 }
6053
6054 #[test]
6055 fn test_looks_incomplete_trailing_colon() {
6056 assert!(AgentLoop::looks_incomplete("Let me check the file:"));
6057 assert!(AgentLoop::looks_incomplete("Next steps:"));
6058 }
6059
6060 #[test]
6061 fn test_looks_incomplete_ellipsis() {
6062 assert!(AgentLoop::looks_incomplete("Working on it..."));
6063 assert!(AgentLoop::looks_incomplete("Processing…"));
6064 }
6065
6066 #[test]
6067 fn test_looks_incomplete_intent_phrases() {
6068 assert!(AgentLoop::looks_incomplete(
6069 "I'll start by reading the file."
6070 ));
6071 assert!(AgentLoop::looks_incomplete(
6072 "Let me check the configuration."
6073 ));
6074 assert!(AgentLoop::looks_incomplete("I will now run the tests."));
6075 assert!(AgentLoop::looks_incomplete(
6076 "I need to update the Cargo.toml."
6077 ));
6078 }
6079
6080 #[test]
6081 fn test_looks_complete_final_answer() {
6082 assert!(!AgentLoop::looks_incomplete(
6084 "The tests pass. All changes have been applied successfully."
6085 ));
6086 assert!(!AgentLoop::looks_incomplete(
6087 "Done. I've updated the three files and verified the build succeeds."
6088 ));
6089 assert!(!AgentLoop::looks_incomplete("42"));
6090 assert!(!AgentLoop::looks_incomplete("Yes."));
6091 }
6092
6093 #[test]
6094 fn test_looks_incomplete_multiline_complete() {
6095 let text = "Here is the summary:\n\n- Fixed the bug in agent.rs\n- All tests pass\n- Build succeeds";
6096 assert!(!AgentLoop::looks_incomplete(text));
6097 }
6098}