mockforge_mqtt/
tls.rs

1//! TLS support for MQTT connections
2//!
3//! This module provides TLS/SSL encryption for MQTT connections using rustls.
4
5use crate::broker::MqttConfig;
6use rustls::pki_types::{CertificateDer, PrivateKeyDer};
7use rustls::ServerConfig;
8use rustls_pemfile::{certs, private_key};
9use std::fs::File;
10use std::io::BufReader;
11use std::path::Path;
12use std::sync::Arc;
13use tokio_rustls::TlsAcceptor;
14
15/// Error type for TLS configuration
16#[derive(Debug, thiserror::Error)]
17pub enum TlsError {
18    #[error("TLS certificate file not found: {0}")]
19    CertNotFound(String),
20    #[error("TLS private key file not found: {0}")]
21    KeyNotFound(String),
22    #[error("Failed to read certificate: {0}")]
23    CertReadError(String),
24    #[error("Failed to read private key: {0}")]
25    KeyReadError(String),
26    #[error("No certificates found in certificate file")]
27    NoCertificates,
28    #[error("No private key found in key file")]
29    NoPrivateKey,
30    #[error("TLS configuration error: {0}")]
31    ConfigError(String),
32    #[error("TLS is enabled but certificate path is not configured")]
33    CertPathNotConfigured,
34    #[error("TLS is enabled but key path is not configured")]
35    KeyPathNotConfigured,
36}
37
38/// Load certificates from a PEM file
39fn load_certs(path: &Path) -> Result<Vec<CertificateDer<'static>>, TlsError> {
40    let file = File::open(path)
41        .map_err(|e| TlsError::CertReadError(format!("{}: {}", path.display(), e)))?;
42    let mut reader = BufReader::new(file);
43
44    let certs: Vec<CertificateDer<'static>> = certs(&mut reader).filter_map(|c| c.ok()).collect();
45
46    if certs.is_empty() {
47        return Err(TlsError::NoCertificates);
48    }
49
50    Ok(certs)
51}
52
53/// Load private key from a PEM file
54fn load_private_key(path: &Path) -> Result<PrivateKeyDer<'static>, TlsError> {
55    let file = File::open(path)
56        .map_err(|e| TlsError::KeyReadError(format!("{}: {}", path.display(), e)))?;
57    let mut reader = BufReader::new(file);
58
59    private_key(&mut reader)
60        .map_err(|e| TlsError::KeyReadError(e.to_string()))?
61        .ok_or(TlsError::NoPrivateKey)
62}
63
64/// Create a TLS acceptor from MQTT configuration
65pub fn create_tls_acceptor(config: &MqttConfig) -> Result<TlsAcceptor, TlsError> {
66    let cert_path = config.tls_cert_path.as_ref().ok_or(TlsError::CertPathNotConfigured)?;
67
68    let key_path = config.tls_key_path.as_ref().ok_or(TlsError::KeyPathNotConfigured)?;
69
70    // Verify files exist
71    if !cert_path.exists() {
72        return Err(TlsError::CertNotFound(cert_path.display().to_string()));
73    }
74    if !key_path.exists() {
75        return Err(TlsError::KeyNotFound(key_path.display().to_string()));
76    }
77
78    // Load certificates and private key
79    let certs = load_certs(cert_path)?;
80    let key = load_private_key(key_path)?;
81
82    // Build server config
83    let server_config = ServerConfig::builder()
84        .with_no_client_auth()
85        .with_single_cert(certs, key)
86        .map_err(|e| TlsError::ConfigError(e.to_string()))?;
87
88    Ok(TlsAcceptor::from(Arc::new(server_config)))
89}
90
91/// Create a TLS acceptor with optional client authentication
92pub fn create_tls_acceptor_with_client_auth(config: &MqttConfig) -> Result<TlsAcceptor, TlsError> {
93    let cert_path = config.tls_cert_path.as_ref().ok_or(TlsError::CertPathNotConfigured)?;
94
95    let key_path = config.tls_key_path.as_ref().ok_or(TlsError::KeyPathNotConfigured)?;
96
97    // Verify files exist
98    if !cert_path.exists() {
99        return Err(TlsError::CertNotFound(cert_path.display().to_string()));
100    }
101    if !key_path.exists() {
102        return Err(TlsError::KeyNotFound(key_path.display().to_string()));
103    }
104
105    // Load certificates and private key
106    let certs = load_certs(cert_path)?;
107    let key = load_private_key(key_path)?;
108
109    // Build server config based on client auth setting
110    let server_config = if config.tls_client_auth {
111        // Load CA certificate for client verification
112        let ca_path = config.tls_ca_path.as_ref().ok_or_else(|| {
113            TlsError::ConfigError("Client auth requires CA certificate path".to_string())
114        })?;
115
116        if !ca_path.exists() {
117            return Err(TlsError::CertNotFound(format!("CA certificate: {}", ca_path.display())));
118        }
119
120        let ca_certs = load_certs(ca_path)?;
121
122        // Create root cert store with CA certs
123        let mut root_store = rustls::RootCertStore::empty();
124        for cert in ca_certs {
125            root_store
126                .add(cert)
127                .map_err(|e| TlsError::ConfigError(format!("Failed to add CA cert: {}", e)))?;
128        }
129
130        let client_verifier = rustls::server::WebPkiClientVerifier::builder(Arc::new(root_store))
131            .build()
132            .map_err(|e| {
133                TlsError::ConfigError(format!("Failed to create client verifier: {}", e))
134            })?;
135
136        ServerConfig::builder()
137            .with_client_cert_verifier(client_verifier)
138            .with_single_cert(certs, key)
139            .map_err(|e| TlsError::ConfigError(e.to_string()))?
140    } else {
141        ServerConfig::builder()
142            .with_no_client_auth()
143            .with_single_cert(certs, key)
144            .map_err(|e| TlsError::ConfigError(e.to_string()))?
145    };
146
147    Ok(TlsAcceptor::from(Arc::new(server_config)))
148}
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153
154    #[test]
155    fn test_tls_error_display() {
156        let err = TlsError::CertNotFound("/path/to/cert.pem".to_string());
157        assert!(err.to_string().contains("/path/to/cert.pem"));
158
159        let err = TlsError::NoCertificates;
160        assert!(err.to_string().contains("No certificates"));
161    }
162
163    #[test]
164    fn test_create_tls_acceptor_missing_cert_path() {
165        let config = MqttConfig {
166            tls_enabled: true,
167            tls_cert_path: None,
168            tls_key_path: Some(std::path::PathBuf::from("/tmp/key.pem")),
169            ..Default::default()
170        };
171
172        let result = create_tls_acceptor(&config);
173        assert!(matches!(result, Err(TlsError::CertPathNotConfigured)));
174    }
175
176    #[test]
177    fn test_create_tls_acceptor_missing_key_path() {
178        let config = MqttConfig {
179            tls_enabled: true,
180            tls_cert_path: Some(std::path::PathBuf::from("/tmp/cert.pem")),
181            tls_key_path: None,
182            ..Default::default()
183        };
184
185        let result = create_tls_acceptor(&config);
186        assert!(matches!(result, Err(TlsError::KeyPathNotConfigured)));
187    }
188
189    #[test]
190    fn test_create_tls_acceptor_cert_not_found() {
191        let config = MqttConfig {
192            tls_enabled: true,
193            tls_cert_path: Some(std::path::PathBuf::from("/nonexistent/cert.pem")),
194            tls_key_path: Some(std::path::PathBuf::from("/nonexistent/key.pem")),
195            ..Default::default()
196        };
197
198        let result = create_tls_acceptor(&config);
199        assert!(matches!(result, Err(TlsError::CertNotFound(_))));
200    }
201}