use std::fmt;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use serde::{Deserialize, Serialize};
use tokio::sync::RwLock;
pub use super::events::HookEvent;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum HookMode {
Exec,
Daemon,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum FailMode {
Open,
Closed,
}
impl Default for FailMode {
fn default() -> Self {
FailMode::Open
}
}
fn default_fail_mode() -> FailMode {
FailMode::Open
}
pub const MAX_TIMEOUT_MS: u32 = 30_000;
#[must_use]
pub fn default_mode_for_event(event: HookEvent) -> HookMode {
match event {
HookEvent::PostRecall | HookEvent::PostSearch | HookEvent::PreRecallExpand => {
HookMode::Daemon
}
_ => HookMode::Exec,
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct HookConfig {
pub event: HookEvent,
pub command: PathBuf,
pub priority: i32,
pub timeout_ms: u32,
pub mode: HookMode,
pub enabled: bool,
pub namespace: String,
#[serde(default = "default_fail_mode")]
pub fail_mode: FailMode,
}
#[derive(Debug, Deserialize)]
struct HookConfigRaw {
event: HookEvent,
command: PathBuf,
priority: i32,
timeout_ms: u32,
#[serde(default)]
mode: Option<HookMode>,
enabled: bool,
namespace: String,
#[serde(default = "default_fail_mode")]
fail_mode: FailMode,
}
impl From<HookConfigRaw> for HookConfig {
fn from(raw: HookConfigRaw) -> Self {
let mode = raw
.mode
.unwrap_or_else(|| default_mode_for_event(raw.event));
HookConfig {
event: raw.event,
command: raw.command,
priority: raw.priority,
timeout_ms: raw.timeout_ms,
mode,
enabled: raw.enabled,
namespace: raw.namespace,
fail_mode: raw.fail_mode,
}
}
}
impl<'de> serde::Deserialize<'de> for HookConfig {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
HookConfigRaw::deserialize(deserializer).map(Into::into)
}
}
#[derive(Debug, Deserialize)]
struct HooksFile {
#[serde(default, rename = "hook")]
hooks: Vec<HookConfig>,
}
impl HookConfig {
pub fn load_from_file(path: &Path) -> Result<Vec<HookConfig>, HooksConfigError> {
let contents = std::fs::read_to_string(path).map_err(HooksConfigError::Io)?;
Self::load_from_str(&contents)
}
pub fn load_from_str(contents: &str) -> Result<Vec<HookConfig>, HooksConfigError> {
let parsed: HooksFile = toml::from_str(contents).map_err(|e| {
let (line, col) = e
.span()
.map(|s| byte_offset_to_line_col(contents, s.start))
.unwrap_or((0, 0));
HooksConfigError::Toml {
line,
column: col,
message: e.to_string(),
}
})?;
for (idx, h) in parsed.hooks.iter().enumerate() {
validate_hook(idx, h)?;
}
Ok(parsed.hooks)
}
pub fn default_path() -> Option<PathBuf> {
dirs::config_dir().map(|p| p.join("ai-memory/hooks.toml"))
}
}
fn validate_hook(idx: usize, h: &HookConfig) -> Result<(), HooksConfigError> {
if h.timeout_ms > MAX_TIMEOUT_MS {
return Err(HooksConfigError::Validation {
field: format!("hook[{idx}].timeout_ms"),
reason: format!("{} exceeds maximum {MAX_TIMEOUT_MS}ms", h.timeout_ms),
});
}
if h.command.as_os_str().is_empty() {
return Err(HooksConfigError::Validation {
field: format!("hook[{idx}].command"),
reason: "must be a non-empty path".into(),
});
}
if h.namespace.trim().is_empty() {
return Err(HooksConfigError::Validation {
field: format!("hook[{idx}].namespace"),
reason: "must be a non-empty pattern (use \"*\" to match all)".into(),
});
}
Ok(())
}
fn byte_offset_to_line_col(s: &str, offset: usize) -> (usize, usize) {
let mut line = 1usize;
let mut col = 1usize;
for (i, ch) in s.char_indices() {
if i >= offset {
break;
}
if ch == '\n' {
line += 1;
col = 1;
} else {
col += 1;
}
}
(line, col)
}
#[derive(Debug)]
pub enum HooksConfigError {
Io(std::io::Error),
Toml {
line: usize,
column: usize,
message: String,
},
Validation { field: String, reason: String },
}
impl fmt::Display for HooksConfigError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
HooksConfigError::Io(e) => write!(f, "hooks.toml read error: {e}"),
HooksConfigError::Toml {
line,
column,
message,
} => {
if *line == 0 {
write!(f, "hooks.toml parse error: {message}")
} else {
write!(
f,
"hooks.toml parse error at line {line}, column {column}: {message}"
)
}
}
HooksConfigError::Validation { field, reason } => {
write!(f, "hooks.toml validation error in {field}: {reason}")
}
}
}
}
impl std::error::Error for HooksConfigError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
HooksConfigError::Io(e) => Some(e),
_ => None,
}
}
}
pub type HookConfigSnapshot = RwLock<Vec<HookConfig>>;
#[cfg(unix)]
pub fn spawn_reload_task(
path: PathBuf,
snapshot: Arc<HookConfigSnapshot>,
) -> tokio::task::JoinHandle<()> {
use tokio::signal::unix::{SignalKind, signal};
tokio::spawn(async move {
let mut sighup = match signal(SignalKind::hangup()) {
Ok(s) => s,
Err(e) => {
tracing::error!(error = %e, "hooks: failed to install SIGHUP handler");
return;
}
};
while sighup.recv().await.is_some() {
match HookConfig::load_from_file(&path) {
Ok(new_cfg) => {
let count = new_cfg.len();
let mut guard = snapshot.write().await;
*guard = new_cfg;
tracing::info!(
path = %path.display(),
hooks = count,
"hooks: reloaded config on SIGHUP"
);
}
Err(e) => {
tracing::error!(
path = %path.display(),
error = %e,
"hooks: SIGHUP reload failed; keeping previous config"
);
}
}
}
})
}
#[cfg(not(unix))]
pub fn spawn_reload_task(
_path: PathBuf,
_snapshot: Arc<HookConfigSnapshot>,
) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
tracing::warn!("hooks: SIGHUP reload not supported on this platform");
})
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
const VALID_CANONICAL: &str = r#"
[[hook]]
event = "post_store"
command = "/usr/local/bin/auto-link-detector"
priority = 100
timeout_ms = 5000
mode = "daemon"
enabled = true
namespace = "team/*"
"#;
#[test]
fn parses_canonical_example() {
let hooks = HookConfig::load_from_str(VALID_CANONICAL).expect("parses");
assert_eq!(hooks.len(), 1);
let h = &hooks[0];
assert_eq!(h.event, HookEvent::PostStore);
assert_eq!(
h.command,
PathBuf::from("/usr/local/bin/auto-link-detector")
);
assert_eq!(h.priority, 100);
assert_eq!(h.timeout_ms, 5_000);
assert_eq!(h.mode, HookMode::Daemon);
assert!(h.enabled);
assert_eq!(h.namespace, "team/*");
}
#[test]
fn rejects_timeout_over_cap() {
let toml_src = r#"
[[hook]]
event = "post_recall"
command = "/bin/true"
priority = 0
timeout_ms = 60000
mode = "exec"
enabled = true
namespace = "*"
"#;
let err = HookConfig::load_from_str(toml_src).unwrap_err();
match err {
HooksConfigError::Validation { field, reason } => {
assert!(field.ends_with("timeout_ms"), "field was {field}");
assert!(reason.contains("30000"), "reason was {reason}");
}
other => panic!("expected Validation, got {other:?}"),
}
}
#[test]
fn invalid_toml_reports_line_number() {
let toml_src = "\n\n[[hook]]\nevent = \"post_store\"\nmode = \n";
let err = HookConfig::load_from_str(toml_src).unwrap_err();
match err {
HooksConfigError::Toml {
line, ref message, ..
} => {
assert!(line > 0, "expected non-zero line, got {line}");
let displayed = err.to_string();
assert!(
displayed.contains(&format!("line {line}")),
"Display did not surface line: {displayed} (raw msg: {message})"
);
}
other => panic!("expected Toml, got {other:?}"),
}
}
#[test]
fn multiple_hooks_same_event_preserve_order() {
let toml_src = r#"
[[hook]]
event = "post_store"
command = "/bin/first"
priority = 10
timeout_ms = 1000
mode = "exec"
enabled = true
namespace = "*"
[[hook]]
event = "post_store"
command = "/bin/second"
priority = 5
timeout_ms = 1000
mode = "exec"
enabled = true
namespace = "*"
[[hook]]
event = "post_store"
command = "/bin/third"
priority = 50
timeout_ms = 1000
mode = "exec"
enabled = true
namespace = "*"
"#;
let hooks = HookConfig::load_from_str(toml_src).expect("parses");
assert_eq!(hooks.len(), 3);
assert_eq!(hooks[0].command, PathBuf::from("/bin/first"));
assert_eq!(hooks[1].command, PathBuf::from("/bin/second"));
assert_eq!(hooks[2].command, PathBuf::from("/bin/third"));
assert!(hooks.iter().all(|h| h.event == HookEvent::PostStore));
}
#[test]
fn rejects_empty_namespace() {
let toml_src = r#"
[[hook]]
event = "post_store"
command = "/bin/true"
priority = 0
timeout_ms = 1000
mode = "exec"
enabled = true
namespace = ""
"#;
let err = HookConfig::load_from_str(toml_src).unwrap_err();
assert!(matches!(err, HooksConfigError::Validation { .. }));
}
#[test]
fn rejects_empty_command() {
let toml_src = r#"
[[hook]]
event = "post_store"
command = ""
priority = 0
timeout_ms = 1000
mode = "exec"
enabled = true
namespace = "*"
"#;
let err = HookConfig::load_from_str(toml_src).unwrap_err();
match err {
HooksConfigError::Validation { field, .. } => {
assert!(field.ends_with("command"), "field was {field}");
}
other => panic!("expected Validation, got {other:?}"),
}
}
#[test]
fn empty_file_yields_zero_hooks() {
let hooks = HookConfig::load_from_str("").expect("parses");
assert!(hooks.is_empty());
}
#[test]
fn test_post_recall_default_mode_is_daemon() {
let toml_src = r#"
[[hook]]
event = "post_recall"
command = "/bin/true"
priority = 0
timeout_ms = 1000
enabled = true
namespace = "*"
"#;
let hooks = HookConfig::load_from_str(toml_src).expect("parses with no mode field");
assert_eq!(hooks.len(), 1);
assert_eq!(hooks[0].event, HookEvent::PostRecall);
assert_eq!(
hooks[0].mode,
HookMode::Daemon,
"post_recall must default to daemon mode (R3-S3)"
);
}
#[test]
fn test_post_search_default_mode_is_daemon() {
let toml_src = r#"
[[hook]]
event = "post_search"
command = "/bin/true"
priority = 0
timeout_ms = 1000
enabled = true
namespace = "*"
"#;
let hooks = HookConfig::load_from_str(toml_src).expect("parses with no mode field");
assert_eq!(hooks.len(), 1);
assert_eq!(
hooks[0].mode,
HookMode::Daemon,
"post_search must default to daemon mode (R3-S3)"
);
}
#[test]
fn test_pre_recall_expand_default_mode_is_daemon() {
let toml_src = r#"
[[hook]]
event = "pre_recall_expand"
command = "/bin/true"
priority = 0
timeout_ms = 1000
enabled = true
namespace = "*"
"#;
let hooks = HookConfig::load_from_str(toml_src).expect("parses with no mode field");
assert_eq!(hooks[0].mode, HookMode::Daemon);
}
#[test]
fn test_post_store_default_mode_is_exec() {
let toml_src = r#"
[[hook]]
event = "post_store"
command = "/bin/true"
priority = 0
timeout_ms = 1000
enabled = true
namespace = "*"
"#;
let hooks = HookConfig::load_from_str(toml_src).expect("parses with no mode field");
assert_eq!(
hooks[0].mode,
HookMode::Exec,
"cold-path events still default to exec mode (no R3-S3 change)"
);
}
#[test]
fn test_explicit_mode_overrides_default() {
let toml_src = r#"
[[hook]]
event = "post_recall"
command = "/bin/true"
priority = 0
timeout_ms = 1000
mode = "exec"
enabled = true
namespace = "*"
"#;
let hooks = HookConfig::load_from_str(toml_src).expect("parses");
assert_eq!(
hooks[0].mode,
HookMode::Exec,
"explicit mode = \"exec\" must not be silently flipped to daemon"
);
}
#[test]
fn load_from_file_round_trip() {
let mut tmp = tempfile::NamedTempFile::new().expect("tempfile");
tmp.write_all(VALID_CANONICAL.as_bytes()).expect("write");
let hooks = HookConfig::load_from_file(tmp.path()).expect("loads");
assert_eq!(hooks.len(), 1);
assert_eq!(hooks[0].event, HookEvent::PostStore);
}
#[tokio::test]
async fn sighup_reload_swaps_snapshot() {
let mut tmp = tempfile::NamedTempFile::new().expect("tempfile");
tmp.write_all(VALID_CANONICAL.as_bytes()).expect("write A");
let snapshot: Arc<HookConfigSnapshot> = Arc::new(RwLock::new(
HookConfig::load_from_file(tmp.path()).expect("load A"),
));
{
let guard = snapshot.read().await;
assert_eq!(guard.len(), 1);
assert_eq!(
guard[0].command,
PathBuf::from("/usr/local/bin/auto-link-detector")
);
}
let config_b = r#"
[[hook]]
event = "pre_store"
command = "/opt/hooks/redact-pii"
priority = 200
timeout_ms = 2500
mode = "exec"
enabled = true
namespace = "*"
[[hook]]
event = "post_recall"
command = "/opt/hooks/expand-context"
priority = 50
timeout_ms = 100
mode = "daemon"
enabled = false
namespace = "team/*"
"#;
std::fs::write(tmp.path(), config_b).expect("rewrite to B");
let new_cfg = HookConfig::load_from_file(tmp.path()).expect("load B");
{
let mut guard = snapshot.write().await;
*guard = new_cfg;
}
let guard = snapshot.read().await;
assert_eq!(guard.len(), 2);
assert_eq!(guard[0].event, HookEvent::PreStore);
assert_eq!(guard[0].command, PathBuf::from("/opt/hooks/redact-pii"));
assert_eq!(guard[1].event, HookEvent::PostRecall);
assert!(!guard[1].enabled);
}
#[test]
fn default_path_is_under_config_dir() {
if let Some(p) = HookConfig::default_path() {
let s = p.to_string_lossy();
assert!(
s.ends_with("ai-memory/hooks.toml") || s.ends_with("ai-memory\\hooks.toml"),
"unexpected default path: {s}"
);
}
}
#[test]
fn hook_event_serde_uses_snake_case() {
let json = serde_json::to_string(&HookEvent::PreGovernanceDecision).unwrap();
assert_eq!(json, "\"pre_governance_decision\"");
let back: HookEvent = serde_json::from_str("\"on_index_eviction\"").unwrap();
assert_eq!(back, HookEvent::OnIndexEviction);
}
#[test]
fn hook_mode_serde_uses_snake_case() {
let exec_json = serde_json::to_string(&HookMode::Exec).unwrap();
let daemon_json = serde_json::to_string(&HookMode::Daemon).unwrap();
assert_eq!(exec_json, "\"exec\"");
assert_eq!(daemon_json, "\"daemon\"");
}
#[test]
fn fail_mode_default_is_open() {
assert_eq!(FailMode::default(), FailMode::Open);
assert_eq!(default_fail_mode(), FailMode::Open);
}
#[test]
fn fail_mode_serde_round_trip() {
let open = serde_json::to_string(&FailMode::Open).unwrap();
let closed = serde_json::to_string(&FailMode::Closed).unwrap();
assert_eq!(open, "\"open\"");
assert_eq!(closed, "\"closed\"");
let back: FailMode = serde_json::from_str("\"closed\"").unwrap();
assert_eq!(back, FailMode::Closed);
}
#[test]
fn default_mode_for_event_matrix() {
assert_eq!(
default_mode_for_event(HookEvent::PostRecall),
HookMode::Daemon
);
assert_eq!(
default_mode_for_event(HookEvent::PostSearch),
HookMode::Daemon
);
assert_eq!(
default_mode_for_event(HookEvent::PreRecallExpand),
HookMode::Daemon
);
assert_eq!(default_mode_for_event(HookEvent::PostStore), HookMode::Exec);
assert_eq!(default_mode_for_event(HookEvent::PreStore), HookMode::Exec);
assert_eq!(default_mode_for_event(HookEvent::PreDelete), HookMode::Exec);
}
#[test]
fn fail_mode_closed_is_parsed() {
let toml_src = r#"
[[hook]]
event = "post_store"
command = "/bin/true"
priority = 0
timeout_ms = 1000
mode = "exec"
enabled = true
namespace = "*"
fail_mode = "closed"
"#;
let hooks = HookConfig::load_from_str(toml_src).expect("parses");
assert_eq!(hooks[0].fail_mode, FailMode::Closed);
}
#[test]
fn fail_mode_omitted_defaults_to_open() {
let toml_src = r#"
[[hook]]
event = "post_store"
command = "/bin/true"
priority = 0
timeout_ms = 1000
mode = "exec"
enabled = true
namespace = "*"
"#;
let hooks = HookConfig::load_from_str(toml_src).expect("parses");
assert_eq!(hooks[0].fail_mode, FailMode::Open);
}
#[test]
fn validation_error_display_surfaces_field_and_reason() {
let err = HooksConfigError::Validation {
field: "hook[0].timeout_ms".into(),
reason: "exceeds maximum".into(),
};
let s = err.to_string();
assert!(s.contains("hook[0].timeout_ms"));
assert!(s.contains("exceeds maximum"));
}
#[test]
fn io_error_display_and_source() {
let io_err = std::io::Error::other("simulated read failure");
let err = HooksConfigError::Io(io_err);
let s = err.to_string();
assert!(s.contains("hooks.toml read error"));
assert!(s.contains("simulated read failure"));
use std::error::Error;
assert!(err.source().is_some());
}
#[test]
fn toml_error_no_span_displays_without_line_marker() {
let err = HooksConfigError::Toml {
line: 0,
column: 0,
message: "no span here".into(),
};
let s = err.to_string();
assert!(s.contains("no span here"));
assert!(!s.contains("line 0"));
}
#[test]
fn toml_error_with_span_displays_line_and_column() {
let err = HooksConfigError::Toml {
line: 7,
column: 3,
message: "broken".into(),
};
let s = err.to_string();
assert!(s.contains("line 7"));
assert!(s.contains("column 3"));
}
#[test]
fn hooks_config_error_source_for_non_io_variants_is_none() {
use std::error::Error;
let v = HooksConfigError::Validation {
field: "x".into(),
reason: "y".into(),
};
assert!(v.source().is_none());
let t = HooksConfigError::Toml {
line: 0,
column: 0,
message: "z".into(),
};
assert!(t.source().is_none());
}
#[test]
fn load_from_file_returns_io_error_for_missing_path() {
let p = std::path::Path::new("/this/path/does/not/exist/hooks-test.toml");
let err = HookConfig::load_from_file(p).unwrap_err();
assert!(matches!(err, HooksConfigError::Io(_)));
}
#[test]
fn rejects_whitespace_only_namespace() {
let toml_src = r#"
[[hook]]
event = "post_store"
command = "/bin/true"
priority = 0
timeout_ms = 1000
mode = "exec"
enabled = true
namespace = " "
"#;
let err = HookConfig::load_from_str(toml_src).unwrap_err();
match err {
HooksConfigError::Validation { field, .. } => {
assert!(field.ends_with("namespace"));
}
other => panic!("expected Validation, got {other:?}"),
}
}
#[test]
fn byte_offset_to_line_col_handles_multiline_input() {
let s = "first\nsecond\nthird";
assert_eq!(byte_offset_to_line_col(s, 0), (1, 1));
assert_eq!(byte_offset_to_line_col(s, 5), (1, 6));
assert_eq!(byte_offset_to_line_col(s, 6), (2, 1));
let (line, _) = byte_offset_to_line_col(s, 9_999);
assert!(line >= 3);
}
}