1use std::collections::HashMap;
44use std::fmt;
45use std::fs;
46use std::io::{self, BufReader, Cursor};
47use std::net::SocketAddr;
48use std::path::PathBuf;
49use std::sync::Arc;
50use std::time::{Duration, SystemTime};
51
52use parking_lot::RwLock;
53use serde::{Deserialize, Serialize};
54use thiserror::Error;
55use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
56use tokio::net::TcpStream;
57
58pub use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
60pub use rustls::{ClientConfig, ServerConfig};
61
62#[derive(Debug, Error)]
64pub enum TlsError {
65 #[error("Failed to read certificate file '{path}': {source}")]
67 CertificateReadError {
68 path: PathBuf,
69 #[source]
70 source: io::Error,
71 },
72
73 #[error("Failed to read private key file '{path}': {source}")]
75 KeyReadError {
76 path: PathBuf,
77 #[source]
78 source: io::Error,
79 },
80
81 #[error("Invalid certificate format: {0}")]
83 InvalidCertificate(String),
84
85 #[error("Invalid private key format: {0}")]
87 InvalidPrivateKey(String),
88
89 #[error("Certificate chain validation failed: {0}")]
91 CertificateChainError(String),
92
93 #[error("TLS handshake failed: {0}")]
95 HandshakeError(String),
96
97 #[error("Connection error: {0}")]
99 ConnectionError(String),
100
101 #[error("TLS configuration error: {0}")]
103 ConfigError(String),
104
105 #[error("Certificate expired: {0}")]
107 CertificateExpired(String),
108
109 #[error("Certificate not yet valid: {0}")]
111 CertificateNotYetValid(String),
112
113 #[error("Certificate revoked: {0}")]
115 CertificateRevoked(String),
116
117 #[error("Hostname verification failed: expected '{expected}', got '{actual}'")]
119 HostnameVerificationFailed { expected: String, actual: String },
120
121 #[error("Client certificate required for mTLS but not provided")]
123 ClientCertificateRequired,
124
125 #[error("Failed to generate self-signed certificate: {0}")]
127 SelfSignedGenerationError(String),
128
129 #[error("ALPN negotiation failed: no common protocol")]
131 AlpnNegotiationFailed,
132
133 #[error("TLS internal error: {0}")]
135 RustlsError(String),
136}
137
138impl From<rustls::Error> for TlsError {
139 fn from(err: rustls::Error) -> Self {
140 TlsError::RustlsError(err.to_string())
141 }
142}
143
144pub type TlsResult<T> = std::result::Result<T, TlsError>;
146
147#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
153#[serde(rename_all = "snake_case")]
154pub enum TlsVersion {
155 Tls12,
157 #[default]
159 Tls13,
160}
161
162#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
164#[serde(rename_all = "snake_case")]
165pub enum MtlsMode {
166 #[default]
168 Disabled,
169 Optional,
171 Required,
173}
174
175#[derive(Debug, Clone, Serialize, Deserialize)]
177#[serde(tag = "type", rename_all = "snake_case")]
178pub enum CertificateSource {
179 File { path: PathBuf },
181 Pem { content: String },
183 Der { content: String },
185 SelfSigned { common_name: String },
187}
188
189#[derive(Debug, Clone, Serialize, Deserialize)]
191#[serde(tag = "type", rename_all = "snake_case")]
192pub enum PrivateKeySource {
193 File { path: PathBuf },
195 Pem { content: String },
197 Der { content: String },
199}
200
201#[derive(Debug, Clone, Serialize, Deserialize)]
203pub struct TlsConfig {
204 #[serde(default = "default_true")]
206 pub enabled: bool,
207
208 pub certificate: Option<CertificateSource>,
210
211 pub private_key: Option<PrivateKeySource>,
213
214 pub root_ca: Option<CertificateSource>,
216
217 pub client_ca: Option<CertificateSource>,
219
220 #[serde(default)]
222 pub mtls_mode: MtlsMode,
223
224 #[serde(default)]
226 pub min_version: TlsVersion,
227
228 #[serde(default)]
230 pub alpn_protocols: Vec<String>,
231
232 #[serde(default)]
234 pub ocsp_stapling: bool,
235
236 #[serde(default)]
238 pub pinned_certificates: Vec<String>,
239
240 #[serde(default)]
242 pub insecure_skip_verify: bool,
243
244 pub server_name: Option<String>,
246
247 #[serde(default = "default_session_cache_size")]
249 pub session_cache_size: usize,
250
251 #[serde(default = "default_session_ticket_lifetime")]
253 #[serde(with = "humantime_serde")]
254 pub session_ticket_lifetime: Duration,
255
256 #[serde(default)]
258 #[serde(with = "humantime_serde")]
259 pub cert_reload_interval: Duration,
260}
261
262fn default_true() -> bool {
263 true
264}
265
266fn default_session_cache_size() -> usize {
267 256
268}
269
270fn default_session_ticket_lifetime() -> Duration {
271 Duration::from_secs(86400) }
273
274impl Default for TlsConfig {
275 fn default() -> Self {
276 Self {
277 enabled: false, certificate: None,
279 private_key: None,
280 root_ca: None,
281 client_ca: None,
282 mtls_mode: MtlsMode::Disabled,
283 min_version: TlsVersion::Tls13,
284 alpn_protocols: vec![],
285 ocsp_stapling: false,
286 pinned_certificates: vec![],
287 insecure_skip_verify: false,
288 server_name: None,
289 session_cache_size: default_session_cache_size(),
290 session_ticket_lifetime: default_session_ticket_lifetime(),
291 cert_reload_interval: Duration::ZERO,
292 }
293 }
294}
295
296impl TlsConfig {
297 pub fn disabled() -> Self {
299 Self::default()
300 }
301
302 pub fn self_signed(common_name: &str) -> Self {
304 Self {
305 enabled: true,
306 certificate: Some(CertificateSource::SelfSigned {
307 common_name: common_name.to_string(),
308 }),
309 private_key: None, insecure_skip_verify: true, ..Default::default()
312 }
313 }
314
315 pub fn from_pem_files<P: Into<PathBuf>>(cert_path: P, key_path: P) -> Self {
317 Self {
318 enabled: true,
319 certificate: Some(CertificateSource::File {
320 path: cert_path.into(),
321 }),
322 private_key: Some(PrivateKeySource::File {
323 path: key_path.into(),
324 }),
325 ..Default::default()
326 }
327 }
328
329 pub fn mtls_from_pem_files<P1, P2, P3>(cert_path: P1, key_path: P2, ca_path: P3) -> Self
331 where
332 P1: Into<PathBuf>,
333 P2: Into<PathBuf>,
334 P3: Into<PathBuf> + Clone,
335 {
336 let ca: PathBuf = ca_path.clone().into();
337 Self {
338 enabled: true,
339 certificate: Some(CertificateSource::File {
340 path: cert_path.into(),
341 }),
342 private_key: Some(PrivateKeySource::File {
343 path: key_path.into(),
344 }),
345 client_ca: Some(CertificateSource::File { path: ca.clone() }),
346 root_ca: Some(CertificateSource::File { path: ca }),
347 mtls_mode: MtlsMode::Required,
348 ..Default::default()
349 }
350 }
351}
352
353pub struct TlsConfigBuilder {
359 config: TlsConfig,
360}
361
362impl TlsConfigBuilder {
363 pub fn new() -> Self {
365 Self {
366 config: TlsConfig {
367 enabled: true,
368 ..Default::default()
369 },
370 }
371 }
372
373 pub fn with_cert_file<P: Into<PathBuf>>(mut self, path: P) -> Self {
375 self.config.certificate = Some(CertificateSource::File { path: path.into() });
376 self
377 }
378
379 pub fn with_cert_pem(mut self, pem: String) -> Self {
381 self.config.certificate = Some(CertificateSource::Pem { content: pem });
382 self
383 }
384
385 pub fn with_key_file<P: Into<PathBuf>>(mut self, path: P) -> Self {
387 self.config.private_key = Some(PrivateKeySource::File { path: path.into() });
388 self
389 }
390
391 pub fn with_key_pem(mut self, pem: String) -> Self {
393 self.config.private_key = Some(PrivateKeySource::Pem { content: pem });
394 self
395 }
396
397 pub fn with_root_ca_file<P: Into<PathBuf>>(mut self, path: P) -> Self {
399 self.config.root_ca = Some(CertificateSource::File { path: path.into() });
400 self
401 }
402
403 pub fn with_client_ca_file<P: Into<PathBuf>>(mut self, path: P) -> Self {
405 self.config.client_ca = Some(CertificateSource::File { path: path.into() });
406 self
407 }
408
409 pub fn require_client_cert(mut self, required: bool) -> Self {
411 self.config.mtls_mode = if required {
412 MtlsMode::Required
413 } else {
414 MtlsMode::Disabled
415 };
416 self
417 }
418
419 pub fn with_mtls_mode(mut self, mode: MtlsMode) -> Self {
421 self.config.mtls_mode = mode;
422 self
423 }
424
425 pub fn with_min_version(mut self, version: TlsVersion) -> Self {
427 self.config.min_version = version;
428 self
429 }
430
431 pub fn with_alpn_protocols(mut self, protocols: Vec<String>) -> Self {
433 self.config.alpn_protocols = protocols;
434 self
435 }
436
437 pub fn with_server_name(mut self, name: String) -> Self {
439 self.config.server_name = Some(name);
440 self
441 }
442
443 pub fn insecure_skip_verify(mut self) -> Self {
445 self.config.insecure_skip_verify = true;
446 self
447 }
448
449 pub fn with_pinned_certificate(mut self, fingerprint: String) -> Self {
451 self.config.pinned_certificates.push(fingerprint);
452 self
453 }
454
455 pub fn with_self_signed(mut self, common_name: &str) -> Self {
457 self.config.certificate = Some(CertificateSource::SelfSigned {
458 common_name: common_name.to_string(),
459 });
460 self.config.insecure_skip_verify = true;
461 self
462 }
463
464 pub fn with_cert_reload_interval(mut self, interval: Duration) -> Self {
466 self.config.cert_reload_interval = interval;
467 self
468 }
469
470 pub fn build(self) -> TlsConfig {
472 self.config
473 }
474}
475
476impl Default for TlsConfigBuilder {
477 fn default() -> Self {
478 Self::new()
479 }
480}
481
482pub fn load_certificates(source: &CertificateSource) -> TlsResult<Vec<CertificateDer<'static>>> {
488 match source {
489 CertificateSource::File { path } => {
490 let data = fs::read(path).map_err(|e| TlsError::CertificateReadError {
491 path: path.clone(),
492 source: e,
493 })?;
494 parse_pem_certificates(&data)
495 }
496 CertificateSource::Pem { content } => parse_pem_certificates(content.as_bytes()),
497 CertificateSource::Der { content } => {
498 let der =
499 base64::Engine::decode(&base64::engine::general_purpose::STANDARD, content)
500 .map_err(|e| TlsError::InvalidCertificate(format!("Invalid base64: {}", e)))?;
501 Ok(vec![CertificateDer::from(der)])
502 }
503 CertificateSource::SelfSigned { common_name } => {
504 let (cert, _key) = generate_self_signed(common_name)?;
505 Ok(vec![cert])
506 }
507 }
508}
509
510fn parse_pem_certificates(data: &[u8]) -> TlsResult<Vec<CertificateDer<'static>>> {
512 let mut reader = BufReader::new(Cursor::new(data));
513 let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut reader)
514 .collect::<Result<Vec<_>, _>>()
515 .map_err(|e| TlsError::InvalidCertificate(format!("Failed to parse PEM: {}", e)))?;
516
517 if certs.is_empty() {
518 return Err(TlsError::InvalidCertificate(
519 "No certificates found in PEM data".to_string(),
520 ));
521 }
522
523 Ok(certs)
524}
525
526pub fn load_private_key(source: &PrivateKeySource) -> TlsResult<PrivateKeyDer<'static>> {
528 match source {
529 PrivateKeySource::File { path } => {
530 let data = fs::read(path).map_err(|e| TlsError::KeyReadError {
531 path: path.clone(),
532 source: e,
533 })?;
534 parse_pem_private_key(&data)
535 }
536 PrivateKeySource::Pem { content } => parse_pem_private_key(content.as_bytes()),
537 PrivateKeySource::Der { content } => {
538 let der = base64::Engine::decode(&base64::engine::general_purpose::STANDARD, content)
539 .map_err(|e| TlsError::InvalidPrivateKey(format!("Invalid base64: {}", e)))?;
540 Ok(PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(der)))
541 }
542 }
543}
544
545fn parse_pem_private_key(data: &[u8]) -> TlsResult<PrivateKeyDer<'static>> {
547 let mut reader = BufReader::new(Cursor::new(data));
548
549 rustls_pemfile::private_key(&mut reader)
550 .map_err(|e| TlsError::InvalidPrivateKey(format!("Failed to parse PEM: {}", e)))?
551 .ok_or_else(|| TlsError::InvalidPrivateKey("No private key found in PEM data".to_string()))
552}
553
554pub fn generate_self_signed(
556 common_name: &str,
557) -> TlsResult<(CertificateDer<'static>, PrivateKeyDer<'static>)> {
558 let subject_alt_names = vec![
559 common_name.to_string(),
560 "localhost".to_string(),
561 "127.0.0.1".to_string(),
562 ];
563
564 let mut cert_params = rcgen::CertificateParams::new(subject_alt_names)
565 .map_err(|e| TlsError::SelfSignedGenerationError(e.to_string()))?;
566
567 cert_params.distinguished_name = rcgen::DistinguishedName::new();
569 cert_params.distinguished_name.push(
570 rcgen::DnType::CommonName,
571 rcgen::DnValue::Utf8String(common_name.to_string()),
572 );
573 cert_params.distinguished_name.push(
574 rcgen::DnType::OrganizationName,
575 rcgen::DnValue::Utf8String("Rivven".to_string()),
576 );
577
578 let key_pair = rcgen::KeyPair::generate()
579 .map_err(|e| TlsError::SelfSignedGenerationError(e.to_string()))?;
580
581 let cert = cert_params
582 .self_signed(&key_pair)
583 .map_err(|e| TlsError::SelfSignedGenerationError(e.to_string()))?;
584
585 let cert_der = CertificateDer::from(cert.der().to_vec());
586 let key_der = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(key_pair.serialize_der()));
587
588 Ok((cert_der, key_der))
589}
590
591pub struct TlsAcceptor {
597 config: Arc<ServerConfig>,
598 inner: tokio_rustls::TlsAcceptor,
599 tls_config: TlsConfig,
601 reloadable_config: Option<Arc<RwLock<Arc<ServerConfig>>>>,
603}
604
605impl TlsAcceptor {
606 pub fn new(config: &TlsConfig) -> TlsResult<Self> {
608 let server_config = build_server_config(config)?;
609 let server_config = Arc::new(server_config);
610
611 Ok(Self {
612 inner: tokio_rustls::TlsAcceptor::from(server_config.clone()),
613 config: server_config.clone(),
614 tls_config: config.clone(),
615 reloadable_config: if config.cert_reload_interval > Duration::ZERO {
616 Some(Arc::new(RwLock::new(server_config)))
617 } else {
618 None
619 },
620 })
621 }
622
623 pub async fn accept<IO>(&self, stream: IO) -> TlsResult<TlsServerStream<IO>>
628 where
629 IO: AsyncRead + AsyncWrite + Unpin,
630 {
631 let acceptor = if let Some(ref reloadable) = self.reloadable_config {
632 let config = reloadable.read().clone();
633 tokio_rustls::TlsAcceptor::from(config)
634 } else {
635 self.inner.clone()
636 };
637
638 let tls_stream = acceptor
639 .accept(stream)
640 .await
641 .map_err(|e| TlsError::HandshakeError(e.to_string()))?;
642
643 Ok(TlsServerStream { inner: tls_stream })
644 }
645
646 pub async fn accept_tcp(&self, stream: TcpStream) -> TlsResult<TlsServerStream<TcpStream>> {
648 self.accept(stream).await
649 }
650
651 pub fn reload(&mut self) -> TlsResult<()> {
653 let new_config = build_server_config(&self.tls_config)?;
654 let new_config = Arc::new(new_config);
655
656 self.inner = tokio_rustls::TlsAcceptor::from(new_config.clone());
658 self.config = new_config.clone();
659
660 if let Some(ref reloadable) = self.reloadable_config {
661 *reloadable.write() = new_config;
662 }
663 Ok(())
664 }
665
666 pub fn config(&self) -> &Arc<ServerConfig> {
668 &self.config
669 }
670}
671
672impl fmt::Debug for TlsAcceptor {
673 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
674 f.debug_struct("TlsAcceptor")
675 .field("mtls_mode", &self.tls_config.mtls_mode)
676 .field("min_version", &self.tls_config.min_version)
677 .finish()
678 }
679}
680
681fn build_server_config(config: &TlsConfig) -> TlsResult<ServerConfig> {
683 let (certs, key) =
685 if let Some(CertificateSource::SelfSigned { common_name }) = &config.certificate {
686 let (cert, key) = generate_self_signed(common_name)?;
688 (vec![cert], key)
689 } else {
690 let certs = if let Some(ref cert_source) = config.certificate {
692 load_certificates(cert_source)?
693 } else {
694 return Err(TlsError::ConfigError(
695 "Server certificate required".to_string(),
696 ));
697 };
698
699 let key = if let Some(ref key_source) = config.private_key {
701 load_private_key(key_source)?
702 } else {
703 return Err(TlsError::ConfigError("Private key required".to_string()));
704 };
705
706 (certs, key)
707 };
708
709 let versions: Vec<&'static rustls::SupportedProtocolVersion> = match config.min_version {
711 TlsVersion::Tls13 => vec![&rustls::version::TLS13],
712 TlsVersion::Tls12 => vec![&rustls::version::TLS12, &rustls::version::TLS13],
713 };
714
715 let client_cert_verifier = match config.mtls_mode {
717 MtlsMode::Disabled => None,
718 MtlsMode::Optional | MtlsMode::Required => {
719 if let Some(ref ca_source) = config.client_ca {
720 let ca_certs = load_certificates(ca_source)?;
721 let mut root_store = rustls::RootCertStore::empty();
722 for cert in ca_certs {
723 root_store.add(cert).map_err(|e| {
724 TlsError::CertificateChainError(format!("Failed to add CA cert: {}", e))
725 })?;
726 }
727
728 let verifier = if config.mtls_mode == MtlsMode::Required {
729 rustls::server::WebPkiClientVerifier::builder(Arc::new(root_store))
730 .build()
731 .map_err(|e| {
732 TlsError::ConfigError(format!("Failed to build client verifier: {}", e))
733 })?
734 } else {
735 rustls::server::WebPkiClientVerifier::builder(Arc::new(root_store))
736 .allow_unauthenticated()
737 .build()
738 .map_err(|e| {
739 TlsError::ConfigError(format!("Failed to build client verifier: {}", e))
740 })?
741 };
742
743 Some(verifier)
744 } else if config.mtls_mode == MtlsMode::Required {
745 return Err(TlsError::ConfigError(
746 "mTLS required but no client CA configured".to_string(),
747 ));
748 } else {
749 None
750 }
751 }
752 };
753
754 let mut server_config = if let Some(verifier) = client_cert_verifier {
756 ServerConfig::builder_with_protocol_versions(&versions)
757 .with_client_cert_verifier(verifier)
758 .with_single_cert(certs, key)
759 .map_err(|e| TlsError::ConfigError(format!("Invalid cert/key: {}", e)))?
760 } else {
761 ServerConfig::builder_with_protocol_versions(&versions)
762 .with_no_client_auth()
763 .with_single_cert(certs, key)
764 .map_err(|e| TlsError::ConfigError(format!("Invalid cert/key: {}", e)))?
765 };
766
767 if !config.alpn_protocols.is_empty() {
769 server_config.alpn_protocols = config
770 .alpn_protocols
771 .iter()
772 .map(|p| p.as_bytes().to_vec())
773 .collect();
774 }
775
776 if config.session_cache_size > 0 {
778 server_config.session_storage =
779 rustls::server::ServerSessionMemoryCache::new(config.session_cache_size);
780 }
781
782 Ok(server_config)
783}
784
785pub struct TlsConnector {
791 config: Arc<ClientConfig>,
792 inner: tokio_rustls::TlsConnector,
793 server_name: Option<String>,
795}
796
797impl TlsConnector {
798 pub fn new(config: &TlsConfig) -> TlsResult<Self> {
800 let client_config = build_client_config(config)?;
801 let client_config = Arc::new(client_config);
802
803 Ok(Self {
804 inner: tokio_rustls::TlsConnector::from(client_config.clone()),
805 config: client_config,
806 server_name: config.server_name.clone(),
807 })
808 }
809
810 pub async fn connect<IO>(&self, stream: IO, server_name: &str) -> TlsResult<TlsClientStream<IO>>
812 where
813 IO: AsyncRead + AsyncWrite + Unpin,
814 {
815 let name: rustls::pki_types::ServerName<'static> = server_name
816 .to_string()
817 .try_into()
818 .map_err(|_| TlsError::ConfigError(format!("Invalid server name: {}", server_name)))?;
819
820 let tls_stream = self
821 .inner
822 .connect(name, stream)
823 .await
824 .map_err(|e| TlsError::HandshakeError(e.to_string()))?;
825
826 Ok(TlsClientStream { inner: tls_stream })
827 }
828
829 pub async fn connect_tcp(
831 &self,
832 addr: SocketAddr,
833 server_name: &str,
834 ) -> TlsResult<TlsClientStream<TcpStream>> {
835 let stream = TcpStream::connect(addr)
836 .await
837 .map_err(|e| TlsError::ConnectionError(e.to_string()))?;
838
839 self.connect(stream, server_name).await
840 }
841
842 pub async fn connect_with_default_name<IO>(&self, stream: IO) -> TlsResult<TlsClientStream<IO>>
844 where
845 IO: AsyncRead + AsyncWrite + Unpin,
846 {
847 let name = self.server_name.as_ref().ok_or_else(|| {
848 TlsError::ConfigError("No server name configured for SNI".to_string())
849 })?;
850 self.connect(stream, name).await
851 }
852
853 pub fn config(&self) -> &Arc<ClientConfig> {
855 &self.config
856 }
857}
858
859impl fmt::Debug for TlsConnector {
860 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
861 f.debug_struct("TlsConnector")
862 .field("server_name", &self.server_name)
863 .finish()
864 }
865}
866
867fn build_client_config(config: &TlsConfig) -> TlsResult<ClientConfig> {
869 let versions: Vec<&'static rustls::SupportedProtocolVersion> = match config.min_version {
871 TlsVersion::Tls13 => vec![&rustls::version::TLS13],
872 TlsVersion::Tls12 => vec![&rustls::version::TLS12, &rustls::version::TLS13],
873 };
874
875 let root_store = if config.insecure_skip_verify {
877 rustls::RootCertStore::empty()
879 } else if let Some(ref ca_source) = config.root_ca {
880 let ca_certs = load_certificates(ca_source)?;
881 let mut store = rustls::RootCertStore::empty();
882 for cert in ca_certs {
883 store.add(cert).map_err(|e| {
884 TlsError::CertificateChainError(format!("Failed to add root CA: {}", e))
885 })?;
886 }
887 store
888 } else {
889 let mut store = rustls::RootCertStore::empty();
891 let native_certs = rustls_native_certs::load_native_certs();
892 for cert in native_certs.certs {
893 let _ = store.add(cert);
894 }
895 store
896 };
897
898 let mut client_config = if let (Some(ref cert_source), Some(ref key_source)) =
900 (&config.certificate, &config.private_key)
901 {
902 let certs = load_certificates(cert_source)?;
904 let key = load_private_key(key_source)?;
905
906 ClientConfig::builder_with_protocol_versions(&versions)
907 .with_root_certificates(root_store)
908 .with_client_auth_cert(certs, key)
909 .map_err(|e| TlsError::ConfigError(format!("Invalid client cert/key: {}", e)))?
910 } else if config.insecure_skip_verify {
911 ClientConfig::builder_with_protocol_versions(&versions)
913 .dangerous()
914 .with_custom_certificate_verifier(Arc::new(NoCertificateVerification))
915 .with_no_client_auth()
916 } else {
917 ClientConfig::builder_with_protocol_versions(&versions)
919 .with_root_certificates(root_store)
920 .with_no_client_auth()
921 };
922
923 if !config.alpn_protocols.is_empty() {
925 client_config.alpn_protocols = config
926 .alpn_protocols
927 .iter()
928 .map(|p| p.as_bytes().to_vec())
929 .collect();
930 }
931
932 Ok(client_config)
933}
934
935#[derive(Debug)]
938struct NoCertificateVerification;
939
940impl rustls::client::danger::ServerCertVerifier for NoCertificateVerification {
941 fn verify_server_cert(
942 &self,
943 _end_entity: &CertificateDer<'_>,
944 _intermediates: &[CertificateDer<'_>],
945 _server_name: &rustls::pki_types::ServerName<'_>,
946 _ocsp_response: &[u8],
947 _now: rustls::pki_types::UnixTime,
948 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
949 Ok(rustls::client::danger::ServerCertVerified::assertion())
950 }
951
952 fn verify_tls12_signature(
953 &self,
954 _message: &[u8],
955 _cert: &CertificateDer<'_>,
956 _dss: &rustls::DigitallySignedStruct,
957 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
958 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
959 }
960
961 fn verify_tls13_signature(
962 &self,
963 _message: &[u8],
964 _cert: &CertificateDer<'_>,
965 _dss: &rustls::DigitallySignedStruct,
966 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
967 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
968 }
969
970 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
971 vec![
972 rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
973 rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
974 rustls::SignatureScheme::ECDSA_NISTP521_SHA512,
975 rustls::SignatureScheme::RSA_PSS_SHA256,
976 rustls::SignatureScheme::RSA_PSS_SHA384,
977 rustls::SignatureScheme::RSA_PSS_SHA512,
978 rustls::SignatureScheme::RSA_PKCS1_SHA256,
979 rustls::SignatureScheme::RSA_PKCS1_SHA384,
980 rustls::SignatureScheme::RSA_PKCS1_SHA512,
981 rustls::SignatureScheme::ED25519,
982 ]
983 }
984}
985
986pub struct TlsServerStream<IO> {
992 inner: tokio_rustls::server::TlsStream<IO>,
993}
994
995impl<IO> TlsServerStream<IO>
996where
997 IO: AsyncRead + AsyncWrite + Unpin,
998{
999 pub fn peer_certificates(&self) -> Option<&[CertificateDer<'_>]> {
1001 self.inner.get_ref().1.peer_certificates()
1002 }
1003
1004 pub fn alpn_protocol(&self) -> Option<&[u8]> {
1006 self.inner.get_ref().1.alpn_protocol()
1007 }
1008
1009 pub fn protocol_version(&self) -> Option<rustls::ProtocolVersion> {
1011 self.inner.get_ref().1.protocol_version()
1012 }
1013
1014 pub fn negotiated_cipher_suite(&self) -> Option<rustls::SupportedCipherSuite> {
1016 self.inner.get_ref().1.negotiated_cipher_suite()
1017 }
1018
1019 pub fn cipher_suite_name(&self) -> Option<String> {
1021 self.negotiated_cipher_suite()
1022 .map(|cs| format!("{:?}", cs.suite()))
1023 }
1024
1025 pub fn is_tls_13(&self) -> bool {
1027 self.protocol_version() == Some(rustls::ProtocolVersion::TLSv1_3)
1028 }
1029
1030 pub fn peer_common_name(&self) -> Option<String> {
1032 self.peer_certificates().and_then(|certs| {
1033 if certs.is_empty() {
1034 return None;
1035 }
1036 extract_common_name(&certs[0])
1037 })
1038 }
1039
1040 pub fn peer_subject(&self) -> Option<String> {
1042 self.peer_certificates().and_then(|certs| {
1043 if certs.is_empty() {
1044 return None;
1045 }
1046 extract_subject(&certs[0])
1047 })
1048 }
1049
1050 pub fn get_ref(&self) -> &IO {
1052 self.inner.get_ref().0
1053 }
1054
1055 pub fn get_mut(&mut self) -> &mut IO {
1057 self.inner.get_mut().0
1058 }
1059
1060 pub fn into_inner(self) -> IO {
1062 self.inner.into_inner().0
1063 }
1064}
1065
1066impl<IO> tokio::io::AsyncRead for TlsServerStream<IO>
1067where
1068 IO: AsyncRead + AsyncWrite + Unpin,
1069{
1070 fn poll_read(
1071 mut self: std::pin::Pin<&mut Self>,
1072 cx: &mut std::task::Context<'_>,
1073 buf: &mut ReadBuf<'_>,
1074 ) -> std::task::Poll<io::Result<()>> {
1075 std::pin::Pin::new(&mut self.inner).poll_read(cx, buf)
1076 }
1077}
1078
1079impl<IO> tokio::io::AsyncWrite for TlsServerStream<IO>
1080where
1081 IO: AsyncRead + AsyncWrite + Unpin,
1082{
1083 fn poll_write(
1084 mut self: std::pin::Pin<&mut Self>,
1085 cx: &mut std::task::Context<'_>,
1086 buf: &[u8],
1087 ) -> std::task::Poll<io::Result<usize>> {
1088 std::pin::Pin::new(&mut self.inner).poll_write(cx, buf)
1089 }
1090
1091 fn poll_flush(
1092 mut self: std::pin::Pin<&mut Self>,
1093 cx: &mut std::task::Context<'_>,
1094 ) -> std::task::Poll<io::Result<()>> {
1095 std::pin::Pin::new(&mut self.inner).poll_flush(cx)
1096 }
1097
1098 fn poll_shutdown(
1099 mut self: std::pin::Pin<&mut Self>,
1100 cx: &mut std::task::Context<'_>,
1101 ) -> std::task::Poll<io::Result<()>> {
1102 std::pin::Pin::new(&mut self.inner).poll_shutdown(cx)
1103 }
1104}
1105
1106pub struct TlsClientStream<IO> {
1108 inner: tokio_rustls::client::TlsStream<IO>,
1109}
1110
1111impl<IO> TlsClientStream<IO>
1112where
1113 IO: AsyncRead + AsyncWrite + Unpin,
1114{
1115 pub fn peer_certificates(&self) -> Option<&[CertificateDer<'_>]> {
1117 self.inner.get_ref().1.peer_certificates()
1118 }
1119
1120 pub fn alpn_protocol(&self) -> Option<&[u8]> {
1122 self.inner.get_ref().1.alpn_protocol()
1123 }
1124
1125 pub fn protocol_version(&self) -> Option<rustls::ProtocolVersion> {
1127 self.inner.get_ref().1.protocol_version()
1128 }
1129
1130 pub fn is_tls_13(&self) -> bool {
1132 self.protocol_version() == Some(rustls::ProtocolVersion::TLSv1_3)
1133 }
1134
1135 pub fn get_ref(&self) -> &IO {
1137 self.inner.get_ref().0
1138 }
1139
1140 pub fn get_mut(&mut self) -> &mut IO {
1142 self.inner.get_mut().0
1143 }
1144
1145 pub fn into_inner(self) -> IO {
1147 self.inner.into_inner().0
1148 }
1149}
1150
1151impl<IO> tokio::io::AsyncRead for TlsClientStream<IO>
1152where
1153 IO: AsyncRead + AsyncWrite + Unpin,
1154{
1155 fn poll_read(
1156 mut self: std::pin::Pin<&mut Self>,
1157 cx: &mut std::task::Context<'_>,
1158 buf: &mut ReadBuf<'_>,
1159 ) -> std::task::Poll<io::Result<()>> {
1160 std::pin::Pin::new(&mut self.inner).poll_read(cx, buf)
1161 }
1162}
1163
1164impl<IO> tokio::io::AsyncWrite for TlsClientStream<IO>
1165where
1166 IO: AsyncRead + AsyncWrite + Unpin,
1167{
1168 fn poll_write(
1169 mut self: std::pin::Pin<&mut Self>,
1170 cx: &mut std::task::Context<'_>,
1171 buf: &[u8],
1172 ) -> std::task::Poll<io::Result<usize>> {
1173 std::pin::Pin::new(&mut self.inner).poll_write(cx, buf)
1174 }
1175
1176 fn poll_flush(
1177 mut self: std::pin::Pin<&mut Self>,
1178 cx: &mut std::task::Context<'_>,
1179 ) -> std::task::Poll<io::Result<()>> {
1180 std::pin::Pin::new(&mut self.inner).poll_flush(cx)
1181 }
1182
1183 fn poll_shutdown(
1184 mut self: std::pin::Pin<&mut Self>,
1185 cx: &mut std::task::Context<'_>,
1186 ) -> std::task::Poll<io::Result<()>> {
1187 std::pin::Pin::new(&mut self.inner).poll_shutdown(cx)
1188 }
1189}
1190
1191fn extract_common_name(cert: &CertificateDer<'_>) -> Option<String> {
1197 let (_, cert) = x509_parser::parse_x509_certificate(cert.as_ref()).ok()?;
1199
1200 for rdn in cert.subject().iter_rdn() {
1201 for attr in rdn.iter() {
1202 if attr.attr_type() == &x509_parser::oid_registry::OID_X509_COMMON_NAME {
1203 return attr.as_str().ok().map(|s| s.to_string());
1204 }
1205 }
1206 }
1207
1208 None
1209}
1210
1211fn extract_subject(cert: &CertificateDer<'_>) -> Option<String> {
1213 let (_, cert) = x509_parser::parse_x509_certificate(cert.as_ref()).ok()?;
1214 Some(cert.subject().to_string())
1215}
1216
1217pub fn certificate_fingerprint(cert: &CertificateDer<'_>) -> String {
1219 use sha2::{Digest, Sha256};
1220 let hash = Sha256::digest(cert.as_ref());
1221 hex::encode(hash)
1222}
1223
1224pub fn verify_certificate_chain(
1229 chain: &[CertificateDer<'_>],
1230 root_store: &rustls::RootCertStore,
1231) -> TlsResult<()> {
1232 if chain.is_empty() {
1233 return Err(TlsError::CertificateChainError(
1234 "Empty certificate chain".to_string(),
1235 ));
1236 }
1237
1238 if root_store.is_empty() {
1241 tracing::warn!("Root certificate store is empty - chain validation may fail");
1242 }
1243
1244 for (i, cert) in chain.iter().enumerate() {
1246 let fingerprint = certificate_fingerprint(cert);
1247 tracing::debug!(
1248 "Certificate chain[{}]: fingerprint={}",
1249 i,
1250 &fingerprint[..16]
1251 );
1252 }
1253
1254 Ok(())
1255}
1256
1257pub struct CertificateWatcher {
1263 watched_files: Vec<PathBuf>,
1265 last_modified: HashMap<PathBuf, SystemTime>,
1267 reload_callback: Box<dyn Fn() + Send + Sync>,
1269}
1270
1271impl CertificateWatcher {
1272 pub fn new<F>(files: Vec<PathBuf>, callback: F) -> Self
1274 where
1275 F: Fn() + Send + Sync + 'static,
1276 {
1277 let mut last_modified = HashMap::new();
1278 for file in &files {
1279 if let Ok(meta) = fs::metadata(file) {
1280 if let Ok(modified) = meta.modified() {
1281 last_modified.insert(file.clone(), modified);
1282 }
1283 }
1284 }
1285
1286 Self {
1287 watched_files: files,
1288 last_modified,
1289 reload_callback: Box::new(callback),
1290 }
1291 }
1292
1293 pub fn check_and_reload(&mut self) -> bool {
1295 let mut changed = false;
1296
1297 for file in &self.watched_files {
1298 if let Ok(meta) = fs::metadata(file) {
1299 if let Ok(modified) = meta.modified() {
1300 let last = self.last_modified.get(file);
1301 if last.is_none_or(|&l| modified > l) {
1302 self.last_modified.insert(file.clone(), modified);
1303 changed = true;
1304 }
1305 }
1306 }
1307 }
1308
1309 if changed {
1310 (self.reload_callback)();
1311 }
1312
1313 changed
1314 }
1315
1316 pub fn spawn(mut self, interval: Duration) -> tokio::task::JoinHandle<()> {
1318 tokio::spawn(async move {
1319 let mut ticker = tokio::time::interval(interval);
1320 loop {
1321 ticker.tick().await;
1322 self.check_and_reload();
1323 }
1324 })
1325 }
1326}
1327
1328#[derive(Debug, Clone, Serialize, Deserialize)]
1334pub struct TlsIdentity {
1335 pub common_name: Option<String>,
1337 pub subject: Option<String>,
1339 pub fingerprint: String,
1341 pub organization: Option<String>,
1343 pub organizational_unit: Option<String>,
1345 pub serial_number: Option<String>,
1347 pub valid_from: Option<chrono::DateTime<chrono::Utc>>,
1349 pub valid_until: Option<chrono::DateTime<chrono::Utc>>,
1350 pub is_valid: bool,
1352}
1353
1354impl TlsIdentity {
1355 pub fn from_certificate(cert: &CertificateDer<'_>) -> Self {
1357 let fingerprint = certificate_fingerprint(cert);
1358 let common_name = extract_common_name(cert);
1359 let subject = extract_subject(cert);
1360
1361 let (organization, organizational_unit, serial_number, valid_from, valid_until, is_valid) =
1363 if let Ok((_, parsed)) = x509_parser::parse_x509_certificate(cert.as_ref()) {
1364 let mut org = None;
1365 let mut ou = None;
1366
1367 for rdn in parsed.subject().iter_rdn() {
1368 for attr in rdn.iter() {
1369 if attr.attr_type()
1370 == &x509_parser::oid_registry::OID_X509_ORGANIZATION_NAME
1371 {
1372 org = attr.as_str().ok().map(|s| s.to_string());
1373 }
1374 if attr.attr_type()
1375 == &x509_parser::oid_registry::OID_X509_ORGANIZATIONAL_UNIT
1376 {
1377 ou = attr.as_str().ok().map(|s| s.to_string());
1378 }
1379 }
1380 }
1381
1382 let serial = Some(parsed.serial.to_str_radix(16));
1383
1384 let validity = parsed.validity();
1385 let now = chrono::Utc::now();
1386
1387 let from = chrono::DateTime::from_timestamp(validity.not_before.timestamp(), 0);
1388 let until = chrono::DateTime::from_timestamp(validity.not_after.timestamp(), 0);
1389
1390 let valid = from.is_some_and(|f| now >= f) && until.is_some_and(|u| now <= u);
1391
1392 (org, ou, serial, from, until, valid)
1393 } else {
1394 (None, None, None, None, None, false)
1395 };
1396
1397 Self {
1398 common_name,
1399 subject,
1400 fingerprint,
1401 organization,
1402 organizational_unit,
1403 serial_number,
1404 valid_from,
1405 valid_until,
1406 is_valid,
1407 }
1408 }
1409}
1410
1411#[derive(Debug)]
1417pub struct TlsSecurityAudit {
1418 pub warnings: Vec<String>,
1419 pub errors: Vec<String>,
1420 pub recommendations: Vec<String>,
1421}
1422
1423impl TlsSecurityAudit {
1424 pub fn audit(config: &TlsConfig) -> Self {
1426 let mut audit = Self {
1427 warnings: vec![],
1428 errors: vec![],
1429 recommendations: vec![],
1430 };
1431
1432 if !config.enabled {
1433 audit
1434 .errors
1435 .push("TLS is disabled - all traffic will be unencrypted".to_string());
1436 }
1437
1438 if config.insecure_skip_verify {
1439 audit.errors.push(
1440 "Certificate verification is disabled - vulnerable to MITM attacks".to_string(),
1441 );
1442 }
1443
1444 if config.min_version == TlsVersion::Tls12 {
1445 audit.warnings.push(
1446 "TLS 1.2 is allowed - consider requiring TLS 1.3 for better security".to_string(),
1447 );
1448 }
1449
1450 if config.mtls_mode == MtlsMode::Disabled && config.client_ca.is_some() {
1451 audit.warnings.push(
1452 "Client CA configured but mTLS is disabled - clients won't be verified".to_string(),
1453 );
1454 }
1455
1456 if config.mtls_mode == MtlsMode::Optional {
1457 audit.warnings.push(
1458 "mTLS is optional - some clients may connect without certificates".to_string(),
1459 );
1460 }
1461
1462 if config.session_cache_size == 0 {
1463 audit
1464 .recommendations
1465 .push("Consider enabling session cache for better performance".to_string());
1466 }
1467
1468 if config.cert_reload_interval == Duration::ZERO {
1469 audit.recommendations.push(
1470 "Consider enabling certificate hot-reloading for zero-downtime rotation"
1471 .to_string(),
1472 );
1473 }
1474
1475 if config.pinned_certificates.is_empty() && !config.insecure_skip_verify {
1476 audit
1477 .recommendations
1478 .push("Consider certificate pinning for high-security deployments".to_string());
1479 }
1480
1481 audit
1482 }
1483}
1484
1485#[cfg(test)]
1490mod tests {
1491 use super::*;
1492 use tokio::io::{AsyncReadExt, AsyncWriteExt};
1493
1494 #[test]
1495 fn test_tls_config_default() {
1496 let config = TlsConfig::default();
1497 assert!(!config.enabled);
1498 assert_eq!(config.mtls_mode, MtlsMode::Disabled);
1499 assert_eq!(config.min_version, TlsVersion::Tls13);
1500 }
1501
1502 #[test]
1503 fn test_tls_config_builder() {
1504 let config = TlsConfigBuilder::new()
1505 .with_cert_file("/path/to/cert.pem")
1506 .with_key_file("/path/to/key.pem")
1507 .with_client_ca_file("/path/to/ca.pem")
1508 .require_client_cert(true)
1509 .with_min_version(TlsVersion::Tls12)
1510 .build();
1511
1512 assert!(config.enabled);
1513 assert_eq!(config.mtls_mode, MtlsMode::Required);
1514 assert_eq!(config.min_version, TlsVersion::Tls12);
1515 }
1516
1517 #[tokio::test]
1518 async fn test_tls_server_client_handshake() {
1519 let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
1522
1523 let server_config = TlsConfig {
1525 enabled: true,
1526 certificate: Some(CertificateSource::SelfSigned {
1527 common_name: "localhost".to_string(),
1528 }),
1529 mtls_mode: MtlsMode::Disabled,
1531 ..Default::default()
1532 };
1533
1534 let client_config = TlsConfig {
1536 enabled: true,
1537 insecure_skip_verify: true,
1538 ..Default::default()
1539 };
1540
1541 let acceptor = TlsAcceptor::new(&server_config).unwrap();
1543 let connector = TlsConnector::new(&client_config).unwrap();
1544
1545 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
1547 let addr = listener.local_addr().unwrap();
1548
1549 let server_task = tokio::spawn(async move {
1551 let (tcp_stream, _) = listener.accept().await.unwrap();
1552 let mut tls_stream: TlsServerStream<tokio::net::TcpStream> =
1553 acceptor.accept_tcp(tcp_stream).await.unwrap();
1554
1555 let mut buf = [0u8; 32];
1557 let n = tls_stream.read(&mut buf).await.unwrap();
1558
1559 tls_stream.write_all(&buf[..n]).await.unwrap();
1561 tls_stream.flush().await.unwrap();
1562
1563 n
1564 });
1565
1566 let client_task = tokio::spawn(async move {
1568 let mut stream: TlsClientStream<tokio::net::TcpStream> =
1569 connector.connect_tcp(addr, "localhost").await.unwrap();
1570
1571 let message = b"Hello, TLS!";
1573 stream.write_all(message).await.unwrap();
1574 stream.flush().await.unwrap();
1575
1576 let mut response = [0u8; 32];
1578 let n = stream.read(&mut response).await.unwrap();
1579
1580 (message.to_vec(), response[..n].to_vec())
1581 });
1582
1583 let (server_result, client_result) = tokio::join!(server_task, client_task);
1585
1586 let server_bytes_read = server_result.unwrap();
1587 let (sent, received) = client_result.unwrap();
1588
1589 assert_eq!(server_bytes_read, sent.len());
1591 assert_eq!(sent, received);
1592 }
1593
1594 #[tokio::test]
1595 async fn test_mtls_server_client_handshake() {
1596 use rcgen::{BasicConstraints, CertificateParams, DnType, IsCa, KeyUsagePurpose};
1597
1598 let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
1601
1602 let mut ca_params = CertificateParams::default();
1604 ca_params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
1605 ca_params.key_usages = vec![KeyUsagePurpose::KeyCertSign, KeyUsagePurpose::CrlSign];
1606 ca_params
1607 .distinguished_name
1608 .push(DnType::CommonName, "Rivven Test CA");
1609 let ca_key_pair = rcgen::KeyPair::generate().unwrap();
1610 let ca_cert = ca_params.self_signed(&ca_key_pair).unwrap();
1611 let ca_cert_pem = ca_cert.pem();
1612
1613 let mut server_params = CertificateParams::new(vec!["localhost".to_string()]).unwrap();
1615 server_params
1616 .distinguished_name
1617 .push(DnType::CommonName, "localhost");
1618 let server_key_pair = rcgen::KeyPair::generate().unwrap();
1619 let server_cert = server_params
1620 .signed_by(&server_key_pair, &ca_cert, &ca_key_pair)
1621 .unwrap();
1622 let server_cert_pem = server_cert.pem();
1623 let server_key_pem = server_key_pair.serialize_pem();
1624
1625 let mut client_params =
1627 CertificateParams::new(vec!["client.rivven.local".to_string()]).unwrap();
1628 client_params
1629 .distinguished_name
1630 .push(DnType::CommonName, "client.rivven.local");
1631 let client_key_pair = rcgen::KeyPair::generate().unwrap();
1632 let client_cert = client_params
1633 .signed_by(&client_key_pair, &ca_cert, &ca_key_pair)
1634 .unwrap();
1635 let client_cert_pem = client_cert.pem();
1636 let client_key_pem = client_key_pair.serialize_pem();
1637
1638 let server_config = TlsConfig {
1640 enabled: true,
1641 certificate: Some(CertificateSource::Pem {
1642 content: server_cert_pem,
1643 }),
1644 private_key: Some(PrivateKeySource::Pem {
1645 content: server_key_pem,
1646 }),
1647 client_ca: Some(CertificateSource::Pem {
1648 content: ca_cert_pem.clone(),
1649 }),
1650 mtls_mode: MtlsMode::Required,
1651 insecure_skip_verify: false,
1652 ..Default::default()
1653 };
1654
1655 let client_config = TlsConfig {
1657 enabled: true,
1658 certificate: Some(CertificateSource::Pem {
1659 content: client_cert_pem,
1660 }),
1661 private_key: Some(PrivateKeySource::Pem {
1662 content: client_key_pem,
1663 }),
1664 root_ca: Some(CertificateSource::Pem {
1665 content: ca_cert_pem,
1666 }),
1667 insecure_skip_verify: false,
1668 ..Default::default()
1669 };
1670
1671 let acceptor = TlsAcceptor::new(&server_config).unwrap();
1673 let connector = TlsConnector::new(&client_config).unwrap();
1674
1675 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
1677 let addr = listener.local_addr().unwrap();
1678
1679 let server_task = tokio::spawn(async move {
1681 let (tcp_stream, _) = listener.accept().await.unwrap();
1682 let mut tls_stream: TlsServerStream<tokio::net::TcpStream> =
1683 acceptor.accept_tcp(tcp_stream).await.unwrap();
1684
1685 let has_peer_cert = tls_stream.peer_certificates().is_some();
1687
1688 let mut buf = [0u8; 32];
1690 let n = tls_stream.read(&mut buf).await.unwrap();
1691 tls_stream.write_all(&buf[..n]).await.unwrap();
1692 tls_stream.flush().await.unwrap();
1693
1694 (n, has_peer_cert)
1695 });
1696
1697 let client_task = tokio::spawn(async move {
1699 let mut stream: TlsClientStream<tokio::net::TcpStream> =
1700 connector.connect_tcp(addr, "localhost").await.unwrap();
1701
1702 let message = b"mTLS Test!";
1704 stream.write_all(message).await.unwrap();
1705 stream.flush().await.unwrap();
1706
1707 let mut response = [0u8; 32];
1709 let n = stream.read(&mut response).await.unwrap();
1710
1711 (message.to_vec(), response[..n].to_vec())
1712 });
1713
1714 let (server_result, client_result) = tokio::join!(server_task, client_task);
1716
1717 let (server_bytes_read, has_peer_cert) = server_result.unwrap();
1718 let (sent, received) = client_result.unwrap();
1719
1720 assert_eq!(server_bytes_read, sent.len());
1722 assert_eq!(sent, received);
1723
1724 assert!(
1726 has_peer_cert,
1727 "Server should have received client certificate in mTLS"
1728 );
1729 }
1730
1731 #[test]
1732 fn test_self_signed_generation() {
1733 let result = generate_self_signed("test.rivven.local");
1734 assert!(result.is_ok());
1735
1736 let (cert, _key) = result.unwrap();
1737 assert!(!cert.as_ref().is_empty());
1738
1739 let identity = TlsIdentity::from_certificate(&cert);
1741 assert_eq!(identity.common_name, Some("test.rivven.local".to_string()));
1742 assert!(identity.is_valid);
1743 }
1744
1745 #[test]
1746 fn test_certificate_fingerprint() {
1747 let (cert, _) = generate_self_signed("test.rivven.local").unwrap();
1748 let fingerprint = certificate_fingerprint(&cert);
1749
1750 assert_eq!(fingerprint.len(), 64);
1752 assert!(fingerprint.chars().all(|c| c.is_ascii_hexdigit()));
1753 }
1754
1755 #[test]
1756 fn test_tls_security_audit_disabled() {
1757 let config = TlsConfig::disabled();
1758 let audit = TlsSecurityAudit::audit(&config);
1759
1760 assert!(!audit.errors.is_empty());
1761 assert!(audit.errors.iter().any(|e| e.contains("disabled")));
1762 }
1763
1764 #[test]
1765 fn test_tls_security_audit_insecure() {
1766 let config = TlsConfig {
1767 enabled: true,
1768 insecure_skip_verify: true,
1769 ..Default::default()
1770 };
1771 let audit = TlsSecurityAudit::audit(&config);
1772
1773 assert!(audit.errors.iter().any(|e| e.contains("MITM")));
1774 }
1775
1776 #[test]
1777 fn test_tls_security_audit_production_ready() {
1778 let (_cert, _key) = generate_self_signed("broker.rivven.local").unwrap();
1779
1780 let config = TlsConfig {
1781 enabled: true,
1782 certificate: Some(CertificateSource::SelfSigned {
1783 common_name: "broker.rivven.local".to_string(),
1784 }),
1785 mtls_mode: MtlsMode::Required,
1786 min_version: TlsVersion::Tls13,
1787 insecure_skip_verify: false,
1788 session_cache_size: 256,
1789 ..Default::default()
1790 };
1791
1792 let audit = TlsSecurityAudit::audit(&config);
1793
1794 assert!(audit.errors.is_empty() || audit.errors.iter().all(|e| !e.contains("disabled")));
1797 }
1798
1799 #[test]
1800 fn test_mtls_modes() {
1801 assert_eq!(MtlsMode::default(), MtlsMode::Disabled);
1802
1803 let modes = [MtlsMode::Disabled, MtlsMode::Optional, MtlsMode::Required];
1804 for mode in modes {
1805 let json = serde_json::to_string(&mode).unwrap();
1806 let parsed: MtlsMode = serde_json::from_str(&json).unwrap();
1807 assert_eq!(mode, parsed);
1808 }
1809 }
1810
1811 #[test]
1812 fn test_tls_identity_extraction() {
1813 let (cert, _) = generate_self_signed("service.rivven.internal").unwrap();
1814 let identity = TlsIdentity::from_certificate(&cert);
1815
1816 assert_eq!(
1817 identity.common_name,
1818 Some("service.rivven.internal".to_string())
1819 );
1820 assert!(identity.is_valid);
1821 assert!(identity.valid_from.is_some());
1822 assert!(identity.valid_until.is_some());
1823 assert!(!identity.fingerprint.is_empty());
1824 }
1825}