use regex::Regex;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum PermissionDecision {
Allow,
Deny,
Ask,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum ContinueReason {
Approved,
Modified,
ContextAdded,
Conditional,
Custom(String),
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum StopReason {
SecurityViolation,
ErrorDetected,
UserRequested,
Critical,
Custom(String),
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct HookMatcher {
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(with = "serde_regex", default)]
pub tool_name_regex: Option<Regex>,
#[serde(skip_serializing_if = "Option::is_none")]
pub required_input_fields: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub event_types: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub timeout: Option<f64>,
}
impl HookMatcher {
pub fn new() -> Self {
Self::default()
}
pub fn with_tool_name(mut self, name: impl Into<String>) -> Self {
self.tool_name = Some(name.into());
self
}
pub fn with_tool_name_regex(mut self, pattern: &str) -> Self {
self.tool_name_regex = Some(Regex::new(pattern).expect("Invalid regex pattern"));
self
}
pub fn try_with_tool_name_regex(mut self, pattern: &str) -> Result<Self, regex::Error> {
self.tool_name_regex = Some(Regex::new(pattern)?);
Ok(self)
}
pub fn with_required_fields(mut self, fields: Vec<String>) -> Self {
self.required_input_fields = Some(fields);
self
}
pub fn with_event_types(mut self, events: Vec<String>) -> Self {
self.event_types = Some(events);
self
}
pub fn with_timeout(mut self, seconds: f64) -> Self {
self.timeout = Some(seconds);
self
}
pub fn with_timeout_duration(mut self, duration: std::time::Duration) -> Self {
self.timeout = Some(duration.as_secs_f64());
self
}
pub fn matches(&self, context: &HookContext) -> bool {
if let Some(ref name) = self.tool_name
&& context.tool_name.as_ref() != Some(name)
{
return false;
}
if let Some(ref regex) = self.tool_name_regex {
if let Some(ref tool_name) = context.tool_name {
if !regex.is_match(tool_name) {
return false;
}
} else {
return false;
}
}
if let Some(ref required_fields) = self.required_input_fields {
if let Some(ref input) = context.tool_input {
for field in required_fields {
if input.get(field).is_none() {
return false;
}
}
} else {
return false;
}
}
if let Some(ref event_types) = self.event_types
&& !event_types.contains(&context.event_type)
{
return false;
}
true
}
pub fn is_empty(&self) -> bool {
self.tool_name.is_none()
&& self.tool_name_regex.is_none()
&& self.required_input_fields.is_none()
&& self.event_types.is_none()
}
pub fn timeout_or_default(&self) -> f64 {
self.timeout.unwrap_or(60.0)
}
}
#[derive(Debug, Clone, Default)]
pub struct HookContext {
pub event_type: String,
pub tool_name: Option<String>,
pub tool_input: Option<serde_json::Value>,
pub tool_output: Option<serde_json::Value>,
pub session_id: Option<String>,
}
impl HookContext {
pub fn new(event_type: impl Into<String>) -> Self {
Self {
event_type: event_type.into(),
..Default::default()
}
}
pub fn with_tool_name(mut self, name: impl Into<String>) -> Self {
self.tool_name = Some(name.into());
self
}
pub fn with_tool_input(mut self, input: serde_json::Value) -> Self {
self.tool_input = Some(input);
self
}
pub fn with_tool_output(mut self, output: serde_json::Value) -> Self {
self.tool_output = Some(output);
self
}
pub fn with_session_id(mut self, id: impl Into<String>) -> Self {
self.session_id = Some(id.into());
self
}
}
mod serde_regex {
use regex::Regex;
use serde::{Deserialize, Deserializer, Serializer};
pub fn serialize<S>(regex: &Option<Regex>, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match regex {
Some(r) => serializer.serialize_some(r.as_str()),
None => serializer.serialize_none(),
}
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<Regex>, D::Error>
where
D: Deserializer<'de>,
{
let opt: Option<String> = Option::deserialize(deserializer)?;
match opt {
Some(s) => Regex::new(&s).map(Some).map_err(serde::de::Error::custom),
None => Ok(None),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_permission_decision_serialization() {
let allow = PermissionDecision::Allow;
let json = serde_json::to_string(&allow).unwrap();
assert_eq!(json, r#""allow""#);
let deny = PermissionDecision::Deny;
let json = serde_json::to_string(&deny).unwrap();
assert_eq!(json, r#""deny""#);
let ask = PermissionDecision::Ask;
let json = serde_json::to_string(&ask).unwrap();
assert_eq!(json, r#""ask""#);
}
#[test]
fn test_continue_reason_serialization() {
let reason = ContinueReason::Approved;
let json = serde_json::to_string(&reason).unwrap();
assert_eq!(json, r#""approved""#);
let custom = ContinueReason::Custom("custom reason".to_string());
let json = serde_json::to_string(&custom).unwrap();
assert!(json.contains("custom reason"));
}
#[test]
fn test_stop_reason_serialization() {
let reason = StopReason::SecurityViolation;
let json = serde_json::to_string(&reason).unwrap();
assert_eq!(json, r#""security_violation""#);
let custom = StopReason::Custom("critical error".to_string());
let json = serde_json::to_string(&custom).unwrap();
assert!(json.contains("critical error"));
}
#[test]
fn test_hook_matcher_empty() {
let matcher = HookMatcher::new();
assert!(matcher.is_empty());
let context = HookContext::new("PreToolUse");
assert!(matcher.matches(&context));
}
#[test]
fn test_hook_matcher_tool_name_exact() {
let matcher = HookMatcher::new().with_tool_name("Bash");
let context = HookContext::new("PreToolUse").with_tool_name("Bash");
assert!(matcher.matches(&context));
let context = HookContext::new("PreToolUse").with_tool_name("Write");
assert!(!matcher.matches(&context));
}
#[test]
fn test_hook_matcher_tool_name_regex() {
let matcher = HookMatcher::new().with_tool_name_regex(r"^(Write|Edit|MultiEdit)$");
let context = HookContext::new("PreToolUse").with_tool_name("Write");
assert!(matcher.matches(&context));
let context = HookContext::new("PreToolUse").with_tool_name("Edit");
assert!(matcher.matches(&context));
let context = HookContext::new("PreToolUse").with_tool_name("Bash");
assert!(!matcher.matches(&context));
}
#[test]
fn test_hook_matcher_required_fields() {
let matcher = HookMatcher::new()
.with_tool_name("Bash")
.with_required_fields(vec!["command".to_string()]);
let input = serde_json::json!({ "command": "echo hello" });
let context = HookContext::new("PreToolUse")
.with_tool_name("Bash")
.with_tool_input(input);
assert!(matcher.matches(&context));
let input = serde_json::json!({ "other": "value" });
let context = HookContext::new("PreToolUse")
.with_tool_name("Bash")
.with_tool_input(input);
assert!(!matcher.matches(&context));
}
#[test]
fn test_hook_matcher_event_types() {
let matcher = HookMatcher::new()
.with_event_types(vec!["PreToolUse".to_string(), "PostToolUse".to_string()]);
let context = HookContext::new("PreToolUse");
assert!(matcher.matches(&context));
let context = HookContext::new("UserPromptSubmit");
assert!(!matcher.matches(&context));
}
#[test]
fn test_hook_matcher_combined() {
let matcher = HookMatcher::new()
.with_tool_name_regex(r"^(Write|Edit)$")
.with_event_types(vec!["PreToolUse".to_string()])
.with_required_fields(vec!["file_path".to_string()]);
let input = serde_json::json!({ "file_path": "/tmp/test.txt", "content": "test" });
let context = HookContext::new("PreToolUse")
.with_tool_name("Write")
.with_tool_input(input);
assert!(matcher.matches(&context));
let input = serde_json::json!({ "file_path": "/tmp/test.txt" });
let context = HookContext::new("PreToolUse")
.with_tool_name("Bash")
.with_tool_input(input);
assert!(!matcher.matches(&context));
let input = serde_json::json!({ "file_path": "/tmp/test.txt" });
let context = HookContext::new("PostToolUse")
.with_tool_name("Write")
.with_tool_input(input);
assert!(!matcher.matches(&context));
let input = serde_json::json!({ "content": "test" });
let context = HookContext::new("PreToolUse")
.with_tool_name("Write")
.with_tool_input(input);
assert!(!matcher.matches(&context));
}
#[test]
fn test_hook_matcher_serialization() {
let matcher = HookMatcher::new()
.with_tool_name("Bash")
.with_tool_name_regex(r"^Bash$");
let json = serde_json::to_string(&matcher).unwrap();
let deserialized: HookMatcher = serde_json::from_str(&json).unwrap();
assert_eq!(matcher.tool_name, deserialized.tool_name);
assert_eq!(
matcher.tool_name_regex.as_ref().map(|r| r.as_str()),
deserialized.tool_name_regex.as_ref().map(|r| r.as_str())
);
}
#[test]
fn test_hook_context_builder() {
let context = HookContext::new("PreToolUse")
.with_tool_name("Bash")
.with_tool_input(serde_json::json!({"command": "ls"}))
.with_session_id("session-123");
assert_eq!(context.event_type, "PreToolUse");
assert_eq!(context.tool_name, Some("Bash".to_string()));
assert_eq!(context.session_id, Some("session-123".to_string()));
}
#[test]
fn test_hook_matcher_no_tool_name() {
let matcher = HookMatcher::new().with_tool_name_regex(r"^Bash$");
let context = HookContext::new("UserPromptSubmit");
assert!(!matcher.matches(&context));
}
#[test]
fn test_hook_matcher_regex_partial_match() {
let matcher = HookMatcher::new().with_tool_name_regex(r"Write");
let context = HookContext::new("PreToolUse").with_tool_name("MultiWrite");
assert!(matcher.matches(&context));
let matcher = HookMatcher::new().with_tool_name_regex(r"^Write$");
let context = HookContext::new("PreToolUse").with_tool_name("MultiWrite");
assert!(!matcher.matches(&context));
let context = HookContext::new("PreToolUse").with_tool_name("Write");
assert!(matcher.matches(&context));
}
#[test]
fn test_hook_matcher_timeout() {
let matcher = HookMatcher::new();
assert_eq!(matcher.timeout, None);
assert_eq!(matcher.timeout_or_default(), 60.0);
let matcher = HookMatcher::new().with_timeout(30.0);
assert_eq!(matcher.timeout, Some(30.0));
assert_eq!(matcher.timeout_or_default(), 30.0);
let matcher = HookMatcher::new().with_timeout(0.5);
assert_eq!(matcher.timeout, Some(0.5));
assert_eq!(matcher.timeout_or_default(), 0.5);
}
#[test]
fn test_hook_matcher_timeout_duration() {
use std::time::Duration;
let matcher = HookMatcher::new().with_timeout_duration(Duration::from_secs(45));
assert_eq!(matcher.timeout, Some(45.0));
let matcher = HookMatcher::new().with_timeout_duration(Duration::from_millis(500));
assert_eq!(matcher.timeout, Some(0.5));
}
#[test]
fn test_hook_matcher_timeout_serialization() {
let matcher = HookMatcher::new()
.with_tool_name("Bash")
.with_timeout(30.5);
let json = serde_json::to_string(&matcher).unwrap();
assert!(json.contains("\"timeout\":30.5"));
let deserialized: HookMatcher = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.timeout, Some(30.5));
}
#[test]
fn test_hook_matcher_timeout_not_affects_empty() {
let matcher = HookMatcher::new().with_timeout(30.0);
assert!(matcher.is_empty());
let matcher = HookMatcher::new()
.with_timeout(30.0)
.with_tool_name("Bash");
assert!(!matcher.is_empty());
}
}