agent_sdk/web/
security.rs

1//! URL validation and SSRF protection.
2
3use anyhow::{Context, Result, bail};
4use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, ToSocketAddrs};
5use url::Url;
6
7/// Default blocked hostnames for SSRF protection.
8const DEFAULT_BLOCKED_HOSTS: &[&str] = &[
9    "localhost",
10    "127.0.0.1",
11    "0.0.0.0",
12    "::1",
13    "[::1]",
14    "169.254.169.254",          // AWS metadata
15    "metadata.google.internal", // GCP metadata
16    "metadata.goog",            // GCP metadata alternate
17];
18
19/// URL validator with SSRF protection.
20///
21/// Validates URLs before fetching to prevent Server-Side Request Forgery attacks.
22/// By default, blocks access to:
23/// - Localhost and loopback addresses
24/// - Private IP ranges (10.x, 172.16-31.x, 192.168.x)
25/// - Cloud metadata endpoints (AWS, GCP)
26///
27/// # Example
28///
29/// ```ignore
30/// use agent_sdk::web::UrlValidator;
31///
32/// let validator = UrlValidator::new();
33/// assert!(validator.validate("https://example.com").is_ok());
34/// assert!(validator.validate("http://localhost").is_err());
35/// ```
36#[derive(Clone, Debug)]
37pub struct UrlValidator {
38    /// Only allow these domains (if Some).
39    allowed_domains: Option<Vec<String>>,
40    /// Block these hostnames/IPs.
41    blocked_hosts: Vec<String>,
42    /// Allow private IP ranges (default: false).
43    allow_private_ips: bool,
44    /// Maximum number of redirects to follow (default: 3).
45    max_redirects: usize,
46    /// Require HTTPS (default: true).
47    require_https: bool,
48}
49
50impl Default for UrlValidator {
51    fn default() -> Self {
52        Self::new()
53    }
54}
55
56impl UrlValidator {
57    /// Create a new URL validator with default security settings.
58    #[must_use]
59    pub fn new() -> Self {
60        Self {
61            allowed_domains: None,
62            blocked_hosts: DEFAULT_BLOCKED_HOSTS
63                .iter()
64                .map(|&s| s.to_string())
65                .collect(),
66            allow_private_ips: false,
67            max_redirects: 3,
68            require_https: true,
69        }
70    }
71
72    /// Only allow URLs from specific domains.
73    #[must_use]
74    pub fn with_allowed_domains(mut self, domains: Vec<String>) -> Self {
75        self.allowed_domains = Some(domains);
76        self
77    }
78
79    /// Add additional blocked hosts.
80    #[must_use]
81    pub fn with_blocked_hosts(mut self, hosts: Vec<String>) -> Self {
82        self.blocked_hosts.extend(hosts);
83        self
84    }
85
86    /// Allow private IP ranges (dangerous - use with caution).
87    #[must_use]
88    pub const fn with_allow_private_ips(mut self, allow: bool) -> Self {
89        self.allow_private_ips = allow;
90        self
91    }
92
93    /// Set maximum redirects.
94    #[must_use]
95    pub const fn with_max_redirects(mut self, max: usize) -> Self {
96        self.max_redirects = max;
97        self
98    }
99
100    /// Allow HTTP URLs (default requires HTTPS).
101    #[must_use]
102    pub const fn with_allow_http(mut self) -> Self {
103        self.require_https = false;
104        self
105    }
106
107    /// Get the maximum number of redirects allowed.
108    #[must_use]
109    pub const fn max_redirects(&self) -> usize {
110        self.max_redirects
111    }
112
113    /// Validate a URL string.
114    ///
115    /// # Errors
116    ///
117    /// Returns an error if:
118    /// - The URL is malformed
119    /// - The scheme is not HTTP or HTTPS
120    /// - HTTPS is required but HTTP is used
121    /// - The host is blocked
122    /// - The host resolves to a private/blocked IP
123    /// - The domain is not in the allowed list
124    pub fn validate(&self, url_str: &str) -> Result<Url> {
125        let url = Url::parse(url_str).context("Invalid URL format")?;
126
127        // Check scheme
128        match url.scheme() {
129            "https" => {}
130            "http" => {
131                if self.require_https {
132                    bail!("HTTPS required, but HTTP URL provided");
133                }
134            }
135            scheme => bail!("Unsupported URL scheme: {scheme}"),
136        }
137
138        // Check host
139        let host = url.host_str().context("URL must have a host")?;
140
141        // Check blocked hosts
142        if self.blocked_hosts.iter().any(|blocked| {
143            host.eq_ignore_ascii_case(blocked) || host.ends_with(&format!(".{blocked}"))
144        }) {
145            bail!("Access to host '{host}' is blocked");
146        }
147
148        // Check allowed domains
149        if let Some(ref allowed) = self.allowed_domains {
150            let is_allowed = allowed.iter().any(|domain| {
151                host.eq_ignore_ascii_case(domain) || host.ends_with(&format!(".{domain}"))
152            });
153            if !is_allowed {
154                bail!("Host '{host}' is not in the allowed domains list");
155            }
156        }
157
158        // Resolve and check IP
159        self.validate_resolved_ip(host)?;
160
161        Ok(url)
162    }
163
164    /// Validate that the resolved IP addresses are safe.
165    fn validate_resolved_ip(&self, host: &str) -> Result<()> {
166        // Try to resolve the hostname
167        let addrs: Vec<_> = format!("{host}:80")
168            .to_socket_addrs()
169            .map(Iterator::collect)
170            .unwrap_or_default();
171
172        for addr in addrs {
173            let ip = addr.ip();
174            if !self.allow_private_ips && is_private_ip(&ip) {
175                bail!("Access to private IP address {ip} is blocked");
176            }
177            if is_loopback(&ip) {
178                bail!("Access to loopback address {ip} is blocked");
179            }
180            if is_link_local(&ip) {
181                bail!("Access to link-local address {ip} is blocked");
182            }
183        }
184
185        Ok(())
186    }
187}
188
189/// Check if an IP address is private.
190fn is_private_ip(ip: &IpAddr) -> bool {
191    match ip {
192        IpAddr::V4(ipv4) => is_private_ipv4(*ipv4),
193        IpAddr::V6(ipv6) => is_private_ipv6(ipv6),
194    }
195}
196
197/// Check if an IPv4 address is private.
198fn is_private_ipv4(ip: Ipv4Addr) -> bool {
199    let octets = ip.octets();
200
201    // 10.0.0.0/8
202    if octets[0] == 10 {
203        return true;
204    }
205
206    // 172.16.0.0/12
207    if octets[0] == 172 && (16..=31).contains(&octets[1]) {
208        return true;
209    }
210
211    // 192.168.0.0/16
212    if octets[0] == 192 && octets[1] == 168 {
213        return true;
214    }
215
216    // 100.64.0.0/10 (Carrier-grade NAT)
217    if octets[0] == 100 && (64..=127).contains(&octets[1]) {
218        return true;
219    }
220
221    false
222}
223
224/// Check if an IPv6 address is private.
225const fn is_private_ipv6(ip: &Ipv6Addr) -> bool {
226    // Unique local addresses (fc00::/7)
227    let segments = ip.segments();
228    (segments[0] & 0xfe00) == 0xfc00
229}
230
231/// Check if an IP is a loopback address.
232const fn is_loopback(ip: &IpAddr) -> bool {
233    match ip {
234        IpAddr::V4(ipv4) => ipv4.is_loopback(),
235        IpAddr::V6(ipv6) => ipv6.is_loopback(),
236    }
237}
238
239/// Check if an IP is a link-local address.
240const fn is_link_local(ip: &IpAddr) -> bool {
241    match ip {
242        IpAddr::V4(ipv4) => ipv4.is_link_local(),
243        IpAddr::V6(ipv6) => {
244            // fe80::/10
245            let segments = ipv6.segments();
246            (segments[0] & 0xffc0) == 0xfe80
247        }
248    }
249}
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254
255    #[test]
256    fn test_valid_https_url() {
257        let validator = UrlValidator::new();
258        assert!(validator.validate("https://example.com").is_ok());
259        assert!(validator.validate("https://example.com/path").is_ok());
260        assert!(validator.validate("https://sub.example.com").is_ok());
261    }
262
263    #[test]
264    fn test_http_blocked_by_default() {
265        let validator = UrlValidator::new();
266        let result = validator.validate("http://example.com");
267        assert!(result.is_err());
268        assert!(result.unwrap_err().to_string().contains("HTTPS required"));
269    }
270
271    #[test]
272    fn test_http_allowed_with_flag() {
273        let validator = UrlValidator::new().with_allow_http();
274        assert!(validator.validate("http://example.com").is_ok());
275    }
276
277    #[test]
278    fn test_localhost_blocked() {
279        let validator = UrlValidator::new().with_allow_http();
280        assert!(validator.validate("http://localhost").is_err());
281        assert!(validator.validate("http://127.0.0.1").is_err());
282        assert!(validator.validate("http://[::1]").is_err());
283    }
284
285    #[test]
286    fn test_metadata_endpoints_blocked() {
287        let validator = UrlValidator::new().with_allow_http();
288        assert!(validator.validate("http://169.254.169.254").is_err());
289        assert!(
290            validator
291                .validate("http://metadata.google.internal")
292                .is_err()
293        );
294    }
295
296    #[test]
297    fn test_invalid_url() {
298        let validator = UrlValidator::new();
299        assert!(validator.validate("not-a-url").is_err());
300        assert!(validator.validate("").is_err());
301        assert!(validator.validate("ftp://example.com").is_err());
302    }
303
304    #[test]
305    fn test_allowed_domains() {
306        let validator = UrlValidator::new().with_allowed_domains(vec!["example.com".to_string()]);
307
308        assert!(validator.validate("https://example.com").is_ok());
309        assert!(validator.validate("https://sub.example.com").is_ok());
310
311        let result = validator.validate("https://other.com");
312        assert!(result.is_err());
313        assert!(
314            result
315                .unwrap_err()
316                .to_string()
317                .contains("not in the allowed domains")
318        );
319    }
320
321    #[test]
322    fn test_blocked_hosts() {
323        let validator = UrlValidator::new().with_blocked_hosts(vec!["blocked.com".to_string()]);
324
325        let result = validator.validate("https://blocked.com");
326        assert!(result.is_err());
327        assert!(result.unwrap_err().to_string().contains("blocked"));
328    }
329
330    #[test]
331    fn test_is_private_ipv4() {
332        // Private ranges
333        assert!(is_private_ipv4(Ipv4Addr::new(10, 0, 0, 1)));
334        assert!(is_private_ipv4(Ipv4Addr::new(10, 255, 255, 255)));
335        assert!(is_private_ipv4(Ipv4Addr::new(172, 16, 0, 1)));
336        assert!(is_private_ipv4(Ipv4Addr::new(172, 31, 255, 255)));
337        assert!(is_private_ipv4(Ipv4Addr::new(192, 168, 0, 1)));
338        assert!(is_private_ipv4(Ipv4Addr::new(192, 168, 255, 255)));
339
340        // Not private
341        assert!(!is_private_ipv4(Ipv4Addr::new(8, 8, 8, 8)));
342        assert!(!is_private_ipv4(Ipv4Addr::new(1, 1, 1, 1)));
343        assert!(!is_private_ipv4(Ipv4Addr::new(172, 15, 0, 1)));
344        assert!(!is_private_ipv4(Ipv4Addr::new(172, 32, 0, 1)));
345    }
346
347    #[test]
348    fn test_max_redirects() {
349        let validator = UrlValidator::new().with_max_redirects(5);
350        assert_eq!(validator.max_redirects(), 5);
351    }
352
353    #[test]
354    fn test_default_validator() {
355        let validator = UrlValidator::default();
356        assert!(!validator.allow_private_ips);
357        assert!(validator.require_https);
358        assert_eq!(validator.max_redirects, 3);
359    }
360}