Skip to main content

agent_sdk/web/
security.rs

1//! URL validation and SSRF protection.
2
3use anyhow::{Context, Result, bail};
4use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, ToSocketAddrs};
5use url::Url;
6
7/// Default blocked hostnames for SSRF protection.
8const DEFAULT_BLOCKED_HOSTS: &[&str] = &[
9    "localhost",
10    "127.0.0.1",
11    "0.0.0.0",
12    "::1",
13    "[::1]",
14    "169.254.169.254",          // AWS metadata
15    "metadata.google.internal", // GCP metadata
16    "metadata.goog",            // GCP metadata alternate
17];
18
19/// URL validator with SSRF protection.
20///
21/// Validates URLs before fetching to prevent Server-Side Request Forgery attacks.
22/// By default, blocks access to:
23/// - Localhost and loopback addresses
24/// - Private IP ranges (10.x, 172.16-31.x, 192.168.x)
25/// - Cloud metadata endpoints (AWS, GCP)
26///
27/// # Example
28///
29/// ```ignore
30/// use agent_sdk::web::UrlValidator;
31///
32/// let validator = UrlValidator::new();
33/// assert!(validator.validate("https://example.com").is_ok());
34/// assert!(validator.validate("http://localhost").is_err());
35/// ```
36#[derive(Clone, Debug)]
37pub struct UrlValidator {
38    /// Only allow these domains (if Some).
39    allowed_domains: Option<Vec<String>>,
40    /// Block these hostnames/IPs.
41    blocked_hosts: Vec<String>,
42    /// Allow private IP ranges (default: false).
43    allow_private_ips: bool,
44    /// Maximum number of redirects to follow (default: 3).
45    max_redirects: usize,
46    /// Require HTTPS (default: true).
47    require_https: bool,
48}
49
50impl Default for UrlValidator {
51    fn default() -> Self {
52        Self::new()
53    }
54}
55
56impl UrlValidator {
57    /// Create a new URL validator with default security settings.
58    #[must_use]
59    pub fn new() -> Self {
60        Self {
61            allowed_domains: None,
62            blocked_hosts: DEFAULT_BLOCKED_HOSTS
63                .iter()
64                .map(|&s| s.to_string())
65                .collect(),
66            allow_private_ips: false,
67            max_redirects: 3,
68            require_https: true,
69        }
70    }
71
72    /// Only allow URLs from specific domains.
73    #[must_use]
74    pub fn with_allowed_domains(mut self, domains: Vec<String>) -> Self {
75        self.allowed_domains = Some(domains);
76        self
77    }
78
79    /// Add additional blocked hosts.
80    #[must_use]
81    pub fn with_blocked_hosts(mut self, hosts: Vec<String>) -> Self {
82        self.blocked_hosts.extend(hosts);
83        self
84    }
85
86    /// Allow private IP ranges (dangerous - use with caution).
87    #[must_use]
88    pub const fn with_allow_private_ips(mut self, allow: bool) -> Self {
89        self.allow_private_ips = allow;
90        self
91    }
92
93    /// Set maximum redirects.
94    #[must_use]
95    pub const fn with_max_redirects(mut self, max: usize) -> Self {
96        self.max_redirects = max;
97        self
98    }
99
100    /// Allow HTTP URLs (default requires HTTPS).
101    #[must_use]
102    pub const fn with_allow_http(mut self) -> Self {
103        self.require_https = false;
104        self
105    }
106
107    /// Get the maximum number of redirects allowed.
108    #[must_use]
109    pub const fn max_redirects(&self) -> usize {
110        self.max_redirects
111    }
112
113    /// Validate a URL string.
114    ///
115    /// # Errors
116    ///
117    /// Returns an error if:
118    /// - The URL is malformed
119    /// - The scheme is not HTTP or HTTPS
120    /// - HTTPS is required but HTTP is used
121    /// - The host is blocked
122    /// - The host resolves to a private/blocked IP
123    /// - The domain is not in the allowed list
124    pub fn validate(&self, url_str: &str) -> Result<Url> {
125        let url = Url::parse(url_str).context("Invalid URL format")?;
126
127        // Check scheme
128        match url.scheme() {
129            "https" => {}
130            "http" => {
131                if self.require_https {
132                    bail!("HTTPS required, but HTTP URL provided");
133                }
134            }
135            scheme => bail!("Unsupported URL scheme: {scheme}"),
136        }
137
138        // Check host
139        let host = url.host_str().context("URL must have a host")?;
140
141        // Check blocked hosts
142        if self.blocked_hosts.iter().any(|blocked| {
143            host.eq_ignore_ascii_case(blocked) || host.ends_with(&format!(".{blocked}"))
144        }) {
145            bail!("Access to host '{host}' is blocked");
146        }
147
148        // Check allowed domains
149        if let Some(ref allowed) = self.allowed_domains {
150            let is_allowed = allowed.iter().any(|domain| {
151                host.eq_ignore_ascii_case(domain) || host.ends_with(&format!(".{domain}"))
152            });
153            if !is_allowed {
154                bail!("Host '{host}' is not in the allowed domains list");
155            }
156        }
157
158        // Resolve and check IP
159        self.validate_resolved_ip(host)?;
160
161        Ok(url)
162    }
163
164    /// Validate that the resolved IP addresses are safe.
165    ///
166    /// Fails closed: if DNS resolution returns no results (or fails), the host
167    /// is blocked to prevent DNS-rebinding attacks that rely on transient lookup
168    /// failures.
169    fn validate_resolved_ip(&self, host: &str) -> Result<()> {
170        // Try to resolve the hostname — fail closed on empty/error
171        let addrs: Vec<_> = format!("{host}:80")
172            .to_socket_addrs()
173            .map(Iterator::collect)
174            .unwrap_or_default();
175
176        if addrs.is_empty() {
177            bail!("Could not resolve host '{host}' — blocking unresolvable URLs for safety");
178        }
179
180        for addr in addrs {
181            let ip = addr.ip();
182            if !self.allow_private_ips && is_private_ip(&ip) {
183                bail!("Access to private IP address {ip} is blocked");
184            }
185            if is_loopback(&ip) {
186                bail!("Access to loopback address {ip} is blocked");
187            }
188            if is_link_local(&ip) {
189                bail!("Access to link-local address {ip} is blocked");
190            }
191        }
192
193        Ok(())
194    }
195}
196
197/// Check if an IP address is private.
198///
199/// Also handles IPv4-mapped IPv6 addresses (`::ffff:x.x.x.x`) by extracting
200/// the embedded IPv4 address and applying IPv4 checks.
201fn is_private_ip(ip: &IpAddr) -> bool {
202    match ip {
203        IpAddr::V4(ipv4) => is_private_ipv4(*ipv4),
204        IpAddr::V6(ipv6) => {
205            // Check for IPv4-mapped IPv6 addresses (::ffff:x.x.x.x)
206            if let Some(mapped_v4) = ipv6.to_ipv4_mapped() {
207                return is_private_ipv4(mapped_v4);
208            }
209            is_private_ipv6(ipv6)
210        }
211    }
212}
213
214/// Check if an IPv4 address is private.
215fn is_private_ipv4(ip: Ipv4Addr) -> bool {
216    let octets = ip.octets();
217
218    // 10.0.0.0/8
219    if octets[0] == 10 {
220        return true;
221    }
222
223    // 172.16.0.0/12
224    if octets[0] == 172 && (16..=31).contains(&octets[1]) {
225        return true;
226    }
227
228    // 192.168.0.0/16
229    if octets[0] == 192 && octets[1] == 168 {
230        return true;
231    }
232
233    // 100.64.0.0/10 (Carrier-grade NAT)
234    if octets[0] == 100 && (64..=127).contains(&octets[1]) {
235        return true;
236    }
237
238    false
239}
240
241/// Check if an IPv6 address is private.
242const fn is_private_ipv6(ip: &Ipv6Addr) -> bool {
243    // Unique local addresses (fc00::/7)
244    let segments = ip.segments();
245    (segments[0] & 0xfe00) == 0xfc00
246}
247
248/// Check if an IP is a loopback address.
249///
250/// Handles IPv4-mapped IPv6 addresses (`::ffff:127.0.0.1`).
251const fn is_loopback(ip: &IpAddr) -> bool {
252    match ip {
253        IpAddr::V4(ipv4) => ipv4.is_loopback(),
254        IpAddr::V6(ipv6) => {
255            if let Some(mapped_v4) = ipv6.to_ipv4_mapped() {
256                return mapped_v4.is_loopback();
257            }
258            ipv6.is_loopback()
259        }
260    }
261}
262
263/// Check if an IP is a link-local address.
264///
265/// Handles IPv4-mapped IPv6 addresses (`::ffff:169.254.x.x`).
266const fn is_link_local(ip: &IpAddr) -> bool {
267    match ip {
268        IpAddr::V4(ipv4) => ipv4.is_link_local(),
269        IpAddr::V6(ipv6) => {
270            if let Some(mapped_v4) = ipv6.to_ipv4_mapped() {
271                return mapped_v4.is_link_local();
272            }
273            // fe80::/10
274            let segments = ipv6.segments();
275            (segments[0] & 0xffc0) == 0xfe80
276        }
277    }
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283
284    #[test]
285    fn test_valid_https_url() {
286        let validator = UrlValidator::new();
287        assert!(validator.validate("https://example.com").is_ok());
288        assert!(validator.validate("https://example.com/path").is_ok());
289    }
290
291    #[test]
292    fn test_http_blocked_by_default() {
293        let validator = UrlValidator::new();
294        let result = validator.validate("http://example.com");
295        assert!(result.is_err());
296        assert!(result.unwrap_err().to_string().contains("HTTPS required"));
297    }
298
299    #[test]
300    fn test_http_allowed_with_flag() {
301        let validator = UrlValidator::new().with_allow_http();
302        assert!(validator.validate("http://example.com").is_ok());
303    }
304
305    #[test]
306    fn test_localhost_blocked() {
307        let validator = UrlValidator::new().with_allow_http();
308        assert!(validator.validate("http://localhost").is_err());
309        assert!(validator.validate("http://127.0.0.1").is_err());
310        assert!(validator.validate("http://[::1]").is_err());
311    }
312
313    #[test]
314    fn test_metadata_endpoints_blocked() {
315        let validator = UrlValidator::new().with_allow_http();
316        assert!(validator.validate("http://169.254.169.254").is_err());
317        assert!(
318            validator
319                .validate("http://metadata.google.internal")
320                .is_err()
321        );
322    }
323
324    #[test]
325    fn test_invalid_url() {
326        let validator = UrlValidator::new();
327        assert!(validator.validate("not-a-url").is_err());
328        assert!(validator.validate("").is_err());
329        assert!(validator.validate("ftp://example.com").is_err());
330    }
331
332    #[test]
333    fn test_allowed_domains() {
334        let validator = UrlValidator::new().with_allowed_domains(vec!["example.com".to_string()]);
335
336        assert!(validator.validate("https://example.com").is_ok());
337
338        let result = validator.validate("https://other.com");
339        assert!(result.is_err());
340        assert!(
341            result
342                .unwrap_err()
343                .to_string()
344                .contains("not in the allowed domains")
345        );
346    }
347
348    #[test]
349    fn test_blocked_hosts() {
350        let validator = UrlValidator::new().with_blocked_hosts(vec!["blocked.com".to_string()]);
351
352        let result = validator.validate("https://blocked.com");
353        assert!(result.is_err());
354        assert!(result.unwrap_err().to_string().contains("blocked"));
355    }
356
357    #[test]
358    fn test_is_private_ipv4() {
359        // Private ranges
360        assert!(is_private_ipv4(Ipv4Addr::new(10, 0, 0, 1)));
361        assert!(is_private_ipv4(Ipv4Addr::new(10, 255, 255, 255)));
362        assert!(is_private_ipv4(Ipv4Addr::new(172, 16, 0, 1)));
363        assert!(is_private_ipv4(Ipv4Addr::new(172, 31, 255, 255)));
364        assert!(is_private_ipv4(Ipv4Addr::new(192, 168, 0, 1)));
365        assert!(is_private_ipv4(Ipv4Addr::new(192, 168, 255, 255)));
366
367        // Not private
368        assert!(!is_private_ipv4(Ipv4Addr::new(8, 8, 8, 8)));
369        assert!(!is_private_ipv4(Ipv4Addr::new(1, 1, 1, 1)));
370        assert!(!is_private_ipv4(Ipv4Addr::new(172, 15, 0, 1)));
371        assert!(!is_private_ipv4(Ipv4Addr::new(172, 32, 0, 1)));
372    }
373
374    #[test]
375    fn test_max_redirects() {
376        let validator = UrlValidator::new().with_max_redirects(5);
377        assert_eq!(validator.max_redirects(), 5);
378    }
379
380    #[test]
381    fn test_default_validator() {
382        let validator = UrlValidator::default();
383        assert!(!validator.allow_private_ips);
384        assert!(validator.require_https);
385        assert_eq!(validator.max_redirects, 3);
386    }
387
388    #[test]
389    fn test_unresolvable_host_blocked() {
390        let validator = UrlValidator::new();
391        let result = validator.validate("https://this-domain-does-not-exist-xyz123.example");
392        assert!(result.is_err());
393        let err_msg = result.unwrap_err().to_string();
394        assert!(
395            err_msg.contains("Could not resolve host"),
396            "Expected DNS resolution failure, got: {err_msg}"
397        );
398    }
399
400    #[test]
401    fn test_ipv4_mapped_ipv6_private_detected() {
402        // ::ffff:10.0.0.1 should be detected as private
403        let ip: IpAddr = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0x0a00, 0x0001));
404        assert!(is_private_ip(&ip));
405    }
406
407    #[test]
408    fn test_ipv4_mapped_ipv6_loopback_detected() {
409        // ::ffff:127.0.0.1 should be detected as loopback
410        let ip: IpAddr = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0x7f00, 0x0001));
411        assert!(is_loopback(&ip));
412    }
413
414    #[test]
415    fn test_ipv4_mapped_ipv6_link_local_detected() {
416        // ::ffff:169.254.169.254 should be detected as link-local
417        let ip: IpAddr = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0xa9fe, 0xa9fe));
418        assert!(is_link_local(&ip));
419    }
420
421    #[test]
422    fn test_regular_ipv6_private_still_detected() {
423        // fc00::1 should still be detected as private
424        let ip: IpAddr = IpAddr::V6(Ipv6Addr::new(0xfc00, 0, 0, 0, 0, 0, 0, 1));
425        assert!(is_private_ip(&ip));
426    }
427
428    #[test]
429    fn test_ipv4_mapped_ipv6_public_not_flagged() {
430        // ::ffff:8.8.8.8 should NOT be private
431        let ip: IpAddr = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0x0808, 0x0808));
432        assert!(!is_private_ip(&ip));
433        assert!(!is_loopback(&ip));
434        assert!(!is_link_local(&ip));
435    }
436}