Skip to main content

halter_hooks/
sdk.rs

1// pattern: Functional Core
2
3use std::fmt;
4use std::future::Future;
5use std::path::PathBuf;
6use std::sync::Arc;
7use std::time::Duration;
8
9use anyhow::Context;
10use futures::future::BoxFuture;
11use halter_protocol::{HookHandlerType, PluginId};
12use serde::de::DeserializeOwned;
13use serde_json::Value;
14
15use crate::config::HookEventName;
16use crate::merge::{HookDecision, HookOutput, HookSpecificOutput, PermissionDecision};
17
18/// Boxed future returned by an SDK hook callback.
19pub type HookCallbackFuture = BoxFuture<'static, anyhow::Result<HookResponse>>;
20/// Shared callback used by SDK-registered hooks.
21pub type HookCallback = Arc<dyn Fn(HookInput) -> HookCallbackFuture + Send + Sync>;
22/// Factory for hooks that need a fresh callback instance per dispatch.
23pub type HookFunctionFactory = Arc<dyn Fn() -> HookCallback + Send + Sync>;
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
26/// Relative priority for SDK hooks compared with plugin-file hooks.
27pub enum RegisteredHookPriority {
28    /// Run before hooks loaded from plugin files.
29    BeforePlugins,
30    /// Run after hooks loaded from plugin files.
31    #[default]
32    AfterPlugins,
33}
34
35#[derive(Debug, Clone)]
36/// Input passed to an SDK hook callback.
37pub struct HookInput {
38    pub event_name: HookEventName,
39    pub matcher_value: Option<String>,
40    pub payload: Value,
41}
42
43impl HookInput {
44    /// Return a raw JSON payload field.
45    #[must_use]
46    pub fn field(&self, key: &str) -> Option<&Value> {
47        self.payload.get(key)
48    }
49
50    /// Return a string payload field.
51    #[must_use]
52    pub fn string_field(&self, key: &str) -> Option<&str> {
53        self.field(key).and_then(Value::as_str)
54    }
55
56    /// Tool name for tool-related hook events.
57    #[must_use]
58    pub fn tool_name(&self) -> Option<&str> {
59        self.string_field("tool_name")
60    }
61
62    /// Tool use id for tool-related hook events.
63    #[must_use]
64    pub fn tool_use_id(&self) -> Option<&str> {
65        self.string_field("tool_use_id")
66    }
67
68    /// Decode the entire hook payload into a typed struct.
69    pub fn decode<T: DeserializeOwned>(&self) -> anyhow::Result<T> {
70        serde_json::from_value(self.payload.clone()).context("failed to decode hook input")
71    }
72}
73
74#[derive(Debug, Clone, Default, PartialEq)]
75/// Builder-style response returned by SDK hooks.
76pub struct HookResponse {
77    output: HookOutput,
78}
79
80impl HookResponse {
81    /// Return no changes and allow execution to continue.
82    #[must_use]
83    pub fn passthrough() -> Self {
84        Self::default()
85    }
86
87    /// Block the current operation with a reason.
88    #[must_use]
89    pub fn block(reason: impl Into<String>) -> Self {
90        Self {
91            output: HookOutput {
92                decision: Some(HookDecision::Block),
93                reason: Some(reason.into()),
94                ..HookOutput::default()
95            },
96        }
97    }
98
99    /// Stop the current turn with a reason.
100    #[must_use]
101    pub fn stop(reason: impl Into<String>) -> Self {
102        Self {
103            output: HookOutput {
104                continue_execution: Some(false),
105                stop_reason: Some(reason.into()),
106                ..HookOutput::default()
107            },
108        }
109    }
110
111    /// Add a system message to the merged hook outcome.
112    #[must_use]
113    pub fn with_system_message(mut self, message: impl Into<String>) -> Self {
114        self.output.system_message = Some(message.into());
115        self
116    }
117
118    /// Add context to the next model request.
119    #[must_use]
120    pub fn with_additional_context(mut self, context: impl Into<String>) -> Self {
121        self.output
122            .hook_specific_output
123            .get_or_insert_with(HookSpecificOutput::default)
124            .additional_context = Some(context.into());
125        self
126    }
127
128    /// Replace the tool input seen by downstream execution.
129    #[must_use]
130    pub fn with_updated_input(mut self, input: Value) -> Self {
131        self.output
132            .hook_specific_output
133            .get_or_insert_with(HookSpecificOutput::default)
134            .updated_input = Some(input);
135        self
136    }
137
138    /// Replace the tool output seen by downstream execution.
139    #[must_use]
140    pub fn with_updated_output(mut self, output: Value) -> Self {
141        self.output
142            .hook_specific_output
143            .get_or_insert_with(HookSpecificOutput::default)
144            .updated_mcp_tool_output = Some(output);
145        self
146    }
147
148    /// Set a permission decision for permission-request hooks.
149    #[must_use]
150    pub fn with_permission(
151        mut self,
152        decision: PermissionDecision,
153        reason: Option<impl Into<String>>,
154    ) -> Self {
155        let specific = self
156            .output
157            .hook_specific_output
158            .get_or_insert_with(HookSpecificOutput::default);
159        specific.permission_decision = Some(decision);
160        specific.permission_decision_reason = reason.map(Into::into);
161        self
162    }
163
164    /// Suppress user-visible tool output when supported by the caller.
165    #[must_use]
166    pub fn with_suppress_output(mut self, suppress_output: bool) -> Self {
167        self.output.suppress_output = Some(suppress_output);
168        self
169    }
170
171    /// Convert into the wire-compatible hook output.
172    #[must_use]
173    pub fn into_output(self) -> HookOutput {
174        self.output
175    }
176}
177
178impl From<HookOutput> for HookResponse {
179    fn from(output: HookOutput) -> Self {
180        Self { output }
181    }
182}
183
184/// Accepted return types for SDK hook callbacks.
185pub trait IntoHookResponse {
186    /// Convert a callback result into a [`HookResponse`].
187    fn into_hook_response(self) -> anyhow::Result<HookResponse>;
188}
189
190impl IntoHookResponse for HookResponse {
191    fn into_hook_response(self) -> anyhow::Result<HookResponse> {
192        Ok(self)
193    }
194}
195
196impl IntoHookResponse for HookOutput {
197    fn into_hook_response(self) -> anyhow::Result<HookResponse> {
198        Ok(HookResponse::from(self))
199    }
200}
201
202impl IntoHookResponse for anyhow::Result<HookResponse> {
203    fn into_hook_response(self) -> anyhow::Result<HookResponse> {
204        self
205    }
206}
207
208#[derive(Clone)]
209/// Executable SDK hook backend.
210pub enum HookKind {
211    /// Reuse the same callback for every matching hook dispatch.
212    Callback(HookCallback),
213    /// Build a callback per dispatch.
214    Function(HookFunctionFactory),
215}
216
217impl fmt::Debug for HookKind {
218    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
219        match self {
220            Self::Callback(_) => f.write_str("Callback(..)"),
221            Self::Function(_) => f.write_str("Function(..)"),
222        }
223    }
224}
225
226impl HookKind {
227    /// Hook handler type reported in run summaries.
228    #[must_use]
229    pub fn handler_type(&self) -> HookHandlerType {
230        match self {
231            Self::Callback(_) => HookHandlerType::Callback,
232            Self::Function(_) => HookHandlerType::Function,
233        }
234    }
235}
236
237#[derive(Debug, Clone)]
238/// SDK hook definition registered into a builder.
239pub struct Hook {
240    pub event: HookEventName,
241    pub matcher: Option<String>,
242    pub timeout: Duration,
243    pub status_message: Option<String>,
244    pub if_condition: Option<String>,
245    pub once: bool,
246    pub kind: HookKind,
247}
248
249impl Hook {
250    /// Create a callback hook for an event.
251    #[must_use]
252    pub fn callback<F, Fut, R>(event: HookEventName, callback: F) -> Self
253    where
254        F: Fn(HookInput) -> Fut + Send + Sync + 'static,
255        Fut: Future<Output = R> + Send + 'static,
256        R: IntoHookResponse + 'static,
257    {
258        Self {
259            event,
260            matcher: None,
261            timeout: Duration::from_secs(30),
262            status_message: None,
263            if_condition: None,
264            once: false,
265            kind: HookKind::Callback(Arc::new(move |input| {
266                let fut = callback(input);
267                Box::pin(async move { fut.await.into_hook_response() })
268            })),
269        }
270    }
271
272    /// Create a hook that builds its callback per dispatch.
273    #[must_use]
274    pub fn function<Factory, F, Fut, R>(event: HookEventName, factory: Factory) -> Self
275    where
276        Factory: Fn() -> F + Send + Sync + 'static,
277        F: Fn(HookInput) -> Fut + Send + Sync + 'static,
278        Fut: Future<Output = R> + Send + 'static,
279        R: IntoHookResponse + 'static,
280    {
281        Self {
282            event,
283            matcher: None,
284            timeout: Duration::from_secs(30),
285            status_message: None,
286            if_condition: None,
287            once: false,
288            kind: HookKind::Function(Arc::new(move || {
289                let callback = factory();
290                Arc::new(move |input| {
291                    let fut = callback(input);
292                    Box::pin(async move { fut.await.into_hook_response() })
293                })
294            })),
295        }
296    }
297
298    /// Restrict the hook to matching event values.
299    #[must_use]
300    pub fn with_matcher(mut self, matcher: impl Into<String>) -> Self {
301        self.matcher = Some(matcher.into());
302        self
303    }
304
305    /// Override the hook timeout.
306    #[must_use]
307    pub fn with_timeout(mut self, timeout: Duration) -> Self {
308        self.timeout = timeout;
309        self
310    }
311
312    /// Set the status message shown while the hook runs.
313    #[must_use]
314    pub fn with_status_message(mut self, status_message: impl Into<String>) -> Self {
315        self.status_message = Some(status_message.into());
316        self
317    }
318
319    /// Set a simple hook condition expression.
320    #[must_use]
321    pub fn with_if_condition(mut self, if_condition: impl Into<String>) -> Self {
322        self.if_condition = Some(if_condition.into());
323        self
324    }
325
326    /// Run this hook only once per session when true.
327    #[must_use]
328    pub fn with_once(mut self, once: bool) -> Self {
329        self.once = once;
330        self
331    }
332}
333
334#[derive(Debug, Clone)]
335/// SDK hook with plugin identity and priority metadata.
336pub struct RegisteredHook {
337    pub plugin_id: PluginId,
338    pub plugin_root: PathBuf,
339    pub priority: RegisteredHookPriority,
340    pub hook: Hook,
341}
342
343#[derive(Debug, Clone, Default)]
344/// Collection of SDK hooks registered before runtime construction.
345pub struct RegisteredHooks {
346    hooks: Vec<RegisteredHook>,
347}
348
349impl RegisteredHooks {
350    /// Whether no SDK hooks are registered.
351    #[must_use]
352    pub fn is_empty(&self) -> bool {
353        self.hooks.is_empty()
354    }
355
356    /// Register a hook for one plugin id.
357    pub fn register(&mut self, plugin_id: PluginId, priority: RegisteredHookPriority, hook: Hook) {
358        self.hooks.push(RegisteredHook {
359            plugin_id,
360            plugin_root: PathBuf::new(),
361            priority,
362            hook,
363        });
364    }
365
366    /// Validate SDK hook matchers before runtime construction.
367    pub fn validate(&self) -> anyhow::Result<()> {
368        for hook in &self.hooks {
369            if let Some(matcher) = hook
370                .hook
371                .matcher
372                .as_deref()
373                .map(str::trim)
374                .filter(|value| !value.is_empty())
375            {
376                crate::matcher::CompiledMatcher::compile_regex(matcher).with_context(|| {
377                    format!(
378                        "failed to compile sdk hook matcher for plugin '{}' event '{}'",
379                        hook.plugin_id,
380                        hook.hook.event.canonical_name()
381                    )
382                })?;
383            }
384        }
385        Ok(())
386    }
387
388    /// Convert registered SDK hooks into a runtime hook registry.
389    pub fn instantiate(&self) -> anyhow::Result<crate::Hooks> {
390        self.validate()?;
391        crate::Hooks::from_registered(self.hooks.clone())
392    }
393}
394
395#[cfg(test)]
396mod tests {
397    use std::collections::BTreeSet;
398    use std::sync::Arc;
399    use std::sync::atomic::{AtomicUsize, Ordering};
400
401    use serde_json::json;
402
403    use super::*;
404    use crate::{ConfiguredHandlerConfig, HookDispatchRequest, Hooks};
405
406    #[test]
407    fn registered_hooks_validate_rejects_invalid_matcher() {
408        let mut hooks = RegisteredHooks::default();
409        hooks.register(
410            PluginId::from("plugin"),
411            RegisteredHookPriority::AfterPlugins,
412            Hook::callback(HookEventName::Stop, |_input| async {
413                HookResponse::passthrough()
414            })
415            .with_matcher("["),
416        );
417
418        let error = hooks.validate().expect_err("invalid matcher should fail");
419        assert!(
420            error
421                .to_string()
422                .contains("failed to compile sdk hook matcher")
423        );
424    }
425
426    #[test]
427    fn hook_response_builders_populate_output() {
428        let output = HookResponse::block("blocked")
429            .with_system_message("system")
430            .with_additional_context("context")
431            .with_updated_input(json!({"command": "echo hi"}))
432            .with_updated_output(json!({"ok": true}))
433            .with_permission(PermissionDecision::Deny, Some("nope"))
434            .with_suppress_output(true)
435            .into_output();
436
437        assert_eq!(output.decision, Some(HookDecision::Block));
438        assert_eq!(output.reason.as_deref(), Some("blocked"));
439        assert_eq!(output.system_message.as_deref(), Some("system"));
440        assert_eq!(output.suppress_output, Some(true));
441
442        let specific = output.hook_specific_output.expect("hook specific output");
443        assert_eq!(specific.additional_context.as_deref(), Some("context"));
444        assert_eq!(specific.updated_input, Some(json!({"command": "echo hi"})));
445        assert_eq!(specific.updated_mcp_tool_output, Some(json!({"ok": true})));
446        assert_eq!(specific.permission_decision, Some(PermissionDecision::Deny));
447        assert_eq!(specific.permission_decision_reason.as_deref(), Some("nope"));
448    }
449
450    #[tokio::test]
451    async fn hook_function_factory_creates_fresh_callback_per_instantiate() {
452        let factory_calls = Arc::new(AtomicUsize::new(0));
453        let counter = factory_calls.clone();
454        let hook = Hook::function(HookEventName::Stop, move || {
455            let instance = counter.fetch_add(1, Ordering::SeqCst) + 1;
456            move |_input| async move {
457                Ok(HookResponse::passthrough()
458                    .with_system_message(format!("factory-instance-{instance}")))
459            }
460        });
461
462        let mut registered = RegisteredHooks::default();
463        registered.register(
464            PluginId::from("plugin"),
465            RegisteredHookPriority::AfterPlugins,
466            hook,
467        );
468
469        let first_output =
470            invoke_function_handler(&registered.instantiate().expect("instantiate")).await;
471        let second_output =
472            invoke_function_handler(&registered.instantiate().expect("instantiate")).await;
473
474        assert_eq!(factory_calls.load(Ordering::SeqCst), 2);
475        assert_eq!(first_output.as_deref(), Some("factory-instance-1"));
476        assert_eq!(second_output.as_deref(), Some("factory-instance-2"));
477    }
478
479    async fn invoke_function_handler(hooks: &Hooks) -> Option<String> {
480        let prepared = hooks.prepare(HookDispatchRequest {
481            event_name: HookEventName::Stop,
482            matcher_value: None,
483            payload: json!({}),
484            fired_hook_ids: BTreeSet::new(),
485        });
486        let handler = prepared
487            .matched_handlers()
488            .first()
489            .cloned()
490            .expect("function handler");
491
492        let ConfiguredHandlerConfig::Function(callback) = handler.config else {
493            panic!("expected function handler");
494        };
495        let response = callback(HookInput {
496            event_name: HookEventName::Stop,
497            matcher_value: None,
498            payload: json!({}),
499        })
500        .await
501        .expect("callback response");
502
503        response.into_output().system_message
504    }
505}