use std::sync::{Arc, Once};
use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
use rustls::{ClientConfig, DigitallySignedStruct, RootCertStore, SignatureScheme};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_rustls::TlsConnector as TokioTlsConnector;
use tokio_rustls::client::TlsStream;
use crate::config::{TlsConfig, TlsVersion};
use crate::error::TlsError;
static CRYPTO_PROVIDER_INIT: Once = Once::new();
fn ensure_crypto_provider() {
CRYPTO_PROVIDER_INIT.call_once(|| {
let _ = rustls::crypto::ring::default_provider().install_default();
});
}
#[derive(Debug)]
struct DangerousServerCertVerifier;
impl ServerCertVerifier for DangerousServerCertVerifier {
fn verify_server_cert(
&self,
_end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &ServerName<'_>,
_ocsp_response: &[u8],
_now: UnixTime,
) -> Result<ServerCertVerified, rustls::Error> {
Ok(ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
Ok(HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
Ok(HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
vec![
SignatureScheme::RSA_PKCS1_SHA256,
SignatureScheme::RSA_PKCS1_SHA384,
SignatureScheme::RSA_PKCS1_SHA512,
SignatureScheme::ECDSA_NISTP256_SHA256,
SignatureScheme::ECDSA_NISTP384_SHA384,
SignatureScheme::ECDSA_NISTP521_SHA512,
SignatureScheme::RSA_PSS_SHA256,
SignatureScheme::RSA_PSS_SHA384,
SignatureScheme::RSA_PSS_SHA512,
SignatureScheme::ED25519,
]
}
}
pub fn default_tls_config() -> Result<ClientConfig, TlsError> {
ensure_crypto_provider();
let root_store = RootCertStore {
roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
};
let config = ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth();
Ok(config)
}
pub struct TlsConnector {
config: TlsConfig,
inner: TokioTlsConnector,
}
impl TlsConnector {
pub fn new(config: TlsConfig) -> Result<Self, TlsError> {
let client_config = Self::build_client_config(&config)?;
let inner = TokioTlsConnector::from(Arc::new(client_config));
Ok(Self { config, inner })
}
fn build_client_config(config: &TlsConfig) -> Result<ClientConfig, TlsError> {
ensure_crypto_provider();
let versions: Vec<&'static rustls::SupportedProtocolVersion> =
Self::select_versions(config);
if config.strict_mode && config.trust_server_certificate {
return Err(TlsError::Configuration(
"TrustServerCertificate=true is not allowed in TDS 8.0 strict mode. \
Strict mode requires server certificate validation to prevent \
man-in-the-middle attacks."
.into(),
));
}
if config.trust_server_certificate {
tracing::warn!(
"TrustServerCertificate is enabled - certificate validation is DISABLED. \
This is insecure and should only be used for development/testing. \
Connections are vulnerable to man-in-the-middle attacks."
);
let mut client_config = ClientConfig::builder_with_protocol_versions(&versions)
.dangerous()
.with_custom_certificate_verifier(Arc::new(DangerousServerCertVerifier))
.with_no_client_auth();
if !config.alpn_protocols.is_empty() {
client_config.alpn_protocols = config.alpn_protocols.clone();
}
return Ok(client_config);
}
let root_store = Self::build_root_store(config)?;
let builder = ClientConfig::builder_with_protocol_versions(&versions)
.with_root_certificates(root_store);
let mut client_config = if let Some(client_auth) = &config.client_auth {
let key = match client_auth.key.as_ref() {
rustls::pki_types::PrivateKeyDer::Pkcs1(key) => {
rustls::pki_types::PrivateKeyDer::Pkcs1(key.clone_key())
}
rustls::pki_types::PrivateKeyDer::Sec1(key) => {
rustls::pki_types::PrivateKeyDer::Sec1(key.clone_key())
}
rustls::pki_types::PrivateKeyDer::Pkcs8(key) => {
rustls::pki_types::PrivateKeyDer::Pkcs8(key.clone_key())
}
_ => {
return Err(TlsError::Configuration(
"unsupported private key format".into(),
));
}
};
builder
.with_client_auth_cert(client_auth.certificates.clone(), key)
.map_err(|e| TlsError::Configuration(format!("client auth setup failed: {e}")))?
} else {
builder.with_no_client_auth()
};
if !config.alpn_protocols.is_empty() {
client_config.alpn_protocols = config.alpn_protocols.clone();
}
Ok(client_config)
}
fn build_root_store(config: &TlsConfig) -> Result<RootCertStore, TlsError> {
let mut root_store = RootCertStore::empty();
if config.trust_server_certificate {
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
} else if config.root_certificates.is_empty() {
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
} else {
for cert in &config.root_certificates {
root_store
.add(cert.clone())
.map_err(|e| TlsError::InvalidCertificate(e.to_string()))?;
}
}
Ok(root_store)
}
fn select_versions(config: &TlsConfig) -> Vec<&'static rustls::SupportedProtocolVersion> {
let mut versions = Vec::new();
if config.min_protocol_version <= TlsVersion::Tls12
&& config.max_protocol_version >= TlsVersion::Tls12
{
versions.push(&rustls::version::TLS12);
}
if config.min_protocol_version <= TlsVersion::Tls13
&& config.max_protocol_version >= TlsVersion::Tls13
{
versions.push(&rustls::version::TLS13);
}
if versions.is_empty() {
versions.push(&rustls::version::TLS12);
}
versions
}
pub async fn connect<S>(&self, stream: S, server_name: &str) -> Result<TlsStream<S>, TlsError>
where
S: AsyncRead + AsyncWrite + Unpin,
{
let server_name = self.config.server_name.as_deref().unwrap_or(server_name);
let dns_name = ServerName::try_from(server_name.to_string()).map_err(|_| {
TlsError::HostnameVerification {
expected: server_name.to_string(),
actual: "invalid DNS name".to_string(),
}
})?;
tracing::debug!(server_name = %server_name, "performing TLS handshake");
let tls_stream = self
.inner
.connect(dns_name, stream)
.await
.map_err(|e| TlsError::HandshakeFailed(e.to_string()))?;
tracing::debug!("TLS handshake completed successfully");
Ok(tls_stream)
}
pub async fn connect_with_prelogin<S>(
&self,
stream: S,
server_name: &str,
) -> Result<TlsStream<crate::TlsPreloginWrapper<S>>, TlsError>
where
S: AsyncRead + AsyncWrite + Unpin,
{
let server_name = self.config.server_name.as_deref().unwrap_or(server_name);
let dns_name = ServerName::try_from(server_name.to_string()).map_err(|_| {
TlsError::HostnameVerification {
expected: server_name.to_string(),
actual: "invalid DNS name".to_string(),
}
})?;
tracing::debug!(server_name = %server_name, "performing TLS handshake (PreLogin wrapped)");
let wrapper = crate::TlsPreloginWrapper::new(stream);
let mut tls_stream = self
.inner
.connect(dns_name, wrapper)
.await
.map_err(|e| TlsError::HandshakeFailed(e.to_string()))?;
tls_stream.get_mut().0.handshake_complete();
tracing::debug!("TLS handshake completed successfully (PreLogin wrapped)");
Ok(tls_stream)
}
#[must_use]
pub fn is_strict_mode(&self) -> bool {
self.config.strict_mode
}
#[must_use]
pub fn config(&self) -> &TlsConfig {
&self.config
}
}
impl std::fmt::Debug for TlsConnector {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TlsConnector")
.field("config", &self.config)
.finish_non_exhaustive()
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
fn setup_crypto_provider() {
let _ = rustls::crypto::ring::default_provider().install_default();
}
#[test]
fn test_default_config() {
setup_crypto_provider();
let config = TlsConfig::default();
let connector = TlsConnector::new(config);
assert!(connector.is_ok());
}
#[test]
fn test_trust_server_certificate() {
setup_crypto_provider();
let config = TlsConfig::new().trust_server_certificate(true);
let connector = TlsConnector::new(config).unwrap();
assert!(!connector.is_strict_mode());
}
#[test]
fn test_strict_mode() {
setup_crypto_provider();
let config = TlsConfig::new().strict_mode(true);
let connector = TlsConnector::new(config).unwrap();
assert!(connector.is_strict_mode());
}
}