use super::permission::JcliConfig;
use super::storage::ChatMessage;
use crate::util::log::{write_error_log, write_info_log};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::io::Write;
use std::process::Command;
use std::sync::{Arc, Mutex};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum HookEvent {
PreSendMessage,
PostSendMessage,
PreLlmRequest,
PostLlmResponse,
PreToolExecution,
PostToolExecution,
SessionStart,
SessionEnd,
}
impl std::str::FromStr for HookEvent {
type Err = ();
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"pre_send_message" => Ok(HookEvent::PreSendMessage),
"post_send_message" => Ok(HookEvent::PostSendMessage),
"pre_llm_request" => Ok(HookEvent::PreLlmRequest),
"post_llm_response" => Ok(HookEvent::PostLlmResponse),
"pre_tool_execution" => Ok(HookEvent::PreToolExecution),
"post_tool_execution" => Ok(HookEvent::PostToolExecution),
"session_start" => Ok(HookEvent::SessionStart),
"session_end" => Ok(HookEvent::SessionEnd),
_ => Err(()),
}
}
}
impl HookEvent {
pub fn as_str(&self) -> &'static str {
match self {
HookEvent::PreSendMessage => "pre_send_message",
HookEvent::PostSendMessage => "post_send_message",
HookEvent::PreLlmRequest => "pre_llm_request",
HookEvent::PostLlmResponse => "post_llm_response",
HookEvent::PreToolExecution => "pre_tool_execution",
HookEvent::PostToolExecution => "post_tool_execution",
HookEvent::SessionStart => "session_start",
HookEvent::SessionEnd => "session_end",
}
}
pub fn all() -> &'static [HookEvent] {
&[
HookEvent::PreSendMessage,
HookEvent::PostSendMessage,
HookEvent::PreLlmRequest,
HookEvent::PostLlmResponse,
HookEvent::PreToolExecution,
HookEvent::PostToolExecution,
HookEvent::SessionStart,
HookEvent::SessionEnd,
]
}
pub fn parse(s: &str) -> Option<HookEvent> {
s.parse().ok()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HookDef {
pub command: String,
#[serde(default = "default_timeout")]
pub timeout: u64,
}
fn default_timeout() -> u64 {
10
}
#[derive(Clone)]
pub enum HookKind {
Shell(ShellHook),
Builtin(BuiltinHook),
}
#[derive(Debug, Clone)]
pub struct ShellHook {
pub command: String,
pub timeout: u64,
}
impl From<HookDef> for ShellHook {
fn from(def: HookDef) -> Self {
ShellHook {
command: def.command,
timeout: def.timeout,
}
}
}
impl From<HookDef> for HookKind {
fn from(def: HookDef) -> Self {
HookKind::Shell(ShellHook::from(def))
}
}
pub type BuiltinHookFn = Arc<dyn Fn(&HookContext) -> Option<HookResult> + Send + Sync>;
pub struct BuiltinHook {
pub name: String,
pub handler: BuiltinHookFn,
}
impl Clone for BuiltinHook {
fn clone(&self) -> Self {
BuiltinHook {
name: self.name.clone(),
handler: Arc::clone(&self.handler),
}
}
}
impl std::fmt::Debug for HookKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
HookKind::Shell(shell) => f
.debug_struct("HookKind::Shell")
.field("command", &shell.command)
.field("timeout", &shell.timeout)
.finish(),
HookKind::Builtin(builtin) => f
.debug_struct("HookKind::Builtin")
.field("name", &builtin.name)
.finish(),
}
}
}
#[derive(Debug, Serialize)]
pub struct HookContext {
pub event: HookEvent,
#[serde(skip_serializing_if = "Option::is_none")]
pub messages: Option<Vec<ChatMessage>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_prompt: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub user_input: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub assistant_output: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_arguments: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_result: Option<String>,
pub cwd: String,
}
impl Default for HookContext {
fn default() -> Self {
Self {
event: HookEvent::SessionStart,
messages: None,
system_prompt: None,
model: None,
user_input: None,
assistant_output: None,
tool_name: None,
tool_arguments: None,
tool_result: None,
cwd: std::env::current_dir()
.map(|p| p.display().to_string())
.unwrap_or_else(|_| ".".to_string()),
}
}
}
#[derive(Debug, Deserialize, Default)]
pub struct HookResult {
#[serde(default)]
pub messages: Option<Vec<ChatMessage>>,
#[serde(default)]
pub system_prompt: Option<String>,
#[serde(default)]
pub user_input: Option<String>,
#[serde(default)]
pub assistant_output: Option<String>,
#[serde(default)]
pub tool_arguments: Option<String>,
#[serde(default)]
pub tool_result: Option<String>,
#[serde(default)]
pub inject_messages: Option<Vec<ChatMessage>>,
#[serde(default)]
pub abort: bool,
#[serde(default)]
pub _switch_model: Option<String>,
}
#[derive(Debug, Clone, Default)]
pub struct HookManager {
builtin_hooks: HashMap<HookEvent, Vec<HookKind>>,
user_hooks: HashMap<HookEvent, Vec<HookKind>>,
project_hooks: HashMap<HookEvent, Vec<HookKind>>,
session_hooks: HashMap<HookEvent, Vec<HookKind>>,
}
const HOOK_SOURCE_BUILTIN: &str = "builtin";
const HOOK_SOURCE_USER: &str = "user";
const HOOK_SOURCE_PROJECT: &str = "project";
const HOOK_SOURCE_SESSION: &str = "session";
pub struct HookEntry {
pub event: HookEvent,
pub source: &'static str,
pub label: String,
pub timeout: Option<u64>,
}
impl HookManager {
pub fn load() -> Self {
let mut manager = HookManager::default();
let user_hooks_path = super::storage::hooks_config_path();
if user_hooks_path.is_file() {
match std::fs::read_to_string(&user_hooks_path) {
Ok(content) => {
match serde_yaml::from_str::<HashMap<String, Vec<HookDef>>>(&content) {
Ok(hooks_map) => {
for (event_name, defs) in hooks_map {
if let Some(event) = HookEvent::parse(&event_name) {
manager
.user_hooks
.entry(event)
.or_default()
.extend(defs.into_iter().map(HookKind::from));
} else {
write_error_log(
"HookManager::load",
&format!("未知 hook 事件: {}", event_name),
);
}
}
write_info_log(
"HookManager::load",
&format!("已加载用户级 hooks: {}", user_hooks_path.display()),
);
}
Err(e) => {
write_error_log(
"HookManager::load",
&format!("解析用户级 hooks.yaml 失败: {}", e),
);
}
}
}
Err(e) => {
write_error_log("HookManager::load", &format!("读取 hooks.yaml 失败: {}", e));
}
}
}
if let Some(config_dir) = JcliConfig::find_config_dir() {
let hooks_path = config_dir.join("hooks.yaml");
if hooks_path.is_file() {
match std::fs::read_to_string(&hooks_path) {
Ok(content) => {
match serde_yaml::from_str::<HashMap<String, Vec<HookDef>>>(&content) {
Ok(hooks_map) => {
for (event_name, defs) in hooks_map {
if let Some(event) = HookEvent::parse(&event_name) {
manager
.project_hooks
.entry(event)
.or_default()
.extend(defs.into_iter().map(HookKind::from));
} else {
write_error_log(
"HookManager::load",
&format!(
"项目级 .jcli/hooks.yaml 中未知 hook 事件: {}",
event_name
),
);
}
}
write_info_log(
"HookManager::load",
&format!("已加载项目级 hooks: {}", hooks_path.display()),
);
}
Err(e) => {
write_error_log(
"HookManager::load",
&format!("解析项目级 hooks.yaml 失败: {}", e),
);
}
}
}
Err(e) => {
write_error_log(
"HookManager::load",
&format!("读取项目级 hooks.yaml 失败: {}", e),
);
}
}
}
}
manager
}
pub fn register_builtin(
&mut self,
event: HookEvent,
name: impl Into<String>,
handler: impl Fn(&HookContext) -> Option<HookResult> + Send + Sync + 'static,
) {
self.builtin_hooks
.entry(event)
.or_default()
.push(HookKind::Builtin(BuiltinHook {
name: name.into(),
handler: Arc::new(handler),
}));
}
pub fn register_session_hook(&mut self, event: HookEvent, def: HookDef) {
self.session_hooks
.entry(event)
.or_default()
.push(HookKind::Shell(ShellHook::from(def)));
}
pub fn remove_session_hook(&mut self, event: HookEvent, index: usize) -> bool {
if let Some(hooks) = self.session_hooks.get_mut(&event)
&& index < hooks.len()
{
hooks.remove(index);
return true;
}
false
}
pub fn list_hooks(&self) -> Vec<HookEntry> {
let mut result = Vec::new();
for event in HookEvent::all() {
if let Some(hooks) = self.builtin_hooks.get(event) {
for hook in hooks {
result.push(HookEntry {
event: *event,
source: HOOK_SOURCE_BUILTIN,
label: hook_label(hook),
timeout: hook_timeout(hook),
});
}
}
if let Some(hooks) = self.user_hooks.get(event) {
for hook in hooks {
result.push(HookEntry {
event: *event,
source: HOOK_SOURCE_USER,
label: hook_label(hook),
timeout: hook_timeout(hook),
});
}
}
if let Some(hooks) = self.project_hooks.get(event) {
for hook in hooks {
result.push(HookEntry {
event: *event,
source: HOOK_SOURCE_PROJECT,
label: hook_label(hook),
timeout: hook_timeout(hook),
});
}
}
if let Some(hooks) = self.session_hooks.get(event) {
for hook in hooks {
result.push(HookEntry {
event: *event,
source: HOOK_SOURCE_SESSION,
label: hook_label(hook),
timeout: hook_timeout(hook),
});
}
}
}
result
}
pub fn has_hooks_for(&self, event: HookEvent) -> bool {
self.builtin_hooks
.get(&event)
.is_some_and(|h| !h.is_empty())
|| self.user_hooks.get(&event).is_some_and(|h| !h.is_empty())
|| self
.project_hooks
.get(&event)
.is_some_and(|h| !h.is_empty())
|| self
.session_hooks
.get(&event)
.is_some_and(|h| !h.is_empty())
}
pub fn execute_fire_and_forget(
manager: Arc<Mutex<HookManager>>,
event: HookEvent,
context: HookContext,
) {
std::thread::spawn(move || {
if let Ok(m) = manager.lock() {
let _ = m.execute(event, context);
}
});
}
pub fn execute(&self, event: HookEvent, mut context: HookContext) -> Option<HookResult> {
let mut all_hooks: Vec<&HookKind> = Vec::new();
if let Some(hooks) = self.builtin_hooks.get(&event) {
all_hooks.extend(hooks.iter());
}
if let Some(hooks) = self.user_hooks.get(&event) {
all_hooks.extend(hooks.iter());
}
if let Some(hooks) = self.project_hooks.get(&event) {
all_hooks.extend(hooks.iter());
}
if let Some(hooks) = self.session_hooks.get(&event) {
all_hooks.extend(hooks.iter());
}
if all_hooks.is_empty() {
return None;
}
write_info_log(
"HookManager::execute",
&format!(
"执行 {} 个 hook (事件: {})",
all_hooks.len(),
event.as_str()
),
);
let mut had_modification = false;
let mut final_result = HookResult::default();
for hook in all_hooks {
match execute_hook(hook, &context) {
Ok(result) => {
let hook_label = hook_label(hook);
if result.abort {
write_info_log(
"HookManager::execute",
&format!("Hook abort ({})", hook_label),
);
return Some(HookResult {
abort: true,
..Default::default()
});
}
if let Some(ref msgs) = result.messages {
context.messages = Some(msgs.clone());
final_result.messages = Some(msgs.clone());
had_modification = true;
}
if let Some(ref sp) = result.system_prompt {
context.system_prompt = Some(sp.clone());
final_result.system_prompt = Some(sp.clone());
had_modification = true;
}
if let Some(ref ui) = result.user_input {
context.user_input = Some(ui.clone());
final_result.user_input = Some(ui.clone());
had_modification = true;
}
if let Some(ref ao) = result.assistant_output {
context.assistant_output = Some(ao.clone());
final_result.assistant_output = Some(ao.clone());
had_modification = true;
}
if let Some(ref ta) = result.tool_arguments {
context.tool_arguments = Some(ta.clone());
final_result.tool_arguments = Some(ta.clone());
had_modification = true;
}
if let Some(ref tr) = result.tool_result {
context.tool_result = Some(tr.clone());
final_result.tool_result = Some(tr.clone());
had_modification = true;
}
if let Some(ref inject) = result.inject_messages {
let existing = final_result.inject_messages.get_or_insert_with(Vec::new);
existing.extend(inject.clone());
had_modification = true;
}
}
Err(e) => {
let hook_label = hook_label(hook);
write_error_log(
"HookManager::execute",
&format!("Hook 执行失败 ({}): {}", hook_label, e),
);
return Some(HookResult {
abort: true,
..Default::default()
});
}
}
}
if had_modification {
Some(final_result)
} else {
None
}
}
}
fn execute_hook(kind: &HookKind, context: &HookContext) -> Result<HookResult, String> {
match kind {
HookKind::Shell(shell) => execute_shell_hook(shell, context),
HookKind::Builtin(builtin) => match (builtin.handler)(context) {
Some(result) => Ok(result),
None => Ok(HookResult::default()),
},
}
}
fn execute_shell_hook(hook: &ShellHook, context: &HookContext) -> Result<HookResult, String> {
let context_json =
serde_json::to_string(context).map_err(|e| format!("序列化 context 失败: {}", e))?;
let cwd = std::env::current_dir().map_err(|e| format!("获取 cwd 失败: {}", e))?;
let mut child = Command::new("sh")
.arg("-c")
.arg(&hook.command)
.current_dir(&cwd)
.env("JCLI_HOOK_EVENT", context.event.as_str())
.env("JCLI_CWD", cwd.display().to_string())
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.spawn()
.map_err(|e| format!("启动 hook 进程失败: {}", e))?;
if let Some(mut stdin) = child.stdin.take() {
let _ = stdin.write_all(context_json.as_bytes());
}
let timeout = std::time::Duration::from_secs(hook.timeout);
let start = std::time::Instant::now();
loop {
match child.try_wait() {
Ok(Some(status)) => {
if !status.success() {
return Err(format!("Hook 退出码: {:?}", status.code()));
}
let output = child
.wait_with_output()
.map_err(|e| format!("读取输出失败: {}", e))?;
let stdout = String::from_utf8_lossy(&output.stdout).to_string();
let stdout = stdout.trim();
if stdout.is_empty() || stdout == "{}" {
return Ok(HookResult::default());
}
let result: HookResult = serde_json::from_str(stdout)
.map_err(|e| format!("解析 hook 输出 JSON 失败: {} (输出: {})", e, stdout))?;
write_info_log(
"execute_shell_hook",
&format!("Hook 完成 (cmd: {}), abort={}", hook.command, result.abort),
);
return Ok(result);
}
Ok(None) => {
if start.elapsed() > timeout {
let _ = child.kill();
return Err(format!("Hook 超时 ({}s): {}", hook.timeout, hook.command));
}
std::thread::sleep(std::time::Duration::from_millis(50));
}
Err(e) => {
return Err(format!("等待 hook 进程失败: {}", e));
}
}
}
}
fn hook_label(kind: &HookKind) -> String {
match kind {
HookKind::Shell(shell) => shell.command.clone(),
HookKind::Builtin(builtin) => format!("[builtin: {}]", builtin.name),
}
}
fn hook_timeout(kind: &HookKind) -> Option<u64> {
match kind {
HookKind::Shell(shell) => Some(shell.timeout),
HookKind::Builtin(_) => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hook_event_roundtrip() {
for event in HookEvent::all() {
let s = event.as_str();
let parsed = HookEvent::parse(s).unwrap();
assert_eq!(*event, parsed);
}
}
#[test]
fn test_hook_event_from_str_invalid() {
assert!(HookEvent::parse("unknown_event").is_none());
}
#[test]
fn test_hook_def_default_timeout() {
let yaml = r#"command: "echo hello""#;
let def: HookDef = serde_yaml::from_str(yaml).unwrap();
assert_eq!(def.timeout, 10);
}
#[test]
fn test_hook_def_to_hook_kind() {
let def = HookDef {
command: "echo test".to_string(),
timeout: 5,
};
let kind = HookKind::from(def);
match kind {
HookKind::Shell(shell) => {
assert_eq!(shell.command, "echo test");
assert_eq!(shell.timeout, 5);
}
HookKind::Builtin(_) => panic!("应该转换为 Shell 变体"),
}
}
#[test]
fn test_hook_result_empty_json() {
let result: HookResult = serde_json::from_str("{}").unwrap();
assert!(!result.abort);
assert!(result.messages.is_none());
assert!(result.user_input.is_none());
}
#[test]
fn test_hook_result_with_abort() {
let json = r#"{"abort": true}"#;
let result: HookResult = serde_json::from_str(json).unwrap();
assert!(result.abort);
}
#[test]
fn test_hook_result_with_user_input() {
let json = r#"{"user_input": "[modified] hello"}"#;
let result: HookResult = serde_json::from_str(json).unwrap();
assert_eq!(result.user_input.as_deref(), Some("[modified] hello"));
}
#[test]
fn test_hook_context_serialization() {
let ctx = HookContext {
event: HookEvent::PreSendMessage,
user_input: Some("hello".to_string()),
..Default::default()
};
let json = serde_json::to_string(&ctx).unwrap();
assert!(json.contains("pre_send_message"));
assert!(json.contains("hello"));
assert!(json.contains("user_input"));
assert!(!json.contains("messages"));
assert!(!json.contains("tool_name"));
}
#[test]
fn test_execute_shell_hook_echo() {
let hook = ShellHook {
command: r#"echo '{"user_input": "hooked"}'"#.to_string(),
timeout: 5,
};
let ctx = HookContext {
event: HookEvent::PreSendMessage,
user_input: Some("original".to_string()),
..Default::default()
};
let result = execute_shell_hook(&hook, &ctx).unwrap();
assert_eq!(result.user_input.as_deref(), Some("hooked"));
assert!(!result.abort);
}
#[test]
fn test_execute_shell_hook_empty_output() {
let hook = ShellHook {
command: "echo ''".to_string(),
timeout: 5,
};
let ctx = HookContext::default();
let result = execute_shell_hook(&hook, &ctx).unwrap();
assert!(!result.abort);
assert!(result.user_input.is_none());
}
#[test]
fn test_execute_shell_hook_nonzero_exit() {
let hook = ShellHook {
command: "exit 1".to_string(),
timeout: 5,
};
let ctx = HookContext::default();
let result = execute_shell_hook(&hook, &ctx);
assert!(result.is_err());
}
#[test]
fn test_execute_shell_hook_reads_stdin() {
let hook = ShellHook {
command: r#"input=$(cat); event=$(echo "$input" | python3 -c "import sys,json; print(json.load(sys.stdin).get('event',''))" 2>/dev/null || echo ""); echo '{"user_input": "got_input"}'"#.to_string(),
timeout: 5,
};
let ctx = HookContext {
event: HookEvent::PreSendMessage,
user_input: Some("test".to_string()),
..Default::default()
};
let result = execute_shell_hook(&hook, &ctx).unwrap();
assert_eq!(result.user_input.as_deref(), Some("got_input"));
}
#[test]
fn test_execute_builtin_hook() {
let builtin = BuiltinHook {
name: "test_hook".to_string(),
handler: Arc::new(|ctx| {
if let Some(ref input) = ctx.user_input {
Some(HookResult {
user_input: Some(format!("[hooked] {}", input)),
..Default::default()
})
} else {
None
}
}),
};
let kind = HookKind::Builtin(builtin);
let ctx = HookContext {
event: HookEvent::PreSendMessage,
user_input: Some("original".to_string()),
..Default::default()
};
let result = execute_hook(&kind, &ctx).unwrap();
assert_eq!(result.user_input.as_deref(), Some("[hooked] original"));
}
#[test]
fn test_execute_builtin_hook_returns_none() {
let builtin = BuiltinHook {
name: "no_op".to_string(),
handler: Arc::new(|_| None),
};
let kind = HookKind::Builtin(builtin);
let ctx = HookContext::default();
let result = execute_hook(&kind, &ctx).unwrap();
assert!(!result.abort);
assert!(result.user_input.is_none());
}
#[test]
fn test_hook_manager_empty() {
let manager = HookManager::default();
assert!(manager.list_hooks().is_empty());
let result = manager.execute(HookEvent::PreSendMessage, HookContext::default());
assert!(result.is_none());
}
#[test]
fn test_hook_manager_session_hooks() {
let mut manager = HookManager::default();
manager.register_session_hook(
HookEvent::PreSendMessage,
HookDef {
command: r#"echo '{"user_input": "session_hooked"}'"#.to_string(),
timeout: 5,
},
);
let hooks = manager.list_hooks();
assert_eq!(hooks.len(), 1);
assert_eq!(hooks[0].source, "session");
let result = manager
.execute(
HookEvent::PreSendMessage,
HookContext {
event: HookEvent::PreSendMessage,
user_input: Some("original".to_string()),
..Default::default()
},
)
.unwrap();
assert_eq!(result.user_input.as_deref(), Some("session_hooked"));
}
#[test]
fn test_hook_manager_builtin_hooks() {
let mut manager = HookManager::default();
manager.register_builtin(HookEvent::PreSendMessage, "test_builtin", |ctx| {
if let Some(ref input) = ctx.user_input {
Some(HookResult {
user_input: Some(format!("[builtin] {}", input)),
..Default::default()
})
} else {
None
}
});
let hooks = manager.list_hooks();
assert_eq!(hooks.len(), 1);
assert_eq!(hooks[0].source, "builtin");
assert!(hooks[0].label.contains("test_builtin"));
let result = manager
.execute(
HookEvent::PreSendMessage,
HookContext {
event: HookEvent::PreSendMessage,
user_input: Some("hello".to_string()),
..Default::default()
},
)
.unwrap();
assert_eq!(result.user_input.as_deref(), Some("[builtin] hello"));
}
#[test]
fn test_hook_manager_builtin_before_session() {
let mut manager = HookManager::default();
manager.register_builtin(HookEvent::PreSendMessage, "prefix", |ctx| {
if let Some(ref input) = ctx.user_input {
Some(HookResult {
user_input: Some(format!("[builtin] {}", input)),
..Default::default()
})
} else {
None
}
});
manager.register_session_hook(
HookEvent::PreSendMessage,
HookDef {
command: r#"echo '{"user_input": "session_overridden"}'"#.to_string(),
timeout: 5,
},
);
let result = manager
.execute(
HookEvent::PreSendMessage,
HookContext {
event: HookEvent::PreSendMessage,
user_input: Some("original".to_string()),
..Default::default()
},
)
.unwrap();
assert_eq!(result.user_input.as_deref(), Some("session_overridden"));
}
#[test]
fn test_hook_manager_remove_session_hook() {
let mut manager = HookManager::default();
manager.register_session_hook(
HookEvent::PreSendMessage,
HookDef {
command: "echo test".to_string(),
timeout: 5,
},
);
assert_eq!(manager.list_hooks().len(), 1);
assert!(manager.remove_session_hook(HookEvent::PreSendMessage, 0));
assert!(manager.list_hooks().is_empty());
assert!(!manager.remove_session_hook(HookEvent::PreSendMessage, 0));
}
#[test]
fn test_hook_chain_execution() {
let mut manager = HookManager::default();
manager.register_session_hook(
HookEvent::PreSendMessage,
HookDef {
command: r#"echo '{"user_input": "first"}'"#.to_string(),
timeout: 5,
},
);
manager.register_session_hook(
HookEvent::PreSendMessage,
HookDef {
command: r#"echo '{"user_input": "second"}'"#.to_string(),
timeout: 5,
},
);
let result = manager
.execute(
HookEvent::PreSendMessage,
HookContext {
event: HookEvent::PreSendMessage,
user_input: Some("original".to_string()),
..Default::default()
},
)
.unwrap();
assert_eq!(result.user_input.as_deref(), Some("second"));
}
#[test]
fn test_hook_abort_stops_chain() {
let mut manager = HookManager::default();
manager.register_session_hook(
HookEvent::PreSendMessage,
HookDef {
command: "exit 1".to_string(), timeout: 5,
},
);
manager.register_session_hook(
HookEvent::PreSendMessage,
HookDef {
command: r#"echo '{"user_input": "should_not_reach"}'"#.to_string(),
timeout: 5,
},
);
let result = manager
.execute(
HookEvent::PreSendMessage,
HookContext {
event: HookEvent::PreSendMessage,
..Default::default()
},
)
.unwrap();
assert!(result.abort);
assert!(result.user_input.is_none());
}
#[test]
fn test_builtin_hook_clone() {
let mut manager = HookManager::default();
manager.register_builtin(HookEvent::PreLlmRequest, "test_clone", |_| {
Some(HookResult::default())
});
let cloned = manager.clone();
assert_eq!(cloned.list_hooks().len(), 1);
}
}