adk_plugin/enhanced_plugin.rs
1//! Enhanced plugin trait with fine-grained hooks and default no-op implementations.
2//!
3//! The [`EnhancedPlugin`] trait provides a trait-based plugin interface that allows
4//! plugin authors to implement only the hooks they need. All hook methods have default
5//! implementations that pass through inputs unchanged (identity behavior).
6//!
7//! # Overview
8//!
9//! Unlike the closure-based [`PluginConfig`](crate::PluginConfig) approach, `EnhancedPlugin`
10//! uses Rust's trait system for a more ergonomic and type-safe plugin authoring experience.
11//! Plugins can:
12//!
13//! - Intercept and modify tool call arguments before execution
14//! - Inspect and transform tool results after execution
15//! - Modify LLM requests before they are sent
16//! - Transform LLM responses after they are received
17//! - Access shared state via [`PluginContext`] across all hook invocations
18//! - Define execution priority for deterministic ordering
19//!
20//! # Examples
21//!
22//! ## Minimal plugin (no-op)
23//!
24//! ```rust
25//! use adk_core::async_trait;
26//! use adk_plugin::EnhancedPlugin;
27//!
28//! struct MyPlugin;
29//!
30//! #[async_trait]
31//! impl EnhancedPlugin for MyPlugin {
32//! fn name(&self) -> &str {
33//! "my-plugin"
34//! }
35//! }
36//! ```
37//!
38//! ## Plugin with custom priority and before-tool hook
39//!
40//! ```rust,ignore
41//! use std::sync::Arc;
42//! use adk_core::{async_trait, CallbackContext, Result, Tool};
43//! use adk_plugin::{BeforeToolCallResult, EnhancedPlugin, PluginContext};
44//! use serde_json::Value;
45//!
46//! struct ValidationPlugin;
47//!
48//! #[async_trait]
49//! impl EnhancedPlugin for ValidationPlugin {
50//! fn name(&self) -> &str {
51//! "validation"
52//! }
53//!
54//! fn priority(&self) -> i32 {
55//! 10 // Run early in the pipeline
56//! }
57//!
58//! async fn before_tool_call(
59//! &self,
60//! tool: Arc<dyn Tool>,
61//! args: Value,
62//! _ctx: Arc<dyn CallbackContext>,
63//! _plugin_ctx: &PluginContext,
64//! ) -> Result<BeforeToolCallResult> {
65//! // Inject a safety flag into all tool arguments
66//! let mut modified = args;
67//! if let Value::Object(ref mut map) = modified {
68//! map.insert("safe_mode".to_string(), Value::Bool(true));
69//! }
70//! Ok(BeforeToolCallResult::Continue(modified))
71//! }
72//! }
73//! ```
74//!
75//! ## Plugin using shared context for rate limiting
76//!
77//! ```rust,ignore
78//! use std::sync::Arc;
79//! use adk_core::{async_trait, AdkError, CallbackContext, Result, Tool};
80//! use adk_plugin::{BeforeToolCallResult, EnhancedPlugin, PluginContext};
81//! use serde_json::Value;
82//!
83//! #[derive(Clone)]
84//! struct RateLimitState {
85//! call_count: u32,
86//! }
87//!
88//! struct RateLimitPlugin {
89//! max_calls: u32,
90//! }
91//!
92//! #[async_trait]
93//! impl EnhancedPlugin for RateLimitPlugin {
94//! fn name(&self) -> &str {
95//! "rate-limiter"
96//! }
97//!
98//! fn priority(&self) -> i32 {
99//! 5 // Security plugins run first
100//! }
101//!
102//! async fn before_tool_call(
103//! &self,
104//! _tool: Arc<dyn Tool>,
105//! args: Value,
106//! _ctx: Arc<dyn CallbackContext>,
107//! plugin_ctx: &PluginContext,
108//! ) -> Result<BeforeToolCallResult> {
109//! let mut state = plugin_ctx.get::<RateLimitState>().await
110//! .unwrap_or(RateLimitState { call_count: 0 });
111//!
112//! state.call_count += 1;
113//! plugin_ctx.insert(state.clone()).await;
114//!
115//! if state.call_count > self.max_calls {
116//! return Err(AdkError::plugin("rate limit exceeded"));
117//! }
118//!
119//! Ok(BeforeToolCallResult::Continue(args))
120//! }
121//! }
122//! ```
123
124use std::sync::Arc;
125
126use adk_core::{CallbackContext, LlmRequest, LlmResponse, Result, Tool, async_trait};
127use serde_json::Value;
128
129use crate::context::PluginContext;
130use crate::hook_result::{
131 AfterModelCallResult, AfterToolCallResult, BeforeModelCallResult, BeforeToolCallResult,
132};
133
134/// Enhanced plugin trait with fine-grained hooks and default no-op implementations.
135///
136/// Implement only the hooks you need. All methods have default implementations
137/// that pass through inputs unchanged (identity function behavior).
138///
139/// # Priority
140///
141/// Plugins execute in ascending priority order (lower values run first).
142/// The default priority is 100. Recommended ranges:
143///
144/// | Range | Use Case |
145/// |-------|----------|
146/// | 0–25 | Security plugins (auth, validation, rate limiting) |
147/// | 26–50 | Caching plugins |
148/// | 51–75 | Transformation plugins (sanitization, injection) |
149/// | 76–100 | Logging and metrics plugins |
150/// | 100+ | Application-specific plugins |
151///
152/// # Thread Safety
153///
154/// All implementations must be `Send + Sync` to support concurrent async execution.
155/// The [`PluginContext`] provides thread-safe shared state access via
156/// [`tokio::sync::RwLock`].
157#[async_trait]
158pub trait EnhancedPlugin: Send + Sync {
159 /// Unique name identifying this plugin.
160 ///
161 /// Used for logging, debugging, and error messages. Should be a short,
162 /// descriptive identifier (e.g., `"rate-limiter"`, `"cache"`, `"audit-log"`).
163 fn name(&self) -> &str;
164
165 /// Execution priority. Lower values execute first. Default: 100.
166 ///
167 /// When multiple plugins are registered, they execute hooks in ascending
168 /// priority order. Plugins with the same priority execute in registration order.
169 fn priority(&self) -> i32 {
170 100
171 }
172
173 /// Called before a tool is executed.
174 ///
175 /// Receives the tool reference, call arguments, callback context, and shared plugin context.
176 ///
177 /// # Returns
178 ///
179 /// - `Ok(BeforeToolCallResult::Continue(args))` — pass (possibly modified) args to the
180 /// next plugin in the chain, and ultimately to the tool execution.
181 /// - `Ok(BeforeToolCallResult::ShortCircuit(result))` — skip tool execution entirely
182 /// and use this synthetic result. No further plugins in the chain are invoked.
183 /// - `Err(e)` — stop the pipeline and propagate the error. The tool is not executed.
184 async fn before_tool_call(
185 &self,
186 _tool: Arc<dyn Tool>,
187 args: Value,
188 _ctx: Arc<dyn CallbackContext>,
189 _plugin_ctx: &PluginContext,
190 ) -> Result<BeforeToolCallResult> {
191 Ok(BeforeToolCallResult::Continue(args))
192 }
193
194 /// Called after a tool executes successfully.
195 ///
196 /// Receives the tool reference, the arguments that were used for execution
197 /// (after any modifications by before-hooks), the result, callback context,
198 /// and shared plugin context.
199 ///
200 /// # Returns
201 ///
202 /// - `Ok(AfterToolCallResult::Continue(result))` — pass (possibly modified) result
203 /// to the next plugin in the chain, and ultimately return to the agent.
204 /// - `Err(e)` — stop the pipeline and propagate the error.
205 async fn after_tool_call(
206 &self,
207 _tool: Arc<dyn Tool>,
208 _args: &Value,
209 result: Value,
210 _ctx: Arc<dyn CallbackContext>,
211 _plugin_ctx: &PluginContext,
212 ) -> Result<AfterToolCallResult> {
213 Ok(AfterToolCallResult::Continue(result))
214 }
215
216 /// Called before a model (LLM) call is made.
217 ///
218 /// Receives the LLM request, callback context, and shared plugin context.
219 ///
220 /// # Returns
221 ///
222 /// - `Ok(BeforeModelCallResult::Continue(request))` — pass (possibly modified) request
223 /// to the next plugin in the chain, and ultimately to the LLM provider.
224 /// - `Ok(BeforeModelCallResult::ShortCircuit(response))` — skip the model call entirely
225 /// and use this synthetic response. No further plugins in the chain are invoked.
226 /// - `Err(e)` — stop the pipeline and propagate the error. The model is not called.
227 async fn before_model_call(
228 &self,
229 request: LlmRequest,
230 _ctx: Arc<dyn CallbackContext>,
231 _plugin_ctx: &PluginContext,
232 ) -> Result<BeforeModelCallResult> {
233 Ok(BeforeModelCallResult::Continue(request))
234 }
235
236 /// Called after a model (LLM) call completes.
237 ///
238 /// Receives the LLM response, callback context, and shared plugin context.
239 ///
240 /// # Returns
241 ///
242 /// - `Ok(AfterModelCallResult::Continue(response))` — pass (possibly modified) response
243 /// to the next plugin in the chain, and ultimately return to the agent.
244 /// - `Err(e)` — stop the pipeline and propagate the error.
245 async fn after_model_call(
246 &self,
247 response: LlmResponse,
248 _ctx: Arc<dyn CallbackContext>,
249 _plugin_ctx: &PluginContext,
250 ) -> Result<AfterModelCallResult> {
251 Ok(AfterModelCallResult::Continue(response))
252 }
253
254 /// Called when the plugin is being shut down.
255 ///
256 /// Override this method to perform cleanup operations such as flushing
257 /// buffers, closing connections, or persisting state.
258 ///
259 /// The default implementation is a no-op.
260 async fn close(&self) {}
261}