Skip to main content

mcp_postgres/
ssrf.rs

1//! SSRF protection for outbound URL fetches (`import_from_url`).
2//!
3//! Validates that a user-supplied URL uses an allowed scheme and resolves
4//! only to public IP addresses, blocking access to loopback, private,
5//! link-local (incl. the cloud metadata endpoint 169.254.169.254), and
6//! unique-local ranges.
7
8use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
9
10use crate::errors::MCPError;
11
12/// Return `true` if the IP must NOT be reachable from a user-controlled fetch.
13pub const fn is_blocked_ip(ip: IpAddr) -> bool {
14    match ip {
15        IpAddr::V4(v4) => is_blocked_v4(v4),
16        IpAddr::V6(v6) => is_blocked_v6(v6),
17    }
18}
19
20const fn is_blocked_v4(ip: Ipv4Addr) -> bool {
21    ip.is_loopback()            // 127.0.0.0/8
22        || ip.is_private()      // 10/8, 172.16/12, 192.168/16
23        || ip.is_link_local()   // 169.254/16  (incl. 169.254.169.254 metadata)
24        || ip.is_unspecified()  // 0.0.0.0
25        || ip.is_broadcast()    // 255.255.255.255
26        || ip.is_documentation()
27        || is_shared_v4(ip) // 100.64/10 carrier-grade NAT
28}
29
30/// 100.64.0.0/10 — RFC 6598 shared address space (no stable std helper).
31const fn is_shared_v4(ip: Ipv4Addr) -> bool {
32    let o = ip.octets();
33    o[0] == 100 && (o[1] & 0b1100_0000) == 0b0100_0000
34}
35
36const fn is_blocked_v6(ip: Ipv6Addr) -> bool {
37    if ip.is_loopback() || ip.is_unspecified() {
38        return true;
39    }
40    // IPv4-mapped (::ffff:a.b.c.d) — apply the v4 rules to the embedded addr.
41    if let Some(v4) = ip.to_ipv4_mapped() {
42        return is_blocked_v4(v4);
43    }
44    let seg = ip.segments();
45    let first = seg[0];
46    // fc00::/7 unique local, fe80::/10 link local.
47    (first & 0xfe00) == 0xfc00 || (first & 0xffc0) == 0xfe80
48}
49
50/// Validate a user-supplied import URL and return the resolved, allowed
51/// `host:port` authority. Rejects non-http(s) schemes and any host that
52/// resolves to a blocked address.
53pub async fn validate_import_url(url: &str) -> Result<(), MCPError> {
54    let parsed = reqwest::Url::parse(url)
55        .map_err(|e| MCPError::InvalidParams(format!("Invalid URL: {e}")))?;
56
57    let scheme = parsed.scheme();
58    if scheme != "http" && scheme != "https" {
59        return Err(MCPError::InvalidParams(format!(
60            "URL scheme '{scheme}' is not allowed; only http and https are permitted"
61        )));
62    }
63
64    let host = parsed
65        .host_str()
66        .ok_or_else(|| MCPError::InvalidParams("URL has no host".into()))?;
67    let port = parsed
68        .port_or_known_default()
69        .ok_or_else(|| MCPError::InvalidParams("URL has no port".into()))?;
70
71    // Resolve and ensure every candidate address is public.
72    let addrs = tokio::net::lookup_host((host, port))
73        .await
74        .map_err(|e| MCPError::InvalidParams(format!("Failed to resolve host '{host}': {e}")))?;
75
76    let mut any = false;
77    for addr in addrs {
78        any = true;
79        if is_blocked_ip(addr.ip()) {
80            return Err(MCPError::InvalidParams(format!(
81                "URL host '{host}' resolves to a blocked (private/loopback/link-local) address"
82            )));
83        }
84    }
85    if !any {
86        return Err(MCPError::InvalidParams(format!(
87            "URL host '{host}' did not resolve to any address"
88        )));
89    }
90    Ok(())
91}
92
93#[cfg(test)]
94mod tests {
95    use super::*;
96    use std::str::FromStr;
97
98    fn ip(s: &str) -> IpAddr {
99        IpAddr::from_str(s).unwrap()
100    }
101
102    #[test]
103    fn test_blocked_v4() {
104        assert!(is_blocked_ip(ip("127.0.0.1")));
105        assert!(is_blocked_ip(ip("10.0.0.5")));
106        assert!(is_blocked_ip(ip("172.16.3.4")));
107        assert!(is_blocked_ip(ip("192.168.1.1")));
108        assert!(is_blocked_ip(ip("169.254.169.254"))); // cloud metadata
109        assert!(is_blocked_ip(ip("0.0.0.0")));
110        assert!(is_blocked_ip(ip("100.64.1.1"))); // CGNAT
111    }
112
113    #[test]
114    fn test_allowed_v4() {
115        assert!(!is_blocked_ip(ip("1.1.1.1")));
116        assert!(!is_blocked_ip(ip("8.8.8.8")));
117        assert!(!is_blocked_ip(ip("93.184.216.34")));
118    }
119
120    #[test]
121    fn test_blocked_v6() {
122        assert!(is_blocked_ip(ip("::1")));
123        assert!(is_blocked_ip(ip("::")));
124        assert!(is_blocked_ip(ip("fc00::1")));
125        assert!(is_blocked_ip(ip("fe80::1")));
126        assert!(is_blocked_ip(ip("::ffff:127.0.0.1"))); // mapped loopback
127    }
128
129    #[test]
130    fn test_allowed_v6() {
131        assert!(!is_blocked_ip(ip("2606:4700:4700::1111")));
132    }
133
134    #[tokio::test]
135    async fn test_validate_rejects_scheme() {
136        let err = validate_import_url("file:///etc/passwd").await.unwrap_err();
137        assert!(err.to_string().contains("scheme"));
138        let err = validate_import_url("ftp://example.com/x")
139            .await
140            .unwrap_err();
141        assert!(err.to_string().contains("scheme"));
142    }
143
144    #[tokio::test]
145    async fn test_validate_rejects_loopback_literal() {
146        let err = validate_import_url("http://127.0.0.1:8080/x")
147            .await
148            .unwrap_err();
149        assert!(err.to_string().contains("blocked"));
150    }
151}