use std::panic::AssertUnwindSafe;
use std::path::Path;
use std::sync::Arc;
use std::time::Duration;
use agent_client_protocol_schema::SessionId;
use arc_swap::ArcSwap;
use futures::FutureExt;
use futures::future::BoxFuture;
use serde_json::Value;
use tokio_util::sync::CancellationToken;
use crate::error::BoxError;
use crate::tool::SafetyClass;
pub mod builtin;
pub mod command;
pub mod prompt;
pub mod step;
const DEFAULT_HANDLER_TIMEOUT: Duration = Duration::from_secs(5);
#[non_exhaustive]
#[derive(Debug, Clone, Default)]
pub struct HookMatcher {
pub tool: Option<String>,
pub tool_glob: Option<String>,
pub safety: Vec<SafetyClass>,
}
impl HookMatcher {
pub fn matches_step(&self, tool: Option<&str>, safety: Option<SafetyClass>) -> bool {
if let Some(expected) = &self.tool
&& tool.is_none_or(|n| n != expected)
{
return false;
}
if let Some(pat) = &self.tool_glob
&& tool.is_none_or(|n| !tool_name_matches(pat, n))
{
return false;
}
if !self.safety.is_empty() && safety.is_none_or(|s| !self.safety.contains(&s)) {
return false;
}
true
}
}
fn tool_name_matches(pattern: &str, name: &str) -> bool {
match globset::Glob::new(pattern) {
Ok(glob) => glob.compile_matcher().is_match(name),
Err(err) => {
tracing::warn!(%pattern, %err, "invalid tool_glob pattern; treating as no-match");
false
}
}
}
#[non_exhaustive]
pub struct HookCtx<'a> {
pub session_id: &'a SessionId,
pub cwd: &'a Path,
pub cancel: CancellationToken,
}
impl<'a> HookCtx<'a> {
pub fn new(session_id: &'a SessionId, cwd: &'a Path, cancel: CancellationToken) -> Self {
Self {
session_id,
cwd,
cancel,
}
}
}
#[non_exhaustive]
#[derive(Debug, thiserror::Error)]
pub enum HookError {
#[error("hook handler timed out")]
Timeout,
#[error("hook handler failed: {0}")]
HandlerFailed(#[source] BoxError),
#[error("hook configuration error: {0}")]
Configuration(String),
}
pub trait StepHandler: Send + Sync {
fn handle_step<'a>(
&'a self,
envelope: &'a Value,
ctx: HookCtx<'a>,
) -> BoxFuture<'a, Result<Option<Value>, HookError>>;
}
pub trait HookEngine: Send + Sync {
fn dispatch<'a>(
&'a self,
_step: &'a mut dyn step::HookStep,
_ctx: HookCtx<'a>,
) -> BoxFuture<'a, step::HookControl> {
Box::pin(async { step::HookControl::Proceed })
}
}
#[derive(Debug, Default)]
pub struct NoopHookEngine;
impl HookEngine for NoopHookEngine {}
#[derive(Default)]
pub struct HandlerTable {
pub step_buckets: std::collections::HashMap<&'static str, Vec<StepHandlerEntry>>,
}
pub struct StepHandlerEntry {
pub name: String,
pub matcher: HookMatcher,
pub handler: Arc<dyn StepHandler>,
pub timeout: Option<Duration>,
}
pub const ANONYMOUS_HOOK_NAME: &str = "anonymous";
impl StepHandlerEntry {
pub fn new(matcher: HookMatcher, handler: Arc<dyn StepHandler>) -> Self {
Self {
name: ANONYMOUS_HOOK_NAME.to_string(),
matcher,
handler,
timeout: None,
}
}
pub fn with_name(mut self, name: Option<String>) -> Self {
if let Some(name) = name {
self.name = name;
}
self
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
}
impl HandlerTable {
pub fn empty() -> Self {
Self::default()
}
pub fn step_handlers(&self, event_name: &str) -> &[StepHandlerEntry] {
self.step_buckets
.get(event_name)
.map(Vec::as_slice)
.unwrap_or(&[])
}
pub fn push_step(&mut self, event_name: &'static str, entry: StepHandlerEntry) {
self.step_buckets.entry(event_name).or_default().push(entry);
}
}
pub struct DefaultHookEngine {
table: ArcSwap<HandlerTable>,
}
impl DefaultHookEngine {
pub fn new() -> Self {
Self {
table: ArcSwap::from_pointee(HandlerTable::empty()),
}
}
pub fn reload(&self, table: HandlerTable) {
self.table.store(Arc::new(table));
}
#[doc(hidden)]
pub fn snapshot(&self) -> Arc<HandlerTable> {
self.table.load_full()
}
}
impl Default for DefaultHookEngine {
fn default() -> Self {
Self::new()
}
}
impl HookEngine for DefaultHookEngine {
fn dispatch<'a>(
&'a self,
step: &'a mut dyn step::HookStep,
ctx: HookCtx<'a>,
) -> BoxFuture<'a, step::HookControl> {
let table = self.table.load_full();
Box::pin(async move {
let entries = table.step_handlers(step.event_name());
if entries.is_empty() {
return step::HookControl::Proceed;
}
let envelope_json = with_common_header(step.to_envelope(), step.event_name(), &ctx);
let tool = envelope_json.get("tool").and_then(Value::as_str);
let safety = envelope_json
.get("safety")
.and_then(Value::as_str)
.and_then(parse_safety);
for entry in entries {
if !entry.matcher.matches_step(tool, safety) {
continue;
}
let envelope = with_common_header(step.to_envelope(), step.event_name(), &ctx);
let timeout = entry.timeout.unwrap_or(DEFAULT_HANDLER_TIMEOUT);
let handler_ctx = HookCtx::new(ctx.session_id, ctx.cwd, ctx.cancel.clone());
let fut = AssertUnwindSafe(entry.handler.handle_step(&envelope, handler_ctx))
.catch_unwind();
let verdict = match tokio::time::timeout(timeout, fut).await {
Ok(Ok(Ok(v))) => v,
Ok(Ok(Err(err))) => {
tracing::warn!(event = %step.event_name(), hook = %entry.name, error = %err, "step hook handler error; skipped");
continue;
}
Ok(Err(panic)) => {
tracing::warn!(event = %step.event_name(), hook = %entry.name, panic = %panic_message(&panic), "step hook handler panicked; skipped");
continue;
}
Err(_elapsed) => {
tracing::warn!(event = %step.event_name(), hook = %entry.name, "step hook handler timed out; skipped");
continue;
}
};
let Some(verdict) = verdict else { continue };
match step.apply_verdict(&verdict) {
Ok(step::HookControl::Proceed) => {}
Ok(control) => return control,
Err(err) => {
tracing::warn!(event = %step.event_name(), hook = %entry.name, error = %err, "step verdict malformed; skipped");
}
}
}
step::HookControl::Proceed
})
}
}
fn with_common_header(envelope: Value, event_name: &str, ctx: &HookCtx<'_>) -> Value {
let Value::Object(mut map) = envelope else {
return envelope;
};
map.entry("session_id")
.or_insert_with(|| Value::String(ctx.session_id.0.to_string()));
map.entry("cwd")
.or_insert_with(|| Value::String(ctx.cwd.to_string_lossy().into_owned()));
map.entry("hook_event")
.or_insert_with(|| Value::String(event_name.to_string()));
Value::Object(map)
}
fn parse_safety(s: &str) -> Option<SafetyClass> {
match s {
"read_only" => Some(SafetyClass::ReadOnly),
"mutating" => Some(SafetyClass::Mutating),
"destructive" => Some(SafetyClass::Destructive),
"network" => Some(SafetyClass::Network),
_ => None,
}
}
fn panic_message(payload: &Box<dyn std::any::Any + Send>) -> String {
if let Some(s) = payload.downcast_ref::<&'static str>() {
(*s).to_string()
} else if let Some(s) = payload.downcast_ref::<String>() {
s.clone()
} else {
"<non-string panic payload>".to_string()
}
}
#[cfg(test)]
mod tests;