use std::net::IpAddr;
use std::time::Duration;
use hickory_resolver::config::{ResolveHosts, ResolverConfig, GOOGLE};
use hickory_resolver::net::runtime::TokioRuntimeProvider;
use hickory_resolver::TokioResolver;
use once_cell::sync::Lazy;
use tokio::net::lookup_host;
use tracing::warn;
use crate::error::{Result, SeerError};
static FALLBACK_RESOLVER: Lazy<TokioResolver> = Lazy::new(|| {
let mut builder = TokioResolver::builder_with_config(
ResolverConfig::udp_and_tcp(&GOOGLE),
TokioRuntimeProvider::default(),
);
{
let opts = builder.options_mut();
opts.timeout = Duration::from_secs(5);
opts.attempts = 2;
opts.use_hosts_file = ResolveHosts::Never;
}
builder
.build()
.expect("hickory fallback resolver build is infallible with no TLS features")
});
pub fn is_reserved_ip(ip: IpAddr) -> bool {
match ip {
IpAddr::V4(v4) => {
v4.is_loopback()
|| v4.is_private()
|| v4.is_link_local()
|| v4.is_multicast()
|| v4.is_broadcast()
|| v4.is_unspecified()
|| v4.is_documentation()
|| (v4.octets()[0] == 100 && (v4.octets()[1] & 0xC0) == 64)
|| (v4.octets()[0] == 192 && v4.octets()[1] == 0 && v4.octets()[2] == 0)
|| (v4.octets()[0] == 198 && (v4.octets()[1] == 18 || v4.octets()[1] == 19))
}
IpAddr::V6(v6) => {
v6.is_loopback()
|| v6.is_multicast()
|| v6.is_unspecified()
|| (v6.segments()[0] & 0xfe00) == 0xfc00
|| (v6.segments()[0] & 0xffc0) == 0xfe80
|| v6
.to_ipv4_mapped()
.is_some_and(|v4| is_reserved_ip(IpAddr::V4(v4)))
}
}
}
pub async fn validate_public_host(host: &str, port: u16) -> Result<()> {
if let Ok(ip) = host.parse::<IpAddr>() {
if is_reserved_ip(ip) {
return Err(SeerError::InvalidInput(format!(
"refusing to connect to reserved address: {}",
ip
)));
}
return Ok(());
}
let addrs: Vec<IpAddr> = match lookup_host((host, port)).await {
Ok(iter) => iter.map(|sa| sa.ip()).collect(),
Err(os_err) => {
warn!(
host = %host,
error = %os_err,
"system DNS resolution failed; retrying via hickory fallback"
);
match FALLBACK_RESOLVER.lookup_ip(host).await {
Ok(resp) => resp.iter().collect(),
Err(fallback_err) => {
return Err(SeerError::InvalidInput(format!(
"DNS resolution failed for {host}: {os_err} (fallback: {fallback_err})"
)));
}
}
}
};
if addrs.is_empty() {
return Err(SeerError::InvalidInput(format!(
"no addresses resolved for {host}"
)));
}
for ip in &addrs {
if is_reserved_ip(*ip) {
return Err(SeerError::InvalidInput(format!(
"{host} resolves to reserved address {ip}"
)));
}
}
Ok(())
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn rejects_loopback_v4() {
assert!(is_reserved_ip("127.0.0.1".parse().unwrap()));
}
#[test]
fn rejects_metadata_v4() {
assert!(is_reserved_ip("169.254.169.254".parse().unwrap()));
}
#[test]
fn rejects_rfc1918() {
assert!(is_reserved_ip("10.0.0.1".parse().unwrap()));
assert!(is_reserved_ip("172.16.0.1".parse().unwrap()));
assert!(is_reserved_ip("192.168.1.1".parse().unwrap()));
}
#[test]
fn rejects_cgnat() {
assert!(is_reserved_ip("100.64.0.1".parse().unwrap()));
}
#[test]
fn rejects_benchmarking() {
assert!(is_reserved_ip("198.18.0.1".parse().unwrap()));
}
#[test]
fn rejects_ipv6_loopback() {
assert!(is_reserved_ip("::1".parse().unwrap()));
}
#[test]
fn rejects_ipv6_ula() {
assert!(is_reserved_ip("fd00::1".parse().unwrap()));
}
#[test]
fn rejects_ipv4_mapped_loopback() {
assert!(is_reserved_ip("::ffff:127.0.0.1".parse().unwrap()));
}
#[test]
fn allows_public_v4() {
assert!(!is_reserved_ip("8.8.8.8".parse().unwrap()));
assert!(!is_reserved_ip("1.1.1.1".parse().unwrap()));
}
#[test]
fn allows_public_v6() {
assert!(!is_reserved_ip("2606:4700:4700::1111".parse().unwrap()));
}
#[tokio::test]
async fn validate_rejects_ip_literal_loopback() {
let err = validate_public_host("127.0.0.1", 80).await.unwrap_err();
assert!(matches!(err, SeerError::InvalidInput(_)));
}
#[tokio::test]
async fn validate_rejects_ip_literal_metadata() {
let err = validate_public_host("169.254.169.254", 80)
.await
.unwrap_err();
assert!(matches!(err, SeerError::InvalidInput(_)));
}
#[tokio::test]
async fn validate_allows_public_ip_literal() {
validate_public_host("8.8.8.8", 53).await.unwrap();
}
#[tokio::test]
#[ignore = "requires network — hits Google DNS via hickory fallback"]
async fn validate_rejects_unresolvable_via_fallback() {
let err = validate_public_host("nonexistent.host.invalid.", 443)
.await
.unwrap_err();
let msg = format!("{err}");
assert!(msg.contains("DNS resolution failed"), "got: {msg}");
assert!(msg.contains("fallback"), "got: {msg}");
}
}