Skip to main content

halter_hooks/
sdk.rs

1// pattern: Functional Core
2
3use std::fmt;
4use std::future::Future;
5use std::path::PathBuf;
6use std::sync::Arc;
7use std::time::Duration;
8
9use anyhow::Context;
10use futures::future::BoxFuture;
11use halter_protocol::{HookHandlerType, PluginId};
12use serde::de::DeserializeOwned;
13use serde_json::Value;
14
15use crate::config::HookEventName;
16use crate::merge::{HookDecision, HookOutput, HookSpecificOutput, PermissionDecision};
17
18pub type HookCallbackFuture = BoxFuture<'static, anyhow::Result<HookResponse>>;
19pub type HookCallback = Arc<dyn Fn(HookInput) -> HookCallbackFuture + Send + Sync>;
20pub type HookFunctionFactory = Arc<dyn Fn() -> HookCallback + Send + Sync>;
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
23pub enum RegisteredHookPriority {
24    BeforePlugins,
25    #[default]
26    AfterPlugins,
27}
28
29#[derive(Debug, Clone)]
30pub struct HookInput {
31    pub event_name: HookEventName,
32    pub matcher_value: Option<String>,
33    pub payload: Value,
34}
35
36impl HookInput {
37    #[must_use]
38    pub fn field(&self, key: &str) -> Option<&Value> {
39        self.payload.get(key)
40    }
41
42    #[must_use]
43    pub fn string_field(&self, key: &str) -> Option<&str> {
44        self.field(key).and_then(Value::as_str)
45    }
46
47    #[must_use]
48    pub fn tool_name(&self) -> Option<&str> {
49        self.string_field("tool_name")
50    }
51
52    #[must_use]
53    pub fn tool_use_id(&self) -> Option<&str> {
54        self.string_field("tool_use_id")
55    }
56
57    pub fn decode<T: DeserializeOwned>(&self) -> anyhow::Result<T> {
58        serde_json::from_value(self.payload.clone()).context("failed to decode hook input")
59    }
60}
61
62#[derive(Debug, Clone, Default, PartialEq)]
63pub struct HookResponse {
64    output: HookOutput,
65}
66
67impl HookResponse {
68    #[must_use]
69    pub fn passthrough() -> Self {
70        Self::default()
71    }
72
73    #[must_use]
74    pub fn block(reason: impl Into<String>) -> Self {
75        Self {
76            output: HookOutput {
77                decision: Some(HookDecision::Block),
78                reason: Some(reason.into()),
79                ..HookOutput::default()
80            },
81        }
82    }
83
84    #[must_use]
85    pub fn stop(reason: impl Into<String>) -> Self {
86        Self {
87            output: HookOutput {
88                continue_execution: Some(false),
89                stop_reason: Some(reason.into()),
90                ..HookOutput::default()
91            },
92        }
93    }
94
95    #[must_use]
96    pub fn with_system_message(mut self, message: impl Into<String>) -> Self {
97        self.output.system_message = Some(message.into());
98        self
99    }
100
101    #[must_use]
102    pub fn with_additional_context(mut self, context: impl Into<String>) -> Self {
103        self.output
104            .hook_specific_output
105            .get_or_insert_with(HookSpecificOutput::default)
106            .additional_context = Some(context.into());
107        self
108    }
109
110    #[must_use]
111    pub fn with_updated_input(mut self, input: Value) -> Self {
112        self.output
113            .hook_specific_output
114            .get_or_insert_with(HookSpecificOutput::default)
115            .updated_input = Some(input);
116        self
117    }
118
119    #[must_use]
120    pub fn with_updated_output(mut self, output: Value) -> Self {
121        self.output
122            .hook_specific_output
123            .get_or_insert_with(HookSpecificOutput::default)
124            .updated_mcp_tool_output = Some(output);
125        self
126    }
127
128    #[must_use]
129    pub fn with_permission(
130        mut self,
131        decision: PermissionDecision,
132        reason: Option<impl Into<String>>,
133    ) -> Self {
134        let specific = self
135            .output
136            .hook_specific_output
137            .get_or_insert_with(HookSpecificOutput::default);
138        specific.permission_decision = Some(decision);
139        specific.permission_decision_reason = reason.map(Into::into);
140        self
141    }
142
143    #[must_use]
144    pub fn with_suppress_output(mut self, suppress_output: bool) -> Self {
145        self.output.suppress_output = Some(suppress_output);
146        self
147    }
148
149    #[must_use]
150    pub fn into_output(self) -> HookOutput {
151        self.output
152    }
153}
154
155impl From<HookOutput> for HookResponse {
156    fn from(output: HookOutput) -> Self {
157        Self { output }
158    }
159}
160
161pub trait IntoHookResponse {
162    fn into_hook_response(self) -> anyhow::Result<HookResponse>;
163}
164
165impl IntoHookResponse for HookResponse {
166    fn into_hook_response(self) -> anyhow::Result<HookResponse> {
167        Ok(self)
168    }
169}
170
171impl IntoHookResponse for HookOutput {
172    fn into_hook_response(self) -> anyhow::Result<HookResponse> {
173        Ok(HookResponse::from(self))
174    }
175}
176
177impl IntoHookResponse for anyhow::Result<HookResponse> {
178    fn into_hook_response(self) -> anyhow::Result<HookResponse> {
179        self
180    }
181}
182
183#[derive(Clone)]
184pub enum HookKind {
185    Callback(HookCallback),
186    Function(HookFunctionFactory),
187}
188
189impl fmt::Debug for HookKind {
190    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
191        match self {
192            Self::Callback(_) => f.write_str("Callback(..)"),
193            Self::Function(_) => f.write_str("Function(..)"),
194        }
195    }
196}
197
198impl HookKind {
199    #[must_use]
200    pub fn handler_type(&self) -> HookHandlerType {
201        match self {
202            Self::Callback(_) => HookHandlerType::Callback,
203            Self::Function(_) => HookHandlerType::Function,
204        }
205    }
206}
207
208#[derive(Debug, Clone)]
209pub struct Hook {
210    pub event: HookEventName,
211    pub matcher: Option<String>,
212    pub timeout: Duration,
213    pub status_message: Option<String>,
214    pub if_condition: Option<String>,
215    pub once: bool,
216    pub kind: HookKind,
217}
218
219impl Hook {
220    #[must_use]
221    pub fn callback<F, Fut, R>(event: HookEventName, callback: F) -> Self
222    where
223        F: Fn(HookInput) -> Fut + Send + Sync + 'static,
224        Fut: Future<Output = R> + Send + 'static,
225        R: IntoHookResponse + 'static,
226    {
227        Self {
228            event,
229            matcher: None,
230            timeout: Duration::from_secs(30),
231            status_message: None,
232            if_condition: None,
233            once: false,
234            kind: HookKind::Callback(Arc::new(move |input| {
235                let fut = callback(input);
236                Box::pin(async move { fut.await.into_hook_response() })
237            })),
238        }
239    }
240
241    #[must_use]
242    pub fn function<Factory, F, Fut, R>(event: HookEventName, factory: Factory) -> Self
243    where
244        Factory: Fn() -> F + Send + Sync + 'static,
245        F: Fn(HookInput) -> Fut + Send + Sync + 'static,
246        Fut: Future<Output = R> + Send + 'static,
247        R: IntoHookResponse + 'static,
248    {
249        Self {
250            event,
251            matcher: None,
252            timeout: Duration::from_secs(30),
253            status_message: None,
254            if_condition: None,
255            once: false,
256            kind: HookKind::Function(Arc::new(move || {
257                let callback = factory();
258                Arc::new(move |input| {
259                    let fut = callback(input);
260                    Box::pin(async move { fut.await.into_hook_response() })
261                })
262            })),
263        }
264    }
265
266    #[must_use]
267    pub fn with_matcher(mut self, matcher: impl Into<String>) -> Self {
268        self.matcher = Some(matcher.into());
269        self
270    }
271
272    #[must_use]
273    pub fn with_timeout(mut self, timeout: Duration) -> Self {
274        self.timeout = timeout;
275        self
276    }
277
278    #[must_use]
279    pub fn with_status_message(mut self, status_message: impl Into<String>) -> Self {
280        self.status_message = Some(status_message.into());
281        self
282    }
283
284    #[must_use]
285    pub fn with_if_condition(mut self, if_condition: impl Into<String>) -> Self {
286        self.if_condition = Some(if_condition.into());
287        self
288    }
289
290    #[must_use]
291    pub fn with_once(mut self, once: bool) -> Self {
292        self.once = once;
293        self
294    }
295}
296
297#[derive(Debug, Clone)]
298pub struct RegisteredHook {
299    pub plugin_id: PluginId,
300    pub plugin_root: PathBuf,
301    pub priority: RegisteredHookPriority,
302    pub hook: Hook,
303}
304
305#[derive(Debug, Clone, Default)]
306pub struct RegisteredHooks {
307    hooks: Vec<RegisteredHook>,
308}
309
310impl RegisteredHooks {
311    #[must_use]
312    pub fn is_empty(&self) -> bool {
313        self.hooks.is_empty()
314    }
315
316    pub fn register(&mut self, plugin_id: PluginId, priority: RegisteredHookPriority, hook: Hook) {
317        self.hooks.push(RegisteredHook {
318            plugin_id,
319            plugin_root: PathBuf::new(),
320            priority,
321            hook,
322        });
323    }
324
325    pub fn validate(&self) -> anyhow::Result<()> {
326        for hook in &self.hooks {
327            if let Some(matcher) = hook
328                .hook
329                .matcher
330                .as_deref()
331                .map(str::trim)
332                .filter(|value| !value.is_empty())
333            {
334                crate::matcher::CompiledMatcher::compile_regex(matcher).with_context(|| {
335                    format!(
336                        "failed to compile sdk hook matcher for plugin '{}' event '{}'",
337                        hook.plugin_id,
338                        hook.hook.event.canonical_name()
339                    )
340                })?;
341            }
342        }
343        Ok(())
344    }
345
346    pub fn instantiate(&self) -> anyhow::Result<crate::Hooks> {
347        self.validate()?;
348        crate::Hooks::from_registered(self.hooks.clone())
349    }
350}
351
352#[cfg(test)]
353mod tests {
354    use std::collections::BTreeSet;
355    use std::sync::Arc;
356    use std::sync::atomic::{AtomicUsize, Ordering};
357
358    use serde_json::json;
359
360    use super::*;
361    use crate::{ConfiguredHandlerConfig, HookDispatchRequest, Hooks};
362
363    #[test]
364    fn registered_hooks_validate_rejects_invalid_matcher() {
365        let mut hooks = RegisteredHooks::default();
366        hooks.register(
367            PluginId::from("plugin"),
368            RegisteredHookPriority::AfterPlugins,
369            Hook::callback(HookEventName::Stop, |_input| async {
370                HookResponse::passthrough()
371            })
372            .with_matcher("["),
373        );
374
375        let error = hooks.validate().expect_err("invalid matcher should fail");
376        assert!(
377            error
378                .to_string()
379                .contains("failed to compile sdk hook matcher")
380        );
381    }
382
383    #[test]
384    fn hook_response_builders_populate_output() {
385        let output = HookResponse::block("blocked")
386            .with_system_message("system")
387            .with_additional_context("context")
388            .with_updated_input(json!({"command": "echo hi"}))
389            .with_updated_output(json!({"ok": true}))
390            .with_permission(PermissionDecision::Deny, Some("nope"))
391            .with_suppress_output(true)
392            .into_output();
393
394        assert_eq!(output.decision, Some(HookDecision::Block));
395        assert_eq!(output.reason.as_deref(), Some("blocked"));
396        assert_eq!(output.system_message.as_deref(), Some("system"));
397        assert_eq!(output.suppress_output, Some(true));
398
399        let specific = output.hook_specific_output.expect("hook specific output");
400        assert_eq!(specific.additional_context.as_deref(), Some("context"));
401        assert_eq!(specific.updated_input, Some(json!({"command": "echo hi"})));
402        assert_eq!(specific.updated_mcp_tool_output, Some(json!({"ok": true})));
403        assert_eq!(specific.permission_decision, Some(PermissionDecision::Deny));
404        assert_eq!(specific.permission_decision_reason.as_deref(), Some("nope"));
405    }
406
407    #[tokio::test]
408    async fn hook_function_factory_creates_fresh_callback_per_instantiate() {
409        let factory_calls = Arc::new(AtomicUsize::new(0));
410        let counter = factory_calls.clone();
411        let hook = Hook::function(HookEventName::Stop, move || {
412            let instance = counter.fetch_add(1, Ordering::SeqCst) + 1;
413            move |_input| async move {
414                Ok(HookResponse::passthrough()
415                    .with_system_message(format!("factory-instance-{instance}")))
416            }
417        });
418
419        let mut registered = RegisteredHooks::default();
420        registered.register(
421            PluginId::from("plugin"),
422            RegisteredHookPriority::AfterPlugins,
423            hook,
424        );
425
426        let first_output =
427            invoke_function_handler(&registered.instantiate().expect("instantiate")).await;
428        let second_output =
429            invoke_function_handler(&registered.instantiate().expect("instantiate")).await;
430
431        assert_eq!(factory_calls.load(Ordering::SeqCst), 2);
432        assert_eq!(first_output.as_deref(), Some("factory-instance-1"));
433        assert_eq!(second_output.as_deref(), Some("factory-instance-2"));
434    }
435
436    async fn invoke_function_handler(hooks: &Hooks) -> Option<String> {
437        let prepared = hooks.prepare(HookDispatchRequest {
438            event_name: HookEventName::Stop,
439            matcher_value: None,
440            payload: json!({}),
441            fired_hook_ids: BTreeSet::new(),
442        });
443        let handler = prepared
444            .matched_handlers()
445            .first()
446            .cloned()
447            .expect("function handler");
448
449        let ConfiguredHandlerConfig::Function(callback) = handler.config else {
450            panic!("expected function handler");
451        };
452        let response = callback(HookInput {
453            event_name: HookEventName::Stop,
454            matcher_value: None,
455            payload: json!({}),
456        })
457        .await
458        .expect("callback response");
459
460        response.into_output().system_message
461    }
462}