1use std::{sync::Arc, time::Duration};
11use rustls::pki_types::{CertificateDer, PrivateKeyDer};
12use thiserror::Error;
13
14#[derive(Error, Debug)]
16pub enum CertificateError {
17 #[error("Certificate generation failed: {0}")]
18 GenerationFailed(String),
19
20 #[error("Certificate validation failed: {0}")]
21 ValidationFailed(String),
22
23 #[error("Certificate loading failed: {0}")]
24 LoadingFailed(String),
25
26 #[error("Certificate parsing failed: {0}")]
27 ParsingFailed(String),
28
29 #[error("Private key error: {0}")]
30 PrivateKeyError(String),
31
32 #[error("Certificate chain error: {0}")]
33 ChainError(String),
34
35 #[error("Certificate expired or not yet valid")]
36 ValidityError,
37
38 #[error("Unsupported certificate format")]
39 UnsupportedFormat,
40}
41
42#[derive(Debug, Clone)]
44pub struct CertificateConfig {
45 pub common_name: String,
47
48 pub subject_alt_names: Vec<String>,
50
51 pub validity_duration: Duration,
53
54 pub key_algorithm: KeyAlgorithm,
56
57 pub self_signed: bool,
59
60 pub ca_cert_path: Option<String>,
62
63 pub require_chain_validation: bool,
65}
66
67#[derive(Debug, Clone, Copy)]
69pub enum KeyAlgorithm {
70 Rsa(u32),
72 EcdsaP256,
74 EcdsaP384,
76 Ed25519,
78}
79
80#[derive(Debug)]
82pub struct CertificateBundle {
83 pub cert_chain: Vec<CertificateDer<'static>>,
85
86 pub private_key: PrivateKeyDer<'static>,
88
89 pub created_at: std::time::SystemTime,
91
92 pub expires_at: std::time::SystemTime,
94}
95
96pub struct CertificateManager {
98 config: CertificateConfig,
99 ca_certs: Vec<CertificateDer<'static>>,
100}
101
102impl Default for CertificateConfig {
103 fn default() -> Self {
104 Self {
105 common_name: "ant-quic-node".to_string(),
106 subject_alt_names: vec!["localhost".to_string()],
107 validity_duration: Duration::from_secs(365 * 24 * 60 * 60), key_algorithm: KeyAlgorithm::Ed25519,
109 self_signed: true,
110 ca_cert_path: None,
111 require_chain_validation: false,
112 }
113 }
114}
115
116impl CertificateManager {
117 pub fn new(config: CertificateConfig) -> Result<Self, CertificateError> {
119 let ca_certs = if let Some(ca_path) = &config.ca_cert_path {
120 Self::load_ca_certificates(ca_path)?
121 } else {
122 Vec::new()
123 };
124
125 Ok(Self {
126 config,
127 ca_certs,
128 })
129 }
130
131 pub fn generate_certificate(&self) -> Result<CertificateBundle, CertificateError> {
133 use rcgen::generate_simple_self_signed;
134
135 let subject_alt_names = vec![self.config.common_name.clone()];
138 let cert = generate_simple_self_signed(subject_alt_names)
139 .map_err(|e| CertificateError::GenerationFailed(e.to_string()))?;
140
141 let cert_der = cert.cert.der();
143 let private_key_der = cert.signing_key.serialize_der();
144
145 let created_at = std::time::SystemTime::now();
146 let expires_at = created_at + self.config.validity_duration;
147
148 Ok(CertificateBundle {
149 cert_chain: vec![CertificateDer::from(cert_der.clone())],
150 private_key: PrivateKeyDer::try_from(private_key_der)
151 .map_err(|e| CertificateError::PrivateKeyError(format!("Key conversion failed: {:?}", e)))?,
152 created_at,
153 expires_at,
154 })
155 }
156
157 pub fn load_certificate_from_pem(
159 cert_path: &str,
160 key_path: &str,
161 ) -> Result<CertificateBundle, CertificateError> {
162 use rustls_pemfile::{certs, private_key};
163
164 let cert_file = std::fs::File::open(cert_path)
166 .map_err(|e| CertificateError::LoadingFailed(format!("Failed to open cert file: {}", e)))?;
167
168 let mut cert_reader = std::io::BufReader::new(cert_file);
169 let cert_chain: Vec<CertificateDer<'static>> = certs(&mut cert_reader)
170 .collect::<Result<Vec<_>, _>>()
171 .map_err(|e| CertificateError::ParsingFailed(format!("Failed to parse certificates: {}", e)))?;
172
173 if cert_chain.is_empty() {
174 return Err(CertificateError::LoadingFailed("No certificates found in file".to_string()));
175 }
176
177 let key_file = std::fs::File::open(key_path)
179 .map_err(|e| CertificateError::LoadingFailed(format!("Failed to open key file: {}", e)))?;
180
181 let mut key_reader = std::io::BufReader::new(key_file);
182 let private_key = private_key(&mut key_reader)
183 .map_err(|e| CertificateError::ParsingFailed(format!("Failed to parse private key: {}", e)))?
184 .ok_or_else(|| CertificateError::LoadingFailed("No private key found in file".to_string()))?;
185
186 let (created_at, expires_at) = Self::extract_validity_from_cert(&cert_chain[0])?;
188
189 Ok(CertificateBundle {
190 cert_chain,
191 private_key,
192 created_at,
193 expires_at,
194 })
195 }
196
197 pub fn validate_certificate(&self, bundle: &CertificateBundle) -> Result<(), CertificateError> {
199 let now = std::time::SystemTime::now();
201 if now > bundle.expires_at {
202 return Err(CertificateError::ValidityError);
203 }
204
205 if self.config.require_chain_validation && !self.ca_certs.is_empty() {
207 self.validate_certificate_chain(&bundle.cert_chain)?;
208 }
209
210 Ok(())
211 }
212
213 #[cfg(any(feature = "rustls-aws-lc-rs", feature = "rustls-ring"))]
215 pub fn create_server_config(
216 &self,
217 bundle: &CertificateBundle,
218 ) -> Result<Arc<rustls::ServerConfig>, CertificateError> {
219 use rustls::ServerConfig;
220
221 self.validate_certificate(bundle)?;
222
223 let server_config = ServerConfig::builder()
224 .with_no_client_auth()
225 .with_single_cert(bundle.cert_chain.clone(), bundle.private_key.clone_key())
226 .map_err(|e| CertificateError::ValidationFailed(e.to_string()))?;
227
228 Ok(Arc::new(server_config))
229 }
230
231 #[cfg(any(feature = "rustls-aws-lc-rs", feature = "rustls-ring"))]
233 pub fn create_client_config(&self) -> Result<Arc<rustls::ClientConfig>, CertificateError> {
234 use rustls::ClientConfig;
235
236 let config = if self.ca_certs.is_empty() {
237 ClientConfig::builder()
239 .dangerous()
240 .with_custom_certificate_verifier(Arc::new(NoCertificateVerifier))
241 .with_no_client_auth()
242 } else {
243 let mut root_store = rustls::RootCertStore::empty();
245 for ca_cert in &self.ca_certs {
246 root_store.add(ca_cert.clone())
247 .map_err(|e| CertificateError::ValidationFailed(format!("Failed to add CA cert: {}", e)))?;
248 }
249
250 ClientConfig::builder()
251 .with_root_certificates(root_store)
252 .with_no_client_auth()
253 };
254
255 Ok(Arc::new(config))
256 }
257
258 fn load_ca_certificates(ca_path: &str) -> Result<Vec<CertificateDer<'static>>, CertificateError> {
260 use rustls_pemfile::certs;
261
262 let ca_file = std::fs::File::open(ca_path)
263 .map_err(|e| CertificateError::LoadingFailed(format!("Failed to open CA file: {}", e)))?;
264
265 let mut ca_reader = std::io::BufReader::new(ca_file);
266 let ca_certs: Vec<CertificateDer<'static>> = certs(&mut ca_reader)
267 .collect::<Result<Vec<_>, _>>()
268 .map_err(|e| CertificateError::ParsingFailed(format!("Failed to parse CA certificates: {}", e)))?;
269
270 if ca_certs.is_empty() {
271 return Err(CertificateError::LoadingFailed("No CA certificates found".to_string()));
272 }
273
274 Ok(ca_certs)
275 }
276
277 fn extract_validity_from_cert(
279 _cert: &CertificateDer<'static>,
280 ) -> Result<(std::time::SystemTime, std::time::SystemTime), CertificateError> {
281 let created_at = std::time::SystemTime::now();
284 let expires_at = created_at + Duration::from_secs(365 * 24 * 60 * 60); Ok((created_at, expires_at))
287 }
288
289 fn validate_certificate_chain(
291 &self,
292 cert_chain: &[CertificateDer<'static>],
293 ) -> Result<(), CertificateError> {
294 if cert_chain.is_empty() {
295 return Err(CertificateError::ChainError("Empty certificate chain".to_string()));
296 }
297
298 Ok(())
302 }
303}
304
305#[cfg(any(feature = "rustls-aws-lc-rs", feature = "rustls-ring"))]
307#[derive(Debug)]
308struct NoCertificateVerifier;
309
310#[cfg(any(feature = "rustls-aws-lc-rs", feature = "rustls-ring"))]
311impl rustls::client::danger::ServerCertVerifier for NoCertificateVerifier {
312 fn verify_server_cert(
313 &self,
314 _end_entity: &CertificateDer<'_>,
315 _intermediates: &[CertificateDer<'_>],
316 _server_name: &rustls::pki_types::ServerName<'_>,
317 _ocsp_response: &[u8],
318 _now: rustls::pki_types::UnixTime,
319 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
320 Ok(rustls::client::danger::ServerCertVerified::assertion())
321 }
322
323 fn verify_tls12_signature(
324 &self,
325 _message: &[u8],
326 _cert: &CertificateDer<'_>,
327 _dss: &rustls::DigitallySignedStruct,
328 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
329 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
330 }
331
332 fn verify_tls13_signature(
333 &self,
334 _message: &[u8],
335 _cert: &CertificateDer<'_>,
336 _dss: &rustls::DigitallySignedStruct,
337 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
338 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
339 }
340
341 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
342 vec![
343 rustls::SignatureScheme::RSA_PKCS1_SHA1,
344 rustls::SignatureScheme::ECDSA_SHA1_Legacy,
345 rustls::SignatureScheme::RSA_PKCS1_SHA256,
346 rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
347 rustls::SignatureScheme::RSA_PKCS1_SHA384,
348 rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
349 rustls::SignatureScheme::RSA_PKCS1_SHA512,
350 rustls::SignatureScheme::ECDSA_NISTP521_SHA512,
351 rustls::SignatureScheme::RSA_PSS_SHA256,
352 rustls::SignatureScheme::RSA_PSS_SHA384,
353 rustls::SignatureScheme::RSA_PSS_SHA512,
354 rustls::SignatureScheme::ED25519,
355 rustls::SignatureScheme::ED448,
356 ]
357 }
358}
359
360impl CertificateBundle {
361 pub fn expires_within(&self, duration: Duration) -> bool {
363 let now = std::time::SystemTime::now();
364 match now.checked_add(duration) {
365 Some(check_time) => check_time >= self.expires_at,
366 None => true, }
368 }
369
370 pub fn remaining_validity(&self) -> Option<Duration> {
372 std::time::SystemTime::now()
373 .duration_since(self.expires_at)
374 .ok()
375 }
376}
377
378#[cfg(test)]
379mod tests {
380 use super::*;
381
382 #[test]
383 fn test_default_certificate_config() {
384 let config = CertificateConfig::default();
385 assert_eq!(config.common_name, "ant-quic-node");
386 assert_eq!(config.subject_alt_names, vec!["localhost"]);
387 assert!(config.self_signed);
388 assert!(!config.require_chain_validation);
389 }
390
391 #[test]
392 fn test_certificate_manager_creation() {
393 let config = CertificateConfig::default();
394 let manager = CertificateManager::new(config);
395 assert!(manager.is_ok());
396 }
397
398 #[test]
399 fn test_certificate_generation() {
400 let config = CertificateConfig::default();
401 let manager = CertificateManager::new(config).unwrap();
402
403 let bundle = manager.generate_certificate();
404 assert!(bundle.is_ok());
405
406 let bundle = bundle.unwrap();
407 assert!(!bundle.cert_chain.is_empty());
408 assert!(bundle.expires_at > bundle.created_at);
409 }
410
411 #[test]
412 fn test_certificate_bundle_expiry_check() {
413 let dummy_key = vec![
416 0x30, 0x2e, 0x02, 0x01, 0x00, 0x30, 0x05, 0x06, 0x03, 0x2b, 0x65, 0x70, 0x04, 0x22, 0x04, 0x20, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
424 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
425 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
426 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f,
427 ];
428
429 let bundle = CertificateBundle {
430 cert_chain: vec![],
431 private_key: PrivateKeyDer::try_from(dummy_key).unwrap(),
432 created_at: std::time::SystemTime::now(),
433 expires_at: std::time::SystemTime::now() + Duration::from_secs(3600), };
435
436 assert!(!bundle.expires_within(Duration::from_secs(1800))); assert!(bundle.expires_within(Duration::from_secs(7200))); }
439}