use std::net::ToSocketAddrs;
use url::Url;
pub trait TlsExt {
fn requires_tls(&self) -> bool;
}
impl<T> TlsExt for T
where
T: ToSocketAddrs,
{
fn requires_tls(&self) -> bool {
match self.to_socket_addrs() {
Err(_) => true,
Ok(addrs) => {
let is_local = addrs.into_iter().all(|addr| match addr.ip() {
std::net::IpAddr::V4(addr) => {
addr.is_private() | addr.is_link_local() | addr.is_loopback()
}
std::net::IpAddr::V6(addr) => addr.is_loopback(),
});
!is_local
}
}
}
}
pub(crate) fn is_safe_url(url: &Url) -> bool {
url.scheme() == "https"
|| match url.socket_addrs(|| Some(80)) {
Ok(addrs) => !addrs.as_slice().requires_tls(),
Err(_) => false,
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_addr(addr: &str) -> bool {
(addr, 1234).requires_tls()
}
#[test]
fn test_v4() {
let local = ["localhost", "127.0.0.1", "10.1.2.3"];
for host in local {
assert!(!test_addr(host), "{host} should not require TLS");
}
let global = ["www.google.com", "8.8.8.8"];
for host in global {
assert!(test_addr(host), "{host} should require TLS");
}
}
#[test]
fn test_v6() {
let local = ["::1"];
for host in local {
assert!(!test_addr(host), "{host} should not require TLS");
}
let global = ["2003:d4:773d:7600:904e:2a90:16bb:268d"];
for host in global {
assert!(test_addr(host), "{host} should require TLS");
}
}
#[test]
fn test_urls() {
let safe_urls = ["http://localhost/", "https://localhost/", "https://www.fastly.com/"];
for url in safe_urls {
assert!(is_safe_url(&url.parse().expect("cannot parse URL")), "{url:?} is considered safe");
}
let unsafe_urls = ["http://neverssl.com/", "http://8.8.8.8"];
for url in unsafe_urls {
assert!(
!is_safe_url(&url.parse().expect("cannot parse URL")),
"{url:?} must not be considered safe"
);
}
}
}