synaptic_middleware/
security.rs1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use serde_json::Value;
6use synaptic_core::SynapticError;
7
8use crate::{AgentMiddleware, ToolCallRequest, ToolCaller};
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
12pub enum RiskLevel {
13 None,
14 Low,
15 Medium,
16 High,
17 Critical,
18}
19
20#[async_trait]
22pub trait SecurityAnalyzer: Send + Sync {
23 async fn assess(&self, tool_name: &str, args: &Value) -> Result<RiskLevel, SynapticError>;
24}
25
26#[async_trait]
28pub trait ConfirmationPolicy: Send + Sync {
29 async fn should_confirm(&self, tool_name: &str, risk: RiskLevel)
30 -> Result<bool, SynapticError>;
31}
32
33#[async_trait]
35pub trait SecurityConfirmationCallback: Send + Sync {
36 async fn confirm(
37 &self,
38 tool_name: &str,
39 args: &Value,
40 risk: RiskLevel,
41 ) -> Result<bool, SynapticError>;
42}
43
44pub struct RuleBasedAnalyzer {
46 tool_risks: HashMap<String, RiskLevel>,
47 arg_patterns: Vec<ArgPattern>,
48 default_risk: RiskLevel,
49}
50
51struct ArgPattern {
53 key: String,
54 pattern: String,
55 risk: RiskLevel,
56}
57
58impl RuleBasedAnalyzer {
59 pub fn new() -> Self {
60 Self {
61 tool_risks: HashMap::new(),
62 arg_patterns: Vec::new(),
63 default_risk: RiskLevel::Low,
64 }
65 }
66
67 pub fn with_default_risk(mut self, risk: RiskLevel) -> Self {
69 self.default_risk = risk;
70 self
71 }
72
73 pub fn with_tool_risk(mut self, tool_name: impl Into<String>, risk: RiskLevel) -> Self {
75 self.tool_risks.insert(tool_name.into(), risk);
76 self
77 }
78
79 pub fn with_arg_pattern(
82 mut self,
83 key: impl Into<String>,
84 pattern: impl Into<String>,
85 risk: RiskLevel,
86 ) -> Self {
87 self.arg_patterns.push(ArgPattern {
88 key: key.into(),
89 pattern: pattern.into(),
90 risk,
91 });
92 self
93 }
94}
95
96impl Default for RuleBasedAnalyzer {
97 fn default() -> Self {
98 Self::new()
99 }
100}
101
102#[async_trait]
103impl SecurityAnalyzer for RuleBasedAnalyzer {
104 async fn assess(&self, tool_name: &str, args: &Value) -> Result<RiskLevel, SynapticError> {
105 let mut risk = self
106 .tool_risks
107 .get(tool_name)
108 .copied()
109 .unwrap_or(self.default_risk);
110
111 for pattern in &self.arg_patterns {
113 if let Some(val) = args.get(&pattern.key) {
114 let val_str = match val {
115 Value::String(s) => s.clone(),
116 other => other.to_string(),
117 };
118 if val_str.contains(&pattern.pattern) && pattern.risk > risk {
119 risk = pattern.risk;
120 }
121 }
122 }
123
124 Ok(risk)
125 }
126}
127
128pub struct ThresholdConfirmationPolicy {
130 threshold: RiskLevel,
131}
132
133impl ThresholdConfirmationPolicy {
134 pub fn new(threshold: RiskLevel) -> Self {
135 Self { threshold }
136 }
137}
138
139#[async_trait]
140impl ConfirmationPolicy for ThresholdConfirmationPolicy {
141 async fn should_confirm(
142 &self,
143 _tool_name: &str,
144 risk: RiskLevel,
145 ) -> Result<bool, SynapticError> {
146 Ok(risk >= self.threshold)
147 }
148}
149
150pub struct SecurityMiddleware {
152 analyzer: Arc<dyn SecurityAnalyzer>,
153 policy: Arc<dyn ConfirmationPolicy>,
154 callback: Arc<dyn SecurityConfirmationCallback>,
155 bypass: HashSet<String>,
157}
158
159impl SecurityMiddleware {
160 pub fn new(
161 analyzer: Arc<dyn SecurityAnalyzer>,
162 policy: Arc<dyn ConfirmationPolicy>,
163 callback: Arc<dyn SecurityConfirmationCallback>,
164 ) -> Self {
165 Self {
166 analyzer,
167 policy,
168 callback,
169 bypass: HashSet::new(),
170 }
171 }
172
173 pub fn with_bypass(mut self, tools: impl IntoIterator<Item = impl Into<String>>) -> Self {
175 self.bypass = tools.into_iter().map(|s| s.into()).collect();
176 self
177 }
178}
179
180#[async_trait]
181impl AgentMiddleware for SecurityMiddleware {
182 async fn wrap_tool_call(
183 &self,
184 request: ToolCallRequest,
185 next: &dyn ToolCaller,
186 ) -> Result<Value, SynapticError> {
187 let tool_name = &request.call.name;
188
189 if self.bypass.contains(tool_name) {
191 return next.call(request).await;
192 }
193
194 let risk = self
196 .analyzer
197 .assess(tool_name, &request.call.arguments)
198 .await?;
199
200 let needs_confirm = self.policy.should_confirm(tool_name, risk).await?;
202
203 if needs_confirm {
204 let confirmed = self
205 .callback
206 .confirm(tool_name, &request.call.arguments, risk)
207 .await?;
208 if !confirmed {
209 return Err(SynapticError::Tool(format!(
210 "tool call '{}' rejected by security policy (risk: {:?})",
211 tool_name, risk
212 )));
213 }
214 }
215
216 next.call(request).await
217 }
218}