1use crate::connection_error::ConnectionError;
11use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
12use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
13use rustls::{ClientConfig, DigitallySignedStruct, Error as RustlsError, RootCertStore, SignatureScheme};
14use std::sync::Arc;
15use tokio::net::TcpStream;
16use tokio_rustls::{TlsConnector, client::TlsStream};
17use tracing::{debug, warn};
18
19#[derive(Debug, Clone)]
21pub struct TlsConfig {
22 pub use_tls: bool,
24 pub tls_verify_cert: bool,
26 pub tls_cert_path: Option<String>,
28}
29
30impl Default for TlsConfig {
31 fn default() -> Self {
32 Self {
33 use_tls: false,
34 tls_verify_cert: true,
35 tls_cert_path: None,
36 }
37 }
38}
39
40#[derive(Debug)]
42pub struct CertificateLoadResult {
43 pub root_store: RootCertStore,
44 pub sources: Vec<String>,
45}
46
47#[derive(Debug)]
53struct NoVerifier;
54
55impl ServerCertVerifier for NoVerifier {
56 fn verify_server_cert(
57 &self,
58 _end_entity: &CertificateDer<'_>,
59 _intermediates: &[CertificateDer<'_>],
60 _server_name: &ServerName<'_>,
61 _ocsp_response: &[u8],
62 _now: UnixTime,
63 ) -> Result<ServerCertVerified, RustlsError> {
64 Ok(ServerCertVerified::assertion())
66 }
67
68 fn verify_tls12_signature(
69 &self,
70 _message: &[u8],
71 _cert: &CertificateDer<'_>,
72 _dss: &DigitallySignedStruct,
73 ) -> Result<HandshakeSignatureValid, RustlsError> {
74 Ok(HandshakeSignatureValid::assertion())
76 }
77
78 fn verify_tls13_signature(
79 &self,
80 _message: &[u8],
81 _cert: &CertificateDer<'_>,
82 _dss: &DigitallySignedStruct,
83 ) -> Result<HandshakeSignatureValid, RustlsError> {
84 Ok(HandshakeSignatureValid::assertion())
86 }
87
88 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
89 vec![
91 SignatureScheme::RSA_PKCS1_SHA1,
92 SignatureScheme::ECDSA_SHA1_Legacy,
93 SignatureScheme::RSA_PKCS1_SHA256,
94 SignatureScheme::ECDSA_NISTP256_SHA256,
95 SignatureScheme::RSA_PKCS1_SHA384,
96 SignatureScheme::ECDSA_NISTP384_SHA384,
97 SignatureScheme::RSA_PKCS1_SHA512,
98 SignatureScheme::ECDSA_NISTP521_SHA512,
99 SignatureScheme::RSA_PSS_SHA256,
100 SignatureScheme::RSA_PSS_SHA384,
101 SignatureScheme::RSA_PSS_SHA512,
102 SignatureScheme::ED25519,
103 SignatureScheme::ED448,
104 ]
105 }
106}
107
108pub struct TlsManager {
110 config: TlsConfig,
111}
112
113impl TlsManager {
114 pub fn new(config: TlsConfig) -> Self {
116 Self { config }
117 }
118
119 pub async fn handshake(
121 &self,
122 stream: TcpStream,
123 hostname: &str,
124 backend_name: &str,
125 ) -> Result<TlsStream<TcpStream>, anyhow::Error> {
126 let cert_result = self.load_certificates().await?;
127 let client_config = self.create_optimized_config(cert_result.root_store)?;
128
129 debug!(
130 "TLS: Certificate sources: {}",
131 cert_result.sources.join(", ")
132 );
133
134 let connector = TlsConnector::from(Arc::new(client_config));
135 let domain = rustls_pki_types::ServerName::try_from(hostname)
136 .map_err(|e| anyhow::anyhow!("Invalid hostname for TLS: {}", e))?
137 .to_owned();
138
139 debug!("TLS: Connecting to {} with rustls", hostname);
140 connector.connect(domain, stream).await.map_err(|e| {
141 ConnectionError::TlsHandshake {
142 backend: backend_name.to_string(),
143 source: Box::new(e),
144 }
145 .into()
146 })
147 }
148
149 async fn load_certificates(&self) -> Result<CertificateLoadResult, anyhow::Error> {
151 let mut root_store = RootCertStore::empty();
152 let mut sources = Vec::new();
153
154 if let Some(cert_path) = &self.config.tls_cert_path {
156 debug!("TLS: Loading custom CA certificate from: {}", cert_path);
157 self.load_custom_certificate(&mut root_store, cert_path)?;
158 sources.push("custom certificate".to_string());
159 }
160
161 let system_count = self.load_system_certificates(&mut root_store)?;
163 if system_count > 0 {
164 debug!(
165 "TLS: Loaded {} certificates from system store",
166 system_count
167 );
168 sources.push("system certificates".to_string());
169 }
170
171 if root_store.is_empty() {
173 debug!("TLS: No system certificates available, using Mozilla CA bundle fallback");
174 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
175 sources.push("Mozilla CA bundle".to_string());
176 }
177
178 Ok(CertificateLoadResult {
179 root_store,
180 sources,
181 })
182 }
183
184 fn load_custom_certificate(
186 &self,
187 root_store: &mut RootCertStore,
188 cert_path: &str,
189 ) -> Result<(), anyhow::Error> {
190 let cert_data = std::fs::read(cert_path).map_err(|e| {
191 anyhow::anyhow!("Failed to read TLS certificate from {}: {}", cert_path, e)
192 })?;
193
194 let certs = rustls_pemfile::certs(&mut cert_data.as_slice())
195 .collect::<Result<Vec<_>, _>>()
196 .map_err(|e| anyhow::anyhow!("Failed to parse TLS certificate: {}", e))?;
197
198 for cert in certs {
199 root_store
200 .add(cert)
201 .map_err(|e| anyhow::anyhow!("Failed to add custom certificate to store: {}", e))?;
202 }
203
204 Ok(())
205 }
206
207 fn load_system_certificates(
209 &self,
210 root_store: &mut RootCertStore,
211 ) -> Result<usize, anyhow::Error> {
212 let cert_result = rustls_native_certs::load_native_certs();
213 let mut added_count = 0;
214
215 for cert in cert_result.certs {
216 if root_store.add(cert).is_ok() {
217 added_count += 1;
218 }
219 }
220
221 for error in cert_result.errors {
223 warn!("TLS: Certificate loading error: {}", error);
224 }
225
226 Ok(added_count)
227 }
228
229 fn create_optimized_config(
231 &self,
232 root_store: RootCertStore,
233 ) -> Result<ClientConfig, anyhow::Error> {
234 let mut config = if self.config.tls_verify_cert {
235 debug!("TLS: Certificate verification enabled with ring crypto provider");
236 ClientConfig::builder_with_provider(Arc::new(rustls::crypto::ring::default_provider()))
237 .with_safe_default_protocol_versions()
238 .map_err(|e| {
239 anyhow::anyhow!("Failed to create TLS config with ring provider: {}", e)
240 })?
241 .with_root_certificates(root_store)
242 .with_no_client_auth()
243 } else {
244 warn!("TLS: Certificate verification DISABLED - this is insecure and should only be used for testing!");
245 ClientConfig::builder_with_provider(Arc::new(rustls::crypto::ring::default_provider()))
247 .with_safe_default_protocol_versions()
248 .map_err(|e| {
249 anyhow::anyhow!("Failed to create TLS config with ring provider: {}", e)
250 })?
251 .dangerous()
252 .with_custom_certificate_verifier(Arc::new(NoVerifier))
253 .with_no_client_auth()
254 };
255
256 config.enable_early_data = true; config.resumption = rustls::client::Resumption::default(); Ok(config)
261 }
262}
263
264#[cfg(test)]
265mod tests {
266 use super::*;
267
268 #[test]
269 fn test_tls_config_default() {
270 let config = TlsConfig::default();
271 assert!(!config.use_tls);
272 assert!(config.tls_verify_cert);
273 assert!(config.tls_cert_path.is_none());
274 }
275
276 #[test]
277 fn test_tls_manager_creation() {
278 let config = TlsConfig::default();
279 let _manager = TlsManager::new(config);
280 }
281
282 #[tokio::test]
283 async fn test_certificate_loading() {
284 let config = TlsConfig::default();
285 let manager = TlsManager::new(config);
286
287 let result = manager.load_certificates().await.unwrap();
288 assert!(!result.root_store.is_empty());
289 assert!(!result.sources.is_empty());
291 assert!(
293 result
294 .sources
295 .iter()
296 .any(|s| s.contains("Mozilla") || s.contains("system"))
297 );
298 }
299}