Skip to main content

adk_plugin/
manager.rs

1//! Plugin Manager
2//!
3//! Coordinates execution of callbacks across all registered plugins.
4
5use crate::Plugin;
6use adk_core::{
7    BeforeModelResult, CallbackContext, Content, Event, InvocationContext, LlmRequest, LlmResponse,
8    Result, Tool,
9};
10use std::sync::Arc;
11use std::time::Duration;
12use tracing::{debug, warn};
13
14/// Configuration for the PluginManager.
15#[derive(Clone)]
16pub struct PluginManagerConfig {
17    /// Timeout for closing plugins during shutdown.
18    pub close_timeout: Duration,
19}
20
21impl Default for PluginManagerConfig {
22    fn default() -> Self {
23        Self { close_timeout: Duration::from_secs(5) }
24    }
25}
26
27/// Manages a collection of plugins and coordinates callback execution.
28///
29/// The PluginManager runs callbacks from all registered plugins in order.
30/// For callbacks that can modify data (like on_user_message), the first
31/// plugin to return a modification wins.
32///
33/// # Example
34///
35/// ```rust,ignore
36/// use adk_plugin::{Plugin, PluginManager, PluginConfig};
37///
38/// let plugins = vec![
39///     Plugin::new(PluginConfig {
40///         name: "logging".to_string(),
41///         on_event: Some(log_events()),
42///         ..Default::default()
43///     }),
44///     Plugin::new(PluginConfig {
45///         name: "metrics".to_string(),
46///         before_run: Some(start_timer()),
47///         after_run: Some(stop_timer()),
48///         ..Default::default()
49///     }),
50/// ];
51///
52/// let manager = PluginManager::new(plugins);
53/// ```
54pub struct PluginManager {
55    plugins: Vec<Plugin>,
56    config: PluginManagerConfig,
57}
58
59impl PluginManager {
60    /// Create a new plugin manager with the given plugins.
61    pub fn new(plugins: Vec<Plugin>) -> Self {
62        Self { plugins, config: PluginManagerConfig::default() }
63    }
64
65    /// Create a new plugin manager with custom configuration.
66    pub fn with_config(plugins: Vec<Plugin>, config: PluginManagerConfig) -> Self {
67        Self { plugins, config }
68    }
69
70    /// Get the number of registered plugins.
71    pub fn plugin_count(&self) -> usize {
72        self.plugins.len()
73    }
74
75    /// Get plugin names.
76    pub fn plugin_names(&self) -> Vec<&str> {
77        self.plugins.iter().map(|p| p.name()).collect()
78    }
79
80    /// Run on_user_message callbacks from all plugins.
81    ///
82    /// Returns the modified content if any plugin modified it.
83    pub async fn run_on_user_message(
84        &self,
85        ctx: Arc<dyn InvocationContext>,
86        content: Content,
87    ) -> Result<Option<Content>> {
88        let mut current_content = content;
89        let mut was_modified = false;
90
91        for plugin in &self.plugins {
92            if let Some(callback) = plugin.on_user_message() {
93                debug!(plugin = plugin.name(), "Running on_user_message callback");
94                match callback(ctx.clone(), current_content.clone()).await {
95                    Ok(Some(modified)) => {
96                        debug!(plugin = plugin.name(), "Content modified by plugin");
97                        was_modified = true;
98                        current_content = modified;
99                    }
100                    Ok(None) => {
101                        // Continue with current content
102                    }
103                    Err(e) => {
104                        warn!(plugin = plugin.name(), error = %e, "on_user_message callback failed");
105                        return Err(e);
106                    }
107                }
108            }
109        }
110
111        Ok(if was_modified { Some(current_content) } else { None })
112    }
113
114    /// Run on_event callbacks from all plugins.
115    ///
116    /// Returns the modified event if any plugin modified it.
117    pub async fn run_on_event(
118        &self,
119        ctx: Arc<dyn InvocationContext>,
120        event: Event,
121    ) -> Result<Option<Event>> {
122        let mut current_event = event;
123        let mut was_modified = false;
124
125        for plugin in &self.plugins {
126            if let Some(callback) = plugin.on_event() {
127                debug!(plugin = plugin.name(), event_id = %current_event.id, "Running on_event callback");
128                match callback(ctx.clone(), current_event.clone()).await {
129                    Ok(Some(modified)) => {
130                        debug!(plugin = plugin.name(), "Event modified by plugin");
131                        was_modified = true;
132                        current_event = modified;
133                    }
134                    Ok(None) => {
135                        // Continue with current event
136                    }
137                    Err(e) => {
138                        warn!(plugin = plugin.name(), error = %e, "on_event callback failed");
139                        return Err(e);
140                    }
141                }
142            }
143        }
144
145        Ok(if was_modified { Some(current_event) } else { None })
146    }
147
148    /// Run before_run callbacks from all plugins.
149    ///
150    /// If any plugin returns content, the run should be skipped.
151    pub async fn run_before_run(&self, ctx: Arc<dyn InvocationContext>) -> Result<Option<Content>> {
152        for plugin in &self.plugins {
153            if let Some(callback) = plugin.before_run() {
154                debug!(plugin = plugin.name(), "Running before_run callback");
155                match callback(ctx.clone()).await {
156                    Ok(Some(content)) => {
157                        debug!(plugin = plugin.name(), "before_run returned early exit content");
158                        return Ok(Some(content));
159                    }
160                    Ok(None) => {
161                        // Continue to next plugin
162                    }
163                    Err(e) => {
164                        warn!(plugin = plugin.name(), error = %e, "before_run callback failed");
165                        return Err(e);
166                    }
167                }
168            }
169        }
170
171        Ok(None)
172    }
173
174    /// Run after_run callbacks from all plugins.
175    ///
176    /// This does NOT emit events - it's for cleanup/metrics only.
177    pub async fn run_after_run(&self, ctx: Arc<dyn InvocationContext>) {
178        for plugin in &self.plugins {
179            if let Some(callback) = plugin.after_run() {
180                debug!(plugin = plugin.name(), "Running after_run callback");
181                callback(ctx.clone()).await;
182            }
183        }
184    }
185
186    /// Run before_agent callbacks from all plugins.
187    ///
188    /// If any plugin returns content, the agent run should be skipped.
189    pub async fn run_before_agent(&self, ctx: Arc<dyn CallbackContext>) -> Result<Option<Content>> {
190        for plugin in &self.plugins {
191            if let Some(callback) = plugin.before_agent() {
192                debug!(plugin = plugin.name(), "Running before_agent callback");
193                match callback(ctx.clone()).await {
194                    Ok(Some(content)) => {
195                        debug!(plugin = plugin.name(), "before_agent returned early exit content");
196                        return Ok(Some(content));
197                    }
198                    Ok(None) => {
199                        // Continue to next plugin
200                    }
201                    Err(e) => {
202                        warn!(plugin = plugin.name(), error = %e, "before_agent callback failed");
203                        return Err(e);
204                    }
205                }
206            }
207        }
208
209        Ok(None)
210    }
211
212    /// Run after_agent callbacks from all plugins.
213    pub async fn run_after_agent(&self, ctx: Arc<dyn CallbackContext>) -> Result<Option<Content>> {
214        for plugin in &self.plugins {
215            if let Some(callback) = plugin.after_agent() {
216                debug!(plugin = plugin.name(), "Running after_agent callback");
217                match callback(ctx.clone()).await {
218                    Ok(Some(content)) => {
219                        debug!(plugin = plugin.name(), "after_agent returned content");
220                        return Ok(Some(content));
221                    }
222                    Ok(None) => {
223                        // Continue to next plugin
224                    }
225                    Err(e) => {
226                        warn!(plugin = plugin.name(), error = %e, "after_agent callback failed");
227                        return Err(e);
228                    }
229                }
230            }
231        }
232
233        Ok(None)
234    }
235
236    /// Run before_model callbacks from all plugins.
237    ///
238    /// Callbacks can modify the request or skip the model call.
239    pub async fn run_before_model(
240        &self,
241        ctx: Arc<dyn CallbackContext>,
242        request: LlmRequest,
243    ) -> Result<BeforeModelResult> {
244        let mut current_request = request;
245
246        for plugin in &self.plugins {
247            if let Some(callback) = plugin.before_model() {
248                debug!(plugin = plugin.name(), "Running before_model callback");
249                match callback(ctx.clone(), current_request.clone()).await {
250                    Ok(BeforeModelResult::Continue(modified)) => {
251                        current_request = modified;
252                    }
253                    Ok(BeforeModelResult::Skip(response)) => {
254                        debug!(plugin = plugin.name(), "before_model skipped model call");
255                        return Ok(BeforeModelResult::Skip(response));
256                    }
257                    Err(e) => {
258                        warn!(plugin = plugin.name(), error = %e, "before_model callback failed");
259                        return Err(e);
260                    }
261                }
262            }
263        }
264
265        Ok(BeforeModelResult::Continue(current_request))
266    }
267
268    /// Run after_model callbacks from all plugins.
269    pub async fn run_after_model(
270        &self,
271        ctx: Arc<dyn CallbackContext>,
272        response: LlmResponse,
273    ) -> Result<Option<LlmResponse>> {
274        let mut current_response = response;
275        let mut was_modified = false;
276
277        for plugin in &self.plugins {
278            if let Some(callback) = plugin.after_model() {
279                debug!(plugin = plugin.name(), "Running after_model callback");
280                match callback(ctx.clone(), current_response.clone()).await {
281                    Ok(Some(modified)) => {
282                        was_modified = true;
283                        current_response = modified;
284                    }
285                    Ok(None) => {
286                        // Continue with current response
287                    }
288                    Err(e) => {
289                        warn!(plugin = plugin.name(), error = %e, "after_model callback failed");
290                        return Err(e);
291                    }
292                }
293            }
294        }
295
296        Ok(if was_modified { Some(current_response) } else { None })
297    }
298
299    /// Run on_model_error callbacks from all plugins.
300    pub async fn run_on_model_error(
301        &self,
302        ctx: Arc<dyn CallbackContext>,
303        request: LlmRequest,
304        error: String,
305    ) -> Result<Option<LlmResponse>> {
306        for plugin in &self.plugins {
307            if let Some(callback) = plugin.on_model_error() {
308                debug!(plugin = plugin.name(), "Running on_model_error callback");
309                match callback(ctx.clone(), request.clone(), error.clone()).await {
310                    Ok(Some(response)) => {
311                        debug!(plugin = plugin.name(), "on_model_error provided fallback response");
312                        return Ok(Some(response));
313                    }
314                    Ok(None) => {
315                        // Continue to next plugin
316                    }
317                    Err(e) => {
318                        warn!(plugin = plugin.name(), error = %e, "on_model_error callback failed");
319                        return Err(e);
320                    }
321                }
322            }
323        }
324
325        Ok(None)
326    }
327
328    /// Run before_tool callbacks from all plugins.
329    pub async fn run_before_tool(&self, ctx: Arc<dyn CallbackContext>) -> Result<Option<Content>> {
330        for plugin in &self.plugins {
331            if let Some(callback) = plugin.before_tool() {
332                debug!(plugin = plugin.name(), "Running before_tool callback");
333                match callback(ctx.clone()).await {
334                    Ok(Some(content)) => {
335                        debug!(plugin = plugin.name(), "before_tool returned early exit content");
336                        return Ok(Some(content));
337                    }
338                    Ok(None) => {
339                        // Continue to next plugin
340                    }
341                    Err(e) => {
342                        warn!(plugin = plugin.name(), error = %e, "before_tool callback failed");
343                        return Err(e);
344                    }
345                }
346            }
347        }
348
349        Ok(None)
350    }
351
352    /// Run after_tool callbacks from all plugins.
353    pub async fn run_after_tool(&self, ctx: Arc<dyn CallbackContext>) -> Result<Option<Content>> {
354        for plugin in &self.plugins {
355            if let Some(callback) = plugin.after_tool() {
356                debug!(plugin = plugin.name(), "Running after_tool callback");
357                match callback(ctx.clone()).await {
358                    Ok(Some(content)) => {
359                        debug!(plugin = plugin.name(), "after_tool returned content");
360                        return Ok(Some(content));
361                    }
362                    Ok(None) => {
363                        // Continue to next plugin
364                    }
365                    Err(e) => {
366                        warn!(plugin = plugin.name(), error = %e, "after_tool callback failed");
367                        return Err(e);
368                    }
369                }
370            }
371        }
372
373        Ok(None)
374    }
375
376    /// Run on_tool_error callbacks from all plugins.
377    pub async fn run_on_tool_error(
378        &self,
379        ctx: Arc<dyn CallbackContext>,
380        tool: Arc<dyn Tool>,
381        args: serde_json::Value,
382        error: String,
383    ) -> Result<Option<serde_json::Value>> {
384        for plugin in &self.plugins {
385            if let Some(callback) = plugin.on_tool_error() {
386                debug!(
387                    plugin = plugin.name(),
388                    tool = tool.name(),
389                    "Running on_tool_error callback"
390                );
391                match callback(ctx.clone(), tool.clone(), args.clone(), error.clone()).await {
392                    Ok(Some(result)) => {
393                        debug!(plugin = plugin.name(), "on_tool_error provided fallback result");
394                        return Ok(Some(result));
395                    }
396                    Ok(None) => {
397                        // Continue to next plugin
398                    }
399                    Err(e) => {
400                        warn!(plugin = plugin.name(), error = %e, "on_tool_error callback failed");
401                        return Err(e);
402                    }
403                }
404            }
405        }
406
407        Ok(None)
408    }
409
410    /// Close all plugins with timeout.
411    pub async fn close(&self) {
412        debug!("Closing {} plugins", self.plugins.len());
413
414        for plugin in &self.plugins {
415            let close_future = plugin.close();
416            match tokio::time::timeout(self.config.close_timeout, close_future).await {
417                Ok(()) => {
418                    debug!(plugin = plugin.name(), "Plugin closed successfully");
419                }
420                Err(_) => {
421                    warn!(plugin = plugin.name(), "Plugin close timed out");
422                }
423            }
424        }
425    }
426}
427
428impl std::fmt::Debug for PluginManager {
429    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
430        f.debug_struct("PluginManager")
431            .field("plugin_count", &self.plugins.len())
432            .field("plugin_names", &self.plugin_names())
433            .field("close_timeout", &self.config.close_timeout)
434            .finish()
435    }
436}
437
438#[cfg(test)]
439mod tests {
440    use super::*;
441    use crate::PluginConfig;
442
443    #[test]
444    fn test_plugin_manager_creation() {
445        let plugins = vec![
446            Plugin::new(PluginConfig { name: "test1".to_string(), ..Default::default() }),
447            Plugin::new(PluginConfig { name: "test2".to_string(), ..Default::default() }),
448        ];
449
450        let manager = PluginManager::new(plugins);
451        assert_eq!(manager.plugin_count(), 2);
452        assert_eq!(manager.plugin_names(), vec!["test1", "test2"]);
453    }
454}