1use rustls::pki_types::{CertificateDer, PrivateKeyDer};
20use rustls::{ClientConfig, ServerConfig};
21use rustls_pemfile::{certs, private_key};
22use std::fs::File;
23use std::io::BufReader;
24use std::path::{Path, PathBuf};
25use std::sync::Arc;
26
27#[derive(Debug, thiserror::Error)]
29pub enum TlsError {
30 #[error("TLS certificate file not found: {0}")]
32 CertNotFound(String),
33
34 #[error("TLS private key file not found: {0}")]
36 KeyNotFound(String),
37
38 #[error("Failed to read certificate: {0}")]
40 CertReadError(String),
41
42 #[error("Failed to read private key: {0}")]
44 KeyReadError(String),
45
46 #[error("No certificates found in certificate file")]
48 NoCertificates,
49
50 #[error("No private key found in key file")]
52 NoPrivateKey,
53
54 #[error("TLS configuration error: {0}")]
56 ConfigError(String),
57}
58
59#[derive(Debug, Clone)]
65pub struct TlsConfig {
66 pub cert_path: PathBuf,
68 pub key_path: PathBuf,
70 pub ca_path: Option<PathBuf>,
72}
73
74impl TlsConfig {
75 pub fn new(cert_path: impl Into<PathBuf>, key_path: impl Into<PathBuf>) -> Self {
77 Self {
78 cert_path: cert_path.into(),
79 key_path: key_path.into(),
80 ca_path: None,
81 }
82 }
83
84 pub fn with_ca(mut self, ca_path: impl Into<PathBuf>) -> Self {
86 self.ca_path = Some(ca_path.into());
87 self
88 }
89}
90
91fn load_certs(path: &Path) -> Result<Vec<CertificateDer<'static>>, TlsError> {
93 let file = File::open(path)
94 .map_err(|e| TlsError::CertReadError(format!("{}: {}", path.display(), e)))?;
95 let mut reader = BufReader::new(file);
96
97 let certs_result: Vec<CertificateDer<'static>> =
98 certs(&mut reader).filter_map(|c| c.ok()).collect();
99
100 if certs_result.is_empty() {
101 return Err(TlsError::NoCertificates);
102 }
103
104 Ok(certs_result)
105}
106
107fn load_private_key(path: &Path) -> Result<PrivateKeyDer<'static>, TlsError> {
109 let file = File::open(path)
110 .map_err(|e| TlsError::KeyReadError(format!("{}: {}", path.display(), e)))?;
111 let mut reader = BufReader::new(file);
112
113 private_key(&mut reader)
114 .map_err(|e| TlsError::KeyReadError(e.to_string()))?
115 .ok_or(TlsError::NoPrivateKey)
116}
117
118pub fn build_server_tls_config(config: &TlsConfig) -> Result<Arc<ServerConfig>, TlsError> {
129 if !config.cert_path.exists() {
131 return Err(TlsError::CertNotFound(config.cert_path.display().to_string()));
132 }
133 if !config.key_path.exists() {
134 return Err(TlsError::KeyNotFound(config.key_path.display().to_string()));
135 }
136
137 let certs_vec = load_certs(&config.cert_path)?;
138 let key = load_private_key(&config.key_path)?;
139
140 let provider = rustls::crypto::ring::default_provider();
141 let _ = provider.clone().install_default();
145
146 let server_config = if let Some(ca_path) = &config.ca_path {
147 if !ca_path.exists() {
149 return Err(TlsError::CertNotFound(format!("CA certificate: {}", ca_path.display())));
150 }
151
152 let ca_certs = load_certs(ca_path)?;
153 let mut root_store = rustls::RootCertStore::empty();
154 for cert in ca_certs {
155 root_store
156 .add(cert)
157 .map_err(|e| TlsError::ConfigError(format!("Failed to add CA cert: {}", e)))?;
158 }
159
160 let client_verifier = rustls::server::WebPkiClientVerifier::builder(Arc::new(root_store))
161 .build()
162 .map_err(|e| {
163 TlsError::ConfigError(format!("Failed to create client verifier: {}", e))
164 })?;
165
166 ServerConfig::builder_with_provider(Arc::new(provider))
167 .with_safe_default_protocol_versions()
168 .map_err(|e| TlsError::ConfigError(e.to_string()))?
169 .with_client_cert_verifier(client_verifier)
170 .with_single_cert(certs_vec, key)
171 .map_err(|e| TlsError::ConfigError(e.to_string()))?
172 } else {
173 ServerConfig::builder_with_provider(Arc::new(provider))
175 .with_safe_default_protocol_versions()
176 .map_err(|e| TlsError::ConfigError(e.to_string()))?
177 .with_no_client_auth()
178 .with_single_cert(certs_vec, key)
179 .map_err(|e| TlsError::ConfigError(e.to_string()))?
180 };
181
182 Ok(Arc::new(server_config))
183}
184
185pub fn build_client_tls_config(config: &TlsConfig) -> Result<Arc<ClientConfig>, TlsError> {
197 if !config.cert_path.exists() {
199 return Err(TlsError::CertNotFound(config.cert_path.display().to_string()));
200 }
201 if !config.key_path.exists() {
202 return Err(TlsError::KeyNotFound(config.key_path.display().to_string()));
203 }
204
205 let certs_vec = load_certs(&config.cert_path)?;
206 let key = load_private_key(&config.key_path)?;
207
208 let provider = rustls::crypto::ring::default_provider();
209
210 let mut root_store = rustls::RootCertStore::empty();
212
213 if let Some(ca_path) = &config.ca_path {
214 if !ca_path.exists() {
215 return Err(TlsError::CertNotFound(format!("CA certificate: {}", ca_path.display())));
216 }
217 let ca_certs = load_certs(ca_path)?;
218 for cert in ca_certs {
219 root_store
220 .add(cert)
221 .map_err(|e| TlsError::ConfigError(format!("Failed to add CA cert: {}", e)))?;
222 }
223 } else {
224 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
226 }
227
228 let client_config = ClientConfig::builder_with_provider(Arc::new(provider))
229 .with_safe_default_protocol_versions()
230 .map_err(|e| TlsError::ConfigError(e.to_string()))?
231 .with_root_certificates(root_store)
232 .with_client_auth_cert(certs_vec, key)
233 .map_err(|e| TlsError::ConfigError(e.to_string()))?;
234
235 Ok(Arc::new(client_config))
236}
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241
242 fn write_test_cert_and_key(dir: &tempfile::TempDir) -> (PathBuf, PathBuf) {
246 let cert_path = dir.path().join("cert.pem");
247 let key_path = dir.path().join("key.pem");
248
249 let subject_alt_names = vec!["localhost".to_string()];
251 let cert_params =
252 rcgen::CertificateParams::new(subject_alt_names).expect("Failed to create cert params");
253 let key_pair = rcgen::KeyPair::generate().expect("Failed to generate key pair");
254 let cert = cert_params.self_signed(&key_pair).expect("Failed to self-sign cert");
255
256 let cert_pem = cert.pem();
257 let key_pem = key_pair.serialize_pem();
258
259 std::fs::write(&cert_path, cert_pem).unwrap();
260 std::fs::write(&key_path, key_pem).unwrap();
261
262 (cert_path, key_path)
263 }
264
265 #[test]
266 fn test_tls_config_new() {
267 let config = TlsConfig::new("/tmp/cert.pem", "/tmp/key.pem");
268 assert_eq!(config.cert_path, PathBuf::from("/tmp/cert.pem"));
269 assert_eq!(config.key_path, PathBuf::from("/tmp/key.pem"));
270 assert!(config.ca_path.is_none());
271 }
272
273 #[test]
274 fn test_tls_config_with_ca() {
275 let config = TlsConfig::new("/tmp/cert.pem", "/tmp/key.pem").with_ca("/tmp/ca.pem");
276 assert_eq!(config.ca_path, Some(PathBuf::from("/tmp/ca.pem")));
277 }
278
279 #[test]
280 fn test_tls_error_display() {
281 let err = TlsError::CertNotFound("/path/to/cert.pem".to_string());
282 assert!(err.to_string().contains("/path/to/cert.pem"));
283
284 let err = TlsError::NoCertificates;
285 assert!(err.to_string().contains("No certificates"));
286
287 let err = TlsError::NoPrivateKey;
288 assert!(err.to_string().contains("No private key"));
289
290 let err = TlsError::ConfigError("bad config".to_string());
291 assert!(err.to_string().contains("bad config"));
292 }
293
294 #[test]
295 fn test_build_server_tls_config_cert_not_found() {
296 let config = TlsConfig::new("/nonexistent/cert.pem", "/nonexistent/key.pem");
297 let result = build_server_tls_config(&config);
298 assert!(matches!(result, Err(TlsError::CertNotFound(_))));
299 }
300
301 #[test]
302 fn test_build_server_tls_config_key_not_found() {
303 let dir = tempfile::tempdir().unwrap();
304 let cert_path = dir.path().join("cert.pem");
305 std::fs::write(&cert_path, "placeholder").unwrap();
306
307 let config = TlsConfig::new(&cert_path, "/nonexistent/key.pem");
308 let result = build_server_tls_config(&config);
309 assert!(matches!(result, Err(TlsError::KeyNotFound(_))));
310 }
311
312 #[test]
313 fn test_build_server_tls_config_empty_cert() {
314 let dir = tempfile::tempdir().unwrap();
315 let cert_path = dir.path().join("cert.pem");
316 let key_path = dir.path().join("key.pem");
317 std::fs::write(&cert_path, "").unwrap();
318 std::fs::write(&key_path, "").unwrap();
319
320 let config = TlsConfig::new(&cert_path, &key_path);
321 let result = build_server_tls_config(&config);
322 assert!(matches!(result, Err(TlsError::NoCertificates)));
323 }
324
325 #[test]
326 fn test_build_server_tls_config_valid() {
327 let dir = tempfile::tempdir().unwrap();
328 let (cert_path, key_path) = write_test_cert_and_key(&dir);
329
330 let config = TlsConfig::new(&cert_path, &key_path);
331 let result = build_server_tls_config(&config);
332 assert!(result.is_ok(), "Expected Ok, got: {:?}", result.err());
333 }
334
335 #[test]
336 fn test_build_server_tls_config_with_client_auth() {
337 let dir = tempfile::tempdir().unwrap();
338 let (cert_path, key_path) = write_test_cert_and_key(&dir);
339
340 let ca_path = dir.path().join("ca.pem");
342 std::fs::copy(&cert_path, &ca_path).unwrap();
343
344 let config = TlsConfig::new(&cert_path, &key_path).with_ca(&ca_path);
345 let result = build_server_tls_config(&config);
346 assert!(result.is_ok(), "Expected Ok, got: {:?}", result.err());
347 }
348
349 #[test]
350 fn test_build_server_tls_config_ca_not_found() {
351 let dir = tempfile::tempdir().unwrap();
352 let (cert_path, key_path) = write_test_cert_and_key(&dir);
353
354 let config = TlsConfig::new(&cert_path, &key_path).with_ca("/nonexistent/ca.pem");
355 let result = build_server_tls_config(&config);
356 assert!(matches!(result, Err(TlsError::CertNotFound(_))));
357 }
358
359 #[test]
360 fn test_build_client_tls_config_cert_not_found() {
361 let config = TlsConfig::new("/nonexistent/cert.pem", "/nonexistent/key.pem");
362 let result = build_client_tls_config(&config);
363 assert!(matches!(result, Err(TlsError::CertNotFound(_))));
364 }
365
366 #[test]
367 fn test_build_client_tls_config_valid_with_ca() {
368 let dir = tempfile::tempdir().unwrap();
369 let (cert_path, key_path) = write_test_cert_and_key(&dir);
370
371 let ca_path = dir.path().join("ca.pem");
372 std::fs::copy(&cert_path, &ca_path).unwrap();
373
374 let config = TlsConfig::new(&cert_path, &key_path).with_ca(&ca_path);
375 let result = build_client_tls_config(&config);
376 assert!(result.is_ok(), "Expected Ok, got: {:?}", result.err());
377 }
378
379 #[test]
380 fn test_build_client_tls_config_valid_default_roots() {
381 let dir = tempfile::tempdir().unwrap();
382 let (cert_path, key_path) = write_test_cert_and_key(&dir);
383
384 let config = TlsConfig::new(&cert_path, &key_path);
385 let result = build_client_tls_config(&config);
386 assert!(result.is_ok(), "Expected Ok, got: {:?}", result.err());
387 }
388
389 #[test]
390 fn test_build_client_tls_config_ca_not_found() {
391 let dir = tempfile::tempdir().unwrap();
392 let (cert_path, key_path) = write_test_cert_and_key(&dir);
393
394 let config = TlsConfig::new(&cert_path, &key_path).with_ca("/nonexistent/ca.pem");
395 let result = build_client_tls_config(&config);
396 assert!(matches!(result, Err(TlsError::CertNotFound(_))));
397 }
398}