Skip to main content

fraiseql_wire/connection/
tls.rs

1//! TLS configuration and support for secure connections to Postgres.
2//!
3//! This module provides TLS configuration for connecting to remote Postgres servers.
4//! TLS is recommended for all non-local connections to prevent credential interception.
5
6use crate::{Error, Result};
7use rustls::ClientConfig;
8use rustls::RootCertStore;
9use rustls_pemfile::Item;
10use std::fs;
11use std::sync::Arc;
12
13/// SSL/TLS connection mode matching PostgreSQL `sslmode` parameter.
14///
15/// Controls whether and how TLS is negotiated with the server.
16#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
17pub enum SslMode {
18    /// No TLS (plaintext connection)
19    #[default]
20    Disable,
21    /// TLS required, but server certificate is not verified
22    Require,
23    /// TLS required, server certificate must be signed by a trusted CA
24    VerifyCa,
25    /// TLS required, server certificate must be signed by a trusted CA and hostname must match
26    VerifyFull,
27}
28
29impl SslMode {
30    /// Whether this mode requires certificate verification (CA or full)
31    pub fn requires_verification(&self) -> bool {
32        matches!(self, Self::VerifyCa | Self::VerifyFull)
33    }
34}
35
36impl std::fmt::Display for SslMode {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        match self {
39            Self::Disable => write!(f, "disable"),
40            Self::Require => write!(f, "require"),
41            Self::VerifyCa => write!(f, "verify-ca"),
42            Self::VerifyFull => write!(f, "verify-full"),
43        }
44    }
45}
46
47impl std::str::FromStr for SslMode {
48    type Err = Error;
49
50    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
51        match s {
52            "disable" => Ok(Self::Disable),
53            "require" => Ok(Self::Require),
54            "verify-ca" => Ok(Self::VerifyCa),
55            "verify-full" => Ok(Self::VerifyFull),
56            _ => Err(Error::Config(format!(
57                "invalid sslmode '{}': expected disable, require, verify-ca, or verify-full",
58                s
59            ))),
60        }
61    }
62}
63
64/// TLS configuration for secure Postgres connections.
65///
66/// Provides a builder for creating TLS configurations with various certificate handling options.
67/// By default, server certificates are validated against system root certificates.
68///
69/// # Examples
70///
71/// ```ignore
72/// use fraiseql_wire::connection::TlsConfig;
73///
74/// // With system root certificates (production)
75/// let tls = TlsConfig::builder()
76///     .verify_hostname(true)
77///     .build()?;
78///
79/// // With custom CA certificate
80/// let tls = TlsConfig::builder()
81///     .ca_cert_path("/path/to/ca.pem")?
82///     .verify_hostname(true)
83///     .build()?;
84///
85/// // For development (danger: disables verification)
86/// let tls = TlsConfig::builder()
87///     .danger_accept_invalid_certs(true)
88///     .danger_accept_invalid_hostnames(true)
89///     .build()?;
90/// ```
91#[derive(Clone)]
92pub struct TlsConfig {
93    /// Path to CA certificate file (None = use system roots)
94    ca_cert_path: Option<String>,
95    /// Whether to verify hostname matches certificate
96    verify_hostname: bool,
97    /// Whether to accept invalid certificates (development only)
98    danger_accept_invalid_certs: bool,
99    /// Whether to accept invalid hostnames (development only)
100    danger_accept_invalid_hostnames: bool,
101    /// Compiled rustls ClientConfig
102    client_config: Arc<ClientConfig>,
103}
104
105impl TlsConfig {
106    /// Create a new TLS configuration builder.
107    ///
108    /// # Examples
109    ///
110    /// ```ignore
111    /// let tls = TlsConfig::builder()
112    ///     .verify_hostname(true)
113    ///     .build()?;
114    /// ```
115    pub fn builder() -> TlsConfigBuilder {
116        TlsConfigBuilder::default()
117    }
118
119    /// Get the rustls ClientConfig for this TLS configuration.
120    pub fn client_config(&self) -> Arc<ClientConfig> {
121        self.client_config.clone()
122    }
123
124    /// Check if hostname verification is enabled.
125    pub fn verify_hostname(&self) -> bool {
126        self.verify_hostname
127    }
128
129    /// Check if invalid certificates are accepted (development only).
130    pub fn danger_accept_invalid_certs(&self) -> bool {
131        self.danger_accept_invalid_certs
132    }
133
134    /// Check if invalid hostnames are accepted (development only).
135    pub fn danger_accept_invalid_hostnames(&self) -> bool {
136        self.danger_accept_invalid_hostnames
137    }
138}
139
140impl std::fmt::Debug for TlsConfig {
141    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
142        f.debug_struct("TlsConfig")
143            .field("ca_cert_path", &self.ca_cert_path)
144            .field("verify_hostname", &self.verify_hostname)
145            .field(
146                "danger_accept_invalid_certs",
147                &self.danger_accept_invalid_certs,
148            )
149            .field(
150                "danger_accept_invalid_hostnames",
151                &self.danger_accept_invalid_hostnames,
152            )
153            .field("client_config", &"<ClientConfig>")
154            .finish()
155    }
156}
157
158/// Builder for TLS configuration.
159///
160/// Provides a fluent API for constructing TLS configurations with custom settings.
161pub struct TlsConfigBuilder {
162    ca_cert_path: Option<String>,
163    /// Path to client certificate file (PEM format, for mTLS)
164    pub(crate) client_cert_path: Option<String>,
165    /// Path to client private key file (PEM format, for mTLS)
166    pub(crate) client_key_path: Option<String>,
167    verify_hostname: bool,
168    danger_accept_invalid_certs: bool,
169    danger_accept_invalid_hostnames: bool,
170}
171
172impl Default for TlsConfigBuilder {
173    fn default() -> Self {
174        Self {
175            ca_cert_path: None,
176            client_cert_path: None,
177            client_key_path: None,
178            verify_hostname: true,
179            danger_accept_invalid_certs: false,
180            danger_accept_invalid_hostnames: false,
181        }
182    }
183}
184
185impl TlsConfigBuilder {
186    /// Set the path to a custom CA certificate file (PEM format).
187    ///
188    /// If not set, system root certificates will be used.
189    ///
190    /// # Arguments
191    ///
192    /// * `path` - Path to CA certificate file in PEM format
193    ///
194    /// # Examples
195    ///
196    /// ```ignore
197    /// let tls = TlsConfig::builder()
198    ///     .ca_cert_path("/etc/ssl/certs/ca.pem")?
199    ///     .build()?;
200    /// ```
201    pub fn ca_cert_path(mut self, path: impl Into<String>) -> Self {
202        self.ca_cert_path = Some(path.into());
203        self
204    }
205
206    /// Enable or disable hostname verification (default: enabled).
207    ///
208    /// When enabled, the certificate's subject alternative names (SANs) are verified
209    /// to match the server hostname.
210    ///
211    /// # Arguments
212    ///
213    /// * `verify` - Whether to verify hostname matches certificate
214    ///
215    /// # Examples
216    ///
217    /// ```ignore
218    /// let tls = TlsConfig::builder()
219    ///     .verify_hostname(true)
220    ///     .build()?;
221    /// ```
222    pub fn verify_hostname(mut self, verify: bool) -> Self {
223        self.verify_hostname = verify;
224        self
225    }
226
227    /// ⚠️ **DANGER**: Accept invalid certificates (development only).
228    ///
229    /// **NEVER use in production.** This disables certificate validation entirely,
230    /// making the connection vulnerable to man-in-the-middle attacks.
231    ///
232    /// Only use for testing with self-signed certificates.
233    ///
234    /// # Examples
235    ///
236    /// ```ignore
237    /// let tls = TlsConfig::builder()
238    ///     .danger_accept_invalid_certs(true)
239    ///     .build()?;
240    /// ```
241    pub fn danger_accept_invalid_certs(mut self, accept: bool) -> Self {
242        self.danger_accept_invalid_certs = accept;
243        self
244    }
245
246    /// ⚠️ **DANGER**: Accept invalid hostnames (development only).
247    ///
248    /// **NEVER use in production.** This disables hostname verification,
249    /// making the connection vulnerable to man-in-the-middle attacks.
250    ///
251    /// Only use for testing with self-signed certificates where you can't
252    /// match the hostname.
253    ///
254    /// # Examples
255    ///
256    /// ```ignore
257    /// let tls = TlsConfig::builder()
258    ///     .danger_accept_invalid_hostnames(true)
259    ///     .build()?;
260    /// ```
261    pub fn danger_accept_invalid_hostnames(mut self, accept: bool) -> Self {
262        self.danger_accept_invalid_hostnames = accept;
263        self
264    }
265
266    /// Set the path to a client certificate file (PEM format) for mutual TLS.
267    ///
268    /// Must be paired with `client_key_path`.
269    pub fn client_cert_path(mut self, path: impl Into<String>) -> Self {
270        self.client_cert_path = Some(path.into());
271        self
272    }
273
274    /// Set the path to a client private key file (PEM format) for mutual TLS.
275    ///
276    /// Must be paired with `client_cert_path`.
277    pub fn client_key_path(mut self, path: impl Into<String>) -> Self {
278        self.client_key_path = Some(path.into());
279        self
280    }
281
282    /// Build the TLS configuration.
283    ///
284    /// # Errors
285    ///
286    /// Returns an error if:
287    /// - CA certificate file cannot be read
288    /// - CA certificate is invalid PEM
289    /// - Dangerous options are configured incorrectly
290    ///
291    /// # Examples
292    ///
293    /// ```ignore
294    /// let tls = TlsConfig::builder()
295    ///     .verify_hostname(true)
296    ///     .build()?;
297    /// ```
298    pub fn build(self) -> Result<TlsConfig> {
299        // Load root certificates
300        let root_store = if let Some(ca_path) = &self.ca_cert_path {
301            // Load custom CA certificate from file
302            self.load_custom_ca(ca_path)?
303        } else {
304            // Use system root certificates via rustls-native-certs
305            let result = rustls_native_certs::load_native_certs();
306
307            let mut store = RootCertStore::empty();
308            for cert in result.certs {
309                let _ = store.add_parsable_certificates(std::iter::once(cert));
310            }
311
312            // Log warnings if there were errors, but don't fail
313            if !result.errors.is_empty() && store.is_empty() {
314                return Err(Error::Config(
315                    "Failed to load any system root certificates".to_string(),
316                ));
317            }
318
319            store
320        };
321
322        // Create ClientConfig with or without client auth (mTLS)
323        let client_config = match (&self.client_cert_path, &self.client_key_path) {
324            (Some(cert_path), Some(key_path)) => {
325                let certs = self.load_client_certs(cert_path)?;
326                let key = self.load_client_key(key_path)?;
327                Arc::new(
328                    ClientConfig::builder()
329                        .with_root_certificates(root_store)
330                        .with_client_auth_cert(certs, key)
331                        .map_err(|e| {
332                            Error::Config(format!("invalid client certificate/key: {}", e))
333                        })?,
334                )
335            }
336            (Some(_), None) => {
337                return Err(Error::Config(
338                    "client certificate provided without client key (sslkey)".to_string(),
339                ));
340            }
341            (None, Some(_)) => {
342                return Err(Error::Config(
343                    "client key provided without client certificate (sslcert)".to_string(),
344                ));
345            }
346            (None, None) => Arc::new(
347                ClientConfig::builder()
348                    .with_root_certificates(root_store)
349                    .with_no_client_auth(),
350            ),
351        };
352
353        Ok(TlsConfig {
354            ca_cert_path: self.ca_cert_path,
355            verify_hostname: self.verify_hostname,
356            danger_accept_invalid_certs: self.danger_accept_invalid_certs,
357            danger_accept_invalid_hostnames: self.danger_accept_invalid_hostnames,
358            client_config,
359        })
360    }
361
362    /// Load client certificate chain from a PEM file.
363    fn load_client_certs(
364        &self,
365        cert_path: &str,
366    ) -> Result<Vec<rustls_pki_types::CertificateDer<'static>>> {
367        let cert_data = fs::read(cert_path).map_err(|e| {
368            Error::Config(format!(
369                "failed to read client certificate '{}': {}",
370                cert_path, e
371            ))
372        })?;
373
374        let mut reader = std::io::Cursor::new(&cert_data);
375        let mut certs = Vec::new();
376
377        loop {
378            match rustls_pemfile::read_one(&mut reader) {
379                Ok(Some(Item::X509Certificate(cert))) => certs.push(cert),
380                Ok(Some(_)) => {}
381                Ok(None) => break,
382                Err(_) => {
383                    return Err(Error::Config(format!(
384                        "failed to parse client certificate from '{}'",
385                        cert_path
386                    )));
387                }
388            }
389        }
390
391        if certs.is_empty() {
392            return Err(Error::Config(format!(
393                "no valid certificates found in '{}'",
394                cert_path
395            )));
396        }
397
398        Ok(certs)
399    }
400
401    /// Load client private key from a PEM file.
402    fn load_client_key(&self, key_path: &str) -> Result<rustls_pki_types::PrivateKeyDer<'static>> {
403        let key_data = fs::read(key_path).map_err(|e| {
404            Error::Config(format!("failed to read client key '{}': {}", key_path, e))
405        })?;
406
407        let mut reader = std::io::Cursor::new(&key_data);
408
409        loop {
410            match rustls_pemfile::read_one(&mut reader) {
411                Ok(Some(Item::Pkcs1Key(key))) => {
412                    return Ok(rustls_pki_types::PrivateKeyDer::Pkcs1(key));
413                }
414                Ok(Some(Item::Pkcs8Key(key))) => {
415                    return Ok(rustls_pki_types::PrivateKeyDer::Pkcs8(key));
416                }
417                Ok(Some(Item::Sec1Key(key))) => {
418                    return Ok(rustls_pki_types::PrivateKeyDer::Sec1(key));
419                }
420                Ok(Some(_)) => {}
421                Ok(None) => break,
422                Err(_) => {
423                    return Err(Error::Config(format!(
424                        "failed to parse client key from '{}'",
425                        key_path
426                    )));
427                }
428            }
429        }
430
431        Err(Error::Config(format!(
432            "no valid private key found in '{}'",
433            key_path
434        )))
435    }
436
437    /// Load a custom CA certificate from a PEM file.
438    fn load_custom_ca(&self, ca_path: &str) -> Result<RootCertStore> {
439        let ca_cert_data = fs::read(ca_path).map_err(|e| {
440            Error::Config(format!(
441                "Failed to read CA certificate file '{}': {}",
442                ca_path, e
443            ))
444        })?;
445
446        let mut reader = std::io::Cursor::new(&ca_cert_data);
447        let mut root_store = RootCertStore::empty();
448        let mut found_certs = 0;
449
450        // Parse PEM file and extract certificates
451        loop {
452            match rustls_pemfile::read_one(&mut reader) {
453                Ok(Some(Item::X509Certificate(cert))) => {
454                    let _ = root_store.add_parsable_certificates(std::iter::once(cert));
455                    found_certs += 1;
456                }
457                Ok(Some(_)) => {
458                    // Skip non-certificate items (private keys, etc.)
459                }
460                Ok(None) => {
461                    // End of file
462                    break;
463                }
464                Err(_) => {
465                    return Err(Error::Config(format!(
466                        "Failed to parse CA certificate from '{}'",
467                        ca_path
468                    )));
469                }
470            }
471        }
472
473        if found_certs == 0 {
474            return Err(Error::Config(format!(
475                "No valid certificates found in '{}'",
476                ca_path
477            )));
478        }
479
480        Ok(root_store)
481    }
482}
483
484/// Parse server name from hostname for TLS SNI (Server Name Indication).
485///
486/// # Arguments
487///
488/// * `hostname` - Hostname to parse (without port)
489///
490/// # Returns
491///
492/// A string suitable for TLS server name indication
493///
494/// # Errors
495///
496/// Returns an error if the hostname is invalid.
497pub fn parse_server_name(hostname: &str) -> Result<String> {
498    // Remove trailing dot if present
499    let hostname = hostname.trim_end_matches('.');
500
501    // Validate hostname (basic check)
502    if hostname.is_empty() || hostname.len() > 253 {
503        return Err(Error::Config(format!(
504            "Invalid hostname for TLS: '{}'",
505            hostname
506        )));
507    }
508
509    // Check for invalid characters
510    if !hostname
511        .chars()
512        .all(|c| c.is_alphanumeric() || c == '-' || c == '.')
513    {
514        return Err(Error::Config(format!(
515            "Invalid hostname for TLS: '{}'",
516            hostname
517        )));
518    }
519
520    Ok(hostname.to_string())
521}
522
523#[cfg(test)]
524mod tests {
525    use super::*;
526
527    #[test]
528    fn test_tls_config_builder_defaults() {
529        let tls = TlsConfigBuilder::default();
530        assert!(!tls.danger_accept_invalid_certs);
531        assert!(!tls.danger_accept_invalid_hostnames);
532        assert!(tls.verify_hostname);
533        assert!(tls.ca_cert_path.is_none());
534    }
535
536    #[test]
537    fn test_tls_config_builder_with_hostname_verification() {
538        let tls = TlsConfig::builder()
539            .verify_hostname(true)
540            .build()
541            .expect("Failed to build TLS config");
542
543        assert!(tls.verify_hostname());
544        assert!(!tls.danger_accept_invalid_certs());
545    }
546
547    #[test]
548    fn test_tls_config_builder_with_custom_ca() {
549        // This test would require an actual PEM file
550        // Skipping for now as it requires filesystem setup
551    }
552
553    #[test]
554    fn test_parse_server_name_valid() {
555        let result = parse_server_name("localhost");
556        assert!(result.is_ok());
557
558        let result = parse_server_name("example.com");
559        assert!(result.is_ok());
560
561        let result = parse_server_name("db.internal.example.com");
562        assert!(result.is_ok());
563    }
564
565    #[test]
566    fn test_parse_server_name_trailing_dot() {
567        let result = parse_server_name("example.com.");
568        assert!(result.is_ok());
569    }
570
571    #[test]
572    fn test_parse_server_name_with_port_fails() {
573        // ServerName expects just hostname, not host:port
574        let result = parse_server_name("example.com:5432");
575        // This might actually succeed or fail depending on rustls version
576        // Just ensure it doesn't panic
577        let _ = result;
578    }
579
580    #[test]
581    fn test_ssl_mode_from_str() {
582        assert_eq!("disable".parse::<SslMode>().unwrap(), SslMode::Disable);
583        assert_eq!("require".parse::<SslMode>().unwrap(), SslMode::Require);
584        assert_eq!("verify-ca".parse::<SslMode>().unwrap(), SslMode::VerifyCa);
585        assert_eq!(
586            "verify-full".parse::<SslMode>().unwrap(),
587            SslMode::VerifyFull
588        );
589    }
590
591    #[test]
592    fn test_ssl_mode_from_str_invalid() {
593        assert!("invalid".parse::<SslMode>().is_err());
594        assert!("prefer".parse::<SslMode>().is_err());
595    }
596
597    #[test]
598    fn test_ssl_mode_display() {
599        assert_eq!(SslMode::Disable.to_string(), "disable");
600        assert_eq!(SslMode::Require.to_string(), "require");
601        assert_eq!(SslMode::VerifyCa.to_string(), "verify-ca");
602        assert_eq!(SslMode::VerifyFull.to_string(), "verify-full");
603    }
604
605    #[test]
606    fn test_ssl_mode_default() {
607        assert_eq!(SslMode::default(), SslMode::Disable);
608    }
609
610    #[test]
611    fn test_ssl_mode_requires_verification() {
612        assert!(!SslMode::Disable.requires_verification());
613        assert!(!SslMode::Require.requires_verification());
614        assert!(SslMode::VerifyCa.requires_verification());
615        assert!(SslMode::VerifyFull.requires_verification());
616    }
617
618    #[test]
619    fn test_tls_config_builder_with_client_cert_methods() {
620        // Verify builder API accepts client cert and key paths
621        let builder = TlsConfig::builder()
622            .client_cert_path("/path/to/client.pem")
623            .client_key_path("/path/to/client-key.pem");
624        assert_eq!(
625            builder.client_cert_path.as_deref(),
626            Some("/path/to/client.pem")
627        );
628        assert_eq!(
629            builder.client_key_path.as_deref(),
630            Some("/path/to/client-key.pem")
631        );
632    }
633
634    #[test]
635    fn test_tls_config_builder_client_cert_without_key_fails() {
636        // Providing a client cert without a key should fail
637        let result = TlsConfig::builder()
638            .client_cert_path("/path/to/client.pem")
639            .build();
640        assert!(result.is_err());
641    }
642
643    #[test]
644    fn test_tls_config_builder_client_key_without_cert_fails() {
645        // Providing a client key without a cert should fail
646        let result = TlsConfig::builder()
647            .client_key_path("/path/to/client-key.pem")
648            .build();
649        assert!(result.is_err());
650    }
651
652    #[test]
653    fn test_tls_config_debug() {
654        let tls = TlsConfig::builder()
655            .verify_hostname(true)
656            .build()
657            .expect("Failed to build TLS config");
658
659        let debug_str = format!("{:?}", tls);
660        assert!(debug_str.contains("TlsConfig"));
661        assert!(debug_str.contains("verify_hostname"));
662    }
663}