1use async_trait::async_trait;
2use serde_json::Value;
3use std::sync::Arc;
4use std::time::Instant;
5use tracing::{debug, error, info, warn};
6
7use ai_agents_core::{AgentError, AgentResponse};
8use ai_agents_hitl::{ApprovalRequest, ApprovalResult};
9use ai_agents_llm::{ChatMessage, LLMResponse};
10use ai_agents_memory::{MemoryBudgetEvent, MemoryCompressEvent, MemoryEvictEvent};
11use ai_agents_tools::ToolResult;
12
13#[async_trait]
14pub trait AgentHooks: Send + Sync {
15 async fn on_message_received(&self, _message: &str) {}
16
17 async fn on_llm_start(&self, _messages: &[ChatMessage]) {}
18
19 async fn on_llm_complete(&self, _response: &LLMResponse, _duration_ms: u64) {}
20
21 async fn on_tool_start(&self, _tool: &str, _args: &Value) {}
22
23 async fn on_tool_complete(&self, _tool: &str, _result: &ToolResult, _duration_ms: u64) {}
24
25 async fn on_state_transition(&self, _from: Option<&str>, _to: &str, _reason: &str) {}
26
27 async fn on_error(&self, _error: &AgentError) {}
28
29 async fn on_response(&self, _response: &AgentResponse) {}
30
31 async fn on_approval_requested(&self, _request: &ApprovalRequest) {}
32
33 async fn on_approval_result(&self, _request_id: &str, _result: &ApprovalResult) {}
34
35 async fn on_memory_compress(&self, _event: &MemoryCompressEvent) {}
36
37 async fn on_memory_evict(&self, _event: &MemoryEvictEvent) {}
38
39 async fn on_memory_budget_warning(&self, _event: &MemoryBudgetEvent) {}
40
41 async fn on_delegate_start(&self, _agent_id: &str, _state: &str) {}
43
44 async fn on_delegate_complete(&self, _agent_id: &str, _state: &str, _duration_ms: u64) {}
46
47 async fn on_concurrent_complete(
49 &self,
50 _agent_ids: &[String],
51 _strategy: &str,
52 _duration_ms: u64,
53 ) {
54 }
55
56 async fn on_group_chat_round(&self, _round: u32, _speaker: &str, _content: &str) {}
58
59 async fn on_pipeline_stage(&self, _stage: usize, _agent_id: &str, _duration_ms: u64) {}
61
62 async fn on_pipeline_complete(&self, _stages: usize, _duration_ms: u64) {}
64
65 async fn on_handoff_start(&self, _initial_agent: &str) {}
67
68 async fn on_handoff(&self, _from: &str, _to: &str, _reason: &str) {}
70
71 async fn on_persona_evolve(
73 &self,
74 _field: &str,
75 _old_value: &Value,
76 _new_value: &Value,
77 _reason: Option<&str>,
78 ) {
79 }
80
81 async fn on_secret_revealed(&self, _content: &str) {}
83}
84
85pub struct NoopHooks;
86
87#[async_trait]
88impl AgentHooks for NoopHooks {}
89
90pub struct LoggingHooks {
91 prefix: String,
92}
93
94impl LoggingHooks {
95 pub fn new() -> Self {
96 Self {
97 prefix: "[Agent]".to_string(),
98 }
99 }
100
101 pub fn with_prefix(prefix: impl Into<String>) -> Self {
102 Self {
103 prefix: prefix.into(),
104 }
105 }
106}
107
108impl Default for LoggingHooks {
109 fn default() -> Self {
110 Self::new()
111 }
112}
113
114#[async_trait]
115impl AgentHooks for LoggingHooks {
116 async fn on_message_received(&self, message: &str) {
117 let preview = if message.len() > 100 {
118 format!("{}...", &message[..100])
119 } else {
120 message.to_string()
121 };
122 info!("{} Message received: {}", self.prefix, preview);
123 }
124
125 async fn on_llm_start(&self, messages: &[ChatMessage]) {
126 debug!(
127 "{} LLM starting with {} messages",
128 self.prefix,
129 messages.len()
130 );
131 }
132
133 async fn on_llm_complete(&self, response: &LLMResponse, duration_ms: u64) {
134 info!(
135 "{} LLM complete in {}ms, tokens: {:?}",
136 self.prefix, duration_ms, response.usage
137 );
138 }
139
140 async fn on_tool_start(&self, tool: &str, args: &Value) {
141 debug!("{} Tool {} starting with args: {}", self.prefix, tool, args);
142 }
143
144 async fn on_tool_complete(&self, tool: &str, result: &ToolResult, duration_ms: u64) {
145 if result.success {
146 info!(
147 "{} Tool {} completed in {}ms",
148 self.prefix, tool, duration_ms
149 );
150 } else {
151 warn!(
152 "{} Tool {} failed in {}ms: {}",
153 self.prefix, tool, duration_ms, result.output
154 );
155 }
156 }
157
158 async fn on_state_transition(&self, from: Option<&str>, to: &str, reason: &str) {
159 info!(
160 "{} State transition: {:?} -> {} ({})",
161 self.prefix, from, to, reason
162 );
163 }
164
165 async fn on_error(&self, err: &AgentError) {
166 error!("{} Error: {}", self.prefix, err);
167 }
168
169 async fn on_response(&self, response: &AgentResponse) {
170 let preview = if response.content.len() > 100 {
171 format!("{}...", &response.content[..100])
172 } else {
173 response.content.clone()
174 };
175 debug!("{} Response: {}", self.prefix, preview);
176 }
177
178 async fn on_approval_requested(&self, request: &ApprovalRequest) {
179 info!(
180 "{} Approval requested [{}]: {}",
181 self.prefix, request.id, request.message
182 );
183 }
184
185 async fn on_approval_result(&self, request_id: &str, result: &ApprovalResult) {
186 match result {
187 ApprovalResult::Approved => {
188 info!("{} Approval [{}]: approved", self.prefix, request_id);
189 }
190 ApprovalResult::Rejected { reason } => {
191 warn!(
192 "{} Approval [{}]: rejected ({:?})",
193 self.prefix, request_id, reason
194 );
195 }
196 ApprovalResult::Modified { .. } => {
197 info!(
198 "{} Approval [{}]: approved with modifications",
199 self.prefix, request_id
200 );
201 }
202 ApprovalResult::Timeout => {
203 warn!("{} Approval [{}]: timeout", self.prefix, request_id);
204 }
205 }
206 }
207
208 async fn on_memory_compress(&self, event: &MemoryCompressEvent) {
209 info!(
210 "{} Memory compressed: {} messages, ratio: {:.2}",
211 self.prefix, event.messages_compressed, event.compression_ratio
212 );
213 }
214
215 async fn on_memory_evict(&self, event: &MemoryEvictEvent) {
216 warn!(
217 "{} Memory evicted: {} messages, reason: {:?}",
218 self.prefix, event.messages_evicted, event.reason
219 );
220 }
221
222 async fn on_memory_budget_warning(&self, event: &MemoryBudgetEvent) {
223 warn!(
224 "{} Memory budget warning: {} at {:.1}% ({}/{} tokens)",
225 self.prefix,
226 event.component,
227 event.usage_percent,
228 event.used_tokens,
229 event.budget_tokens
230 );
231 }
232
233 async fn on_delegate_start(&self, agent_id: &str, state: &str) {
234 info!(
235 "{} Delegation started: agent={}, state={}",
236 self.prefix, agent_id, state
237 );
238 }
239
240 async fn on_delegate_complete(&self, agent_id: &str, state: &str, duration_ms: u64) {
241 info!(
242 "{} Delegation complete: agent={}, state={}, duration={}ms",
243 self.prefix, agent_id, state, duration_ms
244 );
245 }
246
247 async fn on_concurrent_complete(&self, agent_ids: &[String], strategy: &str, duration_ms: u64) {
248 info!(
249 "{} Concurrent complete: agents={:?}, strategy={}, duration={}ms",
250 self.prefix, agent_ids, strategy, duration_ms
251 );
252 }
253
254 async fn on_group_chat_round(&self, round: u32, speaker: &str, content: &str) {
255 let preview = if content.len() > 80 {
256 format!("{}...", &content[..80])
257 } else {
258 content.to_string()
259 };
260 debug!(
261 "{} Group chat round {}: {} said: {}",
262 self.prefix, round, speaker, preview
263 );
264 }
265
266 async fn on_pipeline_stage(&self, stage: usize, agent_id: &str, duration_ms: u64) {
267 info!(
268 "{} Pipeline stage {}: agent={}, duration={}ms",
269 self.prefix, stage, agent_id, duration_ms
270 );
271 }
272
273 async fn on_pipeline_complete(&self, stages: usize, duration_ms: u64) {
274 info!(
275 "{} Pipeline complete: {} stages, duration={}ms",
276 self.prefix, stages, duration_ms
277 );
278 }
279
280 async fn on_handoff_start(&self, initial_agent: &str) {
281 info!(
282 "{} Handoff chain started: initial_agent={}",
283 self.prefix, initial_agent
284 );
285 }
286
287 async fn on_handoff(&self, from: &str, to: &str, reason: &str) {
288 info!("{} Handoff: {} -> {} ({})", self.prefix, from, to, reason);
289 }
290
291 async fn on_persona_evolve(
292 &self,
293 field: &str,
294 _old_value: &Value,
295 new_value: &Value,
296 reason: Option<&str>,
297 ) {
298 info!(
299 "{} Persona evolved: field={}, new_value={}, reason={}",
300 self.prefix,
301 field,
302 new_value,
303 reason.unwrap_or("(none)")
304 );
305 }
306
307 async fn on_secret_revealed(&self, content: &str) {
308 info!("{} Secret revealed: {}", self.prefix, content);
309 }
310}
311
312pub struct CompositeHooks {
313 hooks: Vec<Arc<dyn AgentHooks>>,
314}
315
316impl CompositeHooks {
317 pub fn new() -> Self {
318 Self { hooks: Vec::new() }
319 }
320
321 pub fn add(mut self, hooks: Arc<dyn AgentHooks>) -> Self {
322 self.hooks.push(hooks);
323 self
324 }
325
326 pub fn with_hooks(hooks: Vec<Arc<dyn AgentHooks>>) -> Self {
327 Self { hooks }
328 }
329}
330
331impl Default for CompositeHooks {
332 fn default() -> Self {
333 Self::new()
334 }
335}
336
337#[async_trait]
338impl AgentHooks for CompositeHooks {
339 async fn on_message_received(&self, message: &str) {
340 for hook in &self.hooks {
341 hook.on_message_received(message).await;
342 }
343 }
344
345 async fn on_llm_start(&self, messages: &[ChatMessage]) {
346 for hook in &self.hooks {
347 hook.on_llm_start(messages).await;
348 }
349 }
350
351 async fn on_llm_complete(&self, response: &LLMResponse, duration_ms: u64) {
352 for hook in &self.hooks {
353 hook.on_llm_complete(response, duration_ms).await;
354 }
355 }
356
357 async fn on_tool_start(&self, tool: &str, args: &Value) {
358 for hook in &self.hooks {
359 hook.on_tool_start(tool, args).await;
360 }
361 }
362
363 async fn on_tool_complete(&self, tool: &str, result: &ToolResult, duration_ms: u64) {
364 for hook in &self.hooks {
365 hook.on_tool_complete(tool, result, duration_ms).await;
366 }
367 }
368
369 async fn on_state_transition(&self, from: Option<&str>, to: &str, reason: &str) {
370 for hook in &self.hooks {
371 hook.on_state_transition(from, to, reason).await;
372 }
373 }
374
375 async fn on_error(&self, error: &AgentError) {
376 for hook in &self.hooks {
377 hook.on_error(error).await;
378 }
379 }
380
381 async fn on_response(&self, response: &AgentResponse) {
382 for hook in &self.hooks {
383 hook.on_response(response).await;
384 }
385 }
386
387 async fn on_approval_requested(&self, request: &ApprovalRequest) {
388 for hook in &self.hooks {
389 hook.on_approval_requested(request).await;
390 }
391 }
392
393 async fn on_approval_result(&self, request_id: &str, result: &ApprovalResult) {
394 for hook in &self.hooks {
395 hook.on_approval_result(request_id, result).await;
396 }
397 }
398
399 async fn on_memory_compress(&self, event: &MemoryCompressEvent) {
400 for hook in &self.hooks {
401 hook.on_memory_compress(event).await;
402 }
403 }
404
405 async fn on_memory_evict(&self, event: &MemoryEvictEvent) {
406 for hook in &self.hooks {
407 hook.on_memory_evict(event).await;
408 }
409 }
410
411 async fn on_memory_budget_warning(&self, event: &MemoryBudgetEvent) {
412 for hook in &self.hooks {
413 hook.on_memory_budget_warning(event).await;
414 }
415 }
416
417 async fn on_delegate_start(&self, agent_id: &str, state: &str) {
418 for hook in &self.hooks {
419 hook.on_delegate_start(agent_id, state).await;
420 }
421 }
422
423 async fn on_delegate_complete(&self, agent_id: &str, state: &str, duration_ms: u64) {
424 for hook in &self.hooks {
425 hook.on_delegate_complete(agent_id, state, duration_ms)
426 .await;
427 }
428 }
429
430 async fn on_concurrent_complete(&self, agent_ids: &[String], strategy: &str, duration_ms: u64) {
431 for hook in &self.hooks {
432 hook.on_concurrent_complete(agent_ids, strategy, duration_ms)
433 .await;
434 }
435 }
436
437 async fn on_group_chat_round(&self, round: u32, speaker: &str, content: &str) {
438 for hook in &self.hooks {
439 hook.on_group_chat_round(round, speaker, content).await;
440 }
441 }
442
443 async fn on_pipeline_stage(&self, stage: usize, agent_id: &str, duration_ms: u64) {
444 for hook in &self.hooks {
445 hook.on_pipeline_stage(stage, agent_id, duration_ms).await;
446 }
447 }
448
449 async fn on_pipeline_complete(&self, stages: usize, duration_ms: u64) {
450 for hook in &self.hooks {
451 hook.on_pipeline_complete(stages, duration_ms).await;
452 }
453 }
454
455 async fn on_handoff_start(&self, initial_agent: &str) {
456 for hook in &self.hooks {
457 hook.on_handoff_start(initial_agent).await;
458 }
459 }
460
461 async fn on_handoff(&self, _from: &str, _to: &str, _reason: &str) {
462 for hook in &self.hooks {
463 hook.on_handoff(_from, _to, _reason).await;
464 }
465 }
466
467 async fn on_persona_evolve(
468 &self,
469 _field: &str,
470 _old_value: &Value,
471 _new_value: &Value,
472 _reason: Option<&str>,
473 ) {
474 for hook in &self.hooks {
475 hook.on_persona_evolve(_field, _old_value, _new_value, _reason)
476 .await;
477 }
478 }
479
480 async fn on_secret_revealed(&self, _content: &str) {
481 for hook in &self.hooks {
482 hook.on_secret_revealed(_content).await;
483 }
484 }
485}
486
487pub struct HookTimer {
488 start: Instant,
489}
490
491impl HookTimer {
492 pub fn start() -> Self {
493 Self {
494 start: Instant::now(),
495 }
496 }
497
498 pub fn elapsed_ms(&self) -> u64 {
499 self.start.elapsed().as_millis() as u64
500 }
501}
502
503#[cfg(test)]
504mod tests {
505 use super::*;
506 use parking_lot::Mutex;
507
508 struct RecordingHooks {
509 events: Arc<Mutex<Vec<String>>>,
510 }
511
512 impl RecordingHooks {
513 fn new() -> Self {
514 Self {
515 events: Arc::new(Mutex::new(Vec::new())),
516 }
517 }
518
519 fn events(&self) -> Vec<String> {
520 self.events.lock().clone()
521 }
522 }
523
524 #[async_trait]
525 impl AgentHooks for RecordingHooks {
526 async fn on_message_received(&self, message: &str) {
527 self.events
528 .lock()
529 .push(format!("message_received:{}", message));
530 }
531
532 async fn on_llm_start(&self, messages: &[ChatMessage]) {
533 self.events
534 .lock()
535 .push(format!("llm_start:{}", messages.len()));
536 }
537
538 async fn on_llm_complete(&self, _response: &LLMResponse, duration_ms: u64) {
539 self.events
540 .lock()
541 .push(format!("llm_complete:{}", duration_ms));
542 }
543
544 async fn on_tool_start(&self, tool: &str, _args: &Value) {
545 self.events.lock().push(format!("tool_start:{}", tool));
546 }
547
548 async fn on_tool_complete(&self, tool: &str, result: &ToolResult, duration_ms: u64) {
549 self.events.lock().push(format!(
550 "tool_complete:{}:{}:{}",
551 tool, result.success, duration_ms
552 ));
553 }
554
555 async fn on_state_transition(&self, from: Option<&str>, to: &str, reason: &str) {
556 self.events
557 .lock()
558 .push(format!("state_transition:{:?}:{}:{}", from, to, reason));
559 }
560
561 async fn on_error(&self, error: &AgentError) {
562 self.events.lock().push(format!("error:{}", error));
563 }
564
565 async fn on_response(&self, response: &AgentResponse) {
566 self.events
567 .lock()
568 .push(format!("response:{}", response.content.len()));
569 }
570
571 async fn on_approval_requested(&self, request: &ApprovalRequest) {
572 self.events
573 .lock()
574 .push(format!("approval_requested:{}", request.id));
575 }
576
577 async fn on_approval_result(&self, request_id: &str, result: &ApprovalResult) {
578 let status = match result {
579 ApprovalResult::Approved => "approved",
580 ApprovalResult::Rejected { .. } => "rejected",
581 ApprovalResult::Modified { .. } => "modified",
582 ApprovalResult::Timeout => "timeout",
583 };
584 self.events
585 .lock()
586 .push(format!("approval_result:{}:{}", request_id, status));
587 }
588 }
589
590 #[tokio::test]
591 async fn test_noop_hooks() {
592 let hooks = NoopHooks;
593 hooks.on_message_received("test").await;
594 hooks.on_llm_start(&[]).await;
595 }
596
597 #[tokio::test]
598 async fn test_logging_hooks() {
599 let hooks = LoggingHooks::new();
600 hooks.on_message_received("test message").await;
601 hooks.on_llm_start(&[ChatMessage::user("hello")]).await;
602 }
603
604 #[tokio::test]
605 async fn test_recording_hooks() {
606 let hooks = RecordingHooks::new();
607
608 hooks.on_message_received("hello").await;
609 hooks.on_llm_start(&[ChatMessage::user("test")]).await;
610
611 let events = hooks.events();
612 assert_eq!(events.len(), 2);
613 assert!(events[0].contains("message_received"));
614 assert!(events[1].contains("llm_start"));
615 }
616
617 #[tokio::test]
618 async fn test_composite_hooks_with_vec() {
619 let hooks1 = Arc::new(RecordingHooks::new());
620 let hooks2 = Arc::new(RecordingHooks::new());
621
622 let composite = CompositeHooks::with_hooks(vec![
623 hooks1.clone() as Arc<dyn AgentHooks>,
624 hooks2.clone() as Arc<dyn AgentHooks>,
625 ]);
626
627 composite
628 .on_tool_start("calculator", &serde_json::json!({}))
629 .await;
630
631 assert_eq!(hooks1.events().len(), 1);
632 assert_eq!(hooks2.events().len(), 1);
633 }
634
635 #[tokio::test]
636 async fn test_hook_timer() {
637 let timer = HookTimer::start();
638 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
639 let elapsed = timer.elapsed_ms();
640 assert!(elapsed >= 10);
641 }
642
643 #[test]
644 fn test_composite_hooks_default() {
645 let hooks = CompositeHooks::default();
646 assert!(hooks.hooks.is_empty());
647 }
648}