Skip to main content

synaptic_middleware/
ssrf_guard.rs

1use std::collections::HashSet;
2use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
3
4use async_trait::async_trait;
5use serde_json::Value;
6use synaptic_core::SynapticError;
7
8use crate::{AgentMiddleware, ToolCallRequest, ToolCaller};
9
10/// Configuration for SSRF protection.
11#[derive(Debug, Clone)]
12pub struct SsrfGuardConfig {
13    /// Block requests to private/loopback IPs (default: true).
14    pub block_private: bool,
15    /// Additional hostnames to block.
16    pub blocklist: HashSet<String>,
17    /// Hostnames that are always allowed (overrides block_private for these).
18    pub allowlist: HashSet<String>,
19    /// Tool argument keys that contain URLs to inspect.
20    pub url_keys: Vec<String>,
21}
22
23impl Default for SsrfGuardConfig {
24    fn default() -> Self {
25        Self {
26            block_private: true,
27            blocklist: HashSet::new(),
28            allowlist: HashSet::new(),
29            url_keys: vec![
30                "url".to_string(),
31                "uri".to_string(),
32                "endpoint".to_string(),
33                "base_url".to_string(),
34                "webhook_url".to_string(),
35            ],
36        }
37    }
38}
39
40/// Middleware that prevents SSRF (Server-Side Request Forgery) attacks.
41///
42/// Inspects tool call arguments for URLs pointing to private/loopback
43/// addresses and blocks them. Supports configurable allowlist/blocklist.
44pub struct SsrfGuardMiddleware {
45    config: SsrfGuardConfig,
46}
47
48impl SsrfGuardMiddleware {
49    pub fn new(config: SsrfGuardConfig) -> Self {
50        Self { config }
51    }
52
53    /// Check if a URL is safe to access.
54    fn check_url(&self, url: &str) -> Result<(), String> {
55        // Parse the URL to extract the host
56        let host = extract_host(url).ok_or_else(|| format!("invalid URL: {}", url))?;
57
58        // Check allowlist first
59        if self.config.allowlist.contains(&host) {
60            return Ok(());
61        }
62
63        // Check blocklist
64        if self.config.blocklist.contains(&host) {
65            return Err(format!("host '{}' is blocklisted", host));
66        }
67
68        // Check for private/loopback IPs
69        if self.config.block_private {
70            if let Ok(ip) = host.parse::<IpAddr>() {
71                if is_private_ip(&ip) {
72                    return Err(format!(
73                        "access to private/loopback address {} is blocked",
74                        ip
75                    ));
76                }
77            }
78
79            // Also check common private hostnames
80            let lower = host.to_lowercase();
81            if lower == "localhost"
82                || lower == "0.0.0.0"
83                || lower.ends_with(".local")
84                || lower.ends_with(".internal")
85                || lower == "metadata.google.internal"
86                || lower == "169.254.169.254"
87            // AWS metadata
88            {
89                return Err(format!("access to private host '{}' is blocked", host));
90            }
91        }
92
93        Ok(())
94    }
95
96    /// Scan tool arguments for URLs and validate them.
97    fn scan_args(&self, args: &Value) -> Result<(), String> {
98        match args {
99            Value::Object(map) => {
100                for (key, value) in map {
101                    if self.config.url_keys.iter().any(|k| k == key) {
102                        if let Some(url) = value.as_str() {
103                            self.check_url(url)?;
104                        }
105                    }
106                    // Recurse into nested objects
107                    self.scan_args(value)?;
108                }
109            }
110            Value::Array(arr) => {
111                for item in arr {
112                    self.scan_args(item)?;
113                }
114            }
115            Value::String(s) => {
116                // Check if string looks like a URL
117                if (s.starts_with("http://") || s.starts_with("https://")) && s.len() < 2048 {
118                    self.check_url(s)?;
119                }
120            }
121            _ => {}
122        }
123        Ok(())
124    }
125}
126
127#[async_trait]
128impl AgentMiddleware for SsrfGuardMiddleware {
129    async fn wrap_tool_call(
130        &self,
131        request: ToolCallRequest,
132        next: &dyn ToolCaller,
133    ) -> Result<Value, SynapticError> {
134        // Scan arguments for suspicious URLs
135        if let Err(reason) = self.scan_args(&request.call.arguments) {
136            return Err(SynapticError::Security(format!(
137                "SSRF blocked: {} (tool: {})",
138                reason, request.call.name
139            )));
140        }
141
142        next.call(request).await
143    }
144}
145
146/// Extract host from a URL string.
147fn extract_host(url: &str) -> Option<String> {
148    // Simple URL parsing without pulling in the `url` crate
149    let stripped = url
150        .strip_prefix("https://")
151        .or_else(|| url.strip_prefix("http://"))?;
152    let host_port = stripped.split('/').next()?;
153    let host = host_port.split(':').next()?;
154    if host.is_empty() {
155        None
156    } else {
157        Some(host.to_string())
158    }
159}
160
161/// Check if an IP address is private, loopback, or link-local.
162fn is_private_ip(ip: &IpAddr) -> bool {
163    match ip {
164        IpAddr::V4(v4) => {
165            v4.is_loopback()
166                || v4.is_private()
167                || v4.is_link_local()
168                || is_cgnat(v4)
169                || v4.is_broadcast()
170                || v4.is_unspecified()
171        }
172        IpAddr::V6(v6) => v6.is_loopback() || v6.is_unspecified() || is_v6_private(v6),
173    }
174}
175
176fn is_cgnat(ip: &Ipv4Addr) -> bool {
177    // 100.64.0.0/10 (Carrier-grade NAT)
178    let octets = ip.octets();
179    octets[0] == 100 && (octets[1] & 0xC0) == 64
180}
181
182fn is_v6_private(ip: &Ipv6Addr) -> bool {
183    let segments = ip.segments();
184    // fc00::/7 (unique local)
185    (segments[0] & 0xFE00) == 0xFC00
186        // fe80::/10 (link-local)
187        || (segments[0] & 0xFFC0) == 0xFE80
188}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193
194    fn default_guard() -> SsrfGuardMiddleware {
195        SsrfGuardMiddleware::new(SsrfGuardConfig::default())
196    }
197
198    #[test]
199    fn blocks_localhost() {
200        let guard = default_guard();
201        assert!(guard.check_url("http://localhost/api").is_err());
202        assert!(guard.check_url("http://127.0.0.1/api").is_err());
203    }
204
205    #[test]
206    fn blocks_private_ips() {
207        let guard = default_guard();
208        assert!(guard.check_url("http://192.168.1.1/api").is_err());
209        assert!(guard.check_url("http://10.0.0.1/api").is_err());
210        assert!(guard.check_url("http://172.16.0.1/api").is_err());
211    }
212
213    #[test]
214    fn blocks_aws_metadata() {
215        let guard = default_guard();
216        assert!(guard
217            .check_url("http://169.254.169.254/latest/meta-data/")
218            .is_err());
219    }
220
221    #[test]
222    fn allows_public_urls() {
223        let guard = default_guard();
224        assert!(guard.check_url("https://api.openai.com/v1/chat").is_ok());
225        assert!(guard.check_url("https://example.com").is_ok());
226    }
227
228    #[test]
229    fn allowlist_overrides_private() {
230        let mut config = SsrfGuardConfig::default();
231        config.allowlist.insert("localhost".to_string());
232        let guard = SsrfGuardMiddleware::new(config);
233        assert!(guard.check_url("http://localhost/api").is_ok());
234    }
235
236    #[test]
237    fn blocklist_blocks_public() {
238        let mut config = SsrfGuardConfig::default();
239        config.blocklist.insert("evil.com".to_string());
240        let guard = SsrfGuardMiddleware::new(config);
241        assert!(guard.check_url("https://evil.com/api").is_err());
242    }
243
244    #[test]
245    fn scans_nested_args() {
246        let guard = default_guard();
247        let args = serde_json::json!({
248            "config": {
249                "url": "http://127.0.0.1/steal"
250            }
251        });
252        assert!(guard.scan_args(&args).is_err());
253    }
254
255    #[test]
256    fn extract_host_works() {
257        assert_eq!(
258            extract_host("https://example.com/path"),
259            Some("example.com".to_string())
260        );
261        assert_eq!(
262            extract_host("http://localhost:8080/api"),
263            Some("localhost".to_string())
264        );
265        assert_eq!(extract_host("not-a-url"), None);
266    }
267}