1use std::fs::File;
19use std::io::{self, BufReader, Seek, Write};
20use std::net::SocketAddr;
21use std::path::Path;
22use std::sync::Arc;
23
24use rustls::client::danger::{ServerCertVerified, ServerCertVerifier};
25use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime};
26use rustls::{ClientConfig, DigitallySignedStruct, RootCertStore, ServerConfig};
27use rustls_pemfile::{certs, pkcs8_private_keys};
28use tokio::net::{TcpListener, TcpStream};
29use tokio_rustls::client::TlsStream as ClientTlsStream;
30use tokio_rustls::server::TlsStream as ServerTlsStream;
31use tokio_rustls::{TlsAcceptor, TlsConnector};
32use tokio_util::codec::Framed;
33use tracing::{debug, error, info, instrument, warn};
34
35use crate::core::codec::PacketCodec;
36use crate::core::packet::Packet;
37use crate::error::{ProtocolError, Result};
38use futures::{SinkExt, StreamExt};
39
40#[derive(Debug)]
42struct CertificateFingerprint {
43 fingerprint: Vec<u8>,
44}
45
46impl ServerCertVerifier for CertificateFingerprint {
47 fn verify_server_cert(
48 &self,
49 end_entity: &CertificateDer<'_>,
50 _intermediates: &[CertificateDer<'_>],
51 _server_name: &ServerName,
52 _ocsp_response: &[u8],
53 _now: UnixTime,
54 ) -> std::result::Result<ServerCertVerified, rustls::Error> {
55 use sha2::{Digest, Sha256};
56
57 let mut hasher = Sha256::new();
58 hasher.update(end_entity);
59 let hash = hasher.finalize();
60
61 if hash.as_slice() == self.fingerprint.as_slice() {
62 Ok(ServerCertVerified::assertion())
63 } else {
64 Err(rustls::Error::General(
65 "Pinned certificate hash mismatch".into(),
66 ))
67 }
68 }
69
70 fn verify_tls12_signature(
71 &self,
72 _message: &[u8],
73 _cert: &CertificateDer<'_>,
74 _dss: &DigitallySignedStruct,
75 ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
76 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
77 }
78
79 fn verify_tls13_signature(
80 &self,
81 _message: &[u8],
82 _cert: &CertificateDer<'_>,
83 _dss: &DigitallySignedStruct,
84 ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
85 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
86 }
87
88 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
89 vec![
91 rustls::SignatureScheme::RSA_PKCS1_SHA256,
92 rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
93 rustls::SignatureScheme::ED25519,
94 ]
95 }
96}
97
98#[derive(Debug)]
99struct AcceptAnyServerCert;
100
101impl ServerCertVerifier for AcceptAnyServerCert {
102 fn verify_server_cert(
103 &self,
104 _end_entity: &CertificateDer<'_>,
105 _intermediates: &[CertificateDer<'_>],
106 _server_name: &ServerName,
107 _ocsp_response: &[u8],
108 _now: UnixTime,
109 ) -> std::result::Result<ServerCertVerified, rustls::Error> {
110 Ok(ServerCertVerified::assertion())
111 }
112
113 fn verify_tls12_signature(
114 &self,
115 _message: &[u8],
116 _cert: &CertificateDer<'_>,
117 _dss: &DigitallySignedStruct,
118 ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
119 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
120 }
121
122 fn verify_tls13_signature(
123 &self,
124 _message: &[u8],
125 _cert: &CertificateDer<'_>,
126 _dss: &DigitallySignedStruct,
127 ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
128 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
129 }
130
131 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
132 vec![
133 rustls::SignatureScheme::RSA_PKCS1_SHA256,
134 rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
135 rustls::SignatureScheme::ED25519,
136 ]
137 }
138}
139
140fn load_private_key(reader: &mut BufReader<File>) -> Result<PrivateKeyDer<'static>> {
142 reader
145 .seek(std::io::SeekFrom::Start(0))
146 .map_err(ProtocolError::Io)?;
147
148 let keys: std::result::Result<Vec<_>, _> = pkcs8_private_keys(reader).collect();
150 let keys =
151 keys.map_err(|_| ProtocolError::TlsError("Failed to parse PKCS8 private key".into()))?;
152
153 if !keys.is_empty() {
154 return Ok(PrivateKeyDer::Pkcs8(keys[0].clone_key()));
155 }
156
157 Err(ProtocolError::TlsError(
160 "No supported private key format found".into(),
161 ))
162}
163
164pub enum TlsVersion {
166 TLS12,
168 TLS13,
170 All,
172}
173
174pub struct TlsServerConfig {
176 cert_path: String,
177 key_path: String,
178 client_ca_path: Option<String>,
180 require_client_auth: bool,
182 tls_versions: Option<Vec<TlsVersion>>,
184 cipher_suites: Option<Vec<rustls::SupportedCipherSuite>>,
186 alpn_protocols: Option<Vec<Vec<u8>>>,
188}
189
190impl TlsServerConfig {
191 pub fn new<P: AsRef<std::path::Path>>(cert_path: P, key_path: P) -> Self {
193 Self {
194 cert_path: cert_path.as_ref().to_string_lossy().to_string(),
195 key_path: key_path.as_ref().to_string_lossy().to_string(),
196 client_ca_path: None,
197 require_client_auth: false,
198 tls_versions: None,
199 cipher_suites: None,
200 alpn_protocols: Some(vec![b"h2".to_vec(), b"http/1.1".to_vec()]),
201 }
202 }
203
204 pub fn with_tls_versions(mut self, versions: Vec<TlsVersion>) -> Self {
206 self.tls_versions = Some(versions);
207 self
208 }
209
210 pub fn with_cipher_suites(mut self, cipher_suites: Vec<rustls::SupportedCipherSuite>) -> Self {
212 self.cipher_suites = Some(cipher_suites);
213 self
214 }
215
216 pub fn with_client_auth<S: Into<String>>(mut self, client_ca_path: S) -> Self {
218 self.client_ca_path = Some(client_ca_path.into());
219 self.require_client_auth = true;
220 self
221 }
222
223 pub fn require_client_auth(mut self, required: bool) -> Self {
225 self.require_client_auth = required;
226 self
227 }
228
229 pub fn with_alpn_protocols(mut self, protocols: Vec<Vec<u8>>) -> Self {
231 self.alpn_protocols = Some(protocols);
232 self
233 }
234
235 pub fn generate_self_signed<P: AsRef<Path>>(cert_path: P, key_path: P) -> io::Result<Self> {
237 let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()])
238 .map_err(|e| io::Error::other(format!("Certificate generation error: {e}")))?;
239
240 let mut cert_file = File::create(&cert_path)?;
242 let pem = cert.cert.pem();
243 cert_file.write_all(pem.as_bytes())?;
244
245 let mut key_file = File::create(&key_path)?;
247 key_file.write_all(cert.signing_key.serialize_pem().as_bytes())?;
248
249 Ok(Self {
250 cert_path: cert_path.as_ref().to_string_lossy().to_string(),
251 key_path: key_path.as_ref().to_string_lossy().to_string(),
252 client_ca_path: None,
253 require_client_auth: false,
254 tls_versions: None,
255 cipher_suites: None,
256 alpn_protocols: Some(vec![b"h2".to_vec(), b"http/1.1".to_vec()]),
257 })
258 }
259
260 pub fn load_server_config(&self) -> Result<ServerConfig> {
262 let cert_file = File::open(&self.cert_path)
264 .map_err(|e| ProtocolError::TlsError(format!("Failed to open cert file: {e}")))?;
265 let mut cert_reader = BufReader::new(cert_file);
266 let cert_chain: std::result::Result<Vec<_>, _> = certs(&mut cert_reader).collect();
267 let cert_chain: Vec<CertificateDer<'static>> = cert_chain
268 .map_err(|_| ProtocolError::TlsError("Failed to parse certificate".into()))?;
269
270 if cert_chain.is_empty() {
271 return Err(ProtocolError::TlsError("No certificates found".into()));
272 }
273
274 let key_file = File::open(&self.key_path)
276 .map_err(|e| ProtocolError::TlsError(format!("Failed to open key file: {e}")))?;
277 let mut key_reader = BufReader::new(key_file);
278 let private_key = load_private_key(&mut key_reader)?;
279
280 if let Some(versions) = &self.tls_versions {
283 let mut has_tls13 = false;
284 let mut has_tls12 = false;
285 for v in versions {
286 match v {
287 TlsVersion::TLS12 => has_tls12 = true,
288 TlsVersion::TLS13 => has_tls13 = true,
289 TlsVersion::All => {
290 has_tls13 = true;
291 has_tls12 = true;
292 }
293 }
294 }
295 debug!(
297 "TLS versions requested: TLS1.2={}, TLS1.3={}",
298 has_tls12, has_tls13
299 );
300 }
301
302 let config_builder = ServerConfig::builder_with_provider(std::sync::Arc::new(
304 rustls::crypto::ring::default_provider(),
305 ))
306 .with_safe_default_protocol_versions()
307 .map_err(|_| ProtocolError::TlsError("Failed to configure TLS protocol versions".into()))?;
308
309 let cert_builder = config_builder.with_no_client_auth();
310
311 let mut config = cert_builder
313 .with_single_cert(cert_chain.clone(), private_key.clone_key())
314 .map_err(|e| ProtocolError::TlsError(format!("TLS error: {e}")))?;
315
316 if let Some(client_ca_path) = &self.client_ca_path {
318 let client_ca_file = File::open(client_ca_path).map_err(|e| {
320 ProtocolError::TlsError(format!("Failed to open client CA file: {e}"))
321 })?;
322 let mut client_ca_reader = BufReader::new(client_ca_file);
323 let client_ca_certs: std::result::Result<Vec<_>, _> =
324 certs(&mut client_ca_reader).collect();
325 let client_ca_certs: Vec<CertificateDer<'static>> = client_ca_certs.map_err(|_| {
326 ProtocolError::TlsError("Failed to parse client CA certificate".into())
327 })?;
328
329 if client_ca_certs.is_empty() {
330 return Err(ProtocolError::TlsError(
331 "No client CA certificates found".into(),
332 ));
333 }
334
335 let mut client_root_store = RootCertStore::empty();
337 for cert in client_ca_certs {
338 client_root_store.add(cert).map_err(|e| {
339 ProtocolError::TlsError(format!("Failed to add client CA cert: {e}"))
340 })?;
341 }
342
343 let client_auth = rustls::server::WebPkiClientVerifier::builder(std::sync::Arc::new(
345 client_root_store,
346 ))
347 .build()
348 .map_err(|e| {
349 ProtocolError::TlsError(format!("Failed to build client verifier: {e}"))
350 })?;
351
352 let new_builder = ServerConfig::builder_with_provider(std::sync::Arc::new(
354 rustls::crypto::ring::default_provider(),
355 ))
356 .with_safe_default_protocol_versions()
357 .map_err(|_| {
358 ProtocolError::TlsError("Failed to configure TLS protocol versions".into())
359 })?;
360 let new_cert_builder = new_builder.with_client_cert_verifier(client_auth);
361
362 config = new_cert_builder
364 .with_single_cert(cert_chain, private_key.clone_key())
365 .map_err(|e| ProtocolError::TlsError(format!("TLS error with client auth: {e}")))?;
366
367 debug!("mTLS enabled with client certificate verification required");
368 }
369
370 if let Some(protocols) = &self.alpn_protocols {
372 config.alpn_protocols = protocols.clone();
373 debug!(
374 protocol_count = protocols.len(),
375 "ALPN protocols configured"
376 );
377 }
378
379 Ok(config)
380 }
381
382 pub fn calculate_cert_hash(cert: &CertificateDer<'_>) -> Vec<u8> {
384 use sha2::{Digest, Sha256};
385 let mut hasher = Sha256::new();
386 hasher.update(cert.as_ref());
387 hasher.finalize().to_vec()
388 }
389}
390
391pub struct TlsClientConfig {
393 server_name: String,
394 insecure: bool,
395 pinned_cert_hash: Option<Vec<u8>>,
397 client_cert_path: Option<String>,
399 client_key_path: Option<String>,
401 tls_versions: Option<Vec<TlsVersion>>,
403 cipher_suites: Option<Vec<rustls::SupportedCipherSuite>>,
405}
406
407impl TlsClientConfig {
408 pub fn new<S: Into<String>>(server_name: S) -> Self {
410 Self {
411 server_name: server_name.into(),
412 insecure: false,
413 pinned_cert_hash: None,
414 client_cert_path: None,
415 client_key_path: None,
416 tls_versions: None,
417 cipher_suites: None,
418 }
419 }
420
421 pub fn with_tls_versions(mut self, versions: Vec<TlsVersion>) -> Self {
423 self.tls_versions = Some(versions);
424 self
425 }
426
427 pub fn with_cipher_suites(mut self, cipher_suites: Vec<rustls::SupportedCipherSuite>) -> Self {
429 self.cipher_suites = Some(cipher_suites);
430 self
431 }
432
433 pub fn with_client_certificate<S: Into<String>>(mut self, cert_path: S, key_path: S) -> Self {
435 self.client_cert_path = Some(cert_path.into());
436 self.client_key_path = Some(key_path.into());
437 self
438 }
439
440 pub fn insecure(mut self) -> Self {
453 warn!("INSECURE MODE ENABLED: Certificate verification is disabled. This should only be used for development/testing.");
454 self.insecure = true;
455 self
456 }
457
458 pub fn with_pinned_cert_hash(mut self, hash: Vec<u8>) -> Self {
465 if hash.len() != 32 {
466 warn!(
467 "Certificate hash has unexpected length: {} (expected 32 bytes for SHA-256)",
468 hash.len()
469 );
470 }
471 self.pinned_cert_hash = Some(hash);
472 self
473 }
474
475 pub fn load_client_config(&self) -> Result<ClientConfig> {
477 self.log_tls_version_info();
478
479 if self.insecure {
480 self.build_insecure_client_config()
481 } else {
482 self.build_secure_client_config()
483 }
484 }
485
486 fn log_tls_version_info(&self) {
488 if let Some(versions) = &self.tls_versions {
489 let mut has_tls13 = false;
490 let mut has_tls12 = false;
491 for v in versions {
492 match v {
493 TlsVersion::TLS12 => has_tls12 = true,
494 TlsVersion::TLS13 => has_tls13 = true,
495 TlsVersion::All => {
496 has_tls13 = true;
497 has_tls12 = true;
498 }
499 }
500 }
501 debug!(
502 "TLS client versions requested: TLS1.2={}, TLS1.3={}",
503 has_tls12, has_tls13
504 );
505 }
506 }
507
508 fn build_secure_client_config(&self) -> Result<ClientConfig> {
510 let root_store = self.load_system_root_certificates()?;
511 let builder = ClientConfig::builder_with_provider(std::sync::Arc::new(
512 rustls::crypto::ring::default_provider(),
513 ))
514 .with_safe_default_protocol_versions()
515 .map_err(|_| ProtocolError::TlsError("Failed to configure TLS protocol versions".into()))?
516 .with_root_certificates(root_store);
517
518 if let (Some(client_cert_path), Some(client_key_path)) =
520 (&self.client_cert_path, &self.client_key_path)
521 {
522 let (cert_chain, key) =
523 self.load_client_credentials(client_cert_path, client_key_path)?;
524 builder.with_client_auth_cert(cert_chain, key).map_err(|e| {
525 ProtocolError::TlsError(format!("Failed to set client certificate: {e}"))
526 })
527 } else {
528 Ok(builder.with_no_client_auth())
529 }
530 }
531
532 fn build_insecure_client_config(&self) -> Result<ClientConfig> {
534 let builder = ClientConfig::builder_with_provider(std::sync::Arc::new(
535 rustls::crypto::ring::default_provider(),
536 ))
537 .with_safe_default_protocol_versions()
538 .map_err(|_| ProtocolError::TlsError("Failed to configure TLS protocol versions".into()))?;
539 let verifier = self.create_custom_verifier();
540 let custom_builder = builder
541 .dangerous()
542 .with_custom_certificate_verifier(verifier);
543
544 if let (Some(client_cert_path), Some(client_key_path)) =
546 (&self.client_cert_path, &self.client_key_path)
547 {
548 let (cert_chain, key) =
549 self.load_client_credentials(client_cert_path, client_key_path)?;
550 custom_builder
551 .with_client_auth_cert(cert_chain, key)
552 .map_err(|e| {
553 ProtocolError::TlsError(format!("Failed to set client certificate: {e}"))
554 })
555 } else {
556 Ok(custom_builder.with_no_client_auth())
557 }
558 }
559
560 fn load_system_root_certificates(&self) -> Result<RootCertStore> {
562 let mut root_store = RootCertStore::empty();
563 let native_certs = rustls_native_certs::load_native_certs()
564 .map_err(|e| ProtocolError::TlsError(format!("Failed to load native certs: {e}")))?;
565
566 for cert in native_certs {
567 root_store.add(cert).map_err(|e| {
568 ProtocolError::TlsError(format!("Failed to add cert to root store: {e}"))
569 })?;
570 }
571
572 Ok(root_store)
573 }
574
575 fn create_custom_verifier(&self) -> Arc<dyn ServerCertVerifier> {
577 if let Some(hash) = &self.pinned_cert_hash {
578 Arc::new(CertificateFingerprint {
579 fingerprint: hash.clone(),
580 })
581 } else {
582 Arc::new(AcceptAnyServerCert)
583 }
584 }
585
586 fn load_client_credentials(
588 &self,
589 cert_path: &str,
590 key_path: &str,
591 ) -> Result<(Vec<CertificateDer<'static>>, PrivateKeyDer<'static>)> {
592 let cert_file = File::open(cert_path).map_err(ProtocolError::Io)?;
594 let mut cert_reader = BufReader::new(cert_file);
595 let certs_result: std::result::Result<Vec<_>, _> =
596 rustls_pemfile::certs(&mut cert_reader).collect();
597 let certs: Vec<CertificateDer<'static>> = certs_result
598 .map_err(|_| ProtocolError::TlsError("Failed to parse client certificate".into()))?;
599
600 if certs.is_empty() {
601 return Err(ProtocolError::TlsError(
602 "No client certificates found".into(),
603 ));
604 }
605
606 let key_file = File::open(key_path).map_err(ProtocolError::Io)?;
608 let mut key_reader = BufReader::new(key_file);
609 let key = load_private_key(&mut key_reader)?;
610
611 Ok((certs, key))
612 }
613
614 pub fn server_name(&self) -> Result<ServerName<'_>> {
616 ServerName::try_from(self.server_name.as_str())
617 .map_err(|_| ProtocolError::TlsError("Invalid server name".into()))
618 }
619
620 pub fn server_name_string(&self) -> String {
622 self.server_name.clone()
623 }
624}
625
626#[instrument(skip(config))]
628pub async fn start_server(addr: &str, config: TlsServerConfig) -> Result<()> {
629 let tls_config = config.load_server_config()?;
630 let acceptor = TlsAcceptor::from(Arc::new(tls_config));
631 let listener = TcpListener::bind(addr).await?;
632
633 info!(address=%addr, "TLS server listening");
634
635 loop {
636 let (stream, peer) = listener.accept().await?;
637 let acceptor = acceptor.clone();
638
639 tokio::spawn(async move {
640 match acceptor.accept(stream).await {
641 Ok(tls_stream) => {
642 if let Err(e) = handle_tls_connection(tls_stream, peer).await {
643 error!(%peer, error=%e, "Connection error");
644 }
645 }
646 Err(e) => {
647 error!(%peer, error=%e, "TLS handshake failed");
648 }
649 }
650 });
651 }
652}
653
654#[instrument(skip(tls_stream), fields(peer=%peer))]
656async fn handle_tls_connection(
657 tls_stream: ServerTlsStream<TcpStream>,
658 peer: SocketAddr,
659) -> Result<()> {
660 let mut framed = Framed::new(tls_stream, PacketCodec);
661
662 info!("TLS connection established");
663
664 while let Some(packet) = framed.next().await {
665 match packet {
666 Ok(pkt) => {
667 debug!(bytes = pkt.payload.len(), "Received data");
668 on_packet(pkt, &mut framed).await?;
669 }
670 Err(e) => {
671 error!(error=%e, "Protocol error");
672 break;
673 }
674 }
675 }
676
677 info!("TLS connection closed");
678 Ok(())
679}
680
681#[instrument(skip(framed), fields(packet_version=pkt.version, payload_size=pkt.payload.len()))]
683async fn on_packet<T>(pkt: Packet, framed: &mut Framed<T, PacketCodec>) -> Result<()>
684where
685 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
686{
687 let response = Packet {
689 version: pkt.version,
690 payload: pkt.payload,
691 };
692
693 framed.send(response).await?;
694 Ok(())
695}
696
697pub async fn connect(
699 addr: &str,
700 config: TlsClientConfig,
701) -> Result<Framed<ClientTlsStream<TcpStream>, PacketCodec>> {
702 let tls_config = Arc::new(config.load_client_config()?);
703 let connector = TlsConnector::from(tls_config);
704
705 let stream = TcpStream::connect(addr).await?;
706
707 let server_name_str = config.server_name_string();
710 let domain_static: &'static str = Box::leak(server_name_str.into_boxed_str());
711 let domain = ServerName::try_from(domain_static)
712 .map_err(|_| ProtocolError::TlsError("Invalid server name".into()))?;
713
714 let tls_stream = connector
715 .connect(domain, stream)
716 .await
717 .map_err(|e| ProtocolError::TlsError(format!("TLS connection failed: {e}")))?;
718
719 let framed = Framed::new(tls_stream, PacketCodec);
720 Ok(framed)
721}