1use std::collections::HashMap;
76use std::hash::BuildHasher;
77use std::time::Duration;
78
79use serde::Serialize;
80use thiserror::Error;
81use tokio::io::AsyncWriteExt as _;
82use tokio::process::Command;
83use tokio::time::timeout;
84
85pub use zeph_config::{HookAction, HookDef, HookMatcher, SubagentHooks};
86
87#[derive(Debug, Default)]
94pub struct HookOutput {
95 pub updated_tool_output: Option<String>,
99}
100
101#[derive(Debug, Default)]
105pub struct HookRunResult {
106 pub output: HookOutput,
109}
110
111#[derive(Debug, Serialize)]
119pub struct PostToolUseHookInput<'a> {
120 pub tool_name: &'a str,
122 pub tool_args: &'a serde_json::Value,
124 #[serde(skip_serializing_if = "Option::is_none")]
126 pub session_id: Option<&'a str>,
127 pub duration_ms: u64,
129 #[serde(skip_serializing_if = "Option::is_none")]
131 pub tool_output: Option<&'a str>,
132 #[serde(skip_serializing_if = "Option::is_none")]
134 pub tool_error: Option<&'a str>,
135}
136
137const HOOK_STDOUT_CAP: usize = 1024 * 1024; pub trait McpDispatch: Send + Sync {
153 fn call_tool<'a>(
155 &'a self,
156 server: &'a str,
157 tool: &'a str,
158 args: serde_json::Value,
159 ) -> std::pin::Pin<
160 Box<dyn std::future::Future<Output = Result<serde_json::Value, String>> + Send + 'a>,
161 >;
162}
163
164#[non_exhaustive]
168#[derive(Debug, Error)]
169pub enum HookError {
170 #[error("hook command failed (exit code {code}): {command}")]
172 NonZeroExit { command: String, code: i32 },
173
174 #[error("hook command timed out after {timeout_secs}s: {command}")]
176 Timeout { command: String, timeout_secs: u64 },
177
178 #[error("hook I/O error for command '{command}': {source}")]
180 Io {
181 command: String,
182 #[source]
183 source: std::io::Error,
184 },
185
186 #[error(
188 "mcp_tool hook requires an MCP manager but none was provided (server={server}, tool={tool})"
189 )]
190 McpUnavailable { server: String, tool: String },
191
192 #[error("mcp_tool hook failed (server={server}, tool={tool}): {reason}")]
194 McpToolFailed {
195 server: String,
196 tool: String,
197 reason: String,
198 },
199}
200
201#[must_use]
222pub fn matching_hooks<'a>(matchers: &'a [HookMatcher], tool_name: &str) -> Vec<&'a HookDef> {
223 let mut result = Vec::new();
224 for m in matchers {
225 let matched = m
226 .matcher
227 .split('|')
228 .filter(|token| !token.is_empty())
229 .any(|token| tool_name.contains(token));
230 if matched {
231 result.extend(m.hooks.iter());
232 }
233 }
234 result
235}
236
237pub const TOOL_ARGS_JSON_LIMIT: usize = 64 * 1024;
243
244#[must_use]
263pub fn make_base_hook_env(
264 tool_name: &str,
265 tool_input: &serde_json::Value,
266) -> HashMap<String, String> {
267 let mut env = HashMap::new();
268 env.insert("ZEPH_TOOL_NAME".to_owned(), tool_name.to_owned());
269
270 let raw = serde_json::to_string(tool_input).unwrap_or_default();
271 let args_json = if raw.len() > TOOL_ARGS_JSON_LIMIT {
272 tracing::warn!(
273 tool = tool_name,
274 len = raw.len(),
275 limit = TOOL_ARGS_JSON_LIMIT,
276 "ZEPH_TOOL_ARGS_JSON truncated for hook dispatch"
277 );
278 let limit = raw.floor_char_boundary(TOOL_ARGS_JSON_LIMIT);
279 format!("{}…", &raw[..limit])
280 } else {
281 raw
282 };
283 env.insert("ZEPH_TOOL_ARGS_JSON".to_owned(), args_json);
284
285 env
286}
287
288#[tracing::instrument(name = "subagent.hooks.fire", skip_all, fields(hook_count = hooks.len()))]
315pub async fn fire_hooks<S: BuildHasher>(
316 hooks: &[HookDef],
317 env: &HashMap<String, String, S>,
318 mcp: Option<&dyn McpDispatch>,
319 stdin_json: Option<&[u8]>,
320) -> Result<HookRunResult, HookError> {
321 let mut run_result = HookRunResult::default();
322 for hook in hooks {
323 let effective_stdin = run_result
326 .output
327 .updated_tool_output
328 .as_deref()
329 .map(str::as_bytes)
330 .or(stdin_json);
331 let result = fire_single_hook(hook, env, mcp, effective_stdin).await;
332 match result {
333 Ok(hook_output) => {
334 if hook_output.updated_tool_output.is_some() {
335 run_result.output.updated_tool_output = hook_output.updated_tool_output;
336 }
337 }
338 Err(e) if hook.fail_closed => {
339 tracing::error!(
340 error = %e,
341 "fail-closed hook failed — aborting"
342 );
343 return Err(e);
344 }
345 Err(e) => {
346 tracing::warn!(
347 error = %e,
348 "hook failed (fail_open) — continuing"
349 );
350 }
351 }
352 }
353 Ok(run_result)
354}
355
356#[tracing::instrument(name = "subagent.hooks.single", skip_all)]
357async fn fire_single_hook<S: BuildHasher>(
358 hook: &HookDef,
359 env: &HashMap<String, String, S>,
360 mcp: Option<&dyn McpDispatch>,
361 stdin_json: Option<&[u8]>,
362) -> Result<HookOutput, HookError> {
363 match &hook.action {
364 HookAction::Command { command } => {
365 fire_shell_hook(command, hook.timeout_secs, env, stdin_json).await
366 }
367 HookAction::McpTool { server, tool, args } => {
368 let dispatcher = mcp.ok_or_else(|| HookError::McpUnavailable {
369 server: server.clone(),
370 tool: tool.clone(),
371 })?;
372 let call_fut = dispatcher.call_tool(server, tool, args.clone());
373 match timeout(Duration::from_secs(hook.timeout_secs), call_fut).await {
374 Ok(Ok(_)) => {
375 Ok(HookOutput::default())
377 }
378 Ok(Err(reason)) => Err(HookError::McpToolFailed {
379 server: server.clone(),
380 tool: tool.clone(),
381 reason,
382 }),
383 Err(_) => Err(HookError::Timeout {
384 command: format!("mcp_tool:{server}/{tool}"),
385 timeout_secs: hook.timeout_secs,
386 }),
387 }
388 }
389 _ => Ok(HookOutput::default()),
390 }
391}
392
393#[tracing::instrument(name = "subagent.hooks.shell", skip_all, fields(timeout_secs))]
394async fn fire_shell_hook<S: BuildHasher>(
395 command: &str,
396 timeout_secs: u64,
397 env: &HashMap<String, String, S>,
398 stdin_json: Option<&[u8]>,
399) -> Result<HookOutput, HookError> {
400 use std::process::Stdio;
401 use tokio::io::AsyncReadExt as _;
402
403 let mut cmd = Command::new("sh");
404 cmd.arg("-c").arg(command);
405 cmd.env_clear();
407 if let Ok(path) = std::env::var("PATH") {
409 cmd.env("PATH", path);
410 }
411 for (k, v) in env {
412 cmd.env(k, v);
413 }
414 cmd.stdin(if stdin_json.is_some() {
415 Stdio::piped()
416 } else {
417 Stdio::null()
418 });
419 cmd.stdout(Stdio::piped());
421 cmd.stderr(Stdio::null());
422
423 let mut child = cmd.spawn().map_err(|e| HookError::Io {
424 command: command.to_owned(),
425 source: e,
426 })?;
427
428 if let Some(bytes) = stdin_json
431 && let Some(mut stdin_handle) = child.stdin.take()
432 && let Err(e) = stdin_handle.write_all(bytes).await
433 {
434 tracing::warn!(
435 command,
436 error = %e,
437 "failed to write stdin to hook — continuing without stdin data"
438 );
439 }
440
441 let stdout_handle = child.stdout.take();
444 match timeout(Duration::from_secs(timeout_secs), child.wait()).await {
445 Ok(Ok(status)) => {
446 let mut stdout_bytes = Vec::new();
447 if let Some(handle) = stdout_handle {
448 let mut limited = handle.take(HOOK_STDOUT_CAP as u64 + 1);
449 let _ = limited.read_to_end(&mut stdout_bytes).await;
450 }
451 if status.success() {
452 Ok(parse_hook_stdout(command, &stdout_bytes))
453 } else {
454 Err(HookError::NonZeroExit {
455 command: command.to_owned(),
456 code: status.code().unwrap_or(-1),
457 })
458 }
459 }
460 Ok(Err(e)) => Err(HookError::Io {
461 command: command.to_owned(),
462 source: e,
463 }),
464 Err(_) => {
465 let _ = child.kill().await;
467 Err(HookError::Timeout {
468 command: command.to_owned(),
469 timeout_secs,
470 })
471 }
472 }
473}
474
475fn parse_hook_stdout(command: &str, bytes: &[u8]) -> HookOutput {
480 if bytes.is_empty() {
481 return HookOutput::default();
482 }
483 if bytes.len() > HOOK_STDOUT_CAP {
484 tracing::warn!(
485 command,
486 bytes = bytes.len(),
487 cap = HOOK_STDOUT_CAP,
488 "hook stdout exceeds 1 MiB cap — treating as no substitution"
489 );
490 return HookOutput::default();
491 }
492 let Ok(text) = std::str::from_utf8(bytes) else {
493 tracing::warn!(command, "hook stdout is not valid UTF-8 — no substitution");
494 return HookOutput::default();
495 };
496 let Ok(json) = serde_json::from_str::<serde_json::Value>(text) else {
498 return HookOutput::default();
499 };
500 let updated = json
501 .get("hookSpecificOutput")
502 .and_then(|h| h.get("updatedToolOutput"));
503
504 match updated {
505 None | Some(serde_json::Value::Null) => HookOutput::default(),
506 Some(serde_json::Value::String(s)) => HookOutput {
507 updated_tool_output: Some(s.clone()),
508 },
509 Some(other) => {
510 tracing::warn!(
511 command,
512 kind = other
513 .is_object()
514 .then_some("object")
515 .or_else(|| other.is_array().then_some("array"))
516 .or_else(|| other.is_number().then_some("number"))
517 .or_else(|| other.is_boolean().then_some("boolean"))
518 .unwrap_or("unknown"),
519 "hookSpecificOutput.updatedToolOutput has unexpected type — no substitution"
520 );
521 HookOutput::default()
522 }
523 }
524}
525
526#[cfg(test)]
529mod tests {
530 use super::*;
531
532 fn cmd_hook(command: &str, fail_closed: bool, timeout_secs: u64) -> HookDef {
533 HookDef {
534 action: HookAction::Command {
535 command: command.to_owned(),
536 },
537 timeout_secs,
538 fail_closed,
539 }
540 }
541
542 fn make_matcher(matcher: &str, hooks: Vec<HookDef>) -> HookMatcher {
543 HookMatcher {
544 matcher: matcher.to_owned(),
545 hooks,
546 }
547 }
548
549 #[test]
552 fn matching_hooks_exact_name() {
553 let hook = cmd_hook("echo hi", false, 30);
554 let matchers = vec![make_matcher("Edit", vec![hook.clone()])];
555 let result = matching_hooks(&matchers, "Edit");
556 assert_eq!(result.len(), 1);
557 assert!(
558 matches!(&result[0].action, HookAction::Command { command } if command == "echo hi")
559 );
560 }
561
562 #[test]
563 fn matching_hooks_substring() {
564 let hook = cmd_hook("echo sub", false, 30);
565 let matchers = vec![make_matcher("Edit", vec![hook.clone()])];
566 let result = matching_hooks(&matchers, "EditFile");
567 assert_eq!(result.len(), 1);
568 }
569
570 #[test]
571 fn matching_hooks_pipe_separated() {
572 let h1 = cmd_hook("echo e", false, 30);
573 let h2 = cmd_hook("echo w", false, 30);
574 let matchers = vec![
575 make_matcher("Edit|Write", vec![h1.clone()]),
576 make_matcher("Shell", vec![h2.clone()]),
577 ];
578 let result_edit = matching_hooks(&matchers, "Edit");
579 assert_eq!(result_edit.len(), 1);
580
581 let result_shell = matching_hooks(&matchers, "Shell");
582 assert_eq!(result_shell.len(), 1);
583
584 let result_none = matching_hooks(&matchers, "Read");
585 assert!(result_none.is_empty());
586 }
587
588 #[test]
589 fn matching_hooks_no_match() {
590 let hook = cmd_hook("echo nope", false, 30);
591 let matchers = vec![make_matcher("Edit", vec![hook])];
592 let result = matching_hooks(&matchers, "Shell");
593 assert!(result.is_empty());
594 }
595
596 #[test]
597 fn matching_hooks_empty_token_ignored() {
598 let hook = cmd_hook("echo empty", false, 30);
599 let matchers = vec![make_matcher("|Edit|", vec![hook])];
600 let result = matching_hooks(&matchers, "Edit");
601 assert_eq!(result.len(), 1);
602 }
603
604 #[test]
605 fn matching_hooks_multiple_matchers_both_match() {
606 let h1 = cmd_hook("echo 1", false, 30);
607 let h2 = cmd_hook("echo 2", false, 30);
608 let matchers = vec![
609 make_matcher("Shell", vec![h1]),
610 make_matcher("Shell", vec![h2]),
611 ];
612 let result = matching_hooks(&matchers, "Shell");
613 assert_eq!(result.len(), 2);
614 }
615
616 #[tokio::test]
619 async fn fire_hooks_success() {
620 let hooks = vec![cmd_hook("true", false, 5)];
621 let env = HashMap::new();
622 assert!(fire_hooks(&hooks, &env, None, None).await.is_ok());
623 }
624
625 #[tokio::test]
626 async fn fire_hooks_fail_open_continues() {
627 let hooks = vec![
628 cmd_hook("false", false, 5), cmd_hook("true", false, 5), ];
631 let env = HashMap::new();
632 assert!(fire_hooks(&hooks, &env, None, None).await.is_ok());
633 }
634
635 #[tokio::test]
636 async fn fire_hooks_fail_closed_returns_err() {
637 let hooks = vec![cmd_hook("false", true, 5)];
638 let env = HashMap::new();
639 let result = fire_hooks(&hooks, &env, None, None).await;
640 assert!(result.is_err());
641 let err = result.unwrap_err();
642 assert!(matches!(err, HookError::NonZeroExit { .. }));
643 }
644
645 #[tokio::test]
646 async fn fire_hooks_timeout() {
647 let hooks = vec![cmd_hook("sleep 10", true, 1)];
648 let env = HashMap::new();
649 let result = fire_hooks(&hooks, &env, None, None).await;
650 assert!(result.is_err());
651 let err = result.unwrap_err();
652 assert!(matches!(err, HookError::Timeout { .. }));
653 }
654
655 #[tokio::test]
656 async fn fire_hooks_env_passed() {
657 let hooks = vec![cmd_hook(r#"test "$ZEPH_TEST_VAR" = "hello""#, true, 5)];
658 let mut env = HashMap::new();
659 env.insert("ZEPH_TEST_VAR".to_owned(), "hello".to_owned());
660 assert!(fire_hooks(&hooks, &env, None, None).await.is_ok());
661 }
662
663 #[tokio::test]
664 async fn fire_hooks_empty_list_ok() {
665 let env = HashMap::new();
666 assert!(fire_hooks(&[], &env, None, None).await.is_ok());
667 }
668
669 #[tokio::test]
670 async fn fire_hooks_mcp_unavailable_fail_open() {
671 let hooks = vec![HookDef {
672 action: HookAction::McpTool {
673 server: "srv".into(),
674 tool: "t".into(),
675 args: serde_json::Value::Null,
676 },
677 timeout_secs: 5,
678 fail_closed: false,
679 }];
680 let env = HashMap::new();
681 assert!(fire_hooks(&hooks, &env, None, None).await.is_ok());
683 }
684
685 #[tokio::test]
686 async fn fire_hooks_mcp_unavailable_fail_closed() {
687 let hooks = vec![HookDef {
688 action: HookAction::McpTool {
689 server: "srv".into(),
690 tool: "t".into(),
691 args: serde_json::Value::Null,
692 },
693 timeout_secs: 5,
694 fail_closed: true,
695 }];
696 let env = HashMap::new();
697 let result = fire_hooks(&hooks, &env, None, None).await;
698 assert!(matches!(result, Err(HookError::McpUnavailable { .. })));
699 }
700
701 struct CountingDispatch(std::sync::Arc<std::sync::atomic::AtomicU32>);
705
706 impl McpDispatch for CountingDispatch {
707 fn call_tool<'a>(
708 &'a self,
709 _server: &'a str,
710 _tool: &'a str,
711 _args: serde_json::Value,
712 ) -> std::pin::Pin<
713 Box<dyn std::future::Future<Output = Result<serde_json::Value, String>> + Send + 'a>,
714 > {
715 self.0.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
716 Box::pin(std::future::ready(Ok(serde_json::Value::Null)))
717 }
718 }
719
720 #[tokio::test]
721 async fn fire_hooks_mcp_dispatch_called_when_provided() {
722 let call_count = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
723 let dispatch = CountingDispatch(std::sync::Arc::clone(&call_count));
724
725 let hooks = vec![HookDef {
726 action: HookAction::McpTool {
727 server: "srv".into(),
728 tool: "t".into(),
729 args: serde_json::Value::Null,
730 },
731 timeout_secs: 5,
732 fail_closed: true,
733 }];
734 let env = HashMap::new();
735 let result = fire_hooks(&hooks, &env, Some(&dispatch), None).await;
736 assert!(
737 result.is_ok(),
738 "fire_hooks should succeed with mcp dispatch"
739 );
740 assert_eq!(
741 call_count.load(std::sync::atomic::Ordering::SeqCst),
742 1,
743 "MCP dispatch should have been called exactly once"
744 );
745 }
746
747 #[tokio::test]
750 async fn fire_hooks_stdout_replacement_json() {
751 let cmd = r#"printf '{"hookSpecificOutput":{"updatedToolOutput":"replaced"}}'"#;
752 let hooks = vec![cmd_hook(cmd, true, 5)];
753 let env = HashMap::new();
754 let result = fire_hooks(&hooks, &env, None, None).await.unwrap();
755 assert_eq!(
756 result.output.updated_tool_output.as_deref(),
757 Some("replaced")
758 );
759 }
760
761 #[tokio::test]
762 async fn fire_hooks_stdout_empty_no_replacement() {
763 let hooks = vec![cmd_hook("true", true, 5)];
764 let env = HashMap::new();
765 let result = fire_hooks(&hooks, &env, None, None).await.unwrap();
766 assert!(result.output.updated_tool_output.is_none());
767 }
768
769 #[tokio::test]
770 async fn fire_hooks_stdout_non_json_no_replacement() {
771 let hooks = vec![cmd_hook("echo hello", true, 5)];
772 let env = HashMap::new();
773 let result = fire_hooks(&hooks, &env, None, None).await.unwrap();
774 assert!(result.output.updated_tool_output.is_none());
775 }
776
777 #[tokio::test]
778 async fn fire_hooks_stdout_null_updatedtooloutput_no_replacement() {
779 let cmd = r#"printf '{"hookSpecificOutput":{"updatedToolOutput":null}}'"#;
780 let hooks = vec![cmd_hook(cmd, true, 5)];
781 let env = HashMap::new();
782 let result = fire_hooks(&hooks, &env, None, None).await.unwrap();
783 assert!(result.output.updated_tool_output.is_none());
784 }
785
786 #[tokio::test]
787 async fn fire_hooks_stdin_passed_to_hook() {
788 let cmd = r#"python3 -c "import sys,json; d=json.load(sys.stdin); exit(0 if 'duration_ms' in d else 1)""#;
790 let hooks = vec![cmd_hook(cmd, true, 10)];
791 let env = HashMap::new();
792 let stdin = br#"{"tool_name":"Shell","tool_args":{},"duration_ms":42}"#;
793 let result = fire_hooks(&hooks, &env, None, Some(stdin)).await;
794 assert!(
795 result.is_ok(),
796 "hook should succeed when stdin has duration_ms"
797 );
798 }
799
800 #[tokio::test]
801 async fn fire_hooks_chaining_last_replacement_wins() {
802 let h1 = cmd_hook(
804 r#"printf '{"hookSpecificOutput":{"updatedToolOutput":"first"}}'"#,
805 false,
806 5,
807 );
808 let h2 = cmd_hook(
809 r#"printf '{"hookSpecificOutput":{"updatedToolOutput":"second"}}'"#,
810 false,
811 5,
812 );
813 let hooks = vec![h1, h2];
814 let env = HashMap::new();
815 let result = fire_hooks(&hooks, &env, None, None).await.unwrap();
816 assert_eq!(result.output.updated_tool_output.as_deref(), Some("second"));
817 }
818
819 #[test]
822 fn subagent_hooks_parses_from_yaml() {
823 let yaml = r#"
824PreToolUse:
825 - matcher: "Edit|Write"
826 hooks:
827 - type: command
828 command: "echo pre"
829 timeout_secs: 10
830 fail_closed: false
831PostToolUse:
832 - matcher: "Shell"
833 hooks:
834 - type: command
835 command: "echo post"
836"#;
837 let hooks: SubagentHooks = serde_norway::from_str(yaml).unwrap();
838 assert_eq!(hooks.pre_tool_use.len(), 1);
839 assert_eq!(hooks.pre_tool_use[0].matcher, "Edit|Write");
840 assert_eq!(hooks.pre_tool_use[0].hooks.len(), 1);
841 assert!(
842 matches!(&hooks.pre_tool_use[0].hooks[0].action, HookAction::Command { command } if command == "echo pre")
843 );
844 assert_eq!(hooks.post_tool_use.len(), 1);
845 }
846
847 #[test]
848 fn subagent_hooks_defaults_timeout() {
849 let yaml = r#"
850PreToolUse:
851 - matcher: "Edit"
852 hooks:
853 - type: command
854 command: "echo hi"
855"#;
856 let hooks: SubagentHooks = serde_norway::from_str(yaml).unwrap();
857 assert_eq!(hooks.pre_tool_use[0].hooks[0].timeout_secs, 30);
858 assert!(!hooks.pre_tool_use[0].hooks[0].fail_closed);
859 }
860
861 #[test]
862 fn subagent_hooks_empty_default() {
863 let hooks = SubagentHooks::default();
864 assert!(hooks.pre_tool_use.is_empty());
865 assert!(hooks.post_tool_use.is_empty());
866 }
867
868 #[tokio::test]
875 async fn fire_shell_hook_timeout_with_stdout_does_not_deadlock() {
876 let cmd = r#"echo "some output"; sleep 60"#;
879 let hooks = vec![cmd_hook(cmd, true, 1)];
880 let env = HashMap::new();
881
882 let result = tokio::time::timeout(
884 std::time::Duration::from_secs(5),
885 fire_hooks(&hooks, &env, None, None),
886 )
887 .await
888 .expect("fire_hooks must return within 5 s — deadlock regression #4011");
889
890 assert!(
891 matches!(result, Err(HookError::Timeout { .. })),
892 "expected HookError::Timeout, got: {result:?}"
893 );
894 }
895}