stakpak_shared/hooks/
mod.rs1use serde::Serialize;
2use std::{collections::HashMap, fmt::Debug, fmt::Display};
3use tokio::task::JoinSet;
4use uuid::Uuid;
5
6#[derive(Debug, thiserror::Error)]
8pub enum HookError {
9 #[error("Database error: {0}")]
10 DatabaseError(String),
11 #[error("Serialization error: {0}")]
12 SerializationError(String),
13 #[error("Hook execution failed: {0}")]
14 ExecutionError(String),
15}
16
17#[derive(Debug, Serialize)]
18pub struct HookContext<State: Clone + Serialize> {
19 pub session_id: Option<Uuid>,
20 pub new_checkpoint_id: Option<Uuid>,
21 pub request_id: Uuid,
22 pub state: State,
23
24 #[serde(skip)]
25 background_tasks: JoinSet<Result<(), HookError>>,
26}
27
28impl<State: Clone + Serialize> Clone for HookContext<State> {
29 fn clone(&self) -> Self {
30 Self {
31 session_id: self.session_id,
32 new_checkpoint_id: self.new_checkpoint_id,
33 request_id: self.request_id,
34 state: self.state.clone(),
35 background_tasks: JoinSet::new(),
36 }
37 }
38}
39
40impl<State: Clone + Serialize> HookContext<State> {
41 pub fn new(session_id: Option<Uuid>, state: State) -> Self {
42 Self {
43 session_id,
44 new_checkpoint_id: None,
45 request_id: Uuid::new_v4(),
46 state,
47 background_tasks: JoinSet::new(),
48 }
49 }
50
51 pub fn set_session_id(&mut self, session_id: Uuid) {
52 self.session_id = Some(session_id);
53 }
54
55 pub fn set_new_checkpoint_id(&mut self, new_checkpoint_id: Uuid) {
56 self.new_checkpoint_id = Some(new_checkpoint_id);
57 }
58}
59
60impl<State: Clone + Serialize> HookContext<State> {
61 pub fn spawn_task<F>(&mut self, task: F)
62 where
63 F: Future<Output = Result<(), HookError>> + Send + 'static,
64 {
65 self.background_tasks.spawn(task);
66 }
67}
68
69#[derive(Debug, Clone, PartialEq, Eq, Hash)]
70pub enum LifecycleEvent {
71 BeforeRequest,
73 AfterRequest,
74
75 BeforeInference,
77 AfterInference,
78
79 ToolCallRequested,
81 BeforeToolExecution,
82 AfterToolExecution,
83 ToolCallAborted,
84
85 Error,
87}
88
89impl Display for LifecycleEvent {
90 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
91 write!(f, "{:?}", self)
92 }
93}
94
95#[derive(Debug, Default)]
97pub enum HookAction {
98 #[default]
99 Continue,
100 Skip,
102 Abort { reason: String },
104}
105
106impl HookAction {
107 pub fn ok(self) -> Result<(), String> {
109 match self {
110 HookAction::Abort { reason } => Err(reason),
111 _ => Ok(()),
112 }
113 }
114}
115
116#[async_trait::async_trait]
117pub trait Hook<State: Clone + Serialize>: Send + Sync {
118 fn name(&self) -> &str;
119
120 fn priority(&self) -> u8 {
122 50
123 }
124
125 async fn execute(
126 &self,
127 ctx: &mut HookContext<State>,
128 event: &LifecycleEvent,
129 ) -> Result<HookAction, HookError>;
130}
131
132#[derive(Default)]
133pub struct HookRegistry<State> {
134 hooks: HashMap<LifecycleEvent, Vec<Box<dyn Hook<State>>>>,
135}
136impl<State: Clone + Serialize> std::fmt::Debug for HookRegistry<State> {
137 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
138 let mut map = f.debug_map();
139 for (event, hooks) in &self.hooks {
140 let hook_names: Vec<&str> = hooks.iter().map(|hook| hook.name()).collect();
141 map.entry(event, &hook_names);
142 }
143 map.finish()
144 }
145}
146
147impl<State: Clone + Serialize> HookRegistry<State> {
148 pub fn register(&mut self, event: LifecycleEvent, hook: Box<dyn Hook<State>>) {
149 let hooks = self.hooks.entry(event).or_default();
150 hooks.push(hook);
151
152 hooks.sort_by_key(|h| h.priority());
154 }
155
156 pub async fn execute_hooks(
157 &self,
158 ctx: &mut HookContext<State>,
159 event: &LifecycleEvent,
160 ) -> Result<HookAction, HookError> {
161 let Some(hooks) = self.hooks.get(event) else {
162 return Ok(HookAction::Continue);
163 };
164
165 for hook in hooks {
166 match hook.execute(ctx, event).await? {
167 HookAction::Continue => continue,
168 HookAction::Skip => return Ok(HookAction::Skip),
169 HookAction::Abort { reason } => {
170 return Ok(HookAction::Abort { reason });
171 }
172 }
173 }
174
175 Ok(HookAction::Continue)
176 }
177}
178
179#[macro_export]
231macro_rules! define_hook {
232 ($name:ident, $hook_name:expr, async |&$self:ident, $ctx:ident: &mut HookContext<$state:ty>, $event:ident: &LifecycleEvent| $body:block) => {
233 #[async_trait::async_trait]
234 impl Hook<$state> for $name {
235 fn name(&self) -> &str {
236 $hook_name
237 }
238 async fn execute(
239 &$self,
240 $ctx: &mut HookContext<$state>,
241 $event: &LifecycleEvent,
242 ) -> Result<HookAction, HookError> {
243 $body
244 }
245 }
246 };
247}