use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use reqwest::dns::{Addrs, Name, Resolve, Resolving};
use tokio::net::lookup_host;
use crate::fetcher::ssrf::{SsrfError, SsrfLevel, validate_addresses};
tokio::task_local! {
pub static SSRF_LEVEL: SsrfLevel;
}
#[derive(Debug)]
pub struct DialBlocked(pub SsrfError);
impl std::fmt::Display for DialBlocked {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"ssrf policy blocked dial-time address resolution: {}",
self.0
)
}
}
impl std::error::Error for DialBlocked {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
Some(&self.0)
}
}
#[derive(Default)]
pub struct SsrfValidatingResolver;
impl Resolve for SsrfValidatingResolver {
fn resolve(&self, name: Name) -> Resolving {
let host = name.as_str().to_string();
Box::pin(async move {
let target = format!("{host}:0");
let resolved: Vec<SocketAddr> = lookup_host(target.as_str())
.await
.map_err(Box::<dyn std::error::Error + Send + Sync>::from)?
.collect();
if let Ok(level) = SSRF_LEVEL.try_with(|l| *l) {
let ips: Vec<IpAddr> = resolved.iter().map(|s| s.ip()).collect();
if let Err(e) = validate_addresses(&ips, level) {
return Err(
Box::new(DialBlocked(e)) as Box<dyn std::error::Error + Send + Sync>
);
}
}
let iter: Addrs = Box::new(resolved.into_iter());
Ok(iter)
})
}
}
pub fn shared_resolver() -> Arc<SsrfValidatingResolver> {
Arc::new(SsrfValidatingResolver)
}
pub fn dial_blocked_cause<'a>(
err: &'a (dyn std::error::Error + 'static),
) -> Option<&'a DialBlocked> {
let mut current: Option<&(dyn std::error::Error + 'static)> = Some(err);
while let Some(e) = current {
if let Some(blocked) = e.downcast_ref::<DialBlocked>() {
return Some(blocked);
}
current = e.source();
}
None
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::Ipv4Addr;
#[tokio::test]
async fn resolver_passes_through_when_no_context_set() {
let r = SsrfValidatingResolver;
let name: Name = "localhost".parse().unwrap();
let result = r.resolve(name).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn resolver_blocks_loopback_under_strict() {
let r = SsrfValidatingResolver;
let name: Name = "localhost".parse().unwrap();
let result = SSRF_LEVEL
.scope(SsrfLevel::Strict, async { r.resolve(name).await })
.await;
let Err(err) = result else {
panic!("strict must reject loopback");
};
assert!(
dial_blocked_cause(&*err).is_some(),
"expected DialBlocked in source chain, got: {err}",
);
}
#[tokio::test]
async fn resolver_allows_loopback_under_loopback_level() {
let r = SsrfValidatingResolver;
let name: Name = "localhost".parse().unwrap();
let result = SSRF_LEVEL
.scope(SsrfLevel::Loopback, async { r.resolve(name).await })
.await;
assert!(result.is_ok(), "loopback level must accept localhost");
}
#[test]
fn dial_blocked_walks_source_chain() {
let inner = SsrfError::Address {
address: IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
level: SsrfLevel::Strict,
reason: "loopback IPv4",
};
let dial = DialBlocked(inner);
#[derive(Debug)]
struct Wrap(DialBlocked);
impl std::fmt::Display for Wrap {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "wrap")
}
}
impl std::error::Error for Wrap {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
Some(&self.0)
}
}
let wrapped = Wrap(dial);
assert!(dial_blocked_cause(&wrapped).is_some());
}
}