Skip to main content

fraiseql_wire/connection/
tls.rs

1//! TLS configuration and support for secure connections to Postgres.
2//!
3//! This module provides TLS configuration for connecting to remote Postgres servers.
4//! TLS is recommended for all non-local connections to prevent credential interception.
5
6use crate::{Error, Result};
7use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
8use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
9use rustls::RootCertStore;
10use rustls::{ClientConfig, DigitallySignedStruct, SignatureScheme};
11use rustls_pemfile::Item;
12use std::fmt::Debug;
13use std::fs;
14use std::sync::Arc;
15
16/// TLS configuration for secure Postgres connections.
17///
18/// Provides a builder for creating TLS configurations with various certificate handling options.
19/// By default, server certificates are validated against system root certificates.
20///
21/// # Examples
22///
23/// ```ignore
24/// use fraiseql_wire::connection::TlsConfig;
25///
26/// // With system root certificates (production)
27/// let tls = TlsConfig::builder()
28///     .verify_hostname(true)
29///     .build()?;
30///
31/// // With custom CA certificate
32/// let tls = TlsConfig::builder()
33///     .ca_cert_path("/path/to/ca.pem")?
34///     .verify_hostname(true)
35///     .build()?;
36///
37/// // For development (danger: disables verification)
38/// let tls = TlsConfig::builder()
39///     .danger_accept_invalid_certs(true)
40///     .danger_accept_invalid_hostnames(true)
41///     .build()?;
42/// ```
43#[derive(Clone)]
44pub struct TlsConfig {
45    /// Path to CA certificate file (None = use system roots)
46    ca_cert_path: Option<String>,
47    /// Whether to verify hostname matches certificate
48    verify_hostname: bool,
49    /// Whether to accept invalid certificates (development only)
50    danger_accept_invalid_certs: bool,
51    /// Whether to accept invalid hostnames (development only)
52    danger_accept_invalid_hostnames: bool,
53    /// Compiled rustls ClientConfig
54    client_config: Arc<ClientConfig>,
55}
56
57impl TlsConfig {
58    /// Create a new TLS configuration builder.
59    ///
60    /// # Examples
61    ///
62    /// ```ignore
63    /// let tls = TlsConfig::builder()
64    ///     .verify_hostname(true)
65    ///     .build()?;
66    /// ```
67    pub fn builder() -> TlsConfigBuilder {
68        TlsConfigBuilder::default()
69    }
70
71    /// Get the rustls ClientConfig for this TLS configuration.
72    pub fn client_config(&self) -> Arc<ClientConfig> {
73        self.client_config.clone()
74    }
75
76    /// Check if hostname verification is enabled.
77    pub fn verify_hostname(&self) -> bool {
78        self.verify_hostname
79    }
80
81    /// Check if invalid certificates are accepted (development only).
82    pub fn danger_accept_invalid_certs(&self) -> bool {
83        self.danger_accept_invalid_certs
84    }
85
86    /// Check if invalid hostnames are accepted (development only).
87    pub fn danger_accept_invalid_hostnames(&self) -> bool {
88        self.danger_accept_invalid_hostnames
89    }
90}
91
92impl std::fmt::Debug for TlsConfig {
93    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94        f.debug_struct("TlsConfig")
95            .field("ca_cert_path", &self.ca_cert_path)
96            .field("verify_hostname", &self.verify_hostname)
97            .field(
98                "danger_accept_invalid_certs",
99                &self.danger_accept_invalid_certs,
100            )
101            .field(
102                "danger_accept_invalid_hostnames",
103                &self.danger_accept_invalid_hostnames,
104            )
105            .field("client_config", &"<ClientConfig>")
106            .finish()
107    }
108}
109
110/// Builder for TLS configuration.
111///
112/// Provides a fluent API for constructing TLS configurations with custom settings.
113pub struct TlsConfigBuilder {
114    ca_cert_path: Option<String>,
115    verify_hostname: bool,
116    danger_accept_invalid_certs: bool,
117    danger_accept_invalid_hostnames: bool,
118}
119
120impl Default for TlsConfigBuilder {
121    fn default() -> Self {
122        Self {
123            ca_cert_path: None,
124            verify_hostname: true,
125            danger_accept_invalid_certs: false,
126            danger_accept_invalid_hostnames: false,
127        }
128    }
129}
130
131impl TlsConfigBuilder {
132    /// Set the path to a custom CA certificate file (PEM format).
133    ///
134    /// If not set, system root certificates will be used.
135    ///
136    /// # Arguments
137    ///
138    /// * `path` - Path to CA certificate file in PEM format
139    ///
140    /// # Examples
141    ///
142    /// ```ignore
143    /// let tls = TlsConfig::builder()
144    ///     .ca_cert_path("/etc/ssl/certs/ca.pem")?
145    ///     .build()?;
146    /// ```
147    pub fn ca_cert_path(mut self, path: impl Into<String>) -> Self {
148        self.ca_cert_path = Some(path.into());
149        self
150    }
151
152    /// Enable or disable hostname verification (default: enabled).
153    ///
154    /// When enabled, the certificate's subject alternative names (SANs) are verified
155    /// to match the server hostname.
156    ///
157    /// # Arguments
158    ///
159    /// * `verify` - Whether to verify hostname matches certificate
160    ///
161    /// # Examples
162    ///
163    /// ```ignore
164    /// let tls = TlsConfig::builder()
165    ///     .verify_hostname(true)
166    ///     .build()?;
167    /// ```
168    pub fn verify_hostname(mut self, verify: bool) -> Self {
169        self.verify_hostname = verify;
170        self
171    }
172
173    /// ⚠️ **DANGER**: Accept invalid certificates (development only).
174    ///
175    /// **NEVER use in production.** This disables certificate validation entirely,
176    /// making the connection vulnerable to man-in-the-middle attacks.
177    ///
178    /// Only use for testing with self-signed certificates.
179    ///
180    /// # Examples
181    ///
182    /// ```ignore
183    /// let tls = TlsConfig::builder()
184    ///     .danger_accept_invalid_certs(true)
185    ///     .build()?;
186    /// ```
187    pub fn danger_accept_invalid_certs(mut self, accept: bool) -> Self {
188        self.danger_accept_invalid_certs = accept;
189        self
190    }
191
192    /// ⚠️ **DANGER**: Accept invalid hostnames (development only).
193    ///
194    /// **NEVER use in production.** This disables hostname verification,
195    /// making the connection vulnerable to man-in-the-middle attacks.
196    ///
197    /// Only use for testing with self-signed certificates where you can't
198    /// match the hostname.
199    ///
200    /// # Examples
201    ///
202    /// ```ignore
203    /// let tls = TlsConfig::builder()
204    ///     .danger_accept_invalid_hostnames(true)
205    ///     .build()?;
206    /// ```
207    pub fn danger_accept_invalid_hostnames(mut self, accept: bool) -> Self {
208        self.danger_accept_invalid_hostnames = accept;
209        self
210    }
211
212    /// Build the TLS configuration.
213    ///
214    /// # Errors
215    ///
216    /// Returns an error if:
217    /// - CA certificate file cannot be read
218    /// - CA certificate is invalid PEM
219    /// - Dangerous options are configured incorrectly
220    ///
221    /// # Examples
222    ///
223    /// ```ignore
224    /// let tls = TlsConfig::builder()
225    ///     .verify_hostname(true)
226    ///     .build()?;
227    /// ```
228    pub fn build(self) -> Result<TlsConfig> {
229        // SECURITY: Validate TLS configuration before creating client
230        validate_tls_security(self.danger_accept_invalid_certs);
231
232        let client_config = if self.danger_accept_invalid_certs {
233            // Create a client config that accepts any certificate (development only)
234            let verifier = Arc::new(NoVerifier);
235            Arc::new(
236                ClientConfig::builder()
237                    .dangerous()
238                    .with_custom_certificate_verifier(verifier)
239                    .with_no_client_auth(),
240            )
241        } else {
242            // Load root certificates
243            let root_store = if let Some(ca_path) = &self.ca_cert_path {
244                // Load custom CA certificate from file
245                self.load_custom_ca(ca_path)?
246            } else {
247                // Use system root certificates via rustls-native-certs
248                let result = rustls_native_certs::load_native_certs();
249
250                let mut store = RootCertStore::empty();
251                for cert in result.certs {
252                    let _ = store.add_parsable_certificates(std::iter::once(cert));
253                }
254
255                // Log warnings if there were errors, but don't fail
256                if !result.errors.is_empty() && store.is_empty() {
257                    return Err(Error::Config(
258                        "Failed to load any system root certificates".to_string(),
259                    ));
260                }
261
262                store
263            };
264
265            // Create ClientConfig using the correct API for rustls 0.23
266            Arc::new(
267                ClientConfig::builder()
268                    .with_root_certificates(root_store)
269                    .with_no_client_auth(),
270            )
271        };
272
273        Ok(TlsConfig {
274            ca_cert_path: self.ca_cert_path,
275            verify_hostname: self.verify_hostname,
276            danger_accept_invalid_certs: self.danger_accept_invalid_certs,
277            danger_accept_invalid_hostnames: self.danger_accept_invalid_hostnames,
278            client_config,
279        })
280    }
281
282    /// Load a custom CA certificate from a PEM file.
283    fn load_custom_ca(&self, ca_path: &str) -> Result<RootCertStore> {
284        let ca_cert_data = fs::read(ca_path).map_err(|e| {
285            Error::Config(format!(
286                "Failed to read CA certificate file '{}': {}",
287                ca_path, e
288            ))
289        })?;
290
291        let mut reader = std::io::Cursor::new(&ca_cert_data);
292        let mut root_store = RootCertStore::empty();
293        let mut found_certs = 0;
294
295        // Parse PEM file and extract certificates
296        loop {
297            match rustls_pemfile::read_one(&mut reader) {
298                Ok(Some(Item::X509Certificate(cert))) => {
299                    let _ = root_store.add_parsable_certificates(std::iter::once(cert));
300                    found_certs += 1;
301                }
302                Ok(Some(_)) => {
303                    // Skip non-certificate items (private keys, etc.)
304                }
305                Ok(None) => {
306                    // End of file
307                    break;
308                }
309                Err(_) => {
310                    return Err(Error::Config(format!(
311                        "Failed to parse CA certificate from '{}'",
312                        ca_path
313                    )));
314                }
315            }
316        }
317
318        if found_certs == 0 {
319            return Err(Error::Config(format!(
320                "No valid certificates found in '{}'",
321                ca_path
322            )));
323        }
324
325        Ok(root_store)
326    }
327}
328
329/// Validate TLS configuration for security constraints.
330///
331/// Enforces:
332/// - Release builds cannot use `danger_accept_invalid_certs`
333/// - Production environment rejects danger mode
334///
335/// # Arguments
336///
337/// * `danger_accept_invalid_certs` - Whether danger mode is enabled
338///
339/// # Errors
340///
341/// Returns an error or panics if validation fails
342fn validate_tls_security(danger_accept_invalid_certs: bool) {
343    if danger_accept_invalid_certs {
344        // SECURITY: Panic in release builds to prevent accidental production use
345        #[cfg(not(debug_assertions))]
346        {
347            panic!("🚨 CRITICAL: TLS certificate validation bypass not allowed in release builds");
348        }
349
350        // Development builds: warn but allow
351        #[cfg(debug_assertions)]
352        {
353            tracing::warn!("TLS certificate validation is DISABLED (development only)");
354            tracing::warn!("This mode is only for development with self-signed certificates");
355        }
356    }
357}
358
359/// Parse server name from hostname for TLS SNI (Server Name Indication).
360///
361/// # Arguments
362///
363/// * `hostname` - Hostname to parse (without port)
364///
365/// # Returns
366///
367/// A string suitable for TLS server name indication
368///
369/// # Errors
370///
371/// Returns an error if the hostname is invalid.
372pub fn parse_server_name(hostname: &str) -> Result<String> {
373    // Remove trailing dot if present
374    let hostname = hostname.trim_end_matches('.');
375
376    // Validate hostname (basic check)
377    if hostname.is_empty() || hostname.len() > 253 {
378        return Err(Error::Config(format!(
379            "Invalid hostname for TLS: '{}'",
380            hostname
381        )));
382    }
383
384    // Check for invalid characters
385    if !hostname
386        .chars()
387        .all(|c| c.is_alphanumeric() || c == '-' || c == '.')
388    {
389        return Err(Error::Config(format!(
390            "Invalid hostname for TLS: '{}'",
391            hostname
392        )));
393    }
394
395    Ok(hostname.to_string())
396}
397
398#[cfg(test)]
399mod tests {
400    use super::*;
401
402    /// Install a crypto provider for rustls tests.
403    /// This is needed because multiple crypto providers (ring and aws-lc-rs)
404    /// may be enabled via transitive dependencies, requiring explicit selection.
405    fn install_crypto_provider() {
406        // Try to install ring as the default provider, ignore if already installed
407        let _ = rustls::crypto::ring::default_provider().install_default();
408    }
409
410    #[test]
411    fn test_tls_config_builder_defaults() {
412        let tls = TlsConfigBuilder::default();
413        assert!(!tls.danger_accept_invalid_certs);
414        assert!(!tls.danger_accept_invalid_hostnames);
415        assert!(tls.verify_hostname);
416        assert!(tls.ca_cert_path.is_none());
417    }
418
419    #[test]
420    fn test_tls_config_builder_with_hostname_verification() {
421        install_crypto_provider();
422
423        let tls = TlsConfig::builder()
424            .verify_hostname(true)
425            .build()
426            .expect("Failed to build TLS config");
427
428        assert!(tls.verify_hostname());
429        assert!(!tls.danger_accept_invalid_certs());
430    }
431
432    #[test]
433    #[ignore = "requires PEM file on filesystem"]
434    fn test_tls_config_builder_with_custom_ca() {
435        // This test would require an actual PEM file
436    }
437
438    #[test]
439    fn test_parse_server_name_valid() {
440        let _name =
441            parse_server_name("localhost").expect("localhost should be a valid server name");
442        let _name =
443            parse_server_name("example.com").expect("example.com should be a valid server name");
444        let _name = parse_server_name("db.internal.example.com")
445            .expect("subdomain should be a valid server name");
446    }
447
448    #[test]
449    fn test_parse_server_name_trailing_dot() {
450        let _name = parse_server_name("example.com.")
451            .expect("trailing dot should be accepted as valid server name");
452    }
453
454    #[test]
455    fn test_parse_server_name_with_port() {
456        // ServerName expects just hostname, not host:port.
457        // Whether this succeeds or fails depends on the rustls version,
458        // so we only verify it doesn't panic.
459        let _result = parse_server_name("example.com:5432");
460    }
461
462    #[test]
463    fn test_tls_config_debug() {
464        install_crypto_provider();
465
466        let tls = TlsConfig::builder()
467            .verify_hostname(true)
468            .build()
469            .expect("Failed to build TLS config");
470
471        let debug_str = format!("{:?}", tls);
472        assert!(debug_str.contains("TlsConfig"));
473        assert!(debug_str.contains("verify_hostname"));
474    }
475
476    #[test]
477    #[cfg(not(debug_assertions))]
478    #[should_panic(expected = "TLS certificate validation bypass")]
479    fn test_danger_mode_panics_in_release_build() {
480        // This test only runs in release builds and should panic
481        let _ = TlsConfig::builder()
482            .danger_accept_invalid_certs(true)
483            .build();
484    }
485
486    #[test]
487    fn test_danger_mode_allowed_in_debug_build() {
488        install_crypto_provider();
489
490        let config = TlsConfig::builder()
491            .danger_accept_invalid_certs(true)
492            .build()
493            .expect("danger mode should be allowed in debug builds");
494
495        assert!(config.danger_accept_invalid_certs());
496    }
497
498    #[test]
499    fn test_normal_tls_config_works() {
500        install_crypto_provider();
501
502        let config = TlsConfig::builder()
503            .verify_hostname(true)
504            .build()
505            .expect("normal TLS config should build successfully");
506
507        assert!(!config.danger_accept_invalid_certs());
508    }
509}
510
511/// A certificate verifier that accepts any certificate.
512///
513/// ⚠️ **DANGER**: This should ONLY be used for development/testing with self-signed certificates.
514/// Using this in production is a serious security vulnerability.
515#[derive(Debug)]
516struct NoVerifier;
517
518impl ServerCertVerifier for NoVerifier {
519    fn verify_server_cert(
520        &self,
521        _end_entity: &CertificateDer<'_>,
522        _intermediates: &[CertificateDer<'_>],
523        _server_name: &ServerName<'_>,
524        _ocsp_response: &[u8],
525        _now: UnixTime,
526    ) -> std::result::Result<ServerCertVerified, rustls::Error> {
527        // Accept any certificate
528        Ok(ServerCertVerified::assertion())
529    }
530
531    fn verify_tls12_signature(
532        &self,
533        _message: &[u8],
534        _cert: &CertificateDer<'_>,
535        _dss: &DigitallySignedStruct,
536    ) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
537        Ok(HandshakeSignatureValid::assertion())
538    }
539
540    fn verify_tls13_signature(
541        &self,
542        _message: &[u8],
543        _cert: &CertificateDer<'_>,
544        _dss: &DigitallySignedStruct,
545    ) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
546        Ok(HandshakeSignatureValid::assertion())
547    }
548
549    fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
550        // Support all common signature schemes
551        vec![
552            SignatureScheme::RSA_PKCS1_SHA256,
553            SignatureScheme::RSA_PKCS1_SHA384,
554            SignatureScheme::RSA_PKCS1_SHA512,
555            SignatureScheme::ECDSA_NISTP256_SHA256,
556            SignatureScheme::ECDSA_NISTP384_SHA384,
557            SignatureScheme::ECDSA_NISTP521_SHA512,
558            SignatureScheme::RSA_PSS_SHA256,
559            SignatureScheme::RSA_PSS_SHA384,
560            SignatureScheme::RSA_PSS_SHA512,
561            SignatureScheme::ED25519,
562        ]
563    }
564}