use anyhow::Result;
use ipnetwork::IpNetwork;
use lru::LruCache;
use std::collections::HashSet;
use std::net::IpAddr;
use std::num::NonZeroUsize;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::Mutex;
use tokio::time::timeout;
use tracing::{debug, warn};
use trust_dns_resolver::config::{ResolverConfig, ResolverOpts};
use trust_dns_resolver::TokioAsyncResolver;
#[derive(Debug)]
pub struct DestinationFilter {
resolver: Arc<TokioAsyncResolver>,
dns_cache: Arc<Mutex<LruCache<String, CachedResolution>>>,
blocked_ranges: Vec<IpNetwork>,
blocked_hostnames: HashSet<String>,
resolver_timeout: Duration,
cache_ttl: Duration,
}
struct CachedResolution {
ips: Vec<IpAddr>,
resolved_at: Instant,
}
#[derive(Debug, thiserror::Error)]
pub enum DestinationError {
#[error("Blocked hostname: {0}")]
BlockedHostname(String),
#[error("Blocked IP range: {0}")]
BlockedIpRange(IpAddr),
#[error("DNS resolution failed: {0}")]
ResolutionFailed(String),
#[error("DNS resolution timeout")]
ResolutionTimeout,
#[error("No IP addresses found for hostname: {0}")]
NoAddressesFound(String),
}
impl DestinationFilter {
pub fn new(
cache_size: usize,
cache_ttl_seconds: u64,
resolver_timeout_seconds: u64,
) -> Result<Self> {
let mut opts = ResolverOpts::default();
opts.timeout = Duration::from_secs(resolver_timeout_seconds);
opts.attempts = 2;
let resolver = TokioAsyncResolver::tokio_from_system_conf().unwrap_or_else(|_| {
warn!("Failed to load system DNS config, falling back to Google DNS");
TokioAsyncResolver::tokio(ResolverConfig::google(), opts.clone())
});
Ok(Self {
resolver: Arc::new(resolver),
dns_cache: Arc::new(Mutex::new(LruCache::new(
NonZeroUsize::new(cache_size).unwrap(),
))),
blocked_ranges: load_blocked_ranges()?,
blocked_hostnames: load_blocked_hostnames(),
resolver_timeout: Duration::from_secs(resolver_timeout_seconds),
cache_ttl: Duration::from_secs(cache_ttl_seconds),
})
}
pub async fn check_and_resolve(&self, host: &str) -> Result<Vec<IpAddr>, DestinationError> {
debug!("[SSRF] Checking destination: {}", host);
if self.is_hostname_blocked(host) {
warn!("[SSRF] Blocked hostname: {}", host);
return Err(DestinationError::BlockedHostname(host.to_string()));
}
let ips = self.resolve_with_cache(host).await?;
if ips.is_empty() {
return Err(DestinationError::NoAddressesFound(host.to_string()));
}
for ip in &ips {
if self.is_ip_blocked(*ip) {
warn!("[SSRF] Blocked IP {} for hostname {}", ip, host);
return Err(DestinationError::BlockedIpRange(*ip));
}
}
debug!("[SSRF] Allowed destination {} -> {:?}", host, ips);
Ok(ips)
}
async fn resolve_with_cache(&self, host: &str) -> Result<Vec<IpAddr>, DestinationError> {
let mut cache = self.dns_cache.lock().await;
if let Some(cached) = cache.get(host) {
if cached.resolved_at.elapsed() < self.cache_ttl {
debug!("[DNS] Cache hit for {}: {:?}", host, cached.ips);
return Ok(cached.ips.clone());
}
}
drop(cache);
let lookup_result = timeout(self.resolver_timeout, self.resolver.lookup_ip(host)).await;
let ips = match lookup_result {
Ok(Ok(lookup)) => {
let ips: Vec<IpAddr> = lookup.iter().collect();
debug!("[DNS] Resolved {} -> {:?}", host, ips);
ips
}
Ok(Err(e)) => {
warn!("[DNS] Resolution failed for {}: {}", host, e);
return Err(DestinationError::ResolutionFailed(e.to_string()));
}
Err(_) => {
warn!("[DNS] Resolution timeout for {}", host);
return Err(DestinationError::ResolutionTimeout);
}
};
let mut cache = self.dns_cache.lock().await;
cache.put(
host.to_string(),
CachedResolution {
ips: ips.clone(),
resolved_at: Instant::now(),
},
);
Ok(ips)
}
fn is_hostname_blocked(&self, host: &str) -> bool {
let host_lower = host.to_lowercase();
if self.blocked_hostnames.contains(&host_lower) {
return true;
}
if host_lower == "localhost"
|| host_lower.ends_with(".localhost")
|| host_lower == "ip6-localhost"
|| host_lower == "ip6-loopback"
{
return true;
}
if host_lower.ends_with(".local") || host_lower.ends_with(".internal") || host_lower.ends_with(".localdomain")
|| host_lower.ends_with(".home")
|| host_lower.ends_with(".lan")
{
return true;
}
false
}
fn is_ip_blocked(&self, ip: IpAddr) -> bool {
for range in &self.blocked_ranges {
if range.contains(ip) {
return true;
}
}
false
}
}
fn load_blocked_ranges() -> Result<Vec<IpNetwork>> {
const BLOCKED_CIDRS: &[&str] = &[
"10.0.0.0/8",
"172.16.0.0/12",
"192.168.0.0/16",
"127.0.0.0/8",
"169.254.0.0/16",
"169.254.169.254/32", "0.0.0.0/8", "255.255.255.255/32", "224.0.0.0/4", "240.0.0.0/4", "::1/128",
"fe80::/10",
"fc00::/7",
"fd00:ec2::254/128",
"2001:db8::/32", "ff00::/8", ];
let mut ranges = Vec::new();
for cidr in BLOCKED_CIDRS {
ranges.push(cidr.parse()?);
}
Ok(ranges)
}
fn load_blocked_hostnames() -> HashSet<String> {
let mut set = HashSet::new();
set.insert("metadata.google.internal".to_string());
set.insert("metadata.azure.com".to_string());
set.insert("instance-data.ec2.internal".to_string());
set
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_block_localhost() {
let filter = DestinationFilter::new(100, 60, 5).unwrap();
let result = filter.check_and_resolve("localhost").await;
assert!(matches!(result, Err(DestinationError::BlockedHostname(_))));
}
#[tokio::test]
async fn test_block_localhost_variants() {
let filter = DestinationFilter::new(100, 60, 5).unwrap();
assert!(filter.check_and_resolve("ip6-localhost").await.is_err());
assert!(filter.check_and_resolve("test.localhost").await.is_err());
}
#[tokio::test]
async fn test_block_metadata_hostname() {
let filter = DestinationFilter::new(100, 60, 5).unwrap();
let result = filter.check_and_resolve("metadata.google.internal").await;
assert!(matches!(result, Err(DestinationError::BlockedHostname(_))));
}
#[cfg(feature = "network-tests")]
#[tokio::test]
async fn test_allow_public_domain() {
let filter = DestinationFilter::new(100, 60, 5).unwrap();
let result = filter.check_and_resolve("example.com").await;
assert!(result.is_ok());
let ips = result.unwrap();
assert!(!ips.is_empty());
}
#[cfg(feature = "network-tests")]
#[tokio::test]
async fn test_dns_caching() {
let filter = DestinationFilter::new(100, 60, 5).unwrap();
let start = Instant::now();
let ips1 = filter.check_and_resolve("example.com").await.unwrap();
let first_duration = start.elapsed();
let start = Instant::now();
let ips2 = filter.check_and_resolve("example.com").await.unwrap();
let second_duration = start.elapsed();
assert_eq!(ips1, ips2);
assert!(second_duration < first_duration / 2);
}
#[test]
fn test_is_ip_blocked_rfc1918() {
let filter = DestinationFilter::new(100, 60, 5).unwrap();
assert!(filter.is_ip_blocked("10.0.0.1".parse().unwrap()));
assert!(filter.is_ip_blocked("172.16.0.1".parse().unwrap()));
assert!(filter.is_ip_blocked("192.168.1.1".parse().unwrap()));
}
#[test]
fn test_is_ip_blocked_localhost() {
let filter = DestinationFilter::new(100, 60, 5).unwrap();
assert!(filter.is_ip_blocked("127.0.0.1".parse().unwrap()));
assert!(filter.is_ip_blocked("127.255.255.255".parse().unwrap()));
assert!(filter.is_ip_blocked("::1".parse().unwrap()));
}
#[test]
fn test_is_ip_blocked_metadata() {
let filter = DestinationFilter::new(100, 60, 5).unwrap();
assert!(filter.is_ip_blocked("169.254.169.254".parse().unwrap()));
}
#[test]
fn test_is_ip_allowed_public() {
let filter = DestinationFilter::new(100, 60, 5).unwrap();
assert!(!filter.is_ip_blocked("8.8.8.8".parse().unwrap()));
assert!(!filter.is_ip_blocked("1.1.1.1".parse().unwrap()));
assert!(!filter.is_ip_blocked("93.184.216.34".parse().unwrap())); }
#[test]
fn test_hostname_blocked() {
let filter = DestinationFilter::new(100, 60, 5).unwrap();
assert!(filter.is_hostname_blocked("localhost"));
assert!(filter.is_hostname_blocked("Localhost")); assert!(filter.is_hostname_blocked("test.localhost"));
assert!(filter.is_hostname_blocked("metadata.google.internal"));
}
#[test]
fn test_hostname_allowed() {
let filter = DestinationFilter::new(100, 60, 5).unwrap();
assert!(!filter.is_hostname_blocked("example.com"));
assert!(!filter.is_hostname_blocked("google.com"));
}
}