Skip to main content

lean_ctx/core/web/
url_guard.rs

1//! URL validation and SSRF protection for outbound fetches.
2//!
3//! `ctx_url_read` accepts arbitrary URLs supplied by an agent, so every request
4//! is gated here before any socket is opened: a scheme allow-list, rejection of
5//! embedded credentials, and rejection of hosts that resolve to loopback /
6//! private / link-local / metadata ranges. Redirect hops are re-validated by the
7//! caller using the same primitives.
8
9use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, ToSocketAddrs};
10
11/// Reasons a URL is refused before fetching.
12#[derive(Debug, Clone, PartialEq, Eq)]
13pub enum UrlError {
14    Empty,
15    BadScheme(String),
16    MissingHost,
17    Credentials,
18    Blocked(String),
19    Unresolvable(String),
20}
21
22impl std::fmt::Display for UrlError {
23    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24        match self {
25            Self::Empty => write!(f, "empty URL"),
26            Self::BadScheme(s) => {
27                write!(f, "unsupported scheme '{s}' (only http/https allowed)")
28            }
29            Self::MissingHost => write!(f, "URL has no host"),
30            Self::Credentials => write!(f, "URLs with embedded credentials are not allowed"),
31            Self::Blocked(h) => {
32                write!(
33                    f,
34                    "host '{h}' resolves to a blocked (private/loopback) address"
35                )
36            }
37            Self::Unresolvable(h) => write!(f, "host '{h}' could not be resolved"),
38        }
39    }
40}
41
42impl std::error::Error for UrlError {}
43
44/// A syntactically valid http(s) URL with its parsed authority.
45#[derive(Debug, Clone, PartialEq, Eq)]
46pub struct SafeUrl {
47    pub scheme: String,
48    pub host: String,
49    pub port: u16,
50    pub authority: String,
51    pub normalized: String,
52}
53
54/// Validate URL *syntax* only (no DNS lookup). Call
55/// [`SafeUrl::ensure_resolves_safely`] before opening a socket.
56pub fn validate(raw: &str) -> Result<SafeUrl, UrlError> {
57    let trimmed = raw.trim();
58    if trimmed.is_empty() {
59        return Err(UrlError::Empty);
60    }
61    let Some((scheme_raw, rest)) = trimmed.split_once("://") else {
62        let head: String = trimmed.chars().take(12).collect();
63        return Err(UrlError::BadScheme(head));
64    };
65    let scheme = scheme_raw.to_ascii_lowercase();
66    if scheme != "http" && scheme != "https" {
67        return Err(UrlError::BadScheme(scheme));
68    }
69
70    let auth_end = rest.find(['/', '?', '#']).unwrap_or(rest.len());
71    let authority = &rest[..auth_end];
72    let path = &rest[auth_end..];
73    if authority.is_empty() {
74        return Err(UrlError::MissingHost);
75    }
76    if authority.contains('@') {
77        return Err(UrlError::Credentials);
78    }
79
80    let (host, port) = split_host_port(authority, &scheme)?;
81    if host.is_empty() {
82        return Err(UrlError::MissingHost);
83    }
84
85    Ok(SafeUrl {
86        scheme: scheme.clone(),
87        host,
88        port,
89        authority: authority.to_string(),
90        normalized: format!("{scheme}://{authority}{path}"),
91    })
92}
93
94fn split_host_port(authority: &str, scheme: &str) -> Result<(String, u16), UrlError> {
95    let default_port = if scheme == "https" { 443 } else { 80 };
96
97    // IPv6 literal form: `[::1]` or `[::1]:8080`.
98    if let Some(stripped) = authority.strip_prefix('[') {
99        let Some(end) = stripped.find(']') else {
100            return Err(UrlError::MissingHost);
101        };
102        let host = stripped[..end].to_string();
103        let port = match stripped[end + 1..].strip_prefix(':') {
104            Some(p) => p.parse().map_err(|_| UrlError::MissingHost)?,
105            None => default_port,
106        };
107        return Ok((host, port));
108    }
109
110    match authority.rsplit_once(':') {
111        Some((host, port_str))
112            if !port_str.is_empty() && port_str.bytes().all(|b| b.is_ascii_digit()) =>
113        {
114            let port = port_str.parse().map_err(|_| UrlError::MissingHost)?;
115            Ok((host.to_string(), port))
116        }
117        _ => Ok((authority.to_string(), default_port)),
118    }
119}
120
121impl SafeUrl {
122    /// Resolve the host and reject if *any* resolved address falls in a blocked
123    /// range. Rejecting on a single blocked result is a conservative guard
124    /// against DNS-rebinding that mixes a public and an internal address.
125    pub fn ensure_resolves_safely(&self) -> Result<(), UrlError> {
126        if let Ok(ip) = self.host.parse::<IpAddr>() {
127            return if ip_is_blocked(ip) {
128                Err(UrlError::Blocked(self.host.clone()))
129            } else {
130                Ok(())
131            };
132        }
133
134        let addrs = (self.host.as_str(), self.port)
135            .to_socket_addrs()
136            .map_err(|_| UrlError::Unresolvable(self.host.clone()))?;
137
138        let mut resolved_any = false;
139        for addr in addrs {
140            resolved_any = true;
141            if ip_is_blocked(addr.ip()) {
142                return Err(UrlError::Blocked(self.host.clone()));
143            }
144        }
145
146        if resolved_any {
147            Ok(())
148        } else {
149            Err(UrlError::Unresolvable(self.host.clone()))
150        }
151    }
152}
153
154/// True for addresses an outbound fetch must never reach (SSRF guard).
155pub fn ip_is_blocked(ip: IpAddr) -> bool {
156    match ip {
157        IpAddr::V4(v4) => v4_is_blocked(v4),
158        IpAddr::V6(v6) => {
159            // Dual-stack hosts can expose internal v4 ranges via mapped addrs.
160            if let Some(mapped) = v6.to_ipv4_mapped() {
161                return v4_is_blocked(mapped);
162            }
163            v6.is_loopback()
164                || v6.is_unspecified()
165                || is_unique_local_v6(v6)
166                || is_link_local_v6(v6)
167        }
168    }
169}
170
171fn v4_is_blocked(v4: Ipv4Addr) -> bool {
172    let o = v4.octets();
173    v4.is_loopback()
174        || v4.is_private()
175        || v4.is_link_local()
176        || v4.is_broadcast()
177        || v4.is_unspecified()
178        || v4.is_documentation()
179        || o[0] == 0
180        // 100.64.0.0/10 carrier-grade NAT.
181        || (o[0] == 100 && (o[1] & 0xc0) == 64)
182}
183
184fn is_unique_local_v6(v6: Ipv6Addr) -> bool {
185    (v6.segments()[0] & 0xfe00) == 0xfc00
186}
187
188fn is_link_local_v6(v6: Ipv6Addr) -> bool {
189    (v6.segments()[0] & 0xffc0) == 0xfe80
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195
196    #[test]
197    fn validates_https_with_path() {
198        let u = validate("https://example.com/foo/bar?x=1").unwrap();
199        assert_eq!(u.scheme, "https");
200        assert_eq!(u.host, "example.com");
201        assert_eq!(u.port, 443);
202        assert_eq!(u.authority, "example.com");
203        assert_eq!(u.normalized, "https://example.com/foo/bar?x=1");
204    }
205
206    #[test]
207    fn validates_http_with_explicit_port() {
208        let u = validate("http://example.com:8080/p").unwrap();
209        assert_eq!(u.port, 8080);
210        assert_eq!(u.authority, "example.com:8080");
211    }
212
213    #[test]
214    fn validates_ipv6_literal_with_port() {
215        let u = validate("https://[2606:4700::1111]:8443/p").unwrap();
216        assert_eq!(u.host, "2606:4700::1111");
217        assert_eq!(u.port, 8443);
218    }
219
220    #[test]
221    fn rejects_non_http_scheme() {
222        assert!(matches!(
223            validate("ftp://example.com"),
224            Err(UrlError::BadScheme(_))
225        ));
226        assert!(matches!(
227            validate("file:///etc/passwd"),
228            Err(UrlError::BadScheme(_))
229        ));
230    }
231
232    #[test]
233    fn rejects_empty_and_credentials() {
234        assert_eq!(validate("   "), Err(UrlError::Empty));
235        assert_eq!(
236            validate("https://user:pass@example.com"),
237            Err(UrlError::Credentials)
238        );
239    }
240
241    #[test]
242    fn blocks_loopback_and_private_v4() {
243        for ip in ["127.0.0.1", "10.0.0.1", "192.168.1.1", "172.16.0.1"] {
244            assert!(ip_is_blocked(ip.parse().unwrap()), "{ip} must be blocked");
245        }
246    }
247
248    #[test]
249    fn blocks_metadata_and_cgnat() {
250        assert!(ip_is_blocked("169.254.169.254".parse().unwrap()));
251        assert!(ip_is_blocked("100.64.0.1".parse().unwrap()));
252        assert!(ip_is_blocked("0.0.0.0".parse().unwrap()));
253    }
254
255    #[test]
256    fn allows_public_v4_and_v6() {
257        assert!(!ip_is_blocked("8.8.8.8".parse().unwrap()));
258        assert!(!ip_is_blocked("1.1.1.1".parse().unwrap()));
259        assert!(!ip_is_blocked("2606:4700:4700::1111".parse().unwrap()));
260    }
261
262    #[test]
263    fn blocks_v6_internal_ranges() {
264        assert!(ip_is_blocked("::1".parse().unwrap()));
265        assert!(ip_is_blocked("fe80::1".parse().unwrap()));
266        assert!(ip_is_blocked("fc00::1".parse().unwrap()));
267        assert!(ip_is_blocked("::ffff:127.0.0.1".parse().unwrap()));
268    }
269
270    #[test]
271    fn ensure_resolves_safely_rejects_literal_loopback() {
272        let u = validate("http://127.0.0.1/").unwrap();
273        assert!(matches!(
274            u.ensure_resolves_safely(),
275            Err(UrlError::Blocked(_))
276        ));
277    }
278
279    #[test]
280    fn ensure_resolves_safely_allows_literal_public_ip() {
281        let u = validate("http://8.8.8.8/").unwrap();
282        assert!(u.ensure_resolves_safely().is_ok());
283    }
284}