Skip to main content

mockforge_bench/
ssrf.rs

1//! SSRF (Server-Side Request Forgery) guard for cloud-driven runs.
2//!
3//! When `mockforge-test-runner` (or any other cloud-side caller) accepts a
4//! user-supplied target URL and dispatches HTTP traffic at it, an attacker
5//! can point the runner at internal infrastructure: metadata endpoints
6//! (`169.254.169.254`), localhost services, RFC1918 / ULA private ranges,
7//! and so on. Without a guard, we hand them a free internal-network
8//! scanner running inside our Fly.io org.
9//!
10//! [`validate_target_url`] is the single chokepoint. It:
11//!
12//! 1. Parses the URL and rejects anything that isn't `http://` or
13//!    `https://` (no `file://`, `gopher://`, etc.).
14//! 2. Resolves the hostname via DNS (asynchronously) and rejects if any
15//!    resolved IP is in a blocked range.
16//! 3. Treats literal-IP hostnames (`http://10.0.0.1/`) the same way —
17//!    parsing them as `IpAddr` directly so DNS resolution isn't needed.
18//!
19//! Blocked ranges:
20//!
21//! * IPv4: loopback (`127.0.0.0/8`), link-local (`169.254.0.0/16` —
22//!   includes the AWS/GCP/Fly metadata IP `169.254.169.254`), unspecified
23//!   (`0.0.0.0/8`), broadcast, RFC1918 private (`10/8`, `172.16/12`,
24//!   `192.168/16`), CGNAT (`100.64.0.0/10`), benchmark (`198.18/15`).
25//! * IPv6: loopback (`::1`), unspecified (`::`), link-local (`fe80::/10`),
26//!   ULA (`fc00::/7`), IPv4-mapped (`::ffff:0:0/96` — caller could smuggle
27//!   in a private v4 address).
28//!
29//! There is no escape hatch — production callers MUST ensure their target
30//! is publicly reachable. Tests can override with the `loopback-ok` env
31//! var (see [`Policy::for_test`]) for integration-test endpoints on
32//! `127.0.0.1`.
33
34use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
35
36use thiserror::Error;
37
38/// Reasons a target URL can be rejected by [`validate_target_url`].
39#[derive(Debug, Error)]
40pub enum SsrfError {
41    #[error("invalid URL: {0}")]
42    InvalidUrl(String),
43
44    #[error("URL scheme '{0}' not allowed — only http/https")]
45    DisallowedScheme(String),
46
47    #[error("URL has no host component")]
48    MissingHost,
49
50    #[error("DNS resolution failed for '{host}': {source}")]
51    DnsResolutionFailed {
52        host: String,
53        #[source]
54        source: std::io::Error,
55    },
56
57    #[error("DNS resolution returned no addresses for '{0}'")]
58    NoAddressesResolved(String),
59
60    #[error(
61        "target '{host}' resolves to {ip} which is in a blocked range ({reason}); \
62         pointing the cloud runner at internal addresses is not allowed"
63    )]
64    BlockedAddress {
65        host: String,
66        ip: IpAddr,
67        reason: &'static str,
68    },
69}
70
71/// Knobs for [`validate_target_url`]. Default policy is the strict
72/// production one; tests can relax it via [`Policy::allow_loopback`].
73#[derive(Debug, Clone, Copy, Default)]
74pub struct Policy {
75    /// When true, addresses in the IPv4/IPv6 loopback range are allowed.
76    /// Production callers should leave this `false`. Tests against
77    /// `127.0.0.1` set it via [`Policy::for_test`].
78    pub allow_loopback: bool,
79}
80
81impl Policy {
82    /// Strict production policy: nothing private, nothing local.
83    pub const fn strict() -> Self {
84        Self {
85            allow_loopback: false,
86        }
87    }
88
89    /// Test policy: allows loopback so integration tests against
90    /// `127.0.0.1:<port>` mock servers can run.
91    pub const fn for_test() -> Self {
92        Self {
93            allow_loopback: true,
94        }
95    }
96}
97
98/// Validate that a URL is safe for a cloud runner to hit. Returns `Ok(())`
99/// when the target is a publicly-routable HTTP/S endpoint.
100///
101/// Performs DNS resolution, so this is async. Cache results at the call
102/// site if the same URL is validated repeatedly within one request.
103pub async fn validate_target_url(url: &str, policy: Policy) -> Result<(), SsrfError> {
104    let parsed = url::Url::parse(url).map_err(|e| SsrfError::InvalidUrl(e.to_string()))?;
105
106    let scheme = parsed.scheme();
107    if scheme != "http" && scheme != "https" {
108        return Err(SsrfError::DisallowedScheme(scheme.to_string()));
109    }
110
111    let host = parsed.host_str().ok_or(SsrfError::MissingHost)?.to_string();
112    let port = parsed.port_or_known_default().unwrap_or(80);
113
114    // Literal IP — no DNS needed.
115    if let Ok(ip) = host.parse::<IpAddr>() {
116        check_ip(&host, ip, policy)?;
117        return Ok(());
118    }
119
120    // Hostname — resolve and check every resolved address. An attacker
121    // can register a public DNS name pointing at 169.254.169.254 (a.k.a.
122    // "DNS rebinding" / "0.0.0.0 day"), so we MUST inspect resolved IPs,
123    // not just the literal-IP form.
124    let lookup_target = format!("{}:{}", host, port);
125    let addrs: Vec<std::net::SocketAddr> = tokio::net::lookup_host(&lookup_target)
126        .await
127        .map_err(|source| SsrfError::DnsResolutionFailed {
128            host: host.clone(),
129            source,
130        })?
131        .collect();
132
133    if addrs.is_empty() {
134        return Err(SsrfError::NoAddressesResolved(host));
135    }
136
137    for addr in addrs {
138        check_ip(&host, addr.ip(), policy)?;
139    }
140
141    Ok(())
142}
143
144fn check_ip(host: &str, ip: IpAddr, policy: Policy) -> Result<(), SsrfError> {
145    if let Some(reason) = blocked_reason(ip, policy) {
146        return Err(SsrfError::BlockedAddress {
147            host: host.to_string(),
148            ip,
149            reason,
150        });
151    }
152    Ok(())
153}
154
155fn blocked_reason(ip: IpAddr, policy: Policy) -> Option<&'static str> {
156    match ip {
157        IpAddr::V4(v4) => blocked_reason_v4(v4, policy),
158        IpAddr::V6(v6) => blocked_reason_v6(v6, policy),
159    }
160}
161
162fn blocked_reason_v4(ip: Ipv4Addr, policy: Policy) -> Option<&'static str> {
163    if ip.is_loopback() {
164        if policy.allow_loopback {
165            return None;
166        }
167        return Some("IPv4 loopback (127.0.0.0/8)");
168    }
169    if ip.is_unspecified() {
170        return Some("IPv4 unspecified (0.0.0.0)");
171    }
172    if ip.is_broadcast() {
173        return Some("IPv4 broadcast");
174    }
175    if ip.is_link_local() {
176        return Some("IPv4 link-local (169.254.0.0/16, includes cloud metadata IP)");
177    }
178    if ip.is_private() {
179        return Some("IPv4 RFC1918 private (10/8, 172.16/12, 192.168/16)");
180    }
181    if ip.is_documentation() {
182        return Some("IPv4 documentation range (RFC5737)");
183    }
184    // CGNAT range (100.64.0.0/10) — not is_private but still
185    // not-publicly-routable. Cloud providers sometimes use it for inter-VM
186    // links, which is exactly the kind of thing we don't want to expose.
187    let octets = ip.octets();
188    if octets[0] == 100 && (64..=127).contains(&octets[1]) {
189        return Some("IPv4 CGNAT (100.64.0.0/10)");
190    }
191    // Benchmark range 198.18.0.0/15 (RFC2544).
192    if octets[0] == 198 && (octets[1] == 18 || octets[1] == 19) {
193        return Some("IPv4 benchmark (198.18.0.0/15)");
194    }
195    None
196}
197
198fn blocked_reason_v6(ip: Ipv6Addr, policy: Policy) -> Option<&'static str> {
199    if ip.is_loopback() {
200        if policy.allow_loopback {
201            return None;
202        }
203        return Some("IPv6 loopback (::1)");
204    }
205    if ip.is_unspecified() {
206        return Some("IPv6 unspecified (::)");
207    }
208    let segments = ip.segments();
209    // Link-local fe80::/10
210    if (segments[0] & 0xffc0) == 0xfe80 {
211        return Some("IPv6 link-local (fe80::/10)");
212    }
213    // ULA fc00::/7
214    if (segments[0] & 0xfe00) == 0xfc00 {
215        return Some("IPv6 unique-local (fc00::/7)");
216    }
217    // IPv4-mapped ::ffff:0:0/96 — recurse so an attacker can't smuggle a
218    // private v4 through the v6 form (`http://[::ffff:10.0.0.1]/`).
219    if let Some(v4) = ip.to_ipv4_mapped() {
220        return blocked_reason_v4(v4, policy);
221    }
222    None
223}
224
225#[cfg(test)]
226mod tests {
227    use super::*;
228
229    fn assert_blocked(addr: &str, policy: Policy, fragment: &str) {
230        let ip: IpAddr = addr.parse().unwrap();
231        let reason =
232            blocked_reason(ip, policy).unwrap_or_else(|| panic!("expected {addr} to be blocked"));
233        assert!(
234            reason.contains(fragment),
235            "{addr} blocked but reason '{reason}' missing fragment '{fragment}'"
236        );
237    }
238
239    fn assert_allowed(addr: &str, policy: Policy) {
240        let ip: IpAddr = addr.parse().unwrap();
241        assert!(blocked_reason(ip, policy).is_none(), "{addr} unexpectedly blocked");
242    }
243
244    #[test]
245    fn blocks_loopback_v4_strict() {
246        assert_blocked("127.0.0.1", Policy::strict(), "loopback");
247        assert_blocked("127.255.255.254", Policy::strict(), "loopback");
248    }
249
250    #[test]
251    fn allows_loopback_v4_in_test_policy() {
252        assert_allowed("127.0.0.1", Policy::for_test());
253    }
254
255    #[test]
256    fn blocks_link_local_aws_metadata() {
257        assert_blocked("169.254.169.254", Policy::strict(), "link-local");
258    }
259
260    #[test]
261    fn blocks_rfc1918_ranges() {
262        assert_blocked("10.0.0.1", Policy::strict(), "RFC1918");
263        assert_blocked("172.16.0.1", Policy::strict(), "RFC1918");
264        assert_blocked("172.31.255.255", Policy::strict(), "RFC1918");
265        assert_blocked("192.168.0.1", Policy::strict(), "RFC1918");
266    }
267
268    #[test]
269    fn blocks_cgnat() {
270        assert_blocked("100.64.0.1", Policy::strict(), "CGNAT");
271        assert_blocked("100.127.255.255", Policy::strict(), "CGNAT");
272    }
273
274    #[test]
275    fn allows_ranges_outside_cgnat() {
276        // 100.0.0.0/8 outside the CGNAT slice is publicly routable.
277        assert_allowed("100.63.255.255", Policy::strict());
278        assert_allowed("100.128.0.1", Policy::strict());
279    }
280
281    #[test]
282    fn blocks_benchmark_range() {
283        assert_blocked("198.18.0.1", Policy::strict(), "benchmark");
284        assert_blocked("198.19.255.255", Policy::strict(), "benchmark");
285    }
286
287    #[test]
288    fn allows_public_v4() {
289        assert_allowed("8.8.8.8", Policy::strict());
290        assert_allowed("1.1.1.1", Policy::strict());
291        assert_allowed("142.250.190.78", Policy::strict()); // google.com
292    }
293
294    #[test]
295    fn blocks_loopback_v6_strict() {
296        assert_blocked("::1", Policy::strict(), "loopback");
297    }
298
299    #[test]
300    fn blocks_link_local_v6() {
301        assert_blocked("fe80::1", Policy::strict(), "link-local");
302        assert_blocked("febf::1", Policy::strict(), "link-local");
303    }
304
305    #[test]
306    fn blocks_ula() {
307        assert_blocked("fc00::1", Policy::strict(), "unique-local");
308        assert_blocked("fd12:3456::1", Policy::strict(), "unique-local");
309    }
310
311    #[test]
312    fn blocks_ipv4_mapped_private() {
313        assert_blocked("::ffff:10.0.0.1", Policy::strict(), "RFC1918");
314        assert_blocked("::ffff:127.0.0.1", Policy::strict(), "loopback");
315    }
316
317    #[test]
318    fn allows_public_v6() {
319        assert_allowed("2606:4700:4700::1111", Policy::strict()); // cloudflare
320        assert_allowed("2001:4860:4860::8888", Policy::strict()); // google
321    }
322
323    #[tokio::test]
324    async fn validate_rejects_non_http_scheme() {
325        let err = validate_target_url("file:///etc/passwd", Policy::strict()).await.unwrap_err();
326        assert!(matches!(err, SsrfError::DisallowedScheme(s) if s == "file"));
327    }
328
329    #[tokio::test]
330    async fn validate_rejects_garbage_url() {
331        let err = validate_target_url("not a url", Policy::strict()).await.unwrap_err();
332        assert!(matches!(err, SsrfError::InvalidUrl(_)));
333    }
334
335    #[tokio::test]
336    async fn validate_rejects_literal_loopback() {
337        let err = validate_target_url("http://127.0.0.1/", Policy::strict()).await.unwrap_err();
338        assert!(matches!(err, SsrfError::BlockedAddress { .. }));
339    }
340
341    #[tokio::test]
342    async fn validate_rejects_literal_metadata_ip() {
343        let err = validate_target_url("http://169.254.169.254/latest/meta-data/", Policy::strict())
344            .await
345            .unwrap_err();
346        match err {
347            SsrfError::BlockedAddress { reason, .. } => assert!(reason.contains("link-local")),
348            other => panic!("expected BlockedAddress, got {other:?}"),
349        }
350    }
351
352    #[tokio::test]
353    async fn validate_rejects_literal_rfc1918() {
354        let err = validate_target_url("http://10.0.0.1/", Policy::strict()).await.unwrap_err();
355        assert!(matches!(err, SsrfError::BlockedAddress { .. }));
356    }
357
358    #[tokio::test]
359    async fn validate_allows_loopback_in_test_policy() {
360        validate_target_url("http://127.0.0.1:8080/", Policy::for_test()).await.unwrap();
361    }
362}