Skip to main content

codex_runtime/runtime/
shell_hook.rs

1//! Shell-command hook adapter.
2//!
3//! Wraps an external process as a [`PreHook`] or [`PostHook`].
4//! The process receives [`HookContext`] as JSON on stdin and signals its
5//! decision via exit code + stdout JSON.
6//!
7//! ## Exit-code contract (PreHook)
8//!
9//! | exit | stdout | result |
10//! |------|--------|--------|
11//! | `0`  | `{}` or `{"action":"noop"}` | `HookAction::Noop` |
12//! | `0`  | `{"action":"mutate", ...}` | `HookAction::Mutate(patch)` |
13//! | `2`  | `{"message":"..."}` or plain text | `HookAction::Block(reason)` |
14//! | any other | — | `Err(HookIssue { class: Execution })` |
15//!
16//! ## Exit-code contract (PostHook)
17//!
18//! | exit | result |
19//! |------|--------|
20//! | `0`  | `Ok(())` |
21//! | any other | `Err(HookIssue { class: Execution })` |
22
23use std::collections::HashMap;
24use std::time::Duration;
25
26use serde::Deserialize;
27use serde_json::Value;
28use tokio::io::AsyncWriteExt;
29use tokio::process::Command;
30
31use crate::plugin::{
32    BlockReason, HookAction, HookAttachment, HookContext, HookFuture, HookIssue, HookIssueClass,
33    HookPatch, HookPhase, PostHook, PreHook,
34};
35
36/// An external shell command registered as a [`PreHook`] or [`PostHook`].
37///
38/// Allocation: two Strings + optional env map at construction.
39pub struct ShellCommandHook {
40    name: &'static str,
41    /// Shell command string passed to `sh -c`. Example: `"python3 /path/hook.py"`.
42    command: String,
43    /// Hard wall-clock limit for the subprocess. Default: 5 seconds.
44    timeout: Duration,
45    /// Extra environment variables injected into the subprocess.
46    /// Complexity: O(n) insert at construction, O(n) clone per call.
47    env: HashMap<String, String>,
48}
49
50impl ShellCommandHook {
51    /// Construct with a 5-second default timeout.
52    /// Allocation: two Strings.
53    pub fn new(name: &'static str, command: impl Into<String>) -> Self {
54        Self {
55            name,
56            command: command.into(),
57            timeout: Duration::from_secs(5),
58            env: HashMap::new(),
59        }
60    }
61
62    /// Override the subprocess timeout.
63    pub fn with_timeout(mut self, timeout: Duration) -> Self {
64        self.timeout = timeout;
65        self
66    }
67
68    /// Inject one extra environment variable.
69    /// Allocation: two Strings per call.
70    pub fn with_env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
71        self.env.insert(key.into(), value.into());
72        self
73    }
74}
75
76// ── Shared subprocess execution ──────────────────────────────────────────────
77
78/// Raw output from a finished subprocess.
79/// Allocation: two Strings (stdout, stderr).
80struct ShellOutput {
81    exit_code: i32,
82    stdout: String,
83}
84
85/// Spawn `sh -c <command>`, feed `stdin_bytes` to stdin, collect stdout.
86/// Allocation: stdin bytes cloned to pipe, stdout accumulated in String.
87/// Side effect: spawns an OS process.
88async fn run_process(
89    command: &str,
90    env: &HashMap<String, String>,
91    stdin_bytes: Vec<u8>,
92) -> Result<ShellOutput, String> {
93    let mut child = Command::new("sh")
94        .arg("-c")
95        .arg(command)
96        .envs(env)
97        .stdin(std::process::Stdio::piped())
98        .stdout(std::process::Stdio::piped())
99        .stderr(std::process::Stdio::null())
100        .spawn()
101        .map_err(|e| format!("spawn failed: {e}"))?;
102
103    if let Some(mut stdin) = child.stdin.take() {
104        // Ignore write errors — the process may have exited before reading all input.
105        let _ = stdin.write_all(&stdin_bytes).await;
106    }
107
108    let output = child
109        .wait_with_output()
110        .await
111        .map_err(|e| format!("wait failed: {e}"))?;
112
113    Ok(ShellOutput {
114        exit_code: output.status.code().unwrap_or(-1),
115        stdout: String::from_utf8_lossy(&output.stdout).into_owned(),
116    })
117}
118
119// ── PreHook output parsing (pure functions) ───────────────────────────────────
120
121/// Wire shape of a `mutate` response from a shell pre-hook.
122/// Fields mirror [`HookPatch`] with camelCase naming for shell-script ergonomics.
123#[derive(Deserialize)]
124#[serde(rename_all = "camelCase")]
125struct ShellPreOutput {
126    /// `"noop"` or `"mutate"`. Anything else is treated as `"noop"`.
127    #[serde(default)]
128    action: String,
129    prompt_override: Option<String>,
130    model_override: Option<String>,
131    #[serde(default)]
132    add_attachments: Vec<HookAttachment>,
133    #[serde(default)]
134    metadata_delta: Value,
135}
136
137/// Wire shape of a block response (exit 2).
138#[derive(Deserialize, Default)]
139struct ShellBlockOutput {
140    #[serde(default)]
141    message: String,
142}
143
144/// Parse stdout from a shell pre-hook that exited with code 0.
145/// Pure: no I/O. Allocation: depends on patch contents.
146fn parse_pre_output(
147    hook_name: &str,
148    phase: HookPhase,
149    stdout: &str,
150) -> Result<HookAction, HookIssue> {
151    let trimmed = stdout.trim();
152    // Empty stdout → Noop. Avoids requiring shell scripts to emit `{}`.
153    if trimmed.is_empty() {
154        return Ok(HookAction::Noop);
155    }
156    let parsed: ShellPreOutput = match serde_json::from_str(trimmed) {
157        Ok(v) => v,
158        Err(e) => {
159            return Err(HookIssue {
160                hook_name: hook_name.to_owned(),
161                phase,
162                class: HookIssueClass::Execution,
163                message: format!("stdout parse error: {e}"),
164            })
165        }
166    };
167    if parsed.action.eq_ignore_ascii_case("mutate") {
168        Ok(HookAction::Mutate(HookPatch {
169            prompt_override: parsed.prompt_override,
170            model_override: parsed.model_override,
171            add_attachments: parsed.add_attachments,
172            metadata_delta: parsed.metadata_delta,
173        }))
174    } else {
175        Ok(HookAction::Noop)
176    }
177}
178
179/// Parse stdout from a shell pre-hook that exited with code 2.
180/// Pure: no I/O. Allocation: one String for message.
181fn parse_block_output(hook_name: &str, phase: HookPhase, stdout: &str) -> BlockReason {
182    let trimmed = stdout.trim();
183    let message = if trimmed.is_empty() {
184        "blocked by hook (no message)".to_owned()
185    } else {
186        // Try JSON `{"message":"..."}`, fall back to raw stdout as message.
187        serde_json::from_str::<ShellBlockOutput>(trimmed)
188            .map(|o| {
189                if o.message.is_empty() {
190                    trimmed.to_owned()
191                } else {
192                    o.message
193                }
194            })
195            .unwrap_or_else(|_| trimmed.to_owned())
196    };
197    BlockReason {
198        hook_name: hook_name.to_owned(),
199        phase,
200        message,
201    }
202}
203
204/// Build an `Execution` issue for subprocess failures.
205/// Pure. Allocation: one String.
206fn execution_issue(hook_name: &str, phase: HookPhase, message: impl Into<String>) -> HookIssue {
207    HookIssue {
208        hook_name: hook_name.to_owned(),
209        phase,
210        class: HookIssueClass::Execution,
211        message: message.into(),
212    }
213}
214
215/// Build a `Timeout` issue.
216/// Pure. Allocation: one String.
217fn timeout_issue(hook_name: &str, phase: HookPhase, timeout: Duration) -> HookIssue {
218    HookIssue {
219        hook_name: hook_name.to_owned(),
220        phase,
221        class: HookIssueClass::Timeout,
222        message: format!("shell hook timed out after {timeout:?}"),
223    }
224}
225
226// ── PreHook impl ─────────────────────────────────────────────────────────────
227
228impl PreHook for ShellCommandHook {
229    fn name(&self) -> &'static str {
230        self.name
231    }
232
233    fn call<'a>(&'a self, ctx: &'a HookContext) -> HookFuture<'a, Result<HookAction, HookIssue>> {
234        Box::pin(async move {
235            let stdin_bytes = match serde_json::to_vec(ctx) {
236                Ok(b) => b,
237                Err(e) => {
238                    return Err(HookIssue {
239                        hook_name: self.name.to_owned(),
240                        phase: ctx.phase,
241                        class: HookIssueClass::Internal,
242                        message: format!("context serialize failed: {e}"),
243                    })
244                }
245            };
246
247            let output = match tokio::time::timeout(
248                self.timeout,
249                run_process(&self.command, &self.env, stdin_bytes),
250            )
251            .await
252            {
253                Err(_elapsed) => return Err(timeout_issue(self.name, ctx.phase, self.timeout)),
254                Ok(Err(e)) => return Err(execution_issue(self.name, ctx.phase, e)),
255                Ok(Ok(o)) => o,
256            };
257
258            match output.exit_code {
259                0 => parse_pre_output(self.name, ctx.phase, &output.stdout),
260                2 => Ok(HookAction::Block(parse_block_output(
261                    self.name,
262                    ctx.phase,
263                    &output.stdout,
264                ))),
265                code => Err(execution_issue(
266                    self.name,
267                    ctx.phase,
268                    format!("exited with code {code}"),
269                )),
270            }
271        })
272    }
273}
274
275// ── PostHook impl ─────────────────────────────────────────────────────────────
276
277impl PostHook for ShellCommandHook {
278    fn name(&self) -> &'static str {
279        self.name
280    }
281
282    fn call<'a>(&'a self, ctx: &'a HookContext) -> HookFuture<'a, Result<(), HookIssue>> {
283        Box::pin(async move {
284            let stdin_bytes = match serde_json::to_vec(ctx) {
285                Ok(b) => b,
286                Err(e) => {
287                    return Err(HookIssue {
288                        hook_name: self.name.to_owned(),
289                        phase: ctx.phase,
290                        class: HookIssueClass::Internal,
291                        message: format!("context serialize failed: {e}"),
292                    })
293                }
294            };
295
296            let output = match tokio::time::timeout(
297                self.timeout,
298                run_process(&self.command, &self.env, stdin_bytes),
299            )
300            .await
301            {
302                Err(_elapsed) => return Err(timeout_issue(self.name, ctx.phase, self.timeout)),
303                Ok(Err(e)) => return Err(execution_issue(self.name, ctx.phase, e)),
304                Ok(Ok(o)) => o,
305            };
306
307            if output.exit_code == 0 {
308                Ok(())
309            } else {
310                Err(execution_issue(
311                    self.name,
312                    ctx.phase,
313                    format!("exited with code {}", output.exit_code),
314                ))
315            }
316        })
317    }
318}
319
320// ── Unit tests ────────────────────────────────────────────────────────────────
321
322#[cfg(test)]
323mod tests {
324    use super::*;
325
326    fn phase() -> HookPhase {
327        HookPhase::PreRun
328    }
329
330    // ── parse_pre_output ────────────────────────────────────────────────────
331
332    #[test]
333    fn empty_stdout_is_noop() {
334        assert_eq!(parse_pre_output("h", phase(), ""), Ok(HookAction::Noop));
335        assert_eq!(parse_pre_output("h", phase(), "  "), Ok(HookAction::Noop));
336    }
337
338    #[test]
339    fn empty_object_is_noop() {
340        assert_eq!(parse_pre_output("h", phase(), "{}"), Ok(HookAction::Noop));
341    }
342
343    #[test]
344    fn action_noop_explicit() {
345        assert_eq!(
346            parse_pre_output("h", phase(), r#"{"action":"noop"}"#),
347            Ok(HookAction::Noop)
348        );
349    }
350
351    #[test]
352    fn action_mutate_model_override() {
353        let out = parse_pre_output(
354            "h",
355            phase(),
356            r#"{"action":"mutate","modelOverride":"claude-opus-4-6"}"#,
357        );
358        match out {
359            Ok(HookAction::Mutate(patch)) => {
360                assert_eq!(patch.model_override.as_deref(), Some("claude-opus-4-6"));
361                assert!(patch.prompt_override.is_none());
362            }
363            other => panic!("expected Mutate, got {other:?}"),
364        }
365    }
366
367    #[test]
368    fn action_mutate_prompt_override() {
369        let out = parse_pre_output(
370            "h",
371            phase(),
372            r#"{"action":"mutate","promptOverride":"new prompt"}"#,
373        );
374        match out {
375            Ok(HookAction::Mutate(patch)) => {
376                assert_eq!(patch.prompt_override.as_deref(), Some("new prompt"));
377            }
378            other => panic!("expected Mutate, got {other:?}"),
379        }
380    }
381
382    #[test]
383    fn unknown_action_is_noop() {
384        assert_eq!(
385            parse_pre_output("h", phase(), r#"{"action":"unknown"}"#),
386            Ok(HookAction::Noop)
387        );
388    }
389
390    #[test]
391    fn invalid_json_is_execution_issue() {
392        let result = parse_pre_output("h", phase(), "not-json");
393        assert!(matches!(
394            result,
395            Err(HookIssue {
396                class: HookIssueClass::Execution,
397                ..
398            })
399        ));
400    }
401
402    // ── parse_block_output ──────────────────────────────────────────────────
403
404    #[test]
405    fn block_with_json_message() {
406        let r = parse_block_output("h", phase(), r#"{"message":"rm -rf blocked"}"#);
407        assert_eq!(r.message, "rm -rf blocked");
408        assert_eq!(r.hook_name, "h");
409    }
410
411    #[test]
412    fn block_with_plain_text_message() {
413        let r = parse_block_output("h", phase(), "plain text reason");
414        assert_eq!(r.message, "plain text reason");
415    }
416
417    #[test]
418    fn block_with_empty_stdout_gives_fallback() {
419        let r = parse_block_output("h", phase(), "");
420        assert_eq!(r.message, "blocked by hook (no message)");
421    }
422
423    #[test]
424    fn block_with_json_empty_message_falls_back_to_raw() {
425        // `{"message":""}` → raw stdout used as message
426        let r = parse_block_output("h", phase(), r#"{"message":""}"#);
427        assert_eq!(r.message, r#"{"message":""}"#);
428    }
429
430    // ── ShellCommandHook integration (requires sh) ──────────────────────────
431
432    fn ctx() -> HookContext {
433        use serde_json::json;
434        HookContext {
435            phase: HookPhase::PreRun,
436            thread_id: None,
437            turn_id: None,
438            cwd: Some("/tmp".to_owned()),
439            model: None,
440            main_status: None,
441            correlation_id: "hk-1".to_owned(),
442            ts_ms: 0,
443            metadata: json!({}),
444            tool_name: None,
445            tool_input: None,
446        }
447    }
448
449    #[tokio::test]
450    async fn pre_hook_exit0_empty_stdout_is_noop() {
451        let hook = ShellCommandHook::new("test-noop", "exit 0");
452        let result = PreHook::call(&hook, &ctx()).await;
453        assert_eq!(result, Ok(HookAction::Noop));
454    }
455
456    #[tokio::test]
457    async fn pre_hook_exit2_blocks() {
458        let hook = ShellCommandHook::new("test-block", r#"echo '{"message":"denied"}' ; exit 2"#);
459        let result = PreHook::call(&hook, &ctx()).await;
460        match result {
461            Ok(HookAction::Block(r)) => assert_eq!(r.message, "denied"),
462            other => panic!("expected Block, got {other:?}"),
463        }
464    }
465
466    #[tokio::test]
467    async fn pre_hook_exit1_is_execution_error() {
468        let hook = ShellCommandHook::new("test-err", "exit 1");
469        let result = PreHook::call(&hook, &ctx()).await;
470        assert!(matches!(
471            result,
472            Err(HookIssue {
473                class: HookIssueClass::Execution,
474                ..
475            })
476        ));
477    }
478
479    #[tokio::test]
480    async fn pre_hook_exit0_mutate_model() {
481        let hook = ShellCommandHook::new(
482            "test-mutate",
483            r#"echo '{"action":"mutate","modelOverride":"claude-haiku-4-5-20251001"}'"#,
484        );
485        let result = PreHook::call(&hook, &ctx()).await;
486        match result {
487            Ok(HookAction::Mutate(patch)) => {
488                assert_eq!(
489                    patch.model_override.as_deref(),
490                    Some("claude-haiku-4-5-20251001")
491                );
492            }
493            other => panic!("expected Mutate, got {other:?}"),
494        }
495    }
496
497    #[tokio::test]
498    async fn pre_hook_timeout_returns_timeout_issue() {
499        let hook = ShellCommandHook::new("test-timeout", "sleep 60")
500            .with_timeout(Duration::from_millis(50));
501        let result = PreHook::call(&hook, &ctx()).await;
502        assert!(matches!(
503            result,
504            Err(HookIssue {
505                class: HookIssueClass::Timeout,
506                ..
507            })
508        ));
509    }
510
511    #[tokio::test]
512    async fn post_hook_exit0_is_ok() {
513        let hook = ShellCommandHook::new("test-post", "exit 0");
514        let result = PostHook::call(&hook, &ctx()).await;
515        assert_eq!(result, Ok(()));
516    }
517
518    #[tokio::test]
519    async fn post_hook_nonzero_is_execution_error() {
520        let hook = ShellCommandHook::new("test-post-err", "exit 1");
521        let result = PostHook::call(&hook, &ctx()).await;
522        assert!(matches!(
523            result,
524            Err(HookIssue {
525                class: HookIssueClass::Execution,
526                ..
527            })
528        ));
529    }
530
531    #[tokio::test]
532    async fn stdin_receives_hook_context_json() {
533        // The hook reads stdin and echoes the phase field back in model_override.
534        // Uses `jq` if available; skip if not.
535        if std::process::Command::new("jq")
536            .arg("--version")
537            .output()
538            .is_err()
539        {
540            return; // jq not installed — skip
541        }
542        let hook = ShellCommandHook::new(
543            "test-stdin",
544            r#"phase=$(cat | jq -r '.phase'); echo "{\"action\":\"mutate\",\"modelOverride\":\"$phase\"}""#,
545        );
546        let result = PreHook::call(&hook, &ctx()).await;
547        match result {
548            Ok(HookAction::Mutate(patch)) => {
549                assert_eq!(patch.model_override.as_deref(), Some("PreRun"));
550            }
551            other => panic!("expected Mutate, got {other:?}"),
552        }
553    }
554
555    #[tokio::test]
556    async fn with_env_passes_env_to_process() {
557        let hook = ShellCommandHook::new(
558            "test-env",
559            r#"echo "{\"action\":\"mutate\",\"modelOverride\":\"$MY_VAR\"}""#,
560        )
561        .with_env("MY_VAR", "injected-value");
562        let result = PreHook::call(&hook, &ctx()).await;
563        match result {
564            Ok(HookAction::Mutate(patch)) => {
565                assert_eq!(patch.model_override.as_deref(), Some("injected-value"));
566            }
567            other => panic!("expected Mutate, got {other:?}"),
568        }
569    }
570}