use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::process::Command;
use tokio::time::{timeout, Duration};
pub const HOOK_EVENTS: &[&str] = &[
"PreToolUse",
"PostToolUse",
"PostToolUseFailure",
"SessionStart",
"SessionEnd",
"Stop",
"SubagentStart",
"SubagentStop",
"UserPromptSubmit",
"PermissionRequest",
"PermissionDenied",
"TaskCreated",
"TaskCompleted",
"ConfigChange",
"CwdChanged",
"FileChanged",
"Notification",
"PreCompact",
"PostCompact",
"TeammateIdle",
];
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub enum HookEvent {
PreToolUse,
PostToolUse,
PostToolUseFailure,
SessionStart,
SessionEnd,
Stop,
SubagentStart,
SubagentStop,
UserPromptSubmit,
PermissionRequest,
PermissionDenied,
TaskCreated,
TaskCompleted,
ConfigChange,
CwdChanged,
FileChanged,
Notification,
PreCompact,
PostCompact,
TeammateIdle,
}
impl HookEvent {
pub fn as_str(&self) -> &'static str {
match self {
HookEvent::PreToolUse => "PreToolUse",
HookEvent::PostToolUse => "PostToolUse",
HookEvent::PostToolUseFailure => "PostToolUseFailure",
HookEvent::SessionStart => "SessionStart",
HookEvent::SessionEnd => "SessionEnd",
HookEvent::Stop => "Stop",
HookEvent::SubagentStart => "SubagentStart",
HookEvent::SubagentStop => "SubagentStop",
HookEvent::UserPromptSubmit => "UserPromptSubmit",
HookEvent::PermissionRequest => "PermissionRequest",
HookEvent::PermissionDenied => "PermissionDenied",
HookEvent::TaskCreated => "TaskCreated",
HookEvent::TaskCompleted => "TaskCompleted",
HookEvent::ConfigChange => "ConfigChange",
HookEvent::CwdChanged => "CwdChanged",
HookEvent::FileChanged => "FileChanged",
HookEvent::Notification => "Notification",
HookEvent::PreCompact => "PreCompact",
HookEvent::PostCompact => "PostCompact",
HookEvent::TeammateIdle => "TeammateIdle",
}
}
pub fn from_str(s: &str) -> Option<Self> {
match s {
"PreToolUse" => Some(HookEvent::PreToolUse),
"PostToolUse" => Some(HookEvent::PostToolUse),
"PostToolUseFailure" => Some(HookEvent::PostToolUseFailure),
"SessionStart" => Some(HookEvent::SessionStart),
"SessionEnd" => Some(HookEvent::SessionEnd),
"Stop" => Some(HookEvent::Stop),
"SubagentStart" => Some(HookEvent::SubagentStart),
"SubagentStop" => Some(HookEvent::SubagentStop),
"UserPromptSubmit" => Some(HookEvent::UserPromptSubmit),
"PermissionRequest" => Some(HookEvent::PermissionRequest),
"PermissionDenied" => Some(HookEvent::PermissionDenied),
"TaskCreated" => Some(HookEvent::TaskCreated),
"TaskCompleted" => Some(HookEvent::TaskCompleted),
"ConfigChange" => Some(HookEvent::ConfigChange),
"CwdChanged" => Some(HookEvent::CwdChanged),
"FileChanged" => Some(HookEvent::FileChanged),
"Notification" => Some(HookEvent::Notification),
"PreCompact" => Some(HookEvent::PreCompact),
"PostCompact" => Some(HookEvent::PostCompact),
"TeammateIdle" => Some(HookEvent::TeammateIdle),
_ => None,
}
}
}
#[derive(Debug, Clone)]
pub struct HookDefinition {
pub command: Option<String>,
pub timeout: Option<u64>,
pub matcher: Option<String>,
}
impl<'de> Deserialize<'de> for HookDefinition {
fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
Ok(HookDefinition {
command: None,
timeout: Some(30000),
matcher: None,
})
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct HookInput {
pub event: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_input: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_output: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_use_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub session_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub cwd: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
}
impl HookInput {
pub fn new(event: &str) -> Self {
Self {
event: event.to_string(),
tool_name: None,
tool_input: None,
tool_output: None,
tool_use_id: None,
session_id: None,
cwd: None,
error: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct HookOutput {
#[serde(skip_serializing_if = "Option::is_none")]
pub message: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub permission_update: Option<PermissionUpdate>,
#[serde(skip_serializing_if = "Option::is_none")]
pub block: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub notification: Option<Notification>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PermissionUpdate {
pub tool: String,
pub behavior: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Notification {
pub title: String,
pub body: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub level: Option<String>,
}
pub type HookConfig = HashMap<String, Vec<HookDefinition>>;
#[derive(Debug, Default)]
pub struct HookRegistry {
hooks: HashMap<String, Vec<HookDefinition>>,
}
impl HookRegistry {
pub fn new() -> Self {
Self {
hooks: HashMap::new(),
}
}
pub fn register_from_config(&mut self, config: HookConfig) {
for (event, definitions) in config {
if !HOOK_EVENTS.contains(&event.as_str()) {
continue;
}
let existing = self.hooks.entry(event).or_insert_with(Vec::new);
existing.extend(definitions);
}
}
pub fn register(&mut self, event: &str, definition: HookDefinition) {
if !HOOK_EVENTS.contains(&event) {
return;
}
let existing = self.hooks.entry(event.to_string()).or_insert_with(Vec::new);
existing.push(definition);
}
pub async fn execute(
&self,
event: &str,
mut input: HookInput,
) -> Vec<HookOutput> {
let definitions = match self.hooks.get(event) {
Some(d) => d,
None => return vec![],
};
input.event = event.to_string();
let mut results = Vec::new();
for def in definitions {
if let Some(matcher) = &def.matcher {
if let Some(tool_name) = &input.tool_name {
if let Ok(re) = regex::Regex::new(matcher) {
if !re.is_match(tool_name) {
continue;
}
}
}
}
if let Some(command) = &def.command {
match execute_shell_hook(command, &input, def.timeout.unwrap_or(30000)).await {
Ok(output) => {
if let Some(o) = output {
results.push(o);
}
}
Err(e) => {
eprintln!("[Hook] {} hook failed: {}", event, e);
}
}
}
}
results
}
pub fn has_hooks(&self, event: &str) -> bool {
self.hooks.get(event).map(|h| !h.is_empty()).unwrap_or(false)
}
pub fn clear(&mut self) {
self.hooks.clear();
}
}
async fn execute_shell_hook(
command: &str,
input: &HookInput,
timeout_ms: u64,
) -> Result<Option<HookOutput>, crate::error::AgentError> {
let input_json = serde_json::to_string(input)
.map_err(crate::error::AgentError::Json)?;
let cmd_str = command.to_string();
let event = input.event.clone();
let tool_name = input.tool_name.clone();
let session_id = input.session_id.clone();
let cwd = input.cwd.clone();
let result = timeout(
Duration::from_millis(timeout_ms),
tokio::task::spawn_blocking(move || {
let mut cmd = Command::new("bash");
cmd.args(["-c", &cmd_str])
.env("HOOK_EVENT", &event)
.env("HOOK_TOOL_NAME", tool_name.as_deref().unwrap_or(""))
.env("HOOK_SESSION_ID", session_id.as_deref().unwrap_or(""))
.env("HOOK_CWD", cwd.as_deref().unwrap_or(""))
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped());
let mut child = cmd.spawn()?;
use std::io::Write;
if let Some(mut stdin) = child.stdin.take() {
stdin.write_all(input_json.as_bytes())?;
}
let output = child.wait_with_output()?;
if !output.status.success() {
return Ok(None);
}
let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string();
if stdout.is_empty() {
return Ok(None);
}
if let Ok(hook_output) = serde_json::from_str::<HookOutput>(&stdout) {
Ok(Some(hook_output))
} else {
Ok(Some(HookOutput {
message: Some(stdout),
permission_update: None,
block: None,
notification: None,
}))
}
}),
)
.await;
match result {
Ok(Ok(r)) => r,
Ok(Err(e)) => {
let err = std::io::Error::new(std::io::ErrorKind::Other, e.to_string());
Err(crate::error::AgentError::Io(err))
}
Err(_) => {
let err = std::io::Error::new(std::io::ErrorKind::TimedOut, "Hook timeout");
Err(crate::error::AgentError::Io(err))
}
}
}
pub fn create_hook_registry(config: Option<HookConfig>) -> HookRegistry {
let mut registry = HookRegistry::new();
if let Some(c) = config {
registry.register_from_config(c);
}
registry
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hook_event_as_str() {
assert_eq!(HookEvent::PreToolUse.as_str(), "PreToolUse");
assert_eq!(HookEvent::PostToolUse.as_str(), "PostToolUse");
assert_eq!(HookEvent::SessionStart.as_str(), "SessionStart");
}
#[test]
fn test_hook_event_from_str() {
assert_eq!(HookEvent::from_str("PreToolUse"), Some(HookEvent::PreToolUse));
assert_eq!(HookEvent::from_str("Invalid"), None);
}
#[test]
fn test_hook_events_constant() {
assert!(HOOK_EVENTS.contains(&"PreToolUse"));
assert!(HOOK_EVENTS.contains(&"PostToolUse"));
assert!(HOOK_EVENTS.contains(&"SessionStart"));
}
#[test]
fn test_hook_registry_new() {
let registry = HookRegistry::new();
assert!(!registry.has_hooks("PreToolUse"));
}
#[test]
fn test_hook_registry_register() {
let mut registry = HookRegistry::new();
registry.register(
"PreToolUse",
HookDefinition {
command: Some("echo test".to_string()),
timeout: Some(5000),
matcher: Some("Read.*".to_string()),
},
);
assert!(registry.has_hooks("PreToolUse"));
}
#[test]
fn test_hook_registry_clear() {
let mut registry = HookRegistry::new();
registry.register(
"PreToolUse",
HookDefinition {
command: Some("echo test".to_string()),
timeout: None,
matcher: None,
},
);
registry.clear();
assert!(!registry.has_hooks("PreToolUse"));
}
#[test]
fn test_hook_input_new() {
let input = HookInput::new("PreToolUse");
assert_eq!(input.event, "PreToolUse");
}
#[test]
fn test_hook_output_serialization() {
let output = HookOutput {
message: Some("test message".to_string()),
permission_update: None,
block: Some(true),
notification: None,
};
let json = serde_json::to_string(&output).unwrap();
assert!(json.contains("test message"));
}
#[test]
fn test_create_hook_registry() {
let registry = create_hook_registry(None);
assert!(!registry.has_hooks("PreToolUse"));
}
#[tokio::test]
async fn test_execute_no_hooks() {
let registry = HookRegistry::new();
let input = HookInput::new("PreToolUse");
let results = registry.execute("PreToolUse", input).await;
assert!(results.is_empty());
}
#[tokio::test]
async fn test_execute_with_invalid_event() {
let registry = HookRegistry::new();
let input = HookInput::new("InvalidEvent");
let results = registry.execute("InvalidEvent", input).await;
assert!(results.is_empty());
}
}