#![allow(dead_code)]
use std::collections::HashMap;
use std::sync::Arc;
use crate::utils::hooks::hooks_settings::HookCommand;
use crate::utils::hooks::hooks_settings::HookEvent;
use crate::utils::hooks::session_hooks::{add_session_hook, remove_session_hook, OnHookSuccess};
#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
pub struct HooksSettings {
#[serde(flatten)]
pub events: HashMap<String, Vec<HookMatcher>>,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct HookMatcher {
#[serde(skip_serializing_if = "Option::is_none")]
pub matcher: Option<String>,
pub hooks: Vec<serde_json::Value>,
}
const HOOK_EVENT_NAMES: &[&str] = &[
"PreToolUse",
"PostToolUse",
"PostToolUseFailure",
"PermissionDenied",
"Notification",
"UserPromptSubmit",
"SessionStart",
"SessionEnd",
"Stop",
"StopFailure",
"SubagentStart",
"SubagentStop",
"PreCompact",
"PostCompact",
"PermissionRequest",
"Setup",
"TeammateIdle",
"TaskCreated",
"TaskCompleted",
"Elicitation",
"ElicitationResult",
"ConfigChange",
"WorktreeCreate",
"WorktreeRemove",
"InstructionsLoaded",
"CwdChanged",
"FileChanged",
];
fn parse_hook_event(s: &str) -> Option<HookEvent> {
match s {
"PreToolUse" => Some(HookEvent::PreToolUse),
"PostToolUse" => Some(HookEvent::PostToolUse),
"PostToolUseFailure" => Some(HookEvent::PostToolUseFailure),
"PermissionDenied" => Some(HookEvent::PermissionDenied),
"Notification" => Some(HookEvent::Notification),
"UserPromptSubmit" => Some(HookEvent::UserPromptSubmit),
"SessionStart" => Some(HookEvent::SessionStart),
"SessionEnd" => Some(HookEvent::SessionEnd),
"Stop" => Some(HookEvent::Stop),
"StopFailure" => Some(HookEvent::StopFailure),
"SubagentStart" => Some(HookEvent::SubagentStart),
"SubagentStop" => Some(HookEvent::SubagentStop),
"PreCompact" => Some(HookEvent::PreCompact),
"PostCompact" => Some(HookEvent::PostCompact),
"PermissionRequest" => Some(HookEvent::PermissionRequest),
"Setup" => Some(HookEvent::Setup),
"TeammateIdle" => Some(HookEvent::TeammateIdle),
"TaskCreated" => Some(HookEvent::TaskCreated),
"TaskCompleted" => Some(HookEvent::TaskCompleted),
"Elicitation" => Some(HookEvent::Elicitation),
"ElicitationResult" => Some(HookEvent::ElicitationResult),
"ConfigChange" => Some(HookEvent::ConfigChange),
"WorktreeCreate" => Some(HookEvent::WorktreeCreate),
"WorktreeRemove" => Some(HookEvent::WorktreeRemove),
"InstructionsLoaded" => Some(HookEvent::InstructionsLoaded),
"CwdChanged" => Some(HookEvent::CwdChanged),
"FileChanged" => Some(HookEvent::FileChanged),
_ => None,
}
}
fn parse_hook_command(value: &serde_json::Value) -> Result<HookCommand, String> {
if let Some(command) = value.get("command").and_then(|v| v.as_str()) {
return Ok(HookCommand::Command {
command: command.to_string(),
shell: value
.get("shell")
.and_then(|v| v.as_str())
.map(|s| s.to_string()),
if_condition: value
.get("if")
.and_then(|v| v.as_str())
.map(|s| s.to_string()),
timeout: value.get("timeout").and_then(|v| v.as_u64()),
status_message: value
.get("statusMessage")
.and_then(|v| v.as_str())
.map(|s| s.to_string()),
once: value.get("once").and_then(|v| v.as_bool()),
r#async: value.get("async").and_then(|v| v.as_bool()),
async_rewake: value
.get("asyncRewake")
.and_then(|v| v.as_bool()),
});
}
if let Some(prompt) = value.get("prompt").and_then(|v| v.as_str()) {
if value.get("model").is_some() {
return Ok(HookCommand::Agent {
prompt: prompt.to_string(),
model: value
.get("model")
.and_then(|v| v.as_str())
.map(|s| s.to_string()),
if_condition: value
.get("if")
.and_then(|v| v.as_str())
.map(|s| s.to_string()),
timeout: value.get("timeout").and_then(|v| v.as_u64()),
status_message: value
.get("statusMessage")
.and_then(|v| v.as_str())
.map(|s| s.to_string()),
once: value.get("once").and_then(|v| v.as_bool()),
});
}
return Ok(HookCommand::Prompt {
prompt: prompt.to_string(),
if_condition: value
.get("if")
.and_then(|v| v.as_str())
.map(|s| s.to_string()),
timeout: value.get("timeout").and_then(|v| v.as_u64()),
model: value
.get("model")
.and_then(|v| v.as_str())
.map(|s| s.to_string()),
status_message: value
.get("statusMessage")
.and_then(|v| v.as_str())
.map(|s| s.to_string()),
once: value.get("once").and_then(|v| v.as_bool()),
});
}
if let Some(url) = value.get("url").and_then(|v| v.as_str()) {
return Ok(HookCommand::Http {
url: url.to_string(),
if_condition: value
.get("if")
.and_then(|v| v.as_str())
.map(|s| s.to_string()),
timeout: value.get("timeout").and_then(|v| v.as_u64()),
headers: value
.get("headers")
.and_then(|v| v.as_object())
.map(|m| {
m.iter()
.map(|(k, v)| (k.clone(), v.as_str().unwrap_or("").to_string()))
.collect()
}),
allowed_env_vars: value
.get("allowedEnvVars")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect()
}),
status_message: value
.get("statusMessage")
.and_then(|v| v.as_str())
.map(|s| s.to_string()),
once: value.get("once").and_then(|v| v.as_bool()),
});
}
Err("Could not parse hook command from JSON".to_string())
}
struct ParsedHookWithOnce {
hook: HookCommand,
once: bool,
}
fn parse_hook_with_once(value: &serde_json::Value) -> Result<ParsedHookWithOnce, String> {
let hook = parse_hook_command(value)?;
let once = value.get("once").and_then(|v| v.as_bool()).unwrap_or(false);
Ok(ParsedHookWithOnce { hook, once })
}
pub fn register_skill_hooks(
set_app_state: Arc<dyn Fn(&dyn Fn(&mut serde_json::Value)) + Send + Sync>,
session_id: &str,
hooks: &HooksSettings,
skill_name: &str,
skill_root: Option<&str>,
) {
let mut registered_count = 0;
for event_name in HOOK_EVENT_NAMES {
let matchers = match hooks.events.get(*event_name) {
Some(m) => m,
None => continue,
};
let event = match parse_hook_event(event_name) {
Some(e) => e,
None => continue,
};
for matcher_config in matchers {
let matcher = matcher_config.matcher.clone().unwrap_or_default();
for hook_json in &matcher_config.hooks {
let parsed = match parse_hook_with_once(hook_json) {
Ok(p) => p,
Err(_) => continue,
};
let on_hook_success: Option<OnHookSuccess> = if parsed.once {
let set_app_state_inner = Arc::clone(&set_app_state);
let session_id_inner = session_id.to_string();
let event_inner = event.clone();
let hook_inner = parsed.hook.clone();
let skill_name_inner = skill_name.to_string();
Some(Arc::new(move |_: &crate::utils::hooks::session_hooks::SessionHookCommand, _: &crate::utils::hooks::session_hooks::AggregatedHookResult| {
let sn = skill_name_inner.as_str();
log_for_debugging(&format!(
"Removing one-shot hook for event {} in skill '{}'",
event_inner.as_str(),
sn,
));
remove_session_hook(
&*set_app_state_inner,
&session_id_inner,
&event_inner,
&hook_inner,
);
}) as OnHookSuccess)
} else {
None
};
add_session_hook(
&*set_app_state,
session_id,
&event,
&matcher,
parsed.hook,
on_hook_success,
skill_root.map(|s| s.to_string()).as_deref(),
);
registered_count += 1;
}
}
}
if registered_count > 0 {
log_for_debugging(&format!(
"Registered {} hooks from skill '{}'",
registered_count, skill_name
));
}
}
fn log_for_debugging(msg: &str) {
log::debug!("{}", msg);
}
pub fn register_hooks_from_skills(
set_app_state: Arc<dyn Fn(&dyn Fn(&mut serde_json::Value)) + Send + Sync>,
session_id: &str,
skills: &[crate::skills::loader::UnifiedSkill],
) {
let mut total_count = 0;
for skill in skills {
if let Some(ref hooks) = skill.hooks {
register_skill_hooks(
Arc::clone(&set_app_state),
session_id,
hooks,
&skill.name,
None,
);
total_count += 1;
}
}
if total_count > 0 {
log_for_debugging(&format!(
"Registered hooks from {} skill(s) for session {}",
total_count, session_id
));
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
#[test]
fn test_parse_hook_with_once() {
let json = serde_json::json!({
"command": "echo hello",
"once": true
});
let result = parse_hook_with_once(&json);
assert!(result.is_ok());
let parsed = result.unwrap();
assert!(parsed.once);
if let HookCommand::Command { command, .. } = parsed.hook {
assert_eq!(command, "echo hello");
} else {
panic!("Expected Command variant");
}
}
#[test]
fn test_parse_hook_without_once() {
let json = serde_json::json!({
"command": "echo hello"
});
let result = parse_hook_with_once(&json);
assert!(result.is_ok());
let parsed = result.unwrap();
assert!(!parsed.once);
}
#[test]
fn test_register_skill_hooks_empty() {
let hooks = HooksSettings::default();
let call_count = Arc::new(AtomicUsize::new(0));
let call_count_clone = Arc::clone(&call_count);
let set_app_state = move |_: &dyn Fn(&mut serde_json::Value)| {
call_count_clone.fetch_add(1, Ordering::SeqCst);
};
register_skill_hooks(
Arc::new(set_app_state),
"test-session",
&hooks,
"test-skill",
None,
);
assert_eq!(call_count.load(Ordering::SeqCst), 0);
}
#[test]
fn test_register_skill_hooks_with_hooks() {
let mut hooks_settings = HooksSettings::default();
hooks_settings.events.insert(
"Stop".to_string(),
vec![HookMatcher {
matcher: Some("".to_string()),
hooks: vec![serde_json::json!({
"command": "echo hook-executed"
})],
}],
);
let call_count = Arc::new(AtomicUsize::new(0));
let call_count_clone = Arc::clone(&call_count);
let set_app_state = move |_: &dyn Fn(&mut serde_json::Value)| {
call_count_clone.fetch_add(1, Ordering::SeqCst);
};
register_skill_hooks(
Arc::new(set_app_state),
"test-session",
&hooks_settings,
"test-skill",
None,
);
assert_eq!(call_count.load(Ordering::SeqCst), 1);
}
#[test]
fn test_register_hooks_from_skills_empty() {
let skills: Vec<crate::skills::loader::UnifiedSkill> = vec![];
let call_count = Arc::new(AtomicUsize::new(0));
let call_count_clone = Arc::clone(&call_count);
let set_app_state = move |_: &dyn Fn(&mut serde_json::Value)| {
call_count_clone.fetch_add(1, Ordering::SeqCst);
};
register_hooks_from_skills(Arc::new(set_app_state), "test-session", &skills);
assert_eq!(call_count.load(Ordering::SeqCst), 0);
}
#[test]
fn test_register_hooks_from_skills_with_hooks() {
let hooks = HooksSettings {
events: {
let mut map = std::collections::HashMap::new();
map.insert(
"Stop".to_string(),
vec![HookMatcher {
matcher: Some("".to_string()),
hooks: vec![serde_json::json!({
"command": "echo test"
})],
}],
);
map
},
};
let skills = vec![crate::skills::loader::UnifiedSkill {
name: "test-skill".to_string(),
description: "Test".to_string(),
source: crate::skills::loader::SkillSource::Project,
content: "content".to_string(),
paths: None,
user_invocable: None,
hooks: Some(hooks),
}];
let call_count = Arc::new(AtomicUsize::new(0));
let call_count_clone = Arc::clone(&call_count);
let set_app_state = move |_: &dyn Fn(&mut serde_json::Value)| {
call_count_clone.fetch_add(1, Ordering::SeqCst);
};
register_hooks_from_skills(Arc::new(set_app_state), "test-session", &skills);
assert_eq!(call_count.load(Ordering::SeqCst), 1);
}
#[test]
fn test_register_hooks_from_skills_skips_no_hooks() {
let skills = vec![
crate::skills::loader::UnifiedSkill {
name: "no-hooks".to_string(),
description: "No hooks".to_string(),
source: crate::skills::loader::SkillSource::Bundled,
content: "".to_string(),
paths: None,
user_invocable: None,
hooks: None,
},
crate::skills::loader::UnifiedSkill {
name: "also-no-hooks".to_string(),
description: "Also no hooks".to_string(),
source: crate::skills::loader::SkillSource::User,
content: "".to_string(),
paths: None,
user_invocable: None,
hooks: None,
},
];
let call_count = Arc::new(AtomicUsize::new(0));
let call_count_clone = Arc::clone(&call_count);
let set_app_state = move |_: &dyn Fn(&mut serde_json::Value)| {
call_count_clone.fetch_add(1, Ordering::SeqCst);
};
register_hooks_from_skills(Arc::new(set_app_state), "test-session", &skills);
assert_eq!(call_count.load(Ordering::SeqCst), 0);
}
}