1use rustls::pki_types::{CertificateDer, PrivateKeyDer};
20use rustls_pemfile::{certs, private_key};
21use std::fs::File;
22use std::io::BufReader;
23use std::path::{Path, PathBuf};
24use std::sync::Arc;
25
26#[derive(Debug, thiserror::Error)]
28pub enum TlsError {
29 #[error("TLS certificate file not found: {0}")]
31 CertNotFound(String),
32
33 #[error("TLS private key file not found: {0}")]
35 KeyNotFound(String),
36
37 #[error("Failed to read certificate: {0}")]
39 CertReadError(String),
40
41 #[error("Failed to read private key: {0}")]
43 KeyReadError(String),
44
45 #[error("No certificates found in certificate file")]
47 NoCertificates,
48
49 #[error("No private key found in key file")]
51 NoPrivateKey,
52
53 #[error("TLS configuration error: {0}")]
55 ConfigError(String),
56}
57
58#[derive(Debug, Clone)]
64pub struct TlsConfig {
65 pub cert_path: PathBuf,
67 pub key_path: PathBuf,
69 pub ca_path: Option<PathBuf>,
71}
72
73impl TlsConfig {
74 pub fn new(cert_path: impl Into<PathBuf>, key_path: impl Into<PathBuf>) -> Self {
76 Self {
77 cert_path: cert_path.into(),
78 key_path: key_path.into(),
79 ca_path: None,
80 }
81 }
82
83 pub fn with_ca(mut self, ca_path: impl Into<PathBuf>) -> Self {
85 self.ca_path = Some(ca_path.into());
86 self
87 }
88}
89
90fn load_certs(path: &Path) -> Result<Vec<CertificateDer<'static>>, TlsError> {
92 let file = File::open(path)
93 .map_err(|e| TlsError::CertReadError(format!("{}: {}", path.display(), e)))?;
94 let mut reader = BufReader::new(file);
95
96 let certs_result: Vec<CertificateDer<'static>> =
97 certs(&mut reader).filter_map(|c| c.ok()).collect();
98
99 if certs_result.is_empty() {
100 return Err(TlsError::NoCertificates);
101 }
102
103 Ok(certs_result)
104}
105
106fn load_private_key(path: &Path) -> Result<PrivateKeyDer<'static>, TlsError> {
108 let file = File::open(path)
109 .map_err(|e| TlsError::KeyReadError(format!("{}: {}", path.display(), e)))?;
110 let mut reader = BufReader::new(file);
111
112 private_key(&mut reader)
113 .map_err(|e| TlsError::KeyReadError(e.to_string()))?
114 .ok_or(TlsError::NoPrivateKey)
115}
116
117pub fn build_server_tls_config(
128 config: &TlsConfig,
129) -> Result<Arc<tokio_rustls::rustls::ServerConfig>, TlsError> {
130 if !config.cert_path.exists() {
132 return Err(TlsError::CertNotFound(config.cert_path.display().to_string()));
133 }
134 if !config.key_path.exists() {
135 return Err(TlsError::KeyNotFound(config.key_path.display().to_string()));
136 }
137
138 let certs_vec = load_certs(&config.cert_path)?;
139 let key = load_private_key(&config.key_path)?;
140
141 let provider = rustls::crypto::ring::default_provider();
142 let _ = provider.clone().install_default();
146
147 let server_config = if let Some(ca_path) = &config.ca_path {
148 if !ca_path.exists() {
150 return Err(TlsError::CertNotFound(format!("CA certificate: {}", ca_path.display())));
151 }
152
153 let ca_certs = load_certs(ca_path)?;
154 let mut root_store = rustls::RootCertStore::empty();
155 for cert in ca_certs {
156 root_store
157 .add(cert)
158 .map_err(|e| TlsError::ConfigError(format!("Failed to add CA cert: {}", e)))?;
159 }
160
161 let client_verifier = rustls::server::WebPkiClientVerifier::builder(Arc::new(root_store))
162 .build()
163 .map_err(|e| {
164 TlsError::ConfigError(format!("Failed to create client verifier: {}", e))
165 })?;
166
167 rustls::ServerConfig::builder_with_provider(Arc::new(provider))
168 .with_safe_default_protocol_versions()
169 .map_err(|e| TlsError::ConfigError(e.to_string()))?
170 .with_client_cert_verifier(client_verifier)
171 .with_single_cert(certs_vec, key)
172 .map_err(|e| TlsError::ConfigError(e.to_string()))?
173 } else {
174 rustls::ServerConfig::builder_with_provider(Arc::new(provider))
176 .with_safe_default_protocol_versions()
177 .map_err(|e| TlsError::ConfigError(e.to_string()))?
178 .with_no_client_auth()
179 .with_single_cert(certs_vec, key)
180 .map_err(|e| TlsError::ConfigError(e.to_string()))?
181 };
182
183 Ok(Arc::new(server_config))
184}
185
186pub fn build_client_tls_config(
198 config: &TlsConfig,
199) -> Result<Arc<tokio_rustls::rustls::ClientConfig>, TlsError> {
200 if !config.cert_path.exists() {
202 return Err(TlsError::CertNotFound(config.cert_path.display().to_string()));
203 }
204 if !config.key_path.exists() {
205 return Err(TlsError::KeyNotFound(config.key_path.display().to_string()));
206 }
207
208 let certs_vec = load_certs(&config.cert_path)?;
209 let key = load_private_key(&config.key_path)?;
210
211 let provider = rustls::crypto::ring::default_provider();
212
213 let mut root_store = rustls::RootCertStore::empty();
215
216 if let Some(ca_path) = &config.ca_path {
217 if !ca_path.exists() {
218 return Err(TlsError::CertNotFound(format!("CA certificate: {}", ca_path.display())));
219 }
220 let ca_certs = load_certs(ca_path)?;
221 for cert in ca_certs {
222 root_store
223 .add(cert)
224 .map_err(|e| TlsError::ConfigError(format!("Failed to add CA cert: {}", e)))?;
225 }
226 } else {
227 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
229 }
230
231 let client_config = rustls::ClientConfig::builder_with_provider(Arc::new(provider))
232 .with_safe_default_protocol_versions()
233 .map_err(|e| TlsError::ConfigError(e.to_string()))?
234 .with_root_certificates(root_store)
235 .with_client_auth_cert(certs_vec, key)
236 .map_err(|e| TlsError::ConfigError(e.to_string()))?;
237
238 Ok(Arc::new(client_config))
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244
245 fn write_test_cert_and_key(dir: &tempfile::TempDir) -> (PathBuf, PathBuf) {
249 let cert_path = dir.path().join("cert.pem");
250 let key_path = dir.path().join("key.pem");
251
252 let subject_alt_names = vec!["localhost".to_string()];
254 let cert_params =
255 rcgen::CertificateParams::new(subject_alt_names).expect("Failed to create cert params");
256 let key_pair = rcgen::KeyPair::generate().expect("Failed to generate key pair");
257 let cert = cert_params.self_signed(&key_pair).expect("Failed to self-sign cert");
258
259 let cert_pem = cert.pem();
260 let key_pem = key_pair.serialize_pem();
261
262 std::fs::write(&cert_path, cert_pem).unwrap();
263 std::fs::write(&key_path, key_pem).unwrap();
264
265 (cert_path, key_path)
266 }
267
268 #[test]
269 fn test_tls_config_new() {
270 let config = TlsConfig::new("/tmp/cert.pem", "/tmp/key.pem");
271 assert_eq!(config.cert_path, PathBuf::from("/tmp/cert.pem"));
272 assert_eq!(config.key_path, PathBuf::from("/tmp/key.pem"));
273 assert!(config.ca_path.is_none());
274 }
275
276 #[test]
277 fn test_tls_config_with_ca() {
278 let config = TlsConfig::new("/tmp/cert.pem", "/tmp/key.pem").with_ca("/tmp/ca.pem");
279 assert_eq!(config.ca_path, Some(PathBuf::from("/tmp/ca.pem")));
280 }
281
282 #[test]
283 fn test_tls_error_display() {
284 let err = TlsError::CertNotFound("/path/to/cert.pem".to_string());
285 assert!(err.to_string().contains("/path/to/cert.pem"));
286
287 let err = TlsError::NoCertificates;
288 assert!(err.to_string().contains("No certificates"));
289
290 let err = TlsError::NoPrivateKey;
291 assert!(err.to_string().contains("No private key"));
292
293 let err = TlsError::ConfigError("bad config".to_string());
294 assert!(err.to_string().contains("bad config"));
295 }
296
297 #[test]
298 fn test_build_server_tls_config_cert_not_found() {
299 let config = TlsConfig::new("/nonexistent/cert.pem", "/nonexistent/key.pem");
300 let result = build_server_tls_config(&config);
301 assert!(matches!(result, Err(TlsError::CertNotFound(_))));
302 }
303
304 #[test]
305 fn test_build_server_tls_config_key_not_found() {
306 let dir = tempfile::tempdir().unwrap();
307 let cert_path = dir.path().join("cert.pem");
308 std::fs::write(&cert_path, "placeholder").unwrap();
309
310 let config = TlsConfig::new(&cert_path, "/nonexistent/key.pem");
311 let result = build_server_tls_config(&config);
312 assert!(matches!(result, Err(TlsError::KeyNotFound(_))));
313 }
314
315 #[test]
316 fn test_build_server_tls_config_empty_cert() {
317 let dir = tempfile::tempdir().unwrap();
318 let cert_path = dir.path().join("cert.pem");
319 let key_path = dir.path().join("key.pem");
320 std::fs::write(&cert_path, "").unwrap();
321 std::fs::write(&key_path, "").unwrap();
322
323 let config = TlsConfig::new(&cert_path, &key_path);
324 let result = build_server_tls_config(&config);
325 assert!(matches!(result, Err(TlsError::NoCertificates)));
326 }
327
328 #[test]
329 fn test_build_server_tls_config_valid() {
330 let dir = tempfile::tempdir().unwrap();
331 let (cert_path, key_path) = write_test_cert_and_key(&dir);
332
333 let config = TlsConfig::new(&cert_path, &key_path);
334 let result = build_server_tls_config(&config);
335 assert!(result.is_ok(), "Expected Ok, got: {:?}", result.err());
336 }
337
338 #[test]
339 fn test_build_server_tls_config_with_client_auth() {
340 let dir = tempfile::tempdir().unwrap();
341 let (cert_path, key_path) = write_test_cert_and_key(&dir);
342
343 let ca_path = dir.path().join("ca.pem");
345 std::fs::copy(&cert_path, &ca_path).unwrap();
346
347 let config = TlsConfig::new(&cert_path, &key_path).with_ca(&ca_path);
348 let result = build_server_tls_config(&config);
349 assert!(result.is_ok(), "Expected Ok, got: {:?}", result.err());
350 }
351
352 #[test]
353 fn test_build_server_tls_config_ca_not_found() {
354 let dir = tempfile::tempdir().unwrap();
355 let (cert_path, key_path) = write_test_cert_and_key(&dir);
356
357 let config = TlsConfig::new(&cert_path, &key_path).with_ca("/nonexistent/ca.pem");
358 let result = build_server_tls_config(&config);
359 assert!(matches!(result, Err(TlsError::CertNotFound(_))));
360 }
361
362 #[test]
363 fn test_build_client_tls_config_cert_not_found() {
364 let config = TlsConfig::new("/nonexistent/cert.pem", "/nonexistent/key.pem");
365 let result = build_client_tls_config(&config);
366 assert!(matches!(result, Err(TlsError::CertNotFound(_))));
367 }
368
369 #[test]
370 fn test_build_client_tls_config_valid_with_ca() {
371 let dir = tempfile::tempdir().unwrap();
372 let (cert_path, key_path) = write_test_cert_and_key(&dir);
373
374 let ca_path = dir.path().join("ca.pem");
375 std::fs::copy(&cert_path, &ca_path).unwrap();
376
377 let config = TlsConfig::new(&cert_path, &key_path).with_ca(&ca_path);
378 let result = build_client_tls_config(&config);
379 assert!(result.is_ok(), "Expected Ok, got: {:?}", result.err());
380 }
381
382 #[test]
383 fn test_build_client_tls_config_valid_default_roots() {
384 let dir = tempfile::tempdir().unwrap();
385 let (cert_path, key_path) = write_test_cert_and_key(&dir);
386
387 let config = TlsConfig::new(&cert_path, &key_path);
388 let result = build_client_tls_config(&config);
389 assert!(result.is_ok(), "Expected Ok, got: {:?}", result.err());
390 }
391
392 #[test]
393 fn test_build_client_tls_config_ca_not_found() {
394 let dir = tempfile::tempdir().unwrap();
395 let (cert_path, key_path) = write_test_cert_and_key(&dir);
396
397 let config = TlsConfig::new(&cert_path, &key_path).with_ca("/nonexistent/ca.pem");
398 let result = build_client_tls_config(&config);
399 assert!(matches!(result, Err(TlsError::CertNotFound(_))));
400 }
401}