Skip to main content

t_ron/
middleware.rs

1//! Security middleware — wraps bote's Dispatcher with t-ron's security gate.
2//!
3//! Intercepts `tools/call` requests, runs the full check pipeline (policy, rate
4//! limiting, payload scanning, pattern analysis), and blocks denied calls before
5//! they reach the tool handler.
6
7use crate::TRon;
8use crate::gate::{DenyCode, ToolCall, Verdict};
9use bote::Dispatcher;
10use bote::protocol::{JsonRpcRequest, JsonRpcResponse};
11
12/// JSON-RPC error code for security denials (server-defined range).
13const SECURITY_DENIED: i32 = -32001;
14
15/// Security gate wrapping a bote Dispatcher.
16///
17/// Every `tools/call` passes through t-ron's check pipeline before reaching the
18/// inner dispatcher. Non-tool methods (initialize, tools/list) pass through
19/// unmodified.
20pub struct SecurityGate {
21    tron: TRon,
22    inner: Dispatcher,
23}
24
25impl SecurityGate {
26    /// Create a new security gate.
27    #[must_use]
28    pub fn new(tron: TRon, dispatcher: Dispatcher) -> Self {
29        Self {
30            tron,
31            inner: dispatcher,
32        }
33    }
34
35    /// Access the inner dispatcher (e.g. for registering handlers).
36    #[must_use]
37    pub fn dispatcher_mut(&mut self) -> &mut Dispatcher {
38        &mut self.inner
39    }
40
41    /// Access the inner dispatcher immutably.
42    #[must_use]
43    pub fn dispatcher(&self) -> &Dispatcher {
44        &self.inner
45    }
46
47    /// Access the t-ron security monitor.
48    #[must_use]
49    pub fn tron(&self) -> &TRon {
50        &self.tron
51    }
52
53    /// Register t-ron's tool handlers with the inner dispatcher.
54    ///
55    /// **Important:** The tool *definitions* must be registered in the
56    /// `ToolRegistry` before creating the `Dispatcher`. Use
57    /// [`tools::tool_defs()`](crate::tools::tool_defs) to get the definitions
58    /// and register them alongside your application's tools. This method only
59    /// wires up the handler functions.
60    pub fn register_tool_handlers(&mut self) {
61        use crate::tools;
62        let query = self.tron.query();
63        self.inner
64            .handle("tron_status", tools::status_handler(query.clone()));
65        self.inner
66            .handle("tron_risk", tools::risk_handler(query.clone()));
67        self.inner.handle("tron_audit", tools::audit_handler(query));
68        self.inner
69            .handle("tron_policy", tools::policy_handler(&self.tron));
70    }
71
72    /// Dispatch a JSON-RPC request with security checks.
73    ///
74    /// `agent_id` identifies the calling agent — this is the identity t-ron
75    /// checks against its policy engine. Callers are responsible for
76    /// authenticating the agent and providing a trusted ID.
77    pub async fn dispatch(
78        &self,
79        request: &JsonRpcRequest,
80        agent_id: &str,
81    ) -> Option<JsonRpcResponse> {
82        if request.method == "tools/call"
83            && let Some(denied) = self.check_tool_call(request, agent_id).await
84        {
85            return Some(denied);
86        }
87        self.inner.dispatch(request)
88    }
89
90    /// Dispatch with streaming support and security checks.
91    pub async fn dispatch_streaming(
92        &self,
93        request: &JsonRpcRequest,
94        agent_id: &str,
95    ) -> bote::DispatchOutcome {
96        if request.method == "tools/call"
97            && let Some(denied) = self.check_tool_call(request, agent_id).await
98        {
99            return bote::DispatchOutcome::Immediate(Some(denied));
100        }
101        self.inner.dispatch_streaming(request)
102    }
103
104    /// Run the security check pipeline for a tools/call request.
105    /// Returns `Some(error_response)` if denied, `None` if allowed.
106    async fn check_tool_call(
107        &self,
108        request: &JsonRpcRequest,
109        agent_id: &str,
110    ) -> Option<JsonRpcResponse> {
111        let id = request.id.clone().unwrap_or(serde_json::Value::Null);
112        let tool_name = match request.params.get("name").and_then(|v| v.as_str()) {
113            Some(name) if !name.is_empty() => name,
114            _ => {
115                return Some(Self::deny_response(
116                    id,
117                    "missing or empty tool name in tools/call",
118                    DenyCode::Unauthorized,
119                ));
120            }
121        };
122        let arguments = request
123            .params
124            .get("arguments")
125            .cloned()
126            .unwrap_or(serde_json::json!({}));
127
128        let call = ToolCall {
129            agent_id: agent_id.to_string(),
130            tool_name: tool_name.to_string(),
131            params: arguments,
132            timestamp: chrono::Utc::now(),
133        };
134
135        let verdict = self.tron.check(&call).await;
136        match verdict {
137            Verdict::Deny { reason, code } => {
138                tracing::warn!(
139                    agent = agent_id,
140                    tool = tool_name,
141                    code = ?code,
142                    "security gate denied tool call: {reason}"
143                );
144                Some(Self::deny_response(id, &reason, code))
145            }
146            Verdict::Flag { reason } => {
147                tracing::info!(
148                    agent = agent_id,
149                    tool = tool_name,
150                    "security gate flagged tool call: {reason}"
151                );
152                // Flags are allowed through — they're informational.
153                None
154            }
155            Verdict::Allow => None,
156        }
157    }
158
159    /// Build a JSON-RPC error response for a denied call.
160    fn deny_response(id: serde_json::Value, reason: &str, code: DenyCode) -> JsonRpcResponse {
161        JsonRpcResponse::error(id, SECURITY_DENIED, format!("security: {reason} [{code}]"))
162    }
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168    use crate::{DefaultAction, TRonConfig};
169    use bote::registry::{ToolDef, ToolRegistry, ToolSchema};
170    use std::collections::HashMap;
171    use std::sync::Arc;
172
173    fn make_gate(config: TRonConfig) -> SecurityGate {
174        let tron = TRon::new(config);
175        let mut reg = ToolRegistry::new();
176        reg.register(ToolDef {
177            name: "echo".into(),
178            description: "Echo input".into(),
179            input_schema: ToolSchema {
180                schema_type: "object".into(),
181                properties: HashMap::new(),
182                required: vec![],
183            },
184        });
185        let mut dispatcher = Dispatcher::new(reg);
186        dispatcher.handle(
187            "echo",
188            Arc::new(|params| {
189                serde_json::json!({"content": [{"type": "text", "text": params.to_string()}]})
190            }),
191        );
192        SecurityGate::new(tron, dispatcher)
193    }
194
195    fn tool_call_request(tool_name: &str, arguments: serde_json::Value) -> JsonRpcRequest {
196        JsonRpcRequest::new(1, "tools/call")
197            .with_params(serde_json::json!({"name": tool_name, "arguments": arguments}))
198    }
199
200    #[tokio::test]
201    async fn deny_unknown_agent() {
202        let gate = make_gate(TRonConfig::default());
203        let req = tool_call_request("echo", serde_json::json!({}));
204        let resp = gate.dispatch(&req, "nobody").await.unwrap();
205        assert!(resp.error.is_some());
206        let err = resp.error.unwrap();
207        assert_eq!(err.code, SECURITY_DENIED);
208        assert!(err.message.contains("unauthorized"));
209    }
210
211    #[tokio::test]
212    async fn allow_known_agent() {
213        let config = TRonConfig {
214            default_unknown_agent: DefaultAction::Allow,
215            default_unknown_tool: DefaultAction::Allow,
216            ..Default::default()
217        };
218        let gate = make_gate(config);
219        let req = tool_call_request("echo", serde_json::json!({"msg": "hello"}));
220        let resp = gate.dispatch(&req, "agent-1").await.unwrap();
221        assert!(resp.error.is_none());
222        assert!(resp.result.is_some());
223    }
224
225    #[tokio::test]
226    async fn allow_with_policy() {
227        let gate = make_gate(TRonConfig::default());
228        gate.tron()
229            .load_policy(
230                r#"
231[agent."web-agent"]
232allow = ["echo"]
233"#,
234            )
235            .unwrap();
236        let req = tool_call_request("echo", serde_json::json!({}));
237        let resp = gate.dispatch(&req, "web-agent").await.unwrap();
238        assert!(resp.error.is_none());
239    }
240
241    #[tokio::test]
242    async fn deny_by_policy() {
243        let gate = make_gate(TRonConfig::default());
244        gate.tron()
245            .load_policy(
246                r#"
247[agent."restricted"]
248allow = ["tarang_*"]
249deny = ["echo"]
250"#,
251            )
252            .unwrap();
253        let req = tool_call_request("echo", serde_json::json!({}));
254        let resp = gate.dispatch(&req, "restricted").await.unwrap();
255        assert!(resp.error.is_some());
256    }
257
258    #[tokio::test]
259    async fn deny_injection() {
260        let config = TRonConfig {
261            default_unknown_agent: DefaultAction::Allow,
262            default_unknown_tool: DefaultAction::Allow,
263            ..Default::default()
264        };
265        let gate = make_gate(config);
266        let req = tool_call_request(
267            "echo",
268            serde_json::json!({"q": "1 UNION SELECT * FROM passwords"}),
269        );
270        let resp = gate.dispatch(&req, "agent").await.unwrap();
271        assert!(resp.error.is_some());
272        let err = resp.error.unwrap();
273        assert!(err.message.contains("injection_detected"));
274    }
275
276    #[tokio::test]
277    async fn non_tool_call_passes_through() {
278        let gate = make_gate(TRonConfig::default());
279        // initialize should pass through regardless of agent
280        let req = JsonRpcRequest::new(1, "initialize");
281        let resp = gate.dispatch(&req, "unknown-agent").await.unwrap();
282        assert!(resp.result.is_some());
283    }
284
285    #[tokio::test]
286    async fn tools_list_passes_through() {
287        let gate = make_gate(TRonConfig::default());
288        let req = JsonRpcRequest::new(1, "tools/list");
289        let resp = gate.dispatch(&req, "unknown-agent").await.unwrap();
290        let result = resp.result.unwrap();
291        let tools = result["tools"].as_array().unwrap();
292        assert_eq!(tools.len(), 1);
293    }
294
295    #[tokio::test]
296    async fn rate_limit_through_gate() {
297        let config = TRonConfig {
298            default_unknown_agent: DefaultAction::Allow,
299            default_unknown_tool: DefaultAction::Allow,
300            scan_payloads: false,
301            analyze_patterns: false,
302            ..Default::default()
303        };
304        let gate = make_gate(config);
305        let req = tool_call_request("echo", serde_json::json!({}));
306        for _ in 0..60 {
307            let resp = gate.dispatch(&req, "agent").await.unwrap();
308            assert!(resp.error.is_none());
309        }
310        // 61st should be rate limited
311        let resp = gate.dispatch(&req, "agent").await.unwrap();
312        assert!(resp.error.is_some());
313        assert!(resp.error.unwrap().message.contains("rate_limited"));
314    }
315
316    #[tokio::test]
317    async fn streaming_dispatch_denied() {
318        let gate = make_gate(TRonConfig::default());
319        let req = tool_call_request("echo", serde_json::json!({}));
320        match gate.dispatch_streaming(&req, "nobody").await {
321            bote::DispatchOutcome::Immediate(Some(resp)) => {
322                assert!(resp.error.is_some());
323            }
324            _ => panic!("expected Immediate(Some) for denied call"),
325        }
326    }
327
328    #[tokio::test]
329    async fn streaming_dispatch_allowed() {
330        let config = TRonConfig {
331            default_unknown_agent: DefaultAction::Allow,
332            default_unknown_tool: DefaultAction::Allow,
333            ..Default::default()
334        };
335        let gate = make_gate(config);
336        let req = tool_call_request("echo", serde_json::json!({}));
337        match gate.dispatch_streaming(&req, "agent").await {
338            bote::DispatchOutcome::Immediate(Some(resp)) => {
339                assert!(resp.error.is_none());
340            }
341            _ => panic!("expected Immediate(Some) for allowed sync tool"),
342        }
343    }
344
345    #[tokio::test]
346    async fn audit_logged_through_gate() {
347        let config = TRonConfig {
348            default_unknown_agent: DefaultAction::Allow,
349            default_unknown_tool: DefaultAction::Allow,
350            scan_payloads: false,
351            analyze_patterns: false,
352            ..Default::default()
353        };
354        let gate = make_gate(config);
355        let req = tool_call_request("echo", serde_json::json!({}));
356        gate.dispatch(&req, "agent-1").await;
357
358        let query = gate.tron().query();
359        assert_eq!(query.total_events().await, 1);
360    }
361
362    #[tokio::test]
363    async fn deny_missing_tool_name() {
364        let config = TRonConfig {
365            default_unknown_agent: DefaultAction::Allow,
366            default_unknown_tool: DefaultAction::Allow,
367            ..Default::default()
368        };
369        let gate = make_gate(config);
370        // tools/call with no "name" field
371        let req =
372            JsonRpcRequest::new(1, "tools/call").with_params(serde_json::json!({"arguments": {}}));
373        let resp = gate.dispatch(&req, "agent").await.unwrap();
374        assert!(resp.error.is_some());
375        assert!(resp.error.unwrap().message.contains("missing"));
376    }
377
378    #[tokio::test]
379    async fn deny_empty_tool_name() {
380        let config = TRonConfig {
381            default_unknown_agent: DefaultAction::Allow,
382            default_unknown_tool: DefaultAction::Allow,
383            ..Default::default()
384        };
385        let gate = make_gate(config);
386        let req = tool_call_request("", serde_json::json!({}));
387        let resp = gate.dispatch(&req, "agent").await.unwrap();
388        assert!(resp.error.is_some());
389        assert!(resp.error.unwrap().message.contains("missing"));
390    }
391
392    #[tokio::test]
393    async fn deny_response_format() {
394        let resp = SecurityGate::deny_response(
395            serde_json::json!(42),
396            "rate limit exceeded",
397            DenyCode::RateLimited,
398        );
399        assert_eq!(resp.id, serde_json::json!(42));
400        assert!(resp.error.is_some());
401        let err = resp.error.unwrap();
402        assert_eq!(err.code, SECURITY_DENIED);
403        assert!(err.message.contains("rate_limited"));
404        assert!(err.message.contains("rate limit exceeded"));
405    }
406}