Skip to main content

systemprompt_models/services/
hooks.rs

1use std::fmt;
2use std::str::FromStr;
3
4use anyhow::anyhow;
5use serde::{Deserialize, Serialize};
6use systemprompt_identifiers::HookId;
7
8pub const HOOK_CONFIG_FILENAME: &str = "config.yaml";
9
10const fn default_true() -> bool {
11    true
12}
13
14fn default_version() -> String {
15    "1.0.0".to_string()
16}
17
18fn default_matcher() -> String {
19    "*".to_string()
20}
21
22fn default_hook_id() -> HookId {
23    HookId::new("")
24}
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
27#[serde(rename_all = "PascalCase")]
28pub enum HookEvent {
29    PreToolUse,
30    PostToolUse,
31    PostToolUseFailure,
32    SessionStart,
33    SessionEnd,
34    UserPromptSubmit,
35    Notification,
36    Stop,
37    SubagentStart,
38    SubagentStop,
39}
40
41impl HookEvent {
42    pub const ALL_VARIANTS: &'static [Self] = &[
43        Self::PreToolUse,
44        Self::PostToolUse,
45        Self::PostToolUseFailure,
46        Self::SessionStart,
47        Self::SessionEnd,
48        Self::UserPromptSubmit,
49        Self::Notification,
50        Self::Stop,
51        Self::SubagentStart,
52        Self::SubagentStop,
53    ];
54
55    pub const fn as_str(&self) -> &'static str {
56        match self {
57            Self::PreToolUse => "PreToolUse",
58            Self::PostToolUse => "PostToolUse",
59            Self::PostToolUseFailure => "PostToolUseFailure",
60            Self::SessionStart => "SessionStart",
61            Self::SessionEnd => "SessionEnd",
62            Self::UserPromptSubmit => "UserPromptSubmit",
63            Self::Notification => "Notification",
64            Self::Stop => "Stop",
65            Self::SubagentStart => "SubagentStart",
66            Self::SubagentStop => "SubagentStop",
67        }
68    }
69}
70
71impl fmt::Display for HookEvent {
72    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
73        write!(f, "{}", self.as_str())
74    }
75}
76
77impl FromStr for HookEvent {
78    type Err = anyhow::Error;
79
80    fn from_str(s: &str) -> anyhow::Result<Self> {
81        match s {
82            "PreToolUse" => Ok(Self::PreToolUse),
83            "PostToolUse" => Ok(Self::PostToolUse),
84            "PostToolUseFailure" => Ok(Self::PostToolUseFailure),
85            "SessionStart" => Ok(Self::SessionStart),
86            "SessionEnd" => Ok(Self::SessionEnd),
87            "UserPromptSubmit" => Ok(Self::UserPromptSubmit),
88            "Notification" => Ok(Self::Notification),
89            "Stop" => Ok(Self::Stop),
90            "SubagentStart" => Ok(Self::SubagentStart),
91            "SubagentStop" => Ok(Self::SubagentStop),
92            _ => Err(anyhow!("Invalid hook event: {s}")),
93        }
94    }
95}
96
97#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash, Serialize, Deserialize)]
98#[serde(rename_all = "lowercase")]
99pub enum HookCategory {
100    System,
101    #[default]
102    Custom,
103}
104
105impl HookCategory {
106    pub const fn as_str(&self) -> &'static str {
107        match self {
108            Self::System => "system",
109            Self::Custom => "custom",
110        }
111    }
112}
113
114impl fmt::Display for HookCategory {
115    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
116        write!(f, "{}", self.as_str())
117    }
118}
119
120impl FromStr for HookCategory {
121    type Err = anyhow::Error;
122
123    fn from_str(s: &str) -> anyhow::Result<Self> {
124        match s {
125            "system" => Ok(Self::System),
126            "custom" => Ok(Self::Custom),
127            _ => Err(anyhow!("Invalid hook category: {s}")),
128        }
129    }
130}
131
132#[derive(Debug, Clone, Deserialize)]
133pub struct DiskHookConfig {
134    #[serde(default = "default_hook_id")]
135    pub id: HookId,
136    #[serde(default)]
137    pub name: String,
138    #[serde(default)]
139    pub description: String,
140    #[serde(default = "default_version")]
141    pub version: String,
142    #[serde(default = "default_true")]
143    pub enabled: bool,
144    pub event: HookEvent,
145    #[serde(default = "default_matcher")]
146    pub matcher: String,
147    #[serde(default)]
148    pub command: String,
149    #[serde(default, rename = "async")]
150    pub is_async: bool,
151    #[serde(default)]
152    pub category: HookCategory,
153    #[serde(default)]
154    pub tags: Vec<String>,
155    #[serde(default)]
156    pub visible_to: Vec<String>,
157}
158
159#[derive(Debug, Clone, Default, Serialize, Deserialize)]
160#[serde(rename_all = "PascalCase")]
161pub struct HookEventsConfig {
162    #[serde(default, skip_serializing_if = "Vec::is_empty")]
163    pub pre_tool_use: Vec<HookMatcher>,
164    #[serde(default, skip_serializing_if = "Vec::is_empty")]
165    pub post_tool_use: Vec<HookMatcher>,
166    #[serde(default, skip_serializing_if = "Vec::is_empty")]
167    pub post_tool_use_failure: Vec<HookMatcher>,
168    #[serde(default, skip_serializing_if = "Vec::is_empty")]
169    pub session_start: Vec<HookMatcher>,
170    #[serde(default, skip_serializing_if = "Vec::is_empty")]
171    pub session_end: Vec<HookMatcher>,
172    #[serde(default, skip_serializing_if = "Vec::is_empty")]
173    pub user_prompt_submit: Vec<HookMatcher>,
174    #[serde(default, skip_serializing_if = "Vec::is_empty")]
175    pub notification: Vec<HookMatcher>,
176    #[serde(default, skip_serializing_if = "Vec::is_empty")]
177    pub stop: Vec<HookMatcher>,
178    #[serde(default, skip_serializing_if = "Vec::is_empty")]
179    pub subagent_start: Vec<HookMatcher>,
180    #[serde(default, skip_serializing_if = "Vec::is_empty")]
181    pub subagent_stop: Vec<HookMatcher>,
182}
183
184#[derive(Debug, Clone, Serialize, Deserialize)]
185pub struct HookMatcher {
186    pub matcher: String,
187    pub hooks: Vec<HookAction>,
188}
189
190#[derive(Debug, Clone, Serialize, Deserialize)]
191pub struct HookAction {
192    #[serde(rename = "type")]
193    pub hook_type: HookType,
194    #[serde(skip_serializing_if = "Option::is_none")]
195    pub command: Option<String>,
196    #[serde(skip_serializing_if = "Option::is_none")]
197    pub prompt: Option<String>,
198    #[serde(default, rename = "async")]
199    pub r#async: bool,
200    #[serde(skip_serializing_if = "Option::is_none")]
201    pub timeout: Option<u32>,
202    #[serde(skip_serializing_if = "Option::is_none", rename = "statusMessage")]
203    pub status_message: Option<String>,
204}
205
206#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
207#[serde(rename_all = "lowercase")]
208pub enum HookType {
209    Command,
210    Prompt,
211    Agent,
212}
213
214impl HookEventsConfig {
215    pub fn is_empty(&self) -> bool {
216        self.pre_tool_use.is_empty()
217            && self.post_tool_use.is_empty()
218            && self.post_tool_use_failure.is_empty()
219            && self.session_start.is_empty()
220            && self.session_end.is_empty()
221            && self.user_prompt_submit.is_empty()
222            && self.notification.is_empty()
223            && self.stop.is_empty()
224            && self.subagent_start.is_empty()
225            && self.subagent_stop.is_empty()
226    }
227
228    pub fn matchers_for_event(&self, event: HookEvent) -> &[HookMatcher] {
229        match event {
230            HookEvent::PreToolUse => &self.pre_tool_use,
231            HookEvent::PostToolUse => &self.post_tool_use,
232            HookEvent::PostToolUseFailure => &self.post_tool_use_failure,
233            HookEvent::SessionStart => &self.session_start,
234            HookEvent::SessionEnd => &self.session_end,
235            HookEvent::UserPromptSubmit => &self.user_prompt_submit,
236            HookEvent::Notification => &self.notification,
237            HookEvent::Stop => &self.stop,
238            HookEvent::SubagentStart => &self.subagent_start,
239            HookEvent::SubagentStop => &self.subagent_stop,
240        }
241    }
242
243    pub fn validate(&self) -> anyhow::Result<()> {
244        for event in HookEvent::ALL_VARIANTS {
245            for matcher in self.matchers_for_event(*event) {
246                for action in &matcher.hooks {
247                    match action.hook_type {
248                        HookType::Command => {
249                            if action.command.is_none() {
250                                anyhow::bail!(
251                                    "Hook matcher '{}': command hook requires a 'command' field",
252                                    matcher.matcher
253                                );
254                            }
255                        },
256                        HookType::Prompt => {
257                            if action.prompt.is_none() {
258                                anyhow::bail!(
259                                    "Hook matcher '{}': prompt hook requires a 'prompt' field",
260                                    matcher.matcher
261                                );
262                            }
263                        },
264                        HookType::Agent => {},
265                    }
266                }
267            }
268        }
269
270        Ok(())
271    }
272}