1use std::io::BufReader;
8use std::sync::{Arc, OnceLock};
9
10use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
11use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
12use rustls::{ClientConfig, DigitallySignedStruct, RootCertStore, SignatureScheme};
13
14use crate::errors::MarketDataError;
15
16#[derive(Clone, Debug, Default)]
19pub struct TlsConfig {
20 pub root_cert_pem: Option<Vec<u8>>,
24
25 pub accept_invalid_certs: bool,
30}
31
32static PROVIDER_INSTALLED: OnceLock<()> = OnceLock::new();
33static SYSTEM_ROOTS: OnceLock<Arc<RootCertStore>> = OnceLock::new();
34
35fn install_crypto_provider() {
36 PROVIDER_INSTALLED.get_or_init(|| {
37 let _ = rustls::crypto::ring::default_provider().install_default();
40 });
41}
42
43fn system_root_store() -> &'static Arc<RootCertStore> {
44 SYSTEM_ROOTS.get_or_init(|| {
45 let mut store = RootCertStore::empty();
46 let loaded = rustls_native_certs::load_native_certs();
50 for cert in loaded.certs {
51 let _ = store.add(cert);
52 }
53 Arc::new(store)
54 })
55}
56
57pub fn build_rustls_config(tls: &TlsConfig) -> Result<Arc<ClientConfig>, MarketDataError> {
67 install_crypto_provider();
68
69 if tls.accept_invalid_certs {
70 let config = ClientConfig::builder()
73 .dangerous()
74 .with_custom_certificate_verifier(Arc::new(AlwaysTrustVerifier))
75 .with_no_client_auth();
76 return Ok(Arc::new(config));
77 }
78
79 let mut store = (**system_root_store()).clone();
81
82 if let Some(pem) = &tls.root_cert_pem {
83 let mut reader = BufReader::new(pem.as_slice());
84 for cert_result in rustls_pemfile::certs(&mut reader) {
85 let cert = cert_result.map_err(|e| {
86 MarketDataError::ConfigError(format!("invalid TLS root cert PEM: {e}"))
87 })?;
88 store.add(cert).map_err(|e| {
89 MarketDataError::ConfigError(format!("failed to add root cert: {e}"))
90 })?;
91 }
92 }
93
94 let config = ClientConfig::builder()
95 .with_root_certificates(store)
96 .with_no_client_auth();
97 Ok(Arc::new(config))
98}
99
100#[derive(Debug)]
103struct AlwaysTrustVerifier;
104
105impl ServerCertVerifier for AlwaysTrustVerifier {
106 fn verify_server_cert(
107 &self,
108 _end_entity: &CertificateDer<'_>,
109 _intermediates: &[CertificateDer<'_>],
110 _server_name: &ServerName<'_>,
111 _ocsp_response: &[u8],
112 _now: UnixTime,
113 ) -> Result<ServerCertVerified, rustls::Error> {
114 Ok(ServerCertVerified::assertion())
115 }
116
117 fn verify_tls12_signature(
118 &self,
119 _message: &[u8],
120 _cert: &CertificateDer<'_>,
121 _dss: &DigitallySignedStruct,
122 ) -> Result<HandshakeSignatureValid, rustls::Error> {
123 Ok(HandshakeSignatureValid::assertion())
124 }
125
126 fn verify_tls13_signature(
127 &self,
128 _message: &[u8],
129 _cert: &CertificateDer<'_>,
130 _dss: &DigitallySignedStruct,
131 ) -> Result<HandshakeSignatureValid, rustls::Error> {
132 Ok(HandshakeSignatureValid::assertion())
133 }
134
135 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
136 vec![
137 SignatureScheme::RSA_PKCS1_SHA256,
138 SignatureScheme::RSA_PKCS1_SHA384,
139 SignatureScheme::RSA_PKCS1_SHA512,
140 SignatureScheme::ECDSA_NISTP256_SHA256,
141 SignatureScheme::ECDSA_NISTP384_SHA384,
142 SignatureScheme::RSA_PSS_SHA256,
143 SignatureScheme::RSA_PSS_SHA384,
144 SignatureScheme::RSA_PSS_SHA512,
145 SignatureScheme::ED25519,
146 ]
147 }
148}
149
150#[cfg(test)]
151mod tests {
152 use super::*;
153
154 #[test]
155 fn default_config_builds_rustls_config() {
156 let cfg = TlsConfig::default();
157 let _ = build_rustls_config(&cfg).expect("default should always build");
158 }
159
160 #[test]
161 fn accept_invalid_certs_builds_rustls_config() {
162 let cfg = TlsConfig {
163 accept_invalid_certs: true,
164 ..Default::default()
165 };
166 let _ = build_rustls_config(&cfg).expect("should build");
167 }
168
169 #[test]
170 fn invalid_pem_is_config_error() {
171 let cfg = TlsConfig {
172 root_cert_pem: Some(b"not a real pem".to_vec()),
173 ..Default::default()
174 };
175 let cfg_ok = build_rustls_config(&cfg);
181 assert!(cfg_ok.is_ok(), "garbage non-PEM should parse to zero certs, not error");
182 }
183}