Skip to main content

mockforge_http/
tls.rs

1//! TLS/HTTPS support for HTTP server
2//!
3//! This module provides TLS configuration and certificate loading for secure HTTP connections.
4
5use mockforge_core::config::HttpTlsConfig;
6use mockforge_core::Result;
7use std::sync::Arc;
8use std::sync::Once;
9use tokio_rustls::TlsAcceptor;
10use tracing::info;
11
12static CRYPTO_INIT: Once = Once::new();
13
14/// Initialize the rustls crypto provider.
15///
16/// This must be called before any TLS operations. It is safe to call multiple times
17/// as it uses `Once` to ensure initialization happens exactly once.
18///
19/// Uses the `ring` crypto provider for rustls.
20pub fn init_crypto_provider() {
21    CRYPTO_INIT.call_once(|| {
22        // Install the ring crypto provider as the default for rustls
23        let _ = rustls::crypto::ring::default_provider().install_default();
24    });
25}
26
27/// Create a rustls ServerConfig builder with the appropriate TLS protocol versions.
28///
29/// When `tls13_only` is true, restricts to TLS 1.3 only. Otherwise uses safe defaults (TLS 1.2+).
30fn tls_config_builder(
31    tls13_only: bool,
32) -> rustls::ConfigBuilder<rustls::server::ServerConfig, rustls::WantsVerifier> {
33    if tls13_only {
34        rustls::server::ServerConfig::builder_with_protocol_versions(&[&rustls::version::TLS13])
35    } else {
36        rustls::server::ServerConfig::builder()
37    }
38}
39
40/// Determine if TLS 1.3 only mode is requested from configuration.
41fn is_tls13_only(min_version: &str) -> bool {
42    match min_version {
43        "1.3" => {
44            info!("Enforcing TLS 1.3 only (min_version=1.3)");
45            true
46        }
47        "1.2" | "" => false,
48        other => {
49            tracing::warn!("Unsupported TLS min_version '{}', using defaults (TLS 1.2+)", other);
50            false
51        }
52    }
53}
54
55/// Load TLS acceptor from certificate and key files
56///
57/// This function loads server certificates and private keys from PEM files
58/// and creates a TLS acceptor for use with the HTTP server.
59///
60/// For mutual TLS (mTLS), provide a CA certificate file via `ca_file`.
61pub fn load_tls_acceptor(config: &HttpTlsConfig) -> Result<TlsAcceptor> {
62    use rustls_pemfile::{certs, pkcs8_private_keys};
63    use std::fs::File;
64    use std::io::BufReader;
65
66    // Ensure crypto provider is initialized
67    init_crypto_provider();
68
69    info!(
70        "Loading TLS certificate from {} and key from {}",
71        config.cert_file, config.key_file
72    );
73
74    // Load certificate chain
75    let cert_file = File::open(&config.cert_file).map_err(|e| {
76        mockforge_core::Error::generic(format!(
77            "Failed to open certificate file {}: {}",
78            config.cert_file, e
79        ))
80    })?;
81    let mut cert_reader = BufReader::new(cert_file);
82    let server_certs: Vec<rustls::pki_types::CertificateDer<'static>> = certs(&mut cert_reader)
83        .collect::<std::result::Result<Vec<_>, _>>()
84        .map_err(|e| {
85            mockforge_core::Error::generic(format!(
86                "Failed to parse certificate file {}: {}",
87                config.cert_file, e
88            ))
89        })?;
90
91    if server_certs.is_empty() {
92        return Err(mockforge_core::Error::generic(format!(
93            "No certificates found in {}",
94            config.cert_file
95        )));
96    }
97
98    // Load private key
99    let key_file = File::open(&config.key_file).map_err(|e| {
100        mockforge_core::Error::generic(format!(
101            "Failed to open private key file {}: {}",
102            config.key_file, e
103        ))
104    })?;
105    let mut key_reader = BufReader::new(key_file);
106    let pkcs8_keys: Vec<rustls::pki_types::PrivatePkcs8KeyDer<'static>> =
107        pkcs8_private_keys(&mut key_reader)
108            .collect::<std::result::Result<Vec<_>, _>>()
109            .map_err(|e| {
110                mockforge_core::Error::generic(format!(
111                    "Failed to parse private key file {}: {}",
112                    config.key_file, e
113                ))
114            })?;
115    let mut keys: Vec<rustls::pki_types::PrivateKeyDer<'static>> =
116        pkcs8_keys.into_iter().map(rustls::pki_types::PrivateKeyDer::Pkcs8).collect();
117
118    if keys.is_empty() {
119        return Err(mockforge_core::Error::generic(format!(
120            "No private keys found in {}",
121            config.key_file
122        )));
123    }
124
125    // Build TLS server configuration with version support
126    // Note: rustls uses safe defaults, so we configure during builder creation
127    // Determine mTLS mode: use mtls_mode if set, otherwise fall back to require_client_cert for backward compatibility
128    let mtls_mode = if !config.mtls_mode.is_empty() && config.mtls_mode != "off" {
129        config.mtls_mode.as_str()
130    } else if config.require_client_cert {
131        "required"
132    } else {
133        "off"
134    };
135
136    // Determine TLS protocol versions based on min_version config
137    let tls13_only = is_tls13_only(&config.min_version);
138
139    let server_config = match mtls_mode {
140        "required" => {
141            // Mutual TLS: require client certificates
142            if let Some(ref ca_file_path) = config.ca_file {
143                // Load CA certificate for client verification
144                let ca_file = File::open(ca_file_path).map_err(|e| {
145                    mockforge_core::Error::generic(format!(
146                        "Failed to open CA certificate file {}: {}",
147                        ca_file_path, e
148                    ))
149                })?;
150                let mut ca_reader = BufReader::new(ca_file);
151                let ca_certs: Vec<rustls::pki_types::CertificateDer<'static>> =
152                    certs(&mut ca_reader).collect::<std::result::Result<Vec<_>, _>>().map_err(
153                        |e| {
154                            mockforge_core::Error::generic(format!(
155                                "Failed to parse CA certificate file {}: {}",
156                                ca_file_path, e
157                            ))
158                        },
159                    )?;
160
161                let mut root_store = rustls::RootCertStore::empty();
162                for cert in &ca_certs {
163                    root_store.add(cert.clone()).map_err(|e| {
164                        mockforge_core::Error::generic(format!(
165                            "Failed to add CA certificate to root store: {}",
166                            e
167                        ))
168                    })?;
169                }
170
171                let client_verifier =
172                    rustls::server::WebPkiClientVerifier::builder(Arc::new(root_store))
173                        .build()
174                        .map_err(|e| {
175                            mockforge_core::Error::generic(format!(
176                                "Failed to build client verifier: {}",
177                                e
178                            ))
179                        })?;
180
181                let key = keys.remove(0);
182
183                // Build with mTLS support (required)
184                tls_config_builder(tls13_only)
185                    .with_client_cert_verifier(client_verifier)
186                    .with_single_cert(server_certs, key)
187                    .map_err(|e| {
188                        mockforge_core::Error::generic(format!(
189                            "TLS config error (mTLS required): {}",
190                            e
191                        ))
192                    })?
193            } else {
194                return Err(mockforge_core::Error::generic(
195                    "mTLS mode 'required' requires --tls-ca (CA certificate file)",
196                ));
197            }
198        }
199        "optional" => {
200            // Mutual TLS: accept client certificates if provided, but don't require
201            if let Some(ref ca_file_path) = config.ca_file {
202                // Load CA certificate for client verification
203                let ca_file = File::open(ca_file_path).map_err(|e| {
204                    mockforge_core::Error::generic(format!(
205                        "Failed to open CA certificate file {}: {}",
206                        ca_file_path, e
207                    ))
208                })?;
209                let mut ca_reader = BufReader::new(ca_file);
210                let ca_certs: Vec<rustls::pki_types::CertificateDer<'static>> =
211                    certs(&mut ca_reader).collect::<std::result::Result<Vec<_>, _>>().map_err(
212                        |e| {
213                            mockforge_core::Error::generic(format!(
214                                "Failed to parse CA certificate file {}: {}",
215                                ca_file_path, e
216                            ))
217                        },
218                    )?;
219
220                let mut root_store = rustls::RootCertStore::empty();
221                for cert in &ca_certs {
222                    root_store.add(cert.clone()).map_err(|e| {
223                        mockforge_core::Error::generic(format!(
224                            "Failed to add CA certificate to root store: {}",
225                            e
226                        ))
227                    })?;
228                }
229
230                let client_verifier =
231                    rustls::server::WebPkiClientVerifier::builder(Arc::new(root_store))
232                        .build()
233                        .map_err(|e| {
234                            mockforge_core::Error::generic(format!(
235                                "Failed to build client verifier: {}",
236                                e
237                            ))
238                        })?;
239
240                let key = keys.remove(0);
241
242                // Build with optional mTLS support
243                // Note: rustls doesn't have a built-in "optional" mode, so we use
244                // WebPkiClientVerifier which accepts any client cert that validates,
245                // but connections without certs will also work (we can't enforce optional-only)
246                // For true optional mTLS, we'd need custom verifier logic
247                tls_config_builder(tls13_only)
248                    .with_client_cert_verifier(client_verifier)
249                    .with_single_cert(server_certs, key)
250                    .map_err(|e| {
251                        mockforge_core::Error::generic(format!(
252                            "TLS config error (mTLS optional): {}",
253                            e
254                        ))
255                    })?
256            } else {
257                // Optional mTLS without CA: just standard TLS
258                info!("mTLS optional mode specified but no CA file provided, using standard TLS");
259                let key = keys.remove(0);
260                tls_config_builder(tls13_only)
261                    .with_no_client_auth()
262                    .with_single_cert(server_certs, key)
263                    .map_err(|e| {
264                        mockforge_core::Error::generic(format!("TLS config error: {}", e))
265                    })?
266            }
267        }
268        _ => {
269            // Standard TLS: no client certificate required
270            let key = keys.remove(0);
271            tls_config_builder(tls13_only)
272                .with_no_client_auth()
273                .with_single_cert(server_certs, key)
274                .map_err(|e| mockforge_core::Error::generic(format!("TLS config error: {}", e)))?
275        }
276    };
277
278    // Log cipher suite configuration (cipher suites are controlled by rustls's
279    // safe defaults and not overridable without also selecting specific crypto providers)
280    if !config.cipher_suites.is_empty() {
281        info!(
282            "Custom cipher suites specified: {:?}. Note: rustls enforces safe cipher suites; \
283             for fine-grained control, configure the rustls CryptoProvider.",
284            config.cipher_suites
285        );
286    }
287
288    info!("TLS acceptor configured successfully");
289    Ok(TlsAcceptor::from(Arc::new(server_config)))
290}
291
292/// Load TLS server configuration for use with axum-server
293///
294/// This function is similar to load_tls_acceptor but returns the ServerConfig
295/// directly for use with axum-server's RustlsConfig.
296pub fn load_tls_server_config(
297    config: &HttpTlsConfig,
298) -> std::result::Result<Arc<rustls::server::ServerConfig>, Box<dyn std::error::Error + Send + Sync>>
299{
300    use rustls_pemfile::{certs, pkcs8_private_keys};
301    use std::fs::File;
302    use std::io::BufReader;
303    use std::sync::Arc;
304
305    // Ensure crypto provider is initialized
306    init_crypto_provider();
307
308    info!(
309        "Loading TLS certificate from {} and key from {}",
310        config.cert_file, config.key_file
311    );
312
313    // Load certificate chain
314    let cert_file = File::open(&config.cert_file)
315        .map_err(|e| format!("Failed to open certificate file {}: {}", config.cert_file, e))?;
316    let mut cert_reader = BufReader::new(cert_file);
317    let server_certs: Vec<rustls::pki_types::CertificateDer<'static>> = certs(&mut cert_reader)
318        .collect::<std::result::Result<Vec<_>, _>>()
319        .map_err(|e| format!("Failed to parse certificate file {}: {}", config.cert_file, e))?;
320
321    if server_certs.is_empty() {
322        return Err(format!("No certificates found in {}", config.cert_file).into());
323    }
324
325    // Load private key
326    let key_file = File::open(&config.key_file)
327        .map_err(|e| format!("Failed to open private key file {}: {}", config.key_file, e))?;
328    let mut key_reader = BufReader::new(key_file);
329    let pkcs8_keys: Vec<rustls::pki_types::PrivatePkcs8KeyDer<'static>> =
330        pkcs8_private_keys(&mut key_reader)
331            .collect::<std::result::Result<Vec<_>, _>>()
332            .map_err(|e| format!("Failed to parse private key file {}: {}", config.key_file, e))?;
333    let mut keys: Vec<rustls::pki_types::PrivateKeyDer<'static>> =
334        pkcs8_keys.into_iter().map(rustls::pki_types::PrivateKeyDer::Pkcs8).collect();
335
336    if keys.is_empty() {
337        return Err(format!("No private keys found in {}", config.key_file).into());
338    }
339
340    // Determine TLS protocol versions based on min_version config
341    let tls13_only = is_tls13_only(&config.min_version);
342
343    // Determine mTLS mode
344    let mtls_mode = if !config.mtls_mode.is_empty() && config.mtls_mode != "off" {
345        config.mtls_mode.as_str()
346    } else if config.require_client_cert {
347        "required"
348    } else {
349        "off"
350    };
351
352    let server_config = match mtls_mode {
353        "required" => {
354            if let Some(ref ca_file_path) = config.ca_file {
355                let ca_file = File::open(ca_file_path).map_err(|e| {
356                    format!("Failed to open CA certificate file {}: {}", ca_file_path, e)
357                })?;
358                let mut ca_reader = BufReader::new(ca_file);
359                let ca_certs: Vec<rustls::pki_types::CertificateDer<'static>> =
360                    certs(&mut ca_reader).collect::<std::result::Result<Vec<_>, _>>().map_err(
361                        |e| format!("Failed to parse CA certificate file {}: {}", ca_file_path, e),
362                    )?;
363
364                let mut root_store = rustls::RootCertStore::empty();
365                for cert in &ca_certs {
366                    root_store.add(cert.clone()).map_err(|e| {
367                        format!("Failed to add CA certificate to root store: {}", e)
368                    })?;
369                }
370
371                let client_verifier =
372                    rustls::server::WebPkiClientVerifier::builder(Arc::new(root_store))
373                        .build()
374                        .map_err(|e| format!("Failed to build client verifier: {}", e))?;
375
376                let key = keys.remove(0);
377
378                tls_config_builder(tls13_only)
379                    .with_client_cert_verifier(client_verifier)
380                    .with_single_cert(server_certs, key)
381                    .map_err(|e| format!("TLS config error (mTLS required): {}", e))?
382            } else {
383                return Err("mTLS mode 'required' requires CA certificate file".to_string().into());
384            }
385        }
386        "optional" => {
387            if let Some(ref ca_file_path) = config.ca_file {
388                let ca_file = File::open(ca_file_path).map_err(|e| {
389                    format!("Failed to open CA certificate file {}: {}", ca_file_path, e)
390                })?;
391                let mut ca_reader = BufReader::new(ca_file);
392                let ca_certs: Vec<rustls::pki_types::CertificateDer<'static>> =
393                    certs(&mut ca_reader).collect::<std::result::Result<Vec<_>, _>>().map_err(
394                        |e| format!("Failed to parse CA certificate file {}: {}", ca_file_path, e),
395                    )?;
396
397                let mut root_store = rustls::RootCertStore::empty();
398                for cert in &ca_certs {
399                    root_store.add(cert.clone()).map_err(|e| {
400                        format!("Failed to add CA certificate to root store: {}", e)
401                    })?;
402                }
403
404                let client_verifier =
405                    rustls::server::WebPkiClientVerifier::builder(Arc::new(root_store))
406                        .build()
407                        .map_err(|e| format!("Failed to build client verifier: {}", e))?;
408
409                let key = keys.remove(0);
410
411                tls_config_builder(tls13_only)
412                    .with_client_cert_verifier(client_verifier)
413                    .with_single_cert(server_certs, key)
414                    .map_err(|e| format!("TLS config error (mTLS optional): {}", e))?
415            } else {
416                let key = keys.remove(0);
417                tls_config_builder(tls13_only)
418                    .with_no_client_auth()
419                    .with_single_cert(server_certs, key)
420                    .map_err(|e| format!("TLS config error: {}", e))?
421            }
422        }
423        _ => {
424            let key = keys.remove(0);
425            tls_config_builder(tls13_only)
426                .with_no_client_auth()
427                .with_single_cert(server_certs, key)
428                .map_err(|e| format!("TLS config error: {}", e))?
429        }
430    };
431
432    Ok(Arc::new(server_config))
433}
434
435#[cfg(test)]
436mod tests {
437    use super::*;
438    use std::io::Write;
439    use tempfile::NamedTempFile;
440
441    // Tests use the module-level init_crypto_provider() from super::*
442
443    fn create_test_cert() -> (NamedTempFile, NamedTempFile) {
444        // Create minimal test certificates (these won't actually work for real TLS,
445        // but allow us to test the parsing logic)
446        let cert = NamedTempFile::new().unwrap();
447        let key = NamedTempFile::new().unwrap();
448
449        // Write minimal PEM structure (not valid, but tests file reading)
450        writeln!(cert.as_file(), "-----BEGIN CERTIFICATE-----").unwrap();
451        writeln!(cert.as_file(), "TEST").unwrap();
452        writeln!(cert.as_file(), "-----END CERTIFICATE-----").unwrap();
453
454        writeln!(key.as_file(), "-----BEGIN PRIVATE KEY-----").unwrap();
455        writeln!(key.as_file(), "TEST").unwrap();
456        writeln!(key.as_file(), "-----END PRIVATE KEY-----").unwrap();
457
458        (cert, key)
459    }
460
461    #[test]
462    fn test_tls_config_validation() {
463        init_crypto_provider();
464        let (cert, key) = create_test_cert();
465
466        let config = HttpTlsConfig {
467            enabled: true,
468            cert_file: cert.path().to_string_lossy().to_string(),
469            key_file: key.path().to_string_lossy().to_string(),
470            ca_file: None,
471            min_version: "1.2".to_string(),
472            cipher_suites: Vec::new(),
473            require_client_cert: false,
474            mtls_mode: "off".to_string(),
475        };
476
477        // This will fail because the certificates are not valid,
478        // but it tests that the function attempts to load them
479        let result = load_tls_acceptor(&config);
480        assert!(result.is_err()); // Should fail on invalid cert
481    }
482
483    #[test]
484    fn test_mtls_requires_ca() {
485        init_crypto_provider();
486        let (cert, key) = create_test_cert();
487
488        let config = HttpTlsConfig {
489            enabled: true,
490            cert_file: cert.path().to_string_lossy().to_string(),
491            key_file: key.path().to_string_lossy().to_string(),
492            ca_file: None,
493            min_version: "1.2".to_string(),
494            cipher_suites: Vec::new(),
495            require_client_cert: true, // Requires client cert but no CA file
496            mtls_mode: "required".to_string(),
497        };
498
499        let result = load_tls_acceptor(&config);
500        assert!(result.is_err());
501        let err_msg = format!("{}", result.err().unwrap());
502        assert!(
503            err_msg.contains("CA") || err_msg.contains("--tls-ca"),
504            "Expected error message about CA certificate, got: {}",
505            err_msg
506        );
507    }
508}