1use serde_json::Value;
30use std::collections::HashSet;
31use std::sync::Arc;
32
33#[derive(Debug, Clone)]
35pub enum LifecycleEvent {
36 AgentStarted {
38 agent_id: String,
40 task_description: String,
42 },
43 AgentCompleted {
45 agent_id: String,
47 iterations: u32,
49 summary: String,
51 },
52 AgentFailed {
54 agent_id: String,
56 error: String,
58 iterations: u32,
60 },
61 ToolBeforeExecute {
63 agent_id: Option<String>,
65 tool_name: String,
67 args: Value,
69 },
70 ToolAfterExecute {
72 agent_id: Option<String>,
74 tool_name: String,
76 success: bool,
78 duration_ms: u64,
80 },
81 ProviderRequest {
83 agent_id: Option<String>,
85 provider: String,
87 model: String,
89 },
90 ProviderResponse {
92 agent_id: Option<String>,
94 provider: String,
96 model: String,
98 input_tokens: u64,
100 output_tokens: u64,
102 duration_ms: u64,
104 },
105 ValidationStarted {
107 agent_id: String,
109 checks: Vec<String>,
111 },
112 ValidationCompleted {
114 agent_id: String,
116 passed: bool,
118 issues: Vec<String>,
120 },
121}
122
123impl LifecycleEvent {
124 pub fn event_type(&self) -> &'static str {
126 match self {
127 Self::AgentStarted { .. } => "agent_started",
128 Self::AgentCompleted { .. } => "agent_completed",
129 Self::AgentFailed { .. } => "agent_failed",
130 Self::ToolBeforeExecute { .. } => "tool_before_execute",
131 Self::ToolAfterExecute { .. } => "tool_after_execute",
132 Self::ProviderRequest { .. } => "provider_request",
133 Self::ProviderResponse { .. } => "provider_response",
134 Self::ValidationStarted { .. } => "validation_started",
135 Self::ValidationCompleted { .. } => "validation_completed",
136 }
137 }
138
139 pub fn agent_id(&self) -> Option<&str> {
141 match self {
142 Self::AgentStarted { agent_id, .. }
143 | Self::AgentCompleted { agent_id, .. }
144 | Self::AgentFailed { agent_id, .. }
145 | Self::ValidationStarted { agent_id, .. }
146 | Self::ValidationCompleted { agent_id, .. } => Some(agent_id),
147 Self::ToolBeforeExecute { agent_id, .. }
148 | Self::ToolAfterExecute { agent_id, .. }
149 | Self::ProviderRequest { agent_id, .. }
150 | Self::ProviderResponse { agent_id, .. } => agent_id.as_deref(),
151 }
152 }
153
154 pub fn tool_name(&self) -> Option<&str> {
156 match self {
157 Self::ToolBeforeExecute { tool_name, .. }
158 | Self::ToolAfterExecute { tool_name, .. } => Some(tool_name),
159 _ => None,
160 }
161 }
162}
163
164#[derive(Debug, Clone)]
166pub enum HookResult {
167 Continue,
169 Cancel {
171 reason: String,
173 },
174 Modified(Value),
176}
177
178#[derive(Debug, Clone, Default)]
180pub struct EventFilter {
181 pub agent_ids: HashSet<String>,
183 pub event_types: HashSet<String>,
185 pub tool_names: HashSet<String>,
187}
188
189impl EventFilter {
190 pub fn matches(&self, event: &LifecycleEvent) -> bool {
192 if !self.event_types.is_empty() && !self.event_types.contains(event.event_type()) {
193 return false;
194 }
195 if !self.agent_ids.is_empty() {
196 if let Some(id) = event.agent_id() {
197 if !self.agent_ids.contains(id) {
198 return false;
199 }
200 } else {
201 return false;
202 }
203 }
204 if !self.tool_names.is_empty()
205 && let Some(name) = event.tool_name()
206 && !self.tool_names.contains(name)
207 {
208 return false;
209 }
210 true
211 }
212}
213
214#[async_trait::async_trait]
216pub trait LifecycleHook: Send + Sync {
217 fn name(&self) -> &str;
219
220 fn priority(&self) -> i32 {
222 0
223 }
224
225 fn filter(&self) -> Option<EventFilter> {
227 None
228 }
229
230 async fn on_event(&self, event: &LifecycleEvent) -> HookResult;
232}
233
234pub struct HookRegistry {
236 hooks: Vec<Arc<dyn LifecycleHook>>,
237}
238
239impl HookRegistry {
240 pub fn new() -> Self {
242 Self { hooks: Vec::new() }
243 }
244
245 pub fn register(&mut self, hook: impl LifecycleHook + 'static) {
247 self.hooks.push(Arc::new(hook));
248 self.hooks.sort_by_key(|h| h.priority());
249 }
250
251 pub fn register_arc(&mut self, hook: Arc<dyn LifecycleHook>) {
253 self.hooks.push(hook);
254 self.hooks.sort_by_key(|h| h.priority());
255 }
256
257 pub async fn dispatch(&self, event: &LifecycleEvent) -> HookResult {
262 for hook in &self.hooks {
263 let matches = hook.filter().map(|f| f.matches(event)).unwrap_or(true);
264
265 if !matches {
266 continue;
267 }
268
269 match hook.on_event(event).await {
270 HookResult::Continue => {}
271 result @ HookResult::Cancel { .. } => return result,
272 result @ HookResult::Modified(_) => return result,
273 }
274 }
275 HookResult::Continue
276 }
277
278 pub fn len(&self) -> usize {
280 self.hooks.len()
281 }
282
283 pub fn is_empty(&self) -> bool {
285 self.hooks.is_empty()
286 }
287}
288
289impl Default for HookRegistry {
290 fn default() -> Self {
291 Self::new()
292 }
293}
294
295#[cfg(test)]
296mod tests {
297 use super::*;
298
299 struct CountingHook {
300 name: String,
301 }
302
303 #[async_trait::async_trait]
304 impl LifecycleHook for CountingHook {
305 fn name(&self) -> &str {
306 &self.name
307 }
308 async fn on_event(&self, _event: &LifecycleEvent) -> HookResult {
309 HookResult::Continue
310 }
311 }
312
313 #[test]
314 fn test_registry_register() {
315 let mut registry = HookRegistry::new();
316 assert!(registry.is_empty());
317 registry.register(CountingHook {
318 name: "test".to_string(),
319 });
320 assert_eq!(registry.len(), 1);
321 }
322
323 #[test]
324 fn test_event_filter_matches_all() {
325 let filter = EventFilter::default();
326 let event = LifecycleEvent::AgentStarted {
327 agent_id: "a1".to_string(),
328 task_description: "test".to_string(),
329 };
330 assert!(filter.matches(&event));
331 }
332
333 #[test]
334 fn test_event_filter_by_type() {
335 let filter = EventFilter {
336 event_types: HashSet::from(["agent_started".to_string()]),
337 ..Default::default()
338 };
339 let started = LifecycleEvent::AgentStarted {
340 agent_id: "a1".to_string(),
341 task_description: "test".to_string(),
342 };
343 let completed = LifecycleEvent::AgentCompleted {
344 agent_id: "a1".to_string(),
345 iterations: 5,
346 summary: "done".to_string(),
347 };
348 assert!(filter.matches(&started));
349 assert!(!filter.matches(&completed));
350 }
351
352 #[test]
353 fn test_event_type_names() {
354 let event = LifecycleEvent::ToolBeforeExecute {
355 agent_id: Some("a1".to_string()),
356 tool_name: "read_file".to_string(),
357 args: serde_json::json!({}),
358 };
359 assert_eq!(event.event_type(), "tool_before_execute");
360 assert_eq!(event.agent_id(), Some("a1"));
361 assert_eq!(event.tool_name(), Some("read_file"));
362 }
363}