1use std::fmt;
4use std::future::Future;
5use std::path::PathBuf;
6use std::sync::Arc;
7use std::time::Duration;
8
9use anyhow::Context;
10use futures::future::BoxFuture;
11use halter_protocol::{HookHandlerType, PluginId};
12use serde::de::DeserializeOwned;
13use serde_json::Value;
14
15use crate::config::HookEventName;
16use crate::merge::{HookDecision, HookOutput, HookSpecificOutput, PermissionDecision};
17
18pub type HookCallbackFuture = BoxFuture<'static, anyhow::Result<HookResponse>>;
20pub type HookCallback = Arc<dyn Fn(HookInput) -> HookCallbackFuture + Send + Sync>;
22pub type HookFunctionFactory = Arc<dyn Fn() -> HookCallback + Send + Sync>;
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
26pub enum RegisteredHookPriority {
28 BeforePlugins,
30 #[default]
32 AfterPlugins,
33}
34
35#[derive(Debug, Clone)]
36pub struct HookInput {
38 pub event_name: HookEventName,
39 pub matcher_value: Option<String>,
40 pub payload: Value,
41}
42
43impl HookInput {
44 #[must_use]
46 pub fn field(&self, key: &str) -> Option<&Value> {
47 self.payload.get(key)
48 }
49
50 #[must_use]
52 pub fn string_field(&self, key: &str) -> Option<&str> {
53 self.field(key).and_then(Value::as_str)
54 }
55
56 #[must_use]
58 pub fn tool_name(&self) -> Option<&str> {
59 self.string_field("tool_name")
60 }
61
62 #[must_use]
64 pub fn tool_use_id(&self) -> Option<&str> {
65 self.string_field("tool_use_id")
66 }
67
68 pub fn decode<T: DeserializeOwned>(&self) -> anyhow::Result<T> {
70 serde_json::from_value(self.payload.clone()).context("failed to decode hook input")
71 }
72}
73
74#[derive(Debug, Clone, Default, PartialEq)]
75pub struct HookResponse {
77 output: HookOutput,
78}
79
80impl HookResponse {
81 #[must_use]
83 pub fn passthrough() -> Self {
84 Self::default()
85 }
86
87 #[must_use]
89 pub fn block(reason: impl Into<String>) -> Self {
90 Self {
91 output: HookOutput {
92 decision: Some(HookDecision::Block),
93 reason: Some(reason.into()),
94 ..HookOutput::default()
95 },
96 }
97 }
98
99 #[must_use]
101 pub fn stop(reason: impl Into<String>) -> Self {
102 Self {
103 output: HookOutput {
104 continue_execution: Some(false),
105 stop_reason: Some(reason.into()),
106 ..HookOutput::default()
107 },
108 }
109 }
110
111 #[must_use]
113 pub fn with_system_message(mut self, message: impl Into<String>) -> Self {
114 self.output.system_message = Some(message.into());
115 self
116 }
117
118 #[must_use]
120 pub fn with_additional_context(mut self, context: impl Into<String>) -> Self {
121 self.output
122 .hook_specific_output
123 .get_or_insert_with(HookSpecificOutput::default)
124 .additional_context = Some(context.into());
125 self
126 }
127
128 #[must_use]
130 pub fn with_updated_input(mut self, input: Value) -> Self {
131 self.output
132 .hook_specific_output
133 .get_or_insert_with(HookSpecificOutput::default)
134 .updated_input = Some(input);
135 self
136 }
137
138 #[must_use]
140 pub fn with_updated_output(mut self, output: Value) -> Self {
141 self.output
142 .hook_specific_output
143 .get_or_insert_with(HookSpecificOutput::default)
144 .updated_mcp_tool_output = Some(output);
145 self
146 }
147
148 #[must_use]
150 pub fn with_permission(
151 mut self,
152 decision: PermissionDecision,
153 reason: Option<impl Into<String>>,
154 ) -> Self {
155 let specific = self
156 .output
157 .hook_specific_output
158 .get_or_insert_with(HookSpecificOutput::default);
159 specific.permission_decision = Some(decision);
160 specific.permission_decision_reason = reason.map(Into::into);
161 self
162 }
163
164 #[must_use]
166 pub fn with_suppress_output(mut self, suppress_output: bool) -> Self {
167 self.output.suppress_output = Some(suppress_output);
168 self
169 }
170
171 #[must_use]
173 pub fn into_output(self) -> HookOutput {
174 self.output
175 }
176}
177
178impl From<HookOutput> for HookResponse {
179 fn from(output: HookOutput) -> Self {
180 Self { output }
181 }
182}
183
184pub trait IntoHookResponse {
186 fn into_hook_response(self) -> anyhow::Result<HookResponse>;
188}
189
190impl IntoHookResponse for HookResponse {
191 fn into_hook_response(self) -> anyhow::Result<HookResponse> {
192 Ok(self)
193 }
194}
195
196impl IntoHookResponse for HookOutput {
197 fn into_hook_response(self) -> anyhow::Result<HookResponse> {
198 Ok(HookResponse::from(self))
199 }
200}
201
202impl IntoHookResponse for anyhow::Result<HookResponse> {
203 fn into_hook_response(self) -> anyhow::Result<HookResponse> {
204 self
205 }
206}
207
208#[derive(Clone)]
209pub enum HookKind {
211 Callback(HookCallback),
213 Function(HookFunctionFactory),
215}
216
217impl fmt::Debug for HookKind {
218 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
219 match self {
220 Self::Callback(_) => f.write_str("Callback(..)"),
221 Self::Function(_) => f.write_str("Function(..)"),
222 }
223 }
224}
225
226impl HookKind {
227 #[must_use]
229 pub fn handler_type(&self) -> HookHandlerType {
230 match self {
231 Self::Callback(_) => HookHandlerType::Callback,
232 Self::Function(_) => HookHandlerType::Function,
233 }
234 }
235}
236
237#[derive(Debug, Clone)]
238pub struct Hook {
240 pub event: HookEventName,
241 pub matcher: Option<String>,
242 pub timeout: Duration,
243 pub status_message: Option<String>,
244 pub if_condition: Option<String>,
245 pub once: bool,
246 pub kind: HookKind,
247}
248
249impl Hook {
250 #[must_use]
252 pub fn callback<F, Fut, R>(event: HookEventName, callback: F) -> Self
253 where
254 F: Fn(HookInput) -> Fut + Send + Sync + 'static,
255 Fut: Future<Output = R> + Send + 'static,
256 R: IntoHookResponse + 'static,
257 {
258 Self {
259 event,
260 matcher: None,
261 timeout: Duration::from_secs(30),
262 status_message: None,
263 if_condition: None,
264 once: false,
265 kind: HookKind::Callback(Arc::new(move |input| {
266 let fut = callback(input);
267 Box::pin(async move { fut.await.into_hook_response() })
268 })),
269 }
270 }
271
272 #[must_use]
274 pub fn function<Factory, F, Fut, R>(event: HookEventName, factory: Factory) -> Self
275 where
276 Factory: Fn() -> F + Send + Sync + 'static,
277 F: Fn(HookInput) -> Fut + Send + Sync + 'static,
278 Fut: Future<Output = R> + Send + 'static,
279 R: IntoHookResponse + 'static,
280 {
281 Self {
282 event,
283 matcher: None,
284 timeout: Duration::from_secs(30),
285 status_message: None,
286 if_condition: None,
287 once: false,
288 kind: HookKind::Function(Arc::new(move || {
289 let callback = factory();
290 Arc::new(move |input| {
291 let fut = callback(input);
292 Box::pin(async move { fut.await.into_hook_response() })
293 })
294 })),
295 }
296 }
297
298 #[must_use]
300 pub fn with_matcher(mut self, matcher: impl Into<String>) -> Self {
301 self.matcher = Some(matcher.into());
302 self
303 }
304
305 #[must_use]
307 pub fn with_timeout(mut self, timeout: Duration) -> Self {
308 self.timeout = timeout;
309 self
310 }
311
312 #[must_use]
314 pub fn with_status_message(mut self, status_message: impl Into<String>) -> Self {
315 self.status_message = Some(status_message.into());
316 self
317 }
318
319 #[must_use]
321 pub fn with_if_condition(mut self, if_condition: impl Into<String>) -> Self {
322 self.if_condition = Some(if_condition.into());
323 self
324 }
325
326 #[must_use]
328 pub fn with_once(mut self, once: bool) -> Self {
329 self.once = once;
330 self
331 }
332}
333
334#[derive(Debug, Clone)]
335pub struct RegisteredHook {
337 pub plugin_id: PluginId,
338 pub plugin_root: PathBuf,
339 pub priority: RegisteredHookPriority,
340 pub hook: Hook,
341}
342
343#[derive(Debug, Clone, Default)]
344pub struct RegisteredHooks {
346 hooks: Vec<RegisteredHook>,
347}
348
349impl RegisteredHooks {
350 #[must_use]
352 pub fn is_empty(&self) -> bool {
353 self.hooks.is_empty()
354 }
355
356 pub fn register(&mut self, plugin_id: PluginId, priority: RegisteredHookPriority, hook: Hook) {
358 self.hooks.push(RegisteredHook {
359 plugin_id,
360 plugin_root: PathBuf::new(),
361 priority,
362 hook,
363 });
364 }
365
366 pub fn validate(&self) -> anyhow::Result<()> {
368 for hook in &self.hooks {
369 if let Some(matcher) = hook
370 .hook
371 .matcher
372 .as_deref()
373 .map(str::trim)
374 .filter(|value| !value.is_empty())
375 {
376 crate::matcher::CompiledMatcher::compile_regex(matcher).with_context(|| {
377 format!(
378 "failed to compile sdk hook matcher for plugin '{}' event '{}'",
379 hook.plugin_id,
380 hook.hook.event.canonical_name()
381 )
382 })?;
383 }
384 }
385 Ok(())
386 }
387
388 pub fn instantiate(&self) -> anyhow::Result<crate::Hooks> {
390 self.validate()?;
391 crate::Hooks::from_registered(self.hooks.clone())
392 }
393}
394
395#[cfg(test)]
396mod tests {
397 use std::collections::BTreeSet;
398 use std::sync::Arc;
399 use std::sync::atomic::{AtomicUsize, Ordering};
400
401 use serde_json::json;
402
403 use super::*;
404 use crate::{ConfiguredHandlerConfig, HookDispatchRequest, Hooks};
405
406 #[test]
407 fn registered_hooks_validate_rejects_invalid_matcher() {
408 let mut hooks = RegisteredHooks::default();
409 hooks.register(
410 PluginId::from("plugin"),
411 RegisteredHookPriority::AfterPlugins,
412 Hook::callback(HookEventName::Stop, |_input| async {
413 HookResponse::passthrough()
414 })
415 .with_matcher("["),
416 );
417
418 let error = hooks.validate().expect_err("invalid matcher should fail");
419 assert!(
420 error
421 .to_string()
422 .contains("failed to compile sdk hook matcher")
423 );
424 }
425
426 #[test]
427 fn hook_response_builders_populate_output() {
428 let output = HookResponse::block("blocked")
429 .with_system_message("system")
430 .with_additional_context("context")
431 .with_updated_input(json!({"command": "echo hi"}))
432 .with_updated_output(json!({"ok": true}))
433 .with_permission(PermissionDecision::Deny, Some("nope"))
434 .with_suppress_output(true)
435 .into_output();
436
437 assert_eq!(output.decision, Some(HookDecision::Block));
438 assert_eq!(output.reason.as_deref(), Some("blocked"));
439 assert_eq!(output.system_message.as_deref(), Some("system"));
440 assert_eq!(output.suppress_output, Some(true));
441
442 let specific = output.hook_specific_output.expect("hook specific output");
443 assert_eq!(specific.additional_context.as_deref(), Some("context"));
444 assert_eq!(specific.updated_input, Some(json!({"command": "echo hi"})));
445 assert_eq!(specific.updated_mcp_tool_output, Some(json!({"ok": true})));
446 assert_eq!(specific.permission_decision, Some(PermissionDecision::Deny));
447 assert_eq!(specific.permission_decision_reason.as_deref(), Some("nope"));
448 }
449
450 #[tokio::test]
451 async fn hook_function_factory_creates_fresh_callback_per_instantiate() {
452 let factory_calls = Arc::new(AtomicUsize::new(0));
453 let counter = factory_calls.clone();
454 let hook = Hook::function(HookEventName::Stop, move || {
455 let instance = counter.fetch_add(1, Ordering::SeqCst) + 1;
456 move |_input| async move {
457 Ok(HookResponse::passthrough()
458 .with_system_message(format!("factory-instance-{instance}")))
459 }
460 });
461
462 let mut registered = RegisteredHooks::default();
463 registered.register(
464 PluginId::from("plugin"),
465 RegisteredHookPriority::AfterPlugins,
466 hook,
467 );
468
469 let first_output =
470 invoke_function_handler(®istered.instantiate().expect("instantiate")).await;
471 let second_output =
472 invoke_function_handler(®istered.instantiate().expect("instantiate")).await;
473
474 assert_eq!(factory_calls.load(Ordering::SeqCst), 2);
475 assert_eq!(first_output.as_deref(), Some("factory-instance-1"));
476 assert_eq!(second_output.as_deref(), Some("factory-instance-2"));
477 }
478
479 async fn invoke_function_handler(hooks: &Hooks) -> Option<String> {
480 let prepared = hooks.prepare(HookDispatchRequest {
481 event_name: HookEventName::Stop,
482 matcher_value: None,
483 payload: json!({}),
484 fired_hook_ids: BTreeSet::new(),
485 });
486 let handler = prepared
487 .matched_handlers()
488 .first()
489 .cloned()
490 .expect("function handler");
491
492 let ConfiguredHandlerConfig::Function(callback) = handler.config else {
493 panic!("expected function handler");
494 };
495 let response = callback(HookInput {
496 event_name: HookEventName::Stop,
497 matcher_value: None,
498 payload: json!({}),
499 })
500 .await
501 .expect("callback response");
502
503 response.into_output().system_message
504 }
505}