Skip to main content

starpod_hooks/
callback.rs

1//! Hook callback types — the function signatures and matcher configuration.
2
3use std::fmt;
4use std::future::Future;
5use std::pin::Pin;
6use std::sync::Arc;
7
8use crate::error;
9use crate::input::HookInput;
10use crate::output::HookOutput;
11
12/// Type alias for hook callback functions.
13///
14/// A hook callback receives:
15/// - `input`: typed hook input data
16/// - `tool_use_id`: optional correlation ID for tool-related hooks
17/// - `cancellation`: a tokio CancellationToken for aborting
18///
19/// Returns a [`HookOutput`] that controls the agent's behavior.
20///
21/// # Example
22///
23/// ```
24/// use starpod_hooks::{hook_fn, HookInput, HookOutput};
25///
26/// let callback = hook_fn(|input, _tool_use_id, _cancel| async move {
27///     println!("Hook fired for: {}", input.event_name());
28///     Ok(HookOutput::default())
29/// });
30/// ```
31pub type HookCallback = Arc<
32    dyn Fn(
33            HookInput,
34            Option<String>,
35            tokio_util::sync::CancellationToken,
36        ) -> Pin<Box<dyn Future<Output = error::Result<HookOutput>> + Send>>
37        + Send
38        + Sync,
39>;
40
41/// Helper to create a [`HookCallback`] from an async function.
42///
43/// # Example
44///
45/// ```
46/// use starpod_hooks::{hook_fn, HookOutput};
47///
48/// let my_hook = hook_fn(|_input, _id, _cancel| async move {
49///     Ok(HookOutput::default())
50/// });
51/// ```
52pub fn hook_fn<F, Fut>(f: F) -> HookCallback
53where
54    F: Fn(HookInput, Option<String>, tokio_util::sync::CancellationToken) -> Fut
55        + Send
56        + Sync
57        + 'static,
58    Fut: Future<Output = error::Result<HookOutput>> + Send + 'static,
59{
60    Arc::new(move |input, tool_use_id, cancel| Box::pin(f(input, tool_use_id, cancel)))
61}
62
63/// Hook configuration with optional regex matcher pattern.
64///
65/// Groups one or more callbacks with a regex filter. The matcher pattern
66/// is tested against the hook's filter field (typically the tool name for
67/// tool-related hooks). If no matcher is set, the hooks run for every
68/// event of their type.
69///
70/// # Example
71///
72/// ```
73/// use starpod_hooks::{hook_fn, HookCallbackMatcher, HookOutput};
74///
75/// let matcher = HookCallbackMatcher::new(vec![
76///     hook_fn(|_input, _id, _cancel| async move {
77///         Ok(HookOutput::default())
78///     }),
79/// ])
80/// .with_matcher("Bash|Write")
81/// .with_timeout(30);
82///
83/// assert!(matcher.matches("Bash").unwrap());
84/// assert!(!matcher.matches("Read").unwrap());
85/// ```
86#[derive(Clone)]
87pub struct HookCallbackMatcher {
88    /// Human-readable name for this hook group (used by circuit breaker and logging).
89    pub name: Option<String>,
90
91    /// Regex pattern to match against the event's filter field (e.g., tool name).
92    /// If None, the hook runs for every event of its type.
93    pub matcher: Option<String>,
94
95    /// Array of callback functions to execute when the pattern matches.
96    pub hooks: Vec<HookCallback>,
97
98    /// Timeout in seconds for all hooks in this matcher.
99    pub timeout: Option<u64>,
100
101    /// Eligibility requirements (binaries, env vars, OS).
102    pub requires: Option<crate::eligibility::HookRequirements>,
103}
104
105impl HookCallbackMatcher {
106    pub fn new(hooks: Vec<HookCallback>) -> Self {
107        Self {
108            name: None,
109            matcher: None,
110            hooks,
111            timeout: None,
112            requires: None,
113        }
114    }
115
116    pub fn with_name(mut self, name: impl Into<String>) -> Self {
117        self.name = Some(name.into());
118        self
119    }
120
121    pub fn with_matcher(mut self, matcher: impl Into<String>) -> Self {
122        self.matcher = Some(matcher.into());
123        self
124    }
125
126    pub fn with_timeout(mut self, timeout: u64) -> Self {
127        self.timeout = Some(timeout);
128        self
129    }
130
131    pub fn with_requirements(mut self, requires: crate::eligibility::HookRequirements) -> Self {
132        self.requires = Some(requires);
133        self
134    }
135
136    /// Check if this matcher applies to the given target string.
137    ///
138    /// Returns `Ok(true)` if no matcher is set (matches everything) or
139    /// if the regex pattern matches the target.
140    pub fn matches(&self, target: &str) -> error::Result<bool> {
141        match &self.matcher {
142            None => Ok(true),
143            Some(pattern) => {
144                let re = regex::Regex::new(pattern)?;
145                Ok(re.is_match(target))
146            }
147        }
148    }
149}
150
151impl fmt::Debug for HookCallbackMatcher {
152    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
153        f.debug_struct("HookCallbackMatcher")
154            .field("name", &self.name)
155            .field("matcher", &self.matcher)
156            .field("hooks_count", &self.hooks.len())
157            .field("timeout", &self.timeout)
158            .field("requires", &self.requires)
159            .finish()
160    }
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166
167    fn noop_hook() -> HookCallback {
168        hook_fn(|_input, _id, _cancel| async move { Ok(HookOutput::default()) })
169    }
170
171    #[test]
172    fn matcher_no_pattern_matches_everything() {
173        let m = HookCallbackMatcher::new(vec![noop_hook()]);
174        assert!(m.matches("Bash").unwrap());
175        assert!(m.matches("anything").unwrap());
176        assert!(m.matches("").unwrap());
177    }
178
179    #[test]
180    fn matcher_regex_filters() {
181        let m = HookCallbackMatcher::new(vec![noop_hook()]).with_matcher("Bash|Write");
182        assert!(m.matches("Bash").unwrap());
183        assert!(m.matches("Write").unwrap());
184        assert!(!m.matches("Read").unwrap());
185        assert!(!m.matches("Edit").unwrap());
186    }
187
188    #[test]
189    fn matcher_invalid_regex_returns_error() {
190        let m = HookCallbackMatcher::new(vec![noop_hook()]).with_matcher("[invalid");
191        assert!(m.matches("test").is_err());
192    }
193
194    #[test]
195    fn matcher_with_timeout() {
196        let m = HookCallbackMatcher::new(vec![noop_hook()]).with_timeout(30);
197        assert_eq!(m.timeout, Some(30));
198    }
199
200    #[test]
201    fn matcher_with_name() {
202        let m = HookCallbackMatcher::new(vec![noop_hook()]).with_name("my-hook");
203        assert_eq!(m.name.as_deref(), Some("my-hook"));
204    }
205
206    #[test]
207    fn matcher_with_requirements() {
208        use crate::eligibility::HookRequirements;
209        let req = HookRequirements {
210            bins: vec!["sh".into()],
211            ..Default::default()
212        };
213        let m = HookCallbackMatcher::new(vec![noop_hook()]).with_requirements(req);
214        assert!(m.requires.is_some());
215        assert_eq!(m.requires.unwrap().bins, vec!["sh"]);
216    }
217
218    #[test]
219    fn matcher_builder_chaining() {
220        use crate::eligibility::HookRequirements;
221        let m = HookCallbackMatcher::new(vec![noop_hook()])
222            .with_name("lint")
223            .with_matcher("Write|Edit")
224            .with_timeout(10)
225            .with_requirements(HookRequirements {
226                os: vec!["macos".into()],
227                ..Default::default()
228            });
229
230        assert_eq!(m.name.as_deref(), Some("lint"));
231        assert_eq!(m.matcher.as_deref(), Some("Write|Edit"));
232        assert_eq!(m.timeout, Some(10));
233        assert!(m.requires.is_some());
234    }
235
236    #[test]
237    fn matcher_debug_shows_hook_count() {
238        let m = HookCallbackMatcher::new(vec![noop_hook(), noop_hook()]).with_matcher("test");
239        let debug = format!("{:?}", m);
240        assert!(debug.contains("hooks_count: 2"));
241        assert!(debug.contains("test"));
242    }
243
244    #[test]
245    fn matcher_debug_includes_name_and_requires() {
246        use crate::eligibility::HookRequirements;
247        let m = HookCallbackMatcher::new(vec![noop_hook()])
248            .with_name("my-hook")
249            .with_requirements(HookRequirements::default());
250        let debug = format!("{:?}", m);
251        assert!(
252            debug.contains("my-hook"),
253            "debug should contain name: {}",
254            debug
255        );
256        assert!(
257            debug.contains("requires"),
258            "debug should contain requires: {}",
259            debug
260        );
261    }
262
263    #[tokio::test]
264    async fn hook_fn_creates_callable_callback() {
265        let called = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
266        let called_clone = called.clone();
267
268        let hook = hook_fn(move |_input, _id, _cancel| {
269            let called = called_clone.clone();
270            async move {
271                called.store(true, std::sync::atomic::Ordering::SeqCst);
272                Ok(HookOutput::default())
273            }
274        });
275
276        let input = HookInput::UserPromptSubmit {
277            base: crate::input::BaseHookInput {
278                session_id: "test".into(),
279                transcript_path: String::new(),
280                cwd: "/tmp".into(),
281                permission_mode: None,
282                agent_id: None,
283                agent_type: None,
284            },
285            prompt: "hello".into(),
286        };
287
288        let cancel = tokio_util::sync::CancellationToken::new();
289        let result = hook(input, None, cancel).await;
290        assert!(result.is_ok());
291        assert!(called.load(std::sync::atomic::Ordering::SeqCst));
292    }
293}