rs_adk/plugin/
security.rs1use async_trait::async_trait;
4
5use rs_genai::prelude::FunctionCall;
6
7use super::{Plugin, PluginResult};
8use crate::context::InvocationContext;
9
10#[derive(Debug, Clone)]
12pub enum PolicyOutcome {
13 Allow,
15 Confirm(String),
17 Deny(String),
19}
20
21pub trait PolicyEngine: Send + Sync + 'static {
26 fn evaluate(&self, tool_name: &str, args: &serde_json::Value) -> PolicyOutcome;
28}
29
30pub struct SecurityPlugin {
37 engine: Box<dyn PolicyEngine>,
38}
39
40impl SecurityPlugin {
41 pub fn new(engine: impl PolicyEngine + 'static) -> Self {
43 Self {
44 engine: Box::new(engine),
45 }
46 }
47}
48
49#[async_trait]
50impl Plugin for SecurityPlugin {
51 fn name(&self) -> &str {
52 "security"
53 }
54
55 async fn before_tool(&self, call: &FunctionCall, _ctx: &InvocationContext) -> PluginResult {
56 match self.engine.evaluate(&call.name, &call.args) {
57 PolicyOutcome::Allow => {
58 #[cfg(feature = "tracing-support")]
59 tracing::debug!(tool = %call.name, "[plugin:security] Tool call allowed");
60 PluginResult::Continue
61 }
62 PolicyOutcome::Confirm(msg) => {
63 #[cfg(feature = "tracing-support")]
64 tracing::warn!(tool = %call.name, reason = %msg, "[plugin:security] Tool call requires confirmation");
65 PluginResult::Deny(format!("Confirmation required: {}", msg))
66 }
67 PolicyOutcome::Deny(reason) => {
68 #[cfg(feature = "tracing-support")]
69 tracing::warn!(tool = %call.name, reason = %reason, "[plugin:security] Tool call denied");
70 PluginResult::Deny(reason)
71 }
72 }
73 }
74}
75
76pub struct DenyListPolicy {
78 blocked_tools: Vec<String>,
79}
80
81impl DenyListPolicy {
82 pub fn new(blocked_tools: Vec<String>) -> Self {
84 Self { blocked_tools }
85 }
86}
87
88impl PolicyEngine for DenyListPolicy {
89 fn evaluate(&self, tool_name: &str, _args: &serde_json::Value) -> PolicyOutcome {
90 if self.blocked_tools.iter().any(|t| t == tool_name) {
91 PolicyOutcome::Deny(format!("Tool '{}' is blocked by policy", tool_name))
92 } else {
93 PolicyOutcome::Allow
94 }
95 }
96}
97
98pub struct AllowAllPolicy;
100
101impl PolicyEngine for AllowAllPolicy {
102 fn evaluate(&self, _tool_name: &str, _args: &serde_json::Value) -> PolicyOutcome {
103 PolicyOutcome::Allow
104 }
105}
106
107#[cfg(test)]
108mod tests {
109 use super::*;
110
111 #[test]
112 fn deny_list_policy_blocks() {
113 let policy = DenyListPolicy::new(vec!["dangerous_tool".into()]);
114 let result = policy.evaluate("dangerous_tool", &serde_json::json!({}));
115 assert!(matches!(result, PolicyOutcome::Deny(_)));
116 }
117
118 #[test]
119 fn deny_list_policy_allows() {
120 let policy = DenyListPolicy::new(vec!["dangerous_tool".into()]);
121 let result = policy.evaluate("safe_tool", &serde_json::json!({}));
122 assert!(matches!(result, PolicyOutcome::Allow));
123 }
124
125 #[test]
126 fn allow_all_policy() {
127 let policy = AllowAllPolicy;
128 let result = policy.evaluate("anything", &serde_json::json!({}));
129 assert!(matches!(result, PolicyOutcome::Allow));
130 }
131
132 #[tokio::test]
133 async fn security_plugin_denies_blocked_tool() {
134 use tokio::sync::broadcast;
135
136 let policy = DenyListPolicy::new(vec!["rm_rf".into()]);
137 let plugin = SecurityPlugin::new(policy);
138
139 let (evt_tx, _) = broadcast::channel(16);
140 let writer: std::sync::Arc<dyn rs_genai::session::SessionWriter> =
141 std::sync::Arc::new(crate::test_helpers::MockWriter);
142 let session = crate::agent_session::AgentSession::from_writer(writer, evt_tx);
143 let ctx = InvocationContext::new(session);
144
145 let call = FunctionCall {
146 name: "rm_rf".into(),
147 args: serde_json::json!({}),
148 id: None,
149 };
150
151 let result = plugin.before_tool(&call, &ctx).await;
152 assert!(result.is_deny());
153 }
154
155 #[tokio::test]
156 async fn security_plugin_allows_safe_tool() {
157 use tokio::sync::broadcast;
158
159 let policy = DenyListPolicy::new(vec!["rm_rf".into()]);
160 let plugin = SecurityPlugin::new(policy);
161
162 let (evt_tx, _) = broadcast::channel(16);
163 let writer: std::sync::Arc<dyn rs_genai::session::SessionWriter> =
164 std::sync::Arc::new(crate::test_helpers::MockWriter);
165 let session = crate::agent_session::AgentSession::from_writer(writer, evt_tx);
166 let ctx = InvocationContext::new(session);
167
168 let call = FunctionCall {
169 name: "get_weather".into(),
170 args: serde_json::json!({}),
171 id: None,
172 };
173
174 let result = plugin.before_tool(&call, &ctx).await;
175 assert!(result.is_continue());
176 }
177}