Skip to main content

agent_sdk/web/
security.rs

1//! URL validation and SSRF protection.
2
3use anyhow::{Context, Result, bail};
4use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
5use url::Url;
6
7/// Default blocked hostnames for SSRF protection.
8const DEFAULT_BLOCKED_HOSTS: &[&str] = &[
9    "localhost",
10    "127.0.0.1",
11    "0.0.0.0",
12    "::1",
13    "[::1]",
14    "169.254.169.254",          // AWS metadata
15    "metadata.google.internal", // GCP metadata
16    "metadata.goog",            // GCP metadata alternate
17];
18
19/// URL validator with SSRF protection.
20///
21/// Validates URLs before fetching to prevent Server-Side Request Forgery attacks.
22/// By default, blocks access to:
23/// - Localhost and loopback addresses
24/// - Private IP ranges (10.x, 172.16-31.x, 192.168.x)
25/// - Cloud metadata endpoints (AWS, GCP)
26///
27/// # Example
28///
29/// ```ignore
30/// use agent_sdk::web::UrlValidator;
31///
32/// let validator = UrlValidator::new();
33/// assert!(validator.validate("https://example.com").is_ok());
34/// assert!(validator.validate("http://localhost").is_err());
35/// ```
36#[derive(Clone, Debug)]
37pub struct UrlValidator {
38    /// Only allow these domains (if Some).
39    allowed_domains: Option<Vec<String>>,
40    /// Block these hostnames/IPs.
41    blocked_hosts: Vec<String>,
42    /// Allow private IP ranges (default: false).
43    allow_private_ips: bool,
44    /// Maximum number of redirects to follow (default: 3).
45    max_redirects: usize,
46    /// Require HTTPS (default: true).
47    require_https: bool,
48}
49
50/// A URL that has passed SSRF validation, together with the exact IP addresses
51/// it resolved to.
52///
53/// The caller must connect to one of [`ValidatedUrl::addresses`] (e.g. by
54/// pinning them via [`reqwest::ClientBuilder::resolve_to_addrs`]) rather than
55/// re-resolving the host. Re-resolving opens a DNS-rebinding TOCTOU hole: the
56/// attacker-controlled record can pass validation here and then rebind to a
57/// blocked address (`169.254.169.254`, `127.0.0.1`, …) before the connection is
58/// made.
59#[derive(Clone, Debug)]
60pub struct ValidatedUrl {
61    /// The validated URL.
62    pub url: Url,
63    /// The vetted socket addresses the host resolved to. Pin these for the
64    /// actual request so the connection targets exactly what was validated.
65    pub addresses: Vec<SocketAddr>,
66}
67
68impl Default for UrlValidator {
69    fn default() -> Self {
70        Self::new()
71    }
72}
73
74impl UrlValidator {
75    /// Create a new URL validator with default security settings.
76    #[must_use]
77    pub fn new() -> Self {
78        Self {
79            allowed_domains: None,
80            blocked_hosts: DEFAULT_BLOCKED_HOSTS
81                .iter()
82                .map(|&s| s.to_string())
83                .collect(),
84            allow_private_ips: false,
85            max_redirects: 3,
86            require_https: true,
87        }
88    }
89
90    /// Only allow URLs from specific domains.
91    #[must_use]
92    pub fn with_allowed_domains(mut self, domains: Vec<String>) -> Self {
93        self.allowed_domains = Some(domains);
94        self
95    }
96
97    /// Add additional blocked hosts.
98    #[must_use]
99    pub fn with_blocked_hosts(mut self, hosts: Vec<String>) -> Self {
100        self.blocked_hosts.extend(hosts);
101        self
102    }
103
104    /// Allow private IP ranges (dangerous - use with caution).
105    #[must_use]
106    pub const fn with_allow_private_ips(mut self, allow: bool) -> Self {
107        self.allow_private_ips = allow;
108        self
109    }
110
111    /// Set maximum redirects.
112    #[must_use]
113    pub const fn with_max_redirects(mut self, max: usize) -> Self {
114        self.max_redirects = max;
115        self
116    }
117
118    /// Allow HTTP URLs (default requires HTTPS).
119    #[must_use]
120    pub const fn with_allow_http(mut self) -> Self {
121        self.require_https = false;
122        self
123    }
124
125    /// Get the maximum number of redirects allowed.
126    #[must_use]
127    pub const fn max_redirects(&self) -> usize {
128        self.max_redirects
129    }
130
131    /// Validate a URL string and return it alongside its vetted IP addresses.
132    ///
133    /// DNS resolution runs on the tokio runtime via [`tokio::net::lookup_host`]
134    /// (not the blocking `getaddrinfo` on a worker thread), and the resolved
135    /// addresses are returned so the caller can pin them for the actual request
136    /// — closing the DNS-rebinding TOCTOU window. See [`ValidatedUrl`].
137    ///
138    /// # Errors
139    ///
140    /// Returns an error if:
141    /// - The URL is malformed
142    /// - The scheme is not HTTP or HTTPS
143    /// - HTTPS is required but HTTP is used
144    /// - The host is blocked
145    /// - The host resolves to a private/blocked IP
146    /// - The domain is not in the allowed list
147    pub async fn validate(&self, url_str: &str) -> Result<ValidatedUrl> {
148        let url = Url::parse(url_str).context("Invalid URL format")?;
149
150        // Check scheme
151        match url.scheme() {
152            "https" => {}
153            "http" => {
154                if self.require_https {
155                    bail!("HTTPS required, but HTTP URL provided");
156                }
157            }
158            scheme => bail!("Unsupported URL scheme: {scheme}"),
159        }
160
161        // Check host
162        let host = url.host_str().context("URL must have a host")?;
163
164        // Check blocked hosts
165        if self.blocked_hosts.iter().any(|blocked| {
166            host.eq_ignore_ascii_case(blocked) || host.ends_with(&format!(".{blocked}"))
167        }) {
168            bail!("Access to host '{host}' is blocked");
169        }
170
171        // Check allowed domains
172        if let Some(ref allowed) = self.allowed_domains {
173            let is_allowed = allowed.iter().any(|domain| {
174                host.eq_ignore_ascii_case(domain) || host.ends_with(&format!(".{domain}"))
175            });
176            if !is_allowed {
177                bail!("Host '{host}' is not in the allowed domains list");
178            }
179        }
180
181        // Resolve and check IP — using the URL's real port so the pinned
182        // addresses connect to the right place.
183        let port = url.port_or_known_default().unwrap_or(443);
184        let addresses = self.validate_resolved_ip(host, port).await?;
185
186        Ok(ValidatedUrl { url, addresses })
187    }
188
189    /// Resolve `host` (asynchronously) and verify every resolved IP is safe,
190    /// returning the vetted socket addresses.
191    ///
192    /// Fails closed: if DNS resolution returns no results (or fails), the host
193    /// is blocked to prevent DNS-rebinding attacks that rely on transient lookup
194    /// failures.
195    async fn validate_resolved_ip(&self, host: &str, port: u16) -> Result<Vec<SocketAddr>> {
196        // Resolve on the runtime (no blocking getaddrinfo on a worker thread).
197        // `"{host}:{port}"` parses bracketed IPv6 literals correctly too.
198        let addrs: Vec<SocketAddr> = tokio::net::lookup_host(format!("{host}:{port}"))
199            .await
200            .map(Iterator::collect)
201            .unwrap_or_default();
202
203        if addrs.is_empty() {
204            bail!("Could not resolve host '{host}' — blocking unresolvable URLs for safety");
205        }
206
207        for addr in &addrs {
208            let ip = addr.ip();
209            if !self.allow_private_ips && is_private_ip(&ip) {
210                bail!("Access to private IP address {ip} is blocked");
211            }
212            if is_loopback(&ip) {
213                bail!("Access to loopback address {ip} is blocked");
214            }
215            if is_link_local(&ip) {
216                bail!("Access to link-local address {ip} is blocked");
217            }
218        }
219
220        Ok(addrs)
221    }
222}
223
224/// Check if an IP address is private.
225///
226/// Also handles IPv4-mapped IPv6 addresses (`::ffff:x.x.x.x`) by extracting
227/// the embedded IPv4 address and applying IPv4 checks.
228fn is_private_ip(ip: &IpAddr) -> bool {
229    match ip {
230        IpAddr::V4(ipv4) => is_private_ipv4(*ipv4),
231        IpAddr::V6(ipv6) => {
232            // Check for IPv4-mapped IPv6 addresses (::ffff:x.x.x.x)
233            if let Some(mapped_v4) = ipv6.to_ipv4_mapped() {
234                return is_private_ipv4(mapped_v4);
235            }
236            is_private_ipv6(ipv6)
237        }
238    }
239}
240
241/// Check if an IPv4 address is private.
242fn is_private_ipv4(ip: Ipv4Addr) -> bool {
243    let octets = ip.octets();
244
245    // 10.0.0.0/8
246    if octets[0] == 10 {
247        return true;
248    }
249
250    // 172.16.0.0/12
251    if octets[0] == 172 && (16..=31).contains(&octets[1]) {
252        return true;
253    }
254
255    // 192.168.0.0/16
256    if octets[0] == 192 && octets[1] == 168 {
257        return true;
258    }
259
260    // 100.64.0.0/10 (Carrier-grade NAT)
261    if octets[0] == 100 && (64..=127).contains(&octets[1]) {
262        return true;
263    }
264
265    false
266}
267
268/// Check if an IPv6 address is private.
269const fn is_private_ipv6(ip: &Ipv6Addr) -> bool {
270    // Unique local addresses (fc00::/7)
271    let segments = ip.segments();
272    (segments[0] & 0xfe00) == 0xfc00
273}
274
275/// Check if an IP is a loopback address.
276///
277/// Handles IPv4-mapped IPv6 addresses (`::ffff:127.0.0.1`).
278const fn is_loopback(ip: &IpAddr) -> bool {
279    match ip {
280        IpAddr::V4(ipv4) => ipv4.is_loopback(),
281        IpAddr::V6(ipv6) => {
282            if let Some(mapped_v4) = ipv6.to_ipv4_mapped() {
283                return mapped_v4.is_loopback();
284            }
285            ipv6.is_loopback()
286        }
287    }
288}
289
290/// Check if an IP is a link-local address.
291///
292/// Handles IPv4-mapped IPv6 addresses (`::ffff:169.254.x.x`).
293const fn is_link_local(ip: &IpAddr) -> bool {
294    match ip {
295        IpAddr::V4(ipv4) => ipv4.is_link_local(),
296        IpAddr::V6(ipv6) => {
297            if let Some(mapped_v4) = ipv6.to_ipv4_mapped() {
298                return mapped_v4.is_link_local();
299            }
300            // fe80::/10
301            let segments = ipv6.segments();
302            (segments[0] & 0xffc0) == 0xfe80
303        }
304    }
305}
306
307#[cfg(test)]
308mod tests {
309    use super::*;
310
311    #[tokio::test]
312    async fn test_valid_https_url() {
313        let validator = UrlValidator::new();
314        assert!(validator.validate("https://example.com").await.is_ok());
315        assert!(validator.validate("https://example.com/path").await.is_ok());
316    }
317
318    #[tokio::test]
319    async fn test_validate_returns_vetted_addresses() -> Result<()> {
320        // The validated result must carry the resolved addresses so the caller
321        // can pin them and avoid a re-resolution (DNS-rebinding) window.
322        let validator = UrlValidator::new();
323        let validated = validator
324            .validate("https://example.com")
325            .await
326            .context("example.com should validate")?;
327        assert!(
328            !validated.addresses.is_empty(),
329            "validation must return the vetted IP addresses for pinning"
330        );
331        // The port carried in the pinned addresses must be the URL's port.
332        assert!(validated.addresses.iter().all(|a| a.port() == 443));
333        Ok(())
334    }
335
336    #[tokio::test]
337    async fn test_http_blocked_by_default() {
338        let validator = UrlValidator::new();
339        let result = validator.validate("http://example.com").await;
340        assert!(result.is_err());
341        assert!(result.unwrap_err().to_string().contains("HTTPS required"));
342    }
343
344    #[tokio::test]
345    async fn test_http_allowed_with_flag() {
346        let validator = UrlValidator::new().with_allow_http();
347        assert!(validator.validate("http://example.com").await.is_ok());
348    }
349
350    #[tokio::test]
351    async fn test_localhost_blocked() {
352        let validator = UrlValidator::new().with_allow_http();
353        assert!(validator.validate("http://localhost").await.is_err());
354        assert!(validator.validate("http://127.0.0.1").await.is_err());
355        assert!(validator.validate("http://[::1]").await.is_err());
356    }
357
358    #[tokio::test]
359    async fn test_metadata_endpoints_blocked() {
360        let validator = UrlValidator::new().with_allow_http();
361        assert!(validator.validate("http://169.254.169.254").await.is_err());
362        assert!(
363            validator
364                .validate("http://metadata.google.internal")
365                .await
366                .is_err()
367        );
368    }
369
370    #[tokio::test]
371    async fn test_invalid_url() {
372        let validator = UrlValidator::new();
373        assert!(validator.validate("not-a-url").await.is_err());
374        assert!(validator.validate("").await.is_err());
375        assert!(validator.validate("ftp://example.com").await.is_err());
376    }
377
378    #[tokio::test]
379    async fn test_allowed_domains() {
380        let validator = UrlValidator::new().with_allowed_domains(vec!["example.com".to_string()]);
381
382        assert!(validator.validate("https://example.com").await.is_ok());
383
384        let result = validator.validate("https://other.com").await;
385        assert!(result.is_err());
386        assert!(
387            result
388                .unwrap_err()
389                .to_string()
390                .contains("not in the allowed domains")
391        );
392    }
393
394    #[tokio::test]
395    async fn test_blocked_hosts() {
396        let validator = UrlValidator::new().with_blocked_hosts(vec!["blocked.com".to_string()]);
397
398        let result = validator.validate("https://blocked.com").await;
399        assert!(result.is_err());
400        assert!(result.unwrap_err().to_string().contains("blocked"));
401    }
402
403    #[test]
404    fn test_is_private_ipv4() {
405        // Private ranges
406        assert!(is_private_ipv4(Ipv4Addr::new(10, 0, 0, 1)));
407        assert!(is_private_ipv4(Ipv4Addr::new(10, 255, 255, 255)));
408        assert!(is_private_ipv4(Ipv4Addr::new(172, 16, 0, 1)));
409        assert!(is_private_ipv4(Ipv4Addr::new(172, 31, 255, 255)));
410        assert!(is_private_ipv4(Ipv4Addr::new(192, 168, 0, 1)));
411        assert!(is_private_ipv4(Ipv4Addr::new(192, 168, 255, 255)));
412
413        // Not private
414        assert!(!is_private_ipv4(Ipv4Addr::new(8, 8, 8, 8)));
415        assert!(!is_private_ipv4(Ipv4Addr::new(1, 1, 1, 1)));
416        assert!(!is_private_ipv4(Ipv4Addr::new(172, 15, 0, 1)));
417        assert!(!is_private_ipv4(Ipv4Addr::new(172, 32, 0, 1)));
418    }
419
420    #[test]
421    fn test_max_redirects() {
422        let validator = UrlValidator::new().with_max_redirects(5);
423        assert_eq!(validator.max_redirects(), 5);
424    }
425
426    #[test]
427    fn test_default_validator() {
428        let validator = UrlValidator::default();
429        assert!(!validator.allow_private_ips);
430        assert!(validator.require_https);
431        assert_eq!(validator.max_redirects, 3);
432    }
433
434    #[tokio::test]
435    async fn test_unresolvable_host_blocked() {
436        let validator = UrlValidator::new();
437        let result = validator
438            .validate("https://this-domain-does-not-exist-xyz123.example")
439            .await;
440        assert!(result.is_err());
441        let err_msg = result.unwrap_err().to_string();
442        assert!(
443            err_msg.contains("Could not resolve host"),
444            "Expected DNS resolution failure, got: {err_msg}"
445        );
446    }
447
448    #[test]
449    fn test_ipv4_mapped_ipv6_private_detected() {
450        // ::ffff:10.0.0.1 should be detected as private
451        let ip: IpAddr = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0x0a00, 0x0001));
452        assert!(is_private_ip(&ip));
453    }
454
455    #[test]
456    fn test_ipv4_mapped_ipv6_loopback_detected() {
457        // ::ffff:127.0.0.1 should be detected as loopback
458        let ip: IpAddr = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0x7f00, 0x0001));
459        assert!(is_loopback(&ip));
460    }
461
462    #[test]
463    fn test_ipv4_mapped_ipv6_link_local_detected() {
464        // ::ffff:169.254.169.254 should be detected as link-local
465        let ip: IpAddr = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0xa9fe, 0xa9fe));
466        assert!(is_link_local(&ip));
467    }
468
469    #[test]
470    fn test_regular_ipv6_private_still_detected() {
471        // fc00::1 should still be detected as private
472        let ip: IpAddr = IpAddr::V6(Ipv6Addr::new(0xfc00, 0, 0, 0, 0, 0, 0, 1));
473        assert!(is_private_ip(&ip));
474    }
475
476    #[test]
477    fn test_ipv4_mapped_ipv6_public_not_flagged() {
478        // ::ffff:8.8.8.8 should NOT be private
479        let ip: IpAddr = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0x0808, 0x0808));
480        assert!(!is_private_ip(&ip));
481        assert!(!is_loopback(&ip));
482        assert!(!is_link_local(&ip));
483    }
484}