Skip to main content

fraiseql_wire/connection/
tls.rs

1//! TLS configuration and support for secure connections to Postgres.
2//!
3//! This module provides TLS configuration for connecting to remote Postgres servers.
4//! TLS is recommended for all non-local connections to prevent credential interception.
5
6use crate::{Result, WireError};
7use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
8use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
9use rustls::RootCertStore;
10use rustls::{ClientConfig, DigitallySignedStruct, SignatureScheme};
11use rustls_pemfile::Item;
12use std::fmt::Debug;
13use std::fs;
14use std::sync::Arc;
15
16/// TLS configuration for secure Postgres connections.
17///
18/// Provides a builder for creating TLS configurations with various certificate handling options.
19/// By default, server certificates are validated against system root certificates.
20///
21/// # Examples
22///
23/// ```no_run
24/// // Requires: system root certificates or a CA certificate file on disk.
25/// use fraiseql_wire::connection::TlsConfig;
26///
27/// // With system root certificates (production)
28/// let tls = TlsConfig::builder()
29///     .verify_hostname(true)
30///     .build()?;
31///
32/// // With custom CA certificate
33/// let tls = TlsConfig::builder()
34///     .ca_cert_path("/path/to/ca.pem")
35///     .verify_hostname(true)
36///     .build()?;
37///
38/// // For development (danger: disables verification)
39/// let tls = TlsConfig::builder()
40///     .danger_accept_invalid_certs(true)
41///     .danger_accept_invalid_hostnames(true)
42///     .build()?;
43/// # fraiseql_wire::Result::Ok(())
44/// ```
45#[derive(Clone)]
46pub struct TlsConfig {
47    /// Path to CA certificate file (None = use system roots)
48    ca_cert_path: Option<String>,
49    /// Whether to verify hostname matches certificate
50    verify_hostname: bool,
51    /// Whether to accept invalid certificates (development only)
52    danger_accept_invalid_certs: bool,
53    /// Whether to accept invalid hostnames (development only)
54    danger_accept_invalid_hostnames: bool,
55    /// Compiled rustls `ClientConfig`
56    client_config: Arc<ClientConfig>,
57}
58
59impl TlsConfig {
60    /// Create a new TLS configuration builder.
61    ///
62    /// # Examples
63    ///
64    /// ```no_run
65    /// // Requires: system root certificates.
66    /// use fraiseql_wire::connection::TlsConfig;
67    /// let tls = TlsConfig::builder()
68    ///     .verify_hostname(true)
69    ///     .build()?;
70    /// # fraiseql_wire::Result::Ok(())
71    /// ```
72    pub fn builder() -> TlsConfigBuilder {
73        TlsConfigBuilder::default()
74    }
75
76    /// Get the rustls `ClientConfig` for this TLS configuration.
77    pub fn client_config(&self) -> Arc<ClientConfig> {
78        self.client_config.clone()
79    }
80
81    /// Check if hostname verification is enabled.
82    pub const fn verify_hostname(&self) -> bool {
83        self.verify_hostname
84    }
85
86    /// Check if invalid certificates are accepted (development only).
87    pub const fn danger_accept_invalid_certs(&self) -> bool {
88        self.danger_accept_invalid_certs
89    }
90
91    /// Check if invalid hostnames are accepted (development only).
92    pub const fn danger_accept_invalid_hostnames(&self) -> bool {
93        self.danger_accept_invalid_hostnames
94    }
95}
96
97impl std::fmt::Debug for TlsConfig {
98    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99        f.debug_struct("TlsConfig")
100            .field("ca_cert_path", &self.ca_cert_path)
101            .field("verify_hostname", &self.verify_hostname)
102            .field(
103                "danger_accept_invalid_certs",
104                &self.danger_accept_invalid_certs,
105            )
106            .field(
107                "danger_accept_invalid_hostnames",
108                &self.danger_accept_invalid_hostnames,
109            )
110            .field("client_config", &"<ClientConfig>")
111            .finish()
112    }
113}
114
115/// Builder for TLS configuration.
116///
117/// Provides a fluent API for constructing TLS configurations with custom settings.
118#[must_use = "call .build() to construct the final value"]
119pub struct TlsConfigBuilder {
120    ca_cert_path: Option<String>,
121    verify_hostname: bool,
122    danger_accept_invalid_certs: bool,
123    danger_accept_invalid_hostnames: bool,
124}
125
126impl Default for TlsConfigBuilder {
127    fn default() -> Self {
128        Self {
129            ca_cert_path: None,
130            verify_hostname: true,
131            danger_accept_invalid_certs: false,
132            danger_accept_invalid_hostnames: false,
133        }
134    }
135}
136
137impl TlsConfigBuilder {
138    /// Set the path to a custom CA certificate file (PEM format).
139    ///
140    /// If not set, system root certificates will be used.
141    ///
142    /// # Arguments
143    ///
144    /// * `path` - Path to CA certificate file in PEM format
145    ///
146    /// # Examples
147    ///
148    /// ```no_run
149    /// // Requires: CA certificate file at the specified path.
150    /// use fraiseql_wire::connection::TlsConfig;
151    /// let tls = TlsConfig::builder()
152    ///     .ca_cert_path("/etc/ssl/certs/ca.pem")
153    ///     .build()?;
154    /// # fraiseql_wire::Result::Ok(())
155    /// ```
156    pub fn ca_cert_path(mut self, path: impl Into<String>) -> Self {
157        self.ca_cert_path = Some(path.into());
158        self
159    }
160
161    /// Enable or disable hostname verification (default: enabled).
162    ///
163    /// When enabled, the certificate's subject alternative names (SANs) are verified
164    /// to match the server hostname.
165    ///
166    /// # Arguments
167    ///
168    /// * `verify` - Whether to verify hostname matches certificate
169    ///
170    /// # Examples
171    ///
172    /// ```no_run
173    /// // Requires: system root certificates.
174    /// use fraiseql_wire::connection::TlsConfig;
175    /// let tls = TlsConfig::builder()
176    ///     .verify_hostname(true)
177    ///     .build()?;
178    /// # fraiseql_wire::Result::Ok(())
179    /// ```
180    pub const fn verify_hostname(mut self, verify: bool) -> Self {
181        self.verify_hostname = verify;
182        self
183    }
184
185    /// ⚠️ **DANGER**: Accept invalid certificates (development only).
186    ///
187    /// **NEVER use in production.** This disables certificate validation entirely,
188    /// making the connection vulnerable to man-in-the-middle attacks.
189    ///
190    /// Only use for testing with self-signed certificates.
191    ///
192    /// # Errors
193    ///
194    /// [`TlsConfigBuilder::build`] returns `WireError::Config` when this option is `true`
195    /// in a release build (`cfg(not(debug_assertions))`).
196    ///
197    /// # Examples
198    ///
199    /// ```no_run
200    /// // Requires: debug build only (returns Err in release mode).
201    /// use fraiseql_wire::connection::TlsConfig;
202    /// let tls = TlsConfig::builder()
203    ///     .danger_accept_invalid_certs(true)
204    ///     .build()?;
205    /// # fraiseql_wire::Result::Ok(())
206    /// ```
207    pub const fn danger_accept_invalid_certs(mut self, accept: bool) -> Self {
208        self.danger_accept_invalid_certs = accept;
209        self
210    }
211
212    /// ⚠️ **DANGER**: Accept invalid hostnames (development only).
213    ///
214    /// **NEVER use in production.** This disables hostname verification,
215    /// making the connection vulnerable to man-in-the-middle attacks.
216    ///
217    /// Only use for testing with self-signed certificates where you can't
218    /// match the hostname.
219    ///
220    /// # Examples
221    ///
222    /// ```no_run
223    /// // Requires: debug build only.
224    /// use fraiseql_wire::connection::TlsConfig;
225    /// let tls = TlsConfig::builder()
226    ///     .danger_accept_invalid_hostnames(true)
227    ///     .build()?;
228    /// # fraiseql_wire::Result::Ok(())
229    /// ```
230    pub const fn danger_accept_invalid_hostnames(mut self, accept: bool) -> Self {
231        self.danger_accept_invalid_hostnames = accept;
232        self
233    }
234
235    /// Build the TLS configuration.
236    ///
237    /// # Errors
238    ///
239    /// Returns an error if:
240    /// - CA certificate file cannot be read
241    /// - CA certificate is invalid PEM
242    /// - Dangerous options are configured incorrectly
243    ///
244    /// # Examples
245    ///
246    /// ```no_run
247    /// // Requires: system root certificates.
248    /// use fraiseql_wire::connection::TlsConfig;
249    /// let tls = TlsConfig::builder()
250    ///     .verify_hostname(true)
251    ///     .build()?;
252    /// # fraiseql_wire::Result::Ok(())
253    /// ```
254    pub fn build(self) -> Result<TlsConfig> {
255        // SECURITY: Validate TLS configuration before creating client
256        validate_tls_security(self.danger_accept_invalid_certs)?;
257
258        let client_config = if self.danger_accept_invalid_certs {
259            // Create a client config that accepts any certificate (development only)
260            let verifier = Arc::new(NoVerifier);
261            Arc::new(
262                ClientConfig::builder()
263                    .dangerous()
264                    .with_custom_certificate_verifier(verifier)
265                    .with_no_client_auth(),
266            )
267        } else {
268            // Load root certificates
269            let root_store = if let Some(ca_path) = &self.ca_cert_path {
270                // Load custom CA certificate from file
271                self.load_custom_ca(ca_path)?
272            } else {
273                // Use system root certificates via rustls-native-certs
274                let result = rustls_native_certs::load_native_certs();
275
276                let mut store = RootCertStore::empty();
277                for cert in result.certs {
278                    let _ = store.add_parsable_certificates(std::iter::once(cert));
279                }
280
281                // Log warnings if there were errors, but don't fail
282                if !result.errors.is_empty() && store.is_empty() {
283                    return Err(WireError::Config(
284                        "Failed to load any system root certificates".to_string(),
285                    ));
286                }
287
288                store
289            };
290
291            // Create ClientConfig using the correct API for rustls 0.23
292            Arc::new(
293                ClientConfig::builder()
294                    .with_root_certificates(root_store)
295                    .with_no_client_auth(),
296            )
297        };
298
299        Ok(TlsConfig {
300            ca_cert_path: self.ca_cert_path,
301            verify_hostname: self.verify_hostname,
302            danger_accept_invalid_certs: self.danger_accept_invalid_certs,
303            danger_accept_invalid_hostnames: self.danger_accept_invalid_hostnames,
304            client_config,
305        })
306    }
307
308    /// Load a custom CA certificate from a PEM file.
309    fn load_custom_ca(&self, ca_path: &str) -> Result<RootCertStore> {
310        let ca_cert_data = fs::read(ca_path).map_err(|e| {
311            WireError::Config(format!(
312                "Failed to read CA certificate file '{}': {}",
313                ca_path, e
314            ))
315        })?;
316
317        let mut reader = std::io::Cursor::new(&ca_cert_data);
318        let mut root_store = RootCertStore::empty();
319        let mut found_certs = 0;
320
321        // Parse PEM file and extract certificates
322        loop {
323            match rustls_pemfile::read_one(&mut reader) {
324                Ok(Some(Item::X509Certificate(cert))) => {
325                    let _ = root_store.add_parsable_certificates(std::iter::once(cert));
326                    found_certs += 1;
327                }
328                Ok(Some(_)) => {
329                    // Skip non-certificate items (private keys, etc.)
330                }
331                Ok(None) => {
332                    // End of file
333                    break;
334                }
335                Err(_) => {
336                    return Err(WireError::Config(format!(
337                        "Failed to parse CA certificate from '{}'",
338                        ca_path
339                    )));
340                }
341            }
342        }
343
344        if found_certs == 0 {
345            return Err(WireError::Config(format!(
346                "No valid certificates found in '{}'",
347                ca_path
348            )));
349        }
350
351        Ok(root_store)
352    }
353}
354
355/// Validate TLS configuration for security constraints.
356///
357/// Enforces that release builds cannot use `danger_accept_invalid_certs`.
358/// Development builds emit a warning but proceed.
359///
360/// # Arguments
361///
362/// * `danger_accept_invalid_certs` - Whether danger mode is enabled
363///
364/// # Errors
365///
366/// Returns `WireError::Config` if `danger_accept_invalid_certs` is set in a release build.
367fn validate_tls_security(danger_accept_invalid_certs: bool) -> Result<()> {
368    if danger_accept_invalid_certs {
369        // SECURITY: Return an error in release builds to prevent accidental production use
370        #[cfg(not(debug_assertions))]
371        return Err(WireError::Config(
372            "TLS certificate validation bypass not permitted in release builds".into(),
373        ));
374
375        // Development builds: warn but allow
376        #[cfg(debug_assertions)]
377        {
378            tracing::warn!("TLS certificate validation is DISABLED (development only)");
379            tracing::warn!("This mode is only for development with self-signed certificates");
380        }
381    }
382    Ok(())
383}
384
385/// Parse server name from hostname for TLS SNI (Server Name Indication).
386///
387/// # Arguments
388///
389/// * `hostname` - Hostname to parse (without port)
390///
391/// # Returns
392///
393/// A string suitable for TLS server name indication
394///
395/// # Errors
396///
397/// Returns an error if the hostname is invalid.
398pub fn parse_server_name(hostname: &str) -> Result<String> {
399    // Remove trailing dot if present
400    let hostname = hostname.trim_end_matches('.');
401
402    // Validate hostname (basic check)
403    if hostname.is_empty() || hostname.len() > 253 {
404        return Err(WireError::Config(format!(
405            "Invalid hostname for TLS: '{}'",
406            hostname
407        )));
408    }
409
410    // Check for invalid characters
411    if !hostname
412        .chars()
413        .all(|c| c.is_alphanumeric() || c == '-' || c == '.')
414    {
415        return Err(WireError::Config(format!(
416            "Invalid hostname for TLS: '{}'",
417            hostname
418        )));
419    }
420
421    Ok(hostname.to_string())
422}
423
424#[cfg(test)]
425mod tests {
426    use super::*;
427
428    /// Install a crypto provider for rustls tests.
429    /// This is needed because multiple crypto providers (ring and aws-lc-rs)
430    /// may be enabled via transitive dependencies, requiring explicit selection.
431    fn install_crypto_provider() {
432        // Try to install ring as the default provider, ignore if already installed
433        let _ = rustls::crypto::ring::default_provider().install_default();
434    }
435
436    #[test]
437    fn test_tls_config_builder_defaults() {
438        let tls = TlsConfigBuilder::default();
439        assert!(!tls.danger_accept_invalid_certs);
440        assert!(!tls.danger_accept_invalid_hostnames);
441        assert!(tls.verify_hostname);
442        assert!(tls.ca_cert_path.is_none());
443    }
444
445    #[test]
446    fn test_tls_config_builder_with_hostname_verification() {
447        install_crypto_provider();
448
449        let tls = TlsConfig::builder()
450            .verify_hostname(true)
451            .build()
452            .expect("Failed to build TLS config");
453
454        assert!(tls.verify_hostname());
455        assert!(!tls.danger_accept_invalid_certs());
456    }
457
458    #[test]
459    fn test_tls_config_builder_with_custom_ca() {
460        // This test would require an actual PEM file
461    }
462
463    #[test]
464    fn test_parse_server_name_valid() {
465        let _name =
466            parse_server_name("localhost").expect("localhost should be a valid server name");
467        let _name =
468            parse_server_name("example.com").expect("example.com should be a valid server name");
469        let _name = parse_server_name("db.internal.example.com")
470            .expect("subdomain should be a valid server name");
471    }
472
473    #[test]
474    fn test_parse_server_name_trailing_dot() {
475        let _name = parse_server_name("example.com.")
476            .expect("trailing dot should be accepted as valid server name");
477    }
478
479    #[test]
480    fn test_parse_server_name_with_port() {
481        // ServerName expects just hostname, not host:port.
482        // Whether this succeeds or fails depends on the rustls version,
483        // so we only verify it doesn't panic.
484        let _result = parse_server_name("example.com:5432");
485    }
486
487    #[test]
488    fn test_tls_config_debug() {
489        install_crypto_provider();
490
491        let tls = TlsConfig::builder()
492            .verify_hostname(true)
493            .build()
494            .expect("Failed to build TLS config");
495
496        let debug_str = format!("{:?}", tls);
497        assert!(debug_str.contains("TlsConfig"));
498        assert!(debug_str.contains("verify_hostname"));
499    }
500
501    #[test]
502    #[cfg(not(debug_assertions))]
503    fn test_danger_mode_returns_error_in_release_build() {
504        // This test only runs in release builds; danger mode must return an error
505        let result = TlsConfig::builder()
506            .danger_accept_invalid_certs(true)
507            .build();
508        assert!(
509            result.is_err(),
510            "danger mode must be rejected in release builds"
511        );
512        let err = result.unwrap_err();
513        assert!(
514            err.to_string().contains("not permitted in release builds"),
515            "error message must explain the restriction",
516        );
517    }
518
519    #[test]
520    fn test_danger_mode_allowed_in_debug_build() {
521        install_crypto_provider();
522
523        let config = TlsConfig::builder()
524            .danger_accept_invalid_certs(true)
525            .build()
526            .expect("danger mode should be allowed in debug builds");
527
528        assert!(config.danger_accept_invalid_certs());
529    }
530
531    #[test]
532    fn test_normal_tls_config_works() {
533        install_crypto_provider();
534
535        let config = TlsConfig::builder()
536            .verify_hostname(true)
537            .build()
538            .expect("normal TLS config should build successfully");
539
540        assert!(!config.danger_accept_invalid_certs());
541    }
542}
543
544/// A certificate verifier that accepts any certificate.
545///
546/// ⚠️ **DANGER**: This should ONLY be used for development/testing with self-signed certificates.
547/// Using this in production is a serious security vulnerability.
548#[derive(Debug)]
549struct NoVerifier;
550
551impl ServerCertVerifier for NoVerifier {
552    fn verify_server_cert(
553        &self,
554        _end_entity: &CertificateDer<'_>,
555        _intermediates: &[CertificateDer<'_>],
556        _server_name: &ServerName<'_>,
557        _ocsp_response: &[u8],
558        _now: UnixTime,
559    ) -> std::result::Result<ServerCertVerified, rustls::Error> {
560        // Accept any certificate
561        Ok(ServerCertVerified::assertion())
562    }
563
564    fn verify_tls12_signature(
565        &self,
566        _message: &[u8],
567        _cert: &CertificateDer<'_>,
568        _dss: &DigitallySignedStruct,
569    ) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
570        Ok(HandshakeSignatureValid::assertion())
571    }
572
573    fn verify_tls13_signature(
574        &self,
575        _message: &[u8],
576        _cert: &CertificateDer<'_>,
577        _dss: &DigitallySignedStruct,
578    ) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
579        Ok(HandshakeSignatureValid::assertion())
580    }
581
582    fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
583        // Support all common signature schemes
584        vec![
585            SignatureScheme::RSA_PKCS1_SHA256,
586            SignatureScheme::RSA_PKCS1_SHA384,
587            SignatureScheme::RSA_PKCS1_SHA512,
588            SignatureScheme::ECDSA_NISTP256_SHA256,
589            SignatureScheme::ECDSA_NISTP384_SHA384,
590            SignatureScheme::ECDSA_NISTP521_SHA512,
591            SignatureScheme::RSA_PSS_SHA256,
592            SignatureScheme::RSA_PSS_SHA384,
593            SignatureScheme::RSA_PSS_SHA512,
594            SignatureScheme::ED25519,
595        ]
596    }
597}