Skip to main content

rs_adk/plugin/
mod.rs

1//! Plugin system — lifecycle hooks with control-flow capabilities.
2//!
3//! Plugins are a superset of middleware: they can observe AND control agent
4//! execution. A plugin can deny a tool call, short-circuit with a custom
5//! response, or simply continue. The `PluginManager` runs plugins in order
6//! and respects the first non-Continue result.
7
8mod context_filter;
9mod global_instruction;
10mod logging;
11mod reflect_retry;
12mod security;
13
14pub use context_filter::ContextFilterPlugin;
15pub use global_instruction::GlobalInstructionPlugin;
16pub use logging::LoggingPlugin;
17pub use reflect_retry::ReflectRetryToolPlugin;
18pub use security::{AllowAllPolicy, DenyListPolicy, PolicyEngine, PolicyOutcome, SecurityPlugin};
19
20use std::sync::Arc;
21
22use async_trait::async_trait;
23
24use rs_genai::prelude::FunctionCall;
25
26use crate::context::InvocationContext;
27use crate::events::Event;
28
29/// The result of a plugin hook — controls whether execution continues.
30#[derive(Debug, Clone)]
31pub enum PluginResult {
32    /// Continue with normal execution.
33    Continue,
34    /// Short-circuit execution with a custom value (e.g., cached response).
35    ShortCircuit(serde_json::Value),
36    /// Deny the action with a reason string.
37    Deny(String),
38}
39
40impl PluginResult {
41    /// Returns true if this result is `Continue`.
42    pub fn is_continue(&self) -> bool {
43        matches!(self, Self::Continue)
44    }
45
46    /// Returns true if this result is `Deny`.
47    pub fn is_deny(&self) -> bool {
48        matches!(self, Self::Deny(_))
49    }
50
51    /// Returns true if this result is `ShortCircuit`.
52    pub fn is_short_circuit(&self) -> bool {
53        matches!(self, Self::ShortCircuit(_))
54    }
55}
56
57/// Plugin trait — lifecycle hooks with control-flow capabilities.
58///
59/// Unlike `Middleware` (which is observe-only), plugins can deny or
60/// short-circuit execution. All hooks default to `PluginResult::Continue`.
61#[async_trait]
62pub trait Plugin: Send + Sync + 'static {
63    /// Plugin name for logging/debugging.
64    fn name(&self) -> &str;
65
66    /// Called before an agent starts execution.
67    async fn before_agent(&self, _ctx: &InvocationContext) -> PluginResult {
68        PluginResult::Continue
69    }
70
71    /// Called after an agent completes execution.
72    async fn after_agent(&self, _ctx: &InvocationContext) -> PluginResult {
73        PluginResult::Continue
74    }
75
76    /// Called before a tool is executed. Return `Deny` to prevent execution.
77    async fn before_tool(&self, _call: &FunctionCall, _ctx: &InvocationContext) -> PluginResult {
78        PluginResult::Continue
79    }
80
81    /// Called after a tool completes. Can transform or deny the result.
82    async fn after_tool(
83        &self,
84        _call: &FunctionCall,
85        _result: &serde_json::Value,
86        _ctx: &InvocationContext,
87    ) -> PluginResult {
88        PluginResult::Continue
89    }
90
91    /// Called when an event is emitted.
92    async fn on_event(&self, _event: &Event, _ctx: &InvocationContext) -> PluginResult {
93        PluginResult::Continue
94    }
95
96    /// Called when a user message is received.
97    async fn on_user_message(&self, _message: &str, _ctx: &InvocationContext) -> PluginResult {
98        PluginResult::Continue
99    }
100
101    /// Called before a run starts (before the agent loop).
102    async fn before_run(&self, _ctx: &InvocationContext) -> PluginResult {
103        PluginResult::Continue
104    }
105
106    /// Called after a run completes (after the agent loop).
107    async fn after_run(&self, _ctx: &InvocationContext) -> PluginResult {
108        PluginResult::Continue
109    }
110
111    /// Called before a model generation request.
112    async fn before_model(
113        &self,
114        _request: &crate::llm::LlmRequest,
115        _ctx: &InvocationContext,
116    ) -> PluginResult {
117        PluginResult::Continue
118    }
119
120    /// Called after a model generation response.
121    async fn after_model(
122        &self,
123        _response: &crate::llm::LlmResponse,
124        _ctx: &InvocationContext,
125    ) -> PluginResult {
126        PluginResult::Continue
127    }
128
129    /// Called when a model generation fails.
130    async fn on_model_error(&self, _error: &str, _ctx: &InvocationContext) -> PluginResult {
131        PluginResult::Continue
132    }
133
134    /// Called when a tool execution fails.
135    async fn on_tool_error(
136        &self,
137        _call: &FunctionCall,
138        _error: &str,
139        _ctx: &InvocationContext,
140    ) -> PluginResult {
141        PluginResult::Continue
142    }
143}
144
145/// Manages an ordered list of plugins, running them in sequence.
146///
147/// On each hook, plugins run in order. The first non-Continue result
148/// short-circuits the remaining plugins.
149#[derive(Clone, Default)]
150pub struct PluginManager {
151    plugins: Vec<Arc<dyn Plugin>>,
152}
153
154impl PluginManager {
155    /// Create an empty plugin manager.
156    pub fn new() -> Self {
157        Self::default()
158    }
159
160    /// Add a plugin to the manager.
161    pub fn add(&mut self, plugin: Arc<dyn Plugin>) {
162        self.plugins.push(plugin);
163    }
164
165    /// Number of registered plugins.
166    pub fn len(&self) -> usize {
167        self.plugins.len()
168    }
169
170    /// Returns true if no plugins are registered.
171    pub fn is_empty(&self) -> bool {
172        self.plugins.is_empty()
173    }
174
175    /// Run before_agent hooks. Returns first non-Continue result, or Continue.
176    pub async fn run_before_agent(&self, ctx: &InvocationContext) -> PluginResult {
177        for plugin in &self.plugins {
178            let result = plugin.before_agent(ctx).await;
179            if !result.is_continue() {
180                return result;
181            }
182        }
183        PluginResult::Continue
184    }
185
186    /// Run after_agent hooks. Returns first non-Continue result, or Continue.
187    pub async fn run_after_agent(&self, ctx: &InvocationContext) -> PluginResult {
188        for plugin in self.plugins.iter().rev() {
189            let result = plugin.after_agent(ctx).await;
190            if !result.is_continue() {
191                return result;
192            }
193        }
194        PluginResult::Continue
195    }
196
197    /// Run before_tool hooks. Returns first non-Continue result, or Continue.
198    pub async fn run_before_tool(
199        &self,
200        call: &FunctionCall,
201        ctx: &InvocationContext,
202    ) -> PluginResult {
203        for plugin in &self.plugins {
204            let result = plugin.before_tool(call, ctx).await;
205            if !result.is_continue() {
206                return result;
207            }
208        }
209        PluginResult::Continue
210    }
211
212    /// Run after_tool hooks. Returns first non-Continue result, or Continue.
213    pub async fn run_after_tool(
214        &self,
215        call: &FunctionCall,
216        value: &serde_json::Value,
217        ctx: &InvocationContext,
218    ) -> PluginResult {
219        for plugin in self.plugins.iter().rev() {
220            let result = plugin.after_tool(call, value, ctx).await;
221            if !result.is_continue() {
222                return result;
223            }
224        }
225        PluginResult::Continue
226    }
227
228    /// Run on_event hooks. Returns first non-Continue result, or Continue.
229    pub async fn run_on_event(&self, event: &Event, ctx: &InvocationContext) -> PluginResult {
230        for plugin in &self.plugins {
231            let result = plugin.on_event(event, ctx).await;
232            if !result.is_continue() {
233                return result;
234            }
235        }
236        PluginResult::Continue
237    }
238
239    /// Run on_user_message hooks.
240    pub async fn run_on_user_message(
241        &self,
242        message: &str,
243        ctx: &InvocationContext,
244    ) -> PluginResult {
245        for plugin in &self.plugins {
246            let result = plugin.on_user_message(message, ctx).await;
247            if !result.is_continue() {
248                return result;
249            }
250        }
251        PluginResult::Continue
252    }
253
254    /// Run before_run hooks.
255    pub async fn run_before_run(&self, ctx: &InvocationContext) -> PluginResult {
256        for plugin in &self.plugins {
257            let result = plugin.before_run(ctx).await;
258            if !result.is_continue() {
259                return result;
260            }
261        }
262        PluginResult::Continue
263    }
264
265    /// Run after_run hooks.
266    pub async fn run_after_run(&self, ctx: &InvocationContext) -> PluginResult {
267        for plugin in self.plugins.iter().rev() {
268            let result = plugin.after_run(ctx).await;
269            if !result.is_continue() {
270                return result;
271            }
272        }
273        PluginResult::Continue
274    }
275
276    /// Run before_model hooks.
277    pub async fn run_before_model(
278        &self,
279        request: &crate::llm::LlmRequest,
280        ctx: &InvocationContext,
281    ) -> PluginResult {
282        for plugin in &self.plugins {
283            let result = plugin.before_model(request, ctx).await;
284            if !result.is_continue() {
285                return result;
286            }
287        }
288        PluginResult::Continue
289    }
290
291    /// Run after_model hooks.
292    pub async fn run_after_model(
293        &self,
294        response: &crate::llm::LlmResponse,
295        ctx: &InvocationContext,
296    ) -> PluginResult {
297        for plugin in self.plugins.iter().rev() {
298            let result = plugin.after_model(response, ctx).await;
299            if !result.is_continue() {
300                return result;
301            }
302        }
303        PluginResult::Continue
304    }
305
306    /// Run on_model_error hooks.
307    pub async fn run_on_model_error(&self, error: &str, ctx: &InvocationContext) -> PluginResult {
308        for plugin in &self.plugins {
309            let result = plugin.on_model_error(error, ctx).await;
310            if !result.is_continue() {
311                return result;
312            }
313        }
314        PluginResult::Continue
315    }
316
317    /// Run on_tool_error hooks.
318    pub async fn run_on_tool_error(
319        &self,
320        call: &FunctionCall,
321        error: &str,
322        ctx: &InvocationContext,
323    ) -> PluginResult {
324        for plugin in &self.plugins {
325            let result = plugin.on_tool_error(call, error, ctx).await;
326            if !result.is_continue() {
327                return result;
328            }
329        }
330        PluginResult::Continue
331    }
332}
333
334#[cfg(test)]
335mod tests {
336    use super::*;
337
338    #[test]
339    fn plugin_result_helpers() {
340        assert!(PluginResult::Continue.is_continue());
341        assert!(!PluginResult::Continue.is_deny());
342        assert!(!PluginResult::Continue.is_short_circuit());
343
344        assert!(PluginResult::Deny("nope".into()).is_deny());
345        assert!(!PluginResult::Deny("nope".into()).is_continue());
346
347        let val = serde_json::json!({"cached": true});
348        assert!(PluginResult::ShortCircuit(val).is_short_circuit());
349    }
350
351    #[test]
352    fn plugin_manager_empty() {
353        let pm = PluginManager::new();
354        assert!(pm.is_empty());
355        assert_eq!(pm.len(), 0);
356    }
357
358    #[test]
359    fn plugin_manager_add() {
360        let mut pm = PluginManager::new();
361        pm.add(Arc::new(LoggingPlugin::new()));
362        assert_eq!(pm.len(), 1);
363        assert!(!pm.is_empty());
364    }
365
366    #[test]
367    fn plugin_is_object_safe() {
368        fn _assert(_: &dyn Plugin) {}
369    }
370
371    struct DenyPlugin;
372
373    #[async_trait]
374    impl Plugin for DenyPlugin {
375        fn name(&self) -> &str {
376            "deny"
377        }
378
379        async fn before_tool(
380            &self,
381            _call: &FunctionCall,
382            _ctx: &InvocationContext,
383        ) -> PluginResult {
384            PluginResult::Deny("blocked by policy".into())
385        }
386    }
387
388    struct CountPlugin {
389        count: std::sync::atomic::AtomicU32,
390    }
391
392    #[async_trait]
393    impl Plugin for CountPlugin {
394        fn name(&self) -> &str {
395            "count"
396        }
397
398        async fn before_tool(
399            &self,
400            _call: &FunctionCall,
401            _ctx: &InvocationContext,
402        ) -> PluginResult {
403            self.count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
404            PluginResult::Continue
405        }
406    }
407
408    // Test that new hooks default to Continue
409    #[tokio::test]
410    async fn new_hooks_default_to_continue() {
411        use tokio::sync::broadcast;
412
413        let mut pm = PluginManager::new();
414        pm.add(Arc::new(LoggingPlugin::new()));
415
416        let (evt_tx, _) = broadcast::channel(16);
417        let writer: Arc<dyn rs_genai::session::SessionWriter> =
418            Arc::new(crate::test_helpers::MockWriter);
419        let session = crate::agent_session::AgentSession::from_writer(writer, evt_tx);
420        let ctx = InvocationContext::new(session);
421
422        assert!(pm.run_before_run(&ctx).await.is_continue());
423        assert!(pm.run_after_run(&ctx).await.is_continue());
424        assert!(pm.run_on_user_message("hello", &ctx).await.is_continue());
425
426        let req = crate::llm::LlmRequest::from_text("test");
427        assert!(pm.run_before_model(&req, &ctx).await.is_continue());
428
429        assert!(pm.run_on_model_error("err", &ctx).await.is_continue());
430
431        let call = FunctionCall {
432            name: "t".into(),
433            args: serde_json::json!({}),
434            id: None,
435        };
436        assert!(pm.run_on_tool_error(&call, "err", &ctx).await.is_continue());
437    }
438
439    // Test custom plugin implementing before_model
440    struct ModelBlockerPlugin;
441
442    #[async_trait]
443    impl Plugin for ModelBlockerPlugin {
444        fn name(&self) -> &str {
445            "model-blocker"
446        }
447
448        async fn before_model(
449            &self,
450            _request: &crate::llm::LlmRequest,
451            _ctx: &InvocationContext,
452        ) -> PluginResult {
453            PluginResult::Deny("model calls blocked".into())
454        }
455    }
456
457    #[tokio::test]
458    async fn custom_before_model_plugin() {
459        use tokio::sync::broadcast;
460
461        let mut pm = PluginManager::new();
462        pm.add(Arc::new(ModelBlockerPlugin));
463
464        let (evt_tx, _) = broadcast::channel(16);
465        let writer: Arc<dyn rs_genai::session::SessionWriter> =
466            Arc::new(crate::test_helpers::MockWriter);
467        let session = crate::agent_session::AgentSession::from_writer(writer, evt_tx);
468        let ctx = InvocationContext::new(session);
469
470        let req = crate::llm::LlmRequest::from_text("test");
471        let result = pm.run_before_model(&req, &ctx).await;
472        assert!(result.is_deny());
473    }
474
475    // Test that a deny plugin prevents later plugins from running
476    #[tokio::test]
477    async fn plugin_manager_deny_short_circuits() {
478        use tokio::sync::broadcast;
479
480        let count_plugin = Arc::new(CountPlugin {
481            count: std::sync::atomic::AtomicU32::new(0),
482        });
483
484        let mut pm = PluginManager::new();
485        pm.add(Arc::new(DenyPlugin));
486        pm.add(count_plugin.clone());
487
488        // Create a minimal InvocationContext for testing
489        let (evt_tx, _) = broadcast::channel(16);
490        let writer: Arc<dyn rs_genai::session::SessionWriter> =
491            Arc::new(crate::test_helpers::MockWriter);
492        let session = crate::agent_session::AgentSession::from_writer(writer, evt_tx);
493        let ctx = InvocationContext::new(session);
494
495        let call = FunctionCall {
496            name: "dangerous_tool".into(),
497            args: serde_json::json!({}),
498            id: None,
499        };
500
501        let result = pm.run_before_tool(&call, &ctx).await;
502        assert!(result.is_deny());
503
504        // CountPlugin should NOT have been called
505        assert_eq!(
506            count_plugin.count.load(std::sync::atomic::Ordering::SeqCst),
507            0
508        );
509    }
510}