1use serde::Serialize;
2use std::{collections::HashMap, fmt::Debug, fmt::Display};
3use tokio::task::JoinSet;
4use uuid::Uuid;
5
6#[derive(Debug, thiserror::Error)]
8pub enum HookError {
9 #[error("Database error: {0}")]
10 DatabaseError(String),
11 #[error("Serialization error: {0}")]
12 SerializationError(String),
13 #[error("Hook execution failed: {0}")]
14 ExecutionError(String),
15}
16
17#[derive(Debug, Serialize)]
18pub struct HookContext<State: Clone + Serialize> {
19 pub session_id: Option<Uuid>,
20 pub new_checkpoint_id: Option<Uuid>,
21 pub request_id: Uuid,
22 pub state: State,
23
24 #[serde(skip)]
25 background_tasks: JoinSet<Result<(), HookError>>,
26}
27
28impl<State: Clone + Serialize> Clone for HookContext<State> {
29 fn clone(&self) -> Self {
30 Self {
31 session_id: self.session_id,
32 new_checkpoint_id: self.new_checkpoint_id,
33 request_id: self.request_id,
34 state: self.state.clone(),
35 background_tasks: JoinSet::new(),
36 }
37 }
38}
39
40impl<State: Clone + Serialize> HookContext<State> {
41 pub fn new(session_id: Option<Uuid>, state: State) -> Self {
42 Self {
43 session_id,
44 new_checkpoint_id: None,
45 request_id: Uuid::new_v4(),
46 state,
47 background_tasks: JoinSet::new(),
48 }
49 }
50
51 pub fn set_session_id(&mut self, session_id: Uuid) {
52 self.session_id = Some(session_id);
53 }
54
55 pub fn set_new_checkpoint_id(&mut self, new_checkpoint_id: Uuid) {
56 self.new_checkpoint_id = Some(new_checkpoint_id);
57 }
58}
59
60impl<State: Clone + Serialize> HookContext<State> {
61 pub fn spawn_task<F>(&mut self, task: F)
62 where
63 F: Future<Output = Result<(), HookError>> + Send + 'static,
64 {
65 self.background_tasks.spawn(task);
66 }
67}
68
69#[derive(Debug, Clone, PartialEq, Eq, Hash)]
70pub enum LifecycleEvent {
71 BeforeRequest,
73 AfterRequest,
74
75 BeforeInference,
77 AfterInference,
78
79 ToolCallRequested,
81 BeforeToolExecution,
82 AfterToolExecution,
83 ToolCallAborted,
84
85 Error,
87}
88
89impl Display for LifecycleEvent {
90 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
91 write!(f, "{:?}", self)
92 }
93}
94
95#[derive(Debug, Default)]
97pub enum HookAction {
98 #[default]
99 Continue,
100 Skip,
102 Abort {
104 name: Option<String>,
105 reason: String,
106 },
107}
108
109impl HookAction {
110 pub fn ok(self) -> Result<(), String> {
112 match self {
113 HookAction::Abort { name, reason } => Err(format!(
114 "[{}:hook_abort] {}",
115 name.unwrap_or_default(),
116 reason
117 )),
118 _ => Ok(()),
119 }
120 }
121}
122
123#[async_trait::async_trait]
124pub trait Hook<State: Clone + Serialize>: Send + Sync {
125 fn name(&self) -> &str;
126
127 fn priority(&self) -> u8 {
129 50
130 }
131
132 async fn execute(
133 &self,
134 ctx: &mut HookContext<State>,
135 event: &LifecycleEvent,
136 ) -> Result<HookAction, HookError>;
137}
138
139#[derive(Default)]
140pub struct HookRegistry<State> {
141 hooks: HashMap<LifecycleEvent, Vec<Box<dyn Hook<State>>>>,
142}
143impl<State: Clone + Serialize> std::fmt::Debug for HookRegistry<State> {
144 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145 let mut map = f.debug_map();
146 for (event, hooks) in &self.hooks {
147 let hook_names: Vec<&str> = hooks.iter().map(|hook| hook.name()).collect();
148 map.entry(event, &hook_names);
149 }
150 map.finish()
151 }
152}
153
154impl<State: Clone + Serialize> HookRegistry<State> {
155 pub fn register(&mut self, event: LifecycleEvent, hook: Box<dyn Hook<State>>) {
156 let hooks = self.hooks.entry(event).or_default();
157 hooks.push(hook);
158
159 hooks.sort_by_key(|h| h.priority());
161 }
162
163 pub async fn execute_hooks(
164 &self,
165 ctx: &mut HookContext<State>,
166 event: &LifecycleEvent,
167 ) -> Result<HookAction, HookError> {
168 let Some(hooks) = self.hooks.get(event) else {
169 return Ok(HookAction::Continue);
170 };
171
172 for hook in hooks {
173 match hook.execute(ctx, event).await? {
174 HookAction::Continue => continue,
175 HookAction::Skip => return Ok(HookAction::Skip),
176 HookAction::Abort { name, reason } => {
177 return Ok(HookAction::Abort {
178 name: Some(name.unwrap_or(hook.name().to_string())),
179 reason,
180 });
181 }
182 }
183 }
184
185 Ok(HookAction::Continue)
186 }
187}
188
189#[macro_export]
241macro_rules! define_hook {
242 ($name:ident, $hook_name:expr, async |&$self:ident, $ctx:ident: &mut HookContext<$state:ty>, $event:ident: &LifecycleEvent| $body:block) => {
243 #[async_trait::async_trait]
244 impl Hook<$state> for $name {
245 fn name(&self) -> &str {
246 $hook_name
247 }
248 async fn execute(
249 &$self,
250 $ctx: &mut HookContext<$state>,
251 $event: &LifecycleEvent,
252 ) -> Result<HookAction, HookError> {
253 $body
254 }
255 }
256 };
257}