use std::fs::File;
use std::io::{self, BufReader, Read, Write};
use std::net::{Shutdown, TcpStream};
use std::sync::Arc;
use std::time::Duration;
use super::{TlsConfig, TlsStream};
use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
use rustls::pki_types::pem::PemObject;
use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime};
use rustls::{ClientConfig, ClientConnection, DigitallySignedStruct, RootCertStore, StreamOwned};
use tracing::debug;
pub struct RustlsStream {
inner: StreamOwned<ClientConnection, TcpStream>,
}
impl RustlsStream {
fn new(stream: StreamOwned<ClientConnection, TcpStream>) -> Self {
RustlsStream { inner: stream }
}
fn get_ref(&self) -> &TcpStream {
self.inner.get_ref()
}
}
impl TlsStream for RustlsStream {
fn is_secured(&self) -> bool {
true
}
fn set_read_timeout(&mut self, dur: Option<Duration>) -> io::Result<()> {
self.get_ref().set_read_timeout(dur)
}
fn set_write_timeout(&mut self, dur: Option<Duration>) -> io::Result<()> {
self.get_ref().set_write_timeout(dur)
}
fn shutdown(&mut self) -> io::Result<()> {
self.get_ref().shutdown(Shutdown::Both)
}
}
impl Read for RustlsStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.inner.read(buf)
}
}
impl Write for RustlsStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.inner.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.inner.flush()
}
}
#[derive(Debug)]
struct InsecureVerifier;
impl ServerCertVerifier for InsecureVerifier {
fn verify_server_cert(
&self,
_end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &ServerName<'_>,
_ocsp_response: &[u8],
_now: UnixTime,
) -> Result<ServerCertVerified, rustls::Error> {
Ok(ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
Ok(HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
Ok(HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
vec![
rustls::SignatureScheme::RSA_PKCS1_SHA256,
rustls::SignatureScheme::RSA_PKCS1_SHA384,
rustls::SignatureScheme::RSA_PKCS1_SHA512,
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
rustls::SignatureScheme::ECDSA_NISTP521_SHA512,
rustls::SignatureScheme::RSA_PSS_SHA256,
rustls::SignatureScheme::RSA_PSS_SHA384,
rustls::SignatureScheme::RSA_PSS_SHA512,
rustls::SignatureScheme::ED25519,
]
}
}
pub struct RustlsConnector {
config: Arc<ClientConfig>,
}
impl RustlsConnector {
pub fn new(tls_config: &TlsConfig) -> io::Result<Self> {
let provider = {
#[cfg(feature = "security-ring")]
{
rustls::crypto::ring::default_provider()
}
#[cfg(not(feature = "security-ring"))]
{
rustls::crypto::aws_lc_rs::default_provider()
}
};
let config = if tls_config.verify_hostname {
let root_store = Self::load_root_store(tls_config)?;
let config_builder = ClientConfig::builder_with_provider(Arc::new(provider))
.with_safe_default_protocol_versions()
.map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("Failed to set protocol versions: {e}"),
)
})?
.with_root_certificates(root_store);
if let (Some(cert_path), Some(key_path)) =
(&tls_config.client_cert_path, &tls_config.client_key_path)
{
Self::load_client_auth(config_builder, cert_path, key_path)?
} else {
config_builder.with_no_client_auth()
}
} else {
ClientConfig::builder_with_provider(Arc::new(provider))
.with_safe_default_protocol_versions()
.map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("Failed to set protocol versions: {e}"),
)
})?
.dangerous()
.with_custom_certificate_verifier(Arc::new(InsecureVerifier))
.with_no_client_auth()
};
Ok(RustlsConnector {
config: Arc::new(config),
})
}
fn load_root_store(tls_config: &TlsConfig) -> io::Result<RootCertStore> {
let mut root_store = RootCertStore::empty();
if let Some(ca_cert_path) = &tls_config.ca_cert_path {
let ca_file = File::open(ca_cert_path).map_err(|e| {
io::Error::new(
io::ErrorKind::NotFound,
format!("Failed to open CA cert file: {e}"),
)
})?;
let mut ca_reader = BufReader::new(ca_file);
let certs: Vec<CertificateDer> = CertificateDer::pem_reader_iter(&mut ca_reader)
.collect::<Result<Vec<_>, _>>()
.map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("Failed to parse CA cert: {e}"),
)
})?;
for cert in certs {
root_store.add(cert).map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("Failed to add CA cert: {e}"),
)
})?;
}
} else {
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
let native_certs = rustls_native_certs::load_native_certs();
for cert in native_certs.certs {
let _ = root_store.add(cert);
}
if let Some(e) = native_certs.errors.first() {
debug!(
"Failed to load some native certs (using webpki-roots as fallback): {}",
e
);
}
}
Ok(root_store)
}
fn load_client_auth(
builder: rustls::ConfigBuilder<ClientConfig, rustls::client::WantsClientCert>,
cert_path: &str,
key_path: &str,
) -> io::Result<ClientConfig> {
let cert_file = File::open(cert_path).map_err(|e| {
io::Error::new(
io::ErrorKind::NotFound,
format!("Failed to open client cert file: {e}"),
)
})?;
let mut cert_reader = BufReader::new(cert_file);
let certs: Vec<CertificateDer> = CertificateDer::pem_reader_iter(&mut cert_reader)
.collect::<Result<Vec<_>, _>>()
.map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("Failed to parse client cert: {e}"),
)
})?;
let key_file = File::open(key_path).map_err(|e| {
io::Error::new(
io::ErrorKind::NotFound,
format!("Failed to open client key file: {e}"),
)
})?;
let mut key_reader = BufReader::new(key_file);
let key = PrivateKeyDer::from_pem_reader(&mut key_reader).map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("Failed to parse private key: {e}"),
)
})?;
builder.with_client_auth_cert(certs, key).map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("Failed to set client auth: {e}"),
)
})
}
pub fn connect(&self, domain: &str, tcp_stream: TcpStream) -> io::Result<Box<dyn TlsStream>> {
let server_name = ServerName::try_from(domain)
.map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
format!("Invalid DNS name: {domain}"),
)
})?
.to_owned();
let conn = ClientConnection::new(self.config.clone(), server_name)
.map_err(|e| io::Error::other(format!("TLS connection error: {e}")))?;
let mut stream = StreamOwned::new(conn, tcp_stream);
stream
.conn
.complete_io(&mut stream.sock)
.map_err(|e| io::Error::other(format!("TLS handshake failed: {e}")))?;
Ok(Box::new(RustlsStream::new(stream)))
}
}
impl std::fmt::Debug for RustlsConnector {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "RustlsConnector {{ config: <redacted> }}")
}
}