use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::pin::Pin;
use std::str::FromStr;
use std::sync::Arc;
use reqwest::dns::{Addrs, Resolve, Resolving};
use reqwest::redirect::Policy;
use reqwest::{ClientBuilder, Url};
use crate::error::Error;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum IpPolicy {
Strict,
AllowPrivate,
}
impl Default for IpPolicy {
fn default() -> Self {
Self::from_env()
}
}
impl IpPolicy {
pub fn from_env() -> Self {
Self::from_env_value(std::env::var("HEARTBIT_ALLOW_PRIVATE_IPS").ok().as_deref())
}
pub(crate) fn from_env_value(value: Option<&str>) -> Self {
match value {
Some(v) => match v.trim().to_ascii_lowercase().as_str() {
"1" | "true" => Self::AllowPrivate,
_ => Self::Strict,
},
None => Self::Strict,
}
}
}
#[derive(Debug, Clone)]
pub struct SafeUrl(Url);
impl SafeUrl {
pub async fn parse(s: &str, policy: IpPolicy) -> Result<Self, Error> {
let url = Url::parse(s).map_err(|e| Error::Agent(format!("invalid URL: {e}")))?;
let scheme = url.scheme();
if scheme != "http" && scheme != "https" {
return Err(Error::Agent(format!(
"URL scheme {scheme:?} not allowed; only http and https"
)));
}
if matches!(policy, IpPolicy::AllowPrivate) {
return Ok(Self(url));
}
let host = url
.host_str()
.ok_or_else(|| Error::Agent("URL has no host".into()))?;
let port = url.port_or_known_default().unwrap_or(80);
let bare_host = host
.strip_prefix('[')
.and_then(|h| h.strip_suffix(']'))
.unwrap_or(host);
if let Ok(ip) = IpAddr::from_str(bare_host) {
if is_blocked(&ip) {
return Err(reject(host));
}
return Ok(Self(url));
}
let addrs = tokio::net::lookup_host((bare_host, port))
.await
.map_err(|e| Error::Agent(format!("DNS lookup failed for {host}: {e}")))?;
let mut any = false;
for sa in addrs {
any = true;
if is_blocked(&sa.ip()) {
return Err(reject(host));
}
}
if !any {
return Err(Error::Agent(format!(
"DNS lookup for {host} returned no addresses"
)));
}
Ok(Self(url))
}
pub fn as_str(&self) -> &str {
self.0.as_str()
}
pub fn into_url(self) -> Url {
self.0
}
}
fn reject(host: &str) -> Error {
Error::Agent(format!(
"URL host {host} resolves to a private/loopback address; \
refused (set HEARTBIT_ALLOW_PRIVATE_IPS=1 to override)"
))
}
fn is_blocked(ip: &IpAddr) -> bool {
match ip {
IpAddr::V4(v4) => is_blocked_v4(v4),
IpAddr::V6(v6) => is_blocked_v6(v6),
}
}
pub fn validate_url_sync(s: &str, policy: IpPolicy) -> Result<(), Error> {
let url = Url::parse(s).map_err(|e| Error::Agent(format!("invalid URL: {e}")))?;
let scheme = url.scheme();
if scheme != "http" && scheme != "https" {
return Err(Error::Agent(format!(
"URL scheme {scheme:?} not allowed; only http and https"
)));
}
if matches!(policy, IpPolicy::AllowPrivate) {
return Ok(());
}
let host = url
.host_str()
.ok_or_else(|| Error::Agent("URL has no host".into()))?;
let bare_host = host
.strip_prefix('[')
.and_then(|h| h.strip_suffix(']'))
.unwrap_or(host);
if let Ok(ip) = IpAddr::from_str(bare_host)
&& is_blocked(&ip)
{
return Err(reject(host));
}
Ok(())
}
fn is_blocked_v4(ip: &Ipv4Addr) -> bool {
ip.is_loopback()
|| ip.is_link_local()
|| ip.is_private()
|| ip.is_multicast()
|| ip.is_unspecified()
|| ip.is_broadcast()
|| is_cgnat_v4(ip)
}
fn is_blocked_v6(ip: &Ipv6Addr) -> bool {
if let Some(v4) = ip.to_ipv4_mapped() {
return is_blocked_v4(&v4);
}
ip.is_loopback()
|| ip.is_multicast()
|| ip.is_unspecified()
|| is_link_local_v6(ip)
|| is_ula_v6(ip)
}
fn is_cgnat_v4(ip: &Ipv4Addr) -> bool {
let [a, b, _, _] = ip.octets();
a == 100 && (64..=127).contains(&b)
}
fn is_link_local_v6(ip: &Ipv6Addr) -> bool {
let s = ip.segments()[0];
(s & 0xffc0) == 0xfe80
}
fn is_ula_v6(ip: &Ipv6Addr) -> bool {
let s = ip.segments()[0];
(s & 0xfe00) == 0xfc00
}
pub const DEFAULT_VENDOR_BODY_CAP: usize = 5 * 1024 * 1024;
pub async fn read_body_capped(
response: reqwest::Response,
max_bytes: usize,
) -> Result<(Vec<u8>, bool), Error> {
use futures::TryStreamExt;
let mut buf: Vec<u8> = Vec::with_capacity(8 * 1024);
let mut truncated = false;
let mut stream = response.bytes_stream();
while let Some(chunk) = stream.try_next().await.map_err(Error::Http)? {
let remaining = max_bytes.saturating_sub(buf.len());
if remaining == 0 {
truncated = true;
break;
}
let take = chunk.len().min(remaining);
buf.extend_from_slice(&chunk[..take]);
if take < chunk.len() {
truncated = true;
break;
}
}
Ok((buf, truncated))
}
pub async fn read_text_capped(
response: reqwest::Response,
max_bytes: usize,
) -> Result<String, Error> {
let (bytes, truncated) = read_body_capped(response, max_bytes).await?;
let mut text = String::from_utf8_lossy(&bytes).into_owned();
if truncated {
text.push_str("…[truncated]");
}
Ok(text)
}
pub struct SafeDnsResolver {
policy: IpPolicy,
}
impl SafeDnsResolver {
pub fn new(policy: IpPolicy) -> Self {
Self { policy }
}
}
impl Resolve for SafeDnsResolver {
fn resolve(&self, name: reqwest::dns::Name) -> Resolving {
let host = name.as_str().to_string();
let policy = self.policy;
Box::pin(async move {
let resolved: Vec<SocketAddr> =
tokio::net::lookup_host((host.as_str(), 0)).await?.collect();
if resolved.is_empty() {
return Err::<Addrs, _>(
format!("DNS lookup for {host} returned no addresses").into(),
);
}
let filtered: Vec<SocketAddr> = match policy {
IpPolicy::AllowPrivate => resolved,
IpPolicy::Strict => resolved
.into_iter()
.filter(|sa| !is_blocked(&sa.ip()))
.collect(),
};
if filtered.is_empty() {
return Err::<Addrs, _>(
format!(
"host {host} resolves to private/loopback addresses; \
refused at connect time (set HEARTBIT_ALLOW_PRIVATE_IPS=1 to override)"
)
.into(),
);
}
let iter: Addrs = Box::new(filtered.into_iter());
Ok(iter)
}) as Pin<Box<_>>
}
}
pub fn safe_client_builder() -> ClientBuilder {
reqwest::Client::builder()
.redirect(Policy::none())
.no_proxy()
.connect_timeout(std::time::Duration::from_secs(5))
.dns_resolver(Arc::new(SafeDnsResolver::new(IpPolicy::default())))
}
pub fn vendor_client_builder() -> ClientBuilder {
reqwest::Client::builder()
.redirect(Policy::none())
.no_proxy()
.connect_timeout(std::time::Duration::from_secs(5))
.dns_resolver(Arc::new(SafeDnsResolver::new(IpPolicy::default())))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn ip_policy_unset_is_strict() {
assert_eq!(IpPolicy::from_env_value(None), IpPolicy::Strict);
}
#[test]
fn ip_policy_one_is_allow() {
assert_eq!(IpPolicy::from_env_value(Some("1")), IpPolicy::AllowPrivate);
}
#[test]
fn ip_policy_true_case_insensitive_is_allow() {
assert_eq!(
IpPolicy::from_env_value(Some("TRUE")),
IpPolicy::AllowPrivate
);
assert_eq!(
IpPolicy::from_env_value(Some("True")),
IpPolicy::AllowPrivate
);
assert_eq!(
IpPolicy::from_env_value(Some(" true ")),
IpPolicy::AllowPrivate
);
}
#[test]
fn ip_policy_zero_is_strict() {
assert_eq!(IpPolicy::from_env_value(Some("0")), IpPolicy::Strict);
assert_eq!(IpPolicy::from_env_value(Some("false")), IpPolicy::Strict);
}
#[test]
fn ip_policy_garbage_is_strict() {
assert_eq!(IpPolicy::from_env_value(Some("yesplz")), IpPolicy::Strict);
assert_eq!(IpPolicy::from_env_value(Some("")), IpPolicy::Strict);
}
#[tokio::test]
async fn safe_url_rejects_non_http_scheme() {
let err = SafeUrl::parse("file:///etc/passwd", IpPolicy::Strict)
.await
.unwrap_err();
let msg = err.to_string();
assert!(msg.contains("scheme") && msg.contains("file"), "got: {msg}");
}
#[tokio::test]
async fn safe_url_rejects_invalid_url() {
let err = SafeUrl::parse("not a url", IpPolicy::Strict)
.await
.unwrap_err();
assert!(err.to_string().contains("invalid URL"));
}
#[tokio::test]
async fn safe_url_rejects_loopback_v4() {
assert!(
SafeUrl::parse("http://127.0.0.1/", IpPolicy::Strict)
.await
.is_err()
);
}
#[tokio::test]
async fn safe_url_rejects_loopback_v6() {
assert!(
SafeUrl::parse("http://[::1]/", IpPolicy::Strict)
.await
.is_err()
);
}
#[tokio::test]
async fn safe_url_rejects_link_local_v4() {
assert!(
SafeUrl::parse("http://169.254.169.254/", IpPolicy::Strict)
.await
.is_err()
);
}
#[tokio::test]
async fn safe_url_rejects_link_local_v6() {
assert!(
SafeUrl::parse("http://[fe80::1]/", IpPolicy::Strict)
.await
.is_err()
);
}
#[tokio::test]
async fn safe_url_rejects_rfc1918() {
for h in ["10.0.0.1", "172.16.0.1", "192.168.1.1"] {
let r = SafeUrl::parse(&format!("http://{h}/"), IpPolicy::Strict).await;
assert!(r.is_err(), "{h} should be rejected");
}
}
#[tokio::test]
async fn safe_url_rejects_cgnat() {
assert!(
SafeUrl::parse("http://100.64.0.1/", IpPolicy::Strict)
.await
.is_err()
);
assert!(
SafeUrl::parse("http://100.127.255.1/", IpPolicy::Strict)
.await
.is_err()
);
}
#[tokio::test]
async fn safe_url_rejects_ula() {
assert!(
SafeUrl::parse("http://[fc00::1]/", IpPolicy::Strict)
.await
.is_err()
);
assert!(
SafeUrl::parse("http://[fd00::1]/", IpPolicy::Strict)
.await
.is_err()
);
}
#[tokio::test]
async fn safe_url_rejects_multicast() {
assert!(
SafeUrl::parse("http://224.0.0.1/", IpPolicy::Strict)
.await
.is_err()
);
assert!(
SafeUrl::parse("http://[ff00::1]/", IpPolicy::Strict)
.await
.is_err()
);
}
#[tokio::test]
async fn safe_url_rejects_unspecified() {
assert!(
SafeUrl::parse("http://0.0.0.0/", IpPolicy::Strict)
.await
.is_err()
);
assert!(
SafeUrl::parse("http://[::]/", IpPolicy::Strict)
.await
.is_err()
);
}
#[tokio::test]
async fn safe_url_rejects_broadcast() {
assert!(
SafeUrl::parse("http://255.255.255.255/", IpPolicy::Strict)
.await
.is_err()
);
}
#[tokio::test]
async fn safe_url_accepts_public_ip() {
let safe = SafeUrl::parse("http://8.8.8.8/", IpPolicy::Strict)
.await
.unwrap();
assert_eq!(safe.as_str(), "http://8.8.8.8/");
}
#[tokio::test]
async fn safe_url_rejects_ipv4_mapped_loopback() {
assert!(
SafeUrl::parse("http://[::ffff:127.0.0.1]/", IpPolicy::Strict)
.await
.is_err()
);
}
#[tokio::test]
async fn safe_url_rejects_ipv4_mapped_imds() {
assert!(
SafeUrl::parse("http://[::ffff:169.254.169.254]/", IpPolicy::Strict)
.await
.is_err()
);
}
#[tokio::test]
async fn safe_url_rejects_ipv4_mapped_rfc1918() {
assert!(
SafeUrl::parse("http://[::ffff:10.0.0.1]/", IpPolicy::Strict)
.await
.is_err()
);
}
#[tokio::test]
async fn safe_url_accepts_ipv4_mapped_public() {
let safe = SafeUrl::parse("http://[::ffff:8.8.8.8]/", IpPolicy::Strict)
.await
.unwrap();
assert!(safe.as_str().starts_with("http://[::ffff:"));
}
#[tokio::test]
async fn safe_url_rejects_localhost_dns() {
assert!(
SafeUrl::parse("http://localhost/", IpPolicy::Strict)
.await
.is_err()
);
}
#[tokio::test]
async fn safe_url_allow_private_accepts_loopback() {
let safe = SafeUrl::parse("http://127.0.0.1/", IpPolicy::AllowPrivate)
.await
.unwrap();
assert_eq!(safe.as_str(), "http://127.0.0.1/");
}
#[tokio::test]
async fn safe_url_allow_private_accepts_localhost() {
let safe = SafeUrl::parse("http://localhost/", IpPolicy::AllowPrivate)
.await
.unwrap();
assert_eq!(safe.as_str(), "http://localhost/");
}
#[tokio::test]
async fn safe_url_rejection_message_mentions_override() {
let err = SafeUrl::parse("http://127.0.0.1/", IpPolicy::Strict)
.await
.unwrap_err();
let msg = err.to_string();
assert!(
msg.contains("HEARTBIT_ALLOW_PRIVATE_IPS"),
"rejection message should mention the override env var; got: {msg}"
);
}
#[tokio::test]
async fn safe_client_builder_does_not_follow_redirects() {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
if let Ok((mut sock, _)) = listener.accept().await {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
let mut buf = [0u8; 1024];
let _ = sock.read(&mut buf).await;
let resp = b"HTTP/1.1 302 Found\r\nLocation: /landed\r\nContent-Length: 0\r\n\r\n";
let _ = sock.write_all(resp).await;
let _ = sock.shutdown().await;
}
});
let client = safe_client_builder().build().unwrap();
let resp = client
.get(format!("http://{addr}/start"))
.send()
.await
.unwrap();
assert_eq!(resp.status().as_u16(), 302, "redirect must NOT be followed");
}
#[test]
fn vendor_client_builder_compiles_and_builds() {
let _ = vendor_client_builder().build().unwrap();
}
#[tokio::test]
async fn read_body_capped_truncates_at_limit() {
use std::convert::Infallible;
use tokio::io::AsyncWriteExt;
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
if let Ok((mut sock, _)) = listener.accept().await {
let mut tmp = [0u8; 1024];
let _ = tokio::io::AsyncReadExt::read(&mut sock, &mut tmp).await;
let _ = sock
.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 10485760\r\n\r\n")
.await;
let chunk = vec![b'A'; 64 * 1024];
for _ in 0..160 {
if sock.write_all(&chunk).await.is_err() {
break;
}
}
Ok::<_, Infallible>(())
} else {
Ok(())
}
});
let client = reqwest::Client::builder()
.redirect(reqwest::redirect::Policy::none())
.build()
.unwrap();
let resp = client.get(format!("http://{addr}/")).send().await.unwrap();
let (bytes, truncated) = read_body_capped(resp, 1024 * 1024).await.unwrap();
assert!(truncated, "must report truncation");
assert!(
bytes.len() <= 1024 * 1024 + 64 * 1024,
"must not exceed cap by more than one chunk; got {}",
bytes.len()
);
}
}