use std::collections::HashSet;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use once_cell::sync::Lazy;
use crate::error::{Result, SeerError};
static DOMAIN_ALLOWLIST: Lazy<Option<HashSet<String>>> = Lazy::new(|| {
let set: HashSet<String> = std::env::var("SEER_DOMAIN_ALLOWLIST")
.ok()?
.split(',')
.map(|s| s.trim().to_lowercase())
.filter(|s| !s.is_empty())
.collect();
if set.is_empty() {
None
} else {
Some(set)
}
});
pub fn normalize_domain(domain: &str) -> Result<String> {
let domain = domain.trim().to_lowercase();
let domain = domain
.strip_prefix("http://")
.or_else(|| domain.strip_prefix("https://"))
.unwrap_or(&domain);
let domain = domain.split('/').next().unwrap_or(domain);
let domain = domain.split('?').next().unwrap_or(domain);
let domain = domain.split('#').next().unwrap_or(domain);
let domain = domain.strip_prefix("www.").unwrap_or(domain);
let domain = domain.strip_suffix('.').unwrap_or(domain);
if domain.is_empty() || !domain.contains('.') {
return Err(SeerError::InvalidDomain(domain.to_string()));
}
let domain = if !domain.is_ascii() {
domain_to_ascii(domain)?
} else {
domain.to_string()
};
let valid = domain
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_');
if !valid {
return Err(SeerError::InvalidDomain(domain.to_string()));
}
if domain.contains("..") || domain.starts_with('.') || domain.ends_with('.') {
return Err(SeerError::InvalidDomain(domain.to_string()));
}
if domain.len() > 253 {
return Err(SeerError::InvalidDomain(domain.to_string()));
}
for label in domain.split('.') {
if label.is_empty() || label.starts_with('-') || label.ends_with('-') {
return Err(SeerError::InvalidDomain(domain.to_string()));
}
if label.len() > 63 {
return Err(SeerError::InvalidDomain(domain.to_string()));
}
}
if let Some(ref allowlist) = *DOMAIN_ALLOWLIST {
if let Some(tld) = domain.rsplit('.').next() {
if !allowlist.contains(tld) {
return Err(SeerError::DomainNotAllowed {
domain: domain.to_string(),
tld: tld.to_string(),
});
}
}
}
Ok(domain.to_string())
}
fn domain_to_ascii(domain: &str) -> Result<String> {
idna::domain_to_ascii(domain).map_err(|_| {
SeerError::InvalidDomain(format!("invalid internationalized domain: {}", domain))
})
}
pub fn is_private_or_reserved_ip(ip: &IpAddr) -> bool {
match ip {
IpAddr::V4(ipv4) => is_private_or_reserved_ipv4(ipv4),
IpAddr::V6(ipv6) => is_private_or_reserved_ipv6(ipv6),
}
}
fn is_private_or_reserved_ipv4(ip: &Ipv4Addr) -> bool {
if ip.is_private() || ip.is_loopback() || ip.is_link_local() {
return true;
}
let octets = ip.octets();
if octets[0] == 169 && octets[1] == 254 && octets[2] == 169 && octets[3] == 254 {
return true;
}
if octets[0] == 169 && octets[1] == 254 {
return true;
}
if octets[0] == 192 && octets[1] == 0 && octets[2] == 2 {
return true;
}
if octets[0] == 198 && octets[1] == 51 && octets[2] == 100 {
return true;
}
if octets[0] == 203 && octets[1] == 0 && octets[2] == 113 {
return true;
}
if ip.is_broadcast() {
return true;
}
if ip.is_unspecified() {
return true;
}
if octets[0] >= 224 && octets[0] <= 239 {
return true;
}
if octets[0] >= 240 {
return true;
}
false
}
fn is_private_or_reserved_ipv6(ip: &Ipv6Addr) -> bool {
if ip.is_loopback() {
return true;
}
if ip.is_unspecified() {
return true;
}
let segments = ip.segments();
if (segments[0] & 0xfe00) == 0xfc00 {
return true;
}
if (segments[0] & 0xffc0) == 0xfe80 {
return true;
}
if segments[0] >> 8 == 0xff {
return true;
}
if ip
.to_ipv4_mapped()
.is_some_and(|ipv4| is_private_or_reserved_ipv4(&ipv4))
{
return true;
}
false
}
pub fn describe_reserved_ip(ip: &IpAddr) -> Option<&'static str> {
match ip {
IpAddr::V4(v4) => {
if v4.is_unspecified() {
return Some("unspecified address (0.0.0.0) — domain has no routable IP");
}
if v4.is_loopback() {
return Some("loopback address (127.0.0.0/8)");
}
if v4.is_private() {
return Some("private network (RFC 1918)");
}
if v4.is_link_local() {
return Some("link-local address (169.254.0.0/16)");
}
let o = v4.octets();
if o[0] == 169 && o[1] == 254 && o[2] == 169 && o[3] == 254 {
return Some("cloud metadata endpoint (169.254.169.254)");
}
if o[0] == 169 && o[1] == 254 {
return Some("link-local address (169.254.0.0/16)");
}
if (o[0] == 192 && o[1] == 0 && o[2] == 2)
|| (o[0] == 198 && o[1] == 51 && o[2] == 100)
|| (o[0] == 203 && o[1] == 0 && o[2] == 113)
{
return Some("documentation/test range (RFC 5737)");
}
if v4.is_broadcast() {
return Some("broadcast address (255.255.255.255)");
}
if o[0] >= 224 && o[0] <= 239 {
return Some("multicast address (224.0.0.0/4)");
}
if o[0] >= 240 {
return Some("reserved address (240.0.0.0/4)");
}
None
}
IpAddr::V6(v6) => {
if v6.is_loopback() {
return Some("IPv6 loopback (::1)");
}
if v6.is_unspecified() {
return Some("IPv6 unspecified address (::) — domain has no routable IP");
}
let seg = v6.segments();
if (seg[0] & 0xfe00) == 0xfc00 {
return Some("IPv6 unique local address (fc00::/7)");
}
if (seg[0] & 0xffc0) == 0xfe80 {
return Some("IPv6 link-local address (fe80::/10)");
}
if seg[0] >> 8 == 0xff {
return Some("IPv6 multicast (ff00::/8)");
}
if let Some(v4) = v6.to_ipv4_mapped() {
if is_private_or_reserved_ipv4(&v4) {
return Some("IPv4-mapped IPv6 address in private/reserved range");
}
}
None
}
}
}
pub async fn validate_domain_safe(domain: &str) -> Result<String> {
let normalized = normalize_domain(domain)?;
let addr = format!("{}:443", normalized);
let socket_addrs = tokio::net::lookup_host(&addr)
.await
.map_err(|e| SeerError::InvalidDomain(format!("failed to resolve domain: {}", e)))?;
for socket_addr in socket_addrs {
let ip = socket_addr.ip();
if let Some(reason) = describe_reserved_ip(&ip) {
return Err(SeerError::InvalidDomain(format!(
"cannot connect to '{}': {} — {}",
normalized, ip, reason
)));
}
}
Ok(normalized)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_normalize_domain() {
assert_eq!(normalize_domain("example.com").unwrap(), "example.com");
assert_eq!(normalize_domain("EXAMPLE.COM").unwrap(), "example.com");
assert_eq!(
normalize_domain("https://www.example.com/path").unwrap(),
"example.com"
);
assert_eq!(
normalize_domain("http://example.com/").unwrap(),
"example.com"
);
assert_eq!(
normalize_domain(" WWW.EXAMPLE.COM ").unwrap(),
"example.com"
);
assert_eq!(
normalize_domain("example.com?query=1").unwrap(),
"example.com"
);
assert_eq!(
normalize_domain("example.com#section").unwrap(),
"example.com"
);
assert_eq!(
normalize_domain("https://example.com/path?q=1#frag").unwrap(),
"example.com"
);
assert_eq!(
normalize_domain("_dmarc.example.com").unwrap(),
"_dmarc.example.com"
);
assert_eq!(
normalize_domain("selector1._domainkey.example.com").unwrap(),
"selector1._domainkey.example.com"
);
assert_eq!(
normalize_domain("_sip._tcp.example.com").unwrap(),
"_sip._tcp.example.com"
);
assert!(normalize_domain("").is_err());
assert!(normalize_domain("nodots").is_err());
assert!(normalize_domain("example..com").is_err());
assert!(normalize_domain(".example.com").is_err());
assert!(normalize_domain("-example.com").is_err());
assert!(normalize_domain("example-.com").is_err());
assert_eq!(normalize_domain("example.com.").unwrap(), "example.com");
assert_eq!(
normalize_domain("https://example.com.").unwrap(),
"example.com"
);
assert!(normalize_domain("example.com..").is_err());
}
#[test]
fn test_normalize_idn_domain() {
let result = normalize_domain("münchen.de").unwrap();
assert_eq!(result, "xn--mnchen-3ya.de");
let result = normalize_domain("例え.jp").unwrap();
assert_eq!(result, "xn--r8jz45g.jp");
let result = normalize_domain("中文.com").unwrap();
assert_eq!(result, "xn--fiq228c.com");
let result = normalize_domain("https://münchen.de/path").unwrap();
assert_eq!(result, "xn--mnchen-3ya.de");
}
#[test]
fn test_allowlist_not_set_allows_all() {
assert!(normalize_domain("example.com").is_ok());
assert!(normalize_domain("example.xyz").is_ok());
assert!(normalize_domain("example.co.uk").is_ok());
}
#[test]
fn test_is_private_or_reserved_ipv4() {
assert!(is_private_or_reserved_ip(&IpAddr::V4(Ipv4Addr::new(
10, 0, 0, 1
))));
assert!(is_private_or_reserved_ip(&IpAddr::V4(Ipv4Addr::new(
172, 16, 0, 1
))));
assert!(is_private_or_reserved_ip(&IpAddr::V4(Ipv4Addr::new(
192, 168, 1, 1
))));
assert!(is_private_or_reserved_ip(&IpAddr::V4(Ipv4Addr::new(
127, 0, 0, 1
))));
assert!(is_private_or_reserved_ip(&IpAddr::V4(Ipv4Addr::new(
169, 254, 1, 1
))));
assert!(is_private_or_reserved_ip(&IpAddr::V4(Ipv4Addr::new(
169, 254, 169, 254
))));
assert!(!is_private_or_reserved_ip(&IpAddr::V4(Ipv4Addr::new(
8, 8, 8, 8
))));
assert!(!is_private_or_reserved_ip(&IpAddr::V4(Ipv4Addr::new(
1, 1, 1, 1
))));
}
#[test]
fn test_is_private_or_reserved_ipv6() {
assert!(is_private_or_reserved_ip(&IpAddr::V6(Ipv6Addr::new(
0, 0, 0, 0, 0, 0, 0, 1
))));
assert!(is_private_or_reserved_ip(&IpAddr::V6(Ipv6Addr::new(
0xfc00, 0, 0, 0, 0, 0, 0, 1
))));
assert!(is_private_or_reserved_ip(&IpAddr::V6(Ipv6Addr::new(
0xfe80, 0, 0, 0, 0, 0, 0, 1
))));
assert!(!is_private_or_reserved_ip(&IpAddr::V6(Ipv6Addr::new(
0x2001, 0x4860, 0x4860, 0, 0, 0, 0, 0x8888
))));
}
}