Skip to main content

astrid_hooks/
result.rs

1//! Hook execution results and context.
2
3use chrono::{DateTime, Utc};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use uuid::Uuid;
7
8use crate::hook::HookEvent;
9
10/// Result of hook execution.
11#[derive(Debug, Clone, Default, Serialize, Deserialize)]
12#[serde(rename_all = "snake_case", tag = "action")]
13pub enum HookResult {
14    /// Continue with the operation (no changes).
15    #[default]
16    Continue,
17    /// Continue with modified context.
18    ContinueWith {
19        /// Modified context values.
20        modifications: HashMap<String, serde_json::Value>,
21    },
22    /// Block the operation.
23    Block {
24        /// Reason for blocking.
25        reason: String,
26    },
27    /// Ask the user before proceeding.
28    Ask {
29        /// Question to ask the user.
30        question: String,
31        /// Default answer if user doesn't respond.
32        #[serde(default)]
33        default: Option<String>,
34    },
35}
36
37impl HookResult {
38    /// Create a continue result.
39    #[must_use]
40    pub(crate) fn continue_() -> Self {
41        Self::Continue
42    }
43
44    /// Create a continue-with-modifications result.
45    #[must_use]
46    pub(crate) fn continue_with(modifications: HashMap<String, serde_json::Value>) -> Self {
47        Self::ContinueWith { modifications }
48    }
49
50    /// Create a block result.
51    #[must_use]
52    pub(crate) fn block(reason: impl Into<String>) -> Self {
53        Self::Block {
54            reason: reason.into(),
55        }
56    }
57
58    /// Create an ask result.
59    #[must_use]
60    pub(crate) fn ask(question: impl Into<String>) -> Self {
61        Self::Ask {
62            question: question.into(),
63            default: None,
64        }
65    }
66
67    /// Check if this result blocks the operation.
68    #[must_use]
69    pub(crate) fn is_blocking(&self) -> bool {
70        matches!(self, Self::Block { .. })
71    }
72
73    /// Check if this result requires user interaction.
74    #[must_use]
75    pub(crate) fn requires_interaction(&self) -> bool {
76        matches!(self, Self::Ask { .. })
77    }
78}
79
80/// Context provided to hooks during execution.
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub(crate) struct HookContext {
83    /// Unique identifier for this hook invocation.
84    pub invocation_id: Uuid,
85    /// The event that triggered the hook.
86    pub event: HookEvent,
87    /// Session ID if available.
88    #[serde(default)]
89    pub session_id: Option<Uuid>,
90    /// User ID if available.
91    #[serde(default)]
92    pub user_id: Option<Uuid>,
93    /// Timestamp of the event.
94    pub timestamp: DateTime<Utc>,
95    /// Event-specific data.
96    #[serde(default)]
97    pub data: HashMap<String, serde_json::Value>,
98    /// Previous hook results in the chain.
99    #[serde(default)]
100    pub previous_results: Vec<HookResult>,
101}
102
103impl HookContext {
104    /// Create a new hook context.
105    #[must_use]
106    pub(crate) fn new(event: HookEvent) -> Self {
107        Self {
108            invocation_id: Uuid::new_v4(),
109            event,
110            session_id: None,
111            user_id: None,
112            timestamp: Utc::now(),
113            data: HashMap::new(),
114            previous_results: Vec::new(),
115        }
116    }
117
118    /// Set the session ID.
119    #[must_use]
120    pub(crate) fn with_session(mut self, session_id: Uuid) -> Self {
121        self.session_id = Some(session_id);
122        self
123    }
124
125    /// Set the user ID.
126    #[must_use]
127    pub(crate) fn with_user(mut self, user_id: Uuid) -> Self {
128        self.user_id = Some(user_id);
129        self
130    }
131
132    /// Add data to the context.
133    #[must_use]
134    pub(crate) fn with_data(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
135        self.data.insert(key.into(), value);
136        self
137    }
138
139    /// Add a previous hook result.
140    pub(crate) fn add_previous_result(&mut self, result: HookResult) {
141        self.previous_results.push(result);
142    }
143
144    /// Get a data value.
145    #[must_use]
146    pub(crate) fn get_data(&self, key: &str) -> Option<&serde_json::Value> {
147        self.data.get(key)
148    }
149
150    /// Get a data value as a specific type.
151    #[must_use]
152    pub(crate) fn get_data_as<T: for<'de> Deserialize<'de>>(&self, key: &str) -> Option<T> {
153        self.data
154            .get(key)
155            .and_then(|v| serde_json::from_value(v.clone()).ok())
156    }
157
158    /// Check if any previous hook blocked.
159    #[must_use]
160    pub(crate) fn was_blocked(&self) -> bool {
161        self.previous_results.iter().any(HookResult::is_blocking)
162    }
163
164    /// Convert context to JSON for passing to handlers.
165    #[must_use]
166    pub(crate) fn to_json(&self) -> serde_json::Value {
167        serde_json::to_value(self).unwrap_or(serde_json::Value::Null)
168    }
169
170    /// Convert context to environment variables.
171    #[must_use]
172    pub(crate) fn to_env_vars(&self) -> HashMap<String, String> {
173        let mut env = HashMap::new();
174
175        env.insert("ASTRID_HOOK_ID".to_string(), self.invocation_id.to_string());
176        env.insert("ASTRID_HOOK_EVENT".to_string(), self.event.to_string());
177        env.insert(
178            "ASTRID_HOOK_TIMESTAMP".to_string(),
179            self.timestamp.to_rfc3339(),
180        );
181
182        if let Some(session_id) = &self.session_id {
183            env.insert("ASTRID_SESSION_ID".to_string(), session_id.to_string());
184        }
185
186        if let Some(user_id) = &self.user_id {
187            env.insert("ASTRID_USER_ID".to_string(), user_id.to_string());
188        }
189
190        // Add data as JSON
191        if !self.data.is_empty()
192            && let Ok(json) = serde_json::to_string(&self.data)
193        {
194            env.insert("ASTRID_HOOK_DATA".to_string(), json);
195        }
196
197        env
198    }
199}
200
201/// Execution metadata for a hook run.
202#[derive(Debug, Clone, Serialize, Deserialize)]
203pub(crate) struct HookExecution {
204    /// Hook ID that was executed.
205    pub hook_id: Uuid,
206    /// Invocation ID from the context.
207    pub invocation_id: Uuid,
208    /// When execution started.
209    pub started_at: DateTime<Utc>,
210    /// When execution completed.
211    pub completed_at: DateTime<Utc>,
212    /// Duration in milliseconds.
213    pub duration_ms: u64,
214    /// Result of the execution.
215    pub result: HookExecutionResult,
216}
217
218/// Result of hook execution.
219#[derive(Debug, Clone, Serialize, Deserialize)]
220#[serde(rename_all = "snake_case", tag = "status")]
221pub(crate) enum HookExecutionResult {
222    /// Hook executed successfully.
223    Success {
224        /// The hook's result.
225        result: HookResult,
226        /// Stdout output if applicable.
227        #[serde(default)]
228        stdout: Option<String>,
229    },
230    /// Hook failed to execute.
231    Failure {
232        /// Error message.
233        error: String,
234        /// Stderr output if applicable.
235        #[serde(default)]
236        stderr: Option<String>,
237    },
238    /// Hook execution timed out.
239    Timeout {
240        /// Timeout duration in seconds.
241        timeout_secs: u64,
242    },
243    /// Hook was skipped (disabled or matcher didn't match).
244    Skipped {
245        /// Reason for skipping.
246        reason: String,
247    },
248}
249
250impl HookExecutionResult {
251    /// Check if execution was successful.
252    #[must_use]
253    pub(crate) fn is_success(&self) -> bool {
254        matches!(self, Self::Success { .. })
255    }
256
257    /// Get the hook result if successful.
258    #[must_use]
259    pub(crate) fn hook_result(&self) -> Option<&HookResult> {
260        match self {
261            Self::Success { result, .. } => Some(result),
262            _ => None,
263        }
264    }
265}
266
267#[cfg(test)]
268mod tests {
269    use super::*;
270
271    #[test]
272    fn test_hook_result_continue() {
273        let result = HookResult::continue_();
274        assert!(!result.is_blocking());
275        assert!(!result.requires_interaction());
276    }
277
278    #[test]
279    fn test_hook_result_block() {
280        let result = HookResult::block("Policy violation");
281        assert!(result.is_blocking());
282    }
283
284    #[test]
285    fn test_hook_result_ask() {
286        let result = HookResult::ask("Are you sure?");
287        assert!(result.requires_interaction());
288    }
289
290    #[test]
291    fn test_hook_context_creation() {
292        let session_id = Uuid::new_v4();
293        let user_id = Uuid::new_v4();
294
295        let ctx = HookContext::new(HookEvent::PreToolCall)
296            .with_session(session_id)
297            .with_user(user_id)
298            .with_data("tool_name", serde_json::json!("read_file"));
299
300        assert_eq!(ctx.event, HookEvent::PreToolCall);
301        assert_eq!(ctx.session_id, Some(session_id));
302        assert_eq!(ctx.user_id, Some(user_id));
303        assert!(ctx.get_data("tool_name").is_some());
304    }
305
306    #[test]
307    fn test_hook_context_env_vars() {
308        let ctx = HookContext::new(HookEvent::SessionStart);
309        let env = ctx.to_env_vars();
310
311        assert!(env.contains_key("ASTRID_HOOK_ID"));
312        assert_eq!(
313            env.get("ASTRID_HOOK_EVENT"),
314            Some(&"session_start".to_string())
315        );
316    }
317
318    #[test]
319    fn test_hook_execution_result() {
320        let success = HookExecutionResult::Success {
321            result: HookResult::Continue,
322            stdout: Some("ok".to_string()),
323        };
324        assert!(success.is_success());
325        assert!(success.hook_result().is_some());
326
327        let failure = HookExecutionResult::Failure {
328            error: "command failed".to_string(),
329            stderr: None,
330        };
331        assert!(!failure.is_success());
332        assert!(failure.hook_result().is_none());
333    }
334}