1use std::collections::HashMap;
44use std::fmt;
45use std::fs;
46use std::io::{self, BufReader, Cursor};
47use std::net::SocketAddr;
48use std::path::PathBuf;
49use std::sync::Arc;
50use std::time::{Duration, SystemTime};
51
52use parking_lot::RwLock;
53use serde::{Deserialize, Serialize};
54use thiserror::Error;
55use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
56use tokio::net::TcpStream;
57
58pub use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
60pub use rustls::{ClientConfig, ServerConfig};
61
62#[derive(Debug, Error)]
64pub enum TlsError {
65 #[error("Failed to read certificate file '{path}': {source}")]
67 CertificateReadError {
68 path: PathBuf,
69 #[source]
70 source: io::Error,
71 },
72
73 #[error("Failed to read private key file '{path}': {source}")]
75 KeyReadError {
76 path: PathBuf,
77 #[source]
78 source: io::Error,
79 },
80
81 #[error("Invalid certificate format: {0}")]
83 InvalidCertificate(String),
84
85 #[error("Invalid private key format: {0}")]
87 InvalidPrivateKey(String),
88
89 #[error("Certificate chain validation failed: {0}")]
91 CertificateChainError(String),
92
93 #[error("TLS handshake failed: {0}")]
95 HandshakeError(String),
96
97 #[error("Connection error: {0}")]
99 ConnectionError(String),
100
101 #[error("TLS configuration error: {0}")]
103 ConfigError(String),
104
105 #[error("Certificate expired: {0}")]
107 CertificateExpired(String),
108
109 #[error("Certificate not yet valid: {0}")]
111 CertificateNotYetValid(String),
112
113 #[error("Certificate revoked: {0}")]
115 CertificateRevoked(String),
116
117 #[error("Hostname verification failed: expected '{expected}', got '{actual}'")]
119 HostnameVerificationFailed { expected: String, actual: String },
120
121 #[error("Client certificate required for mTLS but not provided")]
123 ClientCertificateRequired,
124
125 #[error("Failed to generate self-signed certificate: {0}")]
127 SelfSignedGenerationError(String),
128
129 #[error("ALPN negotiation failed: no common protocol")]
131 AlpnNegotiationFailed,
132
133 #[error("TLS internal error: {0}")]
135 RustlsError(String),
136}
137
138impl From<rustls::Error> for TlsError {
139 fn from(err: rustls::Error) -> Self {
140 TlsError::RustlsError(err.to_string())
141 }
142}
143
144pub type TlsResult<T> = std::result::Result<T, TlsError>;
146
147#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
153#[serde(rename_all = "snake_case")]
154pub enum TlsVersion {
155 Tls12,
157 #[default]
159 Tls13,
160}
161
162#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
164#[serde(rename_all = "snake_case")]
165pub enum MtlsMode {
166 #[default]
168 Disabled,
169 Optional,
171 Required,
173}
174
175#[derive(Debug, Clone, Serialize, Deserialize)]
177#[serde(tag = "type", rename_all = "snake_case")]
178pub enum CertificateSource {
179 File { path: PathBuf },
181 Pem { content: String },
183 Der { content: String },
185 SelfSigned { common_name: String },
187}
188
189#[derive(Debug, Clone, Serialize, Deserialize)]
191#[serde(tag = "type", rename_all = "snake_case")]
192pub enum PrivateKeySource {
193 File { path: PathBuf },
195 Pem { content: String },
197 Der { content: String },
199}
200
201#[derive(Debug, Clone, Serialize, Deserialize)]
203pub struct TlsConfig {
204 #[serde(default = "default_true")]
206 pub enabled: bool,
207
208 pub certificate: Option<CertificateSource>,
210
211 pub private_key: Option<PrivateKeySource>,
213
214 pub root_ca: Option<CertificateSource>,
216
217 pub client_ca: Option<CertificateSource>,
219
220 #[serde(default)]
222 pub mtls_mode: MtlsMode,
223
224 #[serde(default)]
226 pub min_version: TlsVersion,
227
228 #[serde(default)]
230 pub alpn_protocols: Vec<String>,
231
232 #[serde(default)]
234 pub ocsp_stapling: bool,
235
236 #[serde(default)]
238 pub pinned_certificates: Vec<String>,
239
240 #[serde(default)]
242 pub insecure_skip_verify: bool,
243
244 pub server_name: Option<String>,
246
247 #[serde(default = "default_session_cache_size")]
249 pub session_cache_size: usize,
250
251 #[serde(default = "default_session_ticket_lifetime")]
253 #[serde(with = "humantime_serde")]
254 pub session_ticket_lifetime: Duration,
255
256 #[serde(default)]
258 #[serde(with = "humantime_serde")]
259 pub cert_reload_interval: Duration,
260}
261
262fn default_true() -> bool {
263 true
264}
265
266fn default_session_cache_size() -> usize {
267 256
268}
269
270fn default_session_ticket_lifetime() -> Duration {
271 Duration::from_secs(86400) }
273
274impl Default for TlsConfig {
275 fn default() -> Self {
276 Self {
277 enabled: false, certificate: None,
279 private_key: None,
280 root_ca: None,
281 client_ca: None,
282 mtls_mode: MtlsMode::Disabled,
283 min_version: TlsVersion::Tls13,
284 alpn_protocols: vec![],
285 ocsp_stapling: false,
286 pinned_certificates: vec![],
287 insecure_skip_verify: false,
288 server_name: None,
289 session_cache_size: default_session_cache_size(),
290 session_ticket_lifetime: default_session_ticket_lifetime(),
291 cert_reload_interval: Duration::ZERO,
292 }
293 }
294}
295
296impl TlsConfig {
297 pub fn disabled() -> Self {
299 Self::default()
300 }
301
302 pub fn self_signed(common_name: &str) -> Self {
304 Self {
305 enabled: true,
306 certificate: Some(CertificateSource::SelfSigned {
307 common_name: common_name.to_string(),
308 }),
309 private_key: None, insecure_skip_verify: true, ..Default::default()
312 }
313 }
314
315 pub fn from_pem_files<P: Into<PathBuf>>(cert_path: P, key_path: P) -> Self {
317 Self {
318 enabled: true,
319 certificate: Some(CertificateSource::File {
320 path: cert_path.into(),
321 }),
322 private_key: Some(PrivateKeySource::File {
323 path: key_path.into(),
324 }),
325 ..Default::default()
326 }
327 }
328
329 pub fn mtls_from_pem_files<P1, P2, P3>(cert_path: P1, key_path: P2, ca_path: P3) -> Self
331 where
332 P1: Into<PathBuf>,
333 P2: Into<PathBuf>,
334 P3: Into<PathBuf> + Clone,
335 {
336 let ca: PathBuf = ca_path.clone().into();
337 Self {
338 enabled: true,
339 certificate: Some(CertificateSource::File {
340 path: cert_path.into(),
341 }),
342 private_key: Some(PrivateKeySource::File {
343 path: key_path.into(),
344 }),
345 client_ca: Some(CertificateSource::File { path: ca.clone() }),
346 root_ca: Some(CertificateSource::File { path: ca }),
347 mtls_mode: MtlsMode::Required,
348 ..Default::default()
349 }
350 }
351}
352
353pub struct TlsConfigBuilder {
359 config: TlsConfig,
360}
361
362impl TlsConfigBuilder {
363 pub fn new() -> Self {
365 Self {
366 config: TlsConfig {
367 enabled: true,
368 ..Default::default()
369 },
370 }
371 }
372
373 pub fn with_cert_file<P: Into<PathBuf>>(mut self, path: P) -> Self {
375 self.config.certificate = Some(CertificateSource::File { path: path.into() });
376 self
377 }
378
379 pub fn with_cert_pem(mut self, pem: String) -> Self {
381 self.config.certificate = Some(CertificateSource::Pem { content: pem });
382 self
383 }
384
385 pub fn with_key_file<P: Into<PathBuf>>(mut self, path: P) -> Self {
387 self.config.private_key = Some(PrivateKeySource::File { path: path.into() });
388 self
389 }
390
391 pub fn with_key_pem(mut self, pem: String) -> Self {
393 self.config.private_key = Some(PrivateKeySource::Pem { content: pem });
394 self
395 }
396
397 pub fn with_root_ca_file<P: Into<PathBuf>>(mut self, path: P) -> Self {
399 self.config.root_ca = Some(CertificateSource::File { path: path.into() });
400 self
401 }
402
403 pub fn with_client_ca_file<P: Into<PathBuf>>(mut self, path: P) -> Self {
405 self.config.client_ca = Some(CertificateSource::File { path: path.into() });
406 self
407 }
408
409 pub fn require_client_cert(mut self, required: bool) -> Self {
411 self.config.mtls_mode = if required {
412 MtlsMode::Required
413 } else {
414 MtlsMode::Disabled
415 };
416 self
417 }
418
419 pub fn with_mtls_mode(mut self, mode: MtlsMode) -> Self {
421 self.config.mtls_mode = mode;
422 self
423 }
424
425 pub fn with_min_version(mut self, version: TlsVersion) -> Self {
427 self.config.min_version = version;
428 self
429 }
430
431 pub fn with_alpn_protocols(mut self, protocols: Vec<String>) -> Self {
433 self.config.alpn_protocols = protocols;
434 self
435 }
436
437 pub fn with_server_name(mut self, name: String) -> Self {
439 self.config.server_name = Some(name);
440 self
441 }
442
443 pub fn insecure_skip_verify(mut self) -> Self {
445 self.config.insecure_skip_verify = true;
446 self
447 }
448
449 pub fn with_pinned_certificate(mut self, fingerprint: String) -> Self {
451 self.config.pinned_certificates.push(fingerprint);
452 self
453 }
454
455 pub fn with_self_signed(mut self, common_name: &str) -> Self {
457 self.config.certificate = Some(CertificateSource::SelfSigned {
458 common_name: common_name.to_string(),
459 });
460 self.config.insecure_skip_verify = true;
461 self
462 }
463
464 pub fn with_cert_reload_interval(mut self, interval: Duration) -> Self {
466 self.config.cert_reload_interval = interval;
467 self
468 }
469
470 pub fn build(self) -> TlsConfig {
472 self.config
473 }
474}
475
476impl Default for TlsConfigBuilder {
477 fn default() -> Self {
478 Self::new()
479 }
480}
481
482pub fn load_certificates(source: &CertificateSource) -> TlsResult<Vec<CertificateDer<'static>>> {
488 match source {
489 CertificateSource::File { path } => {
490 let data = fs::read(path).map_err(|e| TlsError::CertificateReadError {
491 path: path.clone(),
492 source: e,
493 })?;
494 parse_pem_certificates(&data)
495 }
496 CertificateSource::Pem { content } => parse_pem_certificates(content.as_bytes()),
497 CertificateSource::Der { content } => {
498 let der =
499 base64::Engine::decode(&base64::engine::general_purpose::STANDARD, content)
500 .map_err(|e| TlsError::InvalidCertificate(format!("Invalid base64: {}", e)))?;
501 Ok(vec![CertificateDer::from(der)])
502 }
503 CertificateSource::SelfSigned { common_name } => {
504 let (cert, _key) = generate_self_signed(common_name)?;
505 Ok(vec![cert])
506 }
507 }
508}
509
510fn parse_pem_certificates(data: &[u8]) -> TlsResult<Vec<CertificateDer<'static>>> {
512 let mut reader = BufReader::new(Cursor::new(data));
513 let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut reader)
514 .collect::<Result<Vec<_>, _>>()
515 .map_err(|e| TlsError::InvalidCertificate(format!("Failed to parse PEM: {}", e)))?;
516
517 if certs.is_empty() {
518 return Err(TlsError::InvalidCertificate(
519 "No certificates found in PEM data".to_string(),
520 ));
521 }
522
523 Ok(certs)
524}
525
526pub fn load_private_key(source: &PrivateKeySource) -> TlsResult<PrivateKeyDer<'static>> {
528 match source {
529 PrivateKeySource::File { path } => {
530 let data = fs::read(path).map_err(|e| TlsError::KeyReadError {
531 path: path.clone(),
532 source: e,
533 })?;
534 parse_pem_private_key(&data)
535 }
536 PrivateKeySource::Pem { content } => parse_pem_private_key(content.as_bytes()),
537 PrivateKeySource::Der { content } => {
538 let der = base64::Engine::decode(&base64::engine::general_purpose::STANDARD, content)
539 .map_err(|e| TlsError::InvalidPrivateKey(format!("Invalid base64: {}", e)))?;
540 Ok(PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(der)))
541 }
542 }
543}
544
545fn parse_pem_private_key(data: &[u8]) -> TlsResult<PrivateKeyDer<'static>> {
547 let mut reader = BufReader::new(Cursor::new(data));
548
549 rustls_pemfile::private_key(&mut reader)
550 .map_err(|e| TlsError::InvalidPrivateKey(format!("Failed to parse PEM: {}", e)))?
551 .ok_or_else(|| TlsError::InvalidPrivateKey("No private key found in PEM data".to_string()))
552}
553
554pub fn generate_self_signed(
556 common_name: &str,
557) -> TlsResult<(CertificateDer<'static>, PrivateKeyDer<'static>)> {
558 let subject_alt_names = vec![
559 common_name.to_string(),
560 "localhost".to_string(),
561 "127.0.0.1".to_string(),
562 ];
563
564 let mut cert_params = rcgen::CertificateParams::new(subject_alt_names)
565 .map_err(|e| TlsError::SelfSignedGenerationError(e.to_string()))?;
566
567 cert_params.distinguished_name = rcgen::DistinguishedName::new();
569 cert_params.distinguished_name.push(
570 rcgen::DnType::CommonName,
571 rcgen::DnValue::Utf8String(common_name.to_string()),
572 );
573 cert_params.distinguished_name.push(
574 rcgen::DnType::OrganizationName,
575 rcgen::DnValue::Utf8String("Rivven".to_string()),
576 );
577
578 let key_pair = rcgen::KeyPair::generate()
579 .map_err(|e| TlsError::SelfSignedGenerationError(e.to_string()))?;
580
581 let cert = cert_params
582 .self_signed(&key_pair)
583 .map_err(|e| TlsError::SelfSignedGenerationError(e.to_string()))?;
584
585 let cert_der = CertificateDer::from(cert.der().to_vec());
586 let key_der = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(key_pair.serialize_der()));
587
588 Ok((cert_der, key_der))
589}
590
591pub struct TlsAcceptor {
597 config: Arc<ServerConfig>,
598 inner: tokio_rustls::TlsAcceptor,
599 tls_config: TlsConfig,
601 reloadable_config: Option<Arc<RwLock<Arc<ServerConfig>>>>,
603}
604
605impl TlsAcceptor {
606 pub fn new(config: &TlsConfig) -> TlsResult<Self> {
608 let server_config = build_server_config(config)?;
609 let server_config = Arc::new(server_config);
610
611 Ok(Self {
612 inner: tokio_rustls::TlsAcceptor::from(server_config.clone()),
613 config: server_config.clone(),
614 tls_config: config.clone(),
615 reloadable_config: if config.cert_reload_interval > Duration::ZERO {
616 Some(Arc::new(RwLock::new(server_config)))
617 } else {
618 None
619 },
620 })
621 }
622
623 pub async fn accept<IO>(&self, stream: IO) -> TlsResult<TlsServerStream<IO>>
625 where
626 IO: AsyncRead + AsyncWrite + Unpin,
627 {
628 let tls_stream = self
629 .inner
630 .accept(stream)
631 .await
632 .map_err(|e| TlsError::HandshakeError(e.to_string()))?;
633
634 Ok(TlsServerStream { inner: tls_stream })
635 }
636
637 pub async fn accept_tcp(&self, stream: TcpStream) -> TlsResult<TlsServerStream<TcpStream>> {
639 self.accept(stream).await
640 }
641
642 pub fn reload(&self) -> TlsResult<()> {
644 if let Some(ref reloadable) = self.reloadable_config {
645 let new_config = build_server_config(&self.tls_config)?;
646 *reloadable.write() = Arc::new(new_config);
647 }
648 Ok(())
649 }
650
651 pub fn config(&self) -> &Arc<ServerConfig> {
653 &self.config
654 }
655}
656
657impl fmt::Debug for TlsAcceptor {
658 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
659 f.debug_struct("TlsAcceptor")
660 .field("mtls_mode", &self.tls_config.mtls_mode)
661 .field("min_version", &self.tls_config.min_version)
662 .finish()
663 }
664}
665
666fn build_server_config(config: &TlsConfig) -> TlsResult<ServerConfig> {
668 let (certs, key) =
670 if let Some(CertificateSource::SelfSigned { common_name }) = &config.certificate {
671 let (cert, key) = generate_self_signed(common_name)?;
673 (vec![cert], key)
674 } else {
675 let certs = if let Some(ref cert_source) = config.certificate {
677 load_certificates(cert_source)?
678 } else {
679 return Err(TlsError::ConfigError(
680 "Server certificate required".to_string(),
681 ));
682 };
683
684 let key = if let Some(ref key_source) = config.private_key {
686 load_private_key(key_source)?
687 } else {
688 return Err(TlsError::ConfigError("Private key required".to_string()));
689 };
690
691 (certs, key)
692 };
693
694 let versions: Vec<&'static rustls::SupportedProtocolVersion> = match config.min_version {
696 TlsVersion::Tls13 => vec![&rustls::version::TLS13],
697 TlsVersion::Tls12 => vec![&rustls::version::TLS12, &rustls::version::TLS13],
698 };
699
700 let client_cert_verifier = match config.mtls_mode {
702 MtlsMode::Disabled => None,
703 MtlsMode::Optional | MtlsMode::Required => {
704 if let Some(ref ca_source) = config.client_ca {
705 let ca_certs = load_certificates(ca_source)?;
706 let mut root_store = rustls::RootCertStore::empty();
707 for cert in ca_certs {
708 root_store.add(cert).map_err(|e| {
709 TlsError::CertificateChainError(format!("Failed to add CA cert: {}", e))
710 })?;
711 }
712
713 let verifier = if config.mtls_mode == MtlsMode::Required {
714 rustls::server::WebPkiClientVerifier::builder(Arc::new(root_store))
715 .build()
716 .map_err(|e| {
717 TlsError::ConfigError(format!("Failed to build client verifier: {}", e))
718 })?
719 } else {
720 rustls::server::WebPkiClientVerifier::builder(Arc::new(root_store))
721 .allow_unauthenticated()
722 .build()
723 .map_err(|e| {
724 TlsError::ConfigError(format!("Failed to build client verifier: {}", e))
725 })?
726 };
727
728 Some(verifier)
729 } else if config.mtls_mode == MtlsMode::Required {
730 return Err(TlsError::ConfigError(
731 "mTLS required but no client CA configured".to_string(),
732 ));
733 } else {
734 None
735 }
736 }
737 };
738
739 let mut server_config = if let Some(verifier) = client_cert_verifier {
741 ServerConfig::builder_with_protocol_versions(&versions)
742 .with_client_cert_verifier(verifier)
743 .with_single_cert(certs, key)
744 .map_err(|e| TlsError::ConfigError(format!("Invalid cert/key: {}", e)))?
745 } else {
746 ServerConfig::builder_with_protocol_versions(&versions)
747 .with_no_client_auth()
748 .with_single_cert(certs, key)
749 .map_err(|e| TlsError::ConfigError(format!("Invalid cert/key: {}", e)))?
750 };
751
752 if !config.alpn_protocols.is_empty() {
754 server_config.alpn_protocols = config
755 .alpn_protocols
756 .iter()
757 .map(|p| p.as_bytes().to_vec())
758 .collect();
759 }
760
761 if config.session_cache_size > 0 {
763 server_config.session_storage =
764 rustls::server::ServerSessionMemoryCache::new(config.session_cache_size);
765 }
766
767 Ok(server_config)
768}
769
770pub struct TlsConnector {
776 config: Arc<ClientConfig>,
777 inner: tokio_rustls::TlsConnector,
778 server_name: Option<String>,
780}
781
782impl TlsConnector {
783 pub fn new(config: &TlsConfig) -> TlsResult<Self> {
785 let client_config = build_client_config(config)?;
786 let client_config = Arc::new(client_config);
787
788 Ok(Self {
789 inner: tokio_rustls::TlsConnector::from(client_config.clone()),
790 config: client_config,
791 server_name: config.server_name.clone(),
792 })
793 }
794
795 pub async fn connect<IO>(&self, stream: IO, server_name: &str) -> TlsResult<TlsClientStream<IO>>
797 where
798 IO: AsyncRead + AsyncWrite + Unpin,
799 {
800 let name: rustls::pki_types::ServerName<'static> = server_name
801 .to_string()
802 .try_into()
803 .map_err(|_| TlsError::ConfigError(format!("Invalid server name: {}", server_name)))?;
804
805 let tls_stream = self
806 .inner
807 .connect(name, stream)
808 .await
809 .map_err(|e| TlsError::HandshakeError(e.to_string()))?;
810
811 Ok(TlsClientStream { inner: tls_stream })
812 }
813
814 pub async fn connect_tcp(
816 &self,
817 addr: SocketAddr,
818 server_name: &str,
819 ) -> TlsResult<TlsClientStream<TcpStream>> {
820 let stream = TcpStream::connect(addr)
821 .await
822 .map_err(|e| TlsError::ConnectionError(e.to_string()))?;
823
824 self.connect(stream, server_name).await
825 }
826
827 pub async fn connect_with_default_name<IO>(&self, stream: IO) -> TlsResult<TlsClientStream<IO>>
829 where
830 IO: AsyncRead + AsyncWrite + Unpin,
831 {
832 let name = self.server_name.as_ref().ok_or_else(|| {
833 TlsError::ConfigError("No server name configured for SNI".to_string())
834 })?;
835 self.connect(stream, name).await
836 }
837
838 pub fn config(&self) -> &Arc<ClientConfig> {
840 &self.config
841 }
842}
843
844impl fmt::Debug for TlsConnector {
845 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
846 f.debug_struct("TlsConnector")
847 .field("server_name", &self.server_name)
848 .finish()
849 }
850}
851
852fn build_client_config(config: &TlsConfig) -> TlsResult<ClientConfig> {
854 let versions: Vec<&'static rustls::SupportedProtocolVersion> = match config.min_version {
856 TlsVersion::Tls13 => vec![&rustls::version::TLS13],
857 TlsVersion::Tls12 => vec![&rustls::version::TLS12, &rustls::version::TLS13],
858 };
859
860 let root_store = if config.insecure_skip_verify {
862 rustls::RootCertStore::empty()
864 } else if let Some(ref ca_source) = config.root_ca {
865 let ca_certs = load_certificates(ca_source)?;
866 let mut store = rustls::RootCertStore::empty();
867 for cert in ca_certs {
868 store.add(cert).map_err(|e| {
869 TlsError::CertificateChainError(format!("Failed to add root CA: {}", e))
870 })?;
871 }
872 store
873 } else {
874 let mut store = rustls::RootCertStore::empty();
876 let native_certs = rustls_native_certs::load_native_certs();
877 for cert in native_certs.certs {
878 let _ = store.add(cert);
879 }
880 store
881 };
882
883 let mut client_config = if let (Some(ref cert_source), Some(ref key_source)) =
885 (&config.certificate, &config.private_key)
886 {
887 let certs = load_certificates(cert_source)?;
889 let key = load_private_key(key_source)?;
890
891 ClientConfig::builder_with_protocol_versions(&versions)
892 .with_root_certificates(root_store)
893 .with_client_auth_cert(certs, key)
894 .map_err(|e| TlsError::ConfigError(format!("Invalid client cert/key: {}", e)))?
895 } else if config.insecure_skip_verify {
896 ClientConfig::builder_with_protocol_versions(&versions)
898 .dangerous()
899 .with_custom_certificate_verifier(Arc::new(NoCertificateVerification))
900 .with_no_client_auth()
901 } else {
902 ClientConfig::builder_with_protocol_versions(&versions)
904 .with_root_certificates(root_store)
905 .with_no_client_auth()
906 };
907
908 if !config.alpn_protocols.is_empty() {
910 client_config.alpn_protocols = config
911 .alpn_protocols
912 .iter()
913 .map(|p| p.as_bytes().to_vec())
914 .collect();
915 }
916
917 Ok(client_config)
918}
919
920#[derive(Debug)]
923struct NoCertificateVerification;
924
925impl rustls::client::danger::ServerCertVerifier for NoCertificateVerification {
926 fn verify_server_cert(
927 &self,
928 _end_entity: &CertificateDer<'_>,
929 _intermediates: &[CertificateDer<'_>],
930 _server_name: &rustls::pki_types::ServerName<'_>,
931 _ocsp_response: &[u8],
932 _now: rustls::pki_types::UnixTime,
933 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
934 Ok(rustls::client::danger::ServerCertVerified::assertion())
935 }
936
937 fn verify_tls12_signature(
938 &self,
939 _message: &[u8],
940 _cert: &CertificateDer<'_>,
941 _dss: &rustls::DigitallySignedStruct,
942 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
943 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
944 }
945
946 fn verify_tls13_signature(
947 &self,
948 _message: &[u8],
949 _cert: &CertificateDer<'_>,
950 _dss: &rustls::DigitallySignedStruct,
951 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
952 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
953 }
954
955 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
956 vec![
957 rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
958 rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
959 rustls::SignatureScheme::ECDSA_NISTP521_SHA512,
960 rustls::SignatureScheme::RSA_PSS_SHA256,
961 rustls::SignatureScheme::RSA_PSS_SHA384,
962 rustls::SignatureScheme::RSA_PSS_SHA512,
963 rustls::SignatureScheme::RSA_PKCS1_SHA256,
964 rustls::SignatureScheme::RSA_PKCS1_SHA384,
965 rustls::SignatureScheme::RSA_PKCS1_SHA512,
966 rustls::SignatureScheme::ED25519,
967 ]
968 }
969}
970
971pub struct TlsServerStream<IO> {
977 inner: tokio_rustls::server::TlsStream<IO>,
978}
979
980impl<IO> TlsServerStream<IO>
981where
982 IO: AsyncRead + AsyncWrite + Unpin,
983{
984 pub fn peer_certificates(&self) -> Option<&[CertificateDer<'_>]> {
986 self.inner.get_ref().1.peer_certificates()
987 }
988
989 pub fn alpn_protocol(&self) -> Option<&[u8]> {
991 self.inner.get_ref().1.alpn_protocol()
992 }
993
994 pub fn protocol_version(&self) -> Option<rustls::ProtocolVersion> {
996 self.inner.get_ref().1.protocol_version()
997 }
998
999 pub fn negotiated_cipher_suite(&self) -> Option<rustls::SupportedCipherSuite> {
1001 self.inner.get_ref().1.negotiated_cipher_suite()
1002 }
1003
1004 pub fn cipher_suite_name(&self) -> Option<String> {
1006 self.negotiated_cipher_suite()
1007 .map(|cs| format!("{:?}", cs.suite()))
1008 }
1009
1010 pub fn is_tls_13(&self) -> bool {
1012 self.protocol_version() == Some(rustls::ProtocolVersion::TLSv1_3)
1013 }
1014
1015 pub fn peer_common_name(&self) -> Option<String> {
1017 self.peer_certificates().and_then(|certs| {
1018 if certs.is_empty() {
1019 return None;
1020 }
1021 extract_common_name(&certs[0])
1022 })
1023 }
1024
1025 pub fn peer_subject(&self) -> Option<String> {
1027 self.peer_certificates().and_then(|certs| {
1028 if certs.is_empty() {
1029 return None;
1030 }
1031 extract_subject(&certs[0])
1032 })
1033 }
1034
1035 pub fn get_ref(&self) -> &IO {
1037 self.inner.get_ref().0
1038 }
1039
1040 pub fn get_mut(&mut self) -> &mut IO {
1042 self.inner.get_mut().0
1043 }
1044
1045 pub fn into_inner(self) -> IO {
1047 self.inner.into_inner().0
1048 }
1049}
1050
1051impl<IO> tokio::io::AsyncRead for TlsServerStream<IO>
1052where
1053 IO: AsyncRead + AsyncWrite + Unpin,
1054{
1055 fn poll_read(
1056 mut self: std::pin::Pin<&mut Self>,
1057 cx: &mut std::task::Context<'_>,
1058 buf: &mut ReadBuf<'_>,
1059 ) -> std::task::Poll<io::Result<()>> {
1060 std::pin::Pin::new(&mut self.inner).poll_read(cx, buf)
1061 }
1062}
1063
1064impl<IO> tokio::io::AsyncWrite for TlsServerStream<IO>
1065where
1066 IO: AsyncRead + AsyncWrite + Unpin,
1067{
1068 fn poll_write(
1069 mut self: std::pin::Pin<&mut Self>,
1070 cx: &mut std::task::Context<'_>,
1071 buf: &[u8],
1072 ) -> std::task::Poll<io::Result<usize>> {
1073 std::pin::Pin::new(&mut self.inner).poll_write(cx, buf)
1074 }
1075
1076 fn poll_flush(
1077 mut self: std::pin::Pin<&mut Self>,
1078 cx: &mut std::task::Context<'_>,
1079 ) -> std::task::Poll<io::Result<()>> {
1080 std::pin::Pin::new(&mut self.inner).poll_flush(cx)
1081 }
1082
1083 fn poll_shutdown(
1084 mut self: std::pin::Pin<&mut Self>,
1085 cx: &mut std::task::Context<'_>,
1086 ) -> std::task::Poll<io::Result<()>> {
1087 std::pin::Pin::new(&mut self.inner).poll_shutdown(cx)
1088 }
1089}
1090
1091pub struct TlsClientStream<IO> {
1093 inner: tokio_rustls::client::TlsStream<IO>,
1094}
1095
1096impl<IO> TlsClientStream<IO>
1097where
1098 IO: AsyncRead + AsyncWrite + Unpin,
1099{
1100 pub fn peer_certificates(&self) -> Option<&[CertificateDer<'_>]> {
1102 self.inner.get_ref().1.peer_certificates()
1103 }
1104
1105 pub fn alpn_protocol(&self) -> Option<&[u8]> {
1107 self.inner.get_ref().1.alpn_protocol()
1108 }
1109
1110 pub fn protocol_version(&self) -> Option<rustls::ProtocolVersion> {
1112 self.inner.get_ref().1.protocol_version()
1113 }
1114
1115 pub fn is_tls_13(&self) -> bool {
1117 self.protocol_version() == Some(rustls::ProtocolVersion::TLSv1_3)
1118 }
1119
1120 pub fn get_ref(&self) -> &IO {
1122 self.inner.get_ref().0
1123 }
1124
1125 pub fn get_mut(&mut self) -> &mut IO {
1127 self.inner.get_mut().0
1128 }
1129
1130 pub fn into_inner(self) -> IO {
1132 self.inner.into_inner().0
1133 }
1134}
1135
1136impl<IO> tokio::io::AsyncRead for TlsClientStream<IO>
1137where
1138 IO: AsyncRead + AsyncWrite + Unpin,
1139{
1140 fn poll_read(
1141 mut self: std::pin::Pin<&mut Self>,
1142 cx: &mut std::task::Context<'_>,
1143 buf: &mut ReadBuf<'_>,
1144 ) -> std::task::Poll<io::Result<()>> {
1145 std::pin::Pin::new(&mut self.inner).poll_read(cx, buf)
1146 }
1147}
1148
1149impl<IO> tokio::io::AsyncWrite for TlsClientStream<IO>
1150where
1151 IO: AsyncRead + AsyncWrite + Unpin,
1152{
1153 fn poll_write(
1154 mut self: std::pin::Pin<&mut Self>,
1155 cx: &mut std::task::Context<'_>,
1156 buf: &[u8],
1157 ) -> std::task::Poll<io::Result<usize>> {
1158 std::pin::Pin::new(&mut self.inner).poll_write(cx, buf)
1159 }
1160
1161 fn poll_flush(
1162 mut self: std::pin::Pin<&mut Self>,
1163 cx: &mut std::task::Context<'_>,
1164 ) -> std::task::Poll<io::Result<()>> {
1165 std::pin::Pin::new(&mut self.inner).poll_flush(cx)
1166 }
1167
1168 fn poll_shutdown(
1169 mut self: std::pin::Pin<&mut Self>,
1170 cx: &mut std::task::Context<'_>,
1171 ) -> std::task::Poll<io::Result<()>> {
1172 std::pin::Pin::new(&mut self.inner).poll_shutdown(cx)
1173 }
1174}
1175
1176fn extract_common_name(cert: &CertificateDer<'_>) -> Option<String> {
1182 let (_, cert) = x509_parser::parse_x509_certificate(cert.as_ref()).ok()?;
1184
1185 for rdn in cert.subject().iter_rdn() {
1186 for attr in rdn.iter() {
1187 if attr.attr_type() == &x509_parser::oid_registry::OID_X509_COMMON_NAME {
1188 return attr.as_str().ok().map(|s| s.to_string());
1189 }
1190 }
1191 }
1192
1193 None
1194}
1195
1196fn extract_subject(cert: &CertificateDer<'_>) -> Option<String> {
1198 let (_, cert) = x509_parser::parse_x509_certificate(cert.as_ref()).ok()?;
1199 Some(cert.subject().to_string())
1200}
1201
1202pub fn certificate_fingerprint(cert: &CertificateDer<'_>) -> String {
1204 use sha2::{Digest, Sha256};
1205 let hash = Sha256::digest(cert.as_ref());
1206 hex::encode(hash)
1207}
1208
1209pub fn verify_certificate_chain(
1214 chain: &[CertificateDer<'_>],
1215 root_store: &rustls::RootCertStore,
1216) -> TlsResult<()> {
1217 if chain.is_empty() {
1218 return Err(TlsError::CertificateChainError(
1219 "Empty certificate chain".to_string(),
1220 ));
1221 }
1222
1223 if root_store.is_empty() {
1226 tracing::warn!("Root certificate store is empty - chain validation may fail");
1227 }
1228
1229 for (i, cert) in chain.iter().enumerate() {
1231 let fingerprint = certificate_fingerprint(cert);
1232 tracing::debug!(
1233 "Certificate chain[{}]: fingerprint={}",
1234 i,
1235 &fingerprint[..16]
1236 );
1237 }
1238
1239 Ok(())
1240}
1241
1242pub struct CertificateWatcher {
1248 watched_files: Vec<PathBuf>,
1250 last_modified: HashMap<PathBuf, SystemTime>,
1252 reload_callback: Box<dyn Fn() + Send + Sync>,
1254}
1255
1256impl CertificateWatcher {
1257 pub fn new<F>(files: Vec<PathBuf>, callback: F) -> Self
1259 where
1260 F: Fn() + Send + Sync + 'static,
1261 {
1262 let mut last_modified = HashMap::new();
1263 for file in &files {
1264 if let Ok(meta) = fs::metadata(file) {
1265 if let Ok(modified) = meta.modified() {
1266 last_modified.insert(file.clone(), modified);
1267 }
1268 }
1269 }
1270
1271 Self {
1272 watched_files: files,
1273 last_modified,
1274 reload_callback: Box::new(callback),
1275 }
1276 }
1277
1278 pub fn check_and_reload(&mut self) -> bool {
1280 let mut changed = false;
1281
1282 for file in &self.watched_files {
1283 if let Ok(meta) = fs::metadata(file) {
1284 if let Ok(modified) = meta.modified() {
1285 let last = self.last_modified.get(file);
1286 if last.is_none_or(|&l| modified > l) {
1287 self.last_modified.insert(file.clone(), modified);
1288 changed = true;
1289 }
1290 }
1291 }
1292 }
1293
1294 if changed {
1295 (self.reload_callback)();
1296 }
1297
1298 changed
1299 }
1300
1301 pub fn spawn(mut self, interval: Duration) -> tokio::task::JoinHandle<()> {
1303 tokio::spawn(async move {
1304 let mut ticker = tokio::time::interval(interval);
1305 loop {
1306 ticker.tick().await;
1307 self.check_and_reload();
1308 }
1309 })
1310 }
1311}
1312
1313#[derive(Debug, Clone, Serialize, Deserialize)]
1319pub struct TlsIdentity {
1320 pub common_name: Option<String>,
1322 pub subject: Option<String>,
1324 pub fingerprint: String,
1326 pub organization: Option<String>,
1328 pub organizational_unit: Option<String>,
1330 pub serial_number: Option<String>,
1332 pub valid_from: Option<chrono::DateTime<chrono::Utc>>,
1334 pub valid_until: Option<chrono::DateTime<chrono::Utc>>,
1335 pub is_valid: bool,
1337}
1338
1339impl TlsIdentity {
1340 pub fn from_certificate(cert: &CertificateDer<'_>) -> Self {
1342 let fingerprint = certificate_fingerprint(cert);
1343 let common_name = extract_common_name(cert);
1344 let subject = extract_subject(cert);
1345
1346 let (organization, organizational_unit, serial_number, valid_from, valid_until, is_valid) =
1348 if let Ok((_, parsed)) = x509_parser::parse_x509_certificate(cert.as_ref()) {
1349 let mut org = None;
1350 let mut ou = None;
1351
1352 for rdn in parsed.subject().iter_rdn() {
1353 for attr in rdn.iter() {
1354 if attr.attr_type()
1355 == &x509_parser::oid_registry::OID_X509_ORGANIZATION_NAME
1356 {
1357 org = attr.as_str().ok().map(|s| s.to_string());
1358 }
1359 if attr.attr_type()
1360 == &x509_parser::oid_registry::OID_X509_ORGANIZATIONAL_UNIT
1361 {
1362 ou = attr.as_str().ok().map(|s| s.to_string());
1363 }
1364 }
1365 }
1366
1367 let serial = Some(parsed.serial.to_str_radix(16));
1368
1369 let validity = parsed.validity();
1370 let now = chrono::Utc::now();
1371
1372 let from = chrono::DateTime::from_timestamp(validity.not_before.timestamp(), 0);
1373 let until = chrono::DateTime::from_timestamp(validity.not_after.timestamp(), 0);
1374
1375 let valid = from.is_some_and(|f| now >= f) && until.is_some_and(|u| now <= u);
1376
1377 (org, ou, serial, from, until, valid)
1378 } else {
1379 (None, None, None, None, None, false)
1380 };
1381
1382 Self {
1383 common_name,
1384 subject,
1385 fingerprint,
1386 organization,
1387 organizational_unit,
1388 serial_number,
1389 valid_from,
1390 valid_until,
1391 is_valid,
1392 }
1393 }
1394}
1395
1396#[derive(Debug)]
1402pub struct TlsSecurityAudit {
1403 pub warnings: Vec<String>,
1404 pub errors: Vec<String>,
1405 pub recommendations: Vec<String>,
1406}
1407
1408impl TlsSecurityAudit {
1409 pub fn audit(config: &TlsConfig) -> Self {
1411 let mut audit = Self {
1412 warnings: vec![],
1413 errors: vec![],
1414 recommendations: vec![],
1415 };
1416
1417 if !config.enabled {
1418 audit
1419 .errors
1420 .push("TLS is disabled - all traffic will be unencrypted".to_string());
1421 }
1422
1423 if config.insecure_skip_verify {
1424 audit.errors.push(
1425 "Certificate verification is disabled - vulnerable to MITM attacks".to_string(),
1426 );
1427 }
1428
1429 if config.min_version == TlsVersion::Tls12 {
1430 audit.warnings.push(
1431 "TLS 1.2 is allowed - consider requiring TLS 1.3 for better security".to_string(),
1432 );
1433 }
1434
1435 if config.mtls_mode == MtlsMode::Disabled && config.client_ca.is_some() {
1436 audit.warnings.push(
1437 "Client CA configured but mTLS is disabled - clients won't be verified".to_string(),
1438 );
1439 }
1440
1441 if config.mtls_mode == MtlsMode::Optional {
1442 audit.warnings.push(
1443 "mTLS is optional - some clients may connect without certificates".to_string(),
1444 );
1445 }
1446
1447 if config.session_cache_size == 0 {
1448 audit
1449 .recommendations
1450 .push("Consider enabling session cache for better performance".to_string());
1451 }
1452
1453 if config.cert_reload_interval == Duration::ZERO {
1454 audit.recommendations.push(
1455 "Consider enabling certificate hot-reloading for zero-downtime rotation"
1456 .to_string(),
1457 );
1458 }
1459
1460 if config.pinned_certificates.is_empty() && !config.insecure_skip_verify {
1461 audit
1462 .recommendations
1463 .push("Consider certificate pinning for high-security deployments".to_string());
1464 }
1465
1466 audit
1467 }
1468}
1469
1470#[cfg(test)]
1475mod tests {
1476 use super::*;
1477 use tokio::io::{AsyncReadExt, AsyncWriteExt};
1478
1479 #[test]
1480 fn test_tls_config_default() {
1481 let config = TlsConfig::default();
1482 assert!(!config.enabled);
1483 assert_eq!(config.mtls_mode, MtlsMode::Disabled);
1484 assert_eq!(config.min_version, TlsVersion::Tls13);
1485 }
1486
1487 #[test]
1488 fn test_tls_config_builder() {
1489 let config = TlsConfigBuilder::new()
1490 .with_cert_file("/path/to/cert.pem")
1491 .with_key_file("/path/to/key.pem")
1492 .with_client_ca_file("/path/to/ca.pem")
1493 .require_client_cert(true)
1494 .with_min_version(TlsVersion::Tls12)
1495 .build();
1496
1497 assert!(config.enabled);
1498 assert_eq!(config.mtls_mode, MtlsMode::Required);
1499 assert_eq!(config.min_version, TlsVersion::Tls12);
1500 }
1501
1502 #[tokio::test]
1503 async fn test_tls_server_client_handshake() {
1504 let _ = rustls::crypto::ring::default_provider().install_default();
1506
1507 let server_config = TlsConfig {
1509 enabled: true,
1510 certificate: Some(CertificateSource::SelfSigned {
1511 common_name: "localhost".to_string(),
1512 }),
1513 mtls_mode: MtlsMode::Disabled,
1515 ..Default::default()
1516 };
1517
1518 let client_config = TlsConfig {
1520 enabled: true,
1521 insecure_skip_verify: true,
1522 ..Default::default()
1523 };
1524
1525 let acceptor = TlsAcceptor::new(&server_config).unwrap();
1527 let connector = TlsConnector::new(&client_config).unwrap();
1528
1529 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
1531 let addr = listener.local_addr().unwrap();
1532
1533 let server_task = tokio::spawn(async move {
1535 let (tcp_stream, _) = listener.accept().await.unwrap();
1536 let mut tls_stream: TlsServerStream<tokio::net::TcpStream> =
1537 acceptor.accept_tcp(tcp_stream).await.unwrap();
1538
1539 let mut buf = [0u8; 32];
1541 let n = tls_stream.read(&mut buf).await.unwrap();
1542
1543 tls_stream.write_all(&buf[..n]).await.unwrap();
1545 tls_stream.flush().await.unwrap();
1546
1547 n
1548 });
1549
1550 let client_task = tokio::spawn(async move {
1552 let mut stream: TlsClientStream<tokio::net::TcpStream> =
1553 connector.connect_tcp(addr, "localhost").await.unwrap();
1554
1555 let message = b"Hello, TLS!";
1557 stream.write_all(message).await.unwrap();
1558 stream.flush().await.unwrap();
1559
1560 let mut response = [0u8; 32];
1562 let n = stream.read(&mut response).await.unwrap();
1563
1564 (message.to_vec(), response[..n].to_vec())
1565 });
1566
1567 let (server_result, client_result) = tokio::join!(server_task, client_task);
1569
1570 let server_bytes_read = server_result.unwrap();
1571 let (sent, received) = client_result.unwrap();
1572
1573 assert_eq!(server_bytes_read, sent.len());
1575 assert_eq!(sent, received);
1576 }
1577
1578 #[tokio::test]
1579 async fn test_mtls_server_client_handshake() {
1580 use rcgen::{BasicConstraints, CertificateParams, DnType, IsCa, KeyUsagePurpose};
1581
1582 let _ = rustls::crypto::ring::default_provider().install_default();
1584
1585 let mut ca_params = CertificateParams::default();
1587 ca_params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
1588 ca_params.key_usages = vec![KeyUsagePurpose::KeyCertSign, KeyUsagePurpose::CrlSign];
1589 ca_params
1590 .distinguished_name
1591 .push(DnType::CommonName, "Rivven Test CA");
1592 let ca_key_pair = rcgen::KeyPair::generate().unwrap();
1593 let ca_cert = ca_params.self_signed(&ca_key_pair).unwrap();
1594 let ca_cert_pem = ca_cert.pem();
1595
1596 let mut server_params = CertificateParams::new(vec!["localhost".to_string()]).unwrap();
1598 server_params
1599 .distinguished_name
1600 .push(DnType::CommonName, "localhost");
1601 let server_key_pair = rcgen::KeyPair::generate().unwrap();
1602 let server_cert = server_params
1603 .signed_by(&server_key_pair, &ca_cert, &ca_key_pair)
1604 .unwrap();
1605 let server_cert_pem = server_cert.pem();
1606 let server_key_pem = server_key_pair.serialize_pem();
1607
1608 let mut client_params =
1610 CertificateParams::new(vec!["client.rivven.local".to_string()]).unwrap();
1611 client_params
1612 .distinguished_name
1613 .push(DnType::CommonName, "client.rivven.local");
1614 let client_key_pair = rcgen::KeyPair::generate().unwrap();
1615 let client_cert = client_params
1616 .signed_by(&client_key_pair, &ca_cert, &ca_key_pair)
1617 .unwrap();
1618 let client_cert_pem = client_cert.pem();
1619 let client_key_pem = client_key_pair.serialize_pem();
1620
1621 let server_config = TlsConfig {
1623 enabled: true,
1624 certificate: Some(CertificateSource::Pem {
1625 content: server_cert_pem,
1626 }),
1627 private_key: Some(PrivateKeySource::Pem {
1628 content: server_key_pem,
1629 }),
1630 client_ca: Some(CertificateSource::Pem {
1631 content: ca_cert_pem.clone(),
1632 }),
1633 mtls_mode: MtlsMode::Required,
1634 insecure_skip_verify: false,
1635 ..Default::default()
1636 };
1637
1638 let client_config = TlsConfig {
1640 enabled: true,
1641 certificate: Some(CertificateSource::Pem {
1642 content: client_cert_pem,
1643 }),
1644 private_key: Some(PrivateKeySource::Pem {
1645 content: client_key_pem,
1646 }),
1647 root_ca: Some(CertificateSource::Pem {
1648 content: ca_cert_pem,
1649 }),
1650 insecure_skip_verify: false,
1651 ..Default::default()
1652 };
1653
1654 let acceptor = TlsAcceptor::new(&server_config).unwrap();
1656 let connector = TlsConnector::new(&client_config).unwrap();
1657
1658 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
1660 let addr = listener.local_addr().unwrap();
1661
1662 let server_task = tokio::spawn(async move {
1664 let (tcp_stream, _) = listener.accept().await.unwrap();
1665 let mut tls_stream: TlsServerStream<tokio::net::TcpStream> =
1666 acceptor.accept_tcp(tcp_stream).await.unwrap();
1667
1668 let has_peer_cert = tls_stream.peer_certificates().is_some();
1670
1671 let mut buf = [0u8; 32];
1673 let n = tls_stream.read(&mut buf).await.unwrap();
1674 tls_stream.write_all(&buf[..n]).await.unwrap();
1675 tls_stream.flush().await.unwrap();
1676
1677 (n, has_peer_cert)
1678 });
1679
1680 let client_task = tokio::spawn(async move {
1682 let mut stream: TlsClientStream<tokio::net::TcpStream> =
1683 connector.connect_tcp(addr, "localhost").await.unwrap();
1684
1685 let message = b"mTLS Test!";
1687 stream.write_all(message).await.unwrap();
1688 stream.flush().await.unwrap();
1689
1690 let mut response = [0u8; 32];
1692 let n = stream.read(&mut response).await.unwrap();
1693
1694 (message.to_vec(), response[..n].to_vec())
1695 });
1696
1697 let (server_result, client_result) = tokio::join!(server_task, client_task);
1699
1700 let (server_bytes_read, has_peer_cert) = server_result.unwrap();
1701 let (sent, received) = client_result.unwrap();
1702
1703 assert_eq!(server_bytes_read, sent.len());
1705 assert_eq!(sent, received);
1706
1707 assert!(
1709 has_peer_cert,
1710 "Server should have received client certificate in mTLS"
1711 );
1712 }
1713
1714 #[test]
1715 fn test_self_signed_generation() {
1716 let result = generate_self_signed("test.rivven.local");
1717 assert!(result.is_ok());
1718
1719 let (cert, _key) = result.unwrap();
1720 assert!(!cert.as_ref().is_empty());
1721
1722 let identity = TlsIdentity::from_certificate(&cert);
1724 assert_eq!(identity.common_name, Some("test.rivven.local".to_string()));
1725 assert!(identity.is_valid);
1726 }
1727
1728 #[test]
1729 fn test_certificate_fingerprint() {
1730 let (cert, _) = generate_self_signed("test.rivven.local").unwrap();
1731 let fingerprint = certificate_fingerprint(&cert);
1732
1733 assert_eq!(fingerprint.len(), 64);
1735 assert!(fingerprint.chars().all(|c| c.is_ascii_hexdigit()));
1736 }
1737
1738 #[test]
1739 fn test_tls_security_audit_disabled() {
1740 let config = TlsConfig::disabled();
1741 let audit = TlsSecurityAudit::audit(&config);
1742
1743 assert!(!audit.errors.is_empty());
1744 assert!(audit.errors.iter().any(|e| e.contains("disabled")));
1745 }
1746
1747 #[test]
1748 fn test_tls_security_audit_insecure() {
1749 let config = TlsConfig {
1750 enabled: true,
1751 insecure_skip_verify: true,
1752 ..Default::default()
1753 };
1754 let audit = TlsSecurityAudit::audit(&config);
1755
1756 assert!(audit.errors.iter().any(|e| e.contains("MITM")));
1757 }
1758
1759 #[test]
1760 fn test_tls_security_audit_production_ready() {
1761 let (_cert, _key) = generate_self_signed("broker.rivven.local").unwrap();
1762
1763 let config = TlsConfig {
1764 enabled: true,
1765 certificate: Some(CertificateSource::SelfSigned {
1766 common_name: "broker.rivven.local".to_string(),
1767 }),
1768 mtls_mode: MtlsMode::Required,
1769 min_version: TlsVersion::Tls13,
1770 insecure_skip_verify: false,
1771 session_cache_size: 256,
1772 ..Default::default()
1773 };
1774
1775 let audit = TlsSecurityAudit::audit(&config);
1776
1777 assert!(audit.errors.is_empty() || audit.errors.iter().all(|e| !e.contains("disabled")));
1780 }
1781
1782 #[test]
1783 fn test_mtls_modes() {
1784 assert_eq!(MtlsMode::default(), MtlsMode::Disabled);
1785
1786 let modes = [MtlsMode::Disabled, MtlsMode::Optional, MtlsMode::Required];
1787 for mode in modes {
1788 let json = serde_json::to_string(&mode).unwrap();
1789 let parsed: MtlsMode = serde_json::from_str(&json).unwrap();
1790 assert_eq!(mode, parsed);
1791 }
1792 }
1793
1794 #[test]
1795 fn test_tls_identity_extraction() {
1796 let (cert, _) = generate_self_signed("service.rivven.internal").unwrap();
1797 let identity = TlsIdentity::from_certificate(&cert);
1798
1799 assert_eq!(
1800 identity.common_name,
1801 Some("service.rivven.internal".to_string())
1802 );
1803 assert!(identity.is_valid);
1804 assert!(identity.valid_from.is_some());
1805 assert!(identity.valid_until.is_some());
1806 assert!(!identity.fingerprint.is_empty());
1807 }
1808}