use anyhow::{Result, anyhow};
use rusqlite::Connection;
use serde::Deserialize;
use serde_json::Value;
use std::sync::atomic::AtomicBool;
use std::sync::{Arc, Mutex};
use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
#[cfg(unix)]
use tokio::net::UnixListener;
use tracing::{debug, info, warn};
use crate::bridge::HookRouter;
use crate::bridge::toolbox::Toolbox;
use crate::session::MessageLog;
#[derive(Debug, Deserialize)]
#[serde(tag = "method")]
pub(crate) enum HookRequest {
#[serde(rename = "pre-agent/roots-sync")]
PreAgentRootsSync {},
#[serde(rename = "pre-tool/enforce-editing")]
PreToolEnforceEditing {
tool_name: String,
#[serde(default)]
#[allow(dead_code, reason = "kept for IPC backwards compatibility")]
file_path: Option<String>,
#[serde(default)]
agent_id: String,
#[serde(default)]
session_id: Option<String>,
},
#[serde(rename = "post-tool/diagnostics")]
PostToolDiagnostics {
file: String,
#[serde(default)]
tool: Option<String>,
#[serde(default)]
agent_id: String,
#[serde(default)]
session_id: Option<String>,
},
#[serde(rename = "post-agent/require-release")]
PostAgentRequireRelease {
#[serde(default)]
agent_id: String,
#[serde(default)]
stop_hook_active: bool,
},
#[serde(rename = "post-tool/done-editing")]
PostToolDoneEditing {
#[serde(default)]
agent_id: String,
#[serde(default)]
session_id: Option<String>,
},
#[serde(rename = "session-start/clear-editing")]
SessionStartClearEditing {
#[serde(default)]
session_id: Option<String>,
},
}
#[derive(serde::Serialize, serde::Deserialize, Debug, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum HookResult {
Content(String),
Error(String),
Deny(String),
Block(String),
Courtesy(String),
Cleared(usize),
}
pub struct HookServer {
router: Arc<HookRouter>,
message_log: Arc<MessageLog>,
}
impl HookServer {
#[must_use]
pub fn new(
toolbox: Arc<Toolbox>,
refresh_roots: Arc<AtomicBool>,
message_log: Arc<MessageLog>,
conn: Arc<Mutex<Connection>>,
session_id: String,
client_name: String,
) -> Self {
let router = Arc::new(HookRouter::new(
toolbox,
refresh_roots,
conn,
session_id,
client_name,
));
Self {
router,
message_log,
}
}
#[cfg(unix)]
pub fn start(self, socket_path: &std::path::Path) -> Result<tokio::task::JoinHandle<()>> {
let _ = std::fs::remove_file(socket_path);
let listener = UnixListener::bind(socket_path).map_err(|e| {
anyhow!(
"Failed to bind notify socket {}: {e}",
socket_path.display()
)
})?;
info!("Notify socket listening on {}", socket_path.display());
let server = Arc::new(self);
let handle = tokio::spawn(async move {
loop {
match listener.accept().await {
Ok((stream, _)) => {
let server = server.clone();
tokio::spawn(async move {
if let Err(e) = server.handle_connection(stream).await {
debug!("Notify connection error: {e}");
}
});
}
Err(e) => {
warn!("Notify socket accept error: {e}");
}
}
}
});
Ok(handle)
}
#[cfg(windows)]
pub fn start(self, pipe_path: &std::path::Path) -> Result<tokio::task::JoinHandle<()>> {
use tokio::net::windows::named_pipe::ServerOptions;
let pipe_name = pipe_path.to_string_lossy().to_string();
let mut server = ServerOptions::new()
.first_pipe_instance(true)
.create(&pipe_name)
.map_err(|e| anyhow!("Failed to create notify pipe {pipe_name}: {e}"))?;
info!("Notify pipe listening on {pipe_name}");
let server_arc = Arc::new(self);
let handle = tokio::spawn(async move {
loop {
if let Err(e) = server.connect().await {
warn!("Notify pipe connect error: {e}");
continue;
}
let connected = server;
server = match ServerOptions::new().create(&pipe_name) {
Ok(s) => s,
Err(e) => {
warn!("Notify pipe create error: {e}");
break;
}
};
let srv = server_arc.clone();
tokio::spawn(async move {
if let Err(e) = srv.handle_connection(connected).await {
debug!("Notify connection error: {e}");
}
});
}
});
Ok(handle)
}
async fn handle_connection<S: AsyncRead + AsyncWrite + Unpin>(&self, stream: S) -> Result<()> {
let (reader, mut writer) = tokio::io::split(stream);
let mut buf_reader = BufReader::new(reader);
let mut line = String::new();
buf_reader.read_line(&mut line).await?;
let raw: Value =
serde_json::from_str(line.trim()).map_err(|e| anyhow!("Invalid request: {e}"))?;
let method = raw
.get("method")
.and_then(|v| v.as_str())
.unwrap_or("unknown")
.to_string();
let entry_id = self.message_log.log(
"hook",
&method,
"catenary",
&self.router.client_name,
None,
None,
&raw,
);
let request: HookRequest =
serde_json::from_value(raw).map_err(|e| anyhow!("Invalid hook request: {e}"))?;
let result = self.router.dispatch(request, entry_id).await;
let response = result
.as_ref()
.map(|r| serde_json::to_string(r).unwrap_or_default())
.unwrap_or_default();
self.message_log.log(
"hook",
&method,
"catenary",
&self.router.client_name,
Some(entry_id),
None,
&serde_json::from_str::<Value>(&response).unwrap_or_default(),
);
writer.write_all(response.as_bytes()).await?;
writer.write_all(b"\n").await?;
writer.shutdown().await?;
Ok(())
}
}
#[cfg(test)]
#[allow(
clippy::expect_used,
reason = "tests use expect for readable assertions"
)]
mod tests {
use super::*;
#[test]
fn hook_result_content_round_trip() {
let original = HookResult::Content("[clean]".into());
let json = serde_json::to_string(&original).expect("serialize");
let parsed: HookResult = serde_json::from_str(&json).expect("deserialize");
assert_eq!(parsed, original);
let raw: serde_json::Value = serde_json::from_str(&json).expect("parse as value");
assert_eq!(raw["content"], "[clean]");
assert!(raw.get("error").is_none());
}
#[test]
fn hook_result_error_round_trip() {
let original = HookResult::Error("path resolution failed".into());
let json = serde_json::to_string(&original).expect("serialize");
let parsed: HookResult = serde_json::from_str(&json).expect("deserialize");
assert_eq!(parsed, original);
let raw: serde_json::Value = serde_json::from_str(&json).expect("parse as value");
assert_eq!(raw["error"], "path resolution failed");
assert!(raw.get("content").is_none());
}
#[test]
fn hook_result_deny_round_trip() {
let original = HookResult::Deny("call start_editing first".into());
let json = serde_json::to_string(&original).expect("serialize");
let parsed: HookResult = serde_json::from_str(&json).expect("deserialize");
assert_eq!(parsed, original);
}
#[test]
fn hook_result_block_round_trip() {
let original = HookResult::Block("call done_editing first".into());
let json = serde_json::to_string(&original).expect("serialize");
let parsed: HookResult = serde_json::from_str(&json).expect("deserialize");
assert_eq!(parsed, original);
}
#[test]
fn hook_result_courtesy_round_trip() {
let original = HookResult::Courtesy("[clean]".into());
let json = serde_json::to_string(&original).expect("serialize");
let parsed: HookResult = serde_json::from_str(&json).expect("deserialize");
assert_eq!(parsed, original);
}
#[test]
fn hook_result_cleared_round_trip() {
let original = HookResult::Cleared(3);
let json = serde_json::to_string(&original).expect("serialize");
let parsed: HookResult = serde_json::from_str(&json).expect("deserialize");
assert_eq!(parsed, original);
}
#[test]
fn test_hook_request_tagged_deserialization() {
let json = r#"{"method": "pre-agent/roots-sync"}"#;
let req: HookRequest = serde_json::from_str(json).expect("roots-sync");
assert!(matches!(req, HookRequest::PreAgentRootsSync {}));
let json = r#"{"method": "pre-tool/enforce-editing", "tool_name": "Edit", "file_path": "/tmp/foo.rs", "agent_id": "", "session_id": "abc123"}"#;
let req: HookRequest = serde_json::from_str(json).expect("enforce-editing");
let HookRequest::PreToolEnforceEditing {
tool_name,
file_path,
agent_id,
session_id,
} = req
else {
unreachable!("expected PreToolEnforceEditing");
};
assert_eq!(tool_name, "Edit");
assert_eq!(file_path.as_deref(), Some("/tmp/foo.rs"));
assert_eq!(agent_id, "");
assert_eq!(session_id.as_deref(), Some("abc123"));
let json =
r#"{"method": "post-tool/diagnostics", "file": "/tmp/test.rs", "tool": "Write"}"#;
let req: HookRequest = serde_json::from_str(json).expect("diagnostics");
let HookRequest::PostToolDiagnostics { file, tool, .. } = req else {
unreachable!("expected PostToolDiagnostics");
};
assert_eq!(file, "/tmp/test.rs");
assert_eq!(tool.as_deref(), Some("Write"));
let json = r#"{"method": "post-tool/diagnostics", "file": "/tmp/test.rs"}"#;
let req: HookRequest = serde_json::from_str(json).expect("diagnostics minimal");
let HookRequest::PostToolDiagnostics { tool, .. } = req else {
unreachable!("expected PostToolDiagnostics");
};
assert!(tool.is_none());
let json =
r#"{"method": "post-agent/require-release", "agent_id": "", "stop_hook_active": true}"#;
let req: HookRequest = serde_json::from_str(json).expect("require-release");
let HookRequest::PostAgentRequireRelease {
stop_hook_active, ..
} = req
else {
unreachable!("expected PostAgentRequireRelease");
};
assert!(stop_hook_active);
let json = r#"{"method": "session-start/clear-editing", "session_id": "uuid-123"}"#;
let req: HookRequest = serde_json::from_str(json).expect("clear-editing");
let HookRequest::SessionStartClearEditing { session_id } = req else {
unreachable!("expected SessionStartClearEditing");
};
assert_eq!(session_id.as_deref(), Some("uuid-123"));
}
#[test]
fn test_hook_log_file_request() {
let (_dir, _path, conn) = test_db();
let conn = Arc::new(std::sync::Mutex::new(conn));
conn.lock()
.expect("lock")
.execute(
"INSERT INTO sessions (id, pid, display_name, started_at) \
VALUES ('s1', 1, 'test', '2026-01-01T00:00:00Z')",
[],
)
.expect("insert session");
let log = Arc::new(MessageLog::new(conn.clone(), "s1".to_string()));
let method = "post-tool/diagnostics";
let request_payload = serde_json::json!({
"method": "post-tool/diagnostics",
"file": "/tmp/test.rs",
"tool": "Write"
});
let entry_id = log.log(
"hook",
method,
"catenary",
"claude-code",
None,
None,
&request_payload,
);
assert!(entry_id > 0);
let response_payload = serde_json::json!({"content": "[clean]"});
let resp_id = log.log(
"hook",
method,
"catenary",
"claude-code",
Some(entry_id),
None,
&response_payload,
);
assert!(resp_id > entry_id);
let (r_type, r_method): (String, String) = conn
.lock()
.expect("lock")
.query_row(
"SELECT type, method FROM messages WHERE id = ?1",
[entry_id],
|row| Ok((row.get(0)?, row.get(1)?)),
)
.expect("query request");
assert_eq!(r_type, "hook");
assert_eq!(r_method, "post-tool/diagnostics");
let stored_req_id: Option<i64> = conn
.lock()
.expect("lock")
.query_row(
"SELECT request_id FROM messages WHERE id = ?1",
[resp_id],
|row| row.get(0),
)
.expect("query response");
assert_eq!(stored_req_id, Some(entry_id));
}
#[test]
fn test_hook_log_refresh_roots() {
let (_dir, _path, conn) = test_db();
let conn = Arc::new(std::sync::Mutex::new(conn));
conn.lock()
.expect("lock")
.execute(
"INSERT INTO sessions (id, pid, display_name, started_at) \
VALUES ('s1', 1, 'test', '2026-01-01T00:00:00Z')",
[],
)
.expect("insert session");
let log = Arc::new(MessageLog::new(conn.clone(), "s1".to_string()));
let method = "pre-agent/roots-sync";
let request_payload = serde_json::json!({"method": "pre-agent/roots-sync"});
let entry_id = log.log(
"hook",
method,
"catenary",
"host",
None,
None,
&request_payload,
);
let response_payload = serde_json::json!("");
log.log(
"hook",
method,
"catenary",
"host",
Some(entry_id),
None,
&response_payload,
);
let r_method: String = conn
.lock()
.expect("lock")
.query_row(
"SELECT method FROM messages WHERE id = ?1",
[entry_id],
|row| row.get(0),
)
.expect("query");
assert_eq!(r_method, "pre-agent/roots-sync");
}
fn test_db() -> (tempfile::TempDir, std::path::PathBuf, rusqlite::Connection) {
let dir = tempfile::tempdir().expect("tempdir");
let path = dir.path().join("catenary").join("catenary.db");
let conn = crate::db::open_and_migrate_at(&path).expect("open test DB");
(dir, path, conn)
}
}