Skip to main content

enact_runner/
loop_driver.rs

1//! Core agent loop driver
2//!
3//! `AgentRunner` is the "missing middle" — the robust loop that sits between
4//! the apps (CLI, API) and the `enact-core` kernel. It implements:
5//!
6//! - Multi-iteration tool call loop (ported from zeroclaw's `run_tool_call_loop`)
7//! - Multi-format tool call parsing (JSON, XML, Markdown)
8//! - Auto context compaction when history grows too large
9//! - Retry with exponential backoff for transient errors
10//! - Checkpoint saving at configured intervals
11//! - Stream event emission throughout
12//! - LLM call tracing and token/cost accounting (observability)
13//!
14//! This fulfills the mandate from `28-TS-RUST-INTEGRATION.md`:
15//! "Rust handles retries/fallbacks" and `27-LONG-RUNNING-EXECUTIONS.md`:
16//! "Unified long-running agentic loop".
17
18use crate::approval::ApprovalChecker;
19use crate::commands::{
20    dispatch as dispatch_command, load_commands_for_run, parse_slash_invocation,
21};
22use crate::compaction::{self, HistoryMessage};
23use crate::config::RunnerConfig;
24use crate::hooks::HookRegistry;
25use crate::parser;
26use crate::retry::RetryHandler;
27
28use enact_config::{HookDecision, HookEvent, HookHandler};
29use enact_core::callable::Callable;
30use enact_core::graph::CheckpointStore;
31use enact_core::kernel::cost::{
32    CostCalculator, ModelPricing, TokenUsage as LlmTokenUsage, UsageAccumulator,
33};
34use enact_core::kernel::{ExecutionError, StepId, StepType};
35use enact_core::runner::Runner;
36use enact_core::streaming::StreamEvent;
37use enact_core::tool::Tool;
38use enact_skills::{
39    find_matching_skills, find_skill_by_name, load_skill_body, load_skill_metas,
40    load_skill_resources, SkillResourceLimits,
41};
42
43use std::path::Path;
44use std::sync::Arc;
45use std::time::Instant;
46
47/// Outcome of a runner loop execution.
48#[derive(Debug)]
49pub enum LoopOutcome {
50    /// The agent produced a final response (no more tool calls).
51    Completed(String),
52    /// The loop hit the maximum iteration limit.
53    MaxIterationsReached {
54        last_output: String,
55        iterations: usize,
56    },
57    /// The execution was cancelled externally.
58    Cancelled,
59    /// The execution exceeded the configured timeout.
60    TimedOut { elapsed_secs: u64 },
61}
62
63impl LoopOutcome {
64    /// Whether the outcome represents a successful completion.
65    pub fn is_completed(&self) -> bool {
66        matches!(self, LoopOutcome::Completed(_))
67    }
68
69    /// Get the final output text, if any.
70    pub fn output(&self) -> Option<&str> {
71        match self {
72            LoopOutcome::Completed(s) => Some(s),
73            LoopOutcome::MaxIterationsReached { last_output, .. } => Some(last_output),
74            _ => None,
75        }
76    }
77}
78
79/// The robust agent runner — drives the tool call loop with retries,
80/// compaction, and multi-format parsing.
81pub struct AgentRunner<S: CheckpointStore> {
82    /// The underlying enact-core runner (provides cancel/pause/checkpoint/events)
83    runner: Runner<S>,
84    /// Configuration for iteration limits, compaction, retries
85    config: RunnerConfig,
86    /// Registered tools available to the agent
87    tools: Vec<Arc<dyn Tool>>,
88    /// Cumulative token usage and cost tracking
89    usage_accumulator: UsageAccumulator,
90    /// Optional approval checker (PreToolUse); if returns false, tool is blocked
91    approval_checker: Option<Arc<dyn ApprovalChecker>>,
92    /// Optional hook registry (SessionStart, UserPromptSubmit, PreToolUse, PostToolUse, SessionEnd)
93    hook_registry: Option<Arc<HookRegistry>>,
94}
95
96impl<S: CheckpointStore> AgentRunner<S> {
97    /// Create a new `AgentRunner` wrapping an existing `Runner`.
98    pub fn new(runner: Runner<S>, config: RunnerConfig) -> Self {
99        Self {
100            runner,
101            config,
102            tools: Vec::new(),
103            usage_accumulator: UsageAccumulator::new(),
104            approval_checker: None,
105            hook_registry: None,
106        }
107    }
108
109    /// Set the hook registry for lifecycle hooks (SessionStart, UserPromptSubmit, PreToolUse, PostToolUse, SessionEnd).
110    pub fn with_hook_registry(mut self, registry: Arc<HookRegistry>) -> Self {
111        self.hook_registry = Some(registry);
112        self
113    }
114
115    /// Set an approval checker (PreToolUse). When set, each tool call is checked; if false, the tool is blocked.
116    pub fn with_approval_checker(mut self, checker: Arc<dyn ApprovalChecker>) -> Self {
117        self.approval_checker = Some(checker);
118        self
119    }
120
121    /// Get the cumulative usage statistics
122    pub fn usage(&self) -> &UsageAccumulator {
123        &self.usage_accumulator
124    }
125
126    /// Register a tool for use in the loop.
127    pub fn add_tool(mut self, tool: impl Tool + 'static) -> Self {
128        self.tools.push(Arc::new(tool));
129        self
130    }
131
132    /// Register multiple tools at once.
133    pub fn add_tools(mut self, tools: Vec<Arc<dyn Tool>>) -> Self {
134        self.tools.extend(tools);
135        self
136    }
137
138    /// Access the underlying runner.
139    pub fn inner(&self) -> &Runner<S> {
140        &self.runner
141    }
142
143    /// Access the underlying runner mutably.
144    pub fn inner_mut(&mut self) -> &mut Runner<S> {
145        &mut self.runner
146    }
147
148    /// Run the robust agent loop.
149    ///
150    /// This is the core method — it replaces both `AgenticLoop::run` (which was
151    /// a skeleton) and `LlmCallable::run` (which had a basic loop but no
152    /// retries, compaction, or multi-format parsing).
153    ///
154    /// # Flow
155    ///
156    /// 1. Call the callable with current input
157    /// 2. Parse the response for tool calls (JSON → XML → Markdown)
158    /// 3. If tool calls found: execute them, append results to history, loop
159    /// 4. If no tool calls: return the final response
160    /// 5. Throughout: check limits, retry on errors, compact history, checkpoint
161    pub async fn run(
162        &mut self,
163        callable: &dyn Callable,
164        input: &str,
165        project_dir: Option<&Path>,
166    ) -> anyhow::Result<LoopOutcome> {
167        let start_time = Instant::now();
168        let mut retry_handler = RetryHandler::new(self.config.retry.clone());
169
170        self.load_mcp_tools_at_session_start().await;
171
172        // ── Hooks: SessionStart ──
173        if let Some(ref reg) = self.hook_registry {
174            let mut ctx = serde_json::json!({
175                "event": "SessionStart",
176                "execution_id": self.runner.execution_id().as_str(),
177                "callable": callable.name(),
178            });
179            self.execute_hooks(reg, HookEvent::SessionStart, None, &mut ctx, callable)
180                .await;
181        }
182
183        // ── Slash command dispatch: expand /command args into content (global + project + plugin commands) ──
184        let commands = load_commands_for_run(project_dir);
185        let command_expanded = dispatch_command(input, &commands);
186        let effective_input = command_expanded
187            .clone()
188            .unwrap_or_else(|| input.to_string());
189
190        // ── Manual skill invocation: /skill-name or /plugin:skill-name (command precedence) ──
191        let first_user_content = if command_expanded.is_none() {
192            if let Some(invoked) = build_manual_skill_input(project_dir, input) {
193                invoked
194            } else {
195                let skill_context = build_skill_context(project_dir, &effective_input);
196                if skill_context.is_empty() {
197                    effective_input.clone()
198                } else {
199                    format!("{}\n\n{}", skill_context, effective_input)
200                }
201            }
202        } else {
203            let skill_context = build_skill_context(project_dir, &effective_input);
204            if skill_context.is_empty() {
205                effective_input.clone()
206            } else {
207                format!("{}\n\n{}", skill_context, effective_input)
208            }
209        };
210
211        // Build initial history
212        let mut history = vec![HistoryMessage::user(&first_user_content)];
213
214        // ── Hooks: UserPromptSubmit ──
215        if let Some(ref reg) = self.hook_registry {
216            let mut ctx = serde_json::json!({
217                "event": "UserPromptSubmit",
218                "execution_id": self.runner.execution_id().as_str(),
219                "prompt": effective_input,
220            });
221            self.execute_hooks(reg, HookEvent::UserPromptSubmit, None, &mut ctx, callable)
222                .await;
223        }
224
225        // Emit execution start
226        if self.config.emit_events {
227            self.runner
228                .emitter()
229                .emit(StreamEvent::execution_start(self.runner.execution_id()));
230        }
231
232        tracing::info!(
233            execution_id = %self.runner.execution_id(),
234            callable = callable.name(),
235            max_iterations = self.config.max_iterations,
236            "Starting robust agent loop"
237        );
238
239        let mut last_output = String::new();
240
241        for iteration in 0..self.config.max_iterations {
242            // ── Check cancellation ──
243            if self.runner.is_cancelled() {
244                tracing::info!("Execution cancelled");
245                if let Some(ref reg) = self.hook_registry {
246                    let mut ctx = serde_json::json!({
247                        "event": "SessionEnd",
248                        "execution_id": self.runner.execution_id().as_str(),
249                        "outcome": "cancelled",
250                    });
251                    self.execute_hooks(reg, HookEvent::SessionEnd, None, &mut ctx, callable)
252                        .await;
253                }
254                return Ok(LoopOutcome::Cancelled);
255            }
256
257            // ── Check timeout ──
258            let elapsed = start_time.elapsed();
259            if elapsed > self.config.max_duration {
260                tracing::warn!(elapsed_secs = elapsed.as_secs(), "Execution timed out");
261                if let Some(ref reg) = self.hook_registry {
262                    let mut ctx = serde_json::json!({
263                        "event": "SessionEnd",
264                        "execution_id": self.runner.execution_id().as_str(),
265                        "outcome": "timed_out",
266                        "elapsed_secs": elapsed.as_secs(),
267                    });
268                    self.execute_hooks(reg, HookEvent::SessionEnd, None, &mut ctx, callable)
269                        .await;
270                }
271                return Ok(LoopOutcome::TimedOut {
272                    elapsed_secs: elapsed.as_secs(),
273                });
274            }
275
276            // ── Wait if paused ──
277            while self.runner.is_paused() {
278                tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
279                if self.runner.is_cancelled() {
280                    return Ok(LoopOutcome::Cancelled);
281                }
282            }
283
284            // ── Build input for this iteration ──
285            let current_input = self.build_iteration_input(&history);
286            let step_id = StepId::new();
287            let step_start = Instant::now();
288
289            if self.config.emit_events {
290                self.runner.emitter().emit(StreamEvent::step_start(
291                    self.runner.execution_id(),
292                    &step_id,
293                    StepType::LlmNode,
294                    format!("iteration_{}", iteration),
295                ));
296            }
297
298            tracing::debug!(iteration, "Running callable");
299
300            // ── Emit LLM call start event (observability) ──
301            if self.config.emit_events && self.config.observability.trace_llm_calls {
302                self.runner.emitter().emit(StreamEvent::llm_call_start(
303                    self.runner.execution_id(),
304                    Some(&step_id),
305                    callable.name(),
306                    self.config.observability.model_name.clone(),
307                    history.len(),
308                ));
309            }
310
311            let llm_call_start = Instant::now();
312
313            // ── Call the LLM ──
314            let response = match callable.run(&current_input).await {
315                Ok(output) => {
316                    retry_handler.reset();
317
318                    // ── Emit LLM call end event (observability) ──
319                    let llm_call_duration_ms = llm_call_start.elapsed().as_millis() as u64;
320                    if self.config.emit_events && self.config.observability.trace_llm_calls {
321                        let (prompt_tokens, completion_tokens) = match callable.last_usage() {
322                            Some(u) if !u.is_empty() => (u.prompt_tokens, u.completion_tokens),
323                            _ => {
324                                let prompt_len = current_input.len();
325                                let estimated_prompt_tokens = (prompt_len / 4) as u32;
326                                let estimated_tokens = (output.len() / 4) as u32;
327                                (estimated_prompt_tokens, estimated_tokens)
328                            }
329                        };
330                        let total_tokens = prompt_tokens + completion_tokens;
331
332                        self.runner.emitter().emit(StreamEvent::llm_call_end(
333                            self.runner.execution_id(),
334                            Some(&step_id),
335                            llm_call_duration_ms,
336                            Some(prompt_tokens),
337                            Some(completion_tokens),
338                            Some(total_tokens),
339                        ));
340
341                        // Track cumulative usage
342                        if self.config.observability.track_token_usage {
343                            let usage = LlmTokenUsage::new(prompt_tokens, completion_tokens);
344
345                            // Use pricing from config if available, otherwise fall back to model name
346                            let pricing = if let (Some(input_cost), Some(output_cost)) = (
347                                self.config.observability.cost_per_1m_input,
348                                self.config.observability.cost_per_1m_output,
349                            ) {
350                                ModelPricing::new(input_cost, output_cost)
351                            } else {
352                                let model = self
353                                    .config
354                                    .observability
355                                    .model_name
356                                    .as_deref()
357                                    .unwrap_or("default");
358                                CostCalculator::pricing_for_model(model)
359                            };
360                            let cost = CostCalculator::calculate_cost(&usage, &pricing);
361
362                            self.usage_accumulator.add(&usage, cost);
363
364                            // Emit token usage event
365                            self.runner
366                                .emitter()
367                                .emit(StreamEvent::token_usage_recorded(
368                                    self.runner.execution_id(),
369                                    Some(&step_id),
370                                    prompt_tokens,
371                                    completion_tokens,
372                                    self.usage_accumulator.total_tokens,
373                                    Some(cost),
374                                    Some(self.usage_accumulator.total_cost_usd),
375                                ));
376                        }
377                    }
378
379                    output
380                }
381                Err(e) => {
382                    let err_msg = e.to_string();
383                    let llm_call_duration_ms = llm_call_start.elapsed().as_millis() as u64;
384                    tracing::warn!(iteration, error = %err_msg, "Callable error");
385
386                    // ── Emit LLM call failed event (observability) ──
387                    if self.config.emit_events && self.config.observability.trace_llm_calls {
388                        self.runner.emitter().emit(StreamEvent::llm_call_failed(
389                            self.runner.execution_id(),
390                            Some(&step_id),
391                            &err_msg,
392                            Some(llm_call_duration_ms),
393                        ));
394                    }
395
396                    // Check if we should retry
397                    if let Some(delay) = retry_handler.should_retry(&err_msg) {
398                        if self.config.emit_events {
399                            self.runner.emitter().emit(StreamEvent::step_end(
400                                self.runner.execution_id(),
401                                &step_id,
402                                Some(format!("Retrying after error: {}", err_msg)),
403                                step_start.elapsed().as_millis() as u64,
404                            ));
405                        }
406                        tokio::time::sleep(delay).await;
407                        continue; // Retry this iteration
408                    }
409
410                    // Non-retryable or max retries exceeded
411                    let exec_error = ExecutionError::kernel_internal(err_msg);
412                    if self.config.emit_events {
413                        self.runner.emitter().emit(StreamEvent::execution_failed(
414                            self.runner.execution_id(),
415                            exec_error,
416                        ));
417                    }
418                    return Err(e);
419                }
420            };
421
422            // ── Parse for tool calls ──
423            let mut parse_result = parser::parse(&response);
424
425            if parse_result.tool_calls.is_empty() {
426                // No tool calls — this is the final response
427                last_output = if parse_result.text.is_empty() {
428                    response.clone()
429                } else {
430                    parse_result.text.clone()
431                };
432
433                let step_duration = step_start.elapsed().as_millis() as u64;
434                if self.config.emit_events {
435                    self.runner.emitter().emit(StreamEvent::step_end(
436                        self.runner.execution_id(),
437                        &step_id,
438                        Some(last_output.clone()),
439                        step_duration,
440                    ));
441                    let total_duration = start_time.elapsed().as_millis() as u64;
442                    self.runner.emitter().emit(StreamEvent::execution_end(
443                        self.runner.execution_id(),
444                        Some(last_output.clone()),
445                        total_duration,
446                    ));
447                }
448
449                tracing::info!(
450                    iteration,
451                    output_len = last_output.len(),
452                    "Agent loop completed with final response"
453                );
454
455                // Hooks: SessionEnd / Stop on normal completion
456                if let Some(ref reg) = self.hook_registry {
457                    let mut ctx = serde_json::json!({
458                        "event": "SessionEnd",
459                        "execution_id": self.runner.execution_id().as_str(),
460                        "outcome": "completed",
461                        "output": last_output,
462                    });
463                    self.execute_hooks(reg, HookEvent::SessionEnd, None, &mut ctx, callable)
464                        .await;
465                    self.execute_hooks(reg, HookEvent::Stop, None, &mut ctx, callable)
466                        .await;
467                }
468
469                return Ok(LoopOutcome::Completed(last_output));
470            }
471
472            // ── Execute tool calls ──
473            history.push(HistoryMessage::assistant(&response));
474
475            for (tool_idx, tool_call) in parse_result.tool_calls.iter_mut().enumerate() {
476                let tool_call_id = format!("{}-tool-{}", step_id, tool_idx);
477
478                tracing::debug!(
479                    tool = %tool_call.name,
480                    format = ?tool_call.format,
481                    "Executing tool call"
482                );
483
484                // PreToolUse: hooks (allow/block/mutate)
485                if let Some(ref reg) = self.hook_registry {
486                    let ctx = serde_json::json!({
487                        "event": "PreToolUse",
488                        "execution_id": self.runner.execution_id().as_str(),
489                        "tool_name": tool_call.name,
490                        "arguments": tool_call.arguments,
491                    });
492                    match self
493                        .execute_pre_tool_hooks(reg, Some(&tool_call.name), &ctx, callable)
494                        .await
495                    {
496                        HookDecision::Allow => {}
497                        HookDecision::Mutate { arguments, reason } => {
498                            if let Some(reason) = reason {
499                                tracing::debug!(
500                                    tool = %tool_call.name,
501                                    reason = %reason,
502                                    "PreToolUse hook mutated tool arguments"
503                                );
504                            }
505                            tool_call.arguments = arguments;
506                        }
507                        HookDecision::Block { reason } => {
508                            let blocked_msg = match reason {
509                                Some(reason) if !reason.is_empty() => format!(
510                                    "Tool '{}' was blocked by a PreToolUse hook: {}",
511                                    tool_call.name, reason
512                                ),
513                                _ => {
514                                    format!(
515                                        "Tool '{}' was blocked by a PreToolUse hook.",
516                                        tool_call.name
517                                    )
518                                }
519                            };
520                            history
521                                .push(HistoryMessage::tool_result(&tool_call.name, &blocked_msg));
522                            continue;
523                        }
524                    }
525                }
526
527                // PreToolUse: approval check (human-in-the-loop)
528                if let Some(ref checker) = self.approval_checker {
529                    if self.config.emit_events {
530                        self.runner.emitter().emit(StreamEvent::permission_request(
531                            self.runner.execution_id(),
532                            &tool_call.name,
533                            tool_call.arguments.clone(),
534                            "approval_policy",
535                        ));
536                    }
537                    if !checker
538                        .allow_tool(&tool_call.name, &tool_call.arguments)
539                        .await
540                    {
541                        let blocked_msg = format!(
542                            "Tool '{}' was blocked by approval policy (user denied or policy=always_deny).",
543                            tool_call.name
544                        );
545                        history.push(HistoryMessage::tool_result(&tool_call.name, &blocked_msg));
546                        continue;
547                    }
548                }
549
550                // Emit tool input event
551                if self.config.emit_events {
552                    self.runner.emitter().emit(StreamEvent::ToolInputAvailable {
553                        tool_call_id: tool_call_id.clone(),
554                        tool_name: tool_call.name.clone(),
555                        input: tool_call.arguments.clone(),
556                    });
557                }
558
559                let tool_start = std::time::Instant::now();
560
561                // Find the tool in our registry
562                let tool = self.tools.iter().find(|t| t.name() == tool_call.name);
563
564                let tool_result = match tool {
565                    Some(t) => match t.execute(tool_call.arguments.clone()).await {
566                        Ok(result) => serde_json::to_string(&result)
567                            .unwrap_or_else(|_| format!("{:?}", result)),
568                        Err(e) => format!("Tool error: {}", e),
569                    },
570                    None => {
571                        format!(
572                            "Error: Tool '{}' not found. Available tools: {:?}",
573                            tool_call.name,
574                            self.tools.iter().map(|t| t.name()).collect::<Vec<_>>()
575                        )
576                    }
577                };
578
579                let tool_duration_ms = tool_start.elapsed().as_millis() as u64;
580
581                // Emit tool output event
582                if self.config.emit_events {
583                    self.runner
584                        .emitter()
585                        .emit(StreamEvent::ToolOutputAvailable {
586                            tool_call_id: tool_call_id.clone(),
587                            output: serde_json::json!({
588                                "result": tool_result.clone(),
589                                "duration_ms": tool_duration_ms,
590                            }),
591                        });
592                }
593
594                history.push(HistoryMessage::tool_result(&tool_call.name, &tool_result));
595
596                // PostToolUse: hooks
597                if let Some(ref reg) = self.hook_registry {
598                    let mut ctx = serde_json::json!({
599                        "event": "PostToolUse",
600                        "execution_id": self.runner.execution_id().as_str(),
601                        "tool_name": tool_call.name,
602                        "arguments": tool_call.arguments,
603                        "result": tool_result,
604                    });
605                    self.execute_hooks(
606                        reg,
607                        HookEvent::PostToolUse,
608                        Some(&tool_call.name),
609                        &mut ctx,
610                        callable,
611                    )
612                    .await;
613                }
614            }
615
616            let step_duration = step_start.elapsed().as_millis() as u64;
617            if self.config.emit_events {
618                self.runner.emitter().emit(StreamEvent::step_end(
619                    self.runner.execution_id(),
620                    &step_id,
621                    Some(format!(
622                        "Executed {} tool(s)",
623                        parse_result.tool_calls.len()
624                    )),
625                    step_duration,
626                ));
627            }
628
629            last_output = response;
630
631            // ── Auto-compact history if needed ──
632            if compaction::needs_compaction(&history, self.config.compaction_threshold) {
633                tracing::info!(
634                    history_len = history.len(),
635                    threshold = self.config.compaction_threshold,
636                    "Triggering auto-compaction"
637                );
638
639                match compaction::compact_history(
640                    &mut history,
641                    callable,
642                    self.config.compaction_keep_recent,
643                )
644                .await
645                {
646                    Ok(true) => {
647                        tracing::info!(new_len = history.len(), "History compacted successfully");
648                    }
649                    Ok(false) => {
650                        tracing::debug!("Compaction skipped (not enough messages)");
651                    }
652                    Err(e) => {
653                        // Compaction failure is non-fatal — log and continue
654                        tracing::warn!(error = %e, "History compaction failed, continuing");
655                    }
656                }
657            }
658
659            // ── Checkpoint at intervals ──
660            if let Some(interval) = self.config.checkpoint_interval {
661                if (iteration + 1) % interval == 0 {
662                    let state = enact_core::graph::NodeState::from_string(&last_output);
663                    if let Err(e) = self
664                        .runner
665                        .save_checkpoint(state, Some(callable.name()), Some(callable.name()))
666                        .await
667                    {
668                        tracing::warn!(error = %e, "Failed to save checkpoint, continuing");
669                    } else {
670                        tracing::debug!(iteration, "Checkpoint saved");
671                    }
672                }
673            }
674        }
675
676        // Max iterations reached
677        tracing::warn!(max = self.config.max_iterations, "Max iterations reached");
678
679        // Hooks: SessionEnd / Stop
680        if let Some(ref reg) = self.hook_registry {
681            let mut ctx = serde_json::json!({
682                "event": "SessionEnd",
683                "execution_id": self.runner.execution_id().as_str(),
684                "outcome": "max_iterations",
685                "output": last_output,
686            });
687            self.execute_hooks(reg, HookEvent::SessionEnd, None, &mut ctx, callable)
688                .await;
689            self.execute_hooks(reg, HookEvent::Stop, None, &mut ctx, callable)
690                .await;
691        }
692
693        if self.config.emit_events {
694            let duration = start_time.elapsed().as_millis() as u64;
695            self.runner.emitter().emit(StreamEvent::execution_end(
696                self.runner.execution_id(),
697                Some(last_output.clone()),
698                duration,
699            ));
700        }
701
702        Ok(LoopOutcome::MaxIterationsReached {
703            last_output,
704            iterations: self.config.max_iterations,
705        })
706    }
707
708    /// Build the input string for a given iteration from history.
709    ///
710    /// Concatenates all history messages into a single prompt string.
711    fn build_iteration_input(&self, history: &[HistoryMessage]) -> String {
712        history
713            .iter()
714            .map(|m| format!("{}: {}", m.role, m.content))
715            .collect::<Vec<_>>()
716            .join("\n\n")
717    }
718
719    async fn load_mcp_tools_at_session_start(&mut self) {
720        match enact_mcp::discover_mcp_tools().await {
721            Ok(mcp_tools) if !mcp_tools.is_empty() => {
722                let existing: std::collections::HashSet<String> =
723                    self.tools.iter().map(|t| t.name().to_string()).collect();
724                let mut added = 0usize;
725                for tool in mcp_tools {
726                    if !existing.contains(tool.name()) {
727                        self.tools.push(tool);
728                        added += 1;
729                    }
730                }
731                if added > 0 {
732                    tracing::debug!("Added {} MCP tool(s) at session start", added);
733                }
734            }
735            Ok(_) => {}
736            Err(e) => tracing::warn!("MCP discovery at session start failed: {}", e),
737        }
738    }
739
740    async fn execute_hooks(
741        &self,
742        reg: &HookRegistry,
743        event: HookEvent,
744        tool_name: Option<&str>,
745        ctx: &mut serde_json::Value,
746        callable: &dyn Callable,
747    ) {
748        for hook in reg.hooks_for_event(event, tool_name) {
749            if hook.async_mode {
750                // async_mode is fire-and-forget for non-blocking lifecycle events.
751                if let HookHandler::Command { script } = &hook.handler {
752                    let script = script.clone();
753                    let context = ctx.clone();
754                    let registry = reg.clone();
755                    tokio::spawn(async move {
756                        let _ = registry.run_command_handler(&script, &context).await;
757                    });
758                    continue;
759                }
760            }
761            let _ = self.execute_single_hook(reg, hook, ctx, callable).await;
762        }
763    }
764
765    async fn execute_pre_tool_hooks(
766        &self,
767        reg: &HookRegistry,
768        tool_name: Option<&str>,
769        ctx: &serde_json::Value,
770        callable: &dyn Callable,
771    ) -> HookDecision {
772        let mut latest_mutation: Option<HookDecision> = None;
773        for hook in reg.hooks_for_event(HookEvent::PreToolUse, tool_name) {
774            if hook.async_mode {
775                // PreToolUse async hooks cannot authoritatively block/mutate.
776                if let HookHandler::Command { script } = &hook.handler {
777                    let script = script.clone();
778                    let context = ctx.clone();
779                    let registry = reg.clone();
780                    tokio::spawn(async move {
781                        let _ = registry.run_command_handler(&script, &context).await;
782                    });
783                }
784                continue;
785            }
786            match self
787                .execute_single_pre_tool_hook(reg, hook, ctx, callable)
788                .await
789            {
790                HookDecision::Allow => {}
791                HookDecision::Block { reason } => return HookDecision::Block { reason },
792                HookDecision::Mutate { arguments, reason } => {
793                    latest_mutation = Some(HookDecision::Mutate { arguments, reason });
794                }
795            }
796        }
797        latest_mutation.unwrap_or(HookDecision::Allow)
798    }
799
800    async fn execute_single_hook(
801        &self,
802        reg: &HookRegistry,
803        hook: &enact_config::HookConfig,
804        ctx: &serde_json::Value,
805        callable: &dyn Callable,
806    ) -> bool {
807        match &hook.handler {
808            HookHandler::Command { script } => reg
809                .run_command_handler(script, ctx)
810                .await
811                .map(|r| r.success)
812                .unwrap_or(true),
813            HookHandler::Prompt { template } => {
814                let rendered = render_template(template, ctx);
815                callable.run(&rendered).await.is_ok()
816            }
817            HookHandler::Agent { agent_name } => {
818                let prompt = format!("Hook agent '{}' validation context: {}", agent_name, ctx);
819                callable.run(&prompt).await.is_ok()
820            }
821        }
822    }
823
824    async fn execute_single_pre_tool_hook(
825        &self,
826        reg: &HookRegistry,
827        hook: &enact_config::HookConfig,
828        ctx: &serde_json::Value,
829        callable: &dyn Callable,
830    ) -> HookDecision {
831        match &hook.handler {
832            HookHandler::Command { script } => {
833                let Ok(result) = reg.run_command_handler(script, ctx).await else {
834                    return HookDecision::Allow;
835                };
836                parse_command_hook_decision(&result.stdout, result.success)
837            }
838            HookHandler::Prompt { template } => {
839                let rendered = render_template(template, ctx);
840                let Ok(output) = callable.run(&rendered).await else {
841                    return HookDecision::Allow;
842                };
843                parse_model_hook_decision(&output)
844            }
845            HookHandler::Agent { agent_name } => {
846                let prompt = format!(
847                    "Return ONLY JSON hook decision for this context. Agent '{}': {}",
848                    agent_name, ctx
849                );
850                let Ok(output) = callable.run(&prompt).await else {
851                    return HookDecision::Allow;
852                };
853                parse_model_hook_decision(&output)
854            }
855        }
856    }
857}
858
859/// Build context string from skills that match the user prompt (progressive Level 2).
860/// Returns concatenated skill bodies to prepend to the first user message.
861fn build_skill_context(project_dir: Option<&Path>, prompt: &str) -> String {
862    let metas = load_skill_metas(project_dir);
863    if metas.is_empty() {
864        return String::new();
865    }
866    let matched = find_matching_skills(prompt, &metas);
867    let mut parts = Vec::new();
868    for meta in matched {
869        if let Ok(body) = load_skill_body(&meta) {
870            let mut section = format!("## Skill: {}\n\n{}", meta.name, body);
871            let resources =
872                load_skill_resources(&meta, &body, project_dir, SkillResourceLimits::default());
873            if !resources.is_empty() {
874                section.push_str("\n\n### Skill Resources\n");
875                for resource in resources {
876                    section.push_str(&format!(
877                        "\n#### {}\n\n```\n{}\n```\n",
878                        resource.path.display(),
879                        resource.content
880                    ));
881                }
882            }
883            parts.push(section);
884        }
885    }
886    if parts.is_empty() {
887        return String::new();
888    }
889    format!(
890        "<!-- Matched skills for this request -->\n\n{}\n\n---\n\n",
891        parts.join("\n\n---\n\n")
892    )
893}
894
895fn build_manual_skill_input(project_dir: Option<&Path>, raw_input: &str) -> Option<String> {
896    let (name, args) = parse_slash_invocation(raw_input)?;
897    let meta = find_skill_by_name(project_dir, &name)?;
898    let body = load_skill_body(&meta).ok()?;
899    let mut section = format!("## Skill: {}\n\n{}", meta.name, body);
900    let resources = load_skill_resources(&meta, &body, project_dir, SkillResourceLimits::default());
901    if !resources.is_empty() {
902        section.push_str("\n\n### Skill Resources\n");
903        for resource in resources {
904            section.push_str(&format!(
905                "\n#### {}\n\n```\n{}\n```\n",
906                resource.path.display(),
907                resource.content
908            ));
909        }
910    }
911    let user_request = if args.trim().is_empty() {
912        "Follow this skill for the current task.".to_string()
913    } else {
914        args
915    };
916    Some(format!(
917        "<!-- Manually invoked skill -->\n\n{}\n\n---\n\n{}",
918        section, user_request
919    ))
920}
921
922fn parse_command_hook_decision(stdout: &str, success: bool) -> HookDecision {
923    if let Some(parsed) = parse_hook_decision(stdout) {
924        return parsed;
925    }
926    if success {
927        HookDecision::Allow
928    } else {
929        HookDecision::Block { reason: None }
930    }
931}
932
933fn parse_model_hook_decision(output: &str) -> HookDecision {
934    parse_hook_decision(output).unwrap_or(HookDecision::Allow)
935}
936
937fn parse_hook_decision(raw: &str) -> Option<HookDecision> {
938    let trimmed = raw.trim();
939    if trimmed.is_empty() {
940        return None;
941    }
942    if let Ok(d) = serde_json::from_str::<HookDecision>(trimmed) {
943        return Some(d);
944    }
945    extract_json_object(trimmed).and_then(|s| serde_json::from_str::<HookDecision>(&s).ok())
946}
947
948fn extract_json_object(s: &str) -> Option<String> {
949    let start = s.find('{')?;
950    let end = s.rfind('}')?;
951    if end <= start {
952        return None;
953    }
954    Some(s[start..=end].to_string())
955}
956
957fn render_template(template: &str, ctx: &serde_json::Value) -> String {
958    let mut out = template.to_string();
959    if let Some(obj) = ctx.as_object() {
960        for (k, v) in obj {
961            let replacement = if v.is_string() {
962                v.as_str().unwrap_or_default().to_string()
963            } else {
964                v.to_string()
965            };
966            out = out.replace(&format!("{{{{{}}}}}", k), &replacement);
967        }
968    }
969    out
970}
971
972/// Convenience: create an `AgentRunner` with default in-memory checkpoint store.
973pub type DefaultAgentRunner = AgentRunner<enact_core::graph::InMemoryCheckpointStore>;
974
975impl DefaultAgentRunner {
976    /// Create a new `AgentRunner` with default in-memory store and default config.
977    pub fn default_new() -> Self {
978        Self::new(
979            enact_core::runner::DefaultRunner::default_new(),
980            RunnerConfig::default(),
981        )
982    }
983
984    /// Create with a specific config.
985    pub fn with_config(config: RunnerConfig) -> Self {
986        Self::new(enact_core::runner::DefaultRunner::default_new(), config)
987    }
988}
989
990#[cfg(test)]
991mod tests {
992    use super::*;
993    use async_trait::async_trait;
994    use enact_core::callable::Callable;
995    use enact_core::tool::Tool;
996    use serde_json::json;
997    use std::fs;
998    use std::sync::{
999        atomic::{AtomicUsize, Ordering},
1000        Arc,
1001    };
1002    use tokio::sync::Mutex;
1003
1004    #[test]
1005    fn parse_command_hook_decision_falls_back_to_exit_status() {
1006        let allow = parse_command_hook_decision("", true);
1007        assert!(matches!(allow, HookDecision::Allow));
1008
1009        let block = parse_command_hook_decision("", false);
1010        assert!(matches!(block, HookDecision::Block { .. }));
1011    }
1012
1013    #[test]
1014    fn parse_model_hook_decision_handles_json_block() {
1015        let decision =
1016            parse_model_hook_decision(r#"{"decision":"block","reason":"unsafe operation"}"#);
1017        assert!(matches!(
1018            decision,
1019            HookDecision::Block {
1020                reason: Some(ref r)
1021            } if r == "unsafe operation"
1022        ));
1023    }
1024
1025    #[test]
1026    fn build_manual_skill_input_loads_skill_body() {
1027        let project = tempfile::tempdir().unwrap();
1028        let skill_dir = project.path().join(".enact").join("skills").join("review");
1029        fs::create_dir_all(&skill_dir).unwrap();
1030        fs::write(
1031            skill_dir.join("SKILL.md"),
1032            "---\nname: review\ndescription: Review code\nversion: 0.1.0\n---\nAlways list findings.\n",
1033        )
1034        .unwrap();
1035
1036        let injected =
1037            build_manual_skill_input(Some(project.path()), "/review inspect this").unwrap();
1038        assert!(injected.contains("Skill: review"));
1039        assert!(injected.contains("inspect this"));
1040    }
1041
1042    struct MockCallable {
1043        calls: AtomicUsize,
1044        first: String,
1045        rest: String,
1046    }
1047
1048    #[async_trait]
1049    impl Callable for MockCallable {
1050        fn name(&self) -> &str {
1051            "mock"
1052        }
1053
1054        async fn run(&self, _input: &str) -> anyhow::Result<String> {
1055            let idx = self.calls.fetch_add(1, Ordering::SeqCst);
1056            if idx == 0 {
1057                Ok(self.first.clone())
1058            } else {
1059                Ok(self.rest.clone())
1060            }
1061        }
1062    }
1063
1064    struct CaptureTool {
1065        seen: Arc<Mutex<Option<serde_json::Value>>>,
1066    }
1067
1068    #[async_trait]
1069    impl Tool for CaptureTool {
1070        fn name(&self) -> &str {
1071            "capture_tool"
1072        }
1073
1074        fn description(&self) -> &str {
1075            "capture"
1076        }
1077
1078        fn parameters_schema(&self) -> serde_json::Value {
1079            json!({"type":"object"})
1080        }
1081
1082        async fn execute(&self, args: serde_json::Value) -> anyhow::Result<serde_json::Value> {
1083            *self.seen.lock().await = Some(args);
1084            Ok(json!({"ok": true}))
1085        }
1086    }
1087
1088    #[tokio::test]
1089    async fn pre_tool_mutation_updates_tool_arguments() {
1090        let temp = tempfile::tempdir().unwrap();
1091        let hooks_path = temp.path().join("hooks.yaml");
1092        fs::write(
1093            &hooks_path,
1094            r#"hooks:
1095  - event: PreToolUse
1096    matcher: "capture_tool"
1097    handler:
1098      type: command
1099      script: "echo '{\"decision\":\"mutate\",\"arguments\":{\"value\":\"mutated\"}}'"
1100"#,
1101        )
1102        .unwrap();
1103        std::env::set_var(
1104            "ENACT_HOOKS_CONFIG_PATH",
1105            hooks_path.to_string_lossy().as_ref(),
1106        );
1107
1108        let seen = Arc::new(Mutex::new(None));
1109        let tool = CaptureTool {
1110            seen: Arc::clone(&seen),
1111        };
1112        let callable = MockCallable {
1113            calls: AtomicUsize::new(0),
1114            first: r#"{"tool_call":{"name":"capture_tool","arguments":{"value":"original"}}}"#
1115                .to_string(),
1116            rest: "done".to_string(),
1117        };
1118
1119        let mut runner = DefaultAgentRunner::default_new()
1120            .add_tool(tool)
1121            .with_hook_registry(Arc::new(HookRegistry::load_global_and_agent(None, None)));
1122        let result = runner.run(&callable, "test", None).await.unwrap();
1123        std::env::remove_var("ENACT_HOOKS_CONFIG_PATH");
1124
1125        assert!(matches!(result, LoopOutcome::Completed(_)));
1126        let captured = seen.lock().await.clone().unwrap();
1127        assert_eq!(captured, json!({"value":"mutated"}));
1128    }
1129
1130    #[tokio::test]
1131    async fn prompt_and_agent_hooks_dispatch_callable() {
1132        let reg = HookRegistry::new();
1133        let runner = DefaultAgentRunner::default_new();
1134        let callable = MockCallable {
1135            calls: AtomicUsize::new(0),
1136            first: r#"{"decision":"allow"}"#.to_string(),
1137            rest: r#"{"decision":"allow"}"#.to_string(),
1138        };
1139        let ctx = json!({"tool_name":"capture_tool","arguments":{"v":1}});
1140
1141        let prompt = enact_config::HookConfig {
1142            event: HookEvent::PreToolUse,
1143            matcher: None,
1144            handler: HookHandler::Prompt {
1145                template: "Decide for {{tool_name}}".to_string(),
1146            },
1147            async_mode: false,
1148        };
1149        let agent = enact_config::HookConfig {
1150            event: HookEvent::PreToolUse,
1151            matcher: None,
1152            handler: HookHandler::Agent {
1153                agent_name: "reviewer".to_string(),
1154            },
1155            async_mode: false,
1156        };
1157
1158        let p = runner
1159            .execute_single_pre_tool_hook(&reg, &prompt, &ctx, &callable)
1160            .await;
1161        let a = runner
1162            .execute_single_pre_tool_hook(&reg, &agent, &ctx, &callable)
1163            .await;
1164
1165        assert!(matches!(p, HookDecision::Allow));
1166        assert!(matches!(a, HookDecision::Allow));
1167        assert_eq!(callable.calls.load(Ordering::SeqCst), 2);
1168    }
1169}