Skip to main content

ai_agents_hooks/
lib.rs

1use async_trait::async_trait;
2use serde_json::Value;
3use std::sync::Arc;
4use std::time::Instant;
5use tracing::{debug, error, info, warn};
6
7use ai_agents_core::{AgentError, AgentResponse};
8use ai_agents_hitl::{ApprovalRequest, ApprovalResult};
9use ai_agents_llm::{ChatMessage, LLMResponse};
10use ai_agents_memory::{MemoryBudgetEvent, MemoryCompressEvent, MemoryEvictEvent};
11use ai_agents_tools::ToolResult;
12
13#[async_trait]
14pub trait AgentHooks: Send + Sync {
15    async fn on_message_received(&self, _message: &str) {}
16
17    async fn on_llm_start(&self, _messages: &[ChatMessage]) {}
18
19    async fn on_llm_complete(&self, _response: &LLMResponse, _duration_ms: u64) {}
20
21    async fn on_tool_start(&self, _tool: &str, _args: &Value) {}
22
23    async fn on_tool_complete(&self, _tool: &str, _result: &ToolResult, _duration_ms: u64) {}
24
25    async fn on_state_transition(&self, _from: Option<&str>, _to: &str, _reason: &str) {}
26
27    async fn on_error(&self, _error: &AgentError) {}
28
29    async fn on_response(&self, _response: &AgentResponse) {}
30
31    async fn on_approval_requested(&self, _request: &ApprovalRequest) {}
32
33    async fn on_approval_result(&self, _request_id: &str, _result: &ApprovalResult) {}
34
35    async fn on_memory_compress(&self, _event: &MemoryCompressEvent) {}
36
37    async fn on_memory_evict(&self, _event: &MemoryEvictEvent) {}
38
39    async fn on_memory_budget_warning(&self, _event: &MemoryBudgetEvent) {}
40
41    /// Fired when a delegated state starts forwarding to a registry agent.
42    async fn on_delegate_start(&self, _agent_id: &str, _state: &str) {}
43
44    /// Fired when a delegated state completes.
45    async fn on_delegate_complete(&self, _agent_id: &str, _state: &str, _duration_ms: u64) {}
46
47    /// Fired when a concurrent state completes aggregation.
48    async fn on_concurrent_complete(
49        &self,
50        _agent_ids: &[String],
51        _strategy: &str,
52        _duration_ms: u64,
53    ) {
54    }
55
56    /// Fired when a group chat round completes.
57    async fn on_group_chat_round(&self, _round: u32, _speaker: &str, _content: &str) {}
58
59    /// Fired after each pipeline stage completes.
60    async fn on_pipeline_stage(&self, _stage: usize, _agent_id: &str, _duration_ms: u64) {}
61
62    /// Fired when a pipeline completes all stages.
63    async fn on_pipeline_complete(&self, _stages: usize, _duration_ms: u64) {}
64
65    /// Fired when a handoff chain starts.
66    async fn on_handoff_start(&self, _initial_agent: &str) {}
67
68    /// Fired on each agent-to-agent control transfer.
69    async fn on_handoff(&self, _from: &str, _to: &str, _reason: &str) {}
70
71    /// Fired when a persona field is mutated via evolve().
72    async fn on_persona_evolve(
73        &self,
74        _field: &str,
75        _old_value: &Value,
76        _new_value: &Value,
77        _reason: Option<&str>,
78    ) {
79    }
80
81    /// Fired when a secret's reveal conditions are satisfied for the first time.
82    async fn on_secret_revealed(&self, _content: &str) {}
83}
84
85pub struct NoopHooks;
86
87#[async_trait]
88impl AgentHooks for NoopHooks {}
89
90pub struct LoggingHooks {
91    prefix: String,
92}
93
94impl LoggingHooks {
95    pub fn new() -> Self {
96        Self {
97            prefix: "[Agent]".to_string(),
98        }
99    }
100
101    pub fn with_prefix(prefix: impl Into<String>) -> Self {
102        Self {
103            prefix: prefix.into(),
104        }
105    }
106}
107
108impl Default for LoggingHooks {
109    fn default() -> Self {
110        Self::new()
111    }
112}
113
114#[async_trait]
115impl AgentHooks for LoggingHooks {
116    async fn on_message_received(&self, message: &str) {
117        let preview = if message.len() > 100 {
118            format!("{}...", &message[..100])
119        } else {
120            message.to_string()
121        };
122        info!("{} Message received: {}", self.prefix, preview);
123    }
124
125    async fn on_llm_start(&self, messages: &[ChatMessage]) {
126        debug!(
127            "{} LLM starting with {} messages",
128            self.prefix,
129            messages.len()
130        );
131    }
132
133    async fn on_llm_complete(&self, response: &LLMResponse, duration_ms: u64) {
134        info!(
135            "{} LLM complete in {}ms, tokens: {:?}",
136            self.prefix, duration_ms, response.usage
137        );
138    }
139
140    async fn on_tool_start(&self, tool: &str, args: &Value) {
141        debug!("{} Tool {} starting with args: {}", self.prefix, tool, args);
142    }
143
144    async fn on_tool_complete(&self, tool: &str, result: &ToolResult, duration_ms: u64) {
145        if result.success {
146            info!(
147                "{} Tool {} completed in {}ms",
148                self.prefix, tool, duration_ms
149            );
150        } else {
151            warn!(
152                "{} Tool {} failed in {}ms: {}",
153                self.prefix, tool, duration_ms, result.output
154            );
155        }
156    }
157
158    async fn on_state_transition(&self, from: Option<&str>, to: &str, reason: &str) {
159        info!(
160            "{} State transition: {:?} -> {} ({})",
161            self.prefix, from, to, reason
162        );
163    }
164
165    async fn on_error(&self, err: &AgentError) {
166        error!("{} Error: {}", self.prefix, err);
167    }
168
169    async fn on_response(&self, response: &AgentResponse) {
170        let preview = if response.content.len() > 100 {
171            format!("{}...", &response.content[..100])
172        } else {
173            response.content.clone()
174        };
175        debug!("{} Response: {}", self.prefix, preview);
176    }
177
178    async fn on_approval_requested(&self, request: &ApprovalRequest) {
179        info!(
180            "{} Approval requested [{}]: {}",
181            self.prefix, request.id, request.message
182        );
183    }
184
185    async fn on_approval_result(&self, request_id: &str, result: &ApprovalResult) {
186        match result {
187            ApprovalResult::Approved => {
188                info!("{} Approval [{}]: approved", self.prefix, request_id);
189            }
190            ApprovalResult::Rejected { reason } => {
191                warn!(
192                    "{} Approval [{}]: rejected ({:?})",
193                    self.prefix, request_id, reason
194                );
195            }
196            ApprovalResult::Modified { .. } => {
197                info!(
198                    "{} Approval [{}]: approved with modifications",
199                    self.prefix, request_id
200                );
201            }
202            ApprovalResult::Timeout => {
203                warn!("{} Approval [{}]: timeout", self.prefix, request_id);
204            }
205        }
206    }
207
208    async fn on_memory_compress(&self, event: &MemoryCompressEvent) {
209        info!(
210            "{} Memory compressed: {} messages, ratio: {:.2}",
211            self.prefix, event.messages_compressed, event.compression_ratio
212        );
213    }
214
215    async fn on_memory_evict(&self, event: &MemoryEvictEvent) {
216        warn!(
217            "{} Memory evicted: {} messages, reason: {:?}",
218            self.prefix, event.messages_evicted, event.reason
219        );
220    }
221
222    async fn on_memory_budget_warning(&self, event: &MemoryBudgetEvent) {
223        warn!(
224            "{} Memory budget warning: {} at {:.1}% ({}/{} tokens)",
225            self.prefix,
226            event.component,
227            event.usage_percent,
228            event.used_tokens,
229            event.budget_tokens
230        );
231    }
232
233    async fn on_delegate_start(&self, agent_id: &str, state: &str) {
234        info!(
235            "{} Delegation started: agent={}, state={}",
236            self.prefix, agent_id, state
237        );
238    }
239
240    async fn on_delegate_complete(&self, agent_id: &str, state: &str, duration_ms: u64) {
241        info!(
242            "{} Delegation complete: agent={}, state={}, duration={}ms",
243            self.prefix, agent_id, state, duration_ms
244        );
245    }
246
247    async fn on_concurrent_complete(&self, agent_ids: &[String], strategy: &str, duration_ms: u64) {
248        info!(
249            "{} Concurrent complete: agents={:?}, strategy={}, duration={}ms",
250            self.prefix, agent_ids, strategy, duration_ms
251        );
252    }
253
254    async fn on_group_chat_round(&self, round: u32, speaker: &str, content: &str) {
255        let preview = if content.len() > 80 {
256            format!("{}...", &content[..80])
257        } else {
258            content.to_string()
259        };
260        debug!(
261            "{} Group chat round {}: {} said: {}",
262            self.prefix, round, speaker, preview
263        );
264    }
265
266    async fn on_pipeline_stage(&self, stage: usize, agent_id: &str, duration_ms: u64) {
267        info!(
268            "{} Pipeline stage {}: agent={}, duration={}ms",
269            self.prefix, stage, agent_id, duration_ms
270        );
271    }
272
273    async fn on_pipeline_complete(&self, stages: usize, duration_ms: u64) {
274        info!(
275            "{} Pipeline complete: {} stages, duration={}ms",
276            self.prefix, stages, duration_ms
277        );
278    }
279
280    async fn on_handoff_start(&self, initial_agent: &str) {
281        info!(
282            "{} Handoff chain started: initial_agent={}",
283            self.prefix, initial_agent
284        );
285    }
286
287    async fn on_handoff(&self, from: &str, to: &str, reason: &str) {
288        info!("{} Handoff: {} -> {} ({})", self.prefix, from, to, reason);
289    }
290
291    async fn on_persona_evolve(
292        &self,
293        field: &str,
294        _old_value: &Value,
295        new_value: &Value,
296        reason: Option<&str>,
297    ) {
298        info!(
299            "{} Persona evolved: field={}, new_value={}, reason={}",
300            self.prefix,
301            field,
302            new_value,
303            reason.unwrap_or("(none)")
304        );
305    }
306
307    async fn on_secret_revealed(&self, content: &str) {
308        info!("{} Secret revealed: {}", self.prefix, content);
309    }
310}
311
312pub struct CompositeHooks {
313    hooks: Vec<Arc<dyn AgentHooks>>,
314}
315
316impl CompositeHooks {
317    pub fn new() -> Self {
318        Self { hooks: Vec::new() }
319    }
320
321    pub fn add(mut self, hooks: Arc<dyn AgentHooks>) -> Self {
322        self.hooks.push(hooks);
323        self
324    }
325
326    pub fn with_hooks(hooks: Vec<Arc<dyn AgentHooks>>) -> Self {
327        Self { hooks }
328    }
329}
330
331impl Default for CompositeHooks {
332    fn default() -> Self {
333        Self::new()
334    }
335}
336
337#[async_trait]
338impl AgentHooks for CompositeHooks {
339    async fn on_message_received(&self, message: &str) {
340        for hook in &self.hooks {
341            hook.on_message_received(message).await;
342        }
343    }
344
345    async fn on_llm_start(&self, messages: &[ChatMessage]) {
346        for hook in &self.hooks {
347            hook.on_llm_start(messages).await;
348        }
349    }
350
351    async fn on_llm_complete(&self, response: &LLMResponse, duration_ms: u64) {
352        for hook in &self.hooks {
353            hook.on_llm_complete(response, duration_ms).await;
354        }
355    }
356
357    async fn on_tool_start(&self, tool: &str, args: &Value) {
358        for hook in &self.hooks {
359            hook.on_tool_start(tool, args).await;
360        }
361    }
362
363    async fn on_tool_complete(&self, tool: &str, result: &ToolResult, duration_ms: u64) {
364        for hook in &self.hooks {
365            hook.on_tool_complete(tool, result, duration_ms).await;
366        }
367    }
368
369    async fn on_state_transition(&self, from: Option<&str>, to: &str, reason: &str) {
370        for hook in &self.hooks {
371            hook.on_state_transition(from, to, reason).await;
372        }
373    }
374
375    async fn on_error(&self, error: &AgentError) {
376        for hook in &self.hooks {
377            hook.on_error(error).await;
378        }
379    }
380
381    async fn on_response(&self, response: &AgentResponse) {
382        for hook in &self.hooks {
383            hook.on_response(response).await;
384        }
385    }
386
387    async fn on_approval_requested(&self, request: &ApprovalRequest) {
388        for hook in &self.hooks {
389            hook.on_approval_requested(request).await;
390        }
391    }
392
393    async fn on_approval_result(&self, request_id: &str, result: &ApprovalResult) {
394        for hook in &self.hooks {
395            hook.on_approval_result(request_id, result).await;
396        }
397    }
398
399    async fn on_memory_compress(&self, event: &MemoryCompressEvent) {
400        for hook in &self.hooks {
401            hook.on_memory_compress(event).await;
402        }
403    }
404
405    async fn on_memory_evict(&self, event: &MemoryEvictEvent) {
406        for hook in &self.hooks {
407            hook.on_memory_evict(event).await;
408        }
409    }
410
411    async fn on_memory_budget_warning(&self, event: &MemoryBudgetEvent) {
412        for hook in &self.hooks {
413            hook.on_memory_budget_warning(event).await;
414        }
415    }
416
417    async fn on_delegate_start(&self, agent_id: &str, state: &str) {
418        for hook in &self.hooks {
419            hook.on_delegate_start(agent_id, state).await;
420        }
421    }
422
423    async fn on_delegate_complete(&self, agent_id: &str, state: &str, duration_ms: u64) {
424        for hook in &self.hooks {
425            hook.on_delegate_complete(agent_id, state, duration_ms)
426                .await;
427        }
428    }
429
430    async fn on_concurrent_complete(&self, agent_ids: &[String], strategy: &str, duration_ms: u64) {
431        for hook in &self.hooks {
432            hook.on_concurrent_complete(agent_ids, strategy, duration_ms)
433                .await;
434        }
435    }
436
437    async fn on_group_chat_round(&self, round: u32, speaker: &str, content: &str) {
438        for hook in &self.hooks {
439            hook.on_group_chat_round(round, speaker, content).await;
440        }
441    }
442
443    async fn on_pipeline_stage(&self, stage: usize, agent_id: &str, duration_ms: u64) {
444        for hook in &self.hooks {
445            hook.on_pipeline_stage(stage, agent_id, duration_ms).await;
446        }
447    }
448
449    async fn on_pipeline_complete(&self, stages: usize, duration_ms: u64) {
450        for hook in &self.hooks {
451            hook.on_pipeline_complete(stages, duration_ms).await;
452        }
453    }
454
455    async fn on_handoff_start(&self, initial_agent: &str) {
456        for hook in &self.hooks {
457            hook.on_handoff_start(initial_agent).await;
458        }
459    }
460
461    async fn on_handoff(&self, _from: &str, _to: &str, _reason: &str) {
462        for hook in &self.hooks {
463            hook.on_handoff(_from, _to, _reason).await;
464        }
465    }
466
467    async fn on_persona_evolve(
468        &self,
469        _field: &str,
470        _old_value: &Value,
471        _new_value: &Value,
472        _reason: Option<&str>,
473    ) {
474        for hook in &self.hooks {
475            hook.on_persona_evolve(_field, _old_value, _new_value, _reason)
476                .await;
477        }
478    }
479
480    async fn on_secret_revealed(&self, _content: &str) {
481        for hook in &self.hooks {
482            hook.on_secret_revealed(_content).await;
483        }
484    }
485}
486
487pub struct HookTimer {
488    start: Instant,
489}
490
491impl HookTimer {
492    pub fn start() -> Self {
493        Self {
494            start: Instant::now(),
495        }
496    }
497
498    pub fn elapsed_ms(&self) -> u64 {
499        self.start.elapsed().as_millis() as u64
500    }
501}
502
503#[cfg(test)]
504mod tests {
505    use super::*;
506    use parking_lot::Mutex;
507
508    struct RecordingHooks {
509        events: Arc<Mutex<Vec<String>>>,
510    }
511
512    impl RecordingHooks {
513        fn new() -> Self {
514            Self {
515                events: Arc::new(Mutex::new(Vec::new())),
516            }
517        }
518
519        fn events(&self) -> Vec<String> {
520            self.events.lock().clone()
521        }
522    }
523
524    #[async_trait]
525    impl AgentHooks for RecordingHooks {
526        async fn on_message_received(&self, message: &str) {
527            self.events
528                .lock()
529                .push(format!("message_received:{}", message));
530        }
531
532        async fn on_llm_start(&self, messages: &[ChatMessage]) {
533            self.events
534                .lock()
535                .push(format!("llm_start:{}", messages.len()));
536        }
537
538        async fn on_llm_complete(&self, _response: &LLMResponse, duration_ms: u64) {
539            self.events
540                .lock()
541                .push(format!("llm_complete:{}", duration_ms));
542        }
543
544        async fn on_tool_start(&self, tool: &str, _args: &Value) {
545            self.events.lock().push(format!("tool_start:{}", tool));
546        }
547
548        async fn on_tool_complete(&self, tool: &str, result: &ToolResult, duration_ms: u64) {
549            self.events.lock().push(format!(
550                "tool_complete:{}:{}:{}",
551                tool, result.success, duration_ms
552            ));
553        }
554
555        async fn on_state_transition(&self, from: Option<&str>, to: &str, reason: &str) {
556            self.events
557                .lock()
558                .push(format!("state_transition:{:?}:{}:{}", from, to, reason));
559        }
560
561        async fn on_error(&self, error: &AgentError) {
562            self.events.lock().push(format!("error:{}", error));
563        }
564
565        async fn on_response(&self, response: &AgentResponse) {
566            self.events
567                .lock()
568                .push(format!("response:{}", response.content.len()));
569        }
570
571        async fn on_approval_requested(&self, request: &ApprovalRequest) {
572            self.events
573                .lock()
574                .push(format!("approval_requested:{}", request.id));
575        }
576
577        async fn on_approval_result(&self, request_id: &str, result: &ApprovalResult) {
578            let status = match result {
579                ApprovalResult::Approved => "approved",
580                ApprovalResult::Rejected { .. } => "rejected",
581                ApprovalResult::Modified { .. } => "modified",
582                ApprovalResult::Timeout => "timeout",
583            };
584            self.events
585                .lock()
586                .push(format!("approval_result:{}:{}", request_id, status));
587        }
588    }
589
590    #[tokio::test]
591    async fn test_noop_hooks() {
592        let hooks = NoopHooks;
593        hooks.on_message_received("test").await;
594        hooks.on_llm_start(&[]).await;
595    }
596
597    #[tokio::test]
598    async fn test_logging_hooks() {
599        let hooks = LoggingHooks::new();
600        hooks.on_message_received("test message").await;
601        hooks.on_llm_start(&[ChatMessage::user("hello")]).await;
602    }
603
604    #[tokio::test]
605    async fn test_recording_hooks() {
606        let hooks = RecordingHooks::new();
607
608        hooks.on_message_received("hello").await;
609        hooks.on_llm_start(&[ChatMessage::user("test")]).await;
610
611        let events = hooks.events();
612        assert_eq!(events.len(), 2);
613        assert!(events[0].contains("message_received"));
614        assert!(events[1].contains("llm_start"));
615    }
616
617    #[tokio::test]
618    async fn test_composite_hooks_with_vec() {
619        let hooks1 = Arc::new(RecordingHooks::new());
620        let hooks2 = Arc::new(RecordingHooks::new());
621
622        let composite = CompositeHooks::with_hooks(vec![
623            hooks1.clone() as Arc<dyn AgentHooks>,
624            hooks2.clone() as Arc<dyn AgentHooks>,
625        ]);
626
627        composite
628            .on_tool_start("calculator", &serde_json::json!({}))
629            .await;
630
631        assert_eq!(hooks1.events().len(), 1);
632        assert_eq!(hooks2.events().len(), 1);
633    }
634
635    #[tokio::test]
636    async fn test_hook_timer() {
637        let timer = HookTimer::start();
638        tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
639        let elapsed = timer.elapsed_ms();
640        assert!(elapsed >= 10);
641    }
642
643    #[test]
644    fn test_composite_hooks_default() {
645        let hooks = CompositeHooks::default();
646        assert!(hooks.hooks.is_empty());
647    }
648}