1use crate::connection_error::ConnectionError;
11use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
12use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
13use rustls::{
14 ClientConfig, DigitallySignedStruct, Error as RustlsError, RootCertStore, SignatureScheme,
15};
16use std::sync::Arc;
17use tokio::net::TcpStream;
18use tokio_rustls::TlsConnector;
19use tracing::{debug, warn};
20
21pub use tokio_rustls::client::TlsStream;
23
24#[derive(Debug, Clone)]
26pub struct TlsConfig {
27 pub use_tls: bool,
29 pub tls_verify_cert: bool,
31 pub tls_cert_path: Option<String>,
33}
34
35impl Default for TlsConfig {
36 fn default() -> Self {
37 Self {
38 use_tls: false,
39 tls_verify_cert: true,
40 tls_cert_path: None,
41 }
42 }
43}
44
45impl TlsConfig {
46 pub fn builder() -> TlsConfigBuilder {
58 TlsConfigBuilder::default()
59 }
60}
61
62#[derive(Debug, Clone)]
99pub struct TlsConfigBuilder {
100 use_tls: bool,
101 tls_verify_cert: bool,
102 tls_cert_path: Option<String>,
103}
104
105impl Default for TlsConfigBuilder {
106 fn default() -> Self {
107 Self {
108 use_tls: false,
109 tls_verify_cert: true, tls_cert_path: None,
111 }
112 }
113}
114
115impl TlsConfigBuilder {
116 pub fn enabled(mut self, use_tls: bool) -> Self {
120 self.use_tls = use_tls;
121 self
122 }
123
124 pub fn verify_cert(mut self, verify: bool) -> Self {
131 self.tls_verify_cert = verify;
132 self
133 }
134
135 pub fn cert_path<S: Into<String>>(mut self, path: S) -> Self {
139 self.tls_cert_path = Some(path.into());
140 self
141 }
142
143 pub fn build(self) -> TlsConfig {
145 TlsConfig {
146 use_tls: self.use_tls,
147 tls_verify_cert: self.tls_verify_cert,
148 tls_cert_path: self.tls_cert_path,
149 }
150 }
151}
152
153mod rustls_backend {
156 use super::*;
157
158 #[derive(Debug)]
160 pub struct CertificateLoadResult {
161 pub root_store: RootCertStore,
162 pub sources: Vec<String>,
163 }
164
165 #[derive(Debug)]
171 pub struct NoVerifier;
172
173 impl ServerCertVerifier for NoVerifier {
174 fn verify_server_cert(
175 &self,
176 _end_entity: &CertificateDer<'_>,
177 _intermediates: &[CertificateDer<'_>],
178 _server_name: &ServerName<'_>,
179 _ocsp_response: &[u8],
180 _now: UnixTime,
181 ) -> Result<ServerCertVerified, RustlsError> {
182 Ok(ServerCertVerified::assertion())
184 }
185
186 fn verify_tls12_signature(
187 &self,
188 _message: &[u8],
189 _cert: &CertificateDer<'_>,
190 _dss: &DigitallySignedStruct,
191 ) -> Result<HandshakeSignatureValid, RustlsError> {
192 Ok(HandshakeSignatureValid::assertion())
194 }
195
196 fn verify_tls13_signature(
197 &self,
198 _message: &[u8],
199 _cert: &CertificateDer<'_>,
200 _dss: &DigitallySignedStruct,
201 ) -> Result<HandshakeSignatureValid, RustlsError> {
202 Ok(HandshakeSignatureValid::assertion())
204 }
205
206 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
207 vec![
209 SignatureScheme::RSA_PKCS1_SHA1,
210 SignatureScheme::ECDSA_SHA1_Legacy,
211 SignatureScheme::RSA_PKCS1_SHA256,
212 SignatureScheme::ECDSA_NISTP256_SHA256,
213 SignatureScheme::RSA_PKCS1_SHA384,
214 SignatureScheme::ECDSA_NISTP384_SHA384,
215 SignatureScheme::RSA_PKCS1_SHA512,
216 SignatureScheme::ECDSA_NISTP521_SHA512,
217 SignatureScheme::RSA_PSS_SHA256,
218 SignatureScheme::RSA_PSS_SHA384,
219 SignatureScheme::RSA_PSS_SHA512,
220 SignatureScheme::ED25519,
221 SignatureScheme::ED448,
222 ]
223 }
224 }
225}
226
227pub struct TlsManager {
233 config: TlsConfig,
234 cached_connector: Arc<TlsConnector>,
239}
240
241impl std::fmt::Debug for TlsManager {
242 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
243 f.debug_struct("TlsManager")
244 .field("config", &self.config)
245 .field("cached_connector", &"<TlsConnector>")
246 .finish()
247 }
248}
249
250impl Clone for TlsManager {
251 fn clone(&self) -> Self {
252 Self {
253 config: self.config.clone(),
254 cached_connector: Arc::clone(&self.cached_connector),
255 }
256 }
257}
258
259impl TlsManager {
260 pub fn new(config: TlsConfig) -> Result<Self, anyhow::Error> {
266 let cert_result = Self::load_certificates_sync(&config)?;
268 let client_config = Self::create_optimized_config_inner(cert_result.root_store, &config)?;
269
270 debug!(
271 "TLS: Initialized with certificate sources: {}",
272 cert_result.sources.join(", ")
273 );
274
275 let cached_connector = Arc::new(TlsConnector::from(Arc::new(client_config)));
276
277 Ok(Self {
278 config,
279 cached_connector,
280 })
281 }
282
283 pub async fn handshake(
285 &self,
286 stream: TcpStream,
287 hostname: &str,
288 backend_name: &str,
289 ) -> Result<TlsStream<TcpStream>, anyhow::Error> {
290 use anyhow::Context;
291
292 debug!("TLS: Connecting to {} with cached config", hostname);
293
294 let domain = rustls_pki_types::ServerName::try_from(hostname)
295 .context("Invalid hostname for TLS")?
296 .to_owned();
297
298 self.cached_connector
299 .connect(domain, stream)
300 .await
301 .map_err(|e| {
302 ConnectionError::TlsHandshake {
303 backend: backend_name.to_string(),
304 source: Box::new(e),
305 }
306 .into()
307 })
308 }
309
310 fn load_certificates_sync(
312 config: &TlsConfig,
313 ) -> Result<rustls_backend::CertificateLoadResult, anyhow::Error> {
314 let mut root_store = RootCertStore::empty();
315 let mut sources = Vec::new();
316
317 if let Some(cert_path) = &config.tls_cert_path {
319 debug!("TLS: Loading custom CA certificate from: {}", cert_path);
320 Self::load_custom_certificate_sync(&mut root_store, cert_path)?;
321 sources.push("custom certificate".to_string());
322 }
323
324 let system_count = Self::load_system_certificates_sync(&mut root_store)?;
326 if system_count > 0 {
327 debug!(
328 "TLS: Loaded {} certificates from system store",
329 system_count
330 );
331 sources.push("system certificates".to_string());
332 }
333
334 if root_store.is_empty() {
336 debug!("TLS: No system certificates available, using Mozilla CA bundle fallback");
337 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
338 sources.push("Mozilla CA bundle".to_string());
339 }
340
341 Ok(rustls_backend::CertificateLoadResult {
342 root_store,
343 sources,
344 })
345 }
346
347 fn load_custom_certificate_sync(
349 root_store: &mut RootCertStore,
350 cert_path: &str,
351 ) -> Result<(), anyhow::Error> {
352 use anyhow::Context;
353
354 let cert_data = std::fs::read(cert_path)
355 .with_context(|| format!("Failed to read TLS certificate from {}", cert_path))?;
356
357 let certs = rustls_pemfile::certs(&mut cert_data.as_slice())
358 .collect::<Result<Vec<_>, _>>()
359 .context("Failed to parse TLS certificate")?;
360
361 for cert in certs {
362 root_store
363 .add(cert)
364 .context("Failed to add custom certificate to store")?;
365 }
366
367 Ok(())
368 }
369
370 fn load_system_certificates_sync(
372 root_store: &mut RootCertStore,
373 ) -> Result<usize, anyhow::Error> {
374 let cert_result = rustls_native_certs::load_native_certs();
375 let mut added_count = 0;
376
377 for cert in cert_result.certs {
378 if root_store.add(cert).is_ok() {
379 added_count += 1;
380 }
381 }
382
383 for error in cert_result.errors {
385 warn!("TLS: Certificate loading error: {}", error);
386 }
387
388 Ok(added_count)
389 }
390
391 fn create_optimized_config_inner(
393 root_store: RootCertStore,
394 config: &TlsConfig,
395 ) -> Result<ClientConfig, anyhow::Error> {
396 use anyhow::Context;
397 use rustls_backend::NoVerifier;
398
399 let mut client_config = if config.tls_verify_cert {
400 debug!("TLS: Certificate verification enabled with ring crypto provider");
401 ClientConfig::builder_with_provider(Arc::new(rustls::crypto::ring::default_provider()))
402 .with_safe_default_protocol_versions()
403 .context("Failed to create TLS config with ring provider")?
404 .with_root_certificates(root_store)
405 .with_no_client_auth()
406 } else {
407 warn!(
408 "TLS: Certificate verification DISABLED - this is insecure and should only be used for testing!"
409 );
410 ClientConfig::builder_with_provider(Arc::new(rustls::crypto::ring::default_provider()))
412 .with_safe_default_protocol_versions()
413 .context("Failed to create TLS config with ring provider")?
414 .dangerous()
415 .with_custom_certificate_verifier(Arc::new(NoVerifier))
416 .with_no_client_auth()
417 };
418
419 client_config.enable_early_data = true; client_config.resumption = rustls::client::Resumption::default(); Ok(client_config)
424 }
425}
426
427#[cfg(test)]
428mod tests {
429 use super::*;
430
431 #[test]
432 fn test_tls_config_default() {
433 let config = TlsConfig::default();
434 assert!(!config.use_tls);
435 assert!(config.tls_verify_cert);
436 assert!(config.tls_cert_path.is_none());
437 }
438
439 #[test]
440 fn test_tls_config_builder_default() {
441 let config = TlsConfig::builder().build();
442 assert!(!config.use_tls);
443 assert!(config.tls_verify_cert); assert!(config.tls_cert_path.is_none());
445 }
446
447 #[test]
448 fn test_tls_config_builder_enabled() {
449 let config = TlsConfig::builder().enabled(true).verify_cert(true).build();
450 assert!(config.use_tls);
451 assert!(config.tls_verify_cert);
452 assert!(config.tls_cert_path.is_none());
453 }
454
455 #[test]
456 fn test_tls_config_builder_with_cert_path() {
457 let config = TlsConfig::builder()
458 .enabled(true)
459 .verify_cert(true)
460 .cert_path("/path/to/cert.pem")
461 .build();
462 assert!(config.use_tls);
463 assert!(config.tls_verify_cert);
464 assert_eq!(config.tls_cert_path, Some("/path/to/cert.pem".to_string()));
465 }
466
467 #[test]
468 fn test_tls_config_builder_insecure() {
469 let config = TlsConfig::builder()
470 .enabled(true)
471 .verify_cert(false)
472 .build();
473 assert!(config.use_tls);
474 assert!(!config.tls_verify_cert);
475 assert!(config.tls_cert_path.is_none());
476 }
477
478 #[test]
479 fn test_tls_config_builder_fluent_api() {
480 let config = TlsConfig::builder()
481 .enabled(true)
482 .verify_cert(true)
483 .cert_path("/custom/ca.pem".to_string())
484 .build();
485 assert!(config.use_tls);
486 assert!(config.tls_verify_cert);
487 assert_eq!(config.tls_cert_path, Some("/custom/ca.pem".to_string()));
488 }
489
490 #[test]
491 fn test_tls_manager_creation() {
492 let config = TlsConfig::default();
493 let manager = TlsManager::new(config).unwrap();
494 assert!(Arc::strong_count(&manager.cached_connector) >= 1);
496 }
497
498 #[test]
499 fn test_certificate_loading() {
500 let config = TlsConfig::default();
501
502 let result = TlsManager::load_certificates_sync(&config).unwrap();
503 assert!(!result.root_store.is_empty());
504 assert!(!result.sources.is_empty());
506 assert!(
508 result
509 .sources
510 .iter()
511 .any(|s| s.contains("Mozilla") || s.contains("system"))
512 );
513 }
514}