mod context_filter;
mod global_instruction;
mod logging;
mod reflect_retry;
mod security;
pub use context_filter::ContextFilterPlugin;
pub use global_instruction::GlobalInstructionPlugin;
pub use logging::LoggingPlugin;
pub use reflect_retry::ReflectRetryToolPlugin;
pub use security::{AllowAllPolicy, DenyListPolicy, PolicyEngine, PolicyOutcome, SecurityPlugin};
use std::sync::Arc;
use async_trait::async_trait;
use rs_genai::prelude::FunctionCall;
use crate::context::InvocationContext;
use crate::events::Event;
#[derive(Debug, Clone)]
pub enum PluginResult {
Continue,
ShortCircuit(serde_json::Value),
Deny(String),
}
impl PluginResult {
pub fn is_continue(&self) -> bool {
matches!(self, Self::Continue)
}
pub fn is_deny(&self) -> bool {
matches!(self, Self::Deny(_))
}
pub fn is_short_circuit(&self) -> bool {
matches!(self, Self::ShortCircuit(_))
}
}
#[async_trait]
pub trait Plugin: Send + Sync + 'static {
fn name(&self) -> &str;
async fn before_agent(&self, _ctx: &InvocationContext) -> PluginResult {
PluginResult::Continue
}
async fn after_agent(&self, _ctx: &InvocationContext) -> PluginResult {
PluginResult::Continue
}
async fn before_tool(&self, _call: &FunctionCall, _ctx: &InvocationContext) -> PluginResult {
PluginResult::Continue
}
async fn after_tool(
&self,
_call: &FunctionCall,
_result: &serde_json::Value,
_ctx: &InvocationContext,
) -> PluginResult {
PluginResult::Continue
}
async fn on_event(&self, _event: &Event, _ctx: &InvocationContext) -> PluginResult {
PluginResult::Continue
}
async fn on_user_message(&self, _message: &str, _ctx: &InvocationContext) -> PluginResult {
PluginResult::Continue
}
async fn before_run(&self, _ctx: &InvocationContext) -> PluginResult {
PluginResult::Continue
}
async fn after_run(&self, _ctx: &InvocationContext) -> PluginResult {
PluginResult::Continue
}
async fn before_model(
&self,
_request: &crate::llm::LlmRequest,
_ctx: &InvocationContext,
) -> PluginResult {
PluginResult::Continue
}
async fn after_model(
&self,
_response: &crate::llm::LlmResponse,
_ctx: &InvocationContext,
) -> PluginResult {
PluginResult::Continue
}
async fn on_model_error(&self, _error: &str, _ctx: &InvocationContext) -> PluginResult {
PluginResult::Continue
}
async fn on_tool_error(
&self,
_call: &FunctionCall,
_error: &str,
_ctx: &InvocationContext,
) -> PluginResult {
PluginResult::Continue
}
}
#[derive(Clone, Default)]
pub struct PluginManager {
plugins: Vec<Arc<dyn Plugin>>,
}
impl PluginManager {
pub fn new() -> Self {
Self::default()
}
pub fn add(&mut self, plugin: Arc<dyn Plugin>) {
self.plugins.push(plugin);
}
pub fn len(&self) -> usize {
self.plugins.len()
}
pub fn is_empty(&self) -> bool {
self.plugins.is_empty()
}
pub async fn run_before_agent(&self, ctx: &InvocationContext) -> PluginResult {
for plugin in &self.plugins {
let result = plugin.before_agent(ctx).await;
if !result.is_continue() {
return result;
}
}
PluginResult::Continue
}
pub async fn run_after_agent(&self, ctx: &InvocationContext) -> PluginResult {
for plugin in self.plugins.iter().rev() {
let result = plugin.after_agent(ctx).await;
if !result.is_continue() {
return result;
}
}
PluginResult::Continue
}
pub async fn run_before_tool(
&self,
call: &FunctionCall,
ctx: &InvocationContext,
) -> PluginResult {
for plugin in &self.plugins {
let result = plugin.before_tool(call, ctx).await;
if !result.is_continue() {
return result;
}
}
PluginResult::Continue
}
pub async fn run_after_tool(
&self,
call: &FunctionCall,
value: &serde_json::Value,
ctx: &InvocationContext,
) -> PluginResult {
for plugin in self.plugins.iter().rev() {
let result = plugin.after_tool(call, value, ctx).await;
if !result.is_continue() {
return result;
}
}
PluginResult::Continue
}
pub async fn run_on_event(&self, event: &Event, ctx: &InvocationContext) -> PluginResult {
for plugin in &self.plugins {
let result = plugin.on_event(event, ctx).await;
if !result.is_continue() {
return result;
}
}
PluginResult::Continue
}
pub async fn run_on_user_message(
&self,
message: &str,
ctx: &InvocationContext,
) -> PluginResult {
for plugin in &self.plugins {
let result = plugin.on_user_message(message, ctx).await;
if !result.is_continue() {
return result;
}
}
PluginResult::Continue
}
pub async fn run_before_run(&self, ctx: &InvocationContext) -> PluginResult {
for plugin in &self.plugins {
let result = plugin.before_run(ctx).await;
if !result.is_continue() {
return result;
}
}
PluginResult::Continue
}
pub async fn run_after_run(&self, ctx: &InvocationContext) -> PluginResult {
for plugin in self.plugins.iter().rev() {
let result = plugin.after_run(ctx).await;
if !result.is_continue() {
return result;
}
}
PluginResult::Continue
}
pub async fn run_before_model(
&self,
request: &crate::llm::LlmRequest,
ctx: &InvocationContext,
) -> PluginResult {
for plugin in &self.plugins {
let result = plugin.before_model(request, ctx).await;
if !result.is_continue() {
return result;
}
}
PluginResult::Continue
}
pub async fn run_after_model(
&self,
response: &crate::llm::LlmResponse,
ctx: &InvocationContext,
) -> PluginResult {
for plugin in self.plugins.iter().rev() {
let result = plugin.after_model(response, ctx).await;
if !result.is_continue() {
return result;
}
}
PluginResult::Continue
}
pub async fn run_on_model_error(&self, error: &str, ctx: &InvocationContext) -> PluginResult {
for plugin in &self.plugins {
let result = plugin.on_model_error(error, ctx).await;
if !result.is_continue() {
return result;
}
}
PluginResult::Continue
}
pub async fn run_on_tool_error(
&self,
call: &FunctionCall,
error: &str,
ctx: &InvocationContext,
) -> PluginResult {
for plugin in &self.plugins {
let result = plugin.on_tool_error(call, error, ctx).await;
if !result.is_continue() {
return result;
}
}
PluginResult::Continue
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn plugin_result_helpers() {
assert!(PluginResult::Continue.is_continue());
assert!(!PluginResult::Continue.is_deny());
assert!(!PluginResult::Continue.is_short_circuit());
assert!(PluginResult::Deny("nope".into()).is_deny());
assert!(!PluginResult::Deny("nope".into()).is_continue());
let val = serde_json::json!({"cached": true});
assert!(PluginResult::ShortCircuit(val).is_short_circuit());
}
#[test]
fn plugin_manager_empty() {
let pm = PluginManager::new();
assert!(pm.is_empty());
assert_eq!(pm.len(), 0);
}
#[test]
fn plugin_manager_add() {
let mut pm = PluginManager::new();
pm.add(Arc::new(LoggingPlugin::new()));
assert_eq!(pm.len(), 1);
assert!(!pm.is_empty());
}
#[test]
fn plugin_is_object_safe() {
fn _assert(_: &dyn Plugin) {}
}
struct DenyPlugin;
#[async_trait]
impl Plugin for DenyPlugin {
fn name(&self) -> &str {
"deny"
}
async fn before_tool(
&self,
_call: &FunctionCall,
_ctx: &InvocationContext,
) -> PluginResult {
PluginResult::Deny("blocked by policy".into())
}
}
struct CountPlugin {
count: std::sync::atomic::AtomicU32,
}
#[async_trait]
impl Plugin for CountPlugin {
fn name(&self) -> &str {
"count"
}
async fn before_tool(
&self,
_call: &FunctionCall,
_ctx: &InvocationContext,
) -> PluginResult {
self.count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
PluginResult::Continue
}
}
#[tokio::test]
async fn new_hooks_default_to_continue() {
use tokio::sync::broadcast;
let mut pm = PluginManager::new();
pm.add(Arc::new(LoggingPlugin::new()));
let (evt_tx, _) = broadcast::channel(16);
let writer: Arc<dyn rs_genai::session::SessionWriter> =
Arc::new(crate::test_helpers::MockWriter);
let session = crate::agent_session::AgentSession::from_writer(writer, evt_tx);
let ctx = InvocationContext::new(session);
assert!(pm.run_before_run(&ctx).await.is_continue());
assert!(pm.run_after_run(&ctx).await.is_continue());
assert!(pm.run_on_user_message("hello", &ctx).await.is_continue());
let req = crate::llm::LlmRequest::from_text("test");
assert!(pm.run_before_model(&req, &ctx).await.is_continue());
assert!(pm.run_on_model_error("err", &ctx).await.is_continue());
let call = FunctionCall {
name: "t".into(),
args: serde_json::json!({}),
id: None,
};
assert!(pm.run_on_tool_error(&call, "err", &ctx).await.is_continue());
}
struct ModelBlockerPlugin;
#[async_trait]
impl Plugin for ModelBlockerPlugin {
fn name(&self) -> &str {
"model-blocker"
}
async fn before_model(
&self,
_request: &crate::llm::LlmRequest,
_ctx: &InvocationContext,
) -> PluginResult {
PluginResult::Deny("model calls blocked".into())
}
}
#[tokio::test]
async fn custom_before_model_plugin() {
use tokio::sync::broadcast;
let mut pm = PluginManager::new();
pm.add(Arc::new(ModelBlockerPlugin));
let (evt_tx, _) = broadcast::channel(16);
let writer: Arc<dyn rs_genai::session::SessionWriter> =
Arc::new(crate::test_helpers::MockWriter);
let session = crate::agent_session::AgentSession::from_writer(writer, evt_tx);
let ctx = InvocationContext::new(session);
let req = crate::llm::LlmRequest::from_text("test");
let result = pm.run_before_model(&req, &ctx).await;
assert!(result.is_deny());
}
#[tokio::test]
async fn plugin_manager_deny_short_circuits() {
use tokio::sync::broadcast;
let count_plugin = Arc::new(CountPlugin {
count: std::sync::atomic::AtomicU32::new(0),
});
let mut pm = PluginManager::new();
pm.add(Arc::new(DenyPlugin));
pm.add(count_plugin.clone());
let (evt_tx, _) = broadcast::channel(16);
let writer: Arc<dyn rs_genai::session::SessionWriter> =
Arc::new(crate::test_helpers::MockWriter);
let session = crate::agent_session::AgentSession::from_writer(writer, evt_tx);
let ctx = InvocationContext::new(session);
let call = FunctionCall {
name: "dangerous_tool".into(),
args: serde_json::json!({}),
id: None,
};
let result = pm.run_before_tool(&call, &ctx).await;
assert!(result.is_deny());
assert_eq!(
count_plugin.count.load(std::sync::atomic::Ordering::SeqCst),
0
);
}
}