Skip to main content

adk_runner/
callbacks.rs

1use adk_core::{CallbackContext, Content, Result};
2use futures::future::BoxFuture;
3use std::sync::Arc;
4
5/// Callback executed before calling the model
6pub type BeforeModelCallback = Box<
7    dyn Fn(Arc<dyn CallbackContext>) -> BoxFuture<'static, Result<Option<Content>>> + Send + Sync,
8>;
9
10/// Callback executed after model response
11pub type AfterModelCallback = Box<
12    dyn Fn(Arc<dyn CallbackContext>) -> BoxFuture<'static, Result<Option<Content>>> + Send + Sync,
13>;
14
15/// Callback executed before tool execution
16pub type BeforeToolCallback = Box<
17    dyn Fn(Arc<dyn CallbackContext>) -> BoxFuture<'static, Result<Option<Content>>> + Send + Sync,
18>;
19
20/// Callback executed after tool execution
21pub type AfterToolCallback = Box<
22    dyn Fn(Arc<dyn CallbackContext>) -> BoxFuture<'static, Result<Option<Content>>> + Send + Sync,
23>;
24
25/// Collection of all callback types for intercepting model and tool execution.
26pub struct Callbacks {
27    /// Callbacks invoked before each model call.
28    pub before_model: Vec<BeforeModelCallback>,
29    /// Callbacks invoked after each model call.
30    pub after_model: Vec<AfterModelCallback>,
31    /// Callbacks invoked before each tool execution.
32    pub before_tool: Vec<BeforeToolCallback>,
33    /// Callbacks invoked after each tool execution.
34    pub after_tool: Vec<AfterToolCallback>,
35}
36
37impl Default for Callbacks {
38    fn default() -> Self {
39        Self::new()
40    }
41}
42
43impl Callbacks {
44    /// Create an empty callback collection.
45    pub fn new() -> Self {
46        Self {
47            before_model: Vec::new(),
48            after_model: Vec::new(),
49            before_tool: Vec::new(),
50            after_tool: Vec::new(),
51        }
52    }
53
54    /// Register a callback to run before each model call.
55    pub fn add_before_model(&mut self, callback: BeforeModelCallback) {
56        self.before_model.push(callback);
57    }
58
59    /// Register a callback to run after each model call.
60    pub fn add_after_model(&mut self, callback: AfterModelCallback) {
61        self.after_model.push(callback);
62    }
63
64    /// Register a callback to run before each tool execution.
65    pub fn add_before_tool(&mut self, callback: BeforeToolCallback) {
66        self.before_tool.push(callback);
67    }
68
69    /// Register a callback to run after each tool execution.
70    pub fn add_after_tool(&mut self, callback: AfterToolCallback) {
71        self.after_tool.push(callback);
72    }
73
74    /// Execute all before_model callbacks
75    pub async fn execute_before_model(
76        &self,
77        ctx: Arc<dyn CallbackContext>,
78    ) -> Result<Vec<Content>> {
79        let mut results = Vec::new();
80        for callback in &self.before_model {
81            if let Some(content) = callback(ctx.clone()).await? {
82                results.push(content);
83            }
84        }
85        Ok(results)
86    }
87
88    /// Execute all after_model callbacks
89    pub async fn execute_after_model(&self, ctx: Arc<dyn CallbackContext>) -> Result<Vec<Content>> {
90        let mut results = Vec::new();
91        for callback in &self.after_model {
92            if let Some(content) = callback(ctx.clone()).await? {
93                results.push(content);
94            }
95        }
96        Ok(results)
97    }
98
99    /// Execute all before_tool callbacks
100    pub async fn execute_before_tool(&self, ctx: Arc<dyn CallbackContext>) -> Result<Vec<Content>> {
101        let mut results = Vec::new();
102        for callback in &self.before_tool {
103            if let Some(content) = callback(ctx.clone()).await? {
104                results.push(content);
105            }
106        }
107        Ok(results)
108    }
109
110    /// Execute all after_tool callbacks
111    pub async fn execute_after_tool(&self, ctx: Arc<dyn CallbackContext>) -> Result<Vec<Content>> {
112        let mut results = Vec::new();
113        for callback in &self.after_tool {
114            if let Some(content) = callback(ctx.clone()).await? {
115                results.push(content);
116            }
117        }
118        Ok(results)
119    }
120}