use std::net::{IpAddr, SocketAddr, ToSocketAddrs};
use std::path::Path;
use std::str::FromStr;
use std::sync::Arc;
use rustls_pki_types::CertificateDer;
use rustls_pki_types::ServerName;
use rustls_pki_types::pem::PemObject;
use tokio_rustls::TlsConnector;
use tokio_rustls::rustls::{ClientConfig, RootCertStore};
#[derive(Debug, thiserror::Error)]
pub enum TlsClientError {
#[error("I/O error: {0}")]
Io(#[from] std::io::Error),
#[error("no certificates found in CA file")]
NoCertificates,
#[error("TLS configuration error: {0}")]
Rustls(#[from] tokio_rustls::rustls::Error),
}
#[derive(Debug, thiserror::Error)]
pub enum RemoteTargetError {
#[error("missing ':port' in remote target: {0:?}")]
MissingPort(String),
#[error("invalid port in remote target {raw:?}: {source}")]
InvalidPort {
raw: String,
#[source]
source: std::num::ParseIntError,
},
#[error("DNS lookup failed for {host:?}: {source}")]
DnsLookupFailed {
host: String,
#[source]
source: std::io::Error,
},
#[error("DNS lookup for {host:?} returned no addresses")]
DnsNoAddresses { host: String },
#[error("invalid SNI hostname {host:?}: {message}")]
InvalidSni { host: String, message: String },
}
fn split_host_port(input: &str) -> Result<(&str, u16), RemoteTargetError> {
if let Some(rest) = input.strip_prefix('[') {
let close = rest
.find(']')
.ok_or_else(|| RemoteTargetError::MissingPort(input.to_string()))?;
let host = &rest[..close];
let after = &rest[close + 1..];
let port_str = after
.strip_prefix(':')
.ok_or_else(|| RemoteTargetError::MissingPort(input.to_string()))?;
let port = port_str
.parse::<u16>()
.map_err(|source| RemoteTargetError::InvalidPort {
raw: input.to_string(),
source,
})?;
Ok((host, port))
} else {
let idx = input
.rfind(':')
.ok_or_else(|| RemoteTargetError::MissingPort(input.to_string()))?;
let host = &input[..idx];
let port_str = &input[idx + 1..];
let port = port_str
.parse::<u16>()
.map_err(|source| RemoteTargetError::InvalidPort {
raw: input.to_string(),
source,
})?;
Ok((host, port))
}
}
pub fn parse_remote_target(
input: &str,
) -> Result<(SocketAddr, ServerName<'static>), RemoteTargetError> {
let (host, port) = split_host_port(input)?;
let socket_addr = if let Ok(ip) = IpAddr::from_str(host) {
SocketAddr::new(ip, port)
} else {
let mut iter = (host, port).to_socket_addrs().map_err(|source| {
RemoteTargetError::DnsLookupFailed {
host: host.to_string(),
source,
}
})?;
iter.next()
.ok_or_else(|| RemoteTargetError::DnsNoAddresses {
host: host.to_string(),
})?
};
let server_name = if let Ok(ip) = IpAddr::from_str(host) {
ServerName::IpAddress(ip.into())
} else {
ServerName::try_from(host.to_string()).map_err(|e| RemoteTargetError::InvalidSni {
host: host.to_string(),
message: e.to_string(),
})?
};
Ok((socket_addr, server_name))
}
pub fn connector_from_ca(ca_path: &Path) -> Result<TlsConnector, TlsClientError> {
let certs: Vec<CertificateDer<'static>> = CertificateDer::pem_file_iter(ca_path)
.map_err(|_| TlsClientError::NoCertificates)?
.collect::<Result<Vec<_>, _>>()
.map_err(|_| TlsClientError::NoCertificates)?;
if certs.is_empty() {
return Err(TlsClientError::NoCertificates);
}
let mut root_store = RootCertStore::empty();
for cert in certs {
root_store.add(cert).map_err(TlsClientError::Rustls)?;
}
let config = ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth();
Ok(TlsConnector::from(Arc::new(config)))
}
pub fn connector_insecure() -> TlsConnector {
let config = ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(NoVerifier))
.with_no_client_auth();
TlsConnector::from(Arc::new(config))
}
#[derive(Debug)]
struct NoVerifier;
impl tokio_rustls::rustls::client::danger::ServerCertVerifier for NoVerifier {
fn verify_server_cert(
&self,
_end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
server_name: &tokio_rustls::rustls::pki_types::ServerName<'_>,
_ocsp_response: &[u8],
_now: tokio_rustls::rustls::pki_types::UnixTime,
) -> Result<tokio_rustls::rustls::client::danger::ServerCertVerified, tokio_rustls::rustls::Error>
{
tracing::warn!(
target: "muxtop::insecure",
server_name = ?server_name,
"TLS certificate verification disabled — only safe in local dev"
);
Ok(tokio_rustls::rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &tokio_rustls::rustls::DigitallySignedStruct,
) -> Result<
tokio_rustls::rustls::client::danger::HandshakeSignatureValid,
tokio_rustls::rustls::Error,
> {
Ok(tokio_rustls::rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &tokio_rustls::rustls::DigitallySignedStruct,
) -> Result<
tokio_rustls::rustls::client::danger::HandshakeSignatureValid,
tokio_rustls::rustls::Error,
> {
Ok(tokio_rustls::rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<tokio_rustls::rustls::SignatureScheme> {
tokio_rustls::rustls::crypto::aws_lc_rs::default_provider()
.signature_verification_algorithms
.supported_schemes()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::NamedTempFile;
fn make_self_signed_cert() -> (String, String) {
let san = vec!["localhost".to_string()];
let ck = rcgen::generate_simple_self_signed(san).unwrap();
(ck.cert.pem(), ck.signing_key.serialize_pem())
}
#[test]
fn test_connector_from_ca_valid() {
let (cert_pem, _) = make_self_signed_cert();
let mut f = NamedTempFile::new().unwrap();
f.write_all(cert_pem.as_bytes()).unwrap();
let connector = connector_from_ca(f.path());
assert!(connector.is_ok());
}
#[test]
fn test_connector_from_ca_missing_file() {
let result = connector_from_ca(Path::new("/nonexistent/ca.pem"));
assert!(result.is_err());
}
#[test]
fn test_connector_insecure_builds() {
let _connector = connector_insecure();
}
#[test]
fn test_parse_remote_target_ipv4_literal() {
let (addr, sni) = parse_remote_target("127.0.0.1:4242").unwrap();
assert_eq!(addr, "127.0.0.1:4242".parse::<SocketAddr>().unwrap());
match sni {
ServerName::IpAddress(_) => {}
other => panic!("expected ServerName::IpAddress, got {other:?}"),
}
}
#[test]
fn test_parse_remote_target_ipv6_literal() {
let (addr, sni) = parse_remote_target("[::1]:4242").unwrap();
assert_eq!(addr, "[::1]:4242".parse::<SocketAddr>().unwrap());
match sni {
ServerName::IpAddress(_) => {}
other => panic!("expected ServerName::IpAddress, got {other:?}"),
}
}
#[test]
fn test_parse_remote_target_hostname_uses_dns_sni() {
let result = parse_remote_target("localhost:4242");
let (addr, sni) = match result {
Ok(v) => v,
Err(e) => {
eprintln!("skipping hostname SNI test: {e}");
return;
}
};
assert_eq!(addr.port(), 4242);
match sni {
ServerName::DnsName(name) => {
assert_eq!(name.as_ref(), "localhost");
}
other => panic!("expected ServerName::DnsName, got {other:?}"),
}
}
#[test]
fn test_parse_remote_target_missing_port() {
let err = parse_remote_target("127.0.0.1").unwrap_err();
assert!(matches!(err, RemoteTargetError::MissingPort(_)));
}
#[test]
fn test_parse_remote_target_invalid_port() {
let err = parse_remote_target("127.0.0.1:notaport").unwrap_err();
assert!(matches!(err, RemoteTargetError::InvalidPort { .. }));
}
#[test]
fn test_split_host_port_ipv6_no_port() {
let err = split_host_port("[::1]").unwrap_err();
assert!(matches!(err, RemoteTargetError::MissingPort(_)));
}
}