Skip to main content

rs_adk/plugin/
security.rs

1//! Security plugin — policy-based tool call authorization.
2
3use async_trait::async_trait;
4
5use rs_genai::prelude::FunctionCall;
6
7use super::{Plugin, PluginResult};
8use crate::context::InvocationContext;
9
10/// The outcome of a policy evaluation.
11#[derive(Debug, Clone)]
12pub enum PolicyOutcome {
13    /// Allow the tool call to proceed.
14    Allow,
15    /// Require user confirmation before proceeding.
16    Confirm(String),
17    /// Deny the tool call with a reason.
18    Deny(String),
19}
20
21/// Trait for evaluating tool call policies.
22///
23/// Implementations can check tool names, arguments, user permissions,
24/// rate limits, etc.
25pub trait PolicyEngine: Send + Sync + 'static {
26    /// Evaluate whether a tool call should be allowed.
27    fn evaluate(&self, tool_name: &str, args: &serde_json::Value) -> PolicyOutcome;
28}
29
30/// Plugin that enforces tool call policies via a `PolicyEngine`.
31///
32/// Before every tool call, the security plugin consults the policy engine.
33/// If the engine returns `Deny`, the tool call is blocked. If it returns
34/// `Confirm`, the tool call is blocked with a confirmation message (in a
35/// real system, this would prompt the user).
36pub struct SecurityPlugin {
37    engine: Box<dyn PolicyEngine>,
38}
39
40impl SecurityPlugin {
41    /// Create a new security plugin with the given policy engine.
42    pub fn new(engine: impl PolicyEngine + 'static) -> Self {
43        Self {
44            engine: Box::new(engine),
45        }
46    }
47}
48
49#[async_trait]
50impl Plugin for SecurityPlugin {
51    fn name(&self) -> &str {
52        "security"
53    }
54
55    async fn before_tool(&self, call: &FunctionCall, _ctx: &InvocationContext) -> PluginResult {
56        match self.engine.evaluate(&call.name, &call.args) {
57            PolicyOutcome::Allow => {
58                #[cfg(feature = "tracing-support")]
59                tracing::debug!(tool = %call.name, "[plugin:security] Tool call allowed");
60                PluginResult::Continue
61            }
62            PolicyOutcome::Confirm(msg) => {
63                #[cfg(feature = "tracing-support")]
64                tracing::warn!(tool = %call.name, reason = %msg, "[plugin:security] Tool call requires confirmation");
65                PluginResult::Deny(format!("Confirmation required: {}", msg))
66            }
67            PolicyOutcome::Deny(reason) => {
68                #[cfg(feature = "tracing-support")]
69                tracing::warn!(tool = %call.name, reason = %reason, "[plugin:security] Tool call denied");
70                PluginResult::Deny(reason)
71            }
72        }
73    }
74}
75
76/// A simple policy engine that blocks specific tool names.
77pub struct DenyListPolicy {
78    blocked_tools: Vec<String>,
79}
80
81impl DenyListPolicy {
82    /// Create a policy that denies specific tools by name.
83    pub fn new(blocked_tools: Vec<String>) -> Self {
84        Self { blocked_tools }
85    }
86}
87
88impl PolicyEngine for DenyListPolicy {
89    fn evaluate(&self, tool_name: &str, _args: &serde_json::Value) -> PolicyOutcome {
90        if self.blocked_tools.iter().any(|t| t == tool_name) {
91            PolicyOutcome::Deny(format!("Tool '{}' is blocked by policy", tool_name))
92        } else {
93            PolicyOutcome::Allow
94        }
95    }
96}
97
98/// A policy engine that allows all tool calls.
99pub struct AllowAllPolicy;
100
101impl PolicyEngine for AllowAllPolicy {
102    fn evaluate(&self, _tool_name: &str, _args: &serde_json::Value) -> PolicyOutcome {
103        PolicyOutcome::Allow
104    }
105}
106
107#[cfg(test)]
108mod tests {
109    use super::*;
110
111    #[test]
112    fn deny_list_policy_blocks() {
113        let policy = DenyListPolicy::new(vec!["dangerous_tool".into()]);
114        let result = policy.evaluate("dangerous_tool", &serde_json::json!({}));
115        assert!(matches!(result, PolicyOutcome::Deny(_)));
116    }
117
118    #[test]
119    fn deny_list_policy_allows() {
120        let policy = DenyListPolicy::new(vec!["dangerous_tool".into()]);
121        let result = policy.evaluate("safe_tool", &serde_json::json!({}));
122        assert!(matches!(result, PolicyOutcome::Allow));
123    }
124
125    #[test]
126    fn allow_all_policy() {
127        let policy = AllowAllPolicy;
128        let result = policy.evaluate("anything", &serde_json::json!({}));
129        assert!(matches!(result, PolicyOutcome::Allow));
130    }
131
132    #[tokio::test]
133    async fn security_plugin_denies_blocked_tool() {
134        use tokio::sync::broadcast;
135
136        let policy = DenyListPolicy::new(vec!["rm_rf".into()]);
137        let plugin = SecurityPlugin::new(policy);
138
139        let (evt_tx, _) = broadcast::channel(16);
140        let writer: std::sync::Arc<dyn rs_genai::session::SessionWriter> =
141            std::sync::Arc::new(crate::test_helpers::MockWriter);
142        let session = crate::agent_session::AgentSession::from_writer(writer, evt_tx);
143        let ctx = InvocationContext::new(session);
144
145        let call = FunctionCall {
146            name: "rm_rf".into(),
147            args: serde_json::json!({}),
148            id: None,
149        };
150
151        let result = plugin.before_tool(&call, &ctx).await;
152        assert!(result.is_deny());
153    }
154
155    #[tokio::test]
156    async fn security_plugin_allows_safe_tool() {
157        use tokio::sync::broadcast;
158
159        let policy = DenyListPolicy::new(vec!["rm_rf".into()]);
160        let plugin = SecurityPlugin::new(policy);
161
162        let (evt_tx, _) = broadcast::channel(16);
163        let writer: std::sync::Arc<dyn rs_genai::session::SessionWriter> =
164            std::sync::Arc::new(crate::test_helpers::MockWriter);
165        let session = crate::agent_session::AgentSession::from_writer(writer, evt_tx);
166        let ctx = InvocationContext::new(session);
167
168        let call = FunctionCall {
169            name: "get_weather".into(),
170            args: serde_json::json!({}),
171            id: None,
172        };
173
174        let result = plugin.before_tool(&call, &ctx).await;
175        assert!(result.is_continue());
176    }
177}