use super::network_bind::IpFamily;
use crate::model::TlsSummary;
use anyhow::{anyhow, Context, Result};
use rustls::pki_types::ServerName;
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::net::{lookup_host, TcpSocket};
use tokio::time::timeout;
use tokio_rustls::TlsConnector;
const CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
fn ensure_crypto_provider() {
let _ = rustls::crypto::ring::default_provider().install_default();
}
pub async fn measure_tls_handshake(
hostname: &str,
port: u16,
cert_path: Option<&std::path::Path>,
bind_ip: Option<IpAddr>,
family: Option<IpFamily>,
) -> Result<TlsSummary> {
ensure_crypto_provider();
let mut root_store = rustls::RootCertStore::empty();
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
for cert in rustls_native_certs::load_native_certs().certs {
let _ = root_store.add(cert);
}
if let Some(path) = cert_path {
for cert in super::cert::load_rustls_certificates(path)? {
root_store
.add(cert)
.with_context(|| format!("failed to add custom CA from {}", path.display()))?;
}
}
let config = rustls::ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth();
let connector = TlsConnector::from(Arc::new(config));
let tcp_stream = connect_tcp(hostname, port, bind_ip, family).await?;
let server_name: ServerName<'static> = hostname
.to_string()
.try_into()
.map_err(|_| anyhow!("Invalid DNS name: {}", hostname))?;
let start = Instant::now();
let tls_stream = timeout(HANDSHAKE_TIMEOUT, connector.connect(server_name, tcp_stream))
.await
.with_context(|| format!("TLS handshake timed out after {:?}", HANDSHAKE_TIMEOUT))?
.with_context(|| format!("TLS handshake failed with {}", hostname))?;
let handshake_time = start.elapsed();
let (_, session) = tls_stream.get_ref();
let protocol_version = session.protocol_version().map(|v| format!("{:?}", v));
let cipher_suite = session
.negotiated_cipher_suite()
.map(|cs| format!("{:?}", cs.suite()));
Ok(TlsSummary {
handshake_time_ms: handshake_time.as_secs_f64() * 1000.0,
protocol_version,
cipher_suite,
})
}
async fn connect_tcp(
hostname: &str,
port: u16,
bind_ip: Option<IpAddr>,
family: Option<IpFamily>,
) -> Result<tokio::net::TcpStream> {
let lookup_target = format!("{}:{}", hostname, port);
let resolved: Vec<SocketAddr> = lookup_host(&lookup_target)
.await
.with_context(|| format!("DNS lookup failed for {}", hostname))?
.collect();
if resolved.is_empty() {
return Err(anyhow!("DNS returned no addresses for {}", hostname));
}
let candidates: Vec<SocketAddr> = match family {
Some(f) => resolved.iter().copied().filter(|a| f.matches(a.ip())).collect(),
None => resolved.clone(),
};
if candidates.is_empty() {
return Err(anyhow!(
"no {} address resolved for {}",
family.map(|f| f.label()).unwrap_or("usable"),
hostname
));
}
let mut last_err: Option<anyhow::Error> = None;
for addr in candidates {
let socket = match if addr.is_ipv4() {
TcpSocket::new_v4()
} else {
TcpSocket::new_v6()
} {
Ok(s) => s,
Err(e) => {
last_err = Some(anyhow!(e).context("failed to create socket"));
continue;
}
};
if let Some(ip) = bind_ip {
if let Err(e) = socket.bind(SocketAddr::new(ip, 0)) {
last_err = Some(anyhow!(e).context(format!("failed to bind to {}", ip)));
continue;
}
}
match timeout(CONNECT_TIMEOUT, socket.connect(addr)).await {
Ok(Ok(stream)) => return Ok(stream),
Ok(Err(e)) => last_err = Some(anyhow!(e).context(format!("connect to {} failed", addr))),
Err(_) => {
last_err = Some(anyhow!(
"connect to {} timed out after {:?}",
addr,
CONNECT_TIMEOUT
))
}
}
}
Err(last_err.unwrap_or_else(|| anyhow!("no addresses to try for {}", hostname)))
}
pub fn extract_host_port(url: &str) -> Option<(String, u16)> {
reqwest::Url::parse(url).ok().and_then(|u| {
let host = u.host_str()?.to_string();
let port = u.port_or_known_default().unwrap_or(443);
Some((host, port))
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_host_port() {
assert_eq!(
extract_host_port("https://speed.cloudflare.com"),
Some(("speed.cloudflare.com".to_string(), 443))
);
assert_eq!(
extract_host_port("https://example.com:8443/path"),
Some(("example.com".to_string(), 8443))
);
assert_eq!(
extract_host_port("http://example.com"),
Some(("example.com".to_string(), 80))
);
}
}