1use std::fmt;
4use std::future::Future;
5use std::path::PathBuf;
6use std::sync::Arc;
7use std::time::Duration;
8
9use anyhow::Context;
10use futures::future::BoxFuture;
11use halter_protocol::{HookHandlerType, PluginId};
12use serde::de::DeserializeOwned;
13use serde_json::Value;
14
15use crate::config::HookEventName;
16use crate::merge::{HookDecision, HookOutput, HookSpecificOutput, PermissionDecision};
17
18pub type HookCallbackFuture = BoxFuture<'static, anyhow::Result<HookResponse>>;
19pub type HookCallback = Arc<dyn Fn(HookInput) -> HookCallbackFuture + Send + Sync>;
20pub type HookFunctionFactory = Arc<dyn Fn() -> HookCallback + Send + Sync>;
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
23pub enum RegisteredHookPriority {
24 BeforePlugins,
25 #[default]
26 AfterPlugins,
27}
28
29#[derive(Debug, Clone)]
30pub struct HookInput {
31 pub event_name: HookEventName,
32 pub matcher_value: Option<String>,
33 pub payload: Value,
34}
35
36impl HookInput {
37 #[must_use]
38 pub fn field(&self, key: &str) -> Option<&Value> {
39 self.payload.get(key)
40 }
41
42 #[must_use]
43 pub fn string_field(&self, key: &str) -> Option<&str> {
44 self.field(key).and_then(Value::as_str)
45 }
46
47 #[must_use]
48 pub fn tool_name(&self) -> Option<&str> {
49 self.string_field("tool_name")
50 }
51
52 #[must_use]
53 pub fn tool_use_id(&self) -> Option<&str> {
54 self.string_field("tool_use_id")
55 }
56
57 pub fn decode<T: DeserializeOwned>(&self) -> anyhow::Result<T> {
58 serde_json::from_value(self.payload.clone()).context("failed to decode hook input")
59 }
60}
61
62#[derive(Debug, Clone, Default, PartialEq)]
63pub struct HookResponse {
64 output: HookOutput,
65}
66
67impl HookResponse {
68 #[must_use]
69 pub fn passthrough() -> Self {
70 Self::default()
71 }
72
73 #[must_use]
74 pub fn block(reason: impl Into<String>) -> Self {
75 Self {
76 output: HookOutput {
77 decision: Some(HookDecision::Block),
78 reason: Some(reason.into()),
79 ..HookOutput::default()
80 },
81 }
82 }
83
84 #[must_use]
85 pub fn stop(reason: impl Into<String>) -> Self {
86 Self {
87 output: HookOutput {
88 continue_execution: Some(false),
89 stop_reason: Some(reason.into()),
90 ..HookOutput::default()
91 },
92 }
93 }
94
95 #[must_use]
96 pub fn with_system_message(mut self, message: impl Into<String>) -> Self {
97 self.output.system_message = Some(message.into());
98 self
99 }
100
101 #[must_use]
102 pub fn with_additional_context(mut self, context: impl Into<String>) -> Self {
103 self.output
104 .hook_specific_output
105 .get_or_insert_with(HookSpecificOutput::default)
106 .additional_context = Some(context.into());
107 self
108 }
109
110 #[must_use]
111 pub fn with_updated_input(mut self, input: Value) -> Self {
112 self.output
113 .hook_specific_output
114 .get_or_insert_with(HookSpecificOutput::default)
115 .updated_input = Some(input);
116 self
117 }
118
119 #[must_use]
120 pub fn with_updated_output(mut self, output: Value) -> Self {
121 self.output
122 .hook_specific_output
123 .get_or_insert_with(HookSpecificOutput::default)
124 .updated_mcp_tool_output = Some(output);
125 self
126 }
127
128 #[must_use]
129 pub fn with_permission(
130 mut self,
131 decision: PermissionDecision,
132 reason: Option<impl Into<String>>,
133 ) -> Self {
134 let specific = self
135 .output
136 .hook_specific_output
137 .get_or_insert_with(HookSpecificOutput::default);
138 specific.permission_decision = Some(decision);
139 specific.permission_decision_reason = reason.map(Into::into);
140 self
141 }
142
143 #[must_use]
144 pub fn with_suppress_output(mut self, suppress_output: bool) -> Self {
145 self.output.suppress_output = Some(suppress_output);
146 self
147 }
148
149 #[must_use]
150 pub fn into_output(self) -> HookOutput {
151 self.output
152 }
153}
154
155impl From<HookOutput> for HookResponse {
156 fn from(output: HookOutput) -> Self {
157 Self { output }
158 }
159}
160
161pub trait IntoHookResponse {
162 fn into_hook_response(self) -> anyhow::Result<HookResponse>;
163}
164
165impl IntoHookResponse for HookResponse {
166 fn into_hook_response(self) -> anyhow::Result<HookResponse> {
167 Ok(self)
168 }
169}
170
171impl IntoHookResponse for HookOutput {
172 fn into_hook_response(self) -> anyhow::Result<HookResponse> {
173 Ok(HookResponse::from(self))
174 }
175}
176
177impl IntoHookResponse for anyhow::Result<HookResponse> {
178 fn into_hook_response(self) -> anyhow::Result<HookResponse> {
179 self
180 }
181}
182
183#[derive(Clone)]
184pub enum HookKind {
185 Callback(HookCallback),
186 Function(HookFunctionFactory),
187}
188
189impl fmt::Debug for HookKind {
190 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
191 match self {
192 Self::Callback(_) => f.write_str("Callback(..)"),
193 Self::Function(_) => f.write_str("Function(..)"),
194 }
195 }
196}
197
198impl HookKind {
199 #[must_use]
200 pub fn handler_type(&self) -> HookHandlerType {
201 match self {
202 Self::Callback(_) => HookHandlerType::Callback,
203 Self::Function(_) => HookHandlerType::Function,
204 }
205 }
206}
207
208#[derive(Debug, Clone)]
209pub struct Hook {
210 pub event: HookEventName,
211 pub matcher: Option<String>,
212 pub timeout: Duration,
213 pub status_message: Option<String>,
214 pub if_condition: Option<String>,
215 pub once: bool,
216 pub kind: HookKind,
217}
218
219impl Hook {
220 #[must_use]
221 pub fn callback<F, Fut, R>(event: HookEventName, callback: F) -> Self
222 where
223 F: Fn(HookInput) -> Fut + Send + Sync + 'static,
224 Fut: Future<Output = R> + Send + 'static,
225 R: IntoHookResponse + 'static,
226 {
227 Self {
228 event,
229 matcher: None,
230 timeout: Duration::from_secs(30),
231 status_message: None,
232 if_condition: None,
233 once: false,
234 kind: HookKind::Callback(Arc::new(move |input| {
235 let fut = callback(input);
236 Box::pin(async move { fut.await.into_hook_response() })
237 })),
238 }
239 }
240
241 #[must_use]
242 pub fn function<Factory, F, Fut, R>(event: HookEventName, factory: Factory) -> Self
243 where
244 Factory: Fn() -> F + Send + Sync + 'static,
245 F: Fn(HookInput) -> Fut + Send + Sync + 'static,
246 Fut: Future<Output = R> + Send + 'static,
247 R: IntoHookResponse + 'static,
248 {
249 Self {
250 event,
251 matcher: None,
252 timeout: Duration::from_secs(30),
253 status_message: None,
254 if_condition: None,
255 once: false,
256 kind: HookKind::Function(Arc::new(move || {
257 let callback = factory();
258 Arc::new(move |input| {
259 let fut = callback(input);
260 Box::pin(async move { fut.await.into_hook_response() })
261 })
262 })),
263 }
264 }
265
266 #[must_use]
267 pub fn with_matcher(mut self, matcher: impl Into<String>) -> Self {
268 self.matcher = Some(matcher.into());
269 self
270 }
271
272 #[must_use]
273 pub fn with_timeout(mut self, timeout: Duration) -> Self {
274 self.timeout = timeout;
275 self
276 }
277
278 #[must_use]
279 pub fn with_status_message(mut self, status_message: impl Into<String>) -> Self {
280 self.status_message = Some(status_message.into());
281 self
282 }
283
284 #[must_use]
285 pub fn with_if_condition(mut self, if_condition: impl Into<String>) -> Self {
286 self.if_condition = Some(if_condition.into());
287 self
288 }
289
290 #[must_use]
291 pub fn with_once(mut self, once: bool) -> Self {
292 self.once = once;
293 self
294 }
295}
296
297#[derive(Debug, Clone)]
298pub struct RegisteredHook {
299 pub plugin_id: PluginId,
300 pub plugin_root: PathBuf,
301 pub priority: RegisteredHookPriority,
302 pub hook: Hook,
303}
304
305#[derive(Debug, Clone, Default)]
306pub struct RegisteredHooks {
307 hooks: Vec<RegisteredHook>,
308}
309
310impl RegisteredHooks {
311 #[must_use]
312 pub fn is_empty(&self) -> bool {
313 self.hooks.is_empty()
314 }
315
316 pub fn register(&mut self, plugin_id: PluginId, priority: RegisteredHookPriority, hook: Hook) {
317 self.hooks.push(RegisteredHook {
318 plugin_id,
319 plugin_root: PathBuf::new(),
320 priority,
321 hook,
322 });
323 }
324
325 pub fn validate(&self) -> anyhow::Result<()> {
326 for hook in &self.hooks {
327 if let Some(matcher) = hook
328 .hook
329 .matcher
330 .as_deref()
331 .map(str::trim)
332 .filter(|value| !value.is_empty())
333 {
334 crate::matcher::CompiledMatcher::compile_regex(matcher).with_context(|| {
335 format!(
336 "failed to compile sdk hook matcher for plugin '{}' event '{}'",
337 hook.plugin_id,
338 hook.hook.event.canonical_name()
339 )
340 })?;
341 }
342 }
343 Ok(())
344 }
345
346 pub fn instantiate(&self) -> anyhow::Result<crate::Hooks> {
347 self.validate()?;
348 crate::Hooks::from_registered(self.hooks.clone())
349 }
350}
351
352#[cfg(test)]
353mod tests {
354 use std::collections::BTreeSet;
355 use std::sync::Arc;
356 use std::sync::atomic::{AtomicUsize, Ordering};
357
358 use serde_json::json;
359
360 use super::*;
361 use crate::{ConfiguredHandlerConfig, HookDispatchRequest, Hooks};
362
363 #[test]
364 fn registered_hooks_validate_rejects_invalid_matcher() {
365 let mut hooks = RegisteredHooks::default();
366 hooks.register(
367 PluginId::from("plugin"),
368 RegisteredHookPriority::AfterPlugins,
369 Hook::callback(HookEventName::Stop, |_input| async {
370 HookResponse::passthrough()
371 })
372 .with_matcher("["),
373 );
374
375 let error = hooks.validate().expect_err("invalid matcher should fail");
376 assert!(
377 error
378 .to_string()
379 .contains("failed to compile sdk hook matcher")
380 );
381 }
382
383 #[test]
384 fn hook_response_builders_populate_output() {
385 let output = HookResponse::block("blocked")
386 .with_system_message("system")
387 .with_additional_context("context")
388 .with_updated_input(json!({"command": "echo hi"}))
389 .with_updated_output(json!({"ok": true}))
390 .with_permission(PermissionDecision::Deny, Some("nope"))
391 .with_suppress_output(true)
392 .into_output();
393
394 assert_eq!(output.decision, Some(HookDecision::Block));
395 assert_eq!(output.reason.as_deref(), Some("blocked"));
396 assert_eq!(output.system_message.as_deref(), Some("system"));
397 assert_eq!(output.suppress_output, Some(true));
398
399 let specific = output.hook_specific_output.expect("hook specific output");
400 assert_eq!(specific.additional_context.as_deref(), Some("context"));
401 assert_eq!(specific.updated_input, Some(json!({"command": "echo hi"})));
402 assert_eq!(specific.updated_mcp_tool_output, Some(json!({"ok": true})));
403 assert_eq!(specific.permission_decision, Some(PermissionDecision::Deny));
404 assert_eq!(specific.permission_decision_reason.as_deref(), Some("nope"));
405 }
406
407 #[tokio::test]
408 async fn hook_function_factory_creates_fresh_callback_per_instantiate() {
409 let factory_calls = Arc::new(AtomicUsize::new(0));
410 let counter = factory_calls.clone();
411 let hook = Hook::function(HookEventName::Stop, move || {
412 let instance = counter.fetch_add(1, Ordering::SeqCst) + 1;
413 move |_input| async move {
414 Ok(HookResponse::passthrough()
415 .with_system_message(format!("factory-instance-{instance}")))
416 }
417 });
418
419 let mut registered = RegisteredHooks::default();
420 registered.register(
421 PluginId::from("plugin"),
422 RegisteredHookPriority::AfterPlugins,
423 hook,
424 );
425
426 let first_output =
427 invoke_function_handler(®istered.instantiate().expect("instantiate")).await;
428 let second_output =
429 invoke_function_handler(®istered.instantiate().expect("instantiate")).await;
430
431 assert_eq!(factory_calls.load(Ordering::SeqCst), 2);
432 assert_eq!(first_output.as_deref(), Some("factory-instance-1"));
433 assert_eq!(second_output.as_deref(), Some("factory-instance-2"));
434 }
435
436 async fn invoke_function_handler(hooks: &Hooks) -> Option<String> {
437 let prepared = hooks.prepare(HookDispatchRequest {
438 event_name: HookEventName::Stop,
439 matcher_value: None,
440 payload: json!({}),
441 fired_hook_ids: BTreeSet::new(),
442 });
443 let handler = prepared
444 .matched_handlers()
445 .first()
446 .cloned()
447 .expect("function handler");
448
449 let ConfiguredHandlerConfig::Function(callback) = handler.config else {
450 panic!("expected function handler");
451 };
452 let response = callback(HookInput {
453 event_name: HookEventName::Stop,
454 matcher_value: None,
455 payload: json!({}),
456 })
457 .await
458 .expect("callback response");
459
460 response.into_output().system_message
461 }
462}