use axum::{
extract::{DefaultBodyLimit, Path as AxPath, State},
http::{Request, StatusCode},
middleware::{self, Next},
response::{IntoResponse, Response},
routing::{get, post},
Json, Router,
};
use serde::{Deserialize, Serialize};
use std::net::SocketAddr;
use std::path::{Component, PathBuf};
use std::sync::Arc;
use thiserror::Error;
use tokio::sync::mpsc::UnboundedSender;
pub const LOOPBACK_PORT: u16 = 47823;
pub const MAX_BODY_BYTES: usize = 64 * 1024;
pub const KNOWN_TOOLS: &[&str] = &["claude", "cursor", "codex"];
#[derive(Debug, Error)]
pub enum HookEndpointError {
#[error("io: {0}")]
Io(#[from] std::io::Error),
#[error("port {0} is already in use")]
PortInUse(u16),
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct HookEvent {
pub tool: String,
pub event: String,
pub transcript_path: Option<PathBuf>,
pub session_id: Option<String>,
pub cwd: Option<PathBuf>,
#[serde(default, flatten)]
pub extra: serde_json::Map<String, serde_json::Value>,
}
#[derive(Debug, Deserialize)]
pub struct HookPayload {
pub transcript_path: Option<PathBuf>,
pub session_id: Option<String>,
pub cwd: Option<PathBuf>,
#[serde(default, flatten)]
pub extra: serde_json::Map<String, serde_json::Value>,
}
#[derive(Clone)]
struct AppState {
tx: Arc<UnboundedSender<HookEvent>>,
}
pub fn router(tx: UnboundedSender<HookEvent>) -> Router {
let state = AppState { tx: Arc::new(tx) };
Router::new()
.route("/health", get(health_handler))
.route("/hook/{tool}/{event}", post(hook_handler))
.layer(middleware::from_fn(require_loopback_host))
.layer(DefaultBodyLimit::max(MAX_BODY_BYTES))
.with_state(state)
}
pub async fn bind_loopback() -> Result<tokio::net::TcpListener, HookEndpointError> {
let addr = SocketAddr::from(([127, 0, 0, 1], LOOPBACK_PORT));
tokio::net::TcpListener::bind(addr).await.map_err(|e| {
if matches!(e.kind(), std::io::ErrorKind::AddrInUse) {
HookEndpointError::PortInUse(LOOPBACK_PORT)
} else {
HookEndpointError::Io(e)
}
})
}
pub async fn bind_test() -> Result<(tokio::net::TcpListener, SocketAddr), HookEndpointError> {
let addr = SocketAddr::from(([127, 0, 0, 1], 0));
let listener = tokio::net::TcpListener::bind(addr).await?;
let local = listener.local_addr()?;
Ok((listener, local))
}
async fn require_loopback_host(req: Request<axum::body::Body>, next: Next) -> Response {
let host = req
.headers()
.get(axum::http::header::HOST)
.and_then(|v| v.to_str().ok())
.unwrap_or("");
let host_only = host.split(':').next().unwrap_or("");
if host_only == "127.0.0.1" || host_only == "localhost" {
next.run(req).await
} else {
StatusCode::FORBIDDEN.into_response()
}
}
async fn health_handler() -> impl IntoResponse {
(StatusCode::OK, Json(serde_json::json!({"status": "ok"})))
}
async fn hook_handler(
AxPath((tool, event)): AxPath<(String, String)>,
State(state): State<AppState>,
Json(payload): Json<HookPayload>,
) -> Response {
if !KNOWN_TOOLS.contains(&tool.as_str()) {
return (
StatusCode::BAD_REQUEST,
Json(serde_json::json!({"error": "unknown tool", "allowed": KNOWN_TOOLS})),
)
.into_response();
}
let transcript_path = match payload.transcript_path {
Some(p) if path_has_parent_component(&p) => {
return (
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": "transcript_path may not contain `..` components",
})),
)
.into_response();
}
other => other,
};
let cwd = payload
.cwd
.filter(|p| p.is_absolute() && !path_has_parent_component(p));
let evt = HookEvent {
tool,
event,
transcript_path,
session_id: payload.session_id,
cwd,
extra: payload.extra,
};
let _ = state.tx.send(evt);
(StatusCode::OK, Json(serde_json::json!({"queued": true}))).into_response()
}
fn path_has_parent_component(p: &std::path::Path) -> bool {
p.components().any(|c| matches!(c, Component::ParentDir))
}
#[cfg(test)]
mod tests {
use super::*;
use axum::body::Body;
use axum::http::{Request, StatusCode};
use tower::ServiceExt;
fn make_router() -> (Router, tokio::sync::mpsc::UnboundedReceiver<HookEvent>) {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
(router(tx), rx)
}
fn loopback_req(method: &str, uri: &str, body: Body) -> Request<Body> {
Request::builder()
.method(method)
.uri(uri)
.header("host", "127.0.0.1:47823")
.header("content-type", "application/json")
.body(body)
.unwrap()
}
#[tokio::test]
async fn health_returns_ok() {
let (app, _rx) = make_router();
let response = app
.oneshot(loopback_req("GET", "/health", Body::empty()))
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(json["status"], "ok");
}
#[tokio::test]
async fn hook_route_accepts_path_params_and_pushes_event() {
let (app, mut rx) = make_router();
let response = app
.oneshot(loopback_req(
"POST",
"/hook/claude/SessionStart",
Body::from(r#"{"session_id": "abc123"}"#),
))
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let evt = rx.try_recv().expect("hook event must be queued");
assert_eq!(evt.tool, "claude");
assert_eq!(evt.event, "SessionStart");
assert_eq!(evt.session_id.as_deref(), Some("abc123"));
}
#[tokio::test]
async fn hook_route_uses_axum08_curly_braces() {
let (app, mut rx) = make_router();
let response = app
.oneshot(loopback_req(
"POST",
"/hook/cursor/sessionStart",
Body::from(r#"{}"#),
))
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let evt = rx.try_recv().expect("hook event must be queued");
assert_eq!(evt.tool, "cursor");
assert_eq!(evt.event, "sessionStart");
}
#[tokio::test]
async fn hook_event_pushed_to_channel() {
let (app, mut rx) = make_router();
app.oneshot(loopback_req(
"POST",
"/hook/codex/SessionStart",
Body::from(r#"{"transcript_path": "/tmp/abc.jsonl", "session_id": "s1"}"#),
))
.await
.unwrap();
let evt = rx.try_recv().expect("hook event must be queued");
assert_eq!(evt.tool, "codex");
assert_eq!(evt.event, "SessionStart");
assert_eq!(evt.transcript_path, Some(PathBuf::from("/tmp/abc.jsonl")));
assert_eq!(evt.session_id.as_deref(), Some("s1"));
}
#[tokio::test]
async fn hook_event_cwd_parsed_and_extra_fields_preserved() {
let (app, mut rx) = make_router();
app.oneshot(loopback_req(
"POST",
"/hook/claude/PreCompact",
Body::from(r#"{"cwd": "/synthetic/path/1", "version": "2.0.0"}"#),
))
.await
.unwrap();
let evt = rx.try_recv().expect("hook event must be queued");
assert_eq!(evt.cwd, Some(PathBuf::from("/synthetic/path/1")));
assert_eq!(evt.extra.get("version").unwrap(), "2.0.0");
}
#[tokio::test]
async fn hook_event_cwd_with_traversal_is_dropped() {
let (app, mut rx) = make_router();
app.oneshot(loopback_req(
"POST",
"/hook/claude/PreCompact",
Body::from(r#"{"cwd": "/tmp/../etc"}"#),
))
.await
.unwrap();
let evt = rx.try_recv().expect("hook event must be queued");
assert_eq!(evt.cwd, None, "traversal cwd should be dropped");
}
#[tokio::test]
async fn unknown_tool_is_rejected() {
let (app, mut rx) = make_router();
let response = app
.oneshot(loopback_req(
"POST",
"/hook/evil-tool/SessionStart",
Body::from(r#"{}"#),
))
.await
.unwrap();
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
assert!(rx.try_recv().is_err(), "no event for rejected tool");
}
#[tokio::test]
async fn parent_dir_in_transcript_path_is_rejected() {
let (app, mut rx) = make_router();
let response = app
.oneshot(loopback_req(
"POST",
"/hook/claude/SessionStart",
Body::from(r#"{"transcript_path": "/tmp/../etc/passwd"}"#),
))
.await
.unwrap();
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
assert!(rx.try_recv().is_err(), "no event for traversal attempt");
}
#[tokio::test]
async fn dns_rebinding_request_is_rejected() {
let (app, mut rx) = make_router();
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/hook/claude/SessionStart")
.header("host", "evil.example.com")
.header("content-type", "application/json")
.body(Body::from(r#"{}"#))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::FORBIDDEN);
assert!(rx.try_recv().is_err(), "no event for rebinding attempt");
}
#[tokio::test]
async fn body_size_limit_rejects_oversized_payload() {
let (app, mut rx) = make_router();
let huge = format!(
r#"{{"session_id":"{}"}}"#,
"a".repeat(MAX_BODY_BYTES + 1024)
);
let response = app
.oneshot(loopback_req(
"POST",
"/hook/claude/SessionStart",
Body::from(huge),
))
.await
.unwrap();
assert_ne!(response.status(), StatusCode::OK);
assert!(rx.try_recv().is_err(), "no event for oversized payload");
}
#[tokio::test]
async fn bind_test_returns_random_port() {
let (a, addr_a) = bind_test().await.unwrap();
let (b, addr_b) = bind_test().await.unwrap();
assert_ne!(addr_a.port(), addr_b.port());
assert!(addr_a.ip().is_loopback());
assert!(addr_b.ip().is_loopback());
drop(a);
drop(b);
}
#[tokio::test]
async fn port_in_use_returns_specific_error() {
let first = match bind_loopback().await {
Ok(l) => l,
Err(_) => return, };
let second = bind_loopback().await;
match second {
Err(HookEndpointError::PortInUse(p)) => assert_eq!(p, LOOPBACK_PORT),
Ok(_) => panic!("second bind unexpectedly succeeded"),
Err(e) => panic!("expected PortInUse, got {e:?}"),
}
drop(first);
}
#[tokio::test]
async fn unknown_route_returns_404() {
let (app, _rx) = make_router();
let response = app
.oneshot(loopback_req("POST", "/nope", Body::empty()))
.await
.unwrap();
assert_eq!(response.status(), StatusCode::NOT_FOUND);
}
#[tokio::test]
async fn health_path_only_accepts_get() {
let (app, _rx) = make_router();
let response = app
.oneshot(loopback_req("POST", "/health", Body::empty()))
.await
.unwrap();
assert_eq!(response.status(), StatusCode::METHOD_NOT_ALLOWED);
}
#[test]
fn loopback_constant_is_47823() {
assert_eq!(LOOPBACK_PORT, 47823);
}
}