Skip to main content

mockforge_core/
tls.rs

1//! Shared TLS utilities for MockForge protocol crates.
2//!
3//! This module provides a common [`TlsConfig`] struct and builder functions
4//! for creating `rustls` [`ServerConfig`](tokio_rustls::rustls::ServerConfig)
5//! and [`ClientConfig`](tokio_rustls::rustls::ClientConfig) instances.
6//!
7//! Protocol crates (MQTT, AMQP, SMTP, TCP, etc.) that need TLS support can
8//! use these helpers instead of duplicating certificate-loading logic.
9//!
10//! # Examples
11//!
12//! ```rust,no_run
13//! use mockforge_core::tls::TlsConfig;
14//!
15//! let config = TlsConfig::new("certs/server.pem", "certs/server-key.pem");
16//! let server_tls = mockforge_core::tls::build_server_tls_config(&config).unwrap();
17//! ```
18
19use rustls::pki_types::{CertificateDer, PrivateKeyDer};
20use rustls_pemfile::{certs, private_key};
21use std::fs::File;
22use std::io::BufReader;
23use std::path::{Path, PathBuf};
24use std::sync::Arc;
25
26/// Errors that can occur during TLS configuration.
27#[derive(Debug, thiserror::Error)]
28pub enum TlsError {
29    /// The certificate file was not found at the specified path.
30    #[error("TLS certificate file not found: {0}")]
31    CertNotFound(String),
32
33    /// The private key file was not found at the specified path.
34    #[error("TLS private key file not found: {0}")]
35    KeyNotFound(String),
36
37    /// Failed to read the certificate file.
38    #[error("Failed to read certificate: {0}")]
39    CertReadError(String),
40
41    /// Failed to read the private key file.
42    #[error("Failed to read private key: {0}")]
43    KeyReadError(String),
44
45    /// The certificate file contained no valid certificates.
46    #[error("No certificates found in certificate file")]
47    NoCertificates,
48
49    /// The key file contained no valid private key.
50    #[error("No private key found in key file")]
51    NoPrivateKey,
52
53    /// A general TLS configuration error.
54    #[error("TLS configuration error: {0}")]
55    ConfigError(String),
56}
57
58/// TLS configuration holding paths to certificate, key, and optional CA files.
59///
60/// This is a protocol-agnostic configuration struct. Protocol crates can
61/// convert their own config types into `TlsConfig` before calling the
62/// shared builder functions.
63#[derive(Debug, Clone)]
64pub struct TlsConfig {
65    /// Path to the PEM-encoded certificate chain file.
66    pub cert_path: PathBuf,
67    /// Path to the PEM-encoded private key file.
68    pub key_path: PathBuf,
69    /// Optional path to a PEM-encoded CA certificate file for client/server verification.
70    pub ca_path: Option<PathBuf>,
71}
72
73impl TlsConfig {
74    /// Create a new `TlsConfig` with cert and key paths.
75    pub fn new(cert_path: impl Into<PathBuf>, key_path: impl Into<PathBuf>) -> Self {
76        Self {
77            cert_path: cert_path.into(),
78            key_path: key_path.into(),
79            ca_path: None,
80        }
81    }
82
83    /// Set the CA certificate path (for client auth verification or custom root CAs).
84    pub fn with_ca(mut self, ca_path: impl Into<PathBuf>) -> Self {
85        self.ca_path = Some(ca_path.into());
86        self
87    }
88}
89
90/// Load PEM-encoded certificates from a file.
91fn load_certs(path: &Path) -> Result<Vec<CertificateDer<'static>>, TlsError> {
92    let file = File::open(path)
93        .map_err(|e| TlsError::CertReadError(format!("{}: {}", path.display(), e)))?;
94    let mut reader = BufReader::new(file);
95
96    let certs_result: Vec<CertificateDer<'static>> =
97        certs(&mut reader).filter_map(|c| c.ok()).collect();
98
99    if certs_result.is_empty() {
100        return Err(TlsError::NoCertificates);
101    }
102
103    Ok(certs_result)
104}
105
106/// Load a PEM-encoded private key from a file.
107fn load_private_key(path: &Path) -> Result<PrivateKeyDer<'static>, TlsError> {
108    let file = File::open(path)
109        .map_err(|e| TlsError::KeyReadError(format!("{}: {}", path.display(), e)))?;
110    let mut reader = BufReader::new(file);
111
112    private_key(&mut reader)
113        .map_err(|e| TlsError::KeyReadError(e.to_string()))?
114        .ok_or(TlsError::NoPrivateKey)
115}
116
117/// Build a rustls [`ServerConfig`](tokio_rustls::rustls::ServerConfig) from the given [`TlsConfig`].
118///
119/// If `config.ca_path` is set, client certificate verification is enabled
120/// using the CA certificates from that file. Otherwise, no client authentication
121/// is required.
122///
123/// # Errors
124///
125/// Returns [`TlsError`] if certificate/key files cannot be read or the
126/// configuration is invalid.
127pub fn build_server_tls_config(
128    config: &TlsConfig,
129) -> Result<Arc<tokio_rustls::rustls::ServerConfig>, TlsError> {
130    // Verify files exist
131    if !config.cert_path.exists() {
132        return Err(TlsError::CertNotFound(config.cert_path.display().to_string()));
133    }
134    if !config.key_path.exists() {
135        return Err(TlsError::KeyNotFound(config.key_path.display().to_string()));
136    }
137
138    let certs_vec = load_certs(&config.cert_path)?;
139    let key = load_private_key(&config.key_path)?;
140
141    let provider = rustls::crypto::ring::default_provider();
142    // Install as process-level default (ignored if already installed by another thread).
143    // This is needed because WebPkiClientVerifier::builder().build() looks up the
144    // process-level CryptoProvider internally.
145    let _ = provider.clone().install_default();
146
147    let server_config = if let Some(ca_path) = &config.ca_path {
148        // Client-auth mode: require client certificates signed by the CA.
149        if !ca_path.exists() {
150            return Err(TlsError::CertNotFound(format!("CA certificate: {}", ca_path.display())));
151        }
152
153        let ca_certs = load_certs(ca_path)?;
154        let mut root_store = rustls::RootCertStore::empty();
155        for cert in ca_certs {
156            root_store
157                .add(cert)
158                .map_err(|e| TlsError::ConfigError(format!("Failed to add CA cert: {}", e)))?;
159        }
160
161        let client_verifier = rustls::server::WebPkiClientVerifier::builder(Arc::new(root_store))
162            .build()
163            .map_err(|e| {
164                TlsError::ConfigError(format!("Failed to create client verifier: {}", e))
165            })?;
166
167        rustls::ServerConfig::builder_with_provider(Arc::new(provider))
168            .with_safe_default_protocol_versions()
169            .map_err(|e| TlsError::ConfigError(e.to_string()))?
170            .with_client_cert_verifier(client_verifier)
171            .with_single_cert(certs_vec, key)
172            .map_err(|e| TlsError::ConfigError(e.to_string()))?
173    } else {
174        // No client auth.
175        rustls::ServerConfig::builder_with_provider(Arc::new(provider))
176            .with_safe_default_protocol_versions()
177            .map_err(|e| TlsError::ConfigError(e.to_string()))?
178            .with_no_client_auth()
179            .with_single_cert(certs_vec, key)
180            .map_err(|e| TlsError::ConfigError(e.to_string()))?
181    };
182
183    Ok(Arc::new(server_config))
184}
185
186/// Build a rustls [`ClientConfig`](tokio_rustls::rustls::ClientConfig) from the given [`TlsConfig`].
187///
188/// If `config.ca_path` is set, the CA certificates are used as trusted roots
189/// instead of the system default roots. The client certificate and key from
190/// `cert_path` / `key_path` are presented for mutual TLS if the server requests
191/// client authentication.
192///
193/// # Errors
194///
195/// Returns [`TlsError`] if certificate/key files cannot be read or the
196/// configuration is invalid.
197pub fn build_client_tls_config(
198    config: &TlsConfig,
199) -> Result<Arc<tokio_rustls::rustls::ClientConfig>, TlsError> {
200    // Verify files exist
201    if !config.cert_path.exists() {
202        return Err(TlsError::CertNotFound(config.cert_path.display().to_string()));
203    }
204    if !config.key_path.exists() {
205        return Err(TlsError::KeyNotFound(config.key_path.display().to_string()));
206    }
207
208    let certs_vec = load_certs(&config.cert_path)?;
209    let key = load_private_key(&config.key_path)?;
210
211    let provider = rustls::crypto::ring::default_provider();
212
213    // Build root cert store
214    let mut root_store = rustls::RootCertStore::empty();
215
216    if let Some(ca_path) = &config.ca_path {
217        if !ca_path.exists() {
218            return Err(TlsError::CertNotFound(format!("CA certificate: {}", ca_path.display())));
219        }
220        let ca_certs = load_certs(ca_path)?;
221        for cert in ca_certs {
222            root_store
223                .add(cert)
224                .map_err(|e| TlsError::ConfigError(format!("Failed to add CA cert: {}", e)))?;
225        }
226    } else {
227        // Use webpki roots as default trusted CAs
228        root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
229    }
230
231    let client_config = rustls::ClientConfig::builder_with_provider(Arc::new(provider))
232        .with_safe_default_protocol_versions()
233        .map_err(|e| TlsError::ConfigError(e.to_string()))?
234        .with_root_certificates(root_store)
235        .with_client_auth_cert(certs_vec, key)
236        .map_err(|e| TlsError::ConfigError(e.to_string()))?;
237
238    Ok(Arc::new(client_config))
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244
245    // Self-signed test certificate and key generated for unit testing only.
246    // These are NOT real credentials.
247
248    fn write_test_cert_and_key(dir: &tempfile::TempDir) -> (PathBuf, PathBuf) {
249        let cert_path = dir.path().join("cert.pem");
250        let key_path = dir.path().join("key.pem");
251
252        // Generate a self-signed cert+key with rcgen
253        let subject_alt_names = vec!["localhost".to_string()];
254        let cert_params =
255            rcgen::CertificateParams::new(subject_alt_names).expect("Failed to create cert params");
256        let key_pair = rcgen::KeyPair::generate().expect("Failed to generate key pair");
257        let cert = cert_params.self_signed(&key_pair).expect("Failed to self-sign cert");
258
259        let cert_pem = cert.pem();
260        let key_pem = key_pair.serialize_pem();
261
262        std::fs::write(&cert_path, cert_pem).unwrap();
263        std::fs::write(&key_path, key_pem).unwrap();
264
265        (cert_path, key_path)
266    }
267
268    #[test]
269    fn test_tls_config_new() {
270        let config = TlsConfig::new("/tmp/cert.pem", "/tmp/key.pem");
271        assert_eq!(config.cert_path, PathBuf::from("/tmp/cert.pem"));
272        assert_eq!(config.key_path, PathBuf::from("/tmp/key.pem"));
273        assert!(config.ca_path.is_none());
274    }
275
276    #[test]
277    fn test_tls_config_with_ca() {
278        let config = TlsConfig::new("/tmp/cert.pem", "/tmp/key.pem").with_ca("/tmp/ca.pem");
279        assert_eq!(config.ca_path, Some(PathBuf::from("/tmp/ca.pem")));
280    }
281
282    #[test]
283    fn test_tls_error_display() {
284        let err = TlsError::CertNotFound("/path/to/cert.pem".to_string());
285        assert!(err.to_string().contains("/path/to/cert.pem"));
286
287        let err = TlsError::NoCertificates;
288        assert!(err.to_string().contains("No certificates"));
289
290        let err = TlsError::NoPrivateKey;
291        assert!(err.to_string().contains("No private key"));
292
293        let err = TlsError::ConfigError("bad config".to_string());
294        assert!(err.to_string().contains("bad config"));
295    }
296
297    #[test]
298    fn test_build_server_tls_config_cert_not_found() {
299        let config = TlsConfig::new("/nonexistent/cert.pem", "/nonexistent/key.pem");
300        let result = build_server_tls_config(&config);
301        assert!(matches!(result, Err(TlsError::CertNotFound(_))));
302    }
303
304    #[test]
305    fn test_build_server_tls_config_key_not_found() {
306        let dir = tempfile::tempdir().unwrap();
307        let cert_path = dir.path().join("cert.pem");
308        std::fs::write(&cert_path, "placeholder").unwrap();
309
310        let config = TlsConfig::new(&cert_path, "/nonexistent/key.pem");
311        let result = build_server_tls_config(&config);
312        assert!(matches!(result, Err(TlsError::KeyNotFound(_))));
313    }
314
315    #[test]
316    fn test_build_server_tls_config_empty_cert() {
317        let dir = tempfile::tempdir().unwrap();
318        let cert_path = dir.path().join("cert.pem");
319        let key_path = dir.path().join("key.pem");
320        std::fs::write(&cert_path, "").unwrap();
321        std::fs::write(&key_path, "").unwrap();
322
323        let config = TlsConfig::new(&cert_path, &key_path);
324        let result = build_server_tls_config(&config);
325        assert!(matches!(result, Err(TlsError::NoCertificates)));
326    }
327
328    #[test]
329    fn test_build_server_tls_config_valid() {
330        let dir = tempfile::tempdir().unwrap();
331        let (cert_path, key_path) = write_test_cert_and_key(&dir);
332
333        let config = TlsConfig::new(&cert_path, &key_path);
334        let result = build_server_tls_config(&config);
335        assert!(result.is_ok(), "Expected Ok, got: {:?}", result.err());
336    }
337
338    #[test]
339    fn test_build_server_tls_config_with_client_auth() {
340        let dir = tempfile::tempdir().unwrap();
341        let (cert_path, key_path) = write_test_cert_and_key(&dir);
342
343        // Use the same cert as CA for testing
344        let ca_path = dir.path().join("ca.pem");
345        std::fs::copy(&cert_path, &ca_path).unwrap();
346
347        let config = TlsConfig::new(&cert_path, &key_path).with_ca(&ca_path);
348        let result = build_server_tls_config(&config);
349        assert!(result.is_ok(), "Expected Ok, got: {:?}", result.err());
350    }
351
352    #[test]
353    fn test_build_server_tls_config_ca_not_found() {
354        let dir = tempfile::tempdir().unwrap();
355        let (cert_path, key_path) = write_test_cert_and_key(&dir);
356
357        let config = TlsConfig::new(&cert_path, &key_path).with_ca("/nonexistent/ca.pem");
358        let result = build_server_tls_config(&config);
359        assert!(matches!(result, Err(TlsError::CertNotFound(_))));
360    }
361
362    #[test]
363    fn test_build_client_tls_config_cert_not_found() {
364        let config = TlsConfig::new("/nonexistent/cert.pem", "/nonexistent/key.pem");
365        let result = build_client_tls_config(&config);
366        assert!(matches!(result, Err(TlsError::CertNotFound(_))));
367    }
368
369    #[test]
370    fn test_build_client_tls_config_valid_with_ca() {
371        let dir = tempfile::tempdir().unwrap();
372        let (cert_path, key_path) = write_test_cert_and_key(&dir);
373
374        let ca_path = dir.path().join("ca.pem");
375        std::fs::copy(&cert_path, &ca_path).unwrap();
376
377        let config = TlsConfig::new(&cert_path, &key_path).with_ca(&ca_path);
378        let result = build_client_tls_config(&config);
379        assert!(result.is_ok(), "Expected Ok, got: {:?}", result.err());
380    }
381
382    #[test]
383    fn test_build_client_tls_config_valid_default_roots() {
384        let dir = tempfile::tempdir().unwrap();
385        let (cert_path, key_path) = write_test_cert_and_key(&dir);
386
387        let config = TlsConfig::new(&cert_path, &key_path);
388        let result = build_client_tls_config(&config);
389        assert!(result.is_ok(), "Expected Ok, got: {:?}", result.err());
390    }
391
392    #[test]
393    fn test_build_client_tls_config_ca_not_found() {
394        let dir = tempfile::tempdir().unwrap();
395        let (cert_path, key_path) = write_test_cert_and_key(&dir);
396
397        let config = TlsConfig::new(&cert_path, &key_path).with_ca("/nonexistent/ca.pem");
398        let result = build_client_tls_config(&config);
399        assert!(matches!(result, Err(TlsError::CertNotFound(_))));
400    }
401}