starpod_hooks/
callback.rs1use std::fmt;
4use std::future::Future;
5use std::pin::Pin;
6use std::sync::Arc;
7
8use crate::error;
9use crate::input::HookInput;
10use crate::output::HookOutput;
11
12pub type HookCallback = Arc<
32 dyn Fn(
33 HookInput,
34 Option<String>,
35 tokio_util::sync::CancellationToken,
36 ) -> Pin<Box<dyn Future<Output = error::Result<HookOutput>> + Send>>
37 + Send
38 + Sync,
39>;
40
41pub fn hook_fn<F, Fut>(f: F) -> HookCallback
53where
54 F: Fn(HookInput, Option<String>, tokio_util::sync::CancellationToken) -> Fut
55 + Send
56 + Sync
57 + 'static,
58 Fut: Future<Output = error::Result<HookOutput>> + Send + 'static,
59{
60 Arc::new(move |input, tool_use_id, cancel| Box::pin(f(input, tool_use_id, cancel)))
61}
62
63#[derive(Clone)]
87pub struct HookCallbackMatcher {
88 pub name: Option<String>,
90
91 pub matcher: Option<String>,
94
95 pub hooks: Vec<HookCallback>,
97
98 pub timeout: Option<u64>,
100
101 pub requires: Option<crate::eligibility::HookRequirements>,
103}
104
105impl HookCallbackMatcher {
106 pub fn new(hooks: Vec<HookCallback>) -> Self {
107 Self {
108 name: None,
109 matcher: None,
110 hooks,
111 timeout: None,
112 requires: None,
113 }
114 }
115
116 pub fn with_name(mut self, name: impl Into<String>) -> Self {
117 self.name = Some(name.into());
118 self
119 }
120
121 pub fn with_matcher(mut self, matcher: impl Into<String>) -> Self {
122 self.matcher = Some(matcher.into());
123 self
124 }
125
126 pub fn with_timeout(mut self, timeout: u64) -> Self {
127 self.timeout = Some(timeout);
128 self
129 }
130
131 pub fn with_requirements(mut self, requires: crate::eligibility::HookRequirements) -> Self {
132 self.requires = Some(requires);
133 self
134 }
135
136 pub fn matches(&self, target: &str) -> error::Result<bool> {
141 match &self.matcher {
142 None => Ok(true),
143 Some(pattern) => {
144 let re = regex::Regex::new(pattern)?;
145 Ok(re.is_match(target))
146 }
147 }
148 }
149}
150
151impl fmt::Debug for HookCallbackMatcher {
152 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
153 f.debug_struct("HookCallbackMatcher")
154 .field("name", &self.name)
155 .field("matcher", &self.matcher)
156 .field("hooks_count", &self.hooks.len())
157 .field("timeout", &self.timeout)
158 .field("requires", &self.requires)
159 .finish()
160 }
161}
162
163#[cfg(test)]
164mod tests {
165 use super::*;
166
167 fn noop_hook() -> HookCallback {
168 hook_fn(|_input, _id, _cancel| async move { Ok(HookOutput::default()) })
169 }
170
171 #[test]
172 fn matcher_no_pattern_matches_everything() {
173 let m = HookCallbackMatcher::new(vec![noop_hook()]);
174 assert!(m.matches("Bash").unwrap());
175 assert!(m.matches("anything").unwrap());
176 assert!(m.matches("").unwrap());
177 }
178
179 #[test]
180 fn matcher_regex_filters() {
181 let m = HookCallbackMatcher::new(vec![noop_hook()]).with_matcher("Bash|Write");
182 assert!(m.matches("Bash").unwrap());
183 assert!(m.matches("Write").unwrap());
184 assert!(!m.matches("Read").unwrap());
185 assert!(!m.matches("Edit").unwrap());
186 }
187
188 #[test]
189 fn matcher_invalid_regex_returns_error() {
190 let m = HookCallbackMatcher::new(vec![noop_hook()]).with_matcher("[invalid");
191 assert!(m.matches("test").is_err());
192 }
193
194 #[test]
195 fn matcher_with_timeout() {
196 let m = HookCallbackMatcher::new(vec![noop_hook()]).with_timeout(30);
197 assert_eq!(m.timeout, Some(30));
198 }
199
200 #[test]
201 fn matcher_with_name() {
202 let m = HookCallbackMatcher::new(vec![noop_hook()]).with_name("my-hook");
203 assert_eq!(m.name.as_deref(), Some("my-hook"));
204 }
205
206 #[test]
207 fn matcher_with_requirements() {
208 use crate::eligibility::HookRequirements;
209 let req = HookRequirements {
210 bins: vec!["sh".into()],
211 ..Default::default()
212 };
213 let m = HookCallbackMatcher::new(vec![noop_hook()]).with_requirements(req);
214 assert!(m.requires.is_some());
215 assert_eq!(m.requires.unwrap().bins, vec!["sh"]);
216 }
217
218 #[test]
219 fn matcher_builder_chaining() {
220 use crate::eligibility::HookRequirements;
221 let m = HookCallbackMatcher::new(vec![noop_hook()])
222 .with_name("lint")
223 .with_matcher("Write|Edit")
224 .with_timeout(10)
225 .with_requirements(HookRequirements {
226 os: vec!["macos".into()],
227 ..Default::default()
228 });
229
230 assert_eq!(m.name.as_deref(), Some("lint"));
231 assert_eq!(m.matcher.as_deref(), Some("Write|Edit"));
232 assert_eq!(m.timeout, Some(10));
233 assert!(m.requires.is_some());
234 }
235
236 #[test]
237 fn matcher_debug_shows_hook_count() {
238 let m = HookCallbackMatcher::new(vec![noop_hook(), noop_hook()]).with_matcher("test");
239 let debug = format!("{:?}", m);
240 assert!(debug.contains("hooks_count: 2"));
241 assert!(debug.contains("test"));
242 }
243
244 #[test]
245 fn matcher_debug_includes_name_and_requires() {
246 use crate::eligibility::HookRequirements;
247 let m = HookCallbackMatcher::new(vec![noop_hook()])
248 .with_name("my-hook")
249 .with_requirements(HookRequirements::default());
250 let debug = format!("{:?}", m);
251 assert!(
252 debug.contains("my-hook"),
253 "debug should contain name: {}",
254 debug
255 );
256 assert!(
257 debug.contains("requires"),
258 "debug should contain requires: {}",
259 debug
260 );
261 }
262
263 #[tokio::test]
264 async fn hook_fn_creates_callable_callback() {
265 let called = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
266 let called_clone = called.clone();
267
268 let hook = hook_fn(move |_input, _id, _cancel| {
269 let called = called_clone.clone();
270 async move {
271 called.store(true, std::sync::atomic::Ordering::SeqCst);
272 Ok(HookOutput::default())
273 }
274 });
275
276 let input = HookInput::UserPromptSubmit {
277 base: crate::input::BaseHookInput {
278 session_id: "test".into(),
279 transcript_path: String::new(),
280 cwd: "/tmp".into(),
281 permission_mode: None,
282 agent_id: None,
283 agent_type: None,
284 },
285 prompt: "hello".into(),
286 };
287
288 let cancel = tokio_util::sync::CancellationToken::new();
289 let result = hook(input, None, cancel).await;
290 assert!(result.is_ok());
291 assert!(called.load(std::sync::atomic::Ordering::SeqCst));
292 }
293}