1use crate::approval::ApprovalChecker;
19use crate::commands::{
20 dispatch as dispatch_command, load_commands_for_run, parse_slash_invocation,
21};
22use crate::compaction::{self, HistoryMessage};
23use crate::config::RunnerConfig;
24use crate::hooks::HookRegistry;
25use crate::parser;
26use crate::retry::RetryHandler;
27
28use enact_config::{HookDecision, HookEvent, HookHandler};
29use enact_core::callable::Callable;
30use enact_core::graph::CheckpointStore;
31use enact_core::kernel::cost::{
32 CostCalculator, ModelPricing, TokenUsage as LlmTokenUsage, UsageAccumulator,
33};
34use enact_core::kernel::{ExecutionError, StepId, StepType};
35use enact_core::runner::Runner;
36use enact_core::streaming::StreamEvent;
37use enact_core::tool::Tool;
38use enact_skills::{
39 find_matching_skills, find_skill_by_name, load_skill_body, load_skill_metas,
40 load_skill_resources, SkillResourceLimits,
41};
42
43use std::path::Path;
44use std::sync::Arc;
45use std::time::Instant;
46
47#[derive(Debug)]
49pub enum LoopOutcome {
50 Completed(String),
52 MaxIterationsReached {
54 last_output: String,
55 iterations: usize,
56 },
57 Cancelled,
59 TimedOut { elapsed_secs: u64 },
61}
62
63impl LoopOutcome {
64 pub fn is_completed(&self) -> bool {
66 matches!(self, LoopOutcome::Completed(_))
67 }
68
69 pub fn output(&self) -> Option<&str> {
71 match self {
72 LoopOutcome::Completed(s) => Some(s),
73 LoopOutcome::MaxIterationsReached { last_output, .. } => Some(last_output),
74 _ => None,
75 }
76 }
77}
78
79pub struct AgentRunner<S: CheckpointStore> {
82 runner: Runner<S>,
84 config: RunnerConfig,
86 tools: Vec<Arc<dyn Tool>>,
88 usage_accumulator: UsageAccumulator,
90 approval_checker: Option<Arc<dyn ApprovalChecker>>,
92 hook_registry: Option<Arc<HookRegistry>>,
94}
95
96impl<S: CheckpointStore> AgentRunner<S> {
97 pub fn new(runner: Runner<S>, config: RunnerConfig) -> Self {
99 Self {
100 runner,
101 config,
102 tools: Vec::new(),
103 usage_accumulator: UsageAccumulator::new(),
104 approval_checker: None,
105 hook_registry: None,
106 }
107 }
108
109 pub fn with_hook_registry(mut self, registry: Arc<HookRegistry>) -> Self {
111 self.hook_registry = Some(registry);
112 self
113 }
114
115 pub fn with_approval_checker(mut self, checker: Arc<dyn ApprovalChecker>) -> Self {
117 self.approval_checker = Some(checker);
118 self
119 }
120
121 pub fn usage(&self) -> &UsageAccumulator {
123 &self.usage_accumulator
124 }
125
126 pub fn add_tool(mut self, tool: impl Tool + 'static) -> Self {
128 self.tools.push(Arc::new(tool));
129 self
130 }
131
132 pub fn add_tools(mut self, tools: Vec<Arc<dyn Tool>>) -> Self {
134 self.tools.extend(tools);
135 self
136 }
137
138 pub fn inner(&self) -> &Runner<S> {
140 &self.runner
141 }
142
143 pub fn inner_mut(&mut self) -> &mut Runner<S> {
145 &mut self.runner
146 }
147
148 pub async fn run(
162 &mut self,
163 callable: &dyn Callable,
164 input: &str,
165 project_dir: Option<&Path>,
166 ) -> anyhow::Result<LoopOutcome> {
167 let start_time = Instant::now();
168 let mut retry_handler = RetryHandler::new(self.config.retry.clone());
169
170 self.load_mcp_tools_at_session_start().await;
171
172 if let Some(ref reg) = self.hook_registry {
174 let mut ctx = serde_json::json!({
175 "event": "SessionStart",
176 "execution_id": self.runner.execution_id().as_str(),
177 "callable": callable.name(),
178 });
179 self.execute_hooks(reg, HookEvent::SessionStart, None, &mut ctx, callable)
180 .await;
181 }
182
183 let commands = load_commands_for_run(project_dir);
185 let command_expanded = dispatch_command(input, &commands);
186 let effective_input = command_expanded
187 .clone()
188 .unwrap_or_else(|| input.to_string());
189
190 let first_user_content = if command_expanded.is_none() {
192 if let Some(invoked) = build_manual_skill_input(project_dir, input) {
193 invoked
194 } else {
195 let skill_context = build_skill_context(project_dir, &effective_input);
196 if skill_context.is_empty() {
197 effective_input.clone()
198 } else {
199 format!("{}\n\n{}", skill_context, effective_input)
200 }
201 }
202 } else {
203 let skill_context = build_skill_context(project_dir, &effective_input);
204 if skill_context.is_empty() {
205 effective_input.clone()
206 } else {
207 format!("{}\n\n{}", skill_context, effective_input)
208 }
209 };
210
211 let mut history = vec![HistoryMessage::user(&first_user_content)];
213
214 if let Some(ref reg) = self.hook_registry {
216 let mut ctx = serde_json::json!({
217 "event": "UserPromptSubmit",
218 "execution_id": self.runner.execution_id().as_str(),
219 "prompt": effective_input,
220 });
221 self.execute_hooks(reg, HookEvent::UserPromptSubmit, None, &mut ctx, callable)
222 .await;
223 }
224
225 if self.config.emit_events {
227 self.runner
228 .emitter()
229 .emit(StreamEvent::execution_start(self.runner.execution_id()));
230 }
231
232 tracing::info!(
233 execution_id = %self.runner.execution_id(),
234 callable = callable.name(),
235 max_iterations = self.config.max_iterations,
236 "Starting robust agent loop"
237 );
238
239 let mut last_output = String::new();
240
241 for iteration in 0..self.config.max_iterations {
242 if self.runner.is_cancelled() {
244 tracing::info!("Execution cancelled");
245 if let Some(ref reg) = self.hook_registry {
246 let mut ctx = serde_json::json!({
247 "event": "SessionEnd",
248 "execution_id": self.runner.execution_id().as_str(),
249 "outcome": "cancelled",
250 });
251 self.execute_hooks(reg, HookEvent::SessionEnd, None, &mut ctx, callable)
252 .await;
253 }
254 return Ok(LoopOutcome::Cancelled);
255 }
256
257 let elapsed = start_time.elapsed();
259 if elapsed > self.config.max_duration {
260 tracing::warn!(elapsed_secs = elapsed.as_secs(), "Execution timed out");
261 if let Some(ref reg) = self.hook_registry {
262 let mut ctx = serde_json::json!({
263 "event": "SessionEnd",
264 "execution_id": self.runner.execution_id().as_str(),
265 "outcome": "timed_out",
266 "elapsed_secs": elapsed.as_secs(),
267 });
268 self.execute_hooks(reg, HookEvent::SessionEnd, None, &mut ctx, callable)
269 .await;
270 }
271 return Ok(LoopOutcome::TimedOut {
272 elapsed_secs: elapsed.as_secs(),
273 });
274 }
275
276 while self.runner.is_paused() {
278 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
279 if self.runner.is_cancelled() {
280 return Ok(LoopOutcome::Cancelled);
281 }
282 }
283
284 let current_input = self.build_iteration_input(&history);
286 let step_id = StepId::new();
287 let step_start = Instant::now();
288
289 if self.config.emit_events {
290 self.runner.emitter().emit(StreamEvent::step_start(
291 self.runner.execution_id(),
292 &step_id,
293 StepType::LlmNode,
294 format!("iteration_{}", iteration),
295 ));
296 }
297
298 tracing::debug!(iteration, "Running callable");
299
300 if self.config.emit_events && self.config.observability.trace_llm_calls {
302 self.runner.emitter().emit(StreamEvent::llm_call_start(
303 self.runner.execution_id(),
304 Some(&step_id),
305 callable.name(),
306 self.config.observability.model_name.clone(),
307 history.len(),
308 ));
309 }
310
311 let llm_call_start = Instant::now();
312
313 let response = match callable.run(¤t_input).await {
315 Ok(output) => {
316 retry_handler.reset();
317
318 let llm_call_duration_ms = llm_call_start.elapsed().as_millis() as u64;
320 if self.config.emit_events && self.config.observability.trace_llm_calls {
321 let (prompt_tokens, completion_tokens) = match callable.last_usage() {
322 Some(u) if !u.is_empty() => (u.prompt_tokens, u.completion_tokens),
323 _ => {
324 let prompt_len = current_input.len();
325 let estimated_prompt_tokens = (prompt_len / 4) as u32;
326 let estimated_tokens = (output.len() / 4) as u32;
327 (estimated_prompt_tokens, estimated_tokens)
328 }
329 };
330 let total_tokens = prompt_tokens + completion_tokens;
331
332 self.runner.emitter().emit(StreamEvent::llm_call_end(
333 self.runner.execution_id(),
334 Some(&step_id),
335 llm_call_duration_ms,
336 Some(prompt_tokens),
337 Some(completion_tokens),
338 Some(total_tokens),
339 ));
340
341 if self.config.observability.track_token_usage {
343 let usage = LlmTokenUsage::new(prompt_tokens, completion_tokens);
344
345 let pricing = if let (Some(input_cost), Some(output_cost)) = (
347 self.config.observability.cost_per_1m_input,
348 self.config.observability.cost_per_1m_output,
349 ) {
350 ModelPricing::new(input_cost, output_cost)
351 } else {
352 let model = self
353 .config
354 .observability
355 .model_name
356 .as_deref()
357 .unwrap_or("default");
358 CostCalculator::pricing_for_model(model)
359 };
360 let cost = CostCalculator::calculate_cost(&usage, &pricing);
361
362 self.usage_accumulator.add(&usage, cost);
363
364 self.runner
366 .emitter()
367 .emit(StreamEvent::token_usage_recorded(
368 self.runner.execution_id(),
369 Some(&step_id),
370 prompt_tokens,
371 completion_tokens,
372 self.usage_accumulator.total_tokens,
373 Some(cost),
374 Some(self.usage_accumulator.total_cost_usd),
375 ));
376 }
377 }
378
379 output
380 }
381 Err(e) => {
382 let err_msg = e.to_string();
383 let llm_call_duration_ms = llm_call_start.elapsed().as_millis() as u64;
384 tracing::warn!(iteration, error = %err_msg, "Callable error");
385
386 if self.config.emit_events && self.config.observability.trace_llm_calls {
388 self.runner.emitter().emit(StreamEvent::llm_call_failed(
389 self.runner.execution_id(),
390 Some(&step_id),
391 &err_msg,
392 Some(llm_call_duration_ms),
393 ));
394 }
395
396 if let Some(delay) = retry_handler.should_retry(&err_msg) {
398 if self.config.emit_events {
399 self.runner.emitter().emit(StreamEvent::step_end(
400 self.runner.execution_id(),
401 &step_id,
402 Some(format!("Retrying after error: {}", err_msg)),
403 step_start.elapsed().as_millis() as u64,
404 ));
405 }
406 tokio::time::sleep(delay).await;
407 continue; }
409
410 let exec_error = ExecutionError::kernel_internal(err_msg);
412 if self.config.emit_events {
413 self.runner.emitter().emit(StreamEvent::execution_failed(
414 self.runner.execution_id(),
415 exec_error,
416 ));
417 }
418 return Err(e);
419 }
420 };
421
422 let mut parse_result = parser::parse(&response);
424
425 if parse_result.tool_calls.is_empty() {
426 last_output = if parse_result.text.is_empty() {
428 response.clone()
429 } else {
430 parse_result.text.clone()
431 };
432
433 let step_duration = step_start.elapsed().as_millis() as u64;
434 if self.config.emit_events {
435 self.runner.emitter().emit(StreamEvent::step_end(
436 self.runner.execution_id(),
437 &step_id,
438 Some(last_output.clone()),
439 step_duration,
440 ));
441 let total_duration = start_time.elapsed().as_millis() as u64;
442 self.runner.emitter().emit(StreamEvent::execution_end(
443 self.runner.execution_id(),
444 Some(last_output.clone()),
445 total_duration,
446 ));
447 }
448
449 tracing::info!(
450 iteration,
451 output_len = last_output.len(),
452 "Agent loop completed with final response"
453 );
454
455 if let Some(ref reg) = self.hook_registry {
457 let mut ctx = serde_json::json!({
458 "event": "SessionEnd",
459 "execution_id": self.runner.execution_id().as_str(),
460 "outcome": "completed",
461 "output": last_output,
462 });
463 self.execute_hooks(reg, HookEvent::SessionEnd, None, &mut ctx, callable)
464 .await;
465 self.execute_hooks(reg, HookEvent::Stop, None, &mut ctx, callable)
466 .await;
467 }
468
469 return Ok(LoopOutcome::Completed(last_output));
470 }
471
472 history.push(HistoryMessage::assistant(&response));
474
475 for (tool_idx, tool_call) in parse_result.tool_calls.iter_mut().enumerate() {
476 let tool_call_id = format!("{}-tool-{}", step_id, tool_idx);
477
478 tracing::debug!(
479 tool = %tool_call.name,
480 format = ?tool_call.format,
481 "Executing tool call"
482 );
483
484 if let Some(ref reg) = self.hook_registry {
486 let ctx = serde_json::json!({
487 "event": "PreToolUse",
488 "execution_id": self.runner.execution_id().as_str(),
489 "tool_name": tool_call.name,
490 "arguments": tool_call.arguments,
491 });
492 match self
493 .execute_pre_tool_hooks(reg, Some(&tool_call.name), &ctx, callable)
494 .await
495 {
496 HookDecision::Allow => {}
497 HookDecision::Mutate { arguments, reason } => {
498 if let Some(reason) = reason {
499 tracing::debug!(
500 tool = %tool_call.name,
501 reason = %reason,
502 "PreToolUse hook mutated tool arguments"
503 );
504 }
505 tool_call.arguments = arguments;
506 }
507 HookDecision::Block { reason } => {
508 let blocked_msg = match reason {
509 Some(reason) if !reason.is_empty() => format!(
510 "Tool '{}' was blocked by a PreToolUse hook: {}",
511 tool_call.name, reason
512 ),
513 _ => {
514 format!(
515 "Tool '{}' was blocked by a PreToolUse hook.",
516 tool_call.name
517 )
518 }
519 };
520 history
521 .push(HistoryMessage::tool_result(&tool_call.name, &blocked_msg));
522 continue;
523 }
524 }
525 }
526
527 if let Some(ref checker) = self.approval_checker {
529 if self.config.emit_events {
530 self.runner.emitter().emit(StreamEvent::permission_request(
531 self.runner.execution_id(),
532 &tool_call.name,
533 tool_call.arguments.clone(),
534 "approval_policy",
535 ));
536 }
537 if !checker
538 .allow_tool(&tool_call.name, &tool_call.arguments)
539 .await
540 {
541 let blocked_msg = format!(
542 "Tool '{}' was blocked by approval policy (user denied or policy=always_deny).",
543 tool_call.name
544 );
545 history.push(HistoryMessage::tool_result(&tool_call.name, &blocked_msg));
546 continue;
547 }
548 }
549
550 if self.config.emit_events {
552 self.runner.emitter().emit(StreamEvent::ToolInputAvailable {
553 tool_call_id: tool_call_id.clone(),
554 tool_name: tool_call.name.clone(),
555 input: tool_call.arguments.clone(),
556 });
557 }
558
559 let tool_start = std::time::Instant::now();
560
561 let tool = self.tools.iter().find(|t| t.name() == tool_call.name);
563
564 let tool_result = match tool {
565 Some(t) => match t.execute(tool_call.arguments.clone()).await {
566 Ok(result) => serde_json::to_string(&result)
567 .unwrap_or_else(|_| format!("{:?}", result)),
568 Err(e) => format!("Tool error: {}", e),
569 },
570 None => {
571 format!(
572 "Error: Tool '{}' not found. Available tools: {:?}",
573 tool_call.name,
574 self.tools.iter().map(|t| t.name()).collect::<Vec<_>>()
575 )
576 }
577 };
578
579 let tool_duration_ms = tool_start.elapsed().as_millis() as u64;
580
581 if self.config.emit_events {
583 self.runner
584 .emitter()
585 .emit(StreamEvent::ToolOutputAvailable {
586 tool_call_id: tool_call_id.clone(),
587 output: serde_json::json!({
588 "result": tool_result.clone(),
589 "duration_ms": tool_duration_ms,
590 }),
591 });
592 }
593
594 history.push(HistoryMessage::tool_result(&tool_call.name, &tool_result));
595
596 if let Some(ref reg) = self.hook_registry {
598 let mut ctx = serde_json::json!({
599 "event": "PostToolUse",
600 "execution_id": self.runner.execution_id().as_str(),
601 "tool_name": tool_call.name,
602 "arguments": tool_call.arguments,
603 "result": tool_result,
604 });
605 self.execute_hooks(
606 reg,
607 HookEvent::PostToolUse,
608 Some(&tool_call.name),
609 &mut ctx,
610 callable,
611 )
612 .await;
613 }
614 }
615
616 let step_duration = step_start.elapsed().as_millis() as u64;
617 if self.config.emit_events {
618 self.runner.emitter().emit(StreamEvent::step_end(
619 self.runner.execution_id(),
620 &step_id,
621 Some(format!(
622 "Executed {} tool(s)",
623 parse_result.tool_calls.len()
624 )),
625 step_duration,
626 ));
627 }
628
629 last_output = response;
630
631 if compaction::needs_compaction(&history, self.config.compaction_threshold) {
633 tracing::info!(
634 history_len = history.len(),
635 threshold = self.config.compaction_threshold,
636 "Triggering auto-compaction"
637 );
638
639 match compaction::compact_history(
640 &mut history,
641 callable,
642 self.config.compaction_keep_recent,
643 )
644 .await
645 {
646 Ok(true) => {
647 tracing::info!(new_len = history.len(), "History compacted successfully");
648 }
649 Ok(false) => {
650 tracing::debug!("Compaction skipped (not enough messages)");
651 }
652 Err(e) => {
653 tracing::warn!(error = %e, "History compaction failed, continuing");
655 }
656 }
657 }
658
659 if let Some(interval) = self.config.checkpoint_interval {
661 if (iteration + 1) % interval == 0 {
662 let state = enact_core::graph::NodeState::from_string(&last_output);
663 if let Err(e) = self
664 .runner
665 .save_checkpoint(state, Some(callable.name()), Some(callable.name()))
666 .await
667 {
668 tracing::warn!(error = %e, "Failed to save checkpoint, continuing");
669 } else {
670 tracing::debug!(iteration, "Checkpoint saved");
671 }
672 }
673 }
674 }
675
676 tracing::warn!(max = self.config.max_iterations, "Max iterations reached");
678
679 if let Some(ref reg) = self.hook_registry {
681 let mut ctx = serde_json::json!({
682 "event": "SessionEnd",
683 "execution_id": self.runner.execution_id().as_str(),
684 "outcome": "max_iterations",
685 "output": last_output,
686 });
687 self.execute_hooks(reg, HookEvent::SessionEnd, None, &mut ctx, callable)
688 .await;
689 self.execute_hooks(reg, HookEvent::Stop, None, &mut ctx, callable)
690 .await;
691 }
692
693 if self.config.emit_events {
694 let duration = start_time.elapsed().as_millis() as u64;
695 self.runner.emitter().emit(StreamEvent::execution_end(
696 self.runner.execution_id(),
697 Some(last_output.clone()),
698 duration,
699 ));
700 }
701
702 Ok(LoopOutcome::MaxIterationsReached {
703 last_output,
704 iterations: self.config.max_iterations,
705 })
706 }
707
708 fn build_iteration_input(&self, history: &[HistoryMessage]) -> String {
712 history
713 .iter()
714 .map(|m| format!("{}: {}", m.role, m.content))
715 .collect::<Vec<_>>()
716 .join("\n\n")
717 }
718
719 async fn load_mcp_tools_at_session_start(&mut self) {
720 match enact_mcp::discover_mcp_tools().await {
721 Ok(mcp_tools) if !mcp_tools.is_empty() => {
722 let existing: std::collections::HashSet<String> =
723 self.tools.iter().map(|t| t.name().to_string()).collect();
724 let mut added = 0usize;
725 for tool in mcp_tools {
726 if !existing.contains(tool.name()) {
727 self.tools.push(tool);
728 added += 1;
729 }
730 }
731 if added > 0 {
732 tracing::debug!("Added {} MCP tool(s) at session start", added);
733 }
734 }
735 Ok(_) => {}
736 Err(e) => tracing::warn!("MCP discovery at session start failed: {}", e),
737 }
738 }
739
740 async fn execute_hooks(
741 &self,
742 reg: &HookRegistry,
743 event: HookEvent,
744 tool_name: Option<&str>,
745 ctx: &mut serde_json::Value,
746 callable: &dyn Callable,
747 ) {
748 for hook in reg.hooks_for_event(event, tool_name) {
749 if hook.async_mode {
750 if let HookHandler::Command { script } = &hook.handler {
752 let script = script.clone();
753 let context = ctx.clone();
754 let registry = reg.clone();
755 tokio::spawn(async move {
756 let _ = registry.run_command_handler(&script, &context).await;
757 });
758 continue;
759 }
760 }
761 let _ = self.execute_single_hook(reg, hook, ctx, callable).await;
762 }
763 }
764
765 async fn execute_pre_tool_hooks(
766 &self,
767 reg: &HookRegistry,
768 tool_name: Option<&str>,
769 ctx: &serde_json::Value,
770 callable: &dyn Callable,
771 ) -> HookDecision {
772 let mut latest_mutation: Option<HookDecision> = None;
773 for hook in reg.hooks_for_event(HookEvent::PreToolUse, tool_name) {
774 if hook.async_mode {
775 if let HookHandler::Command { script } = &hook.handler {
777 let script = script.clone();
778 let context = ctx.clone();
779 let registry = reg.clone();
780 tokio::spawn(async move {
781 let _ = registry.run_command_handler(&script, &context).await;
782 });
783 }
784 continue;
785 }
786 match self
787 .execute_single_pre_tool_hook(reg, hook, ctx, callable)
788 .await
789 {
790 HookDecision::Allow => {}
791 HookDecision::Block { reason } => return HookDecision::Block { reason },
792 HookDecision::Mutate { arguments, reason } => {
793 latest_mutation = Some(HookDecision::Mutate { arguments, reason });
794 }
795 }
796 }
797 latest_mutation.unwrap_or(HookDecision::Allow)
798 }
799
800 async fn execute_single_hook(
801 &self,
802 reg: &HookRegistry,
803 hook: &enact_config::HookConfig,
804 ctx: &serde_json::Value,
805 callable: &dyn Callable,
806 ) -> bool {
807 match &hook.handler {
808 HookHandler::Command { script } => reg
809 .run_command_handler(script, ctx)
810 .await
811 .map(|r| r.success)
812 .unwrap_or(true),
813 HookHandler::Prompt { template } => {
814 let rendered = render_template(template, ctx);
815 callable.run(&rendered).await.is_ok()
816 }
817 HookHandler::Agent { agent_name } => {
818 let prompt = format!("Hook agent '{}' validation context: {}", agent_name, ctx);
819 callable.run(&prompt).await.is_ok()
820 }
821 }
822 }
823
824 async fn execute_single_pre_tool_hook(
825 &self,
826 reg: &HookRegistry,
827 hook: &enact_config::HookConfig,
828 ctx: &serde_json::Value,
829 callable: &dyn Callable,
830 ) -> HookDecision {
831 match &hook.handler {
832 HookHandler::Command { script } => {
833 let Ok(result) = reg.run_command_handler(script, ctx).await else {
834 return HookDecision::Allow;
835 };
836 parse_command_hook_decision(&result.stdout, result.success)
837 }
838 HookHandler::Prompt { template } => {
839 let rendered = render_template(template, ctx);
840 let Ok(output) = callable.run(&rendered).await else {
841 return HookDecision::Allow;
842 };
843 parse_model_hook_decision(&output)
844 }
845 HookHandler::Agent { agent_name } => {
846 let prompt = format!(
847 "Return ONLY JSON hook decision for this context. Agent '{}': {}",
848 agent_name, ctx
849 );
850 let Ok(output) = callable.run(&prompt).await else {
851 return HookDecision::Allow;
852 };
853 parse_model_hook_decision(&output)
854 }
855 }
856 }
857}
858
859fn build_skill_context(project_dir: Option<&Path>, prompt: &str) -> String {
862 let metas = load_skill_metas(project_dir);
863 if metas.is_empty() {
864 return String::new();
865 }
866 let matched = find_matching_skills(prompt, &metas);
867 let mut parts = Vec::new();
868 for meta in matched {
869 if let Ok(body) = load_skill_body(&meta) {
870 let mut section = format!("## Skill: {}\n\n{}", meta.name, body);
871 let resources =
872 load_skill_resources(&meta, &body, project_dir, SkillResourceLimits::default());
873 if !resources.is_empty() {
874 section.push_str("\n\n### Skill Resources\n");
875 for resource in resources {
876 section.push_str(&format!(
877 "\n#### {}\n\n```\n{}\n```\n",
878 resource.path.display(),
879 resource.content
880 ));
881 }
882 }
883 parts.push(section);
884 }
885 }
886 if parts.is_empty() {
887 return String::new();
888 }
889 format!(
890 "<!-- Matched skills for this request -->\n\n{}\n\n---\n\n",
891 parts.join("\n\n---\n\n")
892 )
893}
894
895fn build_manual_skill_input(project_dir: Option<&Path>, raw_input: &str) -> Option<String> {
896 let (name, args) = parse_slash_invocation(raw_input)?;
897 let meta = find_skill_by_name(project_dir, &name)?;
898 let body = load_skill_body(&meta).ok()?;
899 let mut section = format!("## Skill: {}\n\n{}", meta.name, body);
900 let resources = load_skill_resources(&meta, &body, project_dir, SkillResourceLimits::default());
901 if !resources.is_empty() {
902 section.push_str("\n\n### Skill Resources\n");
903 for resource in resources {
904 section.push_str(&format!(
905 "\n#### {}\n\n```\n{}\n```\n",
906 resource.path.display(),
907 resource.content
908 ));
909 }
910 }
911 let user_request = if args.trim().is_empty() {
912 "Follow this skill for the current task.".to_string()
913 } else {
914 args
915 };
916 Some(format!(
917 "<!-- Manually invoked skill -->\n\n{}\n\n---\n\n{}",
918 section, user_request
919 ))
920}
921
922fn parse_command_hook_decision(stdout: &str, success: bool) -> HookDecision {
923 if let Some(parsed) = parse_hook_decision(stdout) {
924 return parsed;
925 }
926 if success {
927 HookDecision::Allow
928 } else {
929 HookDecision::Block { reason: None }
930 }
931}
932
933fn parse_model_hook_decision(output: &str) -> HookDecision {
934 parse_hook_decision(output).unwrap_or(HookDecision::Allow)
935}
936
937fn parse_hook_decision(raw: &str) -> Option<HookDecision> {
938 let trimmed = raw.trim();
939 if trimmed.is_empty() {
940 return None;
941 }
942 if let Ok(d) = serde_json::from_str::<HookDecision>(trimmed) {
943 return Some(d);
944 }
945 extract_json_object(trimmed).and_then(|s| serde_json::from_str::<HookDecision>(&s).ok())
946}
947
948fn extract_json_object(s: &str) -> Option<String> {
949 let start = s.find('{')?;
950 let end = s.rfind('}')?;
951 if end <= start {
952 return None;
953 }
954 Some(s[start..=end].to_string())
955}
956
957fn render_template(template: &str, ctx: &serde_json::Value) -> String {
958 let mut out = template.to_string();
959 if let Some(obj) = ctx.as_object() {
960 for (k, v) in obj {
961 let replacement = if v.is_string() {
962 v.as_str().unwrap_or_default().to_string()
963 } else {
964 v.to_string()
965 };
966 out = out.replace(&format!("{{{{{}}}}}", k), &replacement);
967 }
968 }
969 out
970}
971
972pub type DefaultAgentRunner = AgentRunner<enact_core::graph::InMemoryCheckpointStore>;
974
975impl DefaultAgentRunner {
976 pub fn default_new() -> Self {
978 Self::new(
979 enact_core::runner::DefaultRunner::default_new(),
980 RunnerConfig::default(),
981 )
982 }
983
984 pub fn with_config(config: RunnerConfig) -> Self {
986 Self::new(enact_core::runner::DefaultRunner::default_new(), config)
987 }
988}
989
990#[cfg(test)]
991mod tests {
992 use super::*;
993 use async_trait::async_trait;
994 use enact_core::callable::Callable;
995 use enact_core::tool::Tool;
996 use serde_json::json;
997 use std::fs;
998 use std::sync::{
999 atomic::{AtomicUsize, Ordering},
1000 Arc,
1001 };
1002 use tokio::sync::Mutex;
1003
1004 #[test]
1005 fn parse_command_hook_decision_falls_back_to_exit_status() {
1006 let allow = parse_command_hook_decision("", true);
1007 assert!(matches!(allow, HookDecision::Allow));
1008
1009 let block = parse_command_hook_decision("", false);
1010 assert!(matches!(block, HookDecision::Block { .. }));
1011 }
1012
1013 #[test]
1014 fn parse_model_hook_decision_handles_json_block() {
1015 let decision =
1016 parse_model_hook_decision(r#"{"decision":"block","reason":"unsafe operation"}"#);
1017 assert!(matches!(
1018 decision,
1019 HookDecision::Block {
1020 reason: Some(ref r)
1021 } if r == "unsafe operation"
1022 ));
1023 }
1024
1025 #[test]
1026 fn build_manual_skill_input_loads_skill_body() {
1027 let project = tempfile::tempdir().unwrap();
1028 let skill_dir = project.path().join(".enact").join("skills").join("review");
1029 fs::create_dir_all(&skill_dir).unwrap();
1030 fs::write(
1031 skill_dir.join("SKILL.md"),
1032 "---\nname: review\ndescription: Review code\nversion: 0.1.0\n---\nAlways list findings.\n",
1033 )
1034 .unwrap();
1035
1036 let injected =
1037 build_manual_skill_input(Some(project.path()), "/review inspect this").unwrap();
1038 assert!(injected.contains("Skill: review"));
1039 assert!(injected.contains("inspect this"));
1040 }
1041
1042 struct MockCallable {
1043 calls: AtomicUsize,
1044 first: String,
1045 rest: String,
1046 }
1047
1048 #[async_trait]
1049 impl Callable for MockCallable {
1050 fn name(&self) -> &str {
1051 "mock"
1052 }
1053
1054 async fn run(&self, _input: &str) -> anyhow::Result<String> {
1055 let idx = self.calls.fetch_add(1, Ordering::SeqCst);
1056 if idx == 0 {
1057 Ok(self.first.clone())
1058 } else {
1059 Ok(self.rest.clone())
1060 }
1061 }
1062 }
1063
1064 struct CaptureTool {
1065 seen: Arc<Mutex<Option<serde_json::Value>>>,
1066 }
1067
1068 #[async_trait]
1069 impl Tool for CaptureTool {
1070 fn name(&self) -> &str {
1071 "capture_tool"
1072 }
1073
1074 fn description(&self) -> &str {
1075 "capture"
1076 }
1077
1078 fn parameters_schema(&self) -> serde_json::Value {
1079 json!({"type":"object"})
1080 }
1081
1082 async fn execute(&self, args: serde_json::Value) -> anyhow::Result<serde_json::Value> {
1083 *self.seen.lock().await = Some(args);
1084 Ok(json!({"ok": true}))
1085 }
1086 }
1087
1088 #[tokio::test]
1089 async fn pre_tool_mutation_updates_tool_arguments() {
1090 let temp = tempfile::tempdir().unwrap();
1091 let hooks_path = temp.path().join("hooks.yaml");
1092 fs::write(
1093 &hooks_path,
1094 r#"hooks:
1095 - event: PreToolUse
1096 matcher: "capture_tool"
1097 handler:
1098 type: command
1099 script: "echo '{\"decision\":\"mutate\",\"arguments\":{\"value\":\"mutated\"}}'"
1100"#,
1101 )
1102 .unwrap();
1103 std::env::set_var(
1104 "ENACT_HOOKS_CONFIG_PATH",
1105 hooks_path.to_string_lossy().as_ref(),
1106 );
1107
1108 let seen = Arc::new(Mutex::new(None));
1109 let tool = CaptureTool {
1110 seen: Arc::clone(&seen),
1111 };
1112 let callable = MockCallable {
1113 calls: AtomicUsize::new(0),
1114 first: r#"{"tool_call":{"name":"capture_tool","arguments":{"value":"original"}}}"#
1115 .to_string(),
1116 rest: "done".to_string(),
1117 };
1118
1119 let mut runner = DefaultAgentRunner::default_new()
1120 .add_tool(tool)
1121 .with_hook_registry(Arc::new(HookRegistry::load_global_and_agent(None, None)));
1122 let result = runner.run(&callable, "test", None).await.unwrap();
1123 std::env::remove_var("ENACT_HOOKS_CONFIG_PATH");
1124
1125 assert!(matches!(result, LoopOutcome::Completed(_)));
1126 let captured = seen.lock().await.clone().unwrap();
1127 assert_eq!(captured, json!({"value":"mutated"}));
1128 }
1129
1130 #[tokio::test]
1131 async fn prompt_and_agent_hooks_dispatch_callable() {
1132 let reg = HookRegistry::new();
1133 let runner = DefaultAgentRunner::default_new();
1134 let callable = MockCallable {
1135 calls: AtomicUsize::new(0),
1136 first: r#"{"decision":"allow"}"#.to_string(),
1137 rest: r#"{"decision":"allow"}"#.to_string(),
1138 };
1139 let ctx = json!({"tool_name":"capture_tool","arguments":{"v":1}});
1140
1141 let prompt = enact_config::HookConfig {
1142 event: HookEvent::PreToolUse,
1143 matcher: None,
1144 handler: HookHandler::Prompt {
1145 template: "Decide for {{tool_name}}".to_string(),
1146 },
1147 async_mode: false,
1148 };
1149 let agent = enact_config::HookConfig {
1150 event: HookEvent::PreToolUse,
1151 matcher: None,
1152 handler: HookHandler::Agent {
1153 agent_name: "reviewer".to_string(),
1154 },
1155 async_mode: false,
1156 };
1157
1158 let p = runner
1159 .execute_single_pre_tool_hook(®, &prompt, &ctx, &callable)
1160 .await;
1161 let a = runner
1162 .execute_single_pre_tool_hook(®, &agent, &ctx, &callable)
1163 .await;
1164
1165 assert!(matches!(p, HookDecision::Allow));
1166 assert!(matches!(a, HookDecision::Allow));
1167 assert_eq!(callable.calls.load(Ordering::SeqCst), 2);
1168 }
1169}