use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::extensions::permissions::Permission;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum HookKind {
BeforeToolCall,
AfterToolCall,
BeforeMessage,
OnMessageComplete,
OnCompaction,
OnSessionStart,
OnSessionEnd,
}
impl HookKind {
pub fn as_str(&self) -> &'static str {
match self {
Self::BeforeToolCall => "before_tool_call",
Self::AfterToolCall => "after_tool_call",
Self::BeforeMessage => "before_message",
Self::OnMessageComplete => "on_message_complete",
Self::OnCompaction => "on_compaction",
Self::OnSessionStart => "on_session_start",
Self::OnSessionEnd => "on_session_end",
}
}
pub fn from_str(s: &str) -> Option<Self> {
match s {
"before_tool_call" => Some(Self::BeforeToolCall),
"after_tool_call" => Some(Self::AfterToolCall),
"before_message" => Some(Self::BeforeMessage),
"on_message_complete" => Some(Self::OnMessageComplete),
"on_compaction" => Some(Self::OnCompaction),
"on_session_start" => Some(Self::OnSessionStart),
"on_session_end" => Some(Self::OnSessionEnd),
_ => None,
}
}
pub fn allowed_action_names(&self) -> &'static [&'static str] {
match self {
Self::BeforeToolCall => &["continue", "block", "confirm", "modify"],
Self::AfterToolCall => &["continue"],
Self::BeforeMessage => &["continue", "inject"],
Self::OnMessageComplete | Self::OnCompaction | Self::OnSessionStart | Self::OnSessionEnd => &["continue"],
}
}
pub fn allows_tool_filter(&self) -> bool {
matches!(self, Self::BeforeToolCall | Self::AfterToolCall)
}
pub fn allows_result(&self, result: &HookResult) -> bool {
match (self, result) {
(_, HookResult::Continue) => true,
(Self::BeforeToolCall, HookResult::Block { .. }) => true,
(Self::BeforeToolCall, HookResult::Confirm { .. }) => true,
(Self::BeforeToolCall, HookResult::Modify { .. }) => true,
(Self::BeforeMessage, HookResult::Inject { .. }) => true,
_ => false,
}
}
pub fn required_permission(&self) -> Permission {
match self {
Self::BeforeToolCall | Self::AfterToolCall => Permission::ToolsIntercept,
Self::BeforeMessage | Self::OnMessageComplete | Self::OnCompaction => Permission::LlmContent,
Self::OnSessionStart | Self::OnSessionEnd => Permission::SessionLifecycle,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HookEvent {
pub kind: HookKind,
pub tool_name: Option<String>,
#[serde(default)]
pub tool_runtime_name: Option<String>,
pub tool_input: Option<Value>,
pub tool_output: Option<String>,
pub message: Option<String>,
pub session_id: Option<String>,
#[serde(default)]
pub transcript: Option<Vec<Value>>,
pub data: Value,
}
impl HookEvent {
pub fn before_tool_call(tool_name: &str, input: Value) -> Self {
Self {
kind: HookKind::BeforeToolCall,
tool_name: Some(tool_name.to_string()),
tool_input: Some(input),
tool_output: None,
message: None,
session_id: None,
tool_runtime_name: None,
transcript: None,
data: Value::Null,
}
}
pub fn after_tool_call(tool_name: &str, input: Value, output: String) -> Self {
const MAX_HOOK_OUTPUT: usize = 32 * 1024; let truncated_output = if output.len() > MAX_HOOK_OUTPUT {
let boundary = output
.char_indices()
.map(|(idx, _)| idx)
.take_while(|idx| *idx <= MAX_HOOK_OUTPUT)
.last()
.unwrap_or(0);
format!(
"{}…[truncated, {} total bytes]",
&output[..boundary],
output.len()
)
} else {
output
};
Self {
kind: HookKind::AfterToolCall,
tool_name: Some(tool_name.to_string()),
tool_input: Some(input),
tool_output: Some(truncated_output),
message: None,
session_id: None,
tool_runtime_name: None,
transcript: None,
data: Value::Null,
}
}
pub fn before_message(message: &str) -> Self {
Self {
kind: HookKind::BeforeMessage,
tool_name: None,
tool_input: None,
tool_output: None,
message: Some(message.to_string()),
session_id: None,
tool_runtime_name: None,
transcript: None,
data: Value::Null,
}
}
pub fn on_message_complete(message: &str, data: Value) -> Self {
Self {
kind: HookKind::OnMessageComplete,
tool_name: None,
tool_input: None,
tool_output: None,
message: Some(message.to_string()),
session_id: None,
tool_runtime_name: None,
transcript: None,
data,
}
}
pub fn on_compaction(
old_session_id: &str,
new_session_id: &str,
summary: &str,
message_count: usize,
mut data: Value,
) -> Self {
if !data.is_object() {
data = Value::Object(Default::default());
}
if let Some(object) = data.as_object_mut() {
object.insert("old_session_id".to_string(), Value::String(old_session_id.to_string()));
object.insert("new_session_id".to_string(), Value::String(new_session_id.to_string()));
object.insert("message_count".to_string(), Value::Number(message_count.into()));
}
Self {
kind: HookKind::OnCompaction,
tool_name: None,
tool_input: None,
tool_output: None,
message: Some(summary.to_string()),
session_id: Some(new_session_id.to_string()),
tool_runtime_name: None,
transcript: None,
data,
}
}
pub fn on_session_start(session_id: &str) -> Self {
Self {
kind: HookKind::OnSessionStart,
tool_name: None,
tool_input: None,
tool_output: None,
message: None,
session_id: Some(session_id.to_string()),
tool_runtime_name: None,
transcript: None,
data: Value::Null,
}
}
pub fn on_session_end(session_id: &str, transcript: Option<Vec<Value>>) -> Self {
Self {
kind: HookKind::OnSessionEnd,
tool_name: None,
tool_input: None,
tool_output: None,
message: None,
session_id: Some(session_id.to_string()),
tool_runtime_name: None,
transcript,
data: Value::Null,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "action", rename_all = "snake_case")]
pub enum HookResult {
Continue,
Block { reason: String },
Inject { content: String },
Confirm { message: String },
Modify { input: Value },
}
impl Default for HookResult {
fn default() -> Self {
Self::Continue
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn hook_kind_as_str_roundtrip() {
let all = [
HookKind::BeforeToolCall,
HookKind::AfterToolCall,
HookKind::BeforeMessage,
HookKind::OnMessageComplete,
HookKind::OnCompaction,
HookKind::OnSessionStart,
HookKind::OnSessionEnd,
];
for kind in all {
let s = kind.as_str();
assert_eq!(
HookKind::from_str(s),
Some(kind),
"round-trip failed for {s}"
);
}
}
#[test]
fn hook_kind_from_str_unknown_returns_none() {
assert_eq!(HookKind::from_str(""), None);
assert_eq!(HookKind::from_str("BeforeToolCall"), None); assert_eq!(HookKind::from_str("on_crash"), None);
}
#[test]
fn hook_kind_serde_snake_case() {
let serialized = serde_json::to_string(&HookKind::BeforeToolCall).unwrap();
assert_eq!(serialized, r#""before_tool_call""#);
let back: HookKind = serde_json::from_str(r#""on_session_end""#).unwrap();
assert_eq!(back, HookKind::OnSessionEnd);
}
#[test]
fn hook_kind_required_permission() {
assert_eq!(
HookKind::BeforeToolCall.required_permission(),
Permission::ToolsIntercept
);
assert_eq!(
HookKind::AfterToolCall.required_permission(),
Permission::ToolsIntercept
);
assert_eq!(
HookKind::BeforeMessage.required_permission(),
Permission::LlmContent
);
assert_eq!(
HookKind::OnMessageComplete.required_permission(),
Permission::LlmContent
);
assert_eq!(
HookKind::OnCompaction.required_permission(),
Permission::LlmContent
);
assert_eq!(
HookKind::OnSessionStart.required_permission(),
Permission::SessionLifecycle
);
assert_eq!(
HookKind::OnSessionEnd.required_permission(),
Permission::SessionLifecycle
);
}
#[test]
fn hook_event_before_tool_call() {
let input = json!({"path": "/tmp/foo"});
let ev = HookEvent::before_tool_call("read_file", input.clone());
assert_eq!(ev.kind, HookKind::BeforeToolCall);
assert_eq!(ev.tool_name.as_deref(), Some("read_file"));
assert_eq!(ev.tool_input.as_ref(), Some(&input));
assert!(ev.tool_output.is_none());
assert!(ev.message.is_none());
assert!(ev.session_id.is_none());
assert_eq!(ev.data, Value::Null);
}
#[test]
fn hook_event_after_tool_call() {
let input = json!({"query": "select 1"});
let ev =
HookEvent::after_tool_call("sql_query", input.clone(), "1 row".to_string());
assert_eq!(ev.kind, HookKind::AfterToolCall);
assert_eq!(ev.tool_name.as_deref(), Some("sql_query"));
assert_eq!(ev.tool_input.as_ref(), Some(&input));
assert_eq!(ev.tool_output.as_deref(), Some("1 row"));
assert!(ev.message.is_none());
assert!(ev.session_id.is_none());
}
#[test]
fn hook_event_before_message() {
let ev = HookEvent::before_message("Hello, LLM");
assert_eq!(ev.kind, HookKind::BeforeMessage);
assert!(ev.tool_name.is_none());
assert!(ev.tool_input.is_none());
assert!(ev.tool_output.is_none());
assert_eq!(ev.message.as_deref(), Some("Hello, LLM"));
assert!(ev.session_id.is_none());
}
#[test]
fn hook_event_on_message_complete() {
let ev = HookEvent::on_message_complete("Done", json!({"content_block_count": 1}));
assert_eq!(ev.kind, HookKind::OnMessageComplete);
assert!(ev.tool_name.is_none());
assert!(ev.tool_input.is_none());
assert!(ev.tool_output.is_none());
assert_eq!(ev.message.as_deref(), Some("Done"));
assert_eq!(ev.data["content_block_count"], 1);
assert!(ev.session_id.is_none());
}
#[test]
fn hook_event_on_compaction() {
let ev = HookEvent::on_compaction(
"old-session",
"new-session",
"Summary",
7,
json!({"source": "manual"}),
);
assert_eq!(ev.kind, HookKind::OnCompaction);
assert_eq!(ev.message.as_deref(), Some("Summary"));
assert_eq!(ev.session_id.as_deref(), Some("new-session"));
assert_eq!(ev.data["old_session_id"], "old-session");
assert_eq!(ev.data["new_session_id"], "new-session");
assert_eq!(ev.data["message_count"], 7);
assert_eq!(ev.data["source"], "manual");
assert!(ev.transcript.is_none());
}
#[test]
fn hook_event_on_session_start() {
let ev = HookEvent::on_session_start("sess-abc-123");
assert_eq!(ev.kind, HookKind::OnSessionStart);
assert_eq!(ev.session_id.as_deref(), Some("sess-abc-123"));
assert!(ev.tool_name.is_none());
assert!(ev.message.is_none());
}
#[test]
fn hook_event_on_session_end() {
let ev = HookEvent::on_session_end("sess-abc-123", None);
assert_eq!(ev.kind, HookKind::OnSessionEnd);
assert_eq!(ev.session_id.as_deref(), Some("sess-abc-123"));
assert!(ev.tool_name.is_none());
assert!(ev.message.is_none());
}
#[test]
fn hook_event_serde_roundtrip() {
let ev = HookEvent::before_tool_call("bash", json!({"cmd": "ls"}));
let json = serde_json::to_string(&ev).unwrap();
let back: HookEvent = serde_json::from_str(&json).unwrap();
assert_eq!(back.kind, ev.kind);
assert_eq!(back.tool_name, ev.tool_name);
assert_eq!(back.tool_input, ev.tool_input);
}
#[test]
fn hook_result_default_is_continue() {
assert!(matches!(HookResult::default(), HookResult::Continue));
}
#[test]
fn hook_result_block_serde() {
let r = HookResult::Block {
reason: "denied by policy".to_string(),
};
let json = serde_json::to_string(&r).unwrap();
assert!(json.contains(r#""action":"block""#));
assert!(json.contains("denied by policy"));
let back: HookResult = serde_json::from_str(&json).unwrap();
assert!(matches!(back, HookResult::Block { reason } if reason == "denied by policy"));
}
#[test]
fn hook_result_confirm_serde() {
let r = HookResult::Confirm {
message: "Run this command?".to_string(),
};
let json = serde_json::to_string(&r).unwrap();
assert_eq!(json, r#"{"action":"confirm","message":"Run this command?"}"#);
let back: HookResult = serde_json::from_str(&json).unwrap();
assert!(matches!(back, HookResult::Confirm { message } if message == "Run this command?"));
}
#[test]
fn hook_result_modify_serde() {
let r = HookResult::Modify { input: json!({"command": "echo safe"}) };
let json = serde_json::to_string(&r).unwrap();
assert_eq!(json, r#"{"action":"modify","input":{"command":"echo safe"}}"#);
let back: HookResult = serde_json::from_str(&json).unwrap();
assert!(matches!(back, HookResult::Modify { input } if input == json!({"command": "echo safe"})));
}
#[test]
fn hook_result_continue_serde() {
let json = serde_json::to_string(&HookResult::Continue).unwrap();
assert_eq!(json, r#"{"action":"continue"}"#);
}
}
impl HookEvent {
pub fn with_runtime_name(mut self, name: &str) -> Self {
self.tool_runtime_name = Some(name.to_string());
self
}
}