Skip to main content

adk_plugin/
enhanced_plugin.rs

1//! Enhanced plugin trait with fine-grained hooks and default no-op implementations.
2//!
3//! The [`EnhancedPlugin`] trait provides a trait-based plugin interface that allows
4//! plugin authors to implement only the hooks they need. All hook methods have default
5//! implementations that pass through inputs unchanged (identity behavior).
6//!
7//! # Overview
8//!
9//! Unlike the closure-based [`PluginConfig`](crate::PluginConfig) approach, `EnhancedPlugin`
10//! uses Rust's trait system for a more ergonomic and type-safe plugin authoring experience.
11//! Plugins can:
12//!
13//! - Intercept and modify tool call arguments before execution
14//! - Inspect and transform tool results after execution
15//! - Modify LLM requests before they are sent
16//! - Transform LLM responses after they are received
17//! - Access shared state via [`PluginContext`] across all hook invocations
18//! - Define execution priority for deterministic ordering
19//!
20//! # Examples
21//!
22//! ## Minimal plugin (no-op)
23//!
24//! ```rust
25//! use adk_core::async_trait;
26//! use adk_plugin::EnhancedPlugin;
27//!
28//! struct MyPlugin;
29//!
30//! #[async_trait]
31//! impl EnhancedPlugin for MyPlugin {
32//!     fn name(&self) -> &str {
33//!         "my-plugin"
34//!     }
35//! }
36//! ```
37//!
38//! ## Plugin with custom priority and before-tool hook
39//!
40//! ```rust,ignore
41//! use std::sync::Arc;
42//! use adk_core::{async_trait, CallbackContext, Result, Tool};
43//! use adk_plugin::{BeforeToolCallResult, EnhancedPlugin, PluginContext};
44//! use serde_json::Value;
45//!
46//! struct ValidationPlugin;
47//!
48//! #[async_trait]
49//! impl EnhancedPlugin for ValidationPlugin {
50//!     fn name(&self) -> &str {
51//!         "validation"
52//!     }
53//!
54//!     fn priority(&self) -> i32 {
55//!         10 // Run early in the pipeline
56//!     }
57//!
58//!     async fn before_tool_call(
59//!         &self,
60//!         tool: Arc<dyn Tool>,
61//!         args: Value,
62//!         _ctx: Arc<dyn CallbackContext>,
63//!         _plugin_ctx: &PluginContext,
64//!     ) -> Result<BeforeToolCallResult> {
65//!         // Inject a safety flag into all tool arguments
66//!         let mut modified = args;
67//!         if let Value::Object(ref mut map) = modified {
68//!             map.insert("safe_mode".to_string(), Value::Bool(true));
69//!         }
70//!         Ok(BeforeToolCallResult::Continue(modified))
71//!     }
72//! }
73//! ```
74//!
75//! ## Plugin using shared context for rate limiting
76//!
77//! ```rust,ignore
78//! use std::sync::Arc;
79//! use adk_core::{async_trait, AdkError, CallbackContext, Result, Tool};
80//! use adk_plugin::{BeforeToolCallResult, EnhancedPlugin, PluginContext};
81//! use serde_json::Value;
82//!
83//! #[derive(Clone)]
84//! struct RateLimitState {
85//!     call_count: u32,
86//! }
87//!
88//! struct RateLimitPlugin {
89//!     max_calls: u32,
90//! }
91//!
92//! #[async_trait]
93//! impl EnhancedPlugin for RateLimitPlugin {
94//!     fn name(&self) -> &str {
95//!         "rate-limiter"
96//!     }
97//!
98//!     fn priority(&self) -> i32 {
99//!         5 // Security plugins run first
100//!     }
101//!
102//!     async fn before_tool_call(
103//!         &self,
104//!         _tool: Arc<dyn Tool>,
105//!         args: Value,
106//!         _ctx: Arc<dyn CallbackContext>,
107//!         plugin_ctx: &PluginContext,
108//!     ) -> Result<BeforeToolCallResult> {
109//!         let mut state = plugin_ctx.get::<RateLimitState>().await
110//!             .unwrap_or(RateLimitState { call_count: 0 });
111//!
112//!         state.call_count += 1;
113//!         plugin_ctx.insert(state.clone()).await;
114//!
115//!         if state.call_count > self.max_calls {
116//!             return Err(AdkError::plugin("rate limit exceeded"));
117//!         }
118//!
119//!         Ok(BeforeToolCallResult::Continue(args))
120//!     }
121//! }
122//! ```
123
124use std::sync::Arc;
125
126use adk_core::{CallbackContext, LlmRequest, LlmResponse, Result, Tool, async_trait};
127use serde_json::Value;
128
129use crate::context::PluginContext;
130use crate::hook_result::{
131    AfterModelCallResult, AfterToolCallResult, BeforeModelCallResult, BeforeToolCallResult,
132};
133
134/// Enhanced plugin trait with fine-grained hooks and default no-op implementations.
135///
136/// Implement only the hooks you need. All methods have default implementations
137/// that pass through inputs unchanged (identity function behavior).
138///
139/// # Priority
140///
141/// Plugins execute in ascending priority order (lower values run first).
142/// The default priority is 100. Recommended ranges:
143///
144/// | Range | Use Case |
145/// |-------|----------|
146/// | 0–25 | Security plugins (auth, validation, rate limiting) |
147/// | 26–50 | Caching plugins |
148/// | 51–75 | Transformation plugins (sanitization, injection) |
149/// | 76–100 | Logging and metrics plugins |
150/// | 100+ | Application-specific plugins |
151///
152/// # Thread Safety
153///
154/// All implementations must be `Send + Sync` to support concurrent async execution.
155/// The [`PluginContext`] provides thread-safe shared state access via
156/// [`tokio::sync::RwLock`].
157#[async_trait]
158pub trait EnhancedPlugin: Send + Sync {
159    /// Unique name identifying this plugin.
160    ///
161    /// Used for logging, debugging, and error messages. Should be a short,
162    /// descriptive identifier (e.g., `"rate-limiter"`, `"cache"`, `"audit-log"`).
163    fn name(&self) -> &str;
164
165    /// Execution priority. Lower values execute first. Default: 100.
166    ///
167    /// When multiple plugins are registered, they execute hooks in ascending
168    /// priority order. Plugins with the same priority execute in registration order.
169    fn priority(&self) -> i32 {
170        100
171    }
172
173    /// Called before a tool is executed.
174    ///
175    /// Receives the tool reference, call arguments, callback context, and shared plugin context.
176    ///
177    /// # Returns
178    ///
179    /// - `Ok(BeforeToolCallResult::Continue(args))` — pass (possibly modified) args to the
180    ///   next plugin in the chain, and ultimately to the tool execution.
181    /// - `Ok(BeforeToolCallResult::ShortCircuit(result))` — skip tool execution entirely
182    ///   and use this synthetic result. No further plugins in the chain are invoked.
183    /// - `Err(e)` — stop the pipeline and propagate the error. The tool is not executed.
184    async fn before_tool_call(
185        &self,
186        _tool: Arc<dyn Tool>,
187        args: Value,
188        _ctx: Arc<dyn CallbackContext>,
189        _plugin_ctx: &PluginContext,
190    ) -> Result<BeforeToolCallResult> {
191        Ok(BeforeToolCallResult::Continue(args))
192    }
193
194    /// Called after a tool executes successfully.
195    ///
196    /// Receives the tool reference, the arguments that were used for execution
197    /// (after any modifications by before-hooks), the result, callback context,
198    /// and shared plugin context.
199    ///
200    /// # Returns
201    ///
202    /// - `Ok(AfterToolCallResult::Continue(result))` — pass (possibly modified) result
203    ///   to the next plugin in the chain, and ultimately return to the agent.
204    /// - `Err(e)` — stop the pipeline and propagate the error.
205    async fn after_tool_call(
206        &self,
207        _tool: Arc<dyn Tool>,
208        _args: &Value,
209        result: Value,
210        _ctx: Arc<dyn CallbackContext>,
211        _plugin_ctx: &PluginContext,
212    ) -> Result<AfterToolCallResult> {
213        Ok(AfterToolCallResult::Continue(result))
214    }
215
216    /// Called before a model (LLM) call is made.
217    ///
218    /// Receives the LLM request, callback context, and shared plugin context.
219    ///
220    /// # Returns
221    ///
222    /// - `Ok(BeforeModelCallResult::Continue(request))` — pass (possibly modified) request
223    ///   to the next plugin in the chain, and ultimately to the LLM provider.
224    /// - `Ok(BeforeModelCallResult::ShortCircuit(response))` — skip the model call entirely
225    ///   and use this synthetic response. No further plugins in the chain are invoked.
226    /// - `Err(e)` — stop the pipeline and propagate the error. The model is not called.
227    async fn before_model_call(
228        &self,
229        request: LlmRequest,
230        _ctx: Arc<dyn CallbackContext>,
231        _plugin_ctx: &PluginContext,
232    ) -> Result<BeforeModelCallResult> {
233        Ok(BeforeModelCallResult::Continue(request))
234    }
235
236    /// Called after a model (LLM) call completes.
237    ///
238    /// Receives the LLM response, callback context, and shared plugin context.
239    ///
240    /// # Returns
241    ///
242    /// - `Ok(AfterModelCallResult::Continue(response))` — pass (possibly modified) response
243    ///   to the next plugin in the chain, and ultimately return to the agent.
244    /// - `Err(e)` — stop the pipeline and propagate the error.
245    async fn after_model_call(
246        &self,
247        response: LlmResponse,
248        _ctx: Arc<dyn CallbackContext>,
249        _plugin_ctx: &PluginContext,
250    ) -> Result<AfterModelCallResult> {
251        Ok(AfterModelCallResult::Continue(response))
252    }
253
254    /// Called when the plugin is being shut down.
255    ///
256    /// Override this method to perform cleanup operations such as flushing
257    /// buffers, closing connections, or persisting state.
258    ///
259    /// The default implementation is a no-op.
260    async fn close(&self) {}
261}