use std::net::IpAddr;
use anyhow::{bail, Result, anyhow};
use url::Url;
#[cfg(feature = "net")]
use trust_dns_resolver::config::{ResolverConfig, ResolverOpts};
#[cfg(feature = "net")]
use trust_dns_resolver::TokioAsyncResolver;
#[cfg(feature = "net")]
use reqwest::{Client, redirect::Policy, header};
pub struct ShieldClient {
#[cfg(feature = "net")]
client: Client,
allowlist: Vec<String>,
}
impl ShieldClient {
pub fn builder() -> ShieldClientBuilder {
ShieldClientBuilder::default()
}
#[cfg(feature = "net")]
pub async fn get(&self, url_str: &str) -> Result<reqwest::Response> {
let (safe_url, original_host) = self.prepare_safe_request(url_str).await?;
let mut request = self.client.get(safe_url);
if let Some(host) = original_host {
request = request.header(header::HOST, host);
}
Ok(request.send().await?)
}
#[cfg(feature = "net")]
async fn prepare_safe_request(&self, url_str: &str) -> Result<(Url, Option<String>)> {
let mut url = Url::parse(url_str)?;
let host_str = url.host_str().ok_or_else(|| anyhow!("No host in URL"))?.to_string();
if self.allowlist.contains(&host_str) {
return Ok((url, None));
}
let resolver = TokioAsyncResolver::tokio(
ResolverConfig::default(),
ResolverOpts::default(),
);
let response = resolver.lookup_ip(&host_str).await?;
let first_ip = response.iter().next().ok_or_else(|| anyhow!("Failed to resolve host"))?;
for ip in response.iter() {
if self.is_private_ip(ip) {
bail!("Access Denied: Private IP address detected ({})", ip);
}
}
url.set_host(Some(&first_ip.to_string()))?;
Ok((url, Some(host_str)))
}
fn is_private_ip(&self, ip: IpAddr) -> bool {
match ip {
IpAddr::V4(v4) => {
v4.is_loopback() || v4.is_private() || v4.is_link_local() ||
v4.is_broadcast() || v4.is_documentation() || v4.is_unspecified()
}
IpAddr::V6(v6) => {
v6.is_loopback() || v6.is_unspecified() ||
(v6.segments()[0] & 0xfe00) == 0xfc00 || (v6.segments()[0] & 0xffc0) == 0xfe80 }
}
}
}
pub struct ShieldClientBuilder {
allowlist: Vec<String>,
}
impl Default for ShieldClientBuilder {
fn default() -> Self {
Self {
allowlist: Vec::new(),
}
}
}
impl ShieldClientBuilder {
pub fn allow_endpoint(mut self, host: &str) -> Self {
self.allowlist.push(host.to_string());
self
}
#[cfg(feature = "net")]
pub fn build(self) -> Result<ShieldClient> {
let client = Client::builder()
.redirect(Policy::none()) .build()?;
Ok(ShieldClient {
client,
allowlist: self.allowlist,
})
}
#[cfg(not(feature = "net"))]
pub fn build(self) -> Result<ShieldClient> {
Ok(ShieldClient {
allowlist: self.allowlist,
})
}
}
#[cfg(test)]
#[cfg(feature = "net")]
mod tests {
use super::*;
#[tokio::test]
async fn test_is_private_ip() {
let shield = ShieldClient::builder().build().unwrap();
assert!(shield.is_private_ip("127.0.0.1".parse().unwrap()));
assert!(shield.is_private_ip("192.168.1.1".parse().unwrap()));
assert!(shield.is_private_ip("::1".parse().unwrap()));
assert!(!shield.is_private_ip("8.8.8.8".parse().unwrap()));
}
#[tokio::test]
async fn test_prepare_safe_request() {
let shield = ShieldClient::builder().build().unwrap();
let res = shield.prepare_safe_request("http://localhost").await;
assert!(res.is_err());
let shield_with_allow = ShieldClient::builder()
.allow_endpoint("localhost")
.build().unwrap();
let res = shield_with_allow.prepare_safe_request("http://localhost").await;
assert!(res.is_ok());
}
}