1use chrono::{DateTime, Utc};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use uuid::Uuid;
7
8use crate::hook::HookEvent;
9
10#[derive(Debug, Clone, Default, Serialize, Deserialize)]
12#[serde(rename_all = "snake_case", tag = "action")]
13pub enum HookResult {
14 #[default]
16 Continue,
17 ContinueWith {
19 modifications: HashMap<String, serde_json::Value>,
21 },
22 Block {
24 reason: String,
26 },
27 Ask {
29 question: String,
31 #[serde(default)]
33 default: Option<String>,
34 },
35}
36
37impl HookResult {
38 #[must_use]
40 pub(crate) fn continue_() -> Self {
41 Self::Continue
42 }
43
44 #[must_use]
46 pub(crate) fn continue_with(modifications: HashMap<String, serde_json::Value>) -> Self {
47 Self::ContinueWith { modifications }
48 }
49
50 #[must_use]
52 pub(crate) fn block(reason: impl Into<String>) -> Self {
53 Self::Block {
54 reason: reason.into(),
55 }
56 }
57
58 #[must_use]
60 pub(crate) fn ask(question: impl Into<String>) -> Self {
61 Self::Ask {
62 question: question.into(),
63 default: None,
64 }
65 }
66
67 #[must_use]
69 pub(crate) fn is_blocking(&self) -> bool {
70 matches!(self, Self::Block { .. })
71 }
72
73 #[must_use]
75 pub(crate) fn requires_interaction(&self) -> bool {
76 matches!(self, Self::Ask { .. })
77 }
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
82pub(crate) struct HookContext {
83 pub invocation_id: Uuid,
85 pub event: HookEvent,
87 #[serde(default)]
89 pub session_id: Option<Uuid>,
90 #[serde(default)]
92 pub user_id: Option<Uuid>,
93 pub timestamp: DateTime<Utc>,
95 #[serde(default)]
97 pub data: HashMap<String, serde_json::Value>,
98 #[serde(default)]
100 pub previous_results: Vec<HookResult>,
101}
102
103impl HookContext {
104 #[must_use]
106 pub(crate) fn new(event: HookEvent) -> Self {
107 Self {
108 invocation_id: Uuid::new_v4(),
109 event,
110 session_id: None,
111 user_id: None,
112 timestamp: Utc::now(),
113 data: HashMap::new(),
114 previous_results: Vec::new(),
115 }
116 }
117
118 #[must_use]
120 pub(crate) fn with_session(mut self, session_id: Uuid) -> Self {
121 self.session_id = Some(session_id);
122 self
123 }
124
125 #[must_use]
127 pub(crate) fn with_user(mut self, user_id: Uuid) -> Self {
128 self.user_id = Some(user_id);
129 self
130 }
131
132 #[must_use]
134 pub(crate) fn with_data(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
135 self.data.insert(key.into(), value);
136 self
137 }
138
139 pub(crate) fn add_previous_result(&mut self, result: HookResult) {
141 self.previous_results.push(result);
142 }
143
144 #[must_use]
146 pub(crate) fn get_data(&self, key: &str) -> Option<&serde_json::Value> {
147 self.data.get(key)
148 }
149
150 #[must_use]
152 pub(crate) fn get_data_as<T: for<'de> Deserialize<'de>>(&self, key: &str) -> Option<T> {
153 self.data
154 .get(key)
155 .and_then(|v| serde_json::from_value(v.clone()).ok())
156 }
157
158 #[must_use]
160 pub(crate) fn was_blocked(&self) -> bool {
161 self.previous_results.iter().any(HookResult::is_blocking)
162 }
163
164 #[must_use]
166 pub(crate) fn to_json(&self) -> serde_json::Value {
167 serde_json::to_value(self).unwrap_or(serde_json::Value::Null)
168 }
169
170 #[must_use]
172 pub(crate) fn to_env_vars(&self) -> HashMap<String, String> {
173 let mut env = HashMap::new();
174
175 env.insert("ASTRID_HOOK_ID".to_string(), self.invocation_id.to_string());
176 env.insert("ASTRID_HOOK_EVENT".to_string(), self.event.to_string());
177 env.insert(
178 "ASTRID_HOOK_TIMESTAMP".to_string(),
179 self.timestamp.to_rfc3339(),
180 );
181
182 if let Some(session_id) = &self.session_id {
183 env.insert("ASTRID_SESSION_ID".to_string(), session_id.to_string());
184 }
185
186 if let Some(user_id) = &self.user_id {
187 env.insert("ASTRID_USER_ID".to_string(), user_id.to_string());
188 }
189
190 if !self.data.is_empty()
192 && let Ok(json) = serde_json::to_string(&self.data)
193 {
194 env.insert("ASTRID_HOOK_DATA".to_string(), json);
195 }
196
197 env
198 }
199}
200
201#[derive(Debug, Clone, Serialize, Deserialize)]
203pub(crate) struct HookExecution {
204 pub hook_id: Uuid,
206 pub invocation_id: Uuid,
208 pub started_at: DateTime<Utc>,
210 pub completed_at: DateTime<Utc>,
212 pub duration_ms: u64,
214 pub result: HookExecutionResult,
216}
217
218#[derive(Debug, Clone, Serialize, Deserialize)]
220#[serde(rename_all = "snake_case", tag = "status")]
221pub(crate) enum HookExecutionResult {
222 Success {
224 result: HookResult,
226 #[serde(default)]
228 stdout: Option<String>,
229 },
230 Failure {
232 error: String,
234 #[serde(default)]
236 stderr: Option<String>,
237 },
238 Timeout {
240 timeout_secs: u64,
242 },
243 Skipped {
245 reason: String,
247 },
248}
249
250impl HookExecutionResult {
251 #[must_use]
253 pub(crate) fn is_success(&self) -> bool {
254 matches!(self, Self::Success { .. })
255 }
256
257 #[must_use]
259 pub(crate) fn hook_result(&self) -> Option<&HookResult> {
260 match self {
261 Self::Success { result, .. } => Some(result),
262 _ => None,
263 }
264 }
265}
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270
271 #[test]
272 fn test_hook_result_continue() {
273 let result = HookResult::continue_();
274 assert!(!result.is_blocking());
275 assert!(!result.requires_interaction());
276 }
277
278 #[test]
279 fn test_hook_result_block() {
280 let result = HookResult::block("Policy violation");
281 assert!(result.is_blocking());
282 }
283
284 #[test]
285 fn test_hook_result_ask() {
286 let result = HookResult::ask("Are you sure?");
287 assert!(result.requires_interaction());
288 }
289
290 #[test]
291 fn test_hook_context_creation() {
292 let session_id = Uuid::new_v4();
293 let user_id = Uuid::new_v4();
294
295 let ctx = HookContext::new(HookEvent::PreToolCall)
296 .with_session(session_id)
297 .with_user(user_id)
298 .with_data("tool_name", serde_json::json!("read_file"));
299
300 assert_eq!(ctx.event, HookEvent::PreToolCall);
301 assert_eq!(ctx.session_id, Some(session_id));
302 assert_eq!(ctx.user_id, Some(user_id));
303 assert!(ctx.get_data("tool_name").is_some());
304 }
305
306 #[test]
307 fn test_hook_context_env_vars() {
308 let ctx = HookContext::new(HookEvent::SessionStart);
309 let env = ctx.to_env_vars();
310
311 assert!(env.contains_key("ASTRID_HOOK_ID"));
312 assert_eq!(
313 env.get("ASTRID_HOOK_EVENT"),
314 Some(&"session_start".to_string())
315 );
316 }
317
318 #[test]
319 fn test_hook_execution_result() {
320 let success = HookExecutionResult::Success {
321 result: HookResult::Continue,
322 stdout: Some("ok".to_string()),
323 };
324 assert!(success.is_success());
325 assert!(success.hook_result().is_some());
326
327 let failure = HookExecutionResult::Failure {
328 error: "command failed".to_string(),
329 stderr: None,
330 };
331 assert!(!failure.is_success());
332 assert!(failure.hook_result().is_none());
333 }
334}