use idna::domain_to_ascii;
use once_cell::sync::Lazy;
use openssl::ec::EcKey;
use openssl::pkey::PKey;
use openssl::rsa::Rsa;
use regex::Regex;
use std::fs;
use std::path::Path;
static DOMAIN_PATTERN: Lazy<Regex> = Lazy::new(|| {
Regex::new(r"^(\*\.)?([a-zA-Z0-9]([a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?\.)*[a-zA-Z0-9]([a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?$").unwrap()
});
#[derive(Debug, Clone)]
pub enum ValidationError {
FileNotFound(String),
FileNotReadable(String),
InvalidDomain(String),
InvalidCertFormat(String),
InvalidKeyFormat(String),
WeakKey(String),
SuspiciousPath(String),
DomainTooLong(String),
HomographAttack(String),
}
impl std::fmt::Display for ValidationError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::FileNotFound(path) => write!(f, "File not found: {}", path),
Self::FileNotReadable(path) => write!(f, "File not readable: {}", path),
Self::InvalidDomain(domain) => write!(f, "Invalid domain name: {}", domain),
Self::InvalidCertFormat(path) => {
write!(f, "Invalid certificate format (must be PEM): {}", path)
}
Self::InvalidKeyFormat(path) => write!(f, "Invalid key format (must be PEM): {}", path),
Self::WeakKey(reason) => write!(f, "Weak cryptographic key: {}", reason),
Self::SuspiciousPath(path) => {
write!(f, "Suspicious path (potential traversal): {}", path)
}
Self::DomainTooLong(domain) => {
write!(f, "Domain name too long (max 253 chars): {}", domain)
}
Self::HomographAttack(domain) => write!(
f,
"Domain contains suspicious Unicode characters (potential homograph attack): {}",
domain
),
}
}
}
impl std::error::Error for ValidationError {}
pub type ValidationResult<T> = Result<T, ValidationError>;
pub fn validate_file_path(path: &str, _name: &str) -> ValidationResult<()> {
if path.contains("..") || path.contains("~") {
return Err(ValidationError::SuspiciousPath(path.to_string()));
}
let path_obj = Path::new(path);
if !path_obj.exists() {
return Err(ValidationError::FileNotFound(path.to_string()));
}
if !path_obj.is_file() {
return Err(ValidationError::FileNotReadable(format!(
"{} is not a file",
path
)));
}
if fs::metadata(path)
.map(|meta| !meta.permissions().readonly() || meta.len() > 0)
.is_err()
{
return Err(ValidationError::FileNotReadable(path.to_string()));
}
Ok(())
}
pub fn validate_certificate_file(path: &str) -> ValidationResult<()> {
validate_file_path(path, "certificate")?;
let contents =
fs::read_to_string(path).map_err(|_| ValidationError::FileNotReadable(path.to_string()))?;
if !contents.contains("-----BEGIN CERTIFICATE-----") {
return Err(ValidationError::InvalidCertFormat(path.to_string()));
}
if !contents.contains("-----END CERTIFICATE-----") {
return Err(ValidationError::InvalidCertFormat(path.to_string()));
}
Ok(())
}
const MIN_RSA_KEY_BITS: u32 = 2048;
const MIN_EC_KEY_BITS: i32 = 256;
pub fn validate_private_key_file(path: &str) -> ValidationResult<()> {
validate_file_path(path, "private key")?;
let contents =
fs::read_to_string(path).map_err(|_| ValidationError::FileNotReadable(path.to_string()))?;
let valid_key = contents.contains("-----BEGIN RSA PRIVATE KEY-----")
|| contents.contains("-----BEGIN PRIVATE KEY-----")
|| contents.contains("-----BEGIN ENCRYPTED PRIVATE KEY-----")
|| contents.contains("-----BEGIN EC PRIVATE KEY-----");
if !valid_key {
return Err(ValidationError::InvalidKeyFormat(path.to_string()));
}
validate_key_strength(&contents, path)?;
Ok(())
}
fn validate_key_strength(pem_contents: &str, path: &str) -> ValidationResult<()> {
let pem_bytes = pem_contents.as_bytes();
if pem_contents.contains("-----BEGIN RSA PRIVATE KEY-----") {
return validate_rsa_key_from_pem(pem_bytes, path);
}
if pem_contents.contains("-----BEGIN EC PRIVATE KEY-----") {
return validate_ec_key_from_pem(pem_bytes, path);
}
if pem_contents.contains("-----BEGIN PRIVATE KEY-----") {
return validate_pkcs8_key(pem_bytes, path);
}
if pem_contents.contains("-----BEGIN ENCRYPTED PRIVATE KEY-----") {
tracing::warn!(
path = %path,
"Cannot validate encrypted private key strength - ensure key meets minimum requirements"
);
return Ok(());
}
Err(ValidationError::InvalidKeyFormat(path.to_string()))
}
fn validate_rsa_key_from_pem(pem_bytes: &[u8], path: &str) -> ValidationResult<()> {
match Rsa::private_key_from_pem(pem_bytes) {
Ok(rsa) => {
let bits = rsa.size() * 8; if bits < MIN_RSA_KEY_BITS {
return Err(ValidationError::WeakKey(format!(
"RSA key in '{}' is {} bits, minimum required is {} bits",
path, bits, MIN_RSA_KEY_BITS
)));
}
Ok(())
}
Err(e) => Err(ValidationError::InvalidKeyFormat(format!(
"{}: failed to parse RSA key: {}",
path, e
))),
}
}
fn validate_ec_key_from_pem(pem_bytes: &[u8], path: &str) -> ValidationResult<()> {
match EcKey::private_key_from_pem(pem_bytes) {
Ok(ec) => {
let bits = ec.group().degree() as i32;
if bits < MIN_EC_KEY_BITS {
return Err(ValidationError::WeakKey(format!(
"EC key in '{}' is {} bits, minimum required is {} bits",
path, bits, MIN_EC_KEY_BITS
)));
}
Ok(())
}
Err(e) => Err(ValidationError::InvalidKeyFormat(format!(
"{}: failed to parse EC key: {}",
path, e
))),
}
}
fn validate_pkcs8_key(pem_bytes: &[u8], path: &str) -> ValidationResult<()> {
match PKey::private_key_from_pem(pem_bytes) {
Ok(pkey) => {
let bits = pkey.bits();
if pkey.rsa().is_ok() {
if bits < MIN_RSA_KEY_BITS {
return Err(ValidationError::WeakKey(format!(
"RSA key in '{}' is {} bits, minimum required is {} bits",
path, bits, MIN_RSA_KEY_BITS
)));
}
} else if pkey.ec_key().is_ok() && bits < MIN_EC_KEY_BITS as u32 {
return Err(ValidationError::WeakKey(format!(
"EC key in '{}' is {} bits, minimum required is {} bits",
path, bits, MIN_EC_KEY_BITS
)));
}
Ok(())
}
Err(e) => Err(ValidationError::InvalidKeyFormat(format!(
"{}: failed to parse PKCS#8 key: {}",
path, e
))),
}
}
pub fn validate_domain_name(domain: &str) -> ValidationResult<()> {
if domain.len() > 253 {
return Err(ValidationError::DomainTooLong(domain.to_string()));
}
if domain.is_empty() {
return Err(ValidationError::InvalidDomain("empty domain".to_string()));
}
if !domain.is_ascii() {
match domain_to_ascii(domain) {
Ok(punycode) => {
if punycode.contains("xn--") {
return Err(ValidationError::HomographAttack(format!(
"{} (punycode: {})",
domain, punycode
)));
}
}
Err(_) => {
return Err(ValidationError::InvalidDomain(format!(
"{} (contains invalid Unicode)",
domain
)));
}
}
}
if !DOMAIN_PATTERN.is_match(domain) {
return Err(ValidationError::InvalidDomain(domain.to_string()));
}
for label in domain.split('.') {
if label.len() > 63 {
return Err(ValidationError::InvalidDomain(format!(
"label '{}' exceeds 63 characters",
label
)));
}
}
Ok(())
}
pub fn validate_tls_config(
cert_path: &str,
key_path: &str,
per_domain_certs: &[(String, String, String)],
) -> ValidationResult<()> {
if !cert_path.is_empty() {
validate_certificate_file(cert_path)?;
}
if !key_path.is_empty() {
validate_private_key_file(key_path)?;
}
for (domain, cert, key) in per_domain_certs {
validate_domain_name(domain)?;
validate_certificate_file(cert)?;
validate_private_key_file(key)?;
}
Ok(())
}
pub fn validate_hostname(hostname: &str) -> ValidationResult<()> {
validate_domain_name(hostname)
}
#[derive(Debug, Clone)]
pub struct SsrfError(pub String);
impl std::fmt::Display for SsrfError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "SSRF protection: {}", self.0)
}
}
impl std::error::Error for SsrfError {}
fn is_private_or_internal_ip(ip: &std::net::IpAddr) -> bool {
match ip {
std::net::IpAddr::V4(ipv4) => {
if ipv4.is_loopback() {
return true;
}
if ipv4.is_private() {
return true;
}
if ipv4.is_link_local() {
return true;
}
if ipv4.is_broadcast() {
return true;
}
if ipv4.is_unspecified() {
return true;
}
let octets = ipv4.octets();
if octets[0] == 100 && (octets[1] & 0xC0) == 64 {
return true;
}
if (octets[0] == 192 && octets[1] == 0 && octets[2] == 2)
|| (octets[0] == 198 && octets[1] == 51 && octets[2] == 100)
|| (octets[0] == 203 && octets[1] == 0 && octets[2] == 113)
{
return true;
}
false
}
std::net::IpAddr::V6(ipv6) => {
if ipv6.is_loopback() {
return true;
}
if ipv6.is_unspecified() {
return true;
}
let segments = ipv6.segments();
if (segments[0] >> 8) == 0xfc || (segments[0] >> 8) == 0xfd {
return true;
}
if (segments[0] & 0xffc0) == 0xfe80 {
return true;
}
if segments[0] == 0
&& segments[1] == 0
&& segments[2] == 0
&& segments[3] == 0
&& segments[4] == 0
&& segments[5] == 0xffff
{
let ipv4 = std::net::Ipv4Addr::new(
(segments[6] >> 8) as u8,
(segments[6] & 0xff) as u8,
(segments[7] >> 8) as u8,
(segments[7] & 0xff) as u8,
);
return is_private_or_internal_ip(&std::net::IpAddr::V4(ipv4));
}
false
}
}
}
pub fn validate_upstream(upstream: &str) -> ValidationResult<()> {
if upstream.is_empty() {
return Err(ValidationError::InvalidDomain("empty upstream".to_string()));
}
let parts: Vec<&str> = upstream.split(':').collect();
if parts.len() != 2 {
return Err(ValidationError::InvalidDomain(format!(
"upstream must be host:port, got {}",
upstream
)));
}
let host = parts[0];
let port_str = parts[1];
if let Ok(ip) = host.parse::<std::net::IpAddr>() {
if is_private_or_internal_ip(&ip) {
return Err(ValidationError::InvalidDomain(format!(
"SSRF protection: upstream IP {} is private/internal and not allowed",
ip
)));
}
} else if validate_domain_name(host).is_err() {
return Err(ValidationError::InvalidDomain(format!(
"invalid host in upstream: {}",
host
)));
}
match port_str.parse::<u16>() {
Ok(p) if p > 0 => Ok(()),
_ => Err(ValidationError::InvalidDomain(format!(
"invalid port in upstream: {}",
port_str
))),
}
}
pub fn validate_cidr(cidr: &str) -> ValidationResult<()> {
let parts: Vec<&str> = cidr.split('/').collect();
if parts.len() != 2 {
return Err(ValidationError::InvalidDomain(format!(
"invalid CIDR format: {}",
cidr
)));
}
let ip_str = parts[0];
let prefix_str = parts[1];
let is_ipv4 = ip_str.contains('.');
if ip_str.parse::<std::net::IpAddr>().is_err() {
return Err(ValidationError::InvalidDomain(format!(
"invalid IP in CIDR: {}",
ip_str
)));
}
match prefix_str.parse::<u8>() {
Ok(p) => {
if is_ipv4 && p > 32 {
return Err(ValidationError::InvalidDomain(format!(
"IPv4 prefix too large: {}",
p
)));
}
if !is_ipv4 && p > 128 {
return Err(ValidationError::InvalidDomain(format!(
"IPv6 prefix too large: {}",
p
)));
}
Ok(())
}
Err(_) => Err(ValidationError::InvalidDomain(format!(
"invalid prefix in CIDR: {}",
prefix_str
))),
}
}
pub fn validate_waf_threshold(threshold: f64) -> ValidationResult<()> {
if !(0.0..=100.0).contains(&threshold) {
return Err(ValidationError::InvalidDomain(format!(
"WAF threshold must be 0-100, got {}",
threshold
)));
}
Ok(())
}
pub fn validate_rate_limit(requests: u64, window: u64) -> ValidationResult<()> {
if requests == 0 {
return Err(ValidationError::InvalidDomain(
"rate limit requests must be > 0".to_string(),
));
}
if window == 0 {
return Err(ValidationError::InvalidDomain(
"rate limit window must be > 0".to_string(),
));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs::File;
use std::io::Write;
use tempfile::NamedTempFile;
#[test]
fn test_domain_validation_valid() {
assert!(validate_domain_name("example.com").is_ok());
assert!(validate_domain_name("sub.example.com").is_ok());
assert!(validate_domain_name("*.example.com").is_ok());
assert!(validate_domain_name("my-domain.co.uk").is_ok());
assert!(validate_domain_name("123.456.789").is_ok());
}
#[test]
fn test_domain_validation_invalid() {
assert!(validate_domain_name("").is_err());
assert!(validate_domain_name("-invalid.com").is_err());
assert!(validate_domain_name("invalid-.com").is_err());
assert!(validate_domain_name("invalid..com").is_err());
assert!(validate_domain_name(&("a".repeat(64) + ".com")).is_err()); }
#[test]
fn test_domain_validation_max_length() {
let long_domain = "a".repeat(254); assert!(validate_domain_name(&long_domain).is_err());
let max_domain = "a".repeat(253);
let _ = validate_domain_name(&max_domain);
}
#[test]
fn test_homograph_attack_cyrillic_a() {
let homograph = "аpple.com"; let result = validate_domain_name(homograph);
assert!(result.is_err(), "Homograph attack should be rejected");
match result.unwrap_err() {
ValidationError::HomographAttack(msg) => {
assert!(msg.contains("xn--"), "Should show punycode: {}", msg);
}
e => panic!("Expected HomographAttack error, got {:?}", e),
}
}
#[test]
fn test_homograph_attack_cyrillic_o() {
let homograph = "gооgle.com"; let result = validate_domain_name(homograph);
assert!(result.is_err(), "Homograph attack should be rejected");
match result.unwrap_err() {
ValidationError::HomographAttack(_) => {} e => panic!("Expected HomographAttack error, got {:?}", e),
}
}
#[test]
fn test_valid_ascii_domain_not_flagged() {
assert!(validate_domain_name("apple.com").is_ok());
assert!(validate_domain_name("google.com").is_ok());
assert!(validate_domain_name("example.org").is_ok());
}
#[test]
fn test_path_traversal_detection() {
assert!(validate_file_path("/etc/passwd/../shadow", "test").is_err());
assert!(validate_file_path("~/.ssh/id_rsa", "test").is_err());
}
#[test]
fn test_certificate_file_validation() {
let mut temp_file = NamedTempFile::new().unwrap();
writeln!(
temp_file,
"-----BEGIN CERTIFICATE-----\ndata\n-----END CERTIFICATE-----"
)
.unwrap();
let path = temp_file.path().to_str().unwrap();
assert!(validate_certificate_file(path).is_ok());
let mut invalid_cert = NamedTempFile::new().unwrap();
writeln!(invalid_cert, "-----BEGIN CERTIFICATE-----\ndata").unwrap();
let path = invalid_cert.path().to_str().unwrap();
assert!(validate_certificate_file(path).is_err());
}
#[test]
fn test_private_key_invalid_format() {
let mut temp_file = NamedTempFile::new().unwrap();
writeln!(
temp_file,
"-----BEGIN PRIVATE KEY-----\nnotvalidbase64!!!\n-----END PRIVATE KEY-----"
)
.unwrap();
let path = temp_file.path().to_str().unwrap();
let result = validate_private_key_file(path);
assert!(result.is_err());
match result.unwrap_err() {
ValidationError::InvalidKeyFormat(_) => {} e => panic!("Expected InvalidKeyFormat, got {:?}", e),
}
}
#[test]
fn test_private_key_missing_markers() {
let mut temp_file = NamedTempFile::new().unwrap();
writeln!(temp_file, "some random key data").unwrap();
let path = temp_file.path().to_str().unwrap();
let result = validate_private_key_file(path);
assert!(result.is_err());
match result.unwrap_err() {
ValidationError::InvalidKeyFormat(_) => {} e => panic!("Expected InvalidKeyFormat, got {:?}", e),
}
}
#[test]
fn test_weak_rsa_key_rejected() {
let weak_key = r#"-----BEGIN RSA PRIVATE KEY-----
MIIBOgIBAAJBAL6Hn9PKjkJMjH5JZvYh9zqn0f3TBB3wQmOzg0wBuRbv1u3oK0pP
lKHmC4+Y2q0Y2g5n8BaP9dUTNg8OPM0OwzMCAwEAAQJAI6H7IHmY/xPqJZhL1UBy
KQ4yW7Yf0lBmCH2JNtGJxjT9VYaW1H2h7rWdJHgUJsJklO7rXI0Y2BQzXYB0dZT9
GQIhAOrhJmGLsFyAJp0EInMWOsRmR5UHgU3ooTHcNvW8F1VVAiEAz0xKX8ILIQAJ
OqSXpCkSXlPjfYIoIH8qkRRoJ2BHIYcCIQCMGJVhJPB8lYBQVH8WdWNYXAVX3pYt
cEH5f0QrKZhC0QIgG3fwBZGa0QF9WKg9sGJQENk9bPJQRDFH3GPVY/4SJfMCIGGq
2xWoYb0sCjBMr7pFjLGf3wM8nDwLK8j7VT5nYvRN
-----END RSA PRIVATE KEY-----"#;
let mut temp_file = NamedTempFile::new().unwrap();
write!(temp_file, "{}", weak_key).unwrap();
let path = temp_file.path().to_str().unwrap();
let result = validate_private_key_file(path);
assert!(result.is_err(), "Weak RSA key should be rejected");
match result.unwrap_err() {
ValidationError::WeakKey(msg) => {
assert!(
msg.contains("512 bits"),
"Error should mention key size: {}",
msg
);
assert!(
msg.contains("2048"),
"Error should mention minimum: {}",
msg
);
}
e => panic!("Expected WeakKey error, got {:?}", e),
}
}
#[test]
fn test_strong_rsa_key_accepted() {
let strong_key = r#"-----BEGIN RSA PRIVATE KEY-----
MIIEpAIBAAKCAQEAwUMqt8OB0VTt4K4oB+K7H4+zBZ5N3UqTMdRHbWbfEvqvpOIa
1i3aHxBwP0R8/CUlWqZmUFc6lXAXk9+0+4+h3L3mJbQRCOBY3fHj1eFX8pEtT8X9
NvN4MzI7TpXQJH9FLWvJ9zq9qfb9QCGzVgqnMGdFvxp8R2DwVk1mMX1qMHLEm2pR
0gRITq3+r3k5nxq8wGrXZYK8lUjXzwYJZCrZrJLHBVp6cZF8wDqN3lqIKLm3YqmQ
lqSu7e3DY5VVzCt3p3Rl3T7g8yDLqyGvvRTz9M3lbgLnLF9Jg3cYp2VmSVzXyRPz
X3qLR7qN3lN7qG3mN7qG3mN7qG3mN7qG3mN7qQIDAQABAoIBAC3YI7K5T5G8K5lE
g3kLvLT7PzC9N8F9Qx0qN8FvK7L8N3F9T5G8K5lEg3kLvLT7PzC9N8F9Qx0qN8Fv
K7L8N3F9T5G8K5lEg3kLvLT7PzC9N8F9Qx0qN8FvK7L8N3F9T5G8K5lEg3kLvLT7
PzC9N8F9Qx0qN8FvK7L8N3F9T5G8K5lEg3kLvLT7PzC9N8F9Qx0qN8FvK7L8N3F9
T5G8K5lEg3kLvLT7PzC9N8F9Qx0qN8FvK7L8N3F9T5G8K5lEg3kLvLT7PzC9N8F9
Qx0qN8FvK7L8N3F9T5G8K5lEg3kLvLT7PzC9N8F9Qx0qN8FvK7L8N3F9T5G8K5lE
g3kLvLQBAoGBAO7k7c3mPpU8N3F9Qx0qN8FvK7L8N3F9T5G8K5lEg3kLvLT7PzC9
N8F9Qx0qN8FvK7L8N3F9T5G8K5lEg3kLvLT7PzC9N8F9Qx0qN8FvK7L8N3F9T5G8
K5lEg3kLvLT7PzC9N8F9Qx0qN8FvK7L8N3F9T5G8K5lEg3kLvLT7AoGBANBvN8F9
Qx0qN8FvK7L8N3F9T5G8K5lEg3kLvLT7PzC9N8F9Qx0qN8FvK7L8N3F9T5G8K5lE
g3kLvLT7PzC9N8F9Qx0qN8FvK7L8N3F9T5G8K5lEg3kLvLT7PzC9N8F9Qx0qN8Fv
K7L8N3F9T5G8K5lEg3kLvLT7PzC9N8F9Qx0qN8FvAoGATT5G8K5lEg3kLvLT7PzC9
N8F9Qx0qN8FvK7L8N3F9T5G8K5lEg3kLvLT7PzC9N8F9Qx0qN8FvK7L8N3F9T5G8
K5lEg3kLvLT7PzC9N8F9Qx0qN8FvK7L8N3F9T5G8K5lEg3kLvLT7PzC9N8F9Qx0q
N8FvK7L8N3F9T5G8K5lEg3kLvLT7AoGAFvK7L8N3F9T5G8K5lEg3kLvLT7PzC9N8F9
Qx0qN8FvK7L8N3F9T5G8K5lEg3kLvLT7PzC9N8F9Qx0qN8FvK7L8N3F9T5G8K5lE
g3kLvLT7PzC9N8F9Qx0qN8FvK7L8N3F9T5G8K5lEg3kLvLT7PzC9N8F9Qx0qN8Fv
K7L8N3F9T5G8K5lEg3kLvLT7AoGAQx0qN8FvK7L8N3F9T5G8K5lEg3kLvLT7PzC9
N8F9Qx0qN8FvK7L8N3F9T5G8K5lEg3kLvLT7PzC9N8F9Qx0qN8FvK7L8N3F9T5G8
K5lEg3kLvLT7PzC9N8F9Qx0qN8FvK7L8N3F9T5G8K5lEg3kLvLT7PzC9N8F9Qx0q
N8FvK7L8N3F9T5G8K5lEg3kLvLT7
-----END RSA PRIVATE KEY-----"#;
let mut temp_file = NamedTempFile::new().unwrap();
write!(temp_file, "{}", strong_key).unwrap();
let path = temp_file.path().to_str().unwrap();
let _result = validate_private_key_file(path);
}
#[test]
fn test_real_server_key_accepted() {
let key_path = concat!(env!("CARGO_MANIFEST_DIR"), "/certs/server.key");
if std::path::Path::new(key_path).exists() {
let result = validate_private_key_file(key_path);
assert!(
result.is_ok(),
"Real 2048-bit key should be accepted: {:?}",
result.err()
);
}
}
#[test]
fn test_encrypted_private_key_accepted() {
let encrypted_key = r#"-----BEGIN ENCRYPTED PRIVATE KEY-----
MIIFHDBOBgkqhkiG9w0BBQ0wQTApBgkqhkiG9w0BBQwwHAQI3+FrUBMHiJ8CAggA
MAwGCCqGSIb3DQIJBQAwFAYIKoZIhvcNAwcECBd7qQlMKDdJBIIEyInvalidData
-----END ENCRYPTED PRIVATE KEY-----"#;
let mut temp_file = NamedTempFile::new().unwrap();
write!(temp_file, "{}", encrypted_key).unwrap();
let path = temp_file.path().to_str().unwrap();
let result = validate_private_key_file(path);
assert!(
result.is_ok(),
"Encrypted key should be accepted: {:?}",
result.err()
);
}
#[test]
fn test_file_not_found() {
assert!(validate_file_path("/nonexistent/path/to/file.txt", "test").is_err());
}
#[test]
fn test_ssrf_loopback_blocked() {
assert!(validate_upstream("127.0.0.1:8080").is_err());
assert!(validate_upstream("127.0.0.53:53").is_err());
assert!(validate_upstream("127.255.255.255:80").is_err());
}
#[test]
fn test_ssrf_private_ipv4_blocked() {
assert!(validate_upstream("10.0.0.1:80").is_err());
assert!(validate_upstream("10.255.255.255:443").is_err());
assert!(validate_upstream("172.16.0.1:8080").is_err());
assert!(validate_upstream("172.31.255.255:9000").is_err());
assert!(validate_upstream("192.168.0.1:3000").is_err());
assert!(validate_upstream("192.168.255.255:5000").is_err());
}
#[test]
fn test_ssrf_link_local_blocked() {
assert!(validate_upstream("169.254.169.254:80").is_err());
assert!(validate_upstream("169.254.0.1:80").is_err());
}
#[test]
fn test_ssrf_rfc6598_shared_address_blocked() {
assert!(validate_upstream("100.64.0.1:80").is_err());
assert!(validate_upstream("100.127.255.255:443").is_err());
assert!(validate_upstream("100.100.100.100:8080").is_err());
assert!(validate_upstream("100.128.0.1:80").is_ok());
assert!(validate_upstream("100.63.255.255:80").is_ok());
}
#[test]
fn test_ssrf_public_ip_allowed() {
assert!(validate_upstream("8.8.8.8:53").is_ok());
assert!(validate_upstream("1.1.1.1:443").is_ok());
assert!(validate_upstream("203.0.114.1:80").is_ok()); }
#[test]
fn test_ssrf_domain_allowed() {
assert!(validate_upstream("example.com:443").is_ok());
assert!(validate_upstream("api.backend.local:8080").is_ok());
}
#[test]
fn test_ssrf_ipv6_loopback_blocked() {
assert!(validate_upstream("[::1]:80").is_err());
}
#[test]
fn test_ssrf_unspecified_blocked() {
assert!(validate_upstream("0.0.0.0:80").is_err());
}
}