use super::super::permission::JcliConfig;
use super::super::storage::{ChatMessage, ModelProvider};
use crate::command::chat::constants::{
HOOK_DEFAULT_LLM_TIMEOUT_SECS, HOOK_DEFAULT_TIMEOUT_SECS, HOOK_LLM_MAX_TOKENS,
};
use crate::config::YamlConfig;
use crate::util::log::{write_error_log, write_info_log};
use nix::sys::signal::{self, Signal};
use nix::unistd::Pid;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::io::Write;
use std::path::{Path, PathBuf};
use std::process::Command;
use std::sync::{Arc, Mutex};
const MAX_CHAIN_DURATION_SECS: u64 = 30;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum HookEvent {
PreSendMessage,
PostSendMessage,
PreLlmRequest,
PostLlmResponse,
PreToolExecution,
PostToolExecution,
PostToolExecutionFailure,
Stop,
PreMicroCompact,
PostMicroCompact,
PreAutoCompact,
PostAutoCompact,
SessionStart,
SessionEnd,
}
impl std::str::FromStr for HookEvent {
type Err = ();
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"pre_send_message" => Ok(HookEvent::PreSendMessage),
"post_send_message" => Ok(HookEvent::PostSendMessage),
"pre_llm_request" => Ok(HookEvent::PreLlmRequest),
"post_llm_response" => Ok(HookEvent::PostLlmResponse),
"pre_tool_execution" => Ok(HookEvent::PreToolExecution),
"post_tool_execution" => Ok(HookEvent::PostToolExecution),
"post_tool_execution_failure" => Ok(HookEvent::PostToolExecutionFailure),
"stop" => Ok(HookEvent::Stop),
"pre_micro_compact" => Ok(HookEvent::PreMicroCompact),
"post_micro_compact" => Ok(HookEvent::PostMicroCompact),
"pre_auto_compact" => Ok(HookEvent::PreAutoCompact),
"post_auto_compact" => Ok(HookEvent::PostAutoCompact),
"session_start" => Ok(HookEvent::SessionStart),
"session_end" => Ok(HookEvent::SessionEnd),
_ => Err(()),
}
}
}
impl HookEvent {
pub fn as_str(&self) -> &'static str {
match self {
HookEvent::PreSendMessage => "pre_send_message",
HookEvent::PostSendMessage => "post_send_message",
HookEvent::PreLlmRequest => "pre_llm_request",
HookEvent::PostLlmResponse => "post_llm_response",
HookEvent::PreToolExecution => "pre_tool_execution",
HookEvent::PostToolExecution => "post_tool_execution",
HookEvent::PostToolExecutionFailure => "post_tool_execution_failure",
HookEvent::Stop => "stop",
HookEvent::PreMicroCompact => "pre_micro_compact",
HookEvent::PostMicroCompact => "post_micro_compact",
HookEvent::PreAutoCompact => "pre_auto_compact",
HookEvent::PostAutoCompact => "post_auto_compact",
HookEvent::SessionStart => "session_start",
HookEvent::SessionEnd => "session_end",
}
}
pub fn all() -> &'static [HookEvent] {
&[
HookEvent::PreSendMessage,
HookEvent::PostSendMessage,
HookEvent::PreLlmRequest,
HookEvent::PostLlmResponse,
HookEvent::PreToolExecution,
HookEvent::PostToolExecution,
HookEvent::PostToolExecutionFailure,
HookEvent::Stop,
HookEvent::PreMicroCompact,
HookEvent::PostMicroCompact,
HookEvent::PreAutoCompact,
HookEvent::PostAutoCompact,
HookEvent::SessionStart,
HookEvent::SessionEnd,
]
}
pub fn parse(s: &str) -> Option<HookEvent> {
s.parse().ok()
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum OnError {
#[default]
Skip,
Abort,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct HookFilter {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_name: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_matcher: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub model_prefix: Option<String>,
}
impl HookFilter {
pub fn is_empty(&self) -> bool {
self.tool_name.is_none() && self.tool_matcher.is_none() && self.model_prefix.is_none()
}
pub fn matches(&self, context: &HookContext) -> bool {
if let Some(ref expected_tool) = self.tool_name {
match &context.tool_name {
Some(actual) if actual == expected_tool => {}
Some(_) => return false,
None => return false,
}
} else if let Some(ref pattern) = self.tool_matcher {
let actual = match &context.tool_name {
Some(a) => a,
None => return false,
};
let matched = pattern.split('|').any(|p| p.trim() == actual);
if !matched {
return false;
}
}
if let Some(ref prefix) = self.model_prefix {
match &context.model {
Some(actual) if actual.starts_with(prefix.as_str()) => {}
Some(_) => return false,
None => return false,
}
}
true
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum HookType {
#[default]
Bash,
Llm,
}
impl std::fmt::Display for HookType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
HookType::Bash => write!(f, "bash"),
HookType::Llm => write!(f, "llm"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HookDef {
#[serde(default)]
pub r#type: HookType,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub command: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub prompt: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
#[serde(default = "default_timeout")]
pub timeout: u64,
#[serde(default)]
pub retry: u32,
#[serde(default)]
pub on_error: OnError,
#[serde(default, skip_serializing_if = "HookFilter::is_empty")]
pub filter: HookFilter,
}
fn default_timeout() -> u64 {
HOOK_DEFAULT_TIMEOUT_SECS
}
fn default_llm_timeout() -> u64 {
HOOK_DEFAULT_LLM_TIMEOUT_SECS
}
#[derive(Clone)]
pub enum HookKind {
Shell(ShellHook),
Llm(LlmHook),
Builtin(BuiltinHook),
}
#[derive(Debug, Clone)]
pub struct ShellHook {
pub name: Option<String>,
pub command: String,
pub timeout: u64,
pub retry: u32,
pub on_error: OnError,
pub filter: HookFilter,
pub dir_path: Option<PathBuf>,
}
#[derive(Debug, Clone)]
pub struct LlmHook {
pub name: Option<String>,
pub prompt: String,
pub model: Option<String>,
pub timeout: u64,
pub retry: u32,
pub on_error: OnError,
pub filter: HookFilter,
#[allow(dead_code)]
pub dir_path: Option<PathBuf>,
}
impl HookDef {
pub fn into_hook_kind(self) -> Result<HookKind, String> {
match self.r#type {
HookType::Bash => {
let command = self.command.unwrap_or_default();
if command.is_empty() {
return Err("bash hook 缺少 command 字段".to_string());
}
Ok(HookKind::Shell(ShellHook {
name: None,
command,
timeout: self.timeout,
retry: self.retry,
on_error: self.on_error,
filter: self.filter,
dir_path: None,
}))
}
HookType::Llm => {
let prompt = self.prompt.unwrap_or_default();
if prompt.is_empty() {
return Err("llm hook 缺少 prompt 字段".to_string());
}
Ok(HookKind::Llm(LlmHook {
name: None,
prompt,
model: self.model,
timeout: if self.timeout == default_timeout() {
default_llm_timeout()
} else {
self.timeout
},
retry: if self.retry == 0 { 1 } else { self.retry },
on_error: self.on_error,
filter: self.filter,
dir_path: None,
}))
}
}
}
}
impl From<HookDef> for HookKind {
fn from(def: HookDef) -> Self {
def.into_hook_kind().unwrap_or_else(|e| {
write_error_log("HookDef::into_hook_kind", &e);
HookKind::Shell(ShellHook {
name: None,
command: String::new(),
timeout: 0,
retry: 0,
on_error: OnError::Skip,
filter: HookFilter::default(),
dir_path: None,
})
})
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HookDirDef {
pub events: Vec<HookEvent>,
#[serde(default)]
pub r#type: HookType,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub command: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub prompt: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
#[serde(default = "default_timeout")]
pub timeout: u64,
#[serde(default)]
pub retry: u32,
#[serde(default)]
pub on_error: OnError,
#[serde(default, skip_serializing_if = "HookFilter::is_empty")]
pub filter: HookFilter,
}
impl HookDirDef {
pub fn into_hook_kinds(
self,
name: &str,
dir_path: &Path,
) -> Result<Vec<(HookEvent, HookKind)>, String> {
if self.events.is_empty() {
return Err(format!("hook '{}' 的 events 为空", name));
}
let kind = match self.r#type {
HookType::Bash => {
let command = self.command.unwrap_or_default();
if command.is_empty() {
return Err(format!("bash hook '{}' 缺少 command 字段", name));
}
HookKind::Shell(ShellHook {
name: Some(name.to_string()),
command,
timeout: self.timeout,
retry: self.retry,
on_error: self.on_error,
filter: self.filter,
dir_path: Some(dir_path.to_path_buf()),
})
}
HookType::Llm => {
let prompt = self.prompt.unwrap_or_default();
if prompt.is_empty() {
return Err(format!("llm hook '{}' 缺少 prompt 字段", name));
}
HookKind::Llm(LlmHook {
name: Some(name.to_string()),
prompt,
model: self.model,
timeout: if self.timeout == default_timeout() {
default_llm_timeout()
} else {
self.timeout
},
retry: if self.retry == 0 { 1 } else { self.retry },
on_error: self.on_error,
filter: self.filter,
dir_path: Some(dir_path.to_path_buf()),
})
}
};
Ok(self.events.into_iter().map(|e| (e, kind.clone())).collect())
}
}
pub fn hooks_dir() -> PathBuf {
let dir = YamlConfig::data_dir().join("agent").join("hooks");
let _ = std::fs::create_dir_all(&dir);
dir
}
pub fn project_hooks_dir() -> Option<PathBuf> {
let config_dir = JcliConfig::find_config_dir()?;
let dir = config_dir.join("hooks");
if dir.is_dir() { Some(dir) } else { None }
}
fn load_hooks_from_dir(dir: &Path, source_name: &str) -> Vec<(String, HookDirDef, PathBuf)> {
let mut hooks = Vec::new();
let entries = match std::fs::read_dir(dir) {
Ok(e) => e,
Err(_) => return hooks,
};
for entry in entries.flatten() {
let path = entry.path();
if !path.is_dir() {
continue;
}
let hook_yaml = path.join("HOOK.yaml");
let hook_name = path
.file_name()
.unwrap_or_default()
.to_string_lossy()
.to_string();
if !hook_yaml.exists() {
continue;
}
match std::fs::read_to_string(&hook_yaml) {
Ok(content) => match serde_yaml::from_str::<HookDirDef>(&content) {
Ok(def) => {
if def.events.is_empty() {
write_error_log(
"load_hooks_from_dir",
&format!("hook '{}' 的 events 为空,跳过", hook_name),
);
continue;
}
hooks.push((hook_name, def, path));
}
Err(e) => write_error_log(
"load_hooks_from_dir",
&format!("解析 {}/HOOK.yaml 失败: {}", hook_name, e),
),
},
Err(e) => write_error_log(
"load_hooks_from_dir",
&format!("读取 {}/HOOK.yaml 失败: {}", hook_name, e),
),
}
}
write_info_log(
"load_hooks_from_dir",
&format!("从 {} 加载了 {} 个 hook", source_name, hooks.len()),
);
hooks
}
pub type BuiltinHookFn = Arc<dyn Fn(&HookContext) -> Option<HookResult> + Send + Sync>;
pub struct BuiltinHook {
pub name: String,
pub handler: BuiltinHookFn,
}
impl Clone for BuiltinHook {
fn clone(&self) -> Self {
BuiltinHook {
name: self.name.clone(),
handler: Arc::clone(&self.handler),
}
}
}
impl std::fmt::Debug for HookKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
HookKind::Shell(shell) => f
.debug_struct("HookKind::Shell")
.field("name", &shell.name)
.field("command", &shell.command)
.field("timeout", &shell.timeout)
.field("on_error", &shell.on_error)
.finish(),
HookKind::Llm(llm) => f
.debug_struct("HookKind::Llm")
.field("name", &llm.name)
.field("prompt", &llm.prompt.len())
.field("model", &llm.model)
.field("timeout", &llm.timeout)
.field("retry", &llm.retry)
.finish(),
HookKind::Builtin(builtin) => f
.debug_struct("HookKind::Builtin")
.field("name", &builtin.name)
.finish(),
}
}
}
#[derive(Debug, Serialize)]
pub struct HookContext {
pub event: HookEvent,
#[serde(skip_serializing_if = "Option::is_none")]
pub messages: Option<Vec<ChatMessage>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_prompt: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub user_input: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub assistant_output: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_arguments: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_result: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_error: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub session_id: Option<String>,
pub cwd: String,
}
impl Default for HookContext {
fn default() -> Self {
Self {
event: HookEvent::SessionStart,
messages: None,
system_prompt: None,
model: None,
user_input: None,
assistant_output: None,
tool_name: None,
tool_arguments: None,
tool_result: None,
tool_error: None,
session_id: None,
cwd: std::env::current_dir()
.map(|p| p.display().to_string())
.unwrap_or_else(|_| ".".to_string()),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum HookAction {
Stop,
Skip,
}
#[derive(Debug, Deserialize, Default)]
pub struct HookResult {
#[serde(default)]
pub messages: Option<Vec<ChatMessage>>,
#[serde(default)]
pub system_prompt: Option<String>,
#[serde(default)]
pub user_input: Option<String>,
#[serde(default)]
pub assistant_output: Option<String>,
#[serde(default)]
pub tool_arguments: Option<String>,
#[serde(default)]
pub tool_result: Option<String>,
#[serde(default)]
pub tool_error: Option<String>,
#[serde(default)]
pub inject_messages: Option<Vec<ChatMessage>>,
#[serde(default)]
pub retry_feedback: Option<String>,
#[serde(default)]
pub additional_context: Option<String>,
#[serde(default)]
pub system_message: Option<String>,
#[serde(default)]
pub action: Option<HookAction>,
}
impl HookResult {
pub fn is_stop(&self) -> bool {
self.action == Some(HookAction::Stop)
}
pub fn is_skip(&self) -> bool {
self.action == Some(HookAction::Skip)
}
pub fn is_halt(&self) -> bool {
self.is_stop() || self.is_skip()
}
}
#[derive(Debug)]
#[allow(dead_code, clippy::large_enum_variant)]
enum HookOutcome {
Success(HookResult),
Retry {
error: String,
#[allow(dead_code)]
attempts_left: u32,
},
Err(String),
}
#[derive(Debug, Clone, Default, PartialEq)]
pub struct HookMetrics {
pub executions: u64,
pub successes: u64,
pub failures: u64,
pub skipped: u64,
pub total_duration_ms: u64,
}
#[derive(Debug, Default)]
pub struct HookManager {
builtin_hooks: HashMap<HookEvent, Vec<HookKind>>,
user_hooks: HashMap<HookEvent, Vec<HookKind>>,
project_hooks: HashMap<HookEvent, Vec<HookKind>>,
session_hooks: HashMap<HookEvent, Vec<HookKind>>,
metrics: Mutex<HashMap<String, HookMetrics>>,
provider: Option<Arc<Mutex<ModelProvider>>>,
}
impl Clone for HookManager {
fn clone(&self) -> Self {
HookManager {
builtin_hooks: self.builtin_hooks.clone(),
user_hooks: self.user_hooks.clone(),
project_hooks: self.project_hooks.clone(),
session_hooks: self.session_hooks.clone(),
metrics: Mutex::new(self.metrics.lock().map(|m| m.clone()).unwrap_or_default()),
provider: self.provider.clone(),
}
}
}
const HOOK_SOURCE_BUILTIN: &str = "builtin";
const HOOK_SOURCE_USER: &str = "user";
const HOOK_SOURCE_PROJECT: &str = "project";
const HOOK_SOURCE_SESSION: &str = "session";
pub struct HookEntry {
pub name: Option<String>,
pub event: HookEvent,
pub source: &'static str,
pub hook_type: &'static str,
pub label: String,
pub timeout: Option<u64>,
pub on_error: Option<OnError>,
pub session_index: Option<usize>,
pub filter: Option<HookFilter>,
pub metrics: Option<HookMetrics>,
}
impl HookManager {
pub fn load() -> Self {
let mut manager = HookManager::default();
let user_dir = hooks_dir();
if user_dir.is_dir() {
for (name, dir_def, dir_path) in load_hooks_from_dir(&user_dir, "用户级") {
match dir_def.into_hook_kinds(&name, &dir_path) {
Ok(pairs) => {
for (event, kind) in pairs {
manager.user_hooks.entry(event).or_default().push(kind);
}
}
Err(e) => write_error_log("HookManager::load", &e),
}
}
write_info_log(
"HookManager::load",
&format!("已加载用户级 hooks: {}", user_dir.display()),
);
}
if let Some(proj_dir) = project_hooks_dir() {
for (name, dir_def, dir_path) in load_hooks_from_dir(&proj_dir, "项目级") {
match dir_def.into_hook_kinds(&name, &dir_path) {
Ok(pairs) => {
for (event, kind) in pairs {
manager.project_hooks.entry(event).or_default().push(kind);
}
}
Err(e) => write_error_log("HookManager::load", &e),
}
}
write_info_log(
"HookManager::load",
&format!("已加载项目级 hooks: {}", proj_dir.display()),
);
}
manager
}
pub fn register_builtin(
&mut self,
event: HookEvent,
name: impl Into<String>,
handler: impl Fn(&HookContext) -> Option<HookResult> + Send + Sync + 'static,
) {
self.builtin_hooks
.entry(event)
.or_default()
.push(HookKind::Builtin(BuiltinHook {
name: name.into(),
handler: Arc::new(handler),
}));
}
pub fn register_session_hook(&mut self, event: HookEvent, def: HookDef) {
match def.into_hook_kind() {
Ok(kind) => {
self.session_hooks.entry(event).or_default().push(kind);
}
Err(e) => {
write_error_log("HookManager::register_session_hook", &e);
}
}
}
pub fn session_hooks_snapshot(&self) -> Vec<super::super::storage::SessionHookPersist> {
let mut result = Vec::new();
for (event, hooks) in &self.session_hooks {
for kind in hooks {
match kind {
HookKind::Shell(sh) => {
result.push(super::super::storage::SessionHookPersist {
event: *event,
definition: HookDef {
r#type: HookType::Bash,
command: Some(sh.command.clone()),
prompt: None,
model: None,
timeout: sh.timeout,
retry: sh.retry,
on_error: sh.on_error,
filter: sh.filter.clone(),
},
});
}
HookKind::Llm(lh) => {
result.push(super::super::storage::SessionHookPersist {
event: *event,
definition: HookDef {
r#type: HookType::Llm,
command: None,
prompt: Some(lh.prompt.clone()),
model: lh.model.clone(),
timeout: lh.timeout,
retry: lh.retry,
on_error: lh.on_error,
filter: lh.filter.clone(),
},
});
}
HookKind::Builtin(_) => {
}
}
}
}
result
}
pub fn clear_session_hooks(&mut self) {
self.session_hooks.clear();
}
pub fn restore_session_hooks(&mut self, hooks: &[super::super::storage::SessionHookPersist]) {
self.session_hooks.clear();
for hook in hooks {
self.register_session_hook(hook.event, hook.definition.clone());
}
}
#[allow(dead_code)]
pub fn register_session_hook_kind(&mut self, event: HookEvent, kind: HookKind) {
self.session_hooks.entry(event).or_default().push(kind);
}
pub fn set_provider(&mut self, provider: Arc<Mutex<ModelProvider>>) {
self.provider = Some(provider);
}
pub fn remove_session_hook(&mut self, event: HookEvent, index: usize) -> bool {
if let Some(hooks) = self.session_hooks.get_mut(&event)
&& index < hooks.len()
{
hooks.remove(index);
return true;
}
false
}
pub fn list_hooks(&self) -> Vec<HookEntry> {
let mut result = Vec::new();
let metrics_map = self.metrics.lock().ok();
let empty_metrics = HashMap::new();
let metrics_ref = metrics_map.as_deref().unwrap_or(&empty_metrics);
let make_entry = |event: HookEvent,
source: &'static str,
hook: &HookKind,
session_index: Option<usize>,
metrics: &HashMap<String, HookMetrics>| {
let label = hook_label(hook);
HookEntry {
name: hook_name(hook).map(|s| s.to_string()),
event,
source,
hook_type: hook_type_str(hook),
timeout: hook_timeout(hook),
on_error: hook_on_error(hook),
filter: hook_filter(hook).cloned(),
metrics: metrics.get(&label).cloned(),
session_index,
label,
}
};
for event in HookEvent::all() {
if let Some(hooks) = self.builtin_hooks.get(event) {
for hook in hooks {
result.push(make_entry(
*event,
HOOK_SOURCE_BUILTIN,
hook,
None,
metrics_ref,
));
}
}
if let Some(hooks) = self.user_hooks.get(event) {
for hook in hooks {
result.push(make_entry(
*event,
HOOK_SOURCE_USER,
hook,
None,
metrics_ref,
));
}
}
if let Some(hooks) = self.project_hooks.get(event) {
for hook in hooks {
result.push(make_entry(
*event,
HOOK_SOURCE_PROJECT,
hook,
None,
metrics_ref,
));
}
}
if let Some(hooks) = self.session_hooks.get(event) {
for (idx, hook) in hooks.iter().enumerate() {
result.push(make_entry(
*event,
HOOK_SOURCE_SESSION,
hook,
Some(idx),
metrics_ref,
));
}
}
}
result
}
#[allow(dead_code)]
pub fn reload(&mut self) {
let fresh = HookManager::load();
self.user_hooks = fresh.user_hooks;
self.project_hooks = fresh.project_hooks;
write_info_log("HookManager::reload", "已重新加载用户级和项目级 hooks");
}
#[allow(dead_code)]
pub fn get_metrics(&self) -> HashMap<String, HookMetrics> {
self.metrics.lock().map(|m| m.clone()).unwrap_or_default()
}
pub fn has_hooks_for(&self, event: HookEvent) -> bool {
self.builtin_hooks
.get(&event)
.is_some_and(|h| !h.is_empty())
|| self.user_hooks.get(&event).is_some_and(|h| !h.is_empty())
|| self
.project_hooks
.get(&event)
.is_some_and(|h| !h.is_empty())
|| self
.session_hooks
.get(&event)
.is_some_and(|h| !h.is_empty())
}
pub fn execute_fire_and_forget(
manager: Arc<Mutex<HookManager>>,
event: HookEvent,
context: HookContext,
) {
std::thread::spawn(move || {
if let Ok(m) = manager.lock() {
let _ = m.execute(event, context);
}
});
}
pub fn execute(&self, event: HookEvent, mut context: HookContext) -> Option<HookResult> {
let mut all_hooks: Vec<&HookKind> = Vec::new();
if let Some(hooks) = self.builtin_hooks.get(&event) {
all_hooks.extend(hooks.iter());
}
if let Some(hooks) = self.user_hooks.get(&event) {
all_hooks.extend(hooks.iter());
}
if let Some(hooks) = self.project_hooks.get(&event) {
all_hooks.extend(hooks.iter());
}
if let Some(hooks) = self.session_hooks.get(&event) {
all_hooks.extend(hooks.iter());
}
if all_hooks.is_empty() {
return None;
}
write_info_log(
"HookManager::execute",
&format!(
"执行 {} 个 hook (事件: {})",
all_hooks.len(),
event.as_str()
),
);
let mut had_modification = false;
let mut final_result = HookResult::default();
let chain_start = std::time::Instant::now();
let chain_timeout = std::time::Duration::from_secs(MAX_CHAIN_DURATION_SECS);
for hook in all_hooks {
if chain_start.elapsed() > chain_timeout {
write_error_log(
"HookManager::execute",
&format!(
"Hook 链总超时 ({}s),中止剩余 hook (事件: {})",
MAX_CHAIN_DURATION_SECS,
event.as_str()
),
);
break;
}
let label = hook_label(hook);
if !hook_should_execute(hook, &context) {
if let Ok(mut metrics) = self.metrics.lock() {
let m = metrics.entry(label).or_default();
m.skipped += 1;
}
continue;
}
let max_attempts = 1 + hook_retry_count(hook); let mut last_outcome = None;
for attempt in 0..max_attempts {
if chain_start.elapsed() > chain_timeout {
write_error_log(
"HookManager::execute",
&format!(
"Hook 链总超时,中止 {} 的重试 (事件: {})",
label,
event.as_str()
),
);
last_outcome = Some(HookOutcome::Err(format!(
"链总超时,第 {} 次尝试中止",
attempt + 1
)));
break;
}
let hook_start = std::time::Instant::now();
let result = execute_hook_with_provider(hook, &context, &self.provider);
let elapsed_ms = hook_start.elapsed().as_millis() as u64;
match result {
Ok(hook_result) => {
if let Ok(mut metrics) = self.metrics.lock() {
let m = metrics.entry(label.clone()).or_default();
m.executions += 1;
m.successes += 1;
m.total_duration_ms += elapsed_ms;
}
if hook_result.is_halt() {
let action_str = if hook_result.is_stop() {
"stop"
} else {
"skip"
};
write_info_log(
"HookManager::execute",
&format!("Hook {} ({})", action_str, label),
);
return Some(HookResult {
action: Some(if hook_result.is_stop() {
HookAction::Stop
} else {
HookAction::Skip
}),
retry_feedback: hook_result.retry_feedback.clone(),
system_message: hook_result.system_message.clone(),
..Default::default()
});
}
if let Some(ref msgs) = hook_result.messages {
context.messages = Some(msgs.clone());
final_result.messages = context.messages.clone();
had_modification = true;
}
if let Some(ref sp) = hook_result.system_prompt {
context.system_prompt = Some(sp.clone());
final_result.system_prompt = context.system_prompt.clone();
had_modification = true;
}
if let Some(ref ui) = hook_result.user_input {
context.user_input = Some(ui.clone());
final_result.user_input = context.user_input.clone();
had_modification = true;
}
if let Some(ref ao) = hook_result.assistant_output {
context.assistant_output = Some(ao.clone());
final_result.assistant_output = context.assistant_output.clone();
had_modification = true;
}
if let Some(ref ta) = hook_result.tool_arguments {
context.tool_arguments = Some(ta.clone());
final_result.tool_arguments = context.tool_arguments.clone();
had_modification = true;
}
if let Some(ref tr) = hook_result.tool_result {
context.tool_result = Some(tr.clone());
final_result.tool_result = context.tool_result.clone();
had_modification = true;
}
if let Some(ref inject) = hook_result.inject_messages {
let existing =
final_result.inject_messages.get_or_insert_with(Vec::new);
existing.extend(inject.clone());
had_modification = true;
}
if let Some(ref rf) = hook_result.retry_feedback {
final_result.retry_feedback = Some(rf.clone());
had_modification = true;
}
if let Some(ref ac) = hook_result.additional_context {
final_result.additional_context = Some(ac.clone());
had_modification = true;
}
if let Some(ref sm) = hook_result.system_message {
final_result.system_message = Some(sm.clone());
had_modification = true;
}
if let Some(ref te) = hook_result.tool_error {
final_result.tool_error = Some(te.clone());
had_modification = true;
}
last_outcome = Some(HookOutcome::Success(hook_result));
break; }
Err(e) => {
if let Ok(mut metrics) = self.metrics.lock() {
let m = metrics.entry(label.clone()).or_default();
m.executions += 1;
m.failures += 1;
m.total_duration_ms += elapsed_ms;
}
let attempts_left = max_attempts - attempt - 1;
if attempts_left > 0 {
write_info_log(
"HookManager::execute",
&format!(
"Hook 执行失败 ({}), 第 {}/{} 次尝试, 剩余重试 {}: {}",
label,
attempt + 1,
max_attempts,
attempts_left,
e
),
);
last_outcome = Some(HookOutcome::Retry {
error: e,
attempts_left,
});
} else {
write_error_log(
"HookManager::execute",
&format!("Hook 执行失败 ({}), 重试耗尽: {}", label, e),
);
last_outcome = Some(HookOutcome::Err(e));
break; }
}
}
}
match last_outcome {
Some(HookOutcome::Success(_)) => {
}
Some(HookOutcome::Retry { error, .. }) => {
write_error_log(
"HookManager::execute",
&format!("Hook 重试未完成 ({}): {}", label, error),
);
match hook_on_error_strategy(hook) {
OnError::Abort => {
return Some(HookResult {
action: Some(HookAction::Stop),
..Default::default()
});
}
OnError::Skip => {
continue;
}
}
}
Some(HookOutcome::Err(e)) => {
write_error_log(
"HookManager::execute",
&format!("Hook 最终失败 ({}): {}", label, e),
);
match hook_on_error_strategy(hook) {
OnError::Abort => {
return Some(HookResult {
action: Some(HookAction::Stop),
..Default::default()
});
}
OnError::Skip => {
continue;
}
}
}
None => {
continue;
}
}
}
if had_modification {
Some(final_result)
} else {
None
}
}
}
fn execute_hook_with_provider(
kind: &HookKind,
context: &HookContext,
provider: &Option<Arc<Mutex<ModelProvider>>>,
) -> Result<HookResult, String> {
match kind {
HookKind::Shell(shell) => execute_shell_hook(shell, context),
HookKind::Llm(llm) => execute_llm_hook(llm, context, provider),
HookKind::Builtin(builtin) => match (builtin.handler)(context) {
Some(result) => Ok(result),
None => Ok(HookResult::default()),
},
}
}
const LLM_HOOK_FORMAT_INSTRUCTION: &str = r#"
---
You are a hook function. You MUST respond with ONLY a valid JSON object matching this schema (no markdown, no explanation outside JSON):
{
"user_input": "string (optional, replace user message)",
"assistant_output": "string (optional, replace assistant output)",
"messages": [{"role":"user","content":"..."}] (optional, replace message list),
"system_prompt": "string (optional, replace system prompt)",
"tool_arguments": "string (optional, replace tool arguments JSON)",
"tool_result": "string (optional, replace tool result)",
"tool_error": "string (optional, replace tool error)",
"inject_messages": [{"role":"user","content":"..."}] (optional, append messages),
"action": "stop" or "skip" (optional, stop=abort pipeline, skip=skip current step),
"retry_feedback": "string (optional, feedback to retry with)",
"additional_context": "string (optional, append to system_prompt)",
"system_message": "string (optional, show toast to user)"
}
Return {} if no modification needed."#;
fn render_prompt_template(template: &str, context: &HookContext) -> String {
let mut result = template.to_string();
result = result.replace("{{event}}", context.event.as_str());
result = result.replace("{{cwd}}", &context.cwd);
result = result.replace(
"{{user_input}}",
context.user_input.as_deref().unwrap_or(""),
);
result = result.replace(
"{{assistant_output}}",
context.assistant_output.as_deref().unwrap_or(""),
);
result = result.replace("{{tool_name}}", context.tool_name.as_deref().unwrap_or(""));
result = result.replace(
"{{tool_arguments}}",
context.tool_arguments.as_deref().unwrap_or(""),
);
result = result.replace(
"{{tool_result}}",
context.tool_result.as_deref().unwrap_or(""),
);
result = result.replace("{{model}}", context.model.as_deref().unwrap_or(""));
result
}
fn extract_json_from_llm_output(text: &str) -> Option<&str> {
let start = text.find('{')?;
let end = text.rfind('}')?;
if end > start {
Some(&text[start..=end])
} else {
None
}
}
fn execute_llm_hook(
hook: &LlmHook,
context: &HookContext,
provider_opt: &Option<Arc<Mutex<ModelProvider>>>,
) -> Result<HookResult, String> {
let provider_arc = provider_opt
.as_ref()
.ok_or("LLM hook 无法执行:未注入 provider")?;
let provider = provider_arc
.lock()
.map_err(|e| format!("获取 provider 锁失败: {}", e))?
.clone();
let provider = if let Some(ref model) = hook.model {
let mut p = provider;
p.model = model.clone();
p
} else {
provider
};
let rendered = render_prompt_template(&hook.prompt, context);
let full_prompt = format!("{}{}", rendered, LLM_HOOK_FORMAT_INSTRUCTION);
let system_msg = "You are a hook function. Respond ONLY with the JSON object as instructed.";
let user_msg = full_prompt.as_str();
let url = format!(
"{}/chat/completions",
provider.api_base.trim_end_matches('/')
);
let request_body = serde_json::json!({
"model": provider.model,
"messages": [
{"role": "system", "content": system_msg},
{"role": "user", "content": user_msg}
],
"temperature": 0.0,
"max_tokens": HOOK_LLM_MAX_TOKENS,
});
let request_str = serde_json::to_string(&request_body)
.map_err(|e| format!("序列化 LLM hook 请求失败: {}", e))?;
let timeout_secs = hook.timeout;
let rt =
tokio::runtime::Runtime::new().map_err(|e| format!("创建 tokio runtime 失败: {}", e))?;
rt.block_on(async {
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(timeout_secs))
.build()
.map_err(|e| format!("创建 HTTP client 失败: {}", e))?;
let resp = client
.post(&url)
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", provider.api_key))
.body(request_str)
.send()
.await
.map_err(|e| format!("LLM hook 请求失败: {}", e))?;
let status = resp.status();
let body = resp
.text()
.await
.map_err(|e| format!("读取 LLM hook 响应失败: {}", e))?;
if !status.is_success() {
return Err(format!(
"LLM hook API 错误: HTTP {} (body: {})",
status,
&body[..body.len().min(500)]
));
}
let parsed: serde_json::Value = serde_json::from_str(&body)
.map_err(|e| format!("解析 LLM hook 响应 JSON 失败: {}", e))?;
let content = parsed["choices"][0]["message"]["content"]
.as_str()
.unwrap_or("")
.trim();
if content.is_empty() || content == "{}" {
return Ok(HookResult::default());
}
let json_str = match extract_json_from_llm_output(content) {
Some(s) => s,
None => {
return Err(format!(
"LLM hook 输出中未找到 JSON (输出: {})",
&content[..content.len().min(500)]
));
}
};
let hook_result: HookResult = serde_json::from_str(json_str).map_err(|e| {
format!(
"解析 LLM hook JSON 失败: {} (提取的 JSON: {})",
e,
&json_str[..json_str.len().min(500)]
)
})?;
write_info_log(
"execute_llm_hook",
&format!(
"LLM hook 完成 (prompt_len={}, model={}), action={:?}",
hook.prompt.len(),
provider.model,
hook_result.action
),
);
Ok(hook_result)
})
}
fn execute_shell_hook(hook: &ShellHook, context: &HookContext) -> Result<HookResult, String> {
let context_json =
serde_json::to_string(context).map_err(|e| format!("序列化 context 失败: {}", e))?;
let user_cwd = std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."));
let hook_dir_str = hook
.dir_path
.as_ref()
.map(|p| p.display().to_string())
.unwrap_or_default();
let mut cmd = Command::new("sh");
cmd.arg("-c")
.arg(&hook.command)
.current_dir(&user_cwd)
.env("JCLI_HOOK_EVENT", context.event.as_str())
.env("JCLI_CWD", user_cwd.display().to_string())
.env("JCLI_HOOK_DIR", &hook_dir_str);
if let Some(ref hook_dir) = hook.dir_path {
let existing_path = std::env::var("PATH").unwrap_or_default();
let new_path = if existing_path.is_empty() {
hook_dir.display().to_string()
} else {
format!("{}:{}", hook_dir.display(), existing_path)
};
cmd.env("PATH", new_path);
}
let mut child = cmd
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.spawn()
.map_err(|e| format!("启动 hook 进程失败: {}", e))?;
let pid = child.id();
if let Some(mut stdin) = child.stdin.take() {
let _ = stdin.write_all(context_json.as_bytes());
}
let (tx, rx) = std::sync::mpsc::channel();
std::thread::spawn(move || {
let _ = tx.send(child.wait_with_output());
});
let timeout = std::time::Duration::from_secs(hook.timeout);
match rx.recv_timeout(timeout) {
Ok(Ok(output)) => {
let stderr_str = String::from_utf8_lossy(&output.stderr).trim().to_string();
if !stderr_str.is_empty() {
write_info_log(
"execute_shell_hook",
&format!("Hook stderr ({}): {}", hook.command, stderr_str),
);
}
if !output.status.success() {
let mut err = format!("Hook 退出码: {:?}", output.status.code());
if !stderr_str.is_empty() {
err.push_str(&format!(", stderr: {}", stderr_str));
}
return Err(err);
}
let stdout = String::from_utf8_lossy(&output.stdout);
let stdout = stdout.trim();
if stdout.is_empty() || stdout == "{}" {
return Ok(HookResult::default());
}
let result: HookResult = serde_json::from_str(stdout)
.map_err(|e| format!("解析 hook 输出 JSON 失败: {} (输出: {})", e, stdout))?;
write_info_log(
"execute_shell_hook",
&format!(
"Hook 完成 (cmd: {}), action={:?}",
hook.command, result.action
),
);
Ok(result)
}
Ok(Err(e)) => Err(format!("等待 hook 进程失败: {}", e)),
Err(_) => {
let _ = signal::kill(Pid::from_raw(pid as i32), Signal::SIGKILL);
Err(format!("Hook 超时 ({}s): {}", hook.timeout, hook.command))
}
}
}
fn hook_name(kind: &HookKind) -> Option<&str> {
match kind {
HookKind::Shell(shell) => shell.name.as_deref(),
HookKind::Llm(llm) => llm.name.as_deref(),
HookKind::Builtin(builtin) => Some(&builtin.name),
}
}
fn hook_label(kind: &HookKind) -> String {
match kind {
HookKind::Shell(shell) => {
if let Some(ref name) = shell.name {
format!("{}: {}", name, shell.command)
} else {
shell.command.clone()
}
}
HookKind::Llm(llm) => {
let first_line = llm
.prompt
.lines()
.find(|l| !l.trim().is_empty())
.unwrap_or(&llm.prompt);
let prompt_preview = if first_line.len() > 80 {
format!("{}...", &first_line[..80])
} else {
first_line.to_string()
};
if let Some(ref name) = llm.name {
format!("[llm: {}] {}", name, prompt_preview)
} else {
format!("[llm: {}]", prompt_preview)
}
}
HookKind::Builtin(builtin) => format!("[builtin: {}]", builtin.name),
}
}
fn hook_type_str(kind: &HookKind) -> &'static str {
match kind {
HookKind::Shell(_) => "bash",
HookKind::Llm(_) => "llm",
HookKind::Builtin(_) => "builtin",
}
}
fn hook_timeout(kind: &HookKind) -> Option<u64> {
match kind {
HookKind::Shell(shell) => Some(shell.timeout),
HookKind::Llm(llm) => Some(llm.timeout),
HookKind::Builtin(_) => None,
}
}
fn hook_retry_count(kind: &HookKind) -> u32 {
match kind {
HookKind::Shell(shell) => shell.retry,
HookKind::Llm(llm) => llm.retry,
HookKind::Builtin(_) => 0,
}
}
fn hook_on_error(kind: &HookKind) -> Option<OnError> {
match kind {
HookKind::Shell(shell) => Some(shell.on_error),
HookKind::Llm(llm) => Some(llm.on_error),
HookKind::Builtin(_) => None,
}
}
fn hook_on_error_strategy(kind: &HookKind) -> OnError {
match kind {
HookKind::Shell(shell) => shell.on_error,
HookKind::Llm(llm) => llm.on_error,
HookKind::Builtin(_) => OnError::Abort,
}
}
fn hook_filter(kind: &HookKind) -> Option<&HookFilter> {
match kind {
HookKind::Shell(shell) if !shell.filter.is_empty() => Some(&shell.filter),
HookKind::Llm(llm) if !llm.filter.is_empty() => Some(&llm.filter),
_ => None,
}
}
fn hook_should_execute(kind: &HookKind, context: &HookContext) -> bool {
match kind {
HookKind::Shell(shell) => shell.filter.matches(context),
HookKind::Llm(llm) => llm.filter.matches(context),
HookKind::Builtin(_) => true,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hook_event_roundtrip() {
for event in HookEvent::all() {
let s = event.as_str();
let parsed = HookEvent::parse(s).unwrap();
assert_eq!(*event, parsed);
}
}
#[test]
fn test_hook_event_from_str_invalid() {
assert!(HookEvent::parse("unknown_event").is_none());
}
#[test]
fn test_hook_def_default_timeout() {
let yaml = r#"command: "echo hello""#;
let def: HookDef = serde_yaml::from_str(yaml).unwrap();
assert_eq!(def.timeout, 10);
assert_eq!(def.r#type, HookType::Bash);
}
#[test]
fn test_hook_def_to_hook_kind_bash() {
let def = HookDef {
r#type: HookType::Bash,
command: Some("echo test".to_string()),
prompt: None,
model: None,
timeout: 5,
retry: 0,
on_error: OnError::Skip,
filter: HookFilter::default(),
};
let kind = HookKind::from(def);
match kind {
HookKind::Shell(shell) => {
assert_eq!(shell.command, "echo test");
assert_eq!(shell.timeout, 5);
}
_ => panic!("应该转换为 Shell 变体"),
}
}
#[test]
fn test_hook_def_to_hook_kind_llm() {
let def = HookDef {
r#type: HookType::Llm,
command: None,
prompt: Some("检查敏感信息: {{user_input}}".to_string()),
model: Some("gpt-4o".to_string()),
timeout: 10, retry: 2,
on_error: OnError::Skip,
filter: HookFilter::default(),
};
let kind = def.into_hook_kind().unwrap();
match kind {
HookKind::Llm(llm) => {
assert_eq!(llm.prompt, "检查敏感信息: {{user_input}}");
assert_eq!(llm.model.as_deref(), Some("gpt-4o"));
assert_eq!(llm.timeout, 30); assert_eq!(llm.retry, 2);
}
_ => panic!("应该转换为 Llm 变体"),
}
}
#[test]
fn test_hook_def_llm_explicit_timeout() {
let def = HookDef {
r#type: HookType::Llm,
command: None,
prompt: Some("test prompt".to_string()),
model: None,
timeout: 60,
retry: 0,
on_error: OnError::Skip,
filter: HookFilter::default(),
};
let kind = def.into_hook_kind().unwrap();
match kind {
HookKind::Llm(llm) => {
assert_eq!(llm.timeout, 60); }
_ => panic!("应该转换为 Llm 变体"),
}
}
#[test]
fn test_hook_def_yaml_with_type() {
let yaml = r#"
type: llm
prompt: "检查敏感信息"
model: gpt-4o
timeout: 30
retry: 2
"#;
let def: HookDef = serde_yaml::from_str(yaml).unwrap();
assert_eq!(def.r#type, HookType::Llm);
assert_eq!(def.prompt.as_deref(), Some("检查敏感信息"));
assert_eq!(def.model.as_deref(), Some("gpt-4o"));
assert_eq!(def.timeout, 30);
assert_eq!(def.retry, 2);
}
#[test]
fn test_hook_result_empty_json() {
let result: HookResult = serde_json::from_str("{}").unwrap();
assert!(!result.is_halt());
assert!(result.messages.is_none());
assert!(result.user_input.is_none());
}
#[test]
fn test_hook_result_with_abort() {
let json = r#"{"action": "stop"}"#;
let result: HookResult = serde_json::from_str(json).unwrap();
assert!(result.is_stop());
}
#[test]
fn test_hook_result_with_action_stop() {
let json = r#"{"action": "stop"}"#;
let result: HookResult = serde_json::from_str(json).unwrap();
assert!(result.is_stop());
assert!(!result.is_skip());
}
#[test]
fn test_hook_result_with_action_skip() {
let json = r#"{"action": "skip"}"#;
let result: HookResult = serde_json::from_str(json).unwrap();
assert!(result.is_skip());
assert!(!result.is_stop());
}
#[test]
fn test_hook_result_with_user_input() {
let json = r#"{"user_input": "[modified] hello"}"#;
let result: HookResult = serde_json::from_str(json).unwrap();
assert_eq!(result.user_input.as_deref(), Some("[modified] hello"));
}
#[test]
fn test_hook_context_serialization() {
let ctx = HookContext {
event: HookEvent::PreSendMessage,
user_input: Some("hello".to_string()),
..Default::default()
};
let json = serde_json::to_string(&ctx).unwrap();
assert!(json.contains("pre_send_message"));
assert!(json.contains("hello"));
assert!(json.contains("user_input"));
assert!(!json.contains("messages"));
assert!(!json.contains("tool_name"));
}
#[test]
fn test_execute_shell_hook_echo() {
let hook = ShellHook {
name: None,
command: r#"echo '{"user_input": "hooked"}'"#.to_string(),
timeout: 5,
retry: 0,
on_error: OnError::Skip,
filter: HookFilter::default(),
dir_path: None,
};
let ctx = HookContext {
event: HookEvent::PreSendMessage,
user_input: Some("original".to_string()),
..Default::default()
};
let result = execute_shell_hook(&hook, &ctx).unwrap();
assert_eq!(result.user_input.as_deref(), Some("hooked"));
assert!(!result.is_halt());
}
#[test]
fn test_execute_shell_hook_empty_output() {
let hook = ShellHook {
name: None,
command: "echo ''".to_string(),
timeout: 5,
retry: 0,
on_error: OnError::Skip,
filter: HookFilter::default(),
dir_path: None,
};
let ctx = HookContext::default();
let result = execute_shell_hook(&hook, &ctx).unwrap();
assert!(!result.is_halt());
assert!(result.user_input.is_none());
}
#[test]
fn test_execute_shell_hook_nonzero_exit() {
let hook = ShellHook {
name: None,
command: "exit 1".to_string(),
timeout: 5,
retry: 0,
on_error: OnError::Skip,
filter: HookFilter::default(),
dir_path: None,
};
let ctx = HookContext::default();
let result = execute_shell_hook(&hook, &ctx);
assert!(result.is_err());
}
#[test]
fn test_execute_shell_hook_reads_stdin() {
let hook = ShellHook {
name: None,
command: r#"input=$(cat); event=$(echo "$input" | python3 -c "import sys,json; print(json.load(sys.stdin).get('event',''))" 2>/dev/null || echo ""); echo '{"user_input": "got_input"}'"#.to_string(),
timeout: 5,
retry: 0,
on_error: OnError::Skip,
filter: HookFilter::default(),
dir_path: None,
};
let ctx = HookContext {
event: HookEvent::PreSendMessage,
user_input: Some("test".to_string()),
..Default::default()
};
let result = execute_shell_hook(&hook, &ctx).unwrap();
assert_eq!(result.user_input.as_deref(), Some("got_input"));
}
#[test]
fn test_execute_builtin_hook() {
let builtin = BuiltinHook {
name: "test_hook".to_string(),
handler: Arc::new(|ctx| {
if let Some(ref input) = ctx.user_input {
Some(HookResult {
user_input: Some(format!("[hooked] {}", input)),
..Default::default()
})
} else {
None
}
}),
};
let kind = HookKind::Builtin(builtin);
let ctx = HookContext {
event: HookEvent::PreSendMessage,
user_input: Some("original".to_string()),
..Default::default()
};
let result = execute_hook_with_provider(&kind, &ctx, &None).unwrap();
assert_eq!(result.user_input.as_deref(), Some("[hooked] original"));
}
#[test]
fn test_execute_builtin_hook_returns_none() {
let builtin = BuiltinHook {
name: "no_op".to_string(),
handler: Arc::new(|_| None),
};
let kind = HookKind::Builtin(builtin);
let ctx = HookContext::default();
let result = execute_hook_with_provider(&kind, &ctx, &None).unwrap();
assert!(!result.is_halt());
assert!(result.user_input.is_none());
}
#[test]
fn test_hook_manager_empty() {
let manager = HookManager::default();
assert!(manager.list_hooks().is_empty());
let result = manager.execute(HookEvent::PreSendMessage, HookContext::default());
assert!(result.is_none());
}
#[test]
fn test_hook_manager_session_hooks() {
let mut manager = HookManager::default();
manager.register_session_hook(
HookEvent::PreSendMessage,
HookDef {
r#type: HookType::Bash,
command: Some(r#"echo '{"user_input": "session_hooked"}'"#.to_string()),
prompt: None,
model: None,
timeout: 5,
retry: 0,
on_error: OnError::Skip,
filter: HookFilter::default(),
},
);
let hooks = manager.list_hooks();
assert_eq!(hooks.len(), 1);
assert_eq!(hooks[0].source, "session");
let result = manager
.execute(
HookEvent::PreSendMessage,
HookContext {
event: HookEvent::PreSendMessage,
user_input: Some("original".to_string()),
..Default::default()
},
)
.unwrap();
assert_eq!(result.user_input.as_deref(), Some("session_hooked"));
}
#[test]
fn test_hook_manager_builtin_hooks() {
let mut manager = HookManager::default();
manager.register_builtin(HookEvent::PreSendMessage, "test_builtin", |ctx| {
if let Some(ref input) = ctx.user_input {
Some(HookResult {
user_input: Some(format!("[builtin] {}", input)),
..Default::default()
})
} else {
None
}
});
let hooks = manager.list_hooks();
assert_eq!(hooks.len(), 1);
assert_eq!(hooks[0].source, "builtin");
assert!(hooks[0].label.contains("test_builtin"));
let result = manager
.execute(
HookEvent::PreSendMessage,
HookContext {
event: HookEvent::PreSendMessage,
user_input: Some("hello".to_string()),
..Default::default()
},
)
.unwrap();
assert_eq!(result.user_input.as_deref(), Some("[builtin] hello"));
}
#[test]
fn test_hook_manager_builtin_before_session() {
let mut manager = HookManager::default();
manager.register_builtin(HookEvent::PreSendMessage, "prefix", |ctx| {
if let Some(ref input) = ctx.user_input {
Some(HookResult {
user_input: Some(format!("[builtin] {}", input)),
..Default::default()
})
} else {
None
}
});
manager.register_session_hook(
HookEvent::PreSendMessage,
HookDef {
r#type: HookType::Bash,
command: Some(r#"echo '{"user_input": "session_overridden"}'"#.to_string()),
prompt: None,
model: None,
timeout: 5,
retry: 0,
on_error: OnError::Skip,
filter: HookFilter::default(),
},
);
let result = manager
.execute(
HookEvent::PreSendMessage,
HookContext {
event: HookEvent::PreSendMessage,
user_input: Some("original".to_string()),
..Default::default()
},
)
.unwrap();
assert_eq!(result.user_input.as_deref(), Some("session_overridden"));
}
#[test]
fn test_hook_manager_remove_session_hook() {
let mut manager = HookManager::default();
manager.register_session_hook(
HookEvent::PreSendMessage,
HookDef {
r#type: HookType::Bash,
command: Some("echo test".to_string()),
prompt: None,
model: None,
timeout: 5,
retry: 0,
on_error: OnError::Skip,
filter: HookFilter::default(),
},
);
assert_eq!(manager.list_hooks().len(), 1);
assert!(manager.remove_session_hook(HookEvent::PreSendMessage, 0));
assert!(manager.list_hooks().is_empty());
assert!(!manager.remove_session_hook(HookEvent::PreSendMessage, 0));
}
#[test]
fn test_hook_chain_execution() {
let mut manager = HookManager::default();
manager.register_session_hook(
HookEvent::PreSendMessage,
HookDef {
r#type: HookType::Bash,
command: Some(r#"echo '{"user_input": "first"}'"#.to_string()),
prompt: None,
model: None,
timeout: 5,
retry: 0,
on_error: OnError::Skip,
filter: HookFilter::default(),
},
);
manager.register_session_hook(
HookEvent::PreSendMessage,
HookDef {
r#type: HookType::Bash,
command: Some(r#"echo '{"user_input": "second"}'"#.to_string()),
prompt: None,
model: None,
timeout: 5,
retry: 0,
on_error: OnError::Skip,
filter: HookFilter::default(),
},
);
let result = manager
.execute(
HookEvent::PreSendMessage,
HookContext {
event: HookEvent::PreSendMessage,
user_input: Some("original".to_string()),
..Default::default()
},
)
.unwrap();
assert_eq!(result.user_input.as_deref(), Some("second"));
}
#[test]
fn test_hook_abort_stops_chain() {
let mut manager = HookManager::default();
manager.register_session_hook(
HookEvent::PreSendMessage,
HookDef {
r#type: HookType::Bash,
command: Some("exit 1".to_string()),
prompt: None,
model: None,
timeout: 5,
retry: 0,
on_error: OnError::Abort,
filter: HookFilter::default(),
},
);
manager.register_session_hook(
HookEvent::PreSendMessage,
HookDef {
r#type: HookType::Bash,
command: Some(r#"echo '{"user_input": "should_not_reach"}'"#.to_string()),
prompt: None,
model: None,
timeout: 5,
retry: 0,
on_error: OnError::Skip,
filter: HookFilter::default(),
},
);
let result = manager
.execute(
HookEvent::PreSendMessage,
HookContext {
event: HookEvent::PreSendMessage,
..Default::default()
},
)
.unwrap();
assert!(result.is_halt());
assert!(result.user_input.is_none());
}
#[test]
fn test_builtin_hook_clone() {
let mut manager = HookManager::default();
manager.register_builtin(HookEvent::PreLlmRequest, "test_clone", |_| {
Some(HookResult::default())
});
let cloned = manager.clone();
assert_eq!(cloned.list_hooks().len(), 1);
}
#[test]
fn test_on_error_skip_continues_chain() {
let mut manager = HookManager::default();
manager.register_session_hook(
HookEvent::PreSendMessage,
HookDef {
r#type: HookType::Bash,
command: Some("exit 1".to_string()),
prompt: None,
model: None,
timeout: 5,
retry: 0,
on_error: OnError::Skip,
filter: HookFilter::default(),
},
);
manager.register_session_hook(
HookEvent::PreSendMessage,
HookDef {
r#type: HookType::Bash,
command: Some(r#"echo '{"user_input": "survived"}'"#.to_string()),
prompt: None,
model: None,
timeout: 5,
retry: 0,
on_error: OnError::Skip,
filter: HookFilter::default(),
},
);
let result = manager
.execute(
HookEvent::PreSendMessage,
HookContext {
event: HookEvent::PreSendMessage,
user_input: Some("original".to_string()),
..Default::default()
},
)
.unwrap();
assert!(!result.is_halt());
assert_eq!(result.user_input.as_deref(), Some("survived"));
}
#[test]
fn test_on_error_abort_stops_chain() {
let mut manager = HookManager::default();
manager.register_session_hook(
HookEvent::PreSendMessage,
HookDef {
r#type: HookType::Bash,
command: Some("exit 1".to_string()),
prompt: None,
model: None,
timeout: 5,
retry: 0,
on_error: OnError::Abort,
filter: HookFilter::default(),
},
);
manager.register_session_hook(
HookEvent::PreSendMessage,
HookDef {
r#type: HookType::Bash,
command: Some(r#"echo '{"user_input": "should_not_reach"}'"#.to_string()),
prompt: None,
model: None,
timeout: 5,
retry: 0,
on_error: OnError::Skip,
filter: HookFilter::default(),
},
);
let result = manager
.execute(
HookEvent::PreSendMessage,
HookContext {
event: HookEvent::PreSendMessage,
..Default::default()
},
)
.unwrap();
assert!(result.is_halt());
assert!(result.user_input.is_none());
}
#[test]
fn test_on_error_default_is_skip() {
let yaml = r#"command: "exit 1"
timeout: 5"#;
let def: HookDef = serde_yaml::from_str(yaml).unwrap();
assert_eq!(def.on_error, OnError::Skip);
}
#[test]
fn test_on_error_yaml_parsing() {
let yaml_skip = r#"command: "echo test"
on_error: skip"#;
let def: HookDef = serde_yaml::from_str(yaml_skip).unwrap();
assert_eq!(def.on_error, OnError::Skip);
let yaml_abort = r#"command: "echo test"
on_error: abort"#;
let def: HookDef = serde_yaml::from_str(yaml_abort).unwrap();
assert_eq!(def.on_error, OnError::Abort);
}
#[test]
fn test_shell_hook_stderr_captured() {
let hook = ShellHook {
name: None,
command: r#"echo '{"user_input": "ok"}'; echo "debug info" >&2"#.to_string(),
timeout: 5,
retry: 0,
on_error: OnError::Skip,
filter: HookFilter::default(),
dir_path: None,
};
let ctx = HookContext {
event: HookEvent::PreSendMessage,
user_input: Some("test".to_string()),
..Default::default()
};
let result = execute_shell_hook(&hook, &ctx).unwrap();
assert_eq!(result.user_input.as_deref(), Some("ok"));
}
#[test]
fn test_shell_hook_stderr_in_error() {
let hook = ShellHook {
name: None,
command: r#"echo "something went wrong" >&2; exit 1"#.to_string(),
timeout: 5,
retry: 0,
on_error: OnError::Skip,
filter: HookFilter::default(),
dir_path: None,
};
let ctx = HookContext::default();
let result = execute_shell_hook(&hook, &ctx);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.contains("stderr:"), "错误信息应包含 stderr: {}", err);
assert!(
err.contains("something went wrong"),
"错误信息应包含 stderr 内容: {}",
err
);
}
#[test]
fn test_hook_entry_session_index() {
let mut manager = HookManager::default();
manager.register_builtin(HookEvent::PreSendMessage, "test", |_| None);
manager.register_session_hook(
HookEvent::PreSendMessage,
HookDef {
r#type: HookType::Bash,
command: Some("echo first".to_string()),
prompt: None,
model: None,
timeout: 5,
retry: 0,
on_error: OnError::Skip,
filter: HookFilter::default(),
},
);
manager.register_session_hook(
HookEvent::PreSendMessage,
HookDef {
r#type: HookType::Bash,
command: Some("echo second".to_string()),
prompt: None,
model: None,
timeout: 5,
retry: 0,
on_error: OnError::Abort,
filter: HookFilter::default(),
},
);
let hooks = manager.list_hooks();
assert_eq!(hooks.len(), 3);
assert_eq!(hooks[0].source, "builtin");
assert!(hooks[0].session_index.is_none());
assert!(hooks[0].on_error.is_none());
assert_eq!(hooks[1].source, "session");
assert_eq!(hooks[1].session_index, Some(0));
assert_eq!(hooks[1].on_error, Some(OnError::Skip));
assert_eq!(hooks[2].source, "session");
assert_eq!(hooks[2].session_index, Some(1));
assert_eq!(hooks[2].on_error, Some(OnError::Abort));
}
#[test]
fn test_switch_model_field_removed() {
let json = r#"{"user_input": "test", "_switch_model": "gpt-4"}"#;
let result: HookResult = serde_json::from_str(json).unwrap();
assert_eq!(result.user_input.as_deref(), Some("test"));
}
#[test]
fn test_new_hook_events_roundtrip() {
for event in [
HookEvent::Stop,
HookEvent::PreMicroCompact,
HookEvent::PostMicroCompact,
HookEvent::PreAutoCompact,
HookEvent::PostAutoCompact,
HookEvent::PostToolExecutionFailure,
] {
let s = event.as_str();
let parsed = HookEvent::parse(s).unwrap();
assert_eq!(event, parsed);
}
}
#[test]
fn test_hook_result_retry_feedback() {
let json = r#"{"action": "stop", "retry_feedback": "请修正敏感信息"}"#;
let result: HookResult = serde_json::from_str(json).unwrap();
assert!(result.is_stop());
assert_eq!(result.retry_feedback.as_deref(), Some("请修正敏感信息"));
}
#[test]
fn test_hook_result_action_stop_with_retry_feedback() {
let json = r#"{"action": "stop", "retry_feedback": "请修正敏感信息"}"#;
let result: HookResult = serde_json::from_str(json).unwrap();
assert!(result.is_stop());
assert_eq!(result.retry_feedback.as_deref(), Some("请修正敏感信息"));
}
#[test]
fn test_hook_result_additional_context() {
let json = r#"{"additional_context": "必须保留宪法规则"}"#;
let result: HookResult = serde_json::from_str(json).unwrap();
assert_eq!(
result.additional_context.as_deref(),
Some("必须保留宪法规则")
);
}
#[test]
fn test_hook_result_system_message() {
let json = r#"{"system_message": "纠查官已审查"}"#;
let result: HookResult = serde_json::from_str(json).unwrap();
assert_eq!(result.system_message.as_deref(), Some("纠查官已审查"));
}
#[test]
fn test_hook_result_tool_error() {
let json = r#"{"tool_error": "权限不足"}"#;
let result: HookResult = serde_json::from_str(json).unwrap();
assert_eq!(result.tool_error.as_deref(), Some("权限不足"));
}
#[test]
fn test_hook_context_new_fields() {
let ctx = HookContext {
event: HookEvent::PreAutoCompact,
tool_error: None,
..Default::default()
};
let json = serde_json::to_string(&ctx).unwrap();
assert!(json.contains("pre_auto_compact"));
assert!(!json.contains("tool_error"));
}
#[test]
fn test_hook_filter_tool_matcher() {
let filter = HookFilter {
tool_name: None,
tool_matcher: Some("Bash|Shell".to_string()),
model_prefix: None,
};
assert!(!filter.is_empty());
let ctx = HookContext {
event: HookEvent::PreToolExecution,
tool_name: Some("Bash".to_string()),
..Default::default()
};
assert!(filter.matches(&ctx));
let ctx = HookContext {
event: HookEvent::PreToolExecution,
tool_name: Some("Shell".to_string()),
..Default::default()
};
assert!(filter.matches(&ctx));
let ctx = HookContext {
event: HookEvent::PreToolExecution,
tool_name: Some("Write".to_string()),
..Default::default()
};
assert!(!filter.matches(&ctx));
let ctx = HookContext {
event: HookEvent::PreToolExecution,
..Default::default()
};
assert!(!filter.matches(&ctx));
}
#[test]
fn test_hook_filter_tool_name_priority_over_matcher() {
let filter = HookFilter {
tool_name: Some("Bash".to_string()),
tool_matcher: Some("Write|Edit".to_string()),
model_prefix: None,
};
let ctx = HookContext {
event: HookEvent::PreToolExecution,
tool_name: Some("Write".to_string()),
..Default::default()
};
assert!(!filter.matches(&ctx));
}
#[test]
fn test_hook_filter_tool_matcher_yaml() {
let yaml = r#"tool_matcher: "Bash|Shell""#;
let filter: HookFilter = serde_yaml::from_str(yaml).unwrap();
assert_eq!(filter.tool_matcher.as_deref(), Some("Bash|Shell"));
assert!(filter.tool_name.is_none());
}
#[test]
fn test_render_prompt_template() {
let template = "事件: {{event}}, 输入: {{user_input}}, 工具: {{tool_name}}";
let ctx = HookContext {
event: HookEvent::PreSendMessage,
user_input: Some("hello".to_string()),
tool_name: Some("Bash".to_string()),
..Default::default()
};
let rendered = render_prompt_template(template, &ctx);
assert!(rendered.contains("pre_send_message"));
assert!(rendered.contains("hello"));
assert!(rendered.contains("Bash"));
}
#[test]
fn test_render_prompt_template_empty_fields() {
let template = "输入: {{user_input}}, 输出: {{assistant_output}}";
let ctx = HookContext::default();
let rendered = render_prompt_template(template, &ctx);
assert_eq!(rendered, "输入: , 输出: ");
}
#[test]
fn test_extract_json_from_llm_output() {
assert_eq!(
extract_json_from_llm_output(r#"{"user_input": "test"}"#),
Some(r#"{"user_input": "test"}"#)
);
assert_eq!(
extract_json_from_llm_output("```json\n{\"user_input\": \"test\"}\n```"),
Some(r#"{"user_input": "test"}"#)
);
assert_eq!(
extract_json_from_llm_output("Here is the result: {\"action\": \"stop\"}"),
Some(r#"{"action": "stop"}"#)
);
assert_eq!(extract_json_from_llm_output("no json here"), None);
}
#[test]
fn test_hook_type_yaml_parsing() {
let yaml_bash = r#"command: "echo hello""#;
let def: HookDef = serde_yaml::from_str(yaml_bash).unwrap();
assert_eq!(def.r#type, HookType::Bash);
let yaml_llm = r#"
type: llm
prompt: "check this""#;
let def: HookDef = serde_yaml::from_str(yaml_llm).unwrap();
assert_eq!(def.r#type, HookType::Llm);
assert_eq!(def.prompt.as_deref(), Some("check this"));
}
#[test]
fn test_hook_def_bash_missing_command() {
let def = HookDef {
r#type: HookType::Bash,
command: None,
prompt: None,
model: None,
timeout: 5,
retry: 0,
on_error: OnError::Skip,
filter: HookFilter::default(),
};
assert!(def.into_hook_kind().is_err());
}
#[test]
fn test_hook_def_llm_missing_prompt() {
let def = HookDef {
r#type: HookType::Llm,
command: None,
prompt: None,
model: None,
timeout: 5,
retry: 0,
on_error: OnError::Skip,
filter: HookFilter::default(),
};
assert!(def.into_hook_kind().is_err());
}
#[test]
fn test_hook_type_display() {
assert_eq!(format!("{}", HookType::Bash), "bash");
assert_eq!(format!("{}", HookType::Llm), "llm");
}
#[test]
fn test_hook_entry_hook_type() {
let mut manager = HookManager::default();
manager.register_builtin(HookEvent::PreSendMessage, "test", |_| None);
manager.register_session_hook(
HookEvent::PreSendMessage,
HookDef {
r#type: HookType::Bash,
command: Some("echo test".to_string()),
prompt: None,
model: None,
timeout: 5,
retry: 0,
on_error: OnError::Skip,
filter: HookFilter::default(),
},
);
manager.register_session_hook_kind(
HookEvent::PreSendMessage,
HookKind::Llm(LlmHook {
name: None,
prompt: "check content".to_string(),
model: None,
timeout: 30,
retry: 1,
on_error: OnError::Skip,
filter: HookFilter::default(),
dir_path: None,
}),
);
let hooks = manager.list_hooks();
assert_eq!(hooks.len(), 3);
assert_eq!(hooks[0].hook_type, "builtin");
assert_eq!(hooks[1].hook_type, "bash");
assert_eq!(hooks[2].hook_type, "llm");
}
#[test]
fn test_llm_hook_no_provider_returns_err() {
let hook = LlmHook {
name: None,
prompt: "test".to_string(),
model: None,
timeout: 5,
retry: 0,
on_error: OnError::Skip,
filter: HookFilter::default(),
dir_path: None,
};
let ctx = HookContext::default();
let result = execute_llm_hook(&hook, &ctx, &None);
assert!(result.is_err());
assert!(result.unwrap_err().contains("未注入 provider"));
}
}