use serde_json::Value;
use std::collections::HashSet;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub enum LifecycleEvent {
AgentStarted {
agent_id: String,
task_description: String,
},
AgentCompleted {
agent_id: String,
iterations: u32,
summary: String,
},
AgentFailed {
agent_id: String,
error: String,
iterations: u32,
},
ToolBeforeExecute {
agent_id: Option<String>,
tool_name: String,
args: Value,
},
ToolAfterExecute {
agent_id: Option<String>,
tool_name: String,
success: bool,
duration_ms: u64,
},
ProviderRequest {
agent_id: Option<String>,
provider: String,
model: String,
},
ProviderResponse {
agent_id: Option<String>,
provider: String,
model: String,
input_tokens: u64,
output_tokens: u64,
duration_ms: u64,
},
ValidationStarted {
agent_id: String,
checks: Vec<String>,
},
ValidationCompleted {
agent_id: String,
passed: bool,
issues: Vec<String>,
},
}
impl LifecycleEvent {
pub fn event_type(&self) -> &'static str {
match self {
Self::AgentStarted { .. } => "agent_started",
Self::AgentCompleted { .. } => "agent_completed",
Self::AgentFailed { .. } => "agent_failed",
Self::ToolBeforeExecute { .. } => "tool_before_execute",
Self::ToolAfterExecute { .. } => "tool_after_execute",
Self::ProviderRequest { .. } => "provider_request",
Self::ProviderResponse { .. } => "provider_response",
Self::ValidationStarted { .. } => "validation_started",
Self::ValidationCompleted { .. } => "validation_completed",
}
}
pub fn agent_id(&self) -> Option<&str> {
match self {
Self::AgentStarted { agent_id, .. }
| Self::AgentCompleted { agent_id, .. }
| Self::AgentFailed { agent_id, .. }
| Self::ValidationStarted { agent_id, .. }
| Self::ValidationCompleted { agent_id, .. } => Some(agent_id),
Self::ToolBeforeExecute { agent_id, .. }
| Self::ToolAfterExecute { agent_id, .. }
| Self::ProviderRequest { agent_id, .. }
| Self::ProviderResponse { agent_id, .. } => agent_id.as_deref(),
}
}
pub fn tool_name(&self) -> Option<&str> {
match self {
Self::ToolBeforeExecute { tool_name, .. }
| Self::ToolAfterExecute { tool_name, .. } => Some(tool_name),
_ => None,
}
}
}
#[derive(Debug, Clone)]
pub enum HookResult {
Continue,
Cancel {
reason: String,
},
Modified(Value),
}
#[derive(Debug, Clone, Default)]
pub struct EventFilter {
pub agent_ids: HashSet<String>,
pub event_types: HashSet<String>,
pub tool_names: HashSet<String>,
}
impl EventFilter {
pub fn matches(&self, event: &LifecycleEvent) -> bool {
if !self.event_types.is_empty() && !self.event_types.contains(event.event_type()) {
return false;
}
if !self.agent_ids.is_empty() {
if let Some(id) = event.agent_id() {
if !self.agent_ids.contains(id) {
return false;
}
} else {
return false;
}
}
if !self.tool_names.is_empty()
&& let Some(name) = event.tool_name()
&& !self.tool_names.contains(name)
{
return false;
}
true
}
}
#[async_trait::async_trait]
pub trait LifecycleHook: Send + Sync {
fn name(&self) -> &str;
fn priority(&self) -> i32 {
0
}
fn filter(&self) -> Option<EventFilter> {
None
}
async fn on_event(&self, event: &LifecycleEvent) -> HookResult;
}
pub struct HookRegistry {
hooks: Vec<Arc<dyn LifecycleHook>>,
}
impl HookRegistry {
pub fn new() -> Self {
Self { hooks: Vec::new() }
}
pub fn register(&mut self, hook: impl LifecycleHook + 'static) {
self.hooks.push(Arc::new(hook));
self.hooks.sort_by_key(|h| h.priority());
}
pub fn register_arc(&mut self, hook: Arc<dyn LifecycleHook>) {
self.hooks.push(hook);
self.hooks.sort_by_key(|h| h.priority());
}
pub async fn dispatch(&self, event: &LifecycleEvent) -> HookResult {
for hook in &self.hooks {
let matches = hook.filter().map(|f| f.matches(event)).unwrap_or(true);
if !matches {
continue;
}
match hook.on_event(event).await {
HookResult::Continue => {}
result @ HookResult::Cancel { .. } => return result,
result @ HookResult::Modified(_) => return result,
}
}
HookResult::Continue
}
pub fn len(&self) -> usize {
self.hooks.len()
}
pub fn is_empty(&self) -> bool {
self.hooks.is_empty()
}
}
impl Default for HookRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
struct CountingHook {
name: String,
}
#[async_trait::async_trait]
impl LifecycleHook for CountingHook {
fn name(&self) -> &str {
&self.name
}
async fn on_event(&self, _event: &LifecycleEvent) -> HookResult {
HookResult::Continue
}
}
#[test]
fn test_registry_register() {
let mut registry = HookRegistry::new();
assert!(registry.is_empty());
registry.register(CountingHook {
name: "test".to_string(),
});
assert_eq!(registry.len(), 1);
}
#[test]
fn test_event_filter_matches_all() {
let filter = EventFilter::default();
let event = LifecycleEvent::AgentStarted {
agent_id: "a1".to_string(),
task_description: "test".to_string(),
};
assert!(filter.matches(&event));
}
#[test]
fn test_event_filter_by_type() {
let filter = EventFilter {
event_types: HashSet::from(["agent_started".to_string()]),
..Default::default()
};
let started = LifecycleEvent::AgentStarted {
agent_id: "a1".to_string(),
task_description: "test".to_string(),
};
let completed = LifecycleEvent::AgentCompleted {
agent_id: "a1".to_string(),
iterations: 5,
summary: "done".to_string(),
};
assert!(filter.matches(&started));
assert!(!filter.matches(&completed));
}
#[test]
fn test_event_type_names() {
let event = LifecycleEvent::ToolBeforeExecute {
agent_id: Some("a1".to_string()),
tool_name: "read_file".to_string(),
args: serde_json::json!({}),
};
assert_eq!(event.event_type(), "tool_before_execute");
assert_eq!(event.agent_id(), Some("a1"));
assert_eq!(event.tool_name(), Some("read_file"));
}
}