1use rustls::pki_types::{CertificateDer, PrivateKeyDer};
11use std::{sync::Arc, time::Duration};
12use thiserror::Error;
13
14#[derive(Error, Debug)]
16pub enum CertificateError {
17 #[error("Certificate generation failed: {0}")]
18 GenerationFailed(String),
19
20 #[error("Certificate validation failed: {0}")]
21 ValidationFailed(String),
22
23 #[error("Certificate loading failed: {0}")]
24 LoadingFailed(String),
25
26 #[error("Certificate parsing failed: {0}")]
27 ParsingFailed(String),
28
29 #[error("Private key error: {0}")]
30 PrivateKeyError(String),
31
32 #[error("Certificate chain error: {0}")]
33 ChainError(String),
34
35 #[error("Certificate expired or not yet valid")]
36 ValidityError,
37
38 #[error("Unsupported certificate format")]
39 UnsupportedFormat,
40}
41
42#[derive(Debug, Clone)]
44pub struct CertificateConfig {
45 pub common_name: String,
47
48 pub subject_alt_names: Vec<String>,
50
51 pub validity_duration: Duration,
53
54 pub key_algorithm: KeyAlgorithm,
56
57 pub self_signed: bool,
59
60 pub ca_cert_path: Option<String>,
62
63 pub require_chain_validation: bool,
65}
66
67#[derive(Debug, Clone, Copy)]
69pub enum KeyAlgorithm {
70 Rsa(u32),
72 EcdsaP256,
74 EcdsaP384,
76 Ed25519,
78}
79
80#[derive(Debug)]
82pub struct CertificateBundle {
83 pub cert_chain: Vec<CertificateDer<'static>>,
85
86 pub private_key: PrivateKeyDer<'static>,
88
89 pub created_at: std::time::SystemTime,
91
92 pub expires_at: std::time::SystemTime,
94}
95
96pub struct CertificateManager {
98 config: CertificateConfig,
99 ca_certs: Vec<CertificateDer<'static>>,
100}
101
102impl Default for CertificateConfig {
103 fn default() -> Self {
104 Self {
105 common_name: "ant-quic-node".to_string(),
106 subject_alt_names: vec!["localhost".to_string()],
107 validity_duration: Duration::from_secs(365 * 24 * 60 * 60), key_algorithm: KeyAlgorithm::Ed25519,
109 self_signed: true,
110 ca_cert_path: None,
111 require_chain_validation: false,
112 }
113 }
114}
115
116impl CertificateManager {
117 pub fn new(config: CertificateConfig) -> Result<Self, CertificateError> {
119 let ca_certs = if let Some(ca_path) = &config.ca_cert_path {
120 Self::load_ca_certificates(ca_path)?
121 } else {
122 Vec::new()
123 };
124
125 Ok(Self { config, ca_certs })
126 }
127
128 pub fn generate_certificate(&self) -> Result<CertificateBundle, CertificateError> {
130 use rcgen::generate_simple_self_signed;
131
132 let subject_alt_names = vec![self.config.common_name.clone()];
135 let cert = generate_simple_self_signed(subject_alt_names)
136 .map_err(|e| CertificateError::GenerationFailed(e.to_string()))?;
137
138 let cert_der = cert.cert.der();
140 let private_key_der = cert.signing_key.serialize_der();
141
142 let created_at = std::time::SystemTime::now();
143 let expires_at = created_at + self.config.validity_duration;
144
145 Ok(CertificateBundle {
146 cert_chain: vec![cert_der.clone()],
147 private_key: PrivateKeyDer::try_from(private_key_der).map_err(|e| {
148 CertificateError::PrivateKeyError(format!("Key conversion failed: {e:?}"))
149 })?,
150 created_at,
151 expires_at,
152 })
153 }
154
155 pub fn load_certificate_from_pem(
157 cert_path: &str,
158 key_path: &str,
159 ) -> Result<CertificateBundle, CertificateError> {
160 use rustls_pemfile::{certs, private_key};
161
162 let cert_file = std::fs::File::open(cert_path).map_err(|e| {
164 CertificateError::LoadingFailed(format!("Failed to open cert file: {e}"))
165 })?;
166
167 let mut cert_reader = std::io::BufReader::new(cert_file);
168 let cert_chain: Vec<CertificateDer<'static>> = certs(&mut cert_reader)
169 .collect::<Result<Vec<_>, _>>()
170 .map_err(|e| {
171 CertificateError::ParsingFailed(format!("Failed to parse certificates: {e}"))
172 })?;
173
174 if cert_chain.is_empty() {
175 return Err(CertificateError::LoadingFailed(
176 "No certificates found in file".to_string(),
177 ));
178 }
179
180 let key_file = std::fs::File::open(key_path).map_err(|e| {
182 CertificateError::LoadingFailed(format!("Failed to open key file: {e}"))
183 })?;
184
185 let mut key_reader = std::io::BufReader::new(key_file);
186 let private_key = private_key(&mut key_reader)
187 .map_err(|e| {
188 CertificateError::ParsingFailed(format!("Failed to parse private key: {e}"))
189 })?
190 .ok_or_else(|| {
191 CertificateError::LoadingFailed("No private key found in file".to_string())
192 })?;
193
194 let (created_at, expires_at) = Self::extract_validity_from_cert(&cert_chain[0])?;
196
197 Ok(CertificateBundle {
198 cert_chain,
199 private_key,
200 created_at,
201 expires_at,
202 })
203 }
204
205 pub fn validate_certificate(&self, bundle: &CertificateBundle) -> Result<(), CertificateError> {
207 let now = std::time::SystemTime::now();
209 if now > bundle.expires_at {
210 return Err(CertificateError::ValidityError);
211 }
212
213 if self.config.require_chain_validation && !self.ca_certs.is_empty() {
215 self.validate_certificate_chain(&bundle.cert_chain)?;
216 }
217
218 Ok(())
219 }
220
221 #[cfg(any(feature = "rustls-aws-lc-rs", feature = "rustls-ring"))]
223 pub fn create_server_config(
224 &self,
225 bundle: &CertificateBundle,
226 ) -> Result<Arc<rustls::ServerConfig>, CertificateError> {
227 use rustls::ServerConfig;
228
229 self.validate_certificate(bundle)?;
230
231 let server_config = ServerConfig::builder()
232 .with_no_client_auth()
233 .with_single_cert(bundle.cert_chain.clone(), bundle.private_key.clone_key())
234 .map_err(|e| CertificateError::ValidationFailed(e.to_string()))?;
235
236 Ok(Arc::new(server_config))
237 }
238
239 #[cfg(any(feature = "rustls-aws-lc-rs", feature = "rustls-ring"))]
241 pub fn create_client_config(&self) -> Result<Arc<rustls::ClientConfig>, CertificateError> {
242 use rustls::ClientConfig;
243
244 let config = if self.ca_certs.is_empty() {
245 ClientConfig::builder()
247 .dangerous()
248 .with_custom_certificate_verifier(Arc::new(NoCertificateVerifier))
249 .with_no_client_auth()
250 } else {
251 let mut root_store = rustls::RootCertStore::empty();
253 for ca_cert in &self.ca_certs {
254 root_store.add(ca_cert.clone()).map_err(|e| {
255 CertificateError::ValidationFailed(format!("Failed to add CA cert: {e}"))
256 })?;
257 }
258
259 ClientConfig::builder()
260 .with_root_certificates(root_store)
261 .with_no_client_auth()
262 };
263
264 Ok(Arc::new(config))
265 }
266
267 fn load_ca_certificates(
269 ca_path: &str,
270 ) -> Result<Vec<CertificateDer<'static>>, CertificateError> {
271 use rustls_pemfile::certs;
272
273 let ca_file = std::fs::File::open(ca_path)
274 .map_err(|e| CertificateError::LoadingFailed(format!("Failed to open CA file: {e}")))?;
275
276 let mut ca_reader = std::io::BufReader::new(ca_file);
277 let ca_certs: Vec<CertificateDer<'static>> = certs(&mut ca_reader)
278 .collect::<Result<Vec<_>, _>>()
279 .map_err(|e| {
280 CertificateError::ParsingFailed(format!("Failed to parse CA certificates: {e}"))
281 })?;
282
283 if ca_certs.is_empty() {
284 return Err(CertificateError::LoadingFailed(
285 "No CA certificates found".to_string(),
286 ));
287 }
288
289 Ok(ca_certs)
290 }
291
292 fn extract_validity_from_cert(
294 _cert: &CertificateDer<'static>,
295 ) -> Result<(std::time::SystemTime, std::time::SystemTime), CertificateError> {
296 let created_at = std::time::SystemTime::now();
299 let expires_at = created_at + Duration::from_secs(365 * 24 * 60 * 60); Ok((created_at, expires_at))
302 }
303
304 fn validate_certificate_chain(
306 &self,
307 cert_chain: &[CertificateDer<'static>],
308 ) -> Result<(), CertificateError> {
309 if cert_chain.is_empty() {
310 return Err(CertificateError::ChainError(
311 "Empty certificate chain".to_string(),
312 ));
313 }
314
315 Ok(())
319 }
320}
321
322#[cfg(any(feature = "rustls-aws-lc-rs", feature = "rustls-ring"))]
324#[derive(Debug)]
325struct NoCertificateVerifier;
326
327#[cfg(any(feature = "rustls-aws-lc-rs", feature = "rustls-ring"))]
328impl rustls::client::danger::ServerCertVerifier for NoCertificateVerifier {
329 fn verify_server_cert(
330 &self,
331 _end_entity: &CertificateDer<'_>,
332 _intermediates: &[CertificateDer<'_>],
333 _server_name: &rustls::pki_types::ServerName<'_>,
334 _ocsp_response: &[u8],
335 _now: rustls::pki_types::UnixTime,
336 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
337 Ok(rustls::client::danger::ServerCertVerified::assertion())
338 }
339
340 fn verify_tls12_signature(
341 &self,
342 _message: &[u8],
343 _cert: &CertificateDer<'_>,
344 _dss: &rustls::DigitallySignedStruct,
345 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
346 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
347 }
348
349 fn verify_tls13_signature(
350 &self,
351 _message: &[u8],
352 _cert: &CertificateDer<'_>,
353 _dss: &rustls::DigitallySignedStruct,
354 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
355 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
356 }
357
358 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
359 vec![
360 rustls::SignatureScheme::RSA_PKCS1_SHA1,
361 rustls::SignatureScheme::ECDSA_SHA1_Legacy,
362 rustls::SignatureScheme::RSA_PKCS1_SHA256,
363 rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
364 rustls::SignatureScheme::RSA_PKCS1_SHA384,
365 rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
366 rustls::SignatureScheme::RSA_PKCS1_SHA512,
367 rustls::SignatureScheme::ECDSA_NISTP521_SHA512,
368 rustls::SignatureScheme::RSA_PSS_SHA256,
369 rustls::SignatureScheme::RSA_PSS_SHA384,
370 rustls::SignatureScheme::RSA_PSS_SHA512,
371 rustls::SignatureScheme::ED25519,
372 rustls::SignatureScheme::ED448,
373 ]
374 }
375}
376
377impl CertificateBundle {
378 pub fn expires_within(&self, duration: Duration) -> bool {
380 let now = std::time::SystemTime::now();
381 match now.checked_add(duration) {
382 Some(check_time) => check_time >= self.expires_at,
383 None => true, }
385 }
386
387 pub fn remaining_validity(&self) -> Option<Duration> {
389 std::time::SystemTime::now()
390 .duration_since(self.expires_at)
391 .ok()
392 }
393}
394
395#[cfg(test)]
396mod tests {
397 use super::*;
398
399 #[test]
400 fn test_default_certificate_config() {
401 let config = CertificateConfig::default();
402 assert_eq!(config.common_name, "ant-quic-node");
403 assert_eq!(config.subject_alt_names, vec!["localhost"]);
404 assert!(config.self_signed);
405 assert!(!config.require_chain_validation);
406 }
407
408 #[test]
409 fn test_certificate_manager_creation() {
410 let config = CertificateConfig::default();
411 let manager = CertificateManager::new(config);
412 assert!(manager.is_ok());
413 }
414
415 #[test]
416 fn test_certificate_generation() {
417 let config = CertificateConfig::default();
418 let manager = CertificateManager::new(config).unwrap();
419
420 let bundle = manager.generate_certificate();
421 assert!(bundle.is_ok());
422
423 let bundle = bundle.unwrap();
424 assert!(!bundle.cert_chain.is_empty());
425 assert!(bundle.expires_at > bundle.created_at);
426 }
427
428 #[test]
429 fn test_certificate_bundle_expiry_check() {
430 let dummy_key = vec![
433 0x30, 0x2e, 0x02, 0x01, 0x00, 0x30, 0x05, 0x06, 0x03, 0x2b, 0x65, 0x70, 0x04, 0x22, 0x04, 0x20, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d,
441 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b,
442 0x1c, 0x1d, 0x1e, 0x1f,
443 ];
444
445 let bundle = CertificateBundle {
446 cert_chain: vec![],
447 private_key: PrivateKeyDer::try_from(dummy_key).unwrap(),
448 created_at: std::time::SystemTime::now(),
449 expires_at: std::time::SystemTime::now() + Duration::from_secs(3600), };
451
452 assert!(!bundle.expires_within(Duration::from_secs(1800))); assert!(bundle.expires_within(Duration::from_secs(7200))); }
455}