use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use url::Url;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum UrlValidationError {
InvalidUrl(String),
DisallowedScheme(String),
MissingHostname,
BlockedHost(String),
}
impl std::fmt::Display for UrlValidationError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::InvalidUrl(msg) => write!(f, "Invalid URL: {msg}"),
Self::DisallowedScheme(scheme) => {
write!(f, "Disallowed URL scheme: {scheme} (must be http or https)")
}
Self::MissingHostname => write!(f, "URL must have a hostname"),
Self::BlockedHost(host) => {
write!(f, "Blocked host: {host} (private/internal address)")
}
}
}
}
impl std::error::Error for UrlValidationError {}
pub fn validate_safe_url(raw_url: &str) -> Result<Url, UrlValidationError> {
let url = Url::parse(raw_url).map_err(|e| UrlValidationError::InvalidUrl(e.to_string()))?;
match url.scheme() {
"http" | "https" => {}
other => return Err(UrlValidationError::DisallowedScheme(other.to_string())),
}
let host = url.host_str().ok_or(UrlValidationError::MissingHostname)?;
if is_blocked_host(host) {
return Err(UrlValidationError::BlockedHost(host.to_string()));
}
Ok(url)
}
fn is_blocked_host(host: &str) -> bool {
let host_lower = host.to_lowercase();
if host_lower == "localhost"
|| host_lower == "localhost."
|| host_lower.ends_with(".localhost")
|| host_lower.ends_with(".localhost.")
{
return true;
}
let bare = host_lower
.strip_prefix('[')
.and_then(|s| s.strip_suffix(']'))
.unwrap_or(&host_lower);
if let Ok(ip) = bare.parse::<IpAddr>() {
return is_blocked_ip(ip);
}
if host_lower == "metadata.google.internal" || host_lower == "metadata.google.internal." {
return true;
}
false
}
pub fn is_blocked_ip(ip: IpAddr) -> bool {
match ip {
IpAddr::V4(v4) => is_blocked_ipv4(v4),
IpAddr::V6(v6) => is_blocked_ipv6(v6),
}
}
fn is_blocked_ipv4(ip: Ipv4Addr) -> bool {
let octets = ip.octets();
if octets[0] == 127 {
return true;
}
if ip.is_unspecified() {
return true;
}
if octets[0] == 10 {
return true;
}
if octets[0] == 172 && (16..=31).contains(&octets[1]) {
return true;
}
if octets[0] == 192 && octets[1] == 168 {
return true;
}
if octets[0] == 169 && octets[1] == 254 {
return true;
}
if octets[0] == 100 && (64..=127).contains(&octets[1]) {
return true;
}
if (octets[0] == 192 && octets[1] == 0 && octets[2] == 2)
|| (octets[0] == 198 && octets[1] == 51 && octets[2] == 100)
|| (octets[0] == 203 && octets[1] == 0 && octets[2] == 113)
{
return true;
}
false
}
fn is_blocked_ipv6(ip: Ipv6Addr) -> bool {
if ip.is_loopback() {
return true;
}
if ip.is_unspecified() {
return true;
}
let segments = ip.segments();
if segments[0] & 0xffc0 == 0xfe80 {
return true;
}
if segments[0] & 0xfe00 == 0xfc00 {
return true;
}
if let Some(v4) = ip.to_ipv4_mapped() {
return is_blocked_ipv4(v4);
}
false
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn accepts_https_public_url() {
assert!(validate_safe_url("https://mcp.example.com/v1/mcp").is_ok());
}
#[test]
fn accepts_http_public_url() {
assert!(validate_safe_url("http://mcp.example.com/v1/mcp").is_ok());
}
#[test]
fn accepts_url_with_port() {
assert!(validate_safe_url("https://mcp.example.com:8443/v1/mcp").is_ok());
}
#[test]
fn accepts_url_with_path_and_query() {
assert!(validate_safe_url("https://api.example.com/mcp?key=val").is_ok());
}
#[test]
fn rejects_ftp_scheme() {
let err = validate_safe_url("ftp://evil.com/file").unwrap_err();
assert!(matches!(err, UrlValidationError::DisallowedScheme(_)));
}
#[test]
fn rejects_file_scheme() {
let err = validate_safe_url("file:///etc/passwd").unwrap_err();
assert!(matches!(err, UrlValidationError::DisallowedScheme(_)));
}
#[test]
fn rejects_javascript_scheme() {
let err = validate_safe_url("javascript:alert(1)").unwrap_err();
assert!(
matches!(err, UrlValidationError::DisallowedScheme(_))
|| matches!(err, UrlValidationError::MissingHostname)
);
}
#[test]
fn rejects_data_scheme() {
let err = validate_safe_url("data:text/plain,hello").unwrap_err();
assert!(
matches!(err, UrlValidationError::DisallowedScheme(_))
|| matches!(err, UrlValidationError::MissingHostname)
);
}
#[test]
fn rejects_empty_string() {
assert!(validate_safe_url("").is_err());
}
#[test]
fn rejects_not_a_url() {
assert!(validate_safe_url("not a url").is_err());
}
#[test]
fn rejects_localhost() {
let err = validate_safe_url("http://localhost/path").unwrap_err();
assert!(matches!(err, UrlValidationError::BlockedHost(_)));
}
#[test]
fn rejects_localhost_with_port() {
let err = validate_safe_url("http://localhost:8080/path").unwrap_err();
assert!(matches!(err, UrlValidationError::BlockedHost(_)));
}
#[test]
fn rejects_subdomain_of_localhost() {
let err = validate_safe_url("http://foo.localhost/path").unwrap_err();
assert!(matches!(err, UrlValidationError::BlockedHost(_)));
}
#[test]
fn rejects_127_0_0_1() {
let err = validate_safe_url("http://127.0.0.1/path").unwrap_err();
assert!(matches!(err, UrlValidationError::BlockedHost(_)));
}
#[test]
fn rejects_127_x_x_x() {
let err = validate_safe_url("http://127.255.0.1/path").unwrap_err();
assert!(matches!(err, UrlValidationError::BlockedHost(_)));
}
#[test]
fn rejects_ipv6_loopback() {
let err = validate_safe_url("http://[::1]/path").unwrap_err();
assert!(matches!(err, UrlValidationError::BlockedHost(_)));
}
#[test]
fn rejects_10_x() {
let err = validate_safe_url("http://10.0.0.1/path").unwrap_err();
assert!(matches!(err, UrlValidationError::BlockedHost(_)));
}
#[test]
fn rejects_172_16_x() {
let err = validate_safe_url("http://172.16.0.1/path").unwrap_err();
assert!(matches!(err, UrlValidationError::BlockedHost(_)));
}
#[test]
fn rejects_172_31_x() {
let err = validate_safe_url("http://172.31.255.255/path").unwrap_err();
assert!(matches!(err, UrlValidationError::BlockedHost(_)));
}
#[test]
fn accepts_172_32_x() {
assert!(validate_safe_url("http://172.32.0.1/path").is_ok());
}
#[test]
fn rejects_192_168_x() {
let err = validate_safe_url("http://192.168.1.1/path").unwrap_err();
assert!(matches!(err, UrlValidationError::BlockedHost(_)));
}
#[test]
fn rejects_link_local() {
let err = validate_safe_url("http://169.254.1.1/path").unwrap_err();
assert!(matches!(err, UrlValidationError::BlockedHost(_)));
}
#[test]
fn rejects_cloud_metadata_ip() {
let err = validate_safe_url("http://169.254.169.254/latest/meta-data/").unwrap_err();
assert!(matches!(err, UrlValidationError::BlockedHost(_)));
}
#[test]
fn rejects_gce_metadata_hostname() {
let err =
validate_safe_url("http://metadata.google.internal/computeMetadata/v1/").unwrap_err();
assert!(matches!(err, UrlValidationError::BlockedHost(_)));
}
#[test]
fn rejects_unspecified_v4() {
let err = validate_safe_url("http://0.0.0.0/path").unwrap_err();
assert!(matches!(err, UrlValidationError::BlockedHost(_)));
}
#[test]
fn rejects_ipv6_unspecified() {
let err = validate_safe_url("http://[::]/path").unwrap_err();
assert!(matches!(err, UrlValidationError::BlockedHost(_)));
}
#[test]
fn rejects_ipv6_link_local() {
let err = validate_safe_url("http://[fe80::1]/path").unwrap_err();
assert!(matches!(err, UrlValidationError::BlockedHost(_)));
}
#[test]
fn rejects_ipv6_unique_local() {
let err = validate_safe_url("http://[fd00::1]/path").unwrap_err();
assert!(matches!(err, UrlValidationError::BlockedHost(_)));
}
#[test]
fn rejects_ipv4_mapped_ipv6_private() {
let err = validate_safe_url("http://[::ffff:127.0.0.1]/path").unwrap_err();
assert!(matches!(err, UrlValidationError::BlockedHost(_)));
}
#[test]
fn rejects_ipv4_mapped_ipv6_metadata() {
let err =
validate_safe_url("http://[::ffff:169.254.169.254]/latest/meta-data/").unwrap_err();
assert!(matches!(err, UrlValidationError::BlockedHost(_)));
}
#[test]
fn rejects_cgnat() {
let err = validate_safe_url("http://100.64.0.1/path").unwrap_err();
assert!(matches!(err, UrlValidationError::BlockedHost(_)));
}
#[test]
fn error_display_messages() {
assert!(
UrlValidationError::BlockedHost("localhost".into())
.to_string()
.contains("private/internal")
);
assert!(
UrlValidationError::DisallowedScheme("ftp".into())
.to_string()
.contains("http or https")
);
}
}