1mod context_filter;
9mod global_instruction;
10mod logging;
11mod reflect_retry;
12mod security;
13
14pub use context_filter::ContextFilterPlugin;
15pub use global_instruction::GlobalInstructionPlugin;
16pub use logging::LoggingPlugin;
17pub use reflect_retry::ReflectRetryToolPlugin;
18pub use security::{AllowAllPolicy, DenyListPolicy, PolicyEngine, PolicyOutcome, SecurityPlugin};
19
20use std::sync::Arc;
21
22use async_trait::async_trait;
23
24use rs_genai::prelude::FunctionCall;
25
26use crate::context::InvocationContext;
27use crate::events::Event;
28
29#[derive(Debug, Clone)]
31pub enum PluginResult {
32 Continue,
34 ShortCircuit(serde_json::Value),
36 Deny(String),
38}
39
40impl PluginResult {
41 pub fn is_continue(&self) -> bool {
43 matches!(self, Self::Continue)
44 }
45
46 pub fn is_deny(&self) -> bool {
48 matches!(self, Self::Deny(_))
49 }
50
51 pub fn is_short_circuit(&self) -> bool {
53 matches!(self, Self::ShortCircuit(_))
54 }
55}
56
57#[async_trait]
62pub trait Plugin: Send + Sync + 'static {
63 fn name(&self) -> &str;
65
66 async fn before_agent(&self, _ctx: &InvocationContext) -> PluginResult {
68 PluginResult::Continue
69 }
70
71 async fn after_agent(&self, _ctx: &InvocationContext) -> PluginResult {
73 PluginResult::Continue
74 }
75
76 async fn before_tool(&self, _call: &FunctionCall, _ctx: &InvocationContext) -> PluginResult {
78 PluginResult::Continue
79 }
80
81 async fn after_tool(
83 &self,
84 _call: &FunctionCall,
85 _result: &serde_json::Value,
86 _ctx: &InvocationContext,
87 ) -> PluginResult {
88 PluginResult::Continue
89 }
90
91 async fn on_event(&self, _event: &Event, _ctx: &InvocationContext) -> PluginResult {
93 PluginResult::Continue
94 }
95
96 async fn on_user_message(&self, _message: &str, _ctx: &InvocationContext) -> PluginResult {
98 PluginResult::Continue
99 }
100
101 async fn before_run(&self, _ctx: &InvocationContext) -> PluginResult {
103 PluginResult::Continue
104 }
105
106 async fn after_run(&self, _ctx: &InvocationContext) -> PluginResult {
108 PluginResult::Continue
109 }
110
111 async fn before_model(
113 &self,
114 _request: &crate::llm::LlmRequest,
115 _ctx: &InvocationContext,
116 ) -> PluginResult {
117 PluginResult::Continue
118 }
119
120 async fn after_model(
122 &self,
123 _response: &crate::llm::LlmResponse,
124 _ctx: &InvocationContext,
125 ) -> PluginResult {
126 PluginResult::Continue
127 }
128
129 async fn on_model_error(&self, _error: &str, _ctx: &InvocationContext) -> PluginResult {
131 PluginResult::Continue
132 }
133
134 async fn on_tool_error(
136 &self,
137 _call: &FunctionCall,
138 _error: &str,
139 _ctx: &InvocationContext,
140 ) -> PluginResult {
141 PluginResult::Continue
142 }
143}
144
145#[derive(Clone, Default)]
150pub struct PluginManager {
151 plugins: Vec<Arc<dyn Plugin>>,
152}
153
154impl PluginManager {
155 pub fn new() -> Self {
157 Self::default()
158 }
159
160 pub fn add(&mut self, plugin: Arc<dyn Plugin>) {
162 self.plugins.push(plugin);
163 }
164
165 pub fn len(&self) -> usize {
167 self.plugins.len()
168 }
169
170 pub fn is_empty(&self) -> bool {
172 self.plugins.is_empty()
173 }
174
175 pub async fn run_before_agent(&self, ctx: &InvocationContext) -> PluginResult {
177 for plugin in &self.plugins {
178 let result = plugin.before_agent(ctx).await;
179 if !result.is_continue() {
180 return result;
181 }
182 }
183 PluginResult::Continue
184 }
185
186 pub async fn run_after_agent(&self, ctx: &InvocationContext) -> PluginResult {
188 for plugin in self.plugins.iter().rev() {
189 let result = plugin.after_agent(ctx).await;
190 if !result.is_continue() {
191 return result;
192 }
193 }
194 PluginResult::Continue
195 }
196
197 pub async fn run_before_tool(
199 &self,
200 call: &FunctionCall,
201 ctx: &InvocationContext,
202 ) -> PluginResult {
203 for plugin in &self.plugins {
204 let result = plugin.before_tool(call, ctx).await;
205 if !result.is_continue() {
206 return result;
207 }
208 }
209 PluginResult::Continue
210 }
211
212 pub async fn run_after_tool(
214 &self,
215 call: &FunctionCall,
216 value: &serde_json::Value,
217 ctx: &InvocationContext,
218 ) -> PluginResult {
219 for plugin in self.plugins.iter().rev() {
220 let result = plugin.after_tool(call, value, ctx).await;
221 if !result.is_continue() {
222 return result;
223 }
224 }
225 PluginResult::Continue
226 }
227
228 pub async fn run_on_event(&self, event: &Event, ctx: &InvocationContext) -> PluginResult {
230 for plugin in &self.plugins {
231 let result = plugin.on_event(event, ctx).await;
232 if !result.is_continue() {
233 return result;
234 }
235 }
236 PluginResult::Continue
237 }
238
239 pub async fn run_on_user_message(
241 &self,
242 message: &str,
243 ctx: &InvocationContext,
244 ) -> PluginResult {
245 for plugin in &self.plugins {
246 let result = plugin.on_user_message(message, ctx).await;
247 if !result.is_continue() {
248 return result;
249 }
250 }
251 PluginResult::Continue
252 }
253
254 pub async fn run_before_run(&self, ctx: &InvocationContext) -> PluginResult {
256 for plugin in &self.plugins {
257 let result = plugin.before_run(ctx).await;
258 if !result.is_continue() {
259 return result;
260 }
261 }
262 PluginResult::Continue
263 }
264
265 pub async fn run_after_run(&self, ctx: &InvocationContext) -> PluginResult {
267 for plugin in self.plugins.iter().rev() {
268 let result = plugin.after_run(ctx).await;
269 if !result.is_continue() {
270 return result;
271 }
272 }
273 PluginResult::Continue
274 }
275
276 pub async fn run_before_model(
278 &self,
279 request: &crate::llm::LlmRequest,
280 ctx: &InvocationContext,
281 ) -> PluginResult {
282 for plugin in &self.plugins {
283 let result = plugin.before_model(request, ctx).await;
284 if !result.is_continue() {
285 return result;
286 }
287 }
288 PluginResult::Continue
289 }
290
291 pub async fn run_after_model(
293 &self,
294 response: &crate::llm::LlmResponse,
295 ctx: &InvocationContext,
296 ) -> PluginResult {
297 for plugin in self.plugins.iter().rev() {
298 let result = plugin.after_model(response, ctx).await;
299 if !result.is_continue() {
300 return result;
301 }
302 }
303 PluginResult::Continue
304 }
305
306 pub async fn run_on_model_error(&self, error: &str, ctx: &InvocationContext) -> PluginResult {
308 for plugin in &self.plugins {
309 let result = plugin.on_model_error(error, ctx).await;
310 if !result.is_continue() {
311 return result;
312 }
313 }
314 PluginResult::Continue
315 }
316
317 pub async fn run_on_tool_error(
319 &self,
320 call: &FunctionCall,
321 error: &str,
322 ctx: &InvocationContext,
323 ) -> PluginResult {
324 for plugin in &self.plugins {
325 let result = plugin.on_tool_error(call, error, ctx).await;
326 if !result.is_continue() {
327 return result;
328 }
329 }
330 PluginResult::Continue
331 }
332}
333
334#[cfg(test)]
335mod tests {
336 use super::*;
337
338 #[test]
339 fn plugin_result_helpers() {
340 assert!(PluginResult::Continue.is_continue());
341 assert!(!PluginResult::Continue.is_deny());
342 assert!(!PluginResult::Continue.is_short_circuit());
343
344 assert!(PluginResult::Deny("nope".into()).is_deny());
345 assert!(!PluginResult::Deny("nope".into()).is_continue());
346
347 let val = serde_json::json!({"cached": true});
348 assert!(PluginResult::ShortCircuit(val).is_short_circuit());
349 }
350
351 #[test]
352 fn plugin_manager_empty() {
353 let pm = PluginManager::new();
354 assert!(pm.is_empty());
355 assert_eq!(pm.len(), 0);
356 }
357
358 #[test]
359 fn plugin_manager_add() {
360 let mut pm = PluginManager::new();
361 pm.add(Arc::new(LoggingPlugin::new()));
362 assert_eq!(pm.len(), 1);
363 assert!(!pm.is_empty());
364 }
365
366 #[test]
367 fn plugin_is_object_safe() {
368 fn _assert(_: &dyn Plugin) {}
369 }
370
371 struct DenyPlugin;
372
373 #[async_trait]
374 impl Plugin for DenyPlugin {
375 fn name(&self) -> &str {
376 "deny"
377 }
378
379 async fn before_tool(
380 &self,
381 _call: &FunctionCall,
382 _ctx: &InvocationContext,
383 ) -> PluginResult {
384 PluginResult::Deny("blocked by policy".into())
385 }
386 }
387
388 struct CountPlugin {
389 count: std::sync::atomic::AtomicU32,
390 }
391
392 #[async_trait]
393 impl Plugin for CountPlugin {
394 fn name(&self) -> &str {
395 "count"
396 }
397
398 async fn before_tool(
399 &self,
400 _call: &FunctionCall,
401 _ctx: &InvocationContext,
402 ) -> PluginResult {
403 self.count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
404 PluginResult::Continue
405 }
406 }
407
408 #[tokio::test]
410 async fn new_hooks_default_to_continue() {
411 use tokio::sync::broadcast;
412
413 let mut pm = PluginManager::new();
414 pm.add(Arc::new(LoggingPlugin::new()));
415
416 let (evt_tx, _) = broadcast::channel(16);
417 let writer: Arc<dyn rs_genai::session::SessionWriter> =
418 Arc::new(crate::test_helpers::MockWriter);
419 let session = crate::agent_session::AgentSession::from_writer(writer, evt_tx);
420 let ctx = InvocationContext::new(session);
421
422 assert!(pm.run_before_run(&ctx).await.is_continue());
423 assert!(pm.run_after_run(&ctx).await.is_continue());
424 assert!(pm.run_on_user_message("hello", &ctx).await.is_continue());
425
426 let req = crate::llm::LlmRequest::from_text("test");
427 assert!(pm.run_before_model(&req, &ctx).await.is_continue());
428
429 assert!(pm.run_on_model_error("err", &ctx).await.is_continue());
430
431 let call = FunctionCall {
432 name: "t".into(),
433 args: serde_json::json!({}),
434 id: None,
435 };
436 assert!(pm.run_on_tool_error(&call, "err", &ctx).await.is_continue());
437 }
438
439 struct ModelBlockerPlugin;
441
442 #[async_trait]
443 impl Plugin for ModelBlockerPlugin {
444 fn name(&self) -> &str {
445 "model-blocker"
446 }
447
448 async fn before_model(
449 &self,
450 _request: &crate::llm::LlmRequest,
451 _ctx: &InvocationContext,
452 ) -> PluginResult {
453 PluginResult::Deny("model calls blocked".into())
454 }
455 }
456
457 #[tokio::test]
458 async fn custom_before_model_plugin() {
459 use tokio::sync::broadcast;
460
461 let mut pm = PluginManager::new();
462 pm.add(Arc::new(ModelBlockerPlugin));
463
464 let (evt_tx, _) = broadcast::channel(16);
465 let writer: Arc<dyn rs_genai::session::SessionWriter> =
466 Arc::new(crate::test_helpers::MockWriter);
467 let session = crate::agent_session::AgentSession::from_writer(writer, evt_tx);
468 let ctx = InvocationContext::new(session);
469
470 let req = crate::llm::LlmRequest::from_text("test");
471 let result = pm.run_before_model(&req, &ctx).await;
472 assert!(result.is_deny());
473 }
474
475 #[tokio::test]
477 async fn plugin_manager_deny_short_circuits() {
478 use tokio::sync::broadcast;
479
480 let count_plugin = Arc::new(CountPlugin {
481 count: std::sync::atomic::AtomicU32::new(0),
482 });
483
484 let mut pm = PluginManager::new();
485 pm.add(Arc::new(DenyPlugin));
486 pm.add(count_plugin.clone());
487
488 let (evt_tx, _) = broadcast::channel(16);
490 let writer: Arc<dyn rs_genai::session::SessionWriter> =
491 Arc::new(crate::test_helpers::MockWriter);
492 let session = crate::agent_session::AgentSession::from_writer(writer, evt_tx);
493 let ctx = InvocationContext::new(session);
494
495 let call = FunctionCall {
496 name: "dangerous_tool".into(),
497 args: serde_json::json!({}),
498 id: None,
499 };
500
501 let result = pm.run_before_tool(&call, &ctx).await;
502 assert!(result.is_deny());
503
504 assert_eq!(
506 count_plugin.count.load(std::sync::atomic::Ordering::SeqCst),
507 0
508 );
509 }
510}