1use crate::broker::MqttConfig;
6use rustls::pki_types::{CertificateDer, PrivateKeyDer};
7use rustls::ServerConfig;
8use rustls_pemfile::{certs, private_key};
9use std::fs::File;
10use std::io::BufReader;
11use std::path::Path;
12use std::sync::Arc;
13use tokio_rustls::TlsAcceptor;
14
15#[derive(Debug, thiserror::Error)]
17pub enum TlsError {
18 #[error("TLS certificate file not found: {0}")]
19 CertNotFound(String),
20 #[error("TLS private key file not found: {0}")]
21 KeyNotFound(String),
22 #[error("Failed to read certificate: {0}")]
23 CertReadError(String),
24 #[error("Failed to read private key: {0}")]
25 KeyReadError(String),
26 #[error("No certificates found in certificate file")]
27 NoCertificates,
28 #[error("No private key found in key file")]
29 NoPrivateKey,
30 #[error("TLS configuration error: {0}")]
31 ConfigError(String),
32 #[error("TLS is enabled but certificate path is not configured")]
33 CertPathNotConfigured,
34 #[error("TLS is enabled but key path is not configured")]
35 KeyPathNotConfigured,
36}
37
38fn load_certs(path: &Path) -> Result<Vec<CertificateDer<'static>>, TlsError> {
40 let file = File::open(path)
41 .map_err(|e| TlsError::CertReadError(format!("{}: {}", path.display(), e)))?;
42 let mut reader = BufReader::new(file);
43
44 let certs: Vec<CertificateDer<'static>> = certs(&mut reader).filter_map(|c| c.ok()).collect();
45
46 if certs.is_empty() {
47 return Err(TlsError::NoCertificates);
48 }
49
50 Ok(certs)
51}
52
53fn load_private_key(path: &Path) -> Result<PrivateKeyDer<'static>, TlsError> {
55 let file = File::open(path)
56 .map_err(|e| TlsError::KeyReadError(format!("{}: {}", path.display(), e)))?;
57 let mut reader = BufReader::new(file);
58
59 private_key(&mut reader)
60 .map_err(|e| TlsError::KeyReadError(e.to_string()))?
61 .ok_or(TlsError::NoPrivateKey)
62}
63
64pub fn create_tls_acceptor(config: &MqttConfig) -> Result<TlsAcceptor, TlsError> {
66 let cert_path = config.tls_cert_path.as_ref().ok_or(TlsError::CertPathNotConfigured)?;
67
68 let key_path = config.tls_key_path.as_ref().ok_or(TlsError::KeyPathNotConfigured)?;
69
70 if !cert_path.exists() {
72 return Err(TlsError::CertNotFound(cert_path.display().to_string()));
73 }
74 if !key_path.exists() {
75 return Err(TlsError::KeyNotFound(key_path.display().to_string()));
76 }
77
78 let certs = load_certs(cert_path)?;
80 let key = load_private_key(key_path)?;
81
82 let server_config = ServerConfig::builder()
84 .with_no_client_auth()
85 .with_single_cert(certs, key)
86 .map_err(|e| TlsError::ConfigError(e.to_string()))?;
87
88 Ok(TlsAcceptor::from(Arc::new(server_config)))
89}
90
91pub fn create_tls_acceptor_with_client_auth(config: &MqttConfig) -> Result<TlsAcceptor, TlsError> {
93 let cert_path = config.tls_cert_path.as_ref().ok_or(TlsError::CertPathNotConfigured)?;
94
95 let key_path = config.tls_key_path.as_ref().ok_or(TlsError::KeyPathNotConfigured)?;
96
97 if !cert_path.exists() {
99 return Err(TlsError::CertNotFound(cert_path.display().to_string()));
100 }
101 if !key_path.exists() {
102 return Err(TlsError::KeyNotFound(key_path.display().to_string()));
103 }
104
105 let certs = load_certs(cert_path)?;
107 let key = load_private_key(key_path)?;
108
109 let server_config = if config.tls_client_auth {
111 let ca_path = config.tls_ca_path.as_ref().ok_or_else(|| {
113 TlsError::ConfigError("Client auth requires CA certificate path".to_string())
114 })?;
115
116 if !ca_path.exists() {
117 return Err(TlsError::CertNotFound(format!("CA certificate: {}", ca_path.display())));
118 }
119
120 let ca_certs = load_certs(ca_path)?;
121
122 let mut root_store = rustls::RootCertStore::empty();
124 for cert in ca_certs {
125 root_store
126 .add(cert)
127 .map_err(|e| TlsError::ConfigError(format!("Failed to add CA cert: {}", e)))?;
128 }
129
130 let client_verifier = rustls::server::WebPkiClientVerifier::builder(Arc::new(root_store))
131 .build()
132 .map_err(|e| {
133 TlsError::ConfigError(format!("Failed to create client verifier: {}", e))
134 })?;
135
136 ServerConfig::builder()
137 .with_client_cert_verifier(client_verifier)
138 .with_single_cert(certs, key)
139 .map_err(|e| TlsError::ConfigError(e.to_string()))?
140 } else {
141 ServerConfig::builder()
142 .with_no_client_auth()
143 .with_single_cert(certs, key)
144 .map_err(|e| TlsError::ConfigError(e.to_string()))?
145 };
146
147 Ok(TlsAcceptor::from(Arc::new(server_config)))
148}
149
150#[cfg(test)]
151mod tests {
152 use super::*;
153
154 #[test]
155 fn test_tls_error_display() {
156 let err = TlsError::CertNotFound("/path/to/cert.pem".to_string());
157 assert!(err.to_string().contains("/path/to/cert.pem"));
158
159 let err = TlsError::NoCertificates;
160 assert!(err.to_string().contains("No certificates"));
161 }
162
163 #[test]
164 fn test_create_tls_acceptor_missing_cert_path() {
165 let config = MqttConfig {
166 tls_enabled: true,
167 tls_cert_path: None,
168 tls_key_path: Some(std::path::PathBuf::from("/tmp/key.pem")),
169 ..Default::default()
170 };
171
172 let result = create_tls_acceptor(&config);
173 assert!(matches!(result, Err(TlsError::CertPathNotConfigured)));
174 }
175
176 #[test]
177 fn test_create_tls_acceptor_missing_key_path() {
178 let config = MqttConfig {
179 tls_enabled: true,
180 tls_cert_path: Some(std::path::PathBuf::from("/tmp/cert.pem")),
181 tls_key_path: None,
182 ..Default::default()
183 };
184
185 let result = create_tls_acceptor(&config);
186 assert!(matches!(result, Err(TlsError::KeyPathNotConfigured)));
187 }
188
189 #[test]
190 fn test_create_tls_acceptor_cert_not_found() {
191 let config = MqttConfig {
192 tls_enabled: true,
193 tls_cert_path: Some(std::path::PathBuf::from("/nonexistent/cert.pem")),
194 tls_key_path: Some(std::path::PathBuf::from("/nonexistent/key.pem")),
195 ..Default::default()
196 };
197
198 let result = create_tls_acceptor(&config);
199 assert!(matches!(result, Err(TlsError::CertNotFound(_))));
200 }
201}