use dashmap::DashMap;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::sync::Arc;
use std::time::{Duration, Instant};
const DNS_CACHE_TTL: Duration = Duration::from_secs(60);
const DNS_CACHE_MAX_ENTRIES: usize = 4096;
type DnsCacheEntry = (Instant, Arc<Vec<SocketAddr>>);
static DNS_CACHE: std::sync::OnceLock<DashMap<String, DnsCacheEntry>> = std::sync::OnceLock::new();
pub async fn resolve_dns_cached(host_port: &str) -> std::io::Result<Vec<SocketAddr>> {
let cache = DNS_CACHE.get_or_init(DashMap::new);
if let Some(entry) = cache.get(host_port) {
let (inserted_at, addrs) = entry.value();
if inserted_at.elapsed() < DNS_CACHE_TTL {
return Ok((**addrs).clone());
}
drop(entry);
cache.remove(host_port);
}
let addrs: Vec<SocketAddr> = tokio::net::lookup_host(host_port).await?.collect();
if !addrs.is_empty() {
if cache.len() >= DNS_CACHE_MAX_ENTRIES {
cache.clear();
}
cache.insert(
host_port.to_string(),
(Instant::now(), Arc::new(addrs.clone())),
);
}
Ok(addrs)
}
pub fn is_private_ip_addr_fast(ip: &IpAddr) -> bool {
match ip {
IpAddr::V4(ipv4) => {
let octets = ipv4.octets();
let val = u32::from_be_bytes(octets);
if val & 0xFF000000 == 0x7F000000 {
return true;
}
if val & 0xFF000000 == 0x0A000000 {
return true;
}
if val & 0xFFF00000 == 0xAC100000 {
return true;
}
if val & 0xFFFF0000 == 0xC0A80000 {
return true;
}
if val & 0xFFFF0000 == 0xA9FE0000 {
return true;
}
if val & 0xFF000000 == 0 {
return true;
}
if val & 0xF0000000 == 0xE0000000 {
return true;
}
if val & 0xFFC00000 == 0x64400000 {
return true;
}
if val & 0xF0000000 == 0xF0000000 {
return true;
}
false
}
IpAddr::V6(ipv6) => {
let octets = ipv6.octets();
if octets == [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1] {
return true;
}
if octets == [0; 16] {
return true;
}
if octets[0] == 0xfe && (octets[1] & 0xc0) == 0x80 {
return true;
}
if (octets[0] & 0xfe) == 0xfc {
return true;
}
if octets[0] == 0xff {
return true;
}
false
}
}
}
pub fn is_private_ip_addr(ip: &IpAddr) -> bool {
is_private_ip_addr_fast(ip) || crate::bogon::ip_addr_is_bogon(*ip)
}
pub fn is_private_url(url_str: &str) -> bool {
let url = match url::Url::parse(url_str) {
Ok(u) => u,
Err(_) => return true, };
if let Some(host) = url.host() {
match host {
url::Host::Ipv4(ip) => {
if is_private_ip_addr_fast(&IpAddr::V4(ip))
|| crate::bogon::ip_addr_is_bogon(IpAddr::V4(ip))
{
return true;
}
}
url::Host::Ipv6(ip) => {
if is_private_ip_addr_fast(&IpAddr::V6(ip))
|| crate::bogon::ip_addr_is_bogon(IpAddr::V6(ip))
{
return true;
}
}
url::Host::Domain(d) => {
if d == "localhost"
|| d.ends_with(".local")
|| d.ends_with(".internal")
|| d.ends_with(".localdomain")
{
return true;
}
let maybe_ip = if let Some(hex) =
d.strip_prefix("0x").or_else(|| d.strip_prefix("0X"))
{
u32::from_str_radix(hex, 16).ok().map(Ipv4Addr::from)
} else if d.starts_with('0') && d.len() > 1 && d.chars().all(|c| c.is_ascii_digit())
{
u32::from_str_radix(d, 8).ok().map(Ipv4Addr::from)
} else if let Ok(n) = d.parse::<u32>() {
Some(Ipv4Addr::from(n))
} else if let Ok(ip) = d.parse::<Ipv4Addr>() {
Some(ip)
} else {
canonicalize_short_form_ipv4(d)
};
if let Some(ip) = maybe_ip {
if is_private_ip_addr_fast(&IpAddr::V4(ip))
|| crate::bogon::ip_addr_is_bogon(IpAddr::V4(ip))
{
return true;
}
}
if looks_like_malformed_ip(d) {
return true;
}
}
}
}
false
}
fn canonicalize_short_form_ipv4(domain: &str) -> Option<Ipv4Addr> {
let parts: Vec<&str> = domain.split('.').collect();
if parts.len() < 2 || parts.len() > 3 {
return None;
}
let values: Option<Vec<u32>> = parts.iter().map(|p| parse_ip_field(p)).collect();
let values = values?;
let n = values.len();
let mut acc: u32 = 0;
for &leading in &values[..n - 1] {
if leading > 0xFF {
return None;
}
acc = (acc << 8) | leading;
}
let remaining_bytes = 4 - (n - 1);
let last = values[n - 1];
let max_last = if remaining_bytes >= 4 {
u32::MAX
} else {
(1u32 << (8 * remaining_bytes as u32)) - 1
};
if last > max_last {
return None;
}
acc = (acc << (8 * remaining_bytes as u32)) | last;
Some(Ipv4Addr::from(acc))
}
fn parse_ip_field(part: &str) -> Option<u32> {
if part.is_empty() {
return None;
}
if let Some(hex) = part.strip_prefix("0x").or_else(|| part.strip_prefix("0X")) {
if hex.is_empty() {
return None;
}
u32::from_str_radix(hex, 16).ok()
} else if part.len() > 1 && part.starts_with('0') {
u32::from_str_radix(part, 8).ok()
} else {
part.parse::<u32>().ok()
}
}
fn looks_like_malformed_ip(domain: &str) -> bool {
let parts: Vec<&str> = domain.split('.').collect();
if parts.len() >= 4
&& parts.iter().all(|p| {
!p.is_empty()
&& p.chars()
.all(|c| c.is_ascii_hexdigit() || c == '-' || c == 'x' || c == 'X')
})
{
return true;
}
if parts.len() == 4
&& parts
.iter()
.all(|p| p.starts_with('0') && p.len() > 1 && p.chars().all(|c| c.is_ascii_digit()))
{
return true;
}
false
}