1use std::collections::HashMap;
76use std::hash::BuildHasher;
77use std::time::Duration;
78
79use serde::Serialize;
80use thiserror::Error;
81use tokio::io::AsyncWriteExt as _;
82use tokio::process::Command;
83use tokio::time::timeout;
84
85pub use zeph_config::{HookAction, HookDef, HookMatcher, SubagentHooks};
86
87#[derive(Debug, Default)]
94pub struct HookOutput {
95 pub updated_tool_output: Option<String>,
99}
100
101#[derive(Debug, Default)]
105pub struct HookRunResult {
106 pub output: HookOutput,
109}
110
111#[derive(Debug, Serialize)]
119pub struct PostToolUseHookInput<'a> {
120 pub tool_name: &'a str,
122 pub tool_args: &'a serde_json::Value,
124 #[serde(skip_serializing_if = "Option::is_none")]
126 pub session_id: Option<&'a str>,
127 pub duration_ms: u64,
129 #[serde(skip_serializing_if = "Option::is_none")]
131 pub tool_output: Option<&'a str>,
132 #[serde(skip_serializing_if = "Option::is_none")]
134 pub tool_error: Option<&'a str>,
135}
136
137const HOOK_STDOUT_CAP: usize = 1024 * 1024; pub trait McpDispatch: Send + Sync {
153 fn call_tool<'a>(
155 &'a self,
156 server: &'a str,
157 tool: &'a str,
158 args: serde_json::Value,
159 ) -> std::pin::Pin<
160 Box<dyn std::future::Future<Output = Result<serde_json::Value, String>> + Send + 'a>,
161 >;
162}
163
164#[non_exhaustive]
168#[derive(Debug, Error)]
169pub enum HookError {
170 #[error("hook command failed (exit code {code}): {command}")]
172 NonZeroExit { command: String, code: i32 },
173
174 #[error("hook command timed out after {timeout_secs}s: {command}")]
176 Timeout { command: String, timeout_secs: u64 },
177
178 #[error("hook I/O error for command '{command}': {source}")]
180 Io {
181 command: String,
182 #[source]
183 source: std::io::Error,
184 },
185
186 #[error(
188 "mcp_tool hook requires an MCP manager but none was provided (server={server}, tool={tool})"
189 )]
190 McpUnavailable { server: String, tool: String },
191
192 #[error("mcp_tool hook failed (server={server}, tool={tool}): {reason}")]
194 McpToolFailed {
195 server: String,
196 tool: String,
197 reason: String,
198 },
199}
200
201#[must_use]
222pub fn matching_hooks<'a>(matchers: &'a [HookMatcher], tool_name: &str) -> Vec<&'a HookDef> {
223 let mut result = Vec::new();
224 for m in matchers {
225 let matched = m
226 .matcher
227 .split('|')
228 .filter(|token| !token.is_empty())
229 .any(|token| tool_name.contains(token));
230 if matched {
231 result.extend(m.hooks.iter());
232 }
233 }
234 result
235}
236
237pub const TOOL_ARGS_JSON_LIMIT: usize = 64 * 1024;
243
244#[must_use]
263pub fn make_base_hook_env(
264 tool_name: &str,
265 tool_input: &serde_json::Value,
266) -> HashMap<String, String> {
267 let mut env = HashMap::new();
268 env.insert("ZEPH_TOOL_NAME".to_owned(), tool_name.to_owned());
269
270 let raw = serde_json::to_string(tool_input).unwrap_or_default();
271 let args_json = if raw.len() > TOOL_ARGS_JSON_LIMIT {
272 tracing::warn!(
273 tool = tool_name,
274 len = raw.len(),
275 limit = TOOL_ARGS_JSON_LIMIT,
276 "ZEPH_TOOL_ARGS_JSON truncated for hook dispatch"
277 );
278 let limit = raw.floor_char_boundary(TOOL_ARGS_JSON_LIMIT);
279 format!("{}…", &raw[..limit])
280 } else {
281 raw
282 };
283 env.insert("ZEPH_TOOL_ARGS_JSON".to_owned(), args_json);
284
285 env
286}
287
288pub async fn fire_hooks<S: BuildHasher>(
315 hooks: &[HookDef],
316 env: &HashMap<String, String, S>,
317 mcp: Option<&dyn McpDispatch>,
318 stdin_json: Option<&[u8]>,
319) -> Result<HookRunResult, HookError> {
320 let mut run_result = HookRunResult::default();
321 for hook in hooks {
322 let effective_stdin = run_result
325 .output
326 .updated_tool_output
327 .as_deref()
328 .map(str::as_bytes)
329 .or(stdin_json);
330 let result = fire_single_hook(hook, env, mcp, effective_stdin).await;
331 match result {
332 Ok(hook_output) => {
333 if hook_output.updated_tool_output.is_some() {
334 run_result.output.updated_tool_output = hook_output.updated_tool_output;
335 }
336 }
337 Err(e) if hook.fail_closed => {
338 tracing::error!(
339 error = %e,
340 "fail-closed hook failed — aborting"
341 );
342 return Err(e);
343 }
344 Err(e) => {
345 tracing::warn!(
346 error = %e,
347 "hook failed (fail_open) — continuing"
348 );
349 }
350 }
351 }
352 Ok(run_result)
353}
354
355async fn fire_single_hook<S: BuildHasher>(
356 hook: &HookDef,
357 env: &HashMap<String, String, S>,
358 mcp: Option<&dyn McpDispatch>,
359 stdin_json: Option<&[u8]>,
360) -> Result<HookOutput, HookError> {
361 match &hook.action {
362 HookAction::Command { command } => {
363 fire_shell_hook(command, hook.timeout_secs, env, stdin_json).await
364 }
365 HookAction::McpTool { server, tool, args } => {
366 let dispatcher = mcp.ok_or_else(|| HookError::McpUnavailable {
367 server: server.clone(),
368 tool: tool.clone(),
369 })?;
370 let call_fut = dispatcher.call_tool(server, tool, args.clone());
371 match timeout(Duration::from_secs(hook.timeout_secs), call_fut).await {
372 Ok(Ok(_)) => {
373 Ok(HookOutput::default())
375 }
376 Ok(Err(reason)) => Err(HookError::McpToolFailed {
377 server: server.clone(),
378 tool: tool.clone(),
379 reason,
380 }),
381 Err(_) => Err(HookError::Timeout {
382 command: format!("mcp_tool:{server}/{tool}"),
383 timeout_secs: hook.timeout_secs,
384 }),
385 }
386 }
387 _ => Ok(HookOutput::default()),
388 }
389}
390
391async fn fire_shell_hook<S: BuildHasher>(
392 command: &str,
393 timeout_secs: u64,
394 env: &HashMap<String, String, S>,
395 stdin_json: Option<&[u8]>,
396) -> Result<HookOutput, HookError> {
397 use std::process::Stdio;
398 use tokio::io::AsyncReadExt as _;
399
400 let mut cmd = Command::new("sh");
401 cmd.arg("-c").arg(command);
402 cmd.env_clear();
404 if let Ok(path) = std::env::var("PATH") {
406 cmd.env("PATH", path);
407 }
408 for (k, v) in env {
409 cmd.env(k, v);
410 }
411 cmd.stdin(if stdin_json.is_some() {
412 Stdio::piped()
413 } else {
414 Stdio::null()
415 });
416 cmd.stdout(Stdio::piped());
418 cmd.stderr(Stdio::null());
419
420 let mut child = cmd.spawn().map_err(|e| HookError::Io {
421 command: command.to_owned(),
422 source: e,
423 })?;
424
425 if let Some(bytes) = stdin_json
428 && let Some(mut stdin_handle) = child.stdin.take()
429 && let Err(e) = stdin_handle.write_all(bytes).await
430 {
431 tracing::warn!(
432 command,
433 error = %e,
434 "failed to write stdin to hook — continuing without stdin data"
435 );
436 }
437
438 let stdout_handle = child.stdout.take();
441 match timeout(Duration::from_secs(timeout_secs), child.wait()).await {
442 Ok(Ok(status)) => {
443 let mut stdout_bytes = Vec::new();
444 if let Some(handle) = stdout_handle {
445 let mut limited = handle.take(HOOK_STDOUT_CAP as u64 + 1);
446 let _ = limited.read_to_end(&mut stdout_bytes).await;
447 }
448 if status.success() {
449 Ok(parse_hook_stdout(command, &stdout_bytes))
450 } else {
451 Err(HookError::NonZeroExit {
452 command: command.to_owned(),
453 code: status.code().unwrap_or(-1),
454 })
455 }
456 }
457 Ok(Err(e)) => Err(HookError::Io {
458 command: command.to_owned(),
459 source: e,
460 }),
461 Err(_) => {
462 let _ = child.kill().await;
464 Err(HookError::Timeout {
465 command: command.to_owned(),
466 timeout_secs,
467 })
468 }
469 }
470}
471
472fn parse_hook_stdout(command: &str, bytes: &[u8]) -> HookOutput {
477 if bytes.is_empty() {
478 return HookOutput::default();
479 }
480 if bytes.len() > HOOK_STDOUT_CAP {
481 tracing::warn!(
482 command,
483 bytes = bytes.len(),
484 cap = HOOK_STDOUT_CAP,
485 "hook stdout exceeds 1 MiB cap — treating as no substitution"
486 );
487 return HookOutput::default();
488 }
489 let Ok(text) = std::str::from_utf8(bytes) else {
490 tracing::warn!(command, "hook stdout is not valid UTF-8 — no substitution");
491 return HookOutput::default();
492 };
493 let Ok(json) = serde_json::from_str::<serde_json::Value>(text) else {
495 return HookOutput::default();
496 };
497 let updated = json
498 .get("hookSpecificOutput")
499 .and_then(|h| h.get("updatedToolOutput"));
500
501 match updated {
502 None | Some(serde_json::Value::Null) => HookOutput::default(),
503 Some(serde_json::Value::String(s)) => HookOutput {
504 updated_tool_output: Some(s.clone()),
505 },
506 Some(other) => {
507 tracing::warn!(
508 command,
509 kind = other
510 .is_object()
511 .then_some("object")
512 .or_else(|| other.is_array().then_some("array"))
513 .or_else(|| other.is_number().then_some("number"))
514 .or_else(|| other.is_boolean().then_some("boolean"))
515 .unwrap_or("unknown"),
516 "hookSpecificOutput.updatedToolOutput has unexpected type — no substitution"
517 );
518 HookOutput::default()
519 }
520 }
521}
522
523#[cfg(test)]
526mod tests {
527 use super::*;
528
529 fn cmd_hook(command: &str, fail_closed: bool, timeout_secs: u64) -> HookDef {
530 HookDef {
531 action: HookAction::Command {
532 command: command.to_owned(),
533 },
534 timeout_secs,
535 fail_closed,
536 }
537 }
538
539 fn make_matcher(matcher: &str, hooks: Vec<HookDef>) -> HookMatcher {
540 HookMatcher {
541 matcher: matcher.to_owned(),
542 hooks,
543 }
544 }
545
546 #[test]
549 fn matching_hooks_exact_name() {
550 let hook = cmd_hook("echo hi", false, 30);
551 let matchers = vec![make_matcher("Edit", vec![hook.clone()])];
552 let result = matching_hooks(&matchers, "Edit");
553 assert_eq!(result.len(), 1);
554 assert!(
555 matches!(&result[0].action, HookAction::Command { command } if command == "echo hi")
556 );
557 }
558
559 #[test]
560 fn matching_hooks_substring() {
561 let hook = cmd_hook("echo sub", false, 30);
562 let matchers = vec![make_matcher("Edit", vec![hook.clone()])];
563 let result = matching_hooks(&matchers, "EditFile");
564 assert_eq!(result.len(), 1);
565 }
566
567 #[test]
568 fn matching_hooks_pipe_separated() {
569 let h1 = cmd_hook("echo e", false, 30);
570 let h2 = cmd_hook("echo w", false, 30);
571 let matchers = vec![
572 make_matcher("Edit|Write", vec![h1.clone()]),
573 make_matcher("Shell", vec![h2.clone()]),
574 ];
575 let result_edit = matching_hooks(&matchers, "Edit");
576 assert_eq!(result_edit.len(), 1);
577
578 let result_shell = matching_hooks(&matchers, "Shell");
579 assert_eq!(result_shell.len(), 1);
580
581 let result_none = matching_hooks(&matchers, "Read");
582 assert!(result_none.is_empty());
583 }
584
585 #[test]
586 fn matching_hooks_no_match() {
587 let hook = cmd_hook("echo nope", false, 30);
588 let matchers = vec![make_matcher("Edit", vec![hook])];
589 let result = matching_hooks(&matchers, "Shell");
590 assert!(result.is_empty());
591 }
592
593 #[test]
594 fn matching_hooks_empty_token_ignored() {
595 let hook = cmd_hook("echo empty", false, 30);
596 let matchers = vec![make_matcher("|Edit|", vec![hook])];
597 let result = matching_hooks(&matchers, "Edit");
598 assert_eq!(result.len(), 1);
599 }
600
601 #[test]
602 fn matching_hooks_multiple_matchers_both_match() {
603 let h1 = cmd_hook("echo 1", false, 30);
604 let h2 = cmd_hook("echo 2", false, 30);
605 let matchers = vec![
606 make_matcher("Shell", vec![h1]),
607 make_matcher("Shell", vec![h2]),
608 ];
609 let result = matching_hooks(&matchers, "Shell");
610 assert_eq!(result.len(), 2);
611 }
612
613 #[tokio::test]
616 async fn fire_hooks_success() {
617 let hooks = vec![cmd_hook("true", false, 5)];
618 let env = HashMap::new();
619 assert!(fire_hooks(&hooks, &env, None, None).await.is_ok());
620 }
621
622 #[tokio::test]
623 async fn fire_hooks_fail_open_continues() {
624 let hooks = vec![
625 cmd_hook("false", false, 5), cmd_hook("true", false, 5), ];
628 let env = HashMap::new();
629 assert!(fire_hooks(&hooks, &env, None, None).await.is_ok());
630 }
631
632 #[tokio::test]
633 async fn fire_hooks_fail_closed_returns_err() {
634 let hooks = vec![cmd_hook("false", true, 5)];
635 let env = HashMap::new();
636 let result = fire_hooks(&hooks, &env, None, None).await;
637 assert!(result.is_err());
638 let err = result.unwrap_err();
639 assert!(matches!(err, HookError::NonZeroExit { .. }));
640 }
641
642 #[tokio::test]
643 async fn fire_hooks_timeout() {
644 let hooks = vec![cmd_hook("sleep 10", true, 1)];
645 let env = HashMap::new();
646 let result = fire_hooks(&hooks, &env, None, None).await;
647 assert!(result.is_err());
648 let err = result.unwrap_err();
649 assert!(matches!(err, HookError::Timeout { .. }));
650 }
651
652 #[tokio::test]
653 async fn fire_hooks_env_passed() {
654 let hooks = vec![cmd_hook(r#"test "$ZEPH_TEST_VAR" = "hello""#, true, 5)];
655 let mut env = HashMap::new();
656 env.insert("ZEPH_TEST_VAR".to_owned(), "hello".to_owned());
657 assert!(fire_hooks(&hooks, &env, None, None).await.is_ok());
658 }
659
660 #[tokio::test]
661 async fn fire_hooks_empty_list_ok() {
662 let env = HashMap::new();
663 assert!(fire_hooks(&[], &env, None, None).await.is_ok());
664 }
665
666 #[tokio::test]
667 async fn fire_hooks_mcp_unavailable_fail_open() {
668 let hooks = vec![HookDef {
669 action: HookAction::McpTool {
670 server: "srv".into(),
671 tool: "t".into(),
672 args: serde_json::Value::Null,
673 },
674 timeout_secs: 5,
675 fail_closed: false,
676 }];
677 let env = HashMap::new();
678 assert!(fire_hooks(&hooks, &env, None, None).await.is_ok());
680 }
681
682 #[tokio::test]
683 async fn fire_hooks_mcp_unavailable_fail_closed() {
684 let hooks = vec![HookDef {
685 action: HookAction::McpTool {
686 server: "srv".into(),
687 tool: "t".into(),
688 args: serde_json::Value::Null,
689 },
690 timeout_secs: 5,
691 fail_closed: true,
692 }];
693 let env = HashMap::new();
694 let result = fire_hooks(&hooks, &env, None, None).await;
695 assert!(matches!(result, Err(HookError::McpUnavailable { .. })));
696 }
697
698 struct CountingDispatch(std::sync::Arc<std::sync::atomic::AtomicU32>);
702
703 impl McpDispatch for CountingDispatch {
704 fn call_tool<'a>(
705 &'a self,
706 _server: &'a str,
707 _tool: &'a str,
708 _args: serde_json::Value,
709 ) -> std::pin::Pin<
710 Box<dyn std::future::Future<Output = Result<serde_json::Value, String>> + Send + 'a>,
711 > {
712 self.0.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
713 Box::pin(std::future::ready(Ok(serde_json::Value::Null)))
714 }
715 }
716
717 #[tokio::test]
718 async fn fire_hooks_mcp_dispatch_called_when_provided() {
719 let call_count = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
720 let dispatch = CountingDispatch(std::sync::Arc::clone(&call_count));
721
722 let hooks = vec![HookDef {
723 action: HookAction::McpTool {
724 server: "srv".into(),
725 tool: "t".into(),
726 args: serde_json::Value::Null,
727 },
728 timeout_secs: 5,
729 fail_closed: true,
730 }];
731 let env = HashMap::new();
732 let result = fire_hooks(&hooks, &env, Some(&dispatch), None).await;
733 assert!(
734 result.is_ok(),
735 "fire_hooks should succeed with mcp dispatch"
736 );
737 assert_eq!(
738 call_count.load(std::sync::atomic::Ordering::SeqCst),
739 1,
740 "MCP dispatch should have been called exactly once"
741 );
742 }
743
744 #[tokio::test]
747 async fn fire_hooks_stdout_replacement_json() {
748 let cmd = r#"printf '{"hookSpecificOutput":{"updatedToolOutput":"replaced"}}'"#;
749 let hooks = vec![cmd_hook(cmd, true, 5)];
750 let env = HashMap::new();
751 let result = fire_hooks(&hooks, &env, None, None).await.unwrap();
752 assert_eq!(
753 result.output.updated_tool_output.as_deref(),
754 Some("replaced")
755 );
756 }
757
758 #[tokio::test]
759 async fn fire_hooks_stdout_empty_no_replacement() {
760 let hooks = vec![cmd_hook("true", true, 5)];
761 let env = HashMap::new();
762 let result = fire_hooks(&hooks, &env, None, None).await.unwrap();
763 assert!(result.output.updated_tool_output.is_none());
764 }
765
766 #[tokio::test]
767 async fn fire_hooks_stdout_non_json_no_replacement() {
768 let hooks = vec![cmd_hook("echo hello", true, 5)];
769 let env = HashMap::new();
770 let result = fire_hooks(&hooks, &env, None, None).await.unwrap();
771 assert!(result.output.updated_tool_output.is_none());
772 }
773
774 #[tokio::test]
775 async fn fire_hooks_stdout_null_updatedtooloutput_no_replacement() {
776 let cmd = r#"printf '{"hookSpecificOutput":{"updatedToolOutput":null}}'"#;
777 let hooks = vec![cmd_hook(cmd, true, 5)];
778 let env = HashMap::new();
779 let result = fire_hooks(&hooks, &env, None, None).await.unwrap();
780 assert!(result.output.updated_tool_output.is_none());
781 }
782
783 #[tokio::test]
784 async fn fire_hooks_stdin_passed_to_hook() {
785 let cmd = r#"python3 -c "import sys,json; d=json.load(sys.stdin); exit(0 if 'duration_ms' in d else 1)""#;
787 let hooks = vec![cmd_hook(cmd, true, 10)];
788 let env = HashMap::new();
789 let stdin = br#"{"tool_name":"Shell","tool_args":{},"duration_ms":42}"#;
790 let result = fire_hooks(&hooks, &env, None, Some(stdin)).await;
791 assert!(
792 result.is_ok(),
793 "hook should succeed when stdin has duration_ms"
794 );
795 }
796
797 #[tokio::test]
798 async fn fire_hooks_chaining_last_replacement_wins() {
799 let h1 = cmd_hook(
801 r#"printf '{"hookSpecificOutput":{"updatedToolOutput":"first"}}'"#,
802 false,
803 5,
804 );
805 let h2 = cmd_hook(
806 r#"printf '{"hookSpecificOutput":{"updatedToolOutput":"second"}}'"#,
807 false,
808 5,
809 );
810 let hooks = vec![h1, h2];
811 let env = HashMap::new();
812 let result = fire_hooks(&hooks, &env, None, None).await.unwrap();
813 assert_eq!(result.output.updated_tool_output.as_deref(), Some("second"));
814 }
815
816 #[test]
819 fn subagent_hooks_parses_from_yaml() {
820 let yaml = r#"
821PreToolUse:
822 - matcher: "Edit|Write"
823 hooks:
824 - type: command
825 command: "echo pre"
826 timeout_secs: 10
827 fail_closed: false
828PostToolUse:
829 - matcher: "Shell"
830 hooks:
831 - type: command
832 command: "echo post"
833"#;
834 let hooks: SubagentHooks = serde_norway::from_str(yaml).unwrap();
835 assert_eq!(hooks.pre_tool_use.len(), 1);
836 assert_eq!(hooks.pre_tool_use[0].matcher, "Edit|Write");
837 assert_eq!(hooks.pre_tool_use[0].hooks.len(), 1);
838 assert!(
839 matches!(&hooks.pre_tool_use[0].hooks[0].action, HookAction::Command { command } if command == "echo pre")
840 );
841 assert_eq!(hooks.post_tool_use.len(), 1);
842 }
843
844 #[test]
845 fn subagent_hooks_defaults_timeout() {
846 let yaml = r#"
847PreToolUse:
848 - matcher: "Edit"
849 hooks:
850 - type: command
851 command: "echo hi"
852"#;
853 let hooks: SubagentHooks = serde_norway::from_str(yaml).unwrap();
854 assert_eq!(hooks.pre_tool_use[0].hooks[0].timeout_secs, 30);
855 assert!(!hooks.pre_tool_use[0].hooks[0].fail_closed);
856 }
857
858 #[test]
859 fn subagent_hooks_empty_default() {
860 let hooks = SubagentHooks::default();
861 assert!(hooks.pre_tool_use.is_empty());
862 assert!(hooks.post_tool_use.is_empty());
863 }
864
865 #[tokio::test]
872 async fn fire_shell_hook_timeout_with_stdout_does_not_deadlock() {
873 let cmd = r#"echo "some output"; sleep 60"#;
876 let hooks = vec![cmd_hook(cmd, true, 1)];
877 let env = HashMap::new();
878
879 let result = tokio::time::timeout(
881 std::time::Duration::from_secs(5),
882 fire_hooks(&hooks, &env, None, None),
883 )
884 .await
885 .expect("fire_hooks must return within 5 s — deadlock regression #4011");
886
887 assert!(
888 matches!(result, Err(HookError::Timeout { .. })),
889 "expected HookError::Timeout, got: {result:?}"
890 );
891 }
892}