Skip to main content

adk_plugin/
adapted_plugin.rs

1//! Adapter wrapping a legacy closure-based [`Plugin`] as an [`EnhancedPlugin`].
2//!
3//! The [`AdaptedPlugin`] struct bridges the existing closure-based plugin system
4//! to the new trait-based [`EnhancedPlugin`] interface, enabling legacy plugins
5//! to participate in the enhanced pipeline without modification.
6//!
7//! # Overview
8//!
9//! Legacy plugins use callbacks that receive only a [`CallbackContext`] for tool hooks,
10//! and `(CallbackContext, LlmRequest)` / `(CallbackContext, LlmResponse)` for model hooks.
11//! They cannot modify tool arguments or results directly. The adapter:
12//!
13//! - Delegates `name()` to the inner [`Plugin::name()`]
14//! - Uses a configurable priority (default 100)
15//! - Invokes legacy `before_tool` / `after_tool` callbacks for side effects,
16//!   but always returns `Continue` with unchanged args/result
17//! - Maps legacy [`BeforeModelResult`] to [`BeforeModelCallResult`]
18//! - Maps legacy `AfterModelCallback` results to [`AfterModelCallResult`]
19//! - Delegates `close()` to the inner [`Plugin::close()`]
20//!
21//! # Example
22//!
23//! ```rust,ignore
24//! use adk_plugin::{AdaptedPlugin, Plugin, PluginConfig};
25//!
26//! let legacy_plugin = Plugin::new(PluginConfig {
27//!     name: "my-legacy-plugin".to_string(),
28//!     before_tool: Some(Box::new(|ctx| {
29//!         Box::pin(async move {
30//!             tracing::info!("tool starting");
31//!             Ok(None)
32//!         })
33//!     })),
34//!     ..Default::default()
35//! });
36//!
37//! // Wrap with default priority (100)
38//! let adapted = AdaptedPlugin::new(legacy_plugin, 100);
39//! ```
40
41use std::sync::Arc;
42
43use adk_core::{
44    BeforeModelResult, CallbackContext, LlmRequest, LlmResponse, Result, Tool, async_trait,
45};
46use serde_json::Value;
47
48use crate::context::PluginContext;
49use crate::enhanced_plugin::EnhancedPlugin;
50use crate::hook_result::{
51    AfterModelCallResult, AfterToolCallResult, BeforeModelCallResult, BeforeToolCallResult,
52};
53use crate::plugin::Plugin;
54
55/// Wraps a legacy closure-based [`Plugin`] as an [`EnhancedPlugin`].
56///
57/// This adapter enables existing plugins to participate in the enhanced
58/// pipeline without modification. Legacy callbacks are invoked for their
59/// side effects, but the adapter does not modify tool arguments or results
60/// (legacy callbacks don't have access to them).
61///
62/// For model hooks, the adapter maps between the legacy [`BeforeModelResult`]
63/// and the new [`BeforeModelCallResult`], preserving short-circuit semantics.
64pub struct AdaptedPlugin {
65    inner: Plugin,
66    priority: i32,
67}
68
69impl AdaptedPlugin {
70    /// Create a new adapter wrapping a legacy plugin with the given priority.
71    ///
72    /// # Arguments
73    ///
74    /// * `plugin` - The legacy [`Plugin`] to wrap
75    /// * `priority` - Execution priority (lower values execute first, default: 100)
76    ///
77    /// # Example
78    ///
79    /// ```rust,ignore
80    /// use adk_plugin::{AdaptedPlugin, Plugin, PluginConfig};
81    ///
82    /// let plugin = Plugin::new(PluginConfig {
83    ///     name: "logger".to_string(),
84    ///     ..Default::default()
85    /// });
86    ///
87    /// let adapted = AdaptedPlugin::new(plugin, 50);
88    /// assert_eq!(adapted.priority(), 50);
89    /// ```
90    pub fn new(plugin: Plugin, priority: i32) -> Self {
91        Self { inner: plugin, priority }
92    }
93}
94
95#[async_trait]
96impl EnhancedPlugin for AdaptedPlugin {
97    fn name(&self) -> &str {
98        self.inner.name()
99    }
100
101    fn priority(&self) -> i32 {
102        self.priority
103    }
104
105    async fn before_tool_call(
106        &self,
107        _tool: Arc<dyn Tool>,
108        args: Value,
109        ctx: Arc<dyn CallbackContext>,
110        _plugin_ctx: &PluginContext,
111    ) -> Result<BeforeToolCallResult> {
112        // Legacy before_tool callbacks only receive CallbackContext and return
113        // Ok(None) to continue or Ok(Some(content)) to skip. They cannot modify
114        // tool arguments. We invoke for side effects and always return Continue.
115        if let Some(callback) = self.inner.before_tool() {
116            // Invoke the legacy callback for its side effects (logging, etc.)
117            // We ignore the return value since legacy callbacks can't modify args.
118            let _ = callback(ctx).await?;
119        }
120        Ok(BeforeToolCallResult::Continue(args))
121    }
122
123    async fn after_tool_call(
124        &self,
125        _tool: Arc<dyn Tool>,
126        _args: &Value,
127        result: Value,
128        ctx: Arc<dyn CallbackContext>,
129        _plugin_ctx: &PluginContext,
130    ) -> Result<AfterToolCallResult> {
131        // Legacy after_tool callbacks only receive CallbackContext and return
132        // Ok(None) to continue or Ok(Some(content)). They cannot modify tool results.
133        // We invoke for side effects and always return Continue with unchanged result.
134        if let Some(callback) = self.inner.after_tool() {
135            let _ = callback(ctx).await?;
136        }
137        Ok(AfterToolCallResult::Continue(result))
138    }
139
140    async fn before_model_call(
141        &self,
142        request: LlmRequest,
143        ctx: Arc<dyn CallbackContext>,
144        _plugin_ctx: &PluginContext,
145    ) -> Result<BeforeModelCallResult> {
146        // Legacy before_model callbacks receive (CallbackContext, LlmRequest) and return
147        // BeforeModelResult::Continue(request) or BeforeModelResult::Skip(response).
148        // We map these to the new BeforeModelCallResult variants.
149        if let Some(callback) = self.inner.before_model() {
150            let legacy_result = callback(ctx, request).await?;
151            match legacy_result {
152                BeforeModelResult::Continue(req) => Ok(BeforeModelCallResult::Continue(req)),
153                BeforeModelResult::Skip(response) => {
154                    Ok(BeforeModelCallResult::ShortCircuit(response))
155                }
156            }
157        } else {
158            Ok(BeforeModelCallResult::Continue(request))
159        }
160    }
161
162    async fn after_model_call(
163        &self,
164        response: LlmResponse,
165        ctx: Arc<dyn CallbackContext>,
166        _plugin_ctx: &PluginContext,
167    ) -> Result<AfterModelCallResult> {
168        // Legacy after_model callbacks receive (CallbackContext, LlmResponse) and return
169        // Ok(Some(response)) to replace or Ok(None) to keep original.
170        if let Some(callback) = self.inner.after_model() {
171            let result = callback(ctx, response.clone()).await?;
172            match result {
173                Some(modified_response) => Ok(AfterModelCallResult::Continue(modified_response)),
174                None => Ok(AfterModelCallResult::Continue(response)),
175            }
176        } else {
177            Ok(AfterModelCallResult::Continue(response))
178        }
179    }
180
181    async fn close(&self) {
182        self.inner.close().await;
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189    use crate::{PluginConfig, plugin::Plugin};
190    use adk_core::{BeforeModelResult, Content, LlmRequest, LlmResponse, Part};
191    use std::sync::atomic::{AtomicBool, Ordering};
192
193    /// Mock CallbackContext for testing
194    struct MockCallbackContext;
195
196    impl adk_core::ReadonlyContext for MockCallbackContext {
197        fn invocation_id(&self) -> &str {
198            "test-invocation"
199        }
200
201        fn agent_name(&self) -> &str {
202            "test-agent"
203        }
204
205        fn user_id(&self) -> &str {
206            "test-user"
207        }
208
209        fn app_name(&self) -> &str {
210            "test-app"
211        }
212
213        fn session_id(&self) -> &str {
214            "test-session"
215        }
216
217        fn branch(&self) -> &str {
218            "main"
219        }
220
221        fn user_content(&self) -> &Content {
222            static CONTENT: std::sync::OnceLock<Content> = std::sync::OnceLock::new();
223            CONTENT.get_or_init(|| Content::new("user"))
224        }
225    }
226
227    #[async_trait]
228    impl CallbackContext for MockCallbackContext {
229        fn artifacts(&self) -> Option<Arc<dyn adk_core::Artifacts>> {
230            None
231        }
232    }
233
234    /// Mock Tool for testing
235    struct MockTool;
236
237    #[async_trait]
238    impl Tool for MockTool {
239        fn name(&self) -> &str {
240            "mock-tool"
241        }
242
243        fn description(&self) -> &str {
244            "A mock tool for testing"
245        }
246
247        async fn execute(
248            &self,
249            _ctx: Arc<dyn adk_core::ToolContext>,
250            _args: Value,
251        ) -> adk_core::Result<Value> {
252            Ok(Value::Null)
253        }
254    }
255
256    #[tokio::test]
257    async fn test_name_delegates_to_inner() {
258        let plugin = Plugin::new(PluginConfig {
259            name: "my-legacy-plugin".to_string(),
260            ..Default::default()
261        });
262        let adapted = AdaptedPlugin::new(plugin, 100);
263        assert_eq!(adapted.name(), "my-legacy-plugin");
264    }
265
266    #[tokio::test]
267    async fn test_priority_uses_configured_value() {
268        let plugin = Plugin::new(PluginConfig { name: "test".to_string(), ..Default::default() });
269        let adapted = AdaptedPlugin::new(plugin, 42);
270        assert_eq!(adapted.priority(), 42);
271    }
272
273    #[tokio::test]
274    async fn test_before_tool_call_invokes_legacy_callback() {
275        let called = Arc::new(AtomicBool::new(false));
276        let called_clone = called.clone();
277
278        let plugin = Plugin::new(PluginConfig {
279            name: "test".to_string(),
280            before_tool: Some(Box::new(move |_ctx| {
281                let flag = called_clone.clone();
282                Box::pin(async move {
283                    flag.store(true, Ordering::SeqCst);
284                    Ok(None)
285                })
286            })),
287            ..Default::default()
288        });
289
290        let adapted = AdaptedPlugin::new(plugin, 100);
291        let ctx: Arc<dyn CallbackContext> = Arc::new(MockCallbackContext);
292        let plugin_ctx = PluginContext::new();
293        let tool: Arc<dyn Tool> = Arc::new(MockTool);
294        let args = serde_json::json!({"key": "value"});
295
296        let result = adapted.before_tool_call(tool, args.clone(), ctx, &plugin_ctx).await.unwrap();
297
298        assert!(called.load(Ordering::SeqCst));
299        match result {
300            BeforeToolCallResult::Continue(returned_args) => {
301                assert_eq!(returned_args, args);
302            }
303            _ => panic!("expected Continue"),
304        }
305    }
306
307    #[tokio::test]
308    async fn test_after_tool_call_invokes_legacy_callback() {
309        let called = Arc::new(AtomicBool::new(false));
310        let called_clone = called.clone();
311
312        let plugin = Plugin::new(PluginConfig {
313            name: "test".to_string(),
314            after_tool: Some(Box::new(move |_ctx| {
315                let flag = called_clone.clone();
316                Box::pin(async move {
317                    flag.store(true, Ordering::SeqCst);
318                    Ok(None)
319                })
320            })),
321            ..Default::default()
322        });
323
324        let adapted = AdaptedPlugin::new(plugin, 100);
325        let ctx: Arc<dyn CallbackContext> = Arc::new(MockCallbackContext);
326        let plugin_ctx = PluginContext::new();
327        let tool: Arc<dyn Tool> = Arc::new(MockTool);
328        let args = serde_json::json!({"input": "test"});
329        let result_val = serde_json::json!({"output": "done"});
330
331        let result = adapted
332            .after_tool_call(tool, &args, result_val.clone(), ctx, &plugin_ctx)
333            .await
334            .unwrap();
335
336        assert!(called.load(Ordering::SeqCst));
337        match result {
338            AfterToolCallResult::Continue(returned_result) => {
339                assert_eq!(returned_result, result_val);
340            }
341        }
342    }
343
344    #[tokio::test]
345    async fn test_before_model_call_maps_continue() {
346        let plugin = Plugin::new(PluginConfig {
347            name: "test".to_string(),
348            before_model: Some(Box::new(|_ctx, request| {
349                Box::pin(async move { Ok(BeforeModelResult::Continue(request)) })
350            })),
351            ..Default::default()
352        });
353
354        let adapted = AdaptedPlugin::new(plugin, 100);
355        let ctx: Arc<dyn CallbackContext> = Arc::new(MockCallbackContext);
356        let plugin_ctx = PluginContext::new();
357        let request = LlmRequest::new("test-model", vec![]);
358
359        let result = adapted.before_model_call(request, ctx, &plugin_ctx).await.unwrap();
360
361        match result {
362            BeforeModelCallResult::Continue(req) => {
363                assert_eq!(req.model, "test-model");
364            }
365            _ => panic!("expected Continue"),
366        }
367    }
368
369    #[tokio::test]
370    async fn test_before_model_call_maps_skip_to_short_circuit() {
371        let plugin = Plugin::new(PluginConfig {
372            name: "test".to_string(),
373            before_model: Some(Box::new(|_ctx, _request| {
374                Box::pin(async move {
375                    let response = LlmResponse {
376                        content: Some(Content::new("model").with_text("cached")),
377                        ..Default::default()
378                    };
379                    Ok(BeforeModelResult::Skip(response))
380                })
381            })),
382            ..Default::default()
383        });
384
385        let adapted = AdaptedPlugin::new(plugin, 100);
386        let ctx: Arc<dyn CallbackContext> = Arc::new(MockCallbackContext);
387        let plugin_ctx = PluginContext::new();
388        let request = LlmRequest::new("model", vec![]);
389
390        let result = adapted.before_model_call(request, ctx, &plugin_ctx).await.unwrap();
391
392        match result {
393            BeforeModelCallResult::ShortCircuit(resp) => {
394                assert!(resp.content.is_some());
395            }
396            _ => panic!("expected ShortCircuit"),
397        }
398    }
399
400    #[tokio::test]
401    async fn test_after_model_call_maps_some_to_continue_modified() {
402        let plugin = Plugin::new(PluginConfig {
403            name: "test".to_string(),
404            after_model: Some(Box::new(|_ctx, _response| {
405                Box::pin(async move {
406                    let modified = LlmResponse {
407                        content: Some(Content::new("model").with_text("modified")),
408                        ..Default::default()
409                    };
410                    Ok(Some(modified))
411                })
412            })),
413            ..Default::default()
414        });
415
416        let adapted = AdaptedPlugin::new(plugin, 100);
417        let ctx: Arc<dyn CallbackContext> = Arc::new(MockCallbackContext);
418        let plugin_ctx = PluginContext::new();
419        let response = LlmResponse::default();
420
421        let result = adapted.after_model_call(response, ctx, &plugin_ctx).await.unwrap();
422
423        match result {
424            AfterModelCallResult::Continue(resp) => {
425                let content = resp.content.unwrap();
426                assert!(
427                    content
428                        .parts
429                        .iter()
430                        .any(|p| matches!(p, Part::Text { text } if text == "modified"))
431                );
432            }
433        }
434    }
435
436    #[tokio::test]
437    async fn test_after_model_call_maps_none_to_continue_unchanged() {
438        let plugin = Plugin::new(PluginConfig {
439            name: "test".to_string(),
440            after_model: Some(Box::new(|_ctx, _response| Box::pin(async move { Ok(None) }))),
441            ..Default::default()
442        });
443
444        let adapted = AdaptedPlugin::new(plugin, 100);
445        let ctx: Arc<dyn CallbackContext> = Arc::new(MockCallbackContext);
446        let plugin_ctx = PluginContext::new();
447        let response = LlmResponse {
448            content: Some(Content::new("model").with_text("original")),
449            ..Default::default()
450        };
451
452        let result = adapted.after_model_call(response, ctx, &plugin_ctx).await.unwrap();
453
454        match result {
455            AfterModelCallResult::Continue(resp) => {
456                let content = resp.content.unwrap();
457                assert!(
458                    content
459                        .parts
460                        .iter()
461                        .any(|p| matches!(p, Part::Text { text } if text == "original"))
462                );
463            }
464        }
465    }
466
467    #[tokio::test]
468    async fn test_close_delegates_to_inner() {
469        let closed = Arc::new(AtomicBool::new(false));
470        let closed_clone = closed.clone();
471
472        let plugin = Plugin::new(PluginConfig {
473            name: "test".to_string(),
474            close_fn: Some(Box::new(move || {
475                let flag = closed_clone.clone();
476                Box::pin(async move {
477                    flag.store(true, Ordering::SeqCst);
478                })
479            })),
480            ..Default::default()
481        });
482
483        let adapted = AdaptedPlugin::new(plugin, 100);
484        adapted.close().await;
485
486        assert!(closed.load(Ordering::SeqCst));
487    }
488
489    #[tokio::test]
490    async fn test_no_callbacks_returns_continue_unchanged() {
491        let plugin = Plugin::new(PluginConfig { name: "empty".to_string(), ..Default::default() });
492
493        let adapted = AdaptedPlugin::new(plugin, 100);
494        let ctx: Arc<dyn CallbackContext> = Arc::new(MockCallbackContext);
495        let plugin_ctx = PluginContext::new();
496        let tool: Arc<dyn Tool> = Arc::new(MockTool);
497
498        // before_tool_call with no callback
499        let args = serde_json::json!({"x": 1});
500        let result = adapted
501            .before_tool_call(tool.clone(), args.clone(), ctx.clone(), &plugin_ctx)
502            .await
503            .unwrap();
504        match result {
505            BeforeToolCallResult::Continue(v) => assert_eq!(v, args),
506            _ => panic!("expected Continue"),
507        }
508
509        // after_tool_call with no callback
510        let res_val = serde_json::json!({"y": 2});
511        let result = adapted
512            .after_tool_call(tool.clone(), &args, res_val.clone(), ctx.clone(), &plugin_ctx)
513            .await
514            .unwrap();
515        match result {
516            AfterToolCallResult::Continue(v) => assert_eq!(v, res_val),
517        }
518
519        // before_model_call with no callback
520        let request = LlmRequest::new("m", vec![]);
521        let result = adapted.before_model_call(request, ctx.clone(), &plugin_ctx).await.unwrap();
522        match result {
523            BeforeModelCallResult::Continue(req) => assert_eq!(req.model, "m"),
524            _ => panic!("expected Continue"),
525        }
526
527        // after_model_call with no callback
528        let response = LlmResponse {
529            content: Some(Content::new("model").with_text("hi")),
530            ..Default::default()
531        };
532        let result = adapted.after_model_call(response, ctx, &plugin_ctx).await.unwrap();
533        match result {
534            AfterModelCallResult::Continue(resp) => {
535                assert!(resp.content.is_some());
536            }
537        }
538    }
539}