fraiseql_wire/connection/
tls.rs1use crate::{Error, Result};
7use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
8use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
9use rustls::RootCertStore;
10use rustls::{ClientConfig, DigitallySignedStruct, SignatureScheme};
11use rustls_pemfile::Item;
12use std::fmt::Debug;
13use std::fs;
14use std::sync::Arc;
15
16#[derive(Clone)]
44pub struct TlsConfig {
45 ca_cert_path: Option<String>,
47 verify_hostname: bool,
49 danger_accept_invalid_certs: bool,
51 danger_accept_invalid_hostnames: bool,
53 client_config: Arc<ClientConfig>,
55}
56
57impl TlsConfig {
58 pub fn builder() -> TlsConfigBuilder {
68 TlsConfigBuilder::default()
69 }
70
71 pub fn client_config(&self) -> Arc<ClientConfig> {
73 self.client_config.clone()
74 }
75
76 pub fn verify_hostname(&self) -> bool {
78 self.verify_hostname
79 }
80
81 pub fn danger_accept_invalid_certs(&self) -> bool {
83 self.danger_accept_invalid_certs
84 }
85
86 pub fn danger_accept_invalid_hostnames(&self) -> bool {
88 self.danger_accept_invalid_hostnames
89 }
90}
91
92impl std::fmt::Debug for TlsConfig {
93 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94 f.debug_struct("TlsConfig")
95 .field("ca_cert_path", &self.ca_cert_path)
96 .field("verify_hostname", &self.verify_hostname)
97 .field(
98 "danger_accept_invalid_certs",
99 &self.danger_accept_invalid_certs,
100 )
101 .field(
102 "danger_accept_invalid_hostnames",
103 &self.danger_accept_invalid_hostnames,
104 )
105 .field("client_config", &"<ClientConfig>")
106 .finish()
107 }
108}
109
110pub struct TlsConfigBuilder {
114 ca_cert_path: Option<String>,
115 verify_hostname: bool,
116 danger_accept_invalid_certs: bool,
117 danger_accept_invalid_hostnames: bool,
118}
119
120impl Default for TlsConfigBuilder {
121 fn default() -> Self {
122 Self {
123 ca_cert_path: None,
124 verify_hostname: true,
125 danger_accept_invalid_certs: false,
126 danger_accept_invalid_hostnames: false,
127 }
128 }
129}
130
131impl TlsConfigBuilder {
132 pub fn ca_cert_path(mut self, path: impl Into<String>) -> Self {
148 self.ca_cert_path = Some(path.into());
149 self
150 }
151
152 pub fn verify_hostname(mut self, verify: bool) -> Self {
169 self.verify_hostname = verify;
170 self
171 }
172
173 pub fn danger_accept_invalid_certs(mut self, accept: bool) -> Self {
188 self.danger_accept_invalid_certs = accept;
189 self
190 }
191
192 pub fn danger_accept_invalid_hostnames(mut self, accept: bool) -> Self {
208 self.danger_accept_invalid_hostnames = accept;
209 self
210 }
211
212 pub fn build(self) -> Result<TlsConfig> {
229 validate_tls_security(self.danger_accept_invalid_certs);
231
232 let client_config = if self.danger_accept_invalid_certs {
233 let verifier = Arc::new(NoVerifier);
235 Arc::new(
236 ClientConfig::builder()
237 .dangerous()
238 .with_custom_certificate_verifier(verifier)
239 .with_no_client_auth(),
240 )
241 } else {
242 let root_store = if let Some(ca_path) = &self.ca_cert_path {
244 self.load_custom_ca(ca_path)?
246 } else {
247 let result = rustls_native_certs::load_native_certs();
249
250 let mut store = RootCertStore::empty();
251 for cert in result.certs {
252 let _ = store.add_parsable_certificates(std::iter::once(cert));
253 }
254
255 if !result.errors.is_empty() && store.is_empty() {
257 return Err(Error::Config(
258 "Failed to load any system root certificates".to_string(),
259 ));
260 }
261
262 store
263 };
264
265 Arc::new(
267 ClientConfig::builder()
268 .with_root_certificates(root_store)
269 .with_no_client_auth(),
270 )
271 };
272
273 Ok(TlsConfig {
274 ca_cert_path: self.ca_cert_path,
275 verify_hostname: self.verify_hostname,
276 danger_accept_invalid_certs: self.danger_accept_invalid_certs,
277 danger_accept_invalid_hostnames: self.danger_accept_invalid_hostnames,
278 client_config,
279 })
280 }
281
282 fn load_custom_ca(&self, ca_path: &str) -> Result<RootCertStore> {
284 let ca_cert_data = fs::read(ca_path).map_err(|e| {
285 Error::Config(format!(
286 "Failed to read CA certificate file '{}': {}",
287 ca_path, e
288 ))
289 })?;
290
291 let mut reader = std::io::Cursor::new(&ca_cert_data);
292 let mut root_store = RootCertStore::empty();
293 let mut found_certs = 0;
294
295 loop {
297 match rustls_pemfile::read_one(&mut reader) {
298 Ok(Some(Item::X509Certificate(cert))) => {
299 let _ = root_store.add_parsable_certificates(std::iter::once(cert));
300 found_certs += 1;
301 }
302 Ok(Some(_)) => {
303 }
305 Ok(None) => {
306 break;
308 }
309 Err(_) => {
310 return Err(Error::Config(format!(
311 "Failed to parse CA certificate from '{}'",
312 ca_path
313 )));
314 }
315 }
316 }
317
318 if found_certs == 0 {
319 return Err(Error::Config(format!(
320 "No valid certificates found in '{}'",
321 ca_path
322 )));
323 }
324
325 Ok(root_store)
326 }
327}
328
329fn validate_tls_security(danger_accept_invalid_certs: bool) {
343 if danger_accept_invalid_certs {
344 #[cfg(not(debug_assertions))]
346 {
347 panic!("🚨 CRITICAL: TLS certificate validation bypass not allowed in release builds");
348 }
349
350 #[cfg(debug_assertions)]
352 {
353 tracing::warn!("TLS certificate validation is DISABLED (development only)");
354 tracing::warn!("This mode is only for development with self-signed certificates");
355 }
356 }
357}
358
359pub fn parse_server_name(hostname: &str) -> Result<String> {
373 let hostname = hostname.trim_end_matches('.');
375
376 if hostname.is_empty() || hostname.len() > 253 {
378 return Err(Error::Config(format!(
379 "Invalid hostname for TLS: '{}'",
380 hostname
381 )));
382 }
383
384 if !hostname
386 .chars()
387 .all(|c| c.is_alphanumeric() || c == '-' || c == '.')
388 {
389 return Err(Error::Config(format!(
390 "Invalid hostname for TLS: '{}'",
391 hostname
392 )));
393 }
394
395 Ok(hostname.to_string())
396}
397
398#[cfg(test)]
399mod tests {
400 use super::*;
401
402 fn install_crypto_provider() {
406 let _ = rustls::crypto::ring::default_provider().install_default();
408 }
409
410 #[test]
411 fn test_tls_config_builder_defaults() {
412 let tls = TlsConfigBuilder::default();
413 assert!(!tls.danger_accept_invalid_certs);
414 assert!(!tls.danger_accept_invalid_hostnames);
415 assert!(tls.verify_hostname);
416 assert!(tls.ca_cert_path.is_none());
417 }
418
419 #[test]
420 fn test_tls_config_builder_with_hostname_verification() {
421 install_crypto_provider();
422
423 let tls = TlsConfig::builder()
424 .verify_hostname(true)
425 .build()
426 .expect("Failed to build TLS config");
427
428 assert!(tls.verify_hostname());
429 assert!(!tls.danger_accept_invalid_certs());
430 }
431
432 #[test]
433 #[ignore = "requires PEM file on filesystem"]
434 fn test_tls_config_builder_with_custom_ca() {
435 }
437
438 #[test]
439 fn test_parse_server_name_valid() {
440 let _name =
441 parse_server_name("localhost").expect("localhost should be a valid server name");
442 let _name =
443 parse_server_name("example.com").expect("example.com should be a valid server name");
444 let _name = parse_server_name("db.internal.example.com")
445 .expect("subdomain should be a valid server name");
446 }
447
448 #[test]
449 fn test_parse_server_name_trailing_dot() {
450 let _name = parse_server_name("example.com.")
451 .expect("trailing dot should be accepted as valid server name");
452 }
453
454 #[test]
455 fn test_parse_server_name_with_port() {
456 let _result = parse_server_name("example.com:5432");
460 }
461
462 #[test]
463 fn test_tls_config_debug() {
464 install_crypto_provider();
465
466 let tls = TlsConfig::builder()
467 .verify_hostname(true)
468 .build()
469 .expect("Failed to build TLS config");
470
471 let debug_str = format!("{:?}", tls);
472 assert!(debug_str.contains("TlsConfig"));
473 assert!(debug_str.contains("verify_hostname"));
474 }
475
476 #[test]
477 #[cfg(not(debug_assertions))]
478 #[should_panic(expected = "TLS certificate validation bypass")]
479 fn test_danger_mode_panics_in_release_build() {
480 let _ = TlsConfig::builder()
482 .danger_accept_invalid_certs(true)
483 .build();
484 }
485
486 #[test]
487 fn test_danger_mode_allowed_in_debug_build() {
488 install_crypto_provider();
489
490 let config = TlsConfig::builder()
491 .danger_accept_invalid_certs(true)
492 .build()
493 .expect("danger mode should be allowed in debug builds");
494
495 assert!(config.danger_accept_invalid_certs());
496 }
497
498 #[test]
499 fn test_normal_tls_config_works() {
500 install_crypto_provider();
501
502 let config = TlsConfig::builder()
503 .verify_hostname(true)
504 .build()
505 .expect("normal TLS config should build successfully");
506
507 assert!(!config.danger_accept_invalid_certs());
508 }
509}
510
511#[derive(Debug)]
516struct NoVerifier;
517
518impl ServerCertVerifier for NoVerifier {
519 fn verify_server_cert(
520 &self,
521 _end_entity: &CertificateDer<'_>,
522 _intermediates: &[CertificateDer<'_>],
523 _server_name: &ServerName<'_>,
524 _ocsp_response: &[u8],
525 _now: UnixTime,
526 ) -> std::result::Result<ServerCertVerified, rustls::Error> {
527 Ok(ServerCertVerified::assertion())
529 }
530
531 fn verify_tls12_signature(
532 &self,
533 _message: &[u8],
534 _cert: &CertificateDer<'_>,
535 _dss: &DigitallySignedStruct,
536 ) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
537 Ok(HandshakeSignatureValid::assertion())
538 }
539
540 fn verify_tls13_signature(
541 &self,
542 _message: &[u8],
543 _cert: &CertificateDer<'_>,
544 _dss: &DigitallySignedStruct,
545 ) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
546 Ok(HandshakeSignatureValid::assertion())
547 }
548
549 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
550 vec![
552 SignatureScheme::RSA_PKCS1_SHA256,
553 SignatureScheme::RSA_PKCS1_SHA384,
554 SignatureScheme::RSA_PKCS1_SHA512,
555 SignatureScheme::ECDSA_NISTP256_SHA256,
556 SignatureScheme::ECDSA_NISTP384_SHA384,
557 SignatureScheme::ECDSA_NISTP521_SHA512,
558 SignatureScheme::RSA_PSS_SHA256,
559 SignatureScheme::RSA_PSS_SHA384,
560 SignatureScheme::RSA_PSS_SHA512,
561 SignatureScheme::ED25519,
562 ]
563 }
564}