Skip to main content

astrid_hooks/
hook.rs

1//! Hook definitions and types.
2
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::fmt;
6use uuid::Uuid;
7
8pub use crate::hook_event::HookEvent;
9
10/// Handler implementation for a hook.
11#[derive(Debug, Clone, Serialize, Deserialize)]
12#[serde(rename_all = "snake_case", tag = "type")]
13pub enum HookHandler {
14    /// Execute a shell command.
15    Command {
16        /// The command to execute.
17        command: String,
18        /// Arguments to pass to the command.
19        #[serde(default)]
20        args: Vec<String>,
21        /// Environment variables to set.
22        #[serde(default)]
23        env: HashMap<String, String>,
24        /// Working directory for the command.
25        #[serde(default)]
26        working_dir: Option<String>,
27    },
28    /// Call an HTTP webhook.
29    Http {
30        /// The URL to call.
31        url: String,
32        /// HTTP method (GET, POST, etc.).
33        #[serde(default = "default_http_method")]
34        method: String,
35        /// Headers to include.
36        #[serde(default)]
37        headers: HashMap<String, String>,
38        /// Request body template.
39        #[serde(default)]
40        body_template: Option<String>,
41    },
42    /// Execute a WASM module via Extism.
43    Wasm {
44        /// Path to the WASM module.
45        module_path: String,
46        /// Function to call in the module.
47        #[serde(default = "default_wasm_function")]
48        function: String,
49    },
50    /// Invoke an LLM-based agent handler (stubbed).
51    Agent {
52        /// Agent prompt template.
53        prompt_template: String,
54        /// Model to use.
55        #[serde(default)]
56        model: Option<String>,
57        /// Maximum tokens for response.
58        #[serde(default)]
59        max_tokens: Option<u32>,
60    },
61}
62
63fn default_http_method() -> String {
64    "POST".to_string()
65}
66
67fn default_wasm_function() -> String {
68    "handle".to_string()
69}
70
71#[allow(dead_code)]
72impl HookHandler {
73    /// Create a new command handler.
74    #[must_use]
75    pub(crate) fn command(command: impl Into<String>) -> Self {
76        Self::Command {
77            command: command.into(),
78            args: Vec::new(),
79            env: HashMap::new(),
80            working_dir: None,
81        }
82    }
83
84    /// Create a new HTTP webhook handler.
85    #[must_use]
86    pub(crate) fn http(url: impl Into<String>) -> Self {
87        Self::Http {
88            url: url.into(),
89            method: "POST".to_string(),
90            headers: HashMap::new(),
91            body_template: None,
92        }
93    }
94
95    /// Create a new WASM handler.
96    #[must_use]
97    pub(crate) fn wasm(module_path: impl Into<String>) -> Self {
98        Self::Wasm {
99            module_path: module_path.into(),
100            function: "handle".to_string(),
101        }
102    }
103
104    /// Create a new agent handler (stubbed).
105    #[must_use]
106    pub(crate) fn agent(prompt_template: impl Into<String>) -> Self {
107        Self::Agent {
108            prompt_template: prompt_template.into(),
109            model: None,
110            max_tokens: None,
111        }
112    }
113
114    /// Check if this handler is stubbed (not yet implemented).
115    #[must_use]
116    pub(crate) fn is_stubbed(&self) -> bool {
117        matches!(self, Self::Agent { .. })
118    }
119}
120
121/// Action to take when a hook fails.
122#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
123#[serde(rename_all = "snake_case")]
124pub enum FailAction {
125    /// Log a warning and continue.
126    #[default]
127    Warn,
128    /// Block the operation that triggered the hook.
129    Block,
130    /// Silently ignore the failure.
131    Ignore,
132}
133
134impl fmt::Display for FailAction {
135    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
136        match self {
137            Self::Warn => write!(f, "warn"),
138            Self::Block => write!(f, "block"),
139            Self::Ignore => write!(f, "ignore"),
140        }
141    }
142}
143
144/// A hook definition.
145#[derive(Debug, Clone, Serialize, Deserialize)]
146pub struct Hook {
147    /// Unique identifier for this hook.
148    pub id: Uuid,
149    /// Human-readable name.
150    #[serde(default)]
151    pub name: Option<String>,
152    /// Description of what this hook does.
153    #[serde(default)]
154    pub description: Option<String>,
155    /// Event that triggers this hook.
156    pub event: HookEvent,
157    /// Optional matcher pattern (glob or regex).
158    #[serde(default)]
159    pub matcher: Option<HookMatcher>,
160    /// Handler implementation.
161    pub handler: HookHandler,
162    /// Timeout in seconds.
163    #[serde(default = "default_timeout")]
164    pub timeout_secs: u64,
165    /// Action to take on failure.
166    #[serde(default)]
167    pub fail_action: FailAction,
168    /// Run asynchronously (don't wait for completion).
169    #[serde(default)]
170    pub async_mode: bool,
171    /// Whether the hook is enabled.
172    #[serde(default = "default_enabled")]
173    pub enabled: bool,
174    /// Priority (lower runs first).
175    #[serde(default = "default_priority")]
176    pub priority: i32,
177}
178
179fn default_timeout() -> u64 {
180    30
181}
182
183fn default_enabled() -> bool {
184    true
185}
186
187fn default_priority() -> i32 {
188    100
189}
190
191#[allow(dead_code)]
192impl Hook {
193    /// Create a new hook for the given event.
194    #[must_use]
195    pub(crate) fn new(event: HookEvent) -> Self {
196        Self {
197            id: Uuid::new_v4(),
198            name: None,
199            description: None,
200            event,
201            matcher: None,
202            handler: HookHandler::command("echo"),
203            timeout_secs: 30,
204            fail_action: FailAction::Warn,
205            async_mode: false,
206            enabled: true,
207            priority: 100,
208        }
209    }
210
211    /// Set the hook's name.
212    #[must_use]
213    pub(crate) fn with_name(mut self, name: impl Into<String>) -> Self {
214        self.name = Some(name.into());
215        self
216    }
217
218    /// Set the hook's description.
219    #[must_use]
220    pub(crate) fn with_description(mut self, description: impl Into<String>) -> Self {
221        self.description = Some(description.into());
222        self
223    }
224
225    /// Set the handler for this hook.
226    #[must_use]
227    pub(crate) fn with_handler(mut self, handler: HookHandler) -> Self {
228        self.handler = handler;
229        self
230    }
231
232    /// Set a matcher pattern.
233    #[must_use]
234    pub(crate) fn with_matcher(mut self, matcher: HookMatcher) -> Self {
235        self.matcher = Some(matcher);
236        self
237    }
238
239    /// Set the timeout in seconds.
240    #[must_use]
241    pub(crate) fn with_timeout(mut self, secs: u64) -> Self {
242        self.timeout_secs = secs;
243        self
244    }
245
246    /// Set the failure action.
247    #[must_use]
248    pub(crate) fn with_fail_action(mut self, action: FailAction) -> Self {
249        self.fail_action = action;
250        self
251    }
252
253    /// Enable async mode.
254    #[must_use]
255    pub(crate) fn async_mode(mut self) -> Self {
256        self.async_mode = true;
257        self
258    }
259
260    /// Disable the hook.
261    #[must_use]
262    pub(crate) fn disabled(mut self) -> Self {
263        self.enabled = false;
264        self
265    }
266
267    /// Set the priority.
268    #[must_use]
269    pub(crate) fn with_priority(mut self, priority: i32) -> Self {
270        self.priority = priority;
271        self
272    }
273}
274
275/// Matcher for filtering when a hook should run.
276#[derive(Debug, Clone, Serialize, Deserialize)]
277#[serde(rename_all = "snake_case", tag = "type")]
278pub enum HookMatcher {
279    /// Match using a glob pattern.
280    Glob {
281        /// The glob pattern.
282        pattern: String,
283    },
284    /// Match using a regex pattern.
285    Regex {
286        /// The regex pattern.
287        pattern: String,
288    },
289    /// Match specific tool names.
290    ToolNames {
291        /// List of tool names to match.
292        names: Vec<String>,
293    },
294    /// Match specific server names.
295    ServerNames {
296        /// List of server names to match.
297        names: Vec<String>,
298    },
299}
300
301#[allow(dead_code)]
302impl HookMatcher {
303    /// Create a glob matcher.
304    #[must_use]
305    pub(crate) fn glob(pattern: impl Into<String>) -> Self {
306        Self::Glob {
307            pattern: pattern.into(),
308        }
309    }
310
311    /// Create a regex matcher.
312    #[must_use]
313    pub(crate) fn regex(pattern: impl Into<String>) -> Self {
314        Self::Regex {
315            pattern: pattern.into(),
316        }
317    }
318
319    /// Create a tool names matcher.
320    #[must_use]
321    pub(crate) fn tools(names: Vec<String>) -> Self {
322        Self::ToolNames { names }
323    }
324
325    /// Create a server names matcher.
326    #[must_use]
327    pub(crate) fn servers(names: Vec<String>) -> Self {
328        Self::ServerNames { names }
329    }
330}
331
332#[cfg(test)]
333mod tests {
334    use super::*;
335
336    #[test]
337    fn test_hook_event_display() {
338        assert_eq!(HookEvent::SessionStart.to_string(), "session_start");
339        assert_eq!(HookEvent::PreToolCall.to_string(), "pre_tool_call");
340    }
341
342    #[test]
343    fn test_hook_creation() {
344        let hook = Hook::new(HookEvent::PreToolCall)
345            .with_name("log-tool-calls")
346            .with_handler(HookHandler::command("echo"))
347            .with_timeout(60);
348
349        assert_eq!(hook.event, HookEvent::PreToolCall);
350        assert_eq!(hook.name, Some("log-tool-calls".to_string()));
351        assert_eq!(hook.timeout_secs, 60);
352        assert!(hook.enabled);
353    }
354
355    #[test]
356    fn test_hook_handler_creation() {
357        let cmd = HookHandler::command("echo");
358        assert!(!cmd.is_stubbed());
359
360        let wasm = HookHandler::wasm("/path/to/module.wasm");
361        assert!(!wasm.is_stubbed());
362
363        let agent = HookHandler::agent("Analyze this event: {{event}}");
364        assert!(agent.is_stubbed());
365    }
366
367    #[test]
368    fn test_hook_matcher() {
369        let glob = HookMatcher::glob("fs_*");
370        let regex = HookMatcher::regex(r"^fs_\w+$");
371        let tools = HookMatcher::tools(vec!["read_file".to_string(), "write_file".to_string()]);
372
373        assert!(matches!(glob, HookMatcher::Glob { .. }));
374        assert!(matches!(regex, HookMatcher::Regex { .. }));
375        assert!(matches!(tools, HookMatcher::ToolNames { .. }));
376    }
377
378    #[test]
379    fn test_fail_action_default() {
380        assert_eq!(FailAction::default(), FailAction::Warn);
381    }
382}