1use crate::Plugin;
6use adk_core::{
7 BeforeModelResult, CallbackContext, Content, Event, InvocationContext, LlmRequest, LlmResponse,
8 Result, Tool,
9};
10use std::sync::Arc;
11use std::time::Duration;
12use tracing::{debug, warn};
13
14#[derive(Clone)]
16pub struct PluginManagerConfig {
17 pub close_timeout: Duration,
19}
20
21impl Default for PluginManagerConfig {
22 fn default() -> Self {
23 Self { close_timeout: Duration::from_secs(5) }
24 }
25}
26
27pub struct PluginManager {
55 plugins: Vec<Plugin>,
56 config: PluginManagerConfig,
57}
58
59impl PluginManager {
60 pub fn new(plugins: Vec<Plugin>) -> Self {
62 Self { plugins, config: PluginManagerConfig::default() }
63 }
64
65 pub fn with_config(plugins: Vec<Plugin>, config: PluginManagerConfig) -> Self {
67 Self { plugins, config }
68 }
69
70 pub fn plugin_count(&self) -> usize {
72 self.plugins.len()
73 }
74
75 pub fn plugin_names(&self) -> Vec<&str> {
77 self.plugins.iter().map(|p| p.name()).collect()
78 }
79
80 pub async fn run_on_user_message(
84 &self,
85 ctx: Arc<dyn InvocationContext>,
86 content: Content,
87 ) -> Result<Option<Content>> {
88 let mut current_content = content;
89 let mut was_modified = false;
90
91 for plugin in &self.plugins {
92 if let Some(callback) = plugin.on_user_message() {
93 debug!(plugin = plugin.name(), "Running on_user_message callback");
94 match callback(ctx.clone(), current_content.clone()).await {
95 Ok(Some(modified)) => {
96 debug!(plugin = plugin.name(), "Content modified by plugin");
97 was_modified = true;
98 current_content = modified;
99 }
100 Ok(None) => {
101 }
103 Err(e) => {
104 warn!(plugin = plugin.name(), error = %e, "on_user_message callback failed");
105 return Err(e);
106 }
107 }
108 }
109 }
110
111 Ok(if was_modified { Some(current_content) } else { None })
112 }
113
114 pub async fn run_on_event(
118 &self,
119 ctx: Arc<dyn InvocationContext>,
120 event: Event,
121 ) -> Result<Option<Event>> {
122 let mut current_event = event;
123 let mut was_modified = false;
124
125 for plugin in &self.plugins {
126 if let Some(callback) = plugin.on_event() {
127 debug!(plugin = plugin.name(), event_id = %current_event.id, "Running on_event callback");
128 match callback(ctx.clone(), current_event.clone()).await {
129 Ok(Some(modified)) => {
130 debug!(plugin = plugin.name(), "Event modified by plugin");
131 was_modified = true;
132 current_event = modified;
133 }
134 Ok(None) => {
135 }
137 Err(e) => {
138 warn!(plugin = plugin.name(), error = %e, "on_event callback failed");
139 return Err(e);
140 }
141 }
142 }
143 }
144
145 Ok(if was_modified { Some(current_event) } else { None })
146 }
147
148 pub async fn run_before_run(&self, ctx: Arc<dyn InvocationContext>) -> Result<Option<Content>> {
152 for plugin in &self.plugins {
153 if let Some(callback) = plugin.before_run() {
154 debug!(plugin = plugin.name(), "Running before_run callback");
155 match callback(ctx.clone()).await {
156 Ok(Some(content)) => {
157 debug!(plugin = plugin.name(), "before_run returned early exit content");
158 return Ok(Some(content));
159 }
160 Ok(None) => {
161 }
163 Err(e) => {
164 warn!(plugin = plugin.name(), error = %e, "before_run callback failed");
165 return Err(e);
166 }
167 }
168 }
169 }
170
171 Ok(None)
172 }
173
174 pub async fn run_after_run(&self, ctx: Arc<dyn InvocationContext>) {
178 for plugin in &self.plugins {
179 if let Some(callback) = plugin.after_run() {
180 debug!(plugin = plugin.name(), "Running after_run callback");
181 callback(ctx.clone()).await;
182 }
183 }
184 }
185
186 pub async fn run_before_agent(&self, ctx: Arc<dyn CallbackContext>) -> Result<Option<Content>> {
190 for plugin in &self.plugins {
191 if let Some(callback) = plugin.before_agent() {
192 debug!(plugin = plugin.name(), "Running before_agent callback");
193 match callback(ctx.clone()).await {
194 Ok(Some(content)) => {
195 debug!(plugin = plugin.name(), "before_agent returned early exit content");
196 return Ok(Some(content));
197 }
198 Ok(None) => {
199 }
201 Err(e) => {
202 warn!(plugin = plugin.name(), error = %e, "before_agent callback failed");
203 return Err(e);
204 }
205 }
206 }
207 }
208
209 Ok(None)
210 }
211
212 pub async fn run_after_agent(&self, ctx: Arc<dyn CallbackContext>) -> Result<Option<Content>> {
214 for plugin in &self.plugins {
215 if let Some(callback) = plugin.after_agent() {
216 debug!(plugin = plugin.name(), "Running after_agent callback");
217 match callback(ctx.clone()).await {
218 Ok(Some(content)) => {
219 debug!(plugin = plugin.name(), "after_agent returned content");
220 return Ok(Some(content));
221 }
222 Ok(None) => {
223 }
225 Err(e) => {
226 warn!(plugin = plugin.name(), error = %e, "after_agent callback failed");
227 return Err(e);
228 }
229 }
230 }
231 }
232
233 Ok(None)
234 }
235
236 pub async fn run_before_model(
240 &self,
241 ctx: Arc<dyn CallbackContext>,
242 request: LlmRequest,
243 ) -> Result<BeforeModelResult> {
244 let mut current_request = request;
245
246 for plugin in &self.plugins {
247 if let Some(callback) = plugin.before_model() {
248 debug!(plugin = plugin.name(), "Running before_model callback");
249 match callback(ctx.clone(), current_request.clone()).await {
250 Ok(BeforeModelResult::Continue(modified)) => {
251 current_request = modified;
252 }
253 Ok(BeforeModelResult::Skip(response)) => {
254 debug!(plugin = plugin.name(), "before_model skipped model call");
255 return Ok(BeforeModelResult::Skip(response));
256 }
257 Err(e) => {
258 warn!(plugin = plugin.name(), error = %e, "before_model callback failed");
259 return Err(e);
260 }
261 }
262 }
263 }
264
265 Ok(BeforeModelResult::Continue(current_request))
266 }
267
268 pub async fn run_after_model(
270 &self,
271 ctx: Arc<dyn CallbackContext>,
272 response: LlmResponse,
273 ) -> Result<Option<LlmResponse>> {
274 let mut current_response = response;
275 let mut was_modified = false;
276
277 for plugin in &self.plugins {
278 if let Some(callback) = plugin.after_model() {
279 debug!(plugin = plugin.name(), "Running after_model callback");
280 match callback(ctx.clone(), current_response.clone()).await {
281 Ok(Some(modified)) => {
282 was_modified = true;
283 current_response = modified;
284 }
285 Ok(None) => {
286 }
288 Err(e) => {
289 warn!(plugin = plugin.name(), error = %e, "after_model callback failed");
290 return Err(e);
291 }
292 }
293 }
294 }
295
296 Ok(if was_modified { Some(current_response) } else { None })
297 }
298
299 pub async fn run_on_model_error(
301 &self,
302 ctx: Arc<dyn CallbackContext>,
303 request: LlmRequest,
304 error: String,
305 ) -> Result<Option<LlmResponse>> {
306 for plugin in &self.plugins {
307 if let Some(callback) = plugin.on_model_error() {
308 debug!(plugin = plugin.name(), "Running on_model_error callback");
309 match callback(ctx.clone(), request.clone(), error.clone()).await {
310 Ok(Some(response)) => {
311 debug!(plugin = plugin.name(), "on_model_error provided fallback response");
312 return Ok(Some(response));
313 }
314 Ok(None) => {
315 }
317 Err(e) => {
318 warn!(plugin = plugin.name(), error = %e, "on_model_error callback failed");
319 return Err(e);
320 }
321 }
322 }
323 }
324
325 Ok(None)
326 }
327
328 pub async fn run_before_tool(&self, ctx: Arc<dyn CallbackContext>) -> Result<Option<Content>> {
330 for plugin in &self.plugins {
331 if let Some(callback) = plugin.before_tool() {
332 debug!(plugin = plugin.name(), "Running before_tool callback");
333 match callback(ctx.clone()).await {
334 Ok(Some(content)) => {
335 debug!(plugin = plugin.name(), "before_tool returned early exit content");
336 return Ok(Some(content));
337 }
338 Ok(None) => {
339 }
341 Err(e) => {
342 warn!(plugin = plugin.name(), error = %e, "before_tool callback failed");
343 return Err(e);
344 }
345 }
346 }
347 }
348
349 Ok(None)
350 }
351
352 pub async fn run_after_tool(&self, ctx: Arc<dyn CallbackContext>) -> Result<Option<Content>> {
354 for plugin in &self.plugins {
355 if let Some(callback) = plugin.after_tool() {
356 debug!(plugin = plugin.name(), "Running after_tool callback");
357 match callback(ctx.clone()).await {
358 Ok(Some(content)) => {
359 debug!(plugin = plugin.name(), "after_tool returned content");
360 return Ok(Some(content));
361 }
362 Ok(None) => {
363 }
365 Err(e) => {
366 warn!(plugin = plugin.name(), error = %e, "after_tool callback failed");
367 return Err(e);
368 }
369 }
370 }
371 }
372
373 Ok(None)
374 }
375
376 pub async fn run_on_tool_error(
378 &self,
379 ctx: Arc<dyn CallbackContext>,
380 tool: Arc<dyn Tool>,
381 args: serde_json::Value,
382 error: String,
383 ) -> Result<Option<serde_json::Value>> {
384 for plugin in &self.plugins {
385 if let Some(callback) = plugin.on_tool_error() {
386 debug!(
387 plugin = plugin.name(),
388 tool = tool.name(),
389 "Running on_tool_error callback"
390 );
391 match callback(ctx.clone(), tool.clone(), args.clone(), error.clone()).await {
392 Ok(Some(result)) => {
393 debug!(plugin = plugin.name(), "on_tool_error provided fallback result");
394 return Ok(Some(result));
395 }
396 Ok(None) => {
397 }
399 Err(e) => {
400 warn!(plugin = plugin.name(), error = %e, "on_tool_error callback failed");
401 return Err(e);
402 }
403 }
404 }
405 }
406
407 Ok(None)
408 }
409
410 pub async fn close(&self) {
412 debug!("Closing {} plugins", self.plugins.len());
413
414 for plugin in &self.plugins {
415 let close_future = plugin.close();
416 match tokio::time::timeout(self.config.close_timeout, close_future).await {
417 Ok(()) => {
418 debug!(plugin = plugin.name(), "Plugin closed successfully");
419 }
420 Err(_) => {
421 warn!(plugin = plugin.name(), "Plugin close timed out");
422 }
423 }
424 }
425 }
426}
427
428impl std::fmt::Debug for PluginManager {
429 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
430 f.debug_struct("PluginManager")
431 .field("plugin_count", &self.plugins.len())
432 .field("plugin_names", &self.plugin_names())
433 .field("close_timeout", &self.config.close_timeout)
434 .finish()
435 }
436}
437
438#[cfg(test)]
439mod tests {
440 use super::*;
441 use crate::PluginConfig;
442
443 #[test]
444 fn test_plugin_manager_creation() {
445 let plugins = vec![
446 Plugin::new(PluginConfig { name: "test1".to_string(), ..Default::default() }),
447 Plugin::new(PluginConfig { name: "test2".to_string(), ..Default::default() }),
448 ];
449
450 let manager = PluginManager::new(plugins);
451 assert_eq!(manager.plugin_count(), 2);
452 assert_eq!(manager.plugin_names(), vec!["test1", "test2"]);
453 }
454}