fraiseql_wire/connection/
tls.rs1use crate::{Error, Result};
7use rustls::ClientConfig;
8use rustls::RootCertStore;
9use rustls_pemfile::Item;
10use std::fs;
11use std::sync::Arc;
12
13#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
17pub enum SslMode {
18 #[default]
20 Disable,
21 Require,
23 VerifyCa,
25 VerifyFull,
27}
28
29impl SslMode {
30 pub fn requires_verification(&self) -> bool {
32 matches!(self, Self::VerifyCa | Self::VerifyFull)
33 }
34}
35
36impl std::fmt::Display for SslMode {
37 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38 match self {
39 Self::Disable => write!(f, "disable"),
40 Self::Require => write!(f, "require"),
41 Self::VerifyCa => write!(f, "verify-ca"),
42 Self::VerifyFull => write!(f, "verify-full"),
43 }
44 }
45}
46
47impl std::str::FromStr for SslMode {
48 type Err = Error;
49
50 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
51 match s {
52 "disable" => Ok(Self::Disable),
53 "require" => Ok(Self::Require),
54 "verify-ca" => Ok(Self::VerifyCa),
55 "verify-full" => Ok(Self::VerifyFull),
56 _ => Err(Error::Config(format!(
57 "invalid sslmode '{}': expected disable, require, verify-ca, or verify-full",
58 s
59 ))),
60 }
61 }
62}
63
64#[derive(Clone)]
92pub struct TlsConfig {
93 ca_cert_path: Option<String>,
95 verify_hostname: bool,
97 danger_accept_invalid_certs: bool,
99 danger_accept_invalid_hostnames: bool,
101 client_config: Arc<ClientConfig>,
103}
104
105impl TlsConfig {
106 pub fn builder() -> TlsConfigBuilder {
116 TlsConfigBuilder::default()
117 }
118
119 pub fn client_config(&self) -> Arc<ClientConfig> {
121 self.client_config.clone()
122 }
123
124 pub fn verify_hostname(&self) -> bool {
126 self.verify_hostname
127 }
128
129 pub fn danger_accept_invalid_certs(&self) -> bool {
131 self.danger_accept_invalid_certs
132 }
133
134 pub fn danger_accept_invalid_hostnames(&self) -> bool {
136 self.danger_accept_invalid_hostnames
137 }
138}
139
140impl std::fmt::Debug for TlsConfig {
141 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
142 f.debug_struct("TlsConfig")
143 .field("ca_cert_path", &self.ca_cert_path)
144 .field("verify_hostname", &self.verify_hostname)
145 .field(
146 "danger_accept_invalid_certs",
147 &self.danger_accept_invalid_certs,
148 )
149 .field(
150 "danger_accept_invalid_hostnames",
151 &self.danger_accept_invalid_hostnames,
152 )
153 .field("client_config", &"<ClientConfig>")
154 .finish()
155 }
156}
157
158pub struct TlsConfigBuilder {
162 ca_cert_path: Option<String>,
163 pub(crate) client_cert_path: Option<String>,
165 pub(crate) client_key_path: Option<String>,
167 verify_hostname: bool,
168 danger_accept_invalid_certs: bool,
169 danger_accept_invalid_hostnames: bool,
170}
171
172impl Default for TlsConfigBuilder {
173 fn default() -> Self {
174 Self {
175 ca_cert_path: None,
176 client_cert_path: None,
177 client_key_path: None,
178 verify_hostname: true,
179 danger_accept_invalid_certs: false,
180 danger_accept_invalid_hostnames: false,
181 }
182 }
183}
184
185impl TlsConfigBuilder {
186 pub fn ca_cert_path(mut self, path: impl Into<String>) -> Self {
202 self.ca_cert_path = Some(path.into());
203 self
204 }
205
206 pub fn verify_hostname(mut self, verify: bool) -> Self {
223 self.verify_hostname = verify;
224 self
225 }
226
227 pub fn danger_accept_invalid_certs(mut self, accept: bool) -> Self {
242 self.danger_accept_invalid_certs = accept;
243 self
244 }
245
246 pub fn danger_accept_invalid_hostnames(mut self, accept: bool) -> Self {
262 self.danger_accept_invalid_hostnames = accept;
263 self
264 }
265
266 pub fn client_cert_path(mut self, path: impl Into<String>) -> Self {
270 self.client_cert_path = Some(path.into());
271 self
272 }
273
274 pub fn client_key_path(mut self, path: impl Into<String>) -> Self {
278 self.client_key_path = Some(path.into());
279 self
280 }
281
282 pub fn build(self) -> Result<TlsConfig> {
299 let root_store = if let Some(ca_path) = &self.ca_cert_path {
301 self.load_custom_ca(ca_path)?
303 } else {
304 let result = rustls_native_certs::load_native_certs();
306
307 let mut store = RootCertStore::empty();
308 for cert in result.certs {
309 let _ = store.add_parsable_certificates(std::iter::once(cert));
310 }
311
312 if !result.errors.is_empty() && store.is_empty() {
314 return Err(Error::Config(
315 "Failed to load any system root certificates".to_string(),
316 ));
317 }
318
319 store
320 };
321
322 let client_config = match (&self.client_cert_path, &self.client_key_path) {
324 (Some(cert_path), Some(key_path)) => {
325 let certs = self.load_client_certs(cert_path)?;
326 let key = self.load_client_key(key_path)?;
327 Arc::new(
328 ClientConfig::builder()
329 .with_root_certificates(root_store)
330 .with_client_auth_cert(certs, key)
331 .map_err(|e| {
332 Error::Config(format!("invalid client certificate/key: {}", e))
333 })?,
334 )
335 }
336 (Some(_), None) => {
337 return Err(Error::Config(
338 "client certificate provided without client key (sslkey)".to_string(),
339 ));
340 }
341 (None, Some(_)) => {
342 return Err(Error::Config(
343 "client key provided without client certificate (sslcert)".to_string(),
344 ));
345 }
346 (None, None) => Arc::new(
347 ClientConfig::builder()
348 .with_root_certificates(root_store)
349 .with_no_client_auth(),
350 ),
351 };
352
353 Ok(TlsConfig {
354 ca_cert_path: self.ca_cert_path,
355 verify_hostname: self.verify_hostname,
356 danger_accept_invalid_certs: self.danger_accept_invalid_certs,
357 danger_accept_invalid_hostnames: self.danger_accept_invalid_hostnames,
358 client_config,
359 })
360 }
361
362 fn load_client_certs(
364 &self,
365 cert_path: &str,
366 ) -> Result<Vec<rustls_pki_types::CertificateDer<'static>>> {
367 let cert_data = fs::read(cert_path).map_err(|e| {
368 Error::Config(format!(
369 "failed to read client certificate '{}': {}",
370 cert_path, e
371 ))
372 })?;
373
374 let mut reader = std::io::Cursor::new(&cert_data);
375 let mut certs = Vec::new();
376
377 loop {
378 match rustls_pemfile::read_one(&mut reader) {
379 Ok(Some(Item::X509Certificate(cert))) => certs.push(cert),
380 Ok(Some(_)) => {}
381 Ok(None) => break,
382 Err(_) => {
383 return Err(Error::Config(format!(
384 "failed to parse client certificate from '{}'",
385 cert_path
386 )));
387 }
388 }
389 }
390
391 if certs.is_empty() {
392 return Err(Error::Config(format!(
393 "no valid certificates found in '{}'",
394 cert_path
395 )));
396 }
397
398 Ok(certs)
399 }
400
401 fn load_client_key(&self, key_path: &str) -> Result<rustls_pki_types::PrivateKeyDer<'static>> {
403 let key_data = fs::read(key_path).map_err(|e| {
404 Error::Config(format!("failed to read client key '{}': {}", key_path, e))
405 })?;
406
407 let mut reader = std::io::Cursor::new(&key_data);
408
409 loop {
410 match rustls_pemfile::read_one(&mut reader) {
411 Ok(Some(Item::Pkcs1Key(key))) => {
412 return Ok(rustls_pki_types::PrivateKeyDer::Pkcs1(key));
413 }
414 Ok(Some(Item::Pkcs8Key(key))) => {
415 return Ok(rustls_pki_types::PrivateKeyDer::Pkcs8(key));
416 }
417 Ok(Some(Item::Sec1Key(key))) => {
418 return Ok(rustls_pki_types::PrivateKeyDer::Sec1(key));
419 }
420 Ok(Some(_)) => {}
421 Ok(None) => break,
422 Err(_) => {
423 return Err(Error::Config(format!(
424 "failed to parse client key from '{}'",
425 key_path
426 )));
427 }
428 }
429 }
430
431 Err(Error::Config(format!(
432 "no valid private key found in '{}'",
433 key_path
434 )))
435 }
436
437 fn load_custom_ca(&self, ca_path: &str) -> Result<RootCertStore> {
439 let ca_cert_data = fs::read(ca_path).map_err(|e| {
440 Error::Config(format!(
441 "Failed to read CA certificate file '{}': {}",
442 ca_path, e
443 ))
444 })?;
445
446 let mut reader = std::io::Cursor::new(&ca_cert_data);
447 let mut root_store = RootCertStore::empty();
448 let mut found_certs = 0;
449
450 loop {
452 match rustls_pemfile::read_one(&mut reader) {
453 Ok(Some(Item::X509Certificate(cert))) => {
454 let _ = root_store.add_parsable_certificates(std::iter::once(cert));
455 found_certs += 1;
456 }
457 Ok(Some(_)) => {
458 }
460 Ok(None) => {
461 break;
463 }
464 Err(_) => {
465 return Err(Error::Config(format!(
466 "Failed to parse CA certificate from '{}'",
467 ca_path
468 )));
469 }
470 }
471 }
472
473 if found_certs == 0 {
474 return Err(Error::Config(format!(
475 "No valid certificates found in '{}'",
476 ca_path
477 )));
478 }
479
480 Ok(root_store)
481 }
482}
483
484pub fn parse_server_name(hostname: &str) -> Result<String> {
498 let hostname = hostname.trim_end_matches('.');
500
501 if hostname.is_empty() || hostname.len() > 253 {
503 return Err(Error::Config(format!(
504 "Invalid hostname for TLS: '{}'",
505 hostname
506 )));
507 }
508
509 if !hostname
511 .chars()
512 .all(|c| c.is_alphanumeric() || c == '-' || c == '.')
513 {
514 return Err(Error::Config(format!(
515 "Invalid hostname for TLS: '{}'",
516 hostname
517 )));
518 }
519
520 Ok(hostname.to_string())
521}
522
523#[cfg(test)]
524mod tests {
525 use super::*;
526
527 #[test]
528 fn test_tls_config_builder_defaults() {
529 let tls = TlsConfigBuilder::default();
530 assert!(!tls.danger_accept_invalid_certs);
531 assert!(!tls.danger_accept_invalid_hostnames);
532 assert!(tls.verify_hostname);
533 assert!(tls.ca_cert_path.is_none());
534 }
535
536 #[test]
537 fn test_tls_config_builder_with_hostname_verification() {
538 let tls = TlsConfig::builder()
539 .verify_hostname(true)
540 .build()
541 .expect("Failed to build TLS config");
542
543 assert!(tls.verify_hostname());
544 assert!(!tls.danger_accept_invalid_certs());
545 }
546
547 #[test]
548 fn test_tls_config_builder_with_custom_ca() {
549 }
552
553 #[test]
554 fn test_parse_server_name_valid() {
555 let result = parse_server_name("localhost");
556 assert!(result.is_ok());
557
558 let result = parse_server_name("example.com");
559 assert!(result.is_ok());
560
561 let result = parse_server_name("db.internal.example.com");
562 assert!(result.is_ok());
563 }
564
565 #[test]
566 fn test_parse_server_name_trailing_dot() {
567 let result = parse_server_name("example.com.");
568 assert!(result.is_ok());
569 }
570
571 #[test]
572 fn test_parse_server_name_with_port_fails() {
573 let result = parse_server_name("example.com:5432");
575 let _ = result;
578 }
579
580 #[test]
581 fn test_ssl_mode_from_str() {
582 assert_eq!("disable".parse::<SslMode>().unwrap(), SslMode::Disable);
583 assert_eq!("require".parse::<SslMode>().unwrap(), SslMode::Require);
584 assert_eq!("verify-ca".parse::<SslMode>().unwrap(), SslMode::VerifyCa);
585 assert_eq!(
586 "verify-full".parse::<SslMode>().unwrap(),
587 SslMode::VerifyFull
588 );
589 }
590
591 #[test]
592 fn test_ssl_mode_from_str_invalid() {
593 assert!("invalid".parse::<SslMode>().is_err());
594 assert!("prefer".parse::<SslMode>().is_err());
595 }
596
597 #[test]
598 fn test_ssl_mode_display() {
599 assert_eq!(SslMode::Disable.to_string(), "disable");
600 assert_eq!(SslMode::Require.to_string(), "require");
601 assert_eq!(SslMode::VerifyCa.to_string(), "verify-ca");
602 assert_eq!(SslMode::VerifyFull.to_string(), "verify-full");
603 }
604
605 #[test]
606 fn test_ssl_mode_default() {
607 assert_eq!(SslMode::default(), SslMode::Disable);
608 }
609
610 #[test]
611 fn test_ssl_mode_requires_verification() {
612 assert!(!SslMode::Disable.requires_verification());
613 assert!(!SslMode::Require.requires_verification());
614 assert!(SslMode::VerifyCa.requires_verification());
615 assert!(SslMode::VerifyFull.requires_verification());
616 }
617
618 #[test]
619 fn test_tls_config_builder_with_client_cert_methods() {
620 let builder = TlsConfig::builder()
622 .client_cert_path("/path/to/client.pem")
623 .client_key_path("/path/to/client-key.pem");
624 assert_eq!(
625 builder.client_cert_path.as_deref(),
626 Some("/path/to/client.pem")
627 );
628 assert_eq!(
629 builder.client_key_path.as_deref(),
630 Some("/path/to/client-key.pem")
631 );
632 }
633
634 #[test]
635 fn test_tls_config_builder_client_cert_without_key_fails() {
636 let result = TlsConfig::builder()
638 .client_cert_path("/path/to/client.pem")
639 .build();
640 assert!(result.is_err());
641 }
642
643 #[test]
644 fn test_tls_config_builder_client_key_without_cert_fails() {
645 let result = TlsConfig::builder()
647 .client_key_path("/path/to/client-key.pem")
648 .build();
649 assert!(result.is_err());
650 }
651
652 #[test]
653 fn test_tls_config_debug() {
654 let tls = TlsConfig::builder()
655 .verify_hostname(true)
656 .build()
657 .expect("Failed to build TLS config");
658
659 let debug_str = format!("{:?}", tls);
660 assert!(debug_str.contains("TlsConfig"));
661 assert!(debug_str.contains("verify_hostname"));
662 }
663}