1use std::collections::HashMap;
24use std::time::Duration;
25
26use serde::Deserialize;
27use serde_json::Value;
28use tokio::io::AsyncWriteExt;
29use tokio::process::Command;
30
31use crate::plugin::{
32 BlockReason, HookAction, HookAttachment, HookContext, HookFuture, HookIssue, HookIssueClass,
33 HookPatch, HookPhase, PostHook, PreHook,
34};
35
36pub struct ShellCommandHook {
40 name: &'static str,
41 command: String,
43 timeout: Duration,
45 env: HashMap<String, String>,
48}
49
50impl ShellCommandHook {
51 pub fn new(name: &'static str, command: impl Into<String>) -> Self {
54 Self {
55 name,
56 command: command.into(),
57 timeout: Duration::from_secs(5),
58 env: HashMap::new(),
59 }
60 }
61
62 pub fn with_timeout(mut self, timeout: Duration) -> Self {
64 self.timeout = timeout;
65 self
66 }
67
68 pub fn with_env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
71 self.env.insert(key.into(), value.into());
72 self
73 }
74}
75
76struct ShellOutput {
81 exit_code: i32,
82 stdout: String,
83}
84
85async fn run_process(
89 command: &str,
90 env: &HashMap<String, String>,
91 stdin_bytes: Vec<u8>,
92) -> Result<ShellOutput, String> {
93 let mut child = Command::new("sh")
94 .arg("-c")
95 .arg(command)
96 .envs(env)
97 .stdin(std::process::Stdio::piped())
98 .stdout(std::process::Stdio::piped())
99 .stderr(std::process::Stdio::null())
100 .spawn()
101 .map_err(|e| format!("spawn failed: {e}"))?;
102
103 if let Some(mut stdin) = child.stdin.take() {
104 let _ = stdin.write_all(&stdin_bytes).await;
106 }
107
108 let output = child
109 .wait_with_output()
110 .await
111 .map_err(|e| format!("wait failed: {e}"))?;
112
113 Ok(ShellOutput {
114 exit_code: output.status.code().unwrap_or(-1),
115 stdout: String::from_utf8_lossy(&output.stdout).into_owned(),
116 })
117}
118
119#[derive(Deserialize)]
124#[serde(rename_all = "camelCase")]
125struct ShellPreOutput {
126 #[serde(default)]
128 action: String,
129 prompt_override: Option<String>,
130 model_override: Option<String>,
131 #[serde(default)]
132 add_attachments: Vec<HookAttachment>,
133 #[serde(default)]
134 metadata_delta: Value,
135}
136
137#[derive(Deserialize, Default)]
139struct ShellBlockOutput {
140 #[serde(default)]
141 message: String,
142}
143
144fn parse_pre_output(
147 hook_name: &str,
148 phase: HookPhase,
149 stdout: &str,
150) -> Result<HookAction, HookIssue> {
151 let trimmed = stdout.trim();
152 if trimmed.is_empty() {
154 return Ok(HookAction::Noop);
155 }
156 let parsed: ShellPreOutput = match serde_json::from_str(trimmed) {
157 Ok(v) => v,
158 Err(e) => {
159 return Err(HookIssue {
160 hook_name: hook_name.to_owned(),
161 phase,
162 class: HookIssueClass::Execution,
163 message: format!("stdout parse error: {e}"),
164 })
165 }
166 };
167 if parsed.action.eq_ignore_ascii_case("mutate") {
168 Ok(HookAction::Mutate(HookPatch {
169 prompt_override: parsed.prompt_override,
170 model_override: parsed.model_override,
171 add_attachments: parsed.add_attachments,
172 metadata_delta: parsed.metadata_delta,
173 }))
174 } else {
175 Ok(HookAction::Noop)
176 }
177}
178
179fn parse_block_output(hook_name: &str, phase: HookPhase, stdout: &str) -> BlockReason {
182 let trimmed = stdout.trim();
183 let message = if trimmed.is_empty() {
184 "blocked by hook (no message)".to_owned()
185 } else {
186 serde_json::from_str::<ShellBlockOutput>(trimmed)
188 .map(|o| {
189 if o.message.is_empty() {
190 trimmed.to_owned()
191 } else {
192 o.message
193 }
194 })
195 .unwrap_or_else(|_| trimmed.to_owned())
196 };
197 BlockReason {
198 hook_name: hook_name.to_owned(),
199 phase,
200 message,
201 }
202}
203
204fn execution_issue(hook_name: &str, phase: HookPhase, message: impl Into<String>) -> HookIssue {
207 HookIssue {
208 hook_name: hook_name.to_owned(),
209 phase,
210 class: HookIssueClass::Execution,
211 message: message.into(),
212 }
213}
214
215fn timeout_issue(hook_name: &str, phase: HookPhase, timeout: Duration) -> HookIssue {
218 HookIssue {
219 hook_name: hook_name.to_owned(),
220 phase,
221 class: HookIssueClass::Timeout,
222 message: format!("shell hook timed out after {timeout:?}"),
223 }
224}
225
226impl PreHook for ShellCommandHook {
229 fn name(&self) -> &'static str {
230 self.name
231 }
232
233 fn call<'a>(&'a self, ctx: &'a HookContext) -> HookFuture<'a, Result<HookAction, HookIssue>> {
234 Box::pin(async move {
235 let stdin_bytes = match serde_json::to_vec(ctx) {
236 Ok(b) => b,
237 Err(e) => {
238 return Err(HookIssue {
239 hook_name: self.name.to_owned(),
240 phase: ctx.phase,
241 class: HookIssueClass::Internal,
242 message: format!("context serialize failed: {e}"),
243 })
244 }
245 };
246
247 let output = match tokio::time::timeout(
248 self.timeout,
249 run_process(&self.command, &self.env, stdin_bytes),
250 )
251 .await
252 {
253 Err(_elapsed) => return Err(timeout_issue(self.name, ctx.phase, self.timeout)),
254 Ok(Err(e)) => return Err(execution_issue(self.name, ctx.phase, e)),
255 Ok(Ok(o)) => o,
256 };
257
258 match output.exit_code {
259 0 => parse_pre_output(self.name, ctx.phase, &output.stdout),
260 2 => Ok(HookAction::Block(parse_block_output(
261 self.name,
262 ctx.phase,
263 &output.stdout,
264 ))),
265 code => Err(execution_issue(
266 self.name,
267 ctx.phase,
268 format!("exited with code {code}"),
269 )),
270 }
271 })
272 }
273}
274
275impl PostHook for ShellCommandHook {
278 fn name(&self) -> &'static str {
279 self.name
280 }
281
282 fn call<'a>(&'a self, ctx: &'a HookContext) -> HookFuture<'a, Result<(), HookIssue>> {
283 Box::pin(async move {
284 let stdin_bytes = match serde_json::to_vec(ctx) {
285 Ok(b) => b,
286 Err(e) => {
287 return Err(HookIssue {
288 hook_name: self.name.to_owned(),
289 phase: ctx.phase,
290 class: HookIssueClass::Internal,
291 message: format!("context serialize failed: {e}"),
292 })
293 }
294 };
295
296 let output = match tokio::time::timeout(
297 self.timeout,
298 run_process(&self.command, &self.env, stdin_bytes),
299 )
300 .await
301 {
302 Err(_elapsed) => return Err(timeout_issue(self.name, ctx.phase, self.timeout)),
303 Ok(Err(e)) => return Err(execution_issue(self.name, ctx.phase, e)),
304 Ok(Ok(o)) => o,
305 };
306
307 if output.exit_code == 0 {
308 Ok(())
309 } else {
310 Err(execution_issue(
311 self.name,
312 ctx.phase,
313 format!("exited with code {}", output.exit_code),
314 ))
315 }
316 })
317 }
318}
319
320#[cfg(test)]
323mod tests {
324 use super::*;
325
326 fn phase() -> HookPhase {
327 HookPhase::PreRun
328 }
329
330 #[test]
333 fn empty_stdout_is_noop() {
334 assert_eq!(parse_pre_output("h", phase(), ""), Ok(HookAction::Noop));
335 assert_eq!(parse_pre_output("h", phase(), " "), Ok(HookAction::Noop));
336 }
337
338 #[test]
339 fn empty_object_is_noop() {
340 assert_eq!(parse_pre_output("h", phase(), "{}"), Ok(HookAction::Noop));
341 }
342
343 #[test]
344 fn action_noop_explicit() {
345 assert_eq!(
346 parse_pre_output("h", phase(), r#"{"action":"noop"}"#),
347 Ok(HookAction::Noop)
348 );
349 }
350
351 #[test]
352 fn action_mutate_model_override() {
353 let out = parse_pre_output(
354 "h",
355 phase(),
356 r#"{"action":"mutate","modelOverride":"claude-opus-4-6"}"#,
357 );
358 match out {
359 Ok(HookAction::Mutate(patch)) => {
360 assert_eq!(patch.model_override.as_deref(), Some("claude-opus-4-6"));
361 assert!(patch.prompt_override.is_none());
362 }
363 other => panic!("expected Mutate, got {other:?}"),
364 }
365 }
366
367 #[test]
368 fn action_mutate_prompt_override() {
369 let out = parse_pre_output(
370 "h",
371 phase(),
372 r#"{"action":"mutate","promptOverride":"new prompt"}"#,
373 );
374 match out {
375 Ok(HookAction::Mutate(patch)) => {
376 assert_eq!(patch.prompt_override.as_deref(), Some("new prompt"));
377 }
378 other => panic!("expected Mutate, got {other:?}"),
379 }
380 }
381
382 #[test]
383 fn unknown_action_is_noop() {
384 assert_eq!(
385 parse_pre_output("h", phase(), r#"{"action":"unknown"}"#),
386 Ok(HookAction::Noop)
387 );
388 }
389
390 #[test]
391 fn invalid_json_is_execution_issue() {
392 let result = parse_pre_output("h", phase(), "not-json");
393 assert!(matches!(
394 result,
395 Err(HookIssue {
396 class: HookIssueClass::Execution,
397 ..
398 })
399 ));
400 }
401
402 #[test]
405 fn block_with_json_message() {
406 let r = parse_block_output("h", phase(), r#"{"message":"rm -rf blocked"}"#);
407 assert_eq!(r.message, "rm -rf blocked");
408 assert_eq!(r.hook_name, "h");
409 }
410
411 #[test]
412 fn block_with_plain_text_message() {
413 let r = parse_block_output("h", phase(), "plain text reason");
414 assert_eq!(r.message, "plain text reason");
415 }
416
417 #[test]
418 fn block_with_empty_stdout_gives_fallback() {
419 let r = parse_block_output("h", phase(), "");
420 assert_eq!(r.message, "blocked by hook (no message)");
421 }
422
423 #[test]
424 fn block_with_json_empty_message_falls_back_to_raw() {
425 let r = parse_block_output("h", phase(), r#"{"message":""}"#);
427 assert_eq!(r.message, r#"{"message":""}"#);
428 }
429
430 fn ctx() -> HookContext {
433 use serde_json::json;
434 HookContext {
435 phase: HookPhase::PreRun,
436 thread_id: None,
437 turn_id: None,
438 cwd: Some("/tmp".to_owned()),
439 model: None,
440 main_status: None,
441 correlation_id: "hk-1".to_owned(),
442 ts_ms: 0,
443 metadata: json!({}),
444 tool_name: None,
445 tool_input: None,
446 }
447 }
448
449 #[tokio::test]
450 async fn pre_hook_exit0_empty_stdout_is_noop() {
451 let hook = ShellCommandHook::new("test-noop", "exit 0");
452 let result = PreHook::call(&hook, &ctx()).await;
453 assert_eq!(result, Ok(HookAction::Noop));
454 }
455
456 #[tokio::test]
457 async fn pre_hook_exit2_blocks() {
458 let hook = ShellCommandHook::new("test-block", r#"echo '{"message":"denied"}' ; exit 2"#);
459 let result = PreHook::call(&hook, &ctx()).await;
460 match result {
461 Ok(HookAction::Block(r)) => assert_eq!(r.message, "denied"),
462 other => panic!("expected Block, got {other:?}"),
463 }
464 }
465
466 #[tokio::test]
467 async fn pre_hook_exit1_is_execution_error() {
468 let hook = ShellCommandHook::new("test-err", "exit 1");
469 let result = PreHook::call(&hook, &ctx()).await;
470 assert!(matches!(
471 result,
472 Err(HookIssue {
473 class: HookIssueClass::Execution,
474 ..
475 })
476 ));
477 }
478
479 #[tokio::test]
480 async fn pre_hook_exit0_mutate_model() {
481 let hook = ShellCommandHook::new(
482 "test-mutate",
483 r#"echo '{"action":"mutate","modelOverride":"claude-haiku-4-5-20251001"}'"#,
484 );
485 let result = PreHook::call(&hook, &ctx()).await;
486 match result {
487 Ok(HookAction::Mutate(patch)) => {
488 assert_eq!(
489 patch.model_override.as_deref(),
490 Some("claude-haiku-4-5-20251001")
491 );
492 }
493 other => panic!("expected Mutate, got {other:?}"),
494 }
495 }
496
497 #[tokio::test]
498 async fn pre_hook_timeout_returns_timeout_issue() {
499 let hook = ShellCommandHook::new("test-timeout", "sleep 60")
500 .with_timeout(Duration::from_millis(50));
501 let result = PreHook::call(&hook, &ctx()).await;
502 assert!(matches!(
503 result,
504 Err(HookIssue {
505 class: HookIssueClass::Timeout,
506 ..
507 })
508 ));
509 }
510
511 #[tokio::test]
512 async fn post_hook_exit0_is_ok() {
513 let hook = ShellCommandHook::new("test-post", "exit 0");
514 let result = PostHook::call(&hook, &ctx()).await;
515 assert_eq!(result, Ok(()));
516 }
517
518 #[tokio::test]
519 async fn post_hook_nonzero_is_execution_error() {
520 let hook = ShellCommandHook::new("test-post-err", "exit 1");
521 let result = PostHook::call(&hook, &ctx()).await;
522 assert!(matches!(
523 result,
524 Err(HookIssue {
525 class: HookIssueClass::Execution,
526 ..
527 })
528 ));
529 }
530
531 #[tokio::test]
532 async fn stdin_receives_hook_context_json() {
533 if std::process::Command::new("jq")
536 .arg("--version")
537 .output()
538 .is_err()
539 {
540 return; }
542 let hook = ShellCommandHook::new(
543 "test-stdin",
544 r#"phase=$(cat | jq -r '.phase'); echo "{\"action\":\"mutate\",\"modelOverride\":\"$phase\"}""#,
545 );
546 let result = PreHook::call(&hook, &ctx()).await;
547 match result {
548 Ok(HookAction::Mutate(patch)) => {
549 assert_eq!(patch.model_override.as_deref(), Some("PreRun"));
550 }
551 other => panic!("expected Mutate, got {other:?}"),
552 }
553 }
554
555 #[tokio::test]
556 async fn with_env_passes_env_to_process() {
557 let hook = ShellCommandHook::new(
558 "test-env",
559 r#"echo "{\"action\":\"mutate\",\"modelOverride\":\"$MY_VAR\"}""#,
560 )
561 .with_env("MY_VAR", "injected-value");
562 let result = PreHook::call(&hook, &ctx()).await;
563 match result {
564 Ok(HookAction::Mutate(patch)) => {
565 assert_eq!(patch.model_override.as_deref(), Some("injected-value"));
566 }
567 other => panic!("expected Mutate, got {other:?}"),
568 }
569 }
570}