use std::io::BufReader;
use std::sync::{Arc, OnceLock};
use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
use rustls::{ClientConfig, DigitallySignedStruct, RootCertStore, SignatureScheme};
use crate::errors::MarketDataError;
#[derive(Clone, Debug, Default)]
pub struct TlsConfig {
pub root_cert_pem: Option<Vec<u8>>,
pub accept_invalid_certs: bool,
}
static PROVIDER_INSTALLED: OnceLock<()> = OnceLock::new();
static SYSTEM_ROOTS: OnceLock<Arc<RootCertStore>> = OnceLock::new();
fn install_crypto_provider() {
PROVIDER_INSTALLED.get_or_init(|| {
let _ = rustls::crypto::ring::default_provider().install_default();
});
}
fn system_root_store() -> &'static Arc<RootCertStore> {
SYSTEM_ROOTS.get_or_init(|| {
let mut store = RootCertStore::empty();
let loaded = rustls_native_certs::load_native_certs();
for cert in loaded.certs {
let _ = store.add(cert);
}
Arc::new(store)
})
}
pub fn build_rustls_config(tls: &TlsConfig) -> Result<Arc<ClientConfig>, MarketDataError> {
install_crypto_provider();
if tls.accept_invalid_certs {
let config = ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(AlwaysTrustVerifier))
.with_no_client_auth();
return Ok(Arc::new(config));
}
let mut store = (**system_root_store()).clone();
if let Some(pem) = &tls.root_cert_pem {
let mut reader = BufReader::new(pem.as_slice());
for cert_result in rustls_pemfile::certs(&mut reader) {
let cert = cert_result.map_err(|e| {
MarketDataError::ConfigError(format!("invalid TLS root cert PEM: {e}"))
})?;
store.add(cert).map_err(|e| {
MarketDataError::ConfigError(format!("failed to add root cert: {e}"))
})?;
}
}
let config = ClientConfig::builder()
.with_root_certificates(store)
.with_no_client_auth();
Ok(Arc::new(config))
}
#[derive(Debug)]
struct AlwaysTrustVerifier;
impl ServerCertVerifier for AlwaysTrustVerifier {
fn verify_server_cert(
&self,
_end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &ServerName<'_>,
_ocsp_response: &[u8],
_now: UnixTime,
) -> Result<ServerCertVerified, rustls::Error> {
Ok(ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
Ok(HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
Ok(HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
vec![
SignatureScheme::RSA_PKCS1_SHA256,
SignatureScheme::RSA_PKCS1_SHA384,
SignatureScheme::RSA_PKCS1_SHA512,
SignatureScheme::ECDSA_NISTP256_SHA256,
SignatureScheme::ECDSA_NISTP384_SHA384,
SignatureScheme::RSA_PSS_SHA256,
SignatureScheme::RSA_PSS_SHA384,
SignatureScheme::RSA_PSS_SHA512,
SignatureScheme::ED25519,
]
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_config_builds_rustls_config() {
let cfg = TlsConfig::default();
let _ = build_rustls_config(&cfg).expect("default should always build");
}
#[test]
fn accept_invalid_certs_builds_rustls_config() {
let cfg = TlsConfig {
accept_invalid_certs: true,
..Default::default()
};
let _ = build_rustls_config(&cfg).expect("should build");
}
#[test]
fn invalid_pem_is_config_error() {
let cfg = TlsConfig {
root_cert_pem: Some(b"not a real pem".to_vec()),
..Default::default()
};
let cfg_ok = build_rustls_config(&cfg);
assert!(cfg_ok.is_ok(), "garbage non-PEM should parse to zero certs, not error");
}
}