#[allow(unused_imports)]
use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::collections::HashMap;
use std::io::{Read, Write};
use std::path::PathBuf;
use std::process::{Command, Stdio};
use std::thread::{self, JoinHandle};
use std::time::{Duration, Instant};
use wait_timeout::ChildExt;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum HookEvent {
SessionStart,
SessionEnd,
MessageSubmit,
ToolCallBefore,
ToolCallAfter,
ModeChange,
OnError,
ShellEnv,
}
impl HookEvent {
#[allow(dead_code)] pub fn as_str(self) -> &'static str {
match self {
HookEvent::SessionStart => "session_start",
HookEvent::SessionEnd => "session_end",
HookEvent::MessageSubmit => "message_submit",
HookEvent::ToolCallBefore => "tool_call_before",
HookEvent::ToolCallAfter => "tool_call_after",
HookEvent::ModeChange => "mode_change",
HookEvent::OnError => "on_error",
HookEvent::ShellEnv => "shell_env",
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
#[derive(Default)]
pub enum HookCondition {
#[default]
Always,
ToolName {
name: String,
},
ToolCategory {
category: String,
},
Mode {
mode: String,
},
ExitCode {
code: i32,
},
All { conditions: Vec<HookCondition> },
Any { conditions: Vec<HookCondition> },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Hook {
pub event: HookEvent,
pub command: String,
#[serde(default)]
pub condition: Option<HookCondition>,
#[serde(default = "default_timeout")]
pub timeout_secs: u64,
#[serde(default)]
pub background: bool,
#[serde(default = "default_continue_on_error")]
pub continue_on_error: bool,
#[serde(default)]
pub name: Option<String>,
}
fn default_timeout() -> u64 {
30
}
fn default_continue_on_error() -> bool {
true
}
impl Hook {
#[allow(dead_code)] pub fn new(event: HookEvent, command: &str) -> Self {
Self {
event,
command: command.to_string(),
condition: None,
timeout_secs: 30,
background: false,
continue_on_error: true,
name: None,
}
}
#[allow(dead_code)] pub fn with_condition(mut self, condition: HookCondition) -> Self {
self.condition = Some(condition);
self
}
#[allow(dead_code)] pub fn with_timeout(mut self, secs: u64) -> Self {
self.timeout_secs = secs;
self
}
#[allow(dead_code)] pub fn background(mut self) -> Self {
self.background = true;
self
}
#[allow(dead_code)] pub fn with_name(mut self, name: &str) -> Self {
self.name = Some(name.to_string());
self
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct HooksConfig {
#[serde(default)]
pub hooks: Vec<Hook>,
#[serde(default = "default_enabled")]
pub enabled: bool,
#[serde(default)]
pub default_timeout_secs: Option<u64>,
#[serde(default)]
pub working_dir: Option<PathBuf>,
}
fn default_enabled() -> bool {
true
}
impl HooksConfig {
pub fn hooks_for_event(&self, event: HookEvent) -> Vec<&Hook> {
if !self.enabled {
return Vec::new();
}
self.hooks.iter().filter(|h| h.event == event).collect()
}
#[allow(dead_code)] pub fn has_hooks(&self) -> bool {
self.enabled && !self.hooks.is_empty()
}
}
#[derive(Debug, Clone, Default)]
pub struct HookContext {
pub tool_name: Option<String>,
pub tool_args: Option<String>,
pub tool_result: Option<String>,
pub tool_exit_code: Option<i32>,
pub tool_success: Option<bool>,
pub mode: Option<String>,
pub previous_mode: Option<String>,
pub session_id: Option<String>,
pub message: Option<String>,
pub error_message: Option<String>,
pub workspace: Option<PathBuf>,
pub model: Option<String>,
pub total_tokens: Option<u32>,
pub session_cost: Option<f64>,
}
impl HookContext {
pub fn new() -> Self {
Self::default()
}
#[allow(dead_code)] pub fn with_tool_name(mut self, name: &str) -> Self {
self.tool_name = Some(name.to_string());
self
}
#[allow(dead_code)] pub fn with_tool_args(mut self, args: &serde_json::Value) -> Self {
self.tool_args = Some(args.to_string());
self
}
#[allow(dead_code)] pub fn with_tool_result(mut self, result: &str, success: bool, exit_code: Option<i32>) -> Self {
self.tool_result = Some(result.to_string());
self.tool_success = Some(success);
self.tool_exit_code = exit_code;
self
}
#[allow(dead_code)] pub fn with_mode(mut self, mode: &str) -> Self {
self.mode = Some(mode.to_string());
self
}
pub fn with_previous_mode(mut self, mode: &str) -> Self {
self.previous_mode = Some(mode.to_string());
self
}
#[allow(dead_code)] pub fn with_workspace(mut self, path: PathBuf) -> Self {
self.workspace = Some(path);
self
}
pub fn with_model(mut self, model: &str) -> Self {
self.model = Some(model.to_string());
self
}
pub fn with_session_id(mut self, session_id: &str) -> Self {
self.session_id = Some(session_id.to_string());
self
}
#[allow(dead_code)] pub fn with_message(mut self, message: &str) -> Self {
self.message = Some(message.to_string());
self
}
#[allow(dead_code)] pub fn with_error(mut self, error: &str) -> Self {
self.error_message = Some(error.to_string());
self
}
pub fn with_tokens(mut self, tokens: u32) -> Self {
self.total_tokens = Some(tokens);
self
}
#[allow(dead_code)] pub fn with_cost(mut self, cost: f64) -> Self {
self.session_cost = Some(cost);
self
}
pub fn to_env_vars(&self) -> HashMap<String, String> {
let mut env = HashMap::new();
if let Some(ref name) = self.tool_name {
env.insert("DEEPSEEK_TOOL_NAME".to_string(), name.clone());
}
if let Some(ref args) = self.tool_args {
env.insert("DEEPSEEK_TOOL_ARGS".to_string(), args.clone());
}
if let Some(ref result) = self.tool_result {
let truncated = if result.len() > 10000 {
let safe_end = result
.char_indices()
.take_while(|(i, _)| *i < 10000)
.last()
.map(|(i, c)| i + c.len_utf8())
.unwrap_or(0);
format!("{}...[truncated]", &result[..safe_end])
} else {
result.clone()
};
env.insert("DEEPSEEK_TOOL_RESULT".to_string(), truncated);
}
if let Some(code) = self.tool_exit_code {
env.insert("DEEPSEEK_TOOL_EXIT_CODE".to_string(), code.to_string());
}
if let Some(success) = self.tool_success {
env.insert("DEEPSEEK_TOOL_SUCCESS".to_string(), success.to_string());
}
if let Some(ref mode) = self.mode {
env.insert("DEEPSEEK_MODE".to_string(), mode.clone());
}
if let Some(ref prev) = self.previous_mode {
env.insert("DEEPSEEK_PREVIOUS_MODE".to_string(), prev.clone());
}
if let Some(ref session_id) = self.session_id {
env.insert("DEEPSEEK_SESSION_ID".to_string(), session_id.clone());
}
if let Some(ref message) = self.message {
let truncated = if message.len() > 5000 {
let safe_end = message
.char_indices()
.take_while(|(i, _)| *i < 5000)
.last()
.map(|(i, c)| i + c.len_utf8())
.unwrap_or(0);
format!("{}...[truncated]", &message[..safe_end])
} else {
message.clone()
};
env.insert("DEEPSEEK_MESSAGE".to_string(), truncated);
}
if let Some(ref error) = self.error_message {
env.insert("DEEPSEEK_ERROR".to_string(), error.clone());
}
if let Some(ref ws) = self.workspace {
env.insert("DEEPSEEK_WORKSPACE".to_string(), ws.display().to_string());
}
if let Some(ref model) = self.model {
env.insert("DEEPSEEK_MODEL".to_string(), model.clone());
}
if let Some(tokens) = self.total_tokens {
env.insert("DEEPSEEK_TOTAL_TOKENS".to_string(), tokens.to_string());
}
if let Some(cost) = self.session_cost {
env.insert("DEEPSEEK_SESSION_COST".to_string(), format!("{cost:.6}"));
}
env
}
}
#[derive(Debug, Clone)]
#[allow(dead_code)] pub struct HookResult {
pub name: Option<String>,
pub success: bool,
pub exit_code: Option<i32>,
pub stdout: String,
pub stderr: String,
pub duration: Duration,
pub error: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum MessageSubmitOutcome {
Unchanged { warning: Option<String> },
Replaced {
text: String,
warning: Option<String>,
},
Blocked { reason: String },
}
impl MessageSubmitOutcome {
pub fn unchanged() -> Self {
Self::Unchanged { warning: None }
}
pub fn replaced(text: String) -> Self {
Self::Replaced {
text,
warning: None,
}
}
fn with_warning(self, warning: Option<String>) -> Self {
match self {
Self::Unchanged { .. } => Self::Unchanged { warning },
Self::Replaced { text, .. } => Self::Replaced { text, warning },
Self::Blocked { reason } => Self::Blocked { reason },
}
}
pub fn warning(&self) -> Option<&str> {
match self {
Self::Unchanged { warning } | Self::Replaced { warning, .. } => warning.as_deref(),
Self::Blocked { .. } => None,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
enum MessageSubmitStdout {
Unchanged,
Replaced(String),
Invalid(String),
}
#[derive(Debug, Clone)]
pub struct HookExecutor {
config: HooksConfig,
default_working_dir: PathBuf,
session_id: String,
}
impl HookExecutor {
fn build_shell_command(command: &str) -> Command {
#[cfg(windows)]
{
let mut cmd = Command::new("cmd");
cmd.arg("/C").arg(command);
cmd
}
#[cfg(not(windows))]
{
let mut cmd = Command::new("sh");
cmd.arg("-c").arg(command);
cmd
}
}
pub fn new(config: HooksConfig, default_working_dir: PathBuf) -> Self {
let session_id = format!("sess_{}", &uuid::Uuid::new_v4().to_string()[..8]);
Self {
config,
default_working_dir,
session_id,
}
}
#[allow(dead_code)] pub fn disabled() -> Self {
Self {
config: HooksConfig {
enabled: false,
..Default::default()
},
default_working_dir: PathBuf::from("."),
session_id: String::new(),
}
}
#[allow(dead_code)] pub fn is_enabled(&self) -> bool {
self.config.enabled
}
pub fn config(&self) -> &HooksConfig {
&self.config
}
pub fn session_id(&self) -> &str {
&self.session_id
}
#[must_use]
pub fn has_hooks_for_event(&self, event: HookEvent) -> bool {
self.config.enabled && self.config.hooks.iter().any(|h| h.event == event)
}
pub fn execute_message_submit_transform(
&self,
context: &HookContext,
original_text: &str,
) -> MessageSubmitOutcome {
if !self.config.enabled {
return MessageSubmitOutcome::unchanged();
}
let hooks = self.config.hooks_for_event(HookEvent::MessageSubmit);
if hooks.is_empty() {
return MessageSubmitOutcome::unchanged();
}
let mut current_text = original_text.to_string();
let mut warning = None;
for hook in hooks {
let hook_context = context.clone().with_message(¤t_text);
if !self.matches_condition(hook, &hook_context) {
continue;
}
let env_vars = hook_context.to_env_vars();
if hook.background {
let _ = self.execute_background(hook, &env_vars);
continue;
}
let payload = message_submit_payload(&hook_context, ¤t_text);
let result = self.execute_sync_with_stdin(hook, &env_vars, &payload);
if result.exit_code == Some(2) {
return MessageSubmitOutcome::Blocked {
reason: message_submit_block_reason(
&result,
"message_submit hook blocked submission",
),
};
}
if !result.success {
let label = result.name.as_deref().unwrap_or("(unnamed)");
tracing::warn!(
target: "hooks",
hook = label,
event = "message_submit",
exit_code = ?result.exit_code,
duration_ms = result.duration.as_millis() as u64,
error = result.error.as_deref().unwrap_or(""),
stderr_head = %result.stderr.lines().next().unwrap_or(""),
"message_submit hook failed"
);
if hook.continue_on_error {
warning = message_submit_continue_warning(&result).or(warning);
continue;
}
return MessageSubmitOutcome::Blocked {
reason: message_submit_block_reason(
&result,
"message_submit hook failed and blocked submission",
),
};
}
match parse_message_submit_stdout(&result.stdout) {
MessageSubmitStdout::Unchanged => {}
MessageSubmitStdout::Replaced(text) => {
current_text = text;
}
MessageSubmitStdout::Invalid(reason) => {
tracing::warn!(
target: "hooks",
hook = result.name.as_deref().unwrap_or("(unnamed)"),
event = "message_submit",
reason = %reason,
"ignored invalid message_submit hook stdout"
);
}
}
}
if current_text == original_text {
MessageSubmitOutcome::unchanged().with_warning(warning)
} else {
MessageSubmitOutcome::replaced(current_text).with_warning(warning)
}
}
pub fn collect_shell_env(&self, context: &HookContext) -> HashMap<String, String> {
let mut merged: HashMap<String, String> = HashMap::new();
if !self.config.enabled {
return merged;
}
let hooks = self.config.hooks_for_event(HookEvent::ShellEnv);
if hooks.is_empty() {
return merged;
}
let env_vars = context.to_env_vars();
for hook in hooks {
if !self.matches_condition(hook, context) {
continue;
}
let result = self.execute_sync(hook, &env_vars);
if !result.success {
tracing::warn!(
target: "hooks",
hook = result.name.as_deref().unwrap_or("(unnamed)"),
event = "shell_env",
exit_code = ?result.exit_code,
error = result.error.as_deref().unwrap_or(""),
"shell_env hook failed; contributing no env vars"
);
continue;
}
let parsed = parse_env_lines(&result.stdout);
if parsed.is_empty() {
continue;
}
crate::audit::log_sensitive_event(
"shell_env_hook",
serde_json::json!({
"hook": result.name,
"tool": context.tool_name,
"keys": parsed.keys().cloned().collect::<Vec<_>>(),
}),
);
merged.extend(parsed);
}
merged
}
pub fn execute(&self, event: HookEvent, context: &HookContext) -> Vec<HookResult> {
if !self.config.enabled {
return Vec::new();
}
let hooks = self.config.hooks_for_event(event);
if hooks.is_empty() {
return Vec::new();
}
let env_vars = context.to_env_vars();
let mut results = Vec::new();
for hook in hooks {
if !self.matches_condition(hook, context) {
continue;
}
let result = if hook.background {
self.execute_background(hook, &env_vars)
} else {
self.execute_sync(hook, &env_vars)
};
if !result.success {
let label = result.name.as_deref().unwrap_or("(unnamed)");
tracing::warn!(
target: "hooks",
hook = label,
event = event.as_str(),
exit_code = ?result.exit_code,
duration_ms = result.duration.as_millis() as u64,
error = result.error.as_deref().unwrap_or(""),
stderr_head = %result.stderr.lines().next().unwrap_or(""),
"hook failed"
);
}
let should_continue = result.success || hook.continue_on_error;
results.push(result);
if !should_continue {
break;
}
}
results
}
#[allow(clippy::only_used_in_recursion)]
fn matches_condition(&self, hook: &Hook, context: &HookContext) -> bool {
match &hook.condition {
None | Some(HookCondition::Always) => true,
Some(HookCondition::ToolName { name }) => {
context.tool_name.as_ref().is_some_and(|n| n == name)
}
Some(HookCondition::ToolCategory { category }) => {
let tool_category = context.tool_name.as_ref().map(|name| match name.as_str() {
"exec_shell" => "shell",
"write_file" | "edit_file" | "apply_patch" => "file_write",
"read_file" | "list_dir" | "grep_files" => "safe",
_ => "other",
});
tool_category.is_some_and(|c| c == category.as_str())
}
Some(HookCondition::Mode { mode }) => context
.mode
.as_ref()
.is_some_and(|m| m.to_lowercase() == mode.to_lowercase()),
Some(HookCondition::ExitCode { code }) => context.tool_exit_code == Some(*code),
Some(HookCondition::All { conditions }) => conditions.iter().all(|c| {
self.matches_condition(
&Hook {
condition: Some(c.clone()),
..hook.clone()
},
context,
)
}),
Some(HookCondition::Any { conditions }) => conditions.iter().any(|c| {
self.matches_condition(
&Hook {
condition: Some(c.clone()),
..hook.clone()
},
context,
)
}),
}
}
fn execute_sync(&self, hook: &Hook, env_vars: &HashMap<String, String>) -> HookResult {
self.execute_sync_inner(hook, env_vars, None)
}
fn execute_sync_with_stdin(
&self,
hook: &Hook,
env_vars: &HashMap<String, String>,
stdin_json: &serde_json::Value,
) -> HookResult {
self.execute_sync_inner(hook, env_vars, Some(stdin_json))
}
fn execute_sync_inner(
&self,
hook: &Hook,
env_vars: &HashMap<String, String>,
stdin_json: Option<&serde_json::Value>,
) -> HookResult {
let started = Instant::now();
let working_dir = self
.config
.working_dir
.clone()
.unwrap_or_else(|| self.default_working_dir.clone());
let timeout_secs = self
.config
.default_timeout_secs
.unwrap_or(hook.timeout_secs);
let timeout = Duration::from_secs(timeout_secs);
let stdin_bytes = match stdin_json.map(serde_json::to_vec).transpose() {
Ok(bytes) => bytes,
Err(e) => {
return HookResult {
name: hook.name.clone(),
success: false,
exit_code: None,
stdout: String::new(),
stderr: String::new(),
duration: started.elapsed(),
error: Some(format!("Failed to encode hook stdin: {e}")),
};
}
};
let mut command = Self::build_shell_command(&hook.command);
command
.current_dir(&working_dir)
.envs(env_vars)
.stdout(Stdio::piped())
.stderr(Stdio::piped());
if stdin_bytes.is_some() {
command.stdin(Stdio::piped());
}
let mut child = match command.spawn() {
Ok(child) => child,
Err(e) => {
return HookResult {
name: hook.name.clone(),
success: false,
exit_code: None,
stdout: String::new(),
stderr: String::new(),
duration: started.elapsed(),
error: Some(format!("Failed to spawn hook: {e}")),
};
}
};
let stdout_reader = child.stdout.take().map(spawn_pipe_reader);
let stderr_reader = child.stderr.take().map(spawn_pipe_reader);
let _stdin_writer = match (stdin_bytes, child.stdin.take()) {
(Some(bytes), Some(stdin)) => Some(spawn_stdin_writer(stdin, bytes)),
_ => None,
};
match child.wait_timeout(timeout) {
Ok(Some(status)) => HookResult {
name: hook.name.clone(),
success: status.success(),
exit_code: status.code(),
stdout: join_reader(stdout_reader),
stderr: join_reader(stderr_reader),
duration: started.elapsed(),
error: None,
},
Ok(None) => {
let _ = child.kill();
let _ = child.wait();
HookResult {
name: hook.name.clone(),
success: false,
exit_code: None,
stdout: String::new(),
stderr: String::new(),
duration: started.elapsed(),
error: Some(format!("Hook timed out after {timeout_secs}s")),
}
}
Err(e) => {
let _ = child.kill();
let _ = child.wait();
HookResult {
name: hook.name.clone(),
success: false,
exit_code: None,
stdout: String::new(),
stderr: String::new(),
duration: started.elapsed(),
error: Some(format!("Failed to wait for hook: {e}")),
}
}
}
}
fn execute_background(&self, hook: &Hook, env_vars: &HashMap<String, String>) -> HookResult {
let started = Instant::now();
let working_dir = self
.config
.working_dir
.clone()
.unwrap_or_else(|| self.default_working_dir.clone());
let cmd = hook.command.clone();
let env = env_vars.clone();
let wd = working_dir.clone();
std::thread::spawn(move || {
let _ = HookExecutor::build_shell_command(&cmd)
.current_dir(&wd)
.envs(&env)
.output();
});
HookResult {
name: hook.name.clone(),
success: true,
exit_code: None,
stdout: String::new(),
stderr: String::new(),
duration: started.elapsed(),
error: None,
}
}
}
fn spawn_pipe_reader(mut pipe: impl Read + Send + 'static) -> JoinHandle<String> {
thread::spawn(move || {
let mut buf = String::new();
let _ = pipe.read_to_string(&mut buf);
buf
})
}
fn join_reader(reader: Option<JoinHandle<String>>) -> String {
reader
.and_then(|handle| handle.join().ok())
.unwrap_or_default()
}
fn spawn_stdin_writer(mut stdin: std::process::ChildStdin, mut bytes: Vec<u8>) -> JoinHandle<()> {
thread::spawn(move || {
bytes.push(b'\n');
let _ = stdin.write_all(&bytes);
let _ = stdin.flush();
})
}
fn message_submit_payload(context: &HookContext, text: &str) -> serde_json::Value {
json!({
"event": HookEvent::MessageSubmit.as_str(),
"text": text,
"session_id": context.session_id.as_deref(),
"workspace": context.workspace.as_ref().map(|path| path.display().to_string()),
"mode": context.mode.as_deref(),
"model": context.model.as_deref(),
"total_tokens": context.total_tokens,
})
}
fn parse_message_submit_stdout(stdout: &str) -> MessageSubmitStdout {
let trimmed = stdout.trim();
if trimmed.is_empty() {
return MessageSubmitStdout::Unchanged;
}
let value: serde_json::Value = match serde_json::from_str(trimmed) {
Ok(value) => value,
Err(e) => return MessageSubmitStdout::Invalid(format!("invalid JSON: {e}")),
};
let Some(object) = value.as_object() else {
return MessageSubmitStdout::Invalid("stdout JSON must be an object".to_string());
};
match object.get("text") {
Some(serde_json::Value::String(text)) if !text.is_empty() => {
MessageSubmitStdout::Replaced(text.clone())
}
Some(serde_json::Value::String(_)) => {
MessageSubmitStdout::Invalid("stdout `text` field must not be empty".to_string())
}
Some(_) => MessageSubmitStdout::Invalid("stdout `text` field must be a string".to_string()),
None => MessageSubmitStdout::Unchanged,
}
}
fn message_submit_continue_warning(result: &HookResult) -> Option<String> {
message_submit_stdout_reason(&result.stdout)
.or_else(|| first_non_empty_line(&result.stderr))
.or_else(|| first_non_empty_line(&result.stdout))
.or_else(|| result.error.as_deref().and_then(first_non_empty_line))
}
fn message_submit_block_reason(result: &HookResult, fallback: &str) -> String {
if let Some(reason) = message_submit_stdout_reason(&result.stdout) {
return reason;
}
if let Some(reason) = first_non_empty_line(&result.stderr) {
return reason;
}
if let Some(reason) = first_non_empty_line(&result.stdout) {
return reason;
}
if let Some(reason) = result.error.as_deref().and_then(first_non_empty_line) {
return reason;
}
fallback.to_string()
}
fn message_submit_stdout_reason(stdout: &str) -> Option<String> {
let value: serde_json::Value = serde_json::from_str(stdout.trim()).ok()?;
value
.get("reason")
.and_then(serde_json::Value::as_str)
.map(truncate_hook_message)
}
fn first_non_empty_line(text: &str) -> Option<String> {
text.lines()
.map(str::trim)
.find(|line| !line.is_empty())
.map(truncate_hook_message)
}
fn truncate_hook_message(message: &str) -> String {
const MAX_CHARS: usize = 240;
let mut chars = message.chars();
let mut out: String = chars.by_ref().take(MAX_CHARS).collect();
if chars.next().is_some() {
out.push('…');
}
out
}
fn parse_env_lines(stdout: &str) -> HashMap<String, String> {
let mut out = HashMap::new();
for raw in stdout.lines() {
let line = raw.trim();
if line.is_empty() || line.starts_with('#') {
continue;
}
let line = line.strip_prefix("export ").unwrap_or(line);
let Some((key, value)) = line.split_once('=') else {
continue;
};
let key = key.trim();
if key.is_empty() {
continue;
}
let value = value.trim();
let stripped = value
.strip_prefix('"')
.and_then(|v| v.strip_suffix('"'))
.or_else(|| value.strip_prefix('\'').and_then(|v| v.strip_suffix('\'')))
.unwrap_or(value);
out.insert(key.to_string(), stripped.to_string());
}
out
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
use std::path::PathBuf;
#[test]
fn parse_env_lines_handles_realistic_hook_output() {
let stdout = r#"
# Aux comment line, ignored
AWS_ACCESS_KEY_ID=AKIAEXAMPLE
export GITHUB_TOKEN=ghp_examplevalue
QUOTED="value with spaces"
SINGLE='also valid'
= empty key dropped
NOEQUAL line dropped
"#;
let parsed = super::parse_env_lines(stdout);
assert_eq!(
parsed.get("AWS_ACCESS_KEY_ID"),
Some(&"AKIAEXAMPLE".to_string())
);
assert_eq!(
parsed.get("GITHUB_TOKEN"),
Some(&"ghp_examplevalue".to_string())
);
assert_eq!(parsed.get("QUOTED"), Some(&"value with spaces".to_string()));
assert_eq!(parsed.get("SINGLE"), Some(&"also valid".to_string()));
assert!(!parsed.contains_key(""));
assert!(!parsed.contains_key("NOEQUAL line dropped"));
assert_eq!(parsed.len(), 4);
}
#[test]
fn parse_env_lines_empty_when_no_assignments() {
let parsed = super::parse_env_lines("# nothing\n\n \n");
assert!(parsed.is_empty());
}
#[test]
fn parse_message_submit_stdout_replaces_text() {
assert_eq!(
super::parse_message_submit_stdout(r#"{"text":"changed"}"#),
MessageSubmitStdout::Replaced("changed".to_string())
);
}
#[test]
fn parse_message_submit_stdout_empty_is_unchanged() {
assert_eq!(
super::parse_message_submit_stdout(" \n\t "),
MessageSubmitStdout::Unchanged
);
}
#[test]
fn parse_message_submit_stdout_without_text_is_unchanged() {
assert_eq!(
super::parse_message_submit_stdout(r#"{"reason":"only used for blocks"}"#),
MessageSubmitStdout::Unchanged
);
}
#[test]
fn parse_message_submit_stdout_rejects_malformed_json() {
assert!(matches!(
super::parse_message_submit_stdout("not json"),
MessageSubmitStdout::Invalid(_)
));
}
#[test]
fn parse_message_submit_stdout_rejects_non_string_text() {
assert!(matches!(
super::parse_message_submit_stdout(r#"{"text":123}"#),
MessageSubmitStdout::Invalid(_)
));
}
#[test]
fn parse_message_submit_stdout_rejects_empty_text() {
assert_eq!(
super::parse_message_submit_stdout(r#"{"text":""}"#),
MessageSubmitStdout::Invalid("stdout `text` field must not be empty".to_string())
);
}
#[test]
fn parse_message_submit_stdout_rejects_non_object_json() {
assert!(matches!(
super::parse_message_submit_stdout(r#"["not", "an", "object"]"#),
MessageSubmitStdout::Invalid(_)
));
assert!(matches!(
super::parse_message_submit_stdout(r#""not an object""#),
MessageSubmitStdout::Invalid(_)
));
}
#[test]
fn test_hook_event_as_str() {
assert_eq!(HookEvent::SessionStart.as_str(), "session_start");
assert_eq!(HookEvent::ToolCallAfter.as_str(), "tool_call_after");
assert_eq!(HookEvent::ModeChange.as_str(), "mode_change");
}
#[test]
fn test_hook_context_to_env_vars() {
let ctx = HookContext::new()
.with_tool_name("exec_shell")
.with_mode("agent")
.with_workspace(PathBuf::from("/tmp"));
let env = ctx.to_env_vars();
assert_eq!(
env.get("DEEPSEEK_TOOL_NAME"),
Some(&"exec_shell".to_string())
);
assert_eq!(env.get("DEEPSEEK_MODE"), Some(&"agent".to_string()));
assert_eq!(env.get("DEEPSEEK_WORKSPACE"), Some(&"/tmp".to_string()));
}
#[test]
fn test_hook_condition_always() {
let hook = Hook::new(HookEvent::SessionStart, "echo test");
let executor = HookExecutor::disabled();
let context = HookContext::new();
assert!(executor.matches_condition(&hook, &context));
}
#[test]
fn test_hook_condition_tool_name() {
let hook = Hook::new(HookEvent::ToolCallBefore, "echo test").with_condition(
HookCondition::ToolName {
name: "exec_shell".to_string(),
},
);
let executor = HookExecutor::disabled();
let context_match = HookContext::new().with_tool_name("exec_shell");
let context_no_match = HookContext::new().with_tool_name("write_file");
assert!(executor.matches_condition(&hook, &context_match));
assert!(!executor.matches_condition(&hook, &context_no_match));
}
#[test]
fn test_hook_condition_mode() {
let hook =
Hook::new(HookEvent::ModeChange, "echo test").with_condition(HookCondition::Mode {
mode: "agent".to_string(),
});
let executor = HookExecutor::disabled();
let context_match = HookContext::new().with_mode("AGENT"); let context_no_match = HookContext::new().with_mode("normal");
assert!(executor.matches_condition(&hook, &context_match));
assert!(!executor.matches_condition(&hook, &context_no_match));
}
#[test]
fn test_hooks_config_for_event() {
let config = HooksConfig {
enabled: true,
hooks: vec![
Hook::new(HookEvent::SessionStart, "echo start"),
Hook::new(HookEvent::SessionEnd, "echo end"),
Hook::new(HookEvent::SessionStart, "echo start2"),
],
..Default::default()
};
let start_hooks = config.hooks_for_event(HookEvent::SessionStart);
assert_eq!(start_hooks.len(), 2);
let end_hooks = config.hooks_for_event(HookEvent::SessionEnd);
assert_eq!(end_hooks.len(), 1);
}
#[test]
fn test_hooks_config_disabled() {
let config = HooksConfig {
enabled: false,
hooks: vec![Hook::new(HookEvent::SessionStart, "echo start")],
..Default::default()
};
let hooks = config.hooks_for_event(HookEvent::SessionStart);
assert!(hooks.is_empty());
}
#[test]
fn test_hook_builder() {
let hook = Hook::new(HookEvent::ToolCallAfter, "notify.sh")
.with_name("notify_tool")
.with_timeout(60)
.background()
.with_condition(HookCondition::ToolCategory {
category: "shell".to_string(),
});
assert_eq!(hook.name, Some("notify_tool".to_string()));
assert_eq!(hook.timeout_secs, 60);
assert!(hook.background);
assert!(matches!(
hook.condition,
Some(HookCondition::ToolCategory { .. })
));
}
#[test]
fn test_hook_timeout_enforced() {
let command = if cfg!(windows) {
"ping -n 3 127.0.0.1 > nul"
} else {
"sleep 2"
};
let hook = Hook::new(HookEvent::SessionStart, command).with_timeout(1);
let executor = HookExecutor::new(HooksConfig::default(), PathBuf::from("."));
let env_vars = HashMap::new();
let result = executor.execute_sync(&hook, &env_vars);
assert!(!result.success);
assert!(
result
.error
.as_ref()
.is_some_and(|e| e.contains("timed out"))
);
}
#[cfg(not(windows))]
#[test]
fn message_submit_stdin_write_does_not_deadlock_when_hook_writes_first() {
let dir = tempfile::tempdir().expect("tempdir");
let command = write_hook_script(
&dir,
"write_before_read.sh",
r#"#!/bin/sh
dd if=/dev/zero bs=1024 count=256 2>/dev/null | tr '\000' x
dd if=/dev/zero bs=1024 count=256 2>/dev/null | tr '\000' e >&2
payload=$(cat)
printf '\ndone:%s\n' "${#payload}"
"#,
);
let hook = Hook::new(HookEvent::MessageSubmit, &command).with_timeout(5);
let executor = HookExecutor::new(HooksConfig::default(), dir.path().to_path_buf());
let env_vars = HashMap::new();
let payload = json!({
"event": "message_submit",
"text": "x".repeat(256 * 1024),
});
let result = executor.execute_sync_with_stdin(&hook, &env_vars, &payload);
assert!(result.success, "hook should complete: {result:?}");
assert!(result.stdout.contains("done:"), "stdout was drained");
assert!(result.stderr.len() >= 256 * 1024, "stderr was drained");
}
#[test]
fn test_executor_session_id() {
let executor = HookExecutor::new(HooksConfig::default(), PathBuf::from("."));
assert!(executor.session_id().starts_with("sess_"));
assert_eq!(executor.session_id().len(), 13); }
#[cfg(not(windows))]
fn write_hook_script(dir: &tempfile::TempDir, name: &str, content: &str) -> String {
let path = dir.path().join(name);
std::fs::write(&path, content).expect("write hook script");
format!("sh {}", path.display())
}
#[cfg(not(windows))]
fn submit_context(dir: &tempfile::TempDir) -> HookContext {
HookContext::new()
.with_session_id("sess_test")
.with_workspace(dir.path().to_path_buf())
.with_mode("agent")
.with_model("deepseek-test")
.with_tokens(42)
}
#[cfg(not(windows))]
#[test]
fn message_submit_transform_applies_hooks_in_order() {
let dir = tempfile::tempdir().expect("tempdir");
let first = write_hook_script(
&dir,
"first.sh",
r#"#!/bin/sh
printf '%s\n' '{"text":"first"}'
"#,
);
let second = write_hook_script(
&dir,
"second.sh",
r#"#!/bin/sh
payload=$(cat)
case "$payload" in
*'"text":"first"'*) printf '%s\n' '{"text":"first second"}' ;;
*) printf '%s\n' '{"text":"wrong"}' ;;
esac
"#,
);
let config = HooksConfig {
enabled: true,
hooks: vec![
Hook::new(HookEvent::MessageSubmit, &first),
Hook::new(HookEvent::MessageSubmit, &second),
],
working_dir: Some(dir.path().to_path_buf()),
..HooksConfig::default()
};
let executor = HookExecutor::new(config, dir.path().to_path_buf());
assert_eq!(
executor.execute_message_submit_transform(&submit_context(&dir), "original"),
MessageSubmitOutcome::replaced("first second".to_string())
);
}
#[cfg(not(windows))]
#[test]
fn message_submit_transform_exit_two_blocks_submission() {
let dir = tempfile::tempdir().expect("tempdir");
let command = write_hook_script(
&dir,
"block.sh",
r#"#!/bin/sh
printf '%s\n' '{"reason":"policy blocked this prompt"}'
exit 2
"#,
);
let config = HooksConfig {
enabled: true,
hooks: vec![Hook::new(HookEvent::MessageSubmit, &command)],
working_dir: Some(dir.path().to_path_buf()),
..HooksConfig::default()
};
let executor = HookExecutor::new(config, dir.path().to_path_buf());
assert_eq!(
executor.execute_message_submit_transform(&submit_context(&dir), "original"),
MessageSubmitOutcome::Blocked {
reason: "policy blocked this prompt".to_string()
}
);
}
#[cfg(not(windows))]
#[test]
fn background_message_submit_hook_is_observer_only() {
let dir = tempfile::tempdir().expect("tempdir");
let command = write_hook_script(
&dir,
"background.sh",
r#"#!/bin/sh
printf '%s\n' '{"text":"ignored"}'
"#,
);
let config = HooksConfig {
enabled: true,
hooks: vec![Hook::new(HookEvent::MessageSubmit, &command).background()],
working_dir: Some(dir.path().to_path_buf()),
..HooksConfig::default()
};
let executor = HookExecutor::new(config, dir.path().to_path_buf());
assert_eq!(
executor.execute_message_submit_transform(&submit_context(&dir), "original"),
MessageSubmitOutcome::unchanged()
);
}
#[test]
fn message_submit_transform_without_configured_hooks_is_unchanged() {
let executor = HookExecutor::new(HooksConfig::default(), PathBuf::from("."));
assert_eq!(
executor.execute_message_submit_transform(&HookContext::new(), "original"),
MessageSubmitOutcome::unchanged()
);
}
#[cfg(not(windows))]
#[test]
fn message_submit_transform_skips_non_matching_condition() {
let dir = tempfile::tempdir().expect("tempdir");
let command = write_hook_script(
&dir,
"replace.sh",
r#"#!/bin/sh
printf '%s\n' '{"text":"should not apply"}'
"#,
);
let hook =
Hook::new(HookEvent::MessageSubmit, &command).with_condition(HookCondition::Mode {
mode: "plan".into(),
});
let config = HooksConfig {
enabled: true,
hooks: vec![hook],
working_dir: Some(dir.path().to_path_buf()),
..HooksConfig::default()
};
let executor = HookExecutor::new(config, dir.path().to_path_buf());
assert_eq!(
executor.execute_message_submit_transform(&submit_context(&dir), "original"),
MessageSubmitOutcome::unchanged()
);
}
#[cfg(not(windows))]
#[test]
fn message_submit_continue_on_error_true_keeps_text_and_runs_later_hooks() {
let dir = tempfile::tempdir().expect("tempdir");
let failing = write_hook_script(
&dir,
"fail_continue.sh",
r#"#!/bin/sh
printf '%s\n' 'soft failure' >&2
exit 9
"#,
);
let replacing = write_hook_script(
&dir,
"replace_after_failure.sh",
r#"#!/bin/sh
printf '%s\n' '{"text":"recovered"}'
"#,
);
let config = HooksConfig {
enabled: true,
hooks: vec![
Hook::new(HookEvent::MessageSubmit, &failing),
Hook::new(HookEvent::MessageSubmit, &replacing),
],
working_dir: Some(dir.path().to_path_buf()),
..HooksConfig::default()
};
let executor = HookExecutor::new(config, dir.path().to_path_buf());
assert_eq!(
executor.execute_message_submit_transform(&submit_context(&dir), "original"),
MessageSubmitOutcome::replaced("recovered".to_string())
.with_warning(Some("soft failure".to_string()))
);
}
#[cfg(not(windows))]
#[test]
fn message_submit_timeout_continue_surfaces_warning_and_runs_later_hooks() {
let dir = tempfile::tempdir().expect("tempdir");
let slow = write_hook_script(
&dir,
"slow_continue.sh",
r#"#!/bin/sh
sleep 2
"#,
);
let replacing = write_hook_script(
&dir,
"replace_after_timeout.sh",
r#"#!/bin/sh
printf '%s\n' '{"text":"after timeout"}'
"#,
);
let mut slow_hook = Hook::new(HookEvent::MessageSubmit, &slow).with_timeout(1);
slow_hook.continue_on_error = true;
let config = HooksConfig {
enabled: true,
hooks: vec![slow_hook, Hook::new(HookEvent::MessageSubmit, &replacing)],
working_dir: Some(dir.path().to_path_buf()),
..HooksConfig::default()
};
let executor = HookExecutor::new(config, dir.path().to_path_buf());
assert_eq!(
executor.execute_message_submit_transform(&submit_context(&dir), "original"),
MessageSubmitOutcome::replaced("after timeout".to_string())
.with_warning(Some("Hook timed out after 1s".to_string()))
);
}
#[cfg(not(windows))]
#[test]
fn message_submit_invalid_stdout_keeps_text_and_runs_later_hooks() {
let dir = tempfile::tempdir().expect("tempdir");
let invalid = write_hook_script(
&dir,
"invalid_stdout.sh",
r#"#!/bin/sh
printf '%s\n' 'not json'
"#,
);
let replacing = write_hook_script(
&dir,
"replace_after_invalid.sh",
r#"#!/bin/sh
printf '%s\n' '{"text":"valid later"}'
"#,
);
let config = HooksConfig {
enabled: true,
hooks: vec![
Hook::new(HookEvent::MessageSubmit, &invalid),
Hook::new(HookEvent::MessageSubmit, &replacing),
],
working_dir: Some(dir.path().to_path_buf()),
..HooksConfig::default()
};
let executor = HookExecutor::new(config, dir.path().to_path_buf());
assert_eq!(
executor.execute_message_submit_transform(&submit_context(&dir), "original"),
MessageSubmitOutcome::replaced("valid later".to_string())
);
}
#[cfg(not(windows))]
#[test]
fn message_submit_continue_on_error_false_blocks_on_failure() {
let dir = tempfile::tempdir().expect("tempdir");
let command = write_hook_script(
&dir,
"fail.sh",
r#"#!/bin/sh
printf '%s\n' 'hard failure' >&2
exit 7
"#,
);
let mut hook = Hook::new(HookEvent::MessageSubmit, &command);
hook.continue_on_error = false;
let config = HooksConfig {
enabled: true,
hooks: vec![hook],
working_dir: Some(dir.path().to_path_buf()),
..HooksConfig::default()
};
let executor = HookExecutor::new(config, dir.path().to_path_buf());
assert_eq!(
executor.execute_message_submit_transform(&submit_context(&dir), "original"),
MessageSubmitOutcome::Blocked {
reason: "hard failure".to_string()
}
);
}
#[test]
fn has_hooks_for_event_fast_path_returns_false_for_empty_config() {
let executor = HookExecutor::disabled();
for event in [
HookEvent::SessionStart,
HookEvent::SessionEnd,
HookEvent::MessageSubmit,
HookEvent::ToolCallBefore,
HookEvent::ToolCallAfter,
HookEvent::ModeChange,
HookEvent::OnError,
] {
assert!(
!executor.has_hooks_for_event(event),
"empty config must short-circuit for {event:?}"
);
}
}
#[test]
fn has_hooks_for_event_returns_false_when_globally_disabled() {
let config = HooksConfig {
enabled: false,
hooks: vec![Hook::new(HookEvent::ToolCallBefore, "echo blocked")],
..HooksConfig::default()
};
let executor = HookExecutor::new(config, PathBuf::from("."));
assert!(
!executor.has_hooks_for_event(HookEvent::ToolCallBefore),
"globally-disabled hooks must report no fires even when one is configured"
);
}
#[test]
fn has_hooks_for_event_distinguishes_event_types() {
let config = HooksConfig {
enabled: true,
hooks: vec![
Hook::new(HookEvent::SessionStart, "echo start"),
Hook::new(HookEvent::ToolCallBefore, "echo before"),
],
..HooksConfig::default()
};
let executor = HookExecutor::new(config, PathBuf::from("."));
assert!(executor.has_hooks_for_event(HookEvent::SessionStart));
assert!(executor.has_hooks_for_event(HookEvent::ToolCallBefore));
assert!(!executor.has_hooks_for_event(HookEvent::ToolCallAfter));
assert!(!executor.has_hooks_for_event(HookEvent::OnError));
assert!(!executor.has_hooks_for_event(HookEvent::ModeChange));
}
}