network_protocol/transport/
tls.rs1use std::fs::File;
19use std::io::{self, BufReader, Seek, Write};
20use std::net::SocketAddr;
21use std::path::Path;
22use std::sync::Arc;
23
24use rustls::ServerName;
25use rustls::{Certificate, ClientConfig, PrivateKey, RootCertStore, ServerConfig};
26use rustls_pemfile::{certs, pkcs8_private_keys};
27use tokio::net::{TcpListener, TcpStream};
28use tokio_rustls::client::TlsStream as ClientTlsStream;
29use tokio_rustls::server::TlsStream as ServerTlsStream;
30use tokio_rustls::{TlsAcceptor, TlsConnector};
31use tokio_util::codec::Framed;
32use tracing::{debug, error, info, instrument, warn};
33
34use crate::core::codec::PacketCodec;
35use crate::core::packet::Packet;
36use crate::error::{ProtocolError, Result};
37use futures::{SinkExt, StreamExt};
38
39struct CertificateFingerprint {
41 fingerprint: Vec<u8>,
42}
43
44impl rustls::client::ServerCertVerifier for CertificateFingerprint {
45 fn verify_server_cert(
46 &self,
47 end_entity: &Certificate,
48 _intermediates: &[Certificate],
49 _server_name: &ServerName,
50 _scts: &mut dyn Iterator<Item = &[u8]>,
51 _ocsp_response: &[u8],
52 _now: std::time::SystemTime,
53 ) -> std::result::Result<rustls::client::ServerCertVerified, rustls::Error> {
54 use sha2::{Digest, Sha256};
55
56 let mut hasher = Sha256::new();
57 hasher.update(&end_entity.0);
58 let hash = hasher.finalize();
59
60 if hash.as_slice() == self.fingerprint.as_slice() {
61 Ok(rustls::client::ServerCertVerified::assertion())
62 } else {
63 Err(rustls::Error::General(
64 "Pinned certificate hash mismatch".into(),
65 ))
66 }
67 }
68}
69
70struct AcceptAnyServerCert;
71
72impl rustls::client::ServerCertVerifier for AcceptAnyServerCert {
73 fn verify_server_cert(
74 &self,
75 _end_entity: &Certificate,
76 _intermediates: &[Certificate],
77 _server_name: &ServerName,
78 _scts: &mut dyn Iterator<Item = &[u8]>,
79 _ocsp_response: &[u8],
80 _now: std::time::SystemTime,
81 ) -> std::result::Result<rustls::client::ServerCertVerified, rustls::Error> {
82 Ok(rustls::client::ServerCertVerified::assertion())
83 }
84}
85
86pub enum TlsVersion {
88 TLS12,
90 TLS13,
92 All,
94}
95
96pub struct TlsServerConfig {
98 cert_path: String,
99 key_path: String,
100 client_ca_path: Option<String>,
102 require_client_auth: bool,
104 tls_versions: Option<Vec<TlsVersion>>,
106 cipher_suites: Option<Vec<rustls::SupportedCipherSuite>>,
108 alpn_protocols: Option<Vec<Vec<u8>>>,
110}
111
112impl TlsServerConfig {
113 pub fn new<P: AsRef<std::path::Path>>(cert_path: P, key_path: P) -> Self {
115 Self {
116 cert_path: cert_path.as_ref().to_string_lossy().to_string(),
117 key_path: key_path.as_ref().to_string_lossy().to_string(),
118 client_ca_path: None,
119 require_client_auth: false,
120 tls_versions: None,
121 cipher_suites: None,
122 alpn_protocols: Some(vec![b"h2".to_vec(), b"http/1.1".to_vec()]),
123 }
124 }
125
126 pub fn with_tls_versions(mut self, versions: Vec<TlsVersion>) -> Self {
128 self.tls_versions = Some(versions);
129 self
130 }
131
132 pub fn with_cipher_suites(mut self, cipher_suites: Vec<rustls::SupportedCipherSuite>) -> Self {
134 self.cipher_suites = Some(cipher_suites);
135 self
136 }
137
138 pub fn with_client_auth<S: Into<String>>(mut self, client_ca_path: S) -> Self {
140 self.client_ca_path = Some(client_ca_path.into());
141 self.require_client_auth = true;
142 self
143 }
144
145 pub fn require_client_auth(mut self, required: bool) -> Self {
147 self.require_client_auth = required;
148 self
149 }
150
151 pub fn with_alpn_protocols(mut self, protocols: Vec<Vec<u8>>) -> Self {
153 self.alpn_protocols = Some(protocols);
154 self
155 }
156
157 pub fn generate_self_signed<P: AsRef<Path>>(cert_path: P, key_path: P) -> io::Result<Self> {
159 let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()])
160 .map_err(|e| io::Error::other(format!("Certificate generation error: {e}")))?;
161
162 let mut cert_file = File::create(&cert_path)?;
164 let pem = cert.cert.pem();
165 cert_file.write_all(pem.as_bytes())?;
166
167 let mut key_file = File::create(&key_path)?;
169 key_file.write_all(cert.signing_key.serialize_pem().as_bytes())?;
170
171 Ok(Self {
172 cert_path: cert_path.as_ref().to_string_lossy().to_string(),
173 key_path: key_path.as_ref().to_string_lossy().to_string(),
174 client_ca_path: None,
175 require_client_auth: false,
176 tls_versions: None,
177 cipher_suites: None,
178 alpn_protocols: Some(vec![b"h2".to_vec(), b"http/1.1".to_vec()]),
179 })
180 }
181
182 pub fn load_server_config(&self) -> Result<ServerConfig> {
184 let cert_file = File::open(&self.cert_path)
186 .map_err(|e| ProtocolError::TlsError(format!("Failed to open cert file: {e}")))?;
187 let mut cert_reader = BufReader::new(cert_file);
188 let cert_chain = certs(&mut cert_reader)
189 .map_err(|_| ProtocolError::TlsError("Failed to parse certificate".into()))?;
190
191 let cert_chain: Vec<Certificate> = cert_chain.into_iter().map(Certificate).collect();
193
194 let key_file = File::open(&self.key_path)
196 .map_err(|e| ProtocolError::TlsError(format!("Failed to open key file: {e}")))?;
197 let mut key_reader = BufReader::new(key_file);
198 let keys = pkcs8_private_keys(&mut key_reader)
199 .map_err(|_| ProtocolError::TlsError("Failed to parse private key".into()))?;
200
201 if keys.is_empty() {
202 return Err(ProtocolError::TlsError("No private keys found".into()));
203 }
204
205 let private_key = PrivateKey(keys[0].clone());
207
208 if let Some(versions) = &self.tls_versions {
211 let mut has_tls13 = false;
212 let mut has_tls12 = false;
213 for v in versions {
214 match v {
215 TlsVersion::TLS12 => has_tls12 = true,
216 TlsVersion::TLS13 => has_tls13 = true,
217 TlsVersion::All => {
218 has_tls13 = true;
219 has_tls12 = true;
220 }
221 }
222 }
223 debug!(
225 "TLS versions requested: TLS1.2={}, TLS1.3={}",
226 has_tls12, has_tls13
227 );
228 }
229
230 let config_builder = ServerConfig::builder().with_safe_defaults();
232
233 let cert_builder = config_builder.with_no_client_auth();
240
241 let mut config = cert_builder
243 .with_single_cert(cert_chain.clone(), private_key.clone())
244 .map_err(|e| ProtocolError::TlsError(format!("TLS error: {e}")))?;
245
246 if let Some(client_ca_path) = &self.client_ca_path {
248 let client_ca_file = File::open(client_ca_path).map_err(|e| {
250 ProtocolError::TlsError(format!("Failed to open client CA file: {e}"))
251 })?;
252 let mut client_ca_reader = BufReader::new(client_ca_file);
253 let client_ca_certs = certs(&mut client_ca_reader).map_err(|_| {
254 ProtocolError::TlsError("Failed to parse client CA certificate".into())
255 })?;
256
257 let client_ca_certs: Vec<Certificate> =
259 client_ca_certs.into_iter().map(Certificate).collect();
260
261 let mut client_root_store = RootCertStore::empty();
263 for cert in &client_ca_certs {
264 client_root_store.add(cert).map_err(|e| {
265 ProtocolError::TlsError(format!("Failed to add client CA cert: {e}"))
266 })?;
267 }
268
269 let client_auth = Arc::new(rustls::server::AllowAnyAuthenticatedClient::new(
271 client_root_store,
272 ));
273
274 let new_builder = ServerConfig::builder().with_safe_defaults();
276 let new_cert_builder = new_builder.with_client_cert_verifier(client_auth);
277
278 config = new_cert_builder
280 .with_single_cert(cert_chain, private_key)
281 .map_err(|e| ProtocolError::TlsError(format!("TLS error with client auth: {e}")))?;
282
283 debug!("mTLS enabled with client certificate verification required");
284 }
285
286 if let Some(protocols) = &self.alpn_protocols {
288 config.alpn_protocols = protocols.clone();
289 debug!(
290 protocol_count = protocols.len(),
291 "ALPN protocols configured"
292 );
293 }
294
295 Ok(config)
296 }
297}
298
299pub struct TlsClientConfig {
301 server_name: String,
302 insecure: bool,
303 pinned_cert_hash: Option<Vec<u8>>,
305 client_cert_path: Option<String>,
307 client_key_path: Option<String>,
309 tls_versions: Option<Vec<TlsVersion>>,
311 cipher_suites: Option<Vec<rustls::SupportedCipherSuite>>,
313}
314
315impl TlsClientConfig {
316 pub fn new<S: Into<String>>(server_name: S) -> Self {
318 Self {
319 server_name: server_name.into(),
320 insecure: false,
321 pinned_cert_hash: None,
322 client_cert_path: None,
323 client_key_path: None,
324 tls_versions: None,
325 cipher_suites: None,
326 }
327 }
328
329 pub fn with_tls_versions(mut self, versions: Vec<TlsVersion>) -> Self {
331 self.tls_versions = Some(versions);
332 self
333 }
334
335 pub fn with_cipher_suites(mut self, cipher_suites: Vec<rustls::SupportedCipherSuite>) -> Self {
337 self.cipher_suites = Some(cipher_suites);
338 self
339 }
340
341 pub fn with_client_certificate<S: Into<String>>(mut self, cert_path: S, key_path: S) -> Self {
343 self.client_cert_path = Some(cert_path.into());
344 self.client_key_path = Some(key_path.into());
345 self
346 }
347
348 pub fn insecure(mut self) -> Self {
361 warn!("INSECURE MODE ENABLED: Certificate verification is disabled. This should only be used for development/testing.");
362 self.insecure = true;
363 self
364 }
365
366 pub fn with_pinned_cert_hash(mut self, hash: Vec<u8>) -> Self {
373 if hash.len() != 32 {
374 warn!(
375 "Certificate hash has unexpected length: {} (expected 32 bytes for SHA-256)",
376 hash.len()
377 );
378 }
379 self.pinned_cert_hash = Some(hash);
380 self
381 }
382
383 pub fn calculate_cert_hash(cert: &Certificate) -> Vec<u8> {
385 use sha2::{Digest, Sha256};
386 let mut hasher = Sha256::new();
387 hasher.update(&cert.0);
388 hasher.finalize().to_vec()
389 }
390
391 fn load_private_key(reader: &mut BufReader<File>) -> Result<PrivateKey> {
393 reader
396 .seek(std::io::SeekFrom::Start(0))
397 .map_err(ProtocolError::Io)?;
398
399 let keys = pkcs8_private_keys(reader)
401 .map_err(|_| ProtocolError::TlsError("Failed to parse PKCS8 private key".into()))?;
402
403 if !keys.is_empty() {
404 return Ok(PrivateKey(keys[0].clone()));
405 }
406
407 Err(ProtocolError::TlsError(
410 "No supported private key format found".into(),
411 ))
412 }
413
414 pub fn load_client_config(&self) -> Result<ClientConfig> {
416 self.log_tls_version_info();
417
418 if self.insecure {
419 self.build_insecure_client_config()
420 } else {
421 self.build_secure_client_config()
422 }
423 }
424
425 fn log_tls_version_info(&self) {
427 if let Some(versions) = &self.tls_versions {
428 let mut has_tls13 = false;
429 let mut has_tls12 = false;
430 for v in versions {
431 match v {
432 TlsVersion::TLS12 => has_tls12 = true,
433 TlsVersion::TLS13 => has_tls13 = true,
434 TlsVersion::All => {
435 has_tls13 = true;
436 has_tls12 = true;
437 }
438 }
439 }
440 debug!(
441 "TLS client versions requested: TLS1.2={}, TLS1.3={}",
442 has_tls12, has_tls13
443 );
444 }
445 }
446
447 fn build_secure_client_config(&self) -> Result<ClientConfig> {
449 let root_store = self.load_system_root_certificates()?;
450 let builder = ClientConfig::builder()
451 .with_safe_defaults()
452 .with_root_certificates(root_store);
453
454 if let (Some(client_cert_path), Some(client_key_path)) =
456 (&self.client_cert_path, &self.client_key_path)
457 {
458 let (cert_chain, key) =
459 self.load_client_credentials(client_cert_path, client_key_path)?;
460 builder.with_client_auth_cert(cert_chain, key).map_err(|e| {
461 ProtocolError::TlsError(format!("Failed to set client certificate: {e}"))
462 })
463 } else {
464 Ok(builder.with_no_client_auth())
465 }
466 }
467
468 fn build_insecure_client_config(&self) -> Result<ClientConfig> {
470 let builder = ClientConfig::builder().with_safe_defaults();
471 let verifier = self.create_custom_verifier();
472 let custom_builder = builder.with_custom_certificate_verifier(verifier);
473
474 if let (Some(client_cert_path), Some(client_key_path)) =
476 (&self.client_cert_path, &self.client_key_path)
477 {
478 let (cert_chain, key) =
479 self.load_client_credentials(client_cert_path, client_key_path)?;
480 custom_builder
481 .with_client_auth_cert(cert_chain, key)
482 .map_err(|e| {
483 ProtocolError::TlsError(format!("Failed to set client certificate: {e}"))
484 })
485 } else {
486 Ok(custom_builder.with_no_client_auth())
487 }
488 }
489
490 fn load_system_root_certificates(&self) -> Result<RootCertStore> {
492 let mut root_store = RootCertStore::empty();
493 let native_certs = rustls_native_certs::load_native_certs()
494 .map_err(|e| ProtocolError::TlsError(format!("Failed to load native certs: {e}")))?;
495
496 for cert in native_certs {
497 root_store.add(&Certificate(cert.0)).map_err(|e| {
498 ProtocolError::TlsError(format!("Failed to add cert to root store: {e}"))
499 })?;
500 }
501
502 Ok(root_store)
503 }
504
505 fn create_custom_verifier(&self) -> Arc<dyn rustls::client::ServerCertVerifier> {
507 if let Some(hash) = &self.pinned_cert_hash {
508 Arc::new(CertificateFingerprint {
509 fingerprint: hash.clone(),
510 })
511 } else {
512 Arc::new(AcceptAnyServerCert)
513 }
514 }
515
516 fn load_client_credentials(
518 &self,
519 cert_path: &str,
520 key_path: &str,
521 ) -> Result<(Vec<Certificate>, PrivateKey)> {
522 let cert_file = File::open(cert_path).map_err(ProtocolError::Io)?;
524 let mut cert_reader = BufReader::new(cert_file);
525 let certs = rustls_pemfile::certs(&mut cert_reader)
526 .map_err(|_| ProtocolError::TlsError("Failed to parse client certificate".into()))?;
527
528 if certs.is_empty() {
529 return Err(ProtocolError::TlsError(
530 "No client certificates found".into(),
531 ));
532 }
533
534 let key_file = File::open(key_path).map_err(ProtocolError::Io)?;
536 let mut key_reader = BufReader::new(key_file);
537 let key = Self::load_private_key(&mut key_reader)?;
538
539 let cert_chain = certs.into_iter().map(Certificate).collect();
540 Ok((cert_chain, key))
541 }
542
543 pub fn server_name(&self) -> Result<ServerName> {
545 ServerName::try_from(self.server_name.as_str())
546 .map_err(|_| ProtocolError::TlsError("Invalid server name".into()))
547 }
548}
549
550#[instrument(skip(config))]
552pub async fn start_server(addr: &str, config: TlsServerConfig) -> Result<()> {
553 let tls_config = config.load_server_config()?;
554 let acceptor = TlsAcceptor::from(Arc::new(tls_config));
555 let listener = TcpListener::bind(addr).await?;
556
557 info!(address=%addr, "TLS server listening");
558
559 loop {
560 let (stream, peer) = listener.accept().await?;
561 let acceptor = acceptor.clone();
562
563 tokio::spawn(async move {
564 match acceptor.accept(stream).await {
565 Ok(tls_stream) => {
566 if let Err(e) = handle_tls_connection(tls_stream, peer).await {
567 error!(%peer, error=%e, "Connection error");
568 }
569 }
570 Err(e) => {
571 error!(%peer, error=%e, "TLS handshake failed");
572 }
573 }
574 });
575 }
576}
577
578#[instrument(skip(tls_stream), fields(peer=%peer))]
580async fn handle_tls_connection(
581 tls_stream: ServerTlsStream<TcpStream>,
582 peer: SocketAddr,
583) -> Result<()> {
584 let mut framed = Framed::new(tls_stream, PacketCodec);
585
586 info!("TLS connection established");
587
588 while let Some(packet) = framed.next().await {
589 match packet {
590 Ok(pkt) => {
591 debug!(bytes = pkt.payload.len(), "Received data");
592 on_packet(pkt, &mut framed).await?;
593 }
594 Err(e) => {
595 error!(error=%e, "Protocol error");
596 break;
597 }
598 }
599 }
600
601 info!("TLS connection closed");
602 Ok(())
603}
604
605#[instrument(skip(framed), fields(packet_version=pkt.version, payload_size=pkt.payload.len()))]
607async fn on_packet<T>(pkt: Packet, framed: &mut Framed<T, PacketCodec>) -> Result<()>
608where
609 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
610{
611 let response = Packet {
613 version: pkt.version,
614 payload: pkt.payload,
615 };
616
617 framed.send(response).await?;
618 Ok(())
619}
620
621#[instrument(skip(config), fields(address=%addr))]
623pub async fn connect(
624 addr: &str,
625 config: TlsClientConfig,
626) -> Result<Framed<ClientTlsStream<TcpStream>, PacketCodec>> {
627 let tls_config = Arc::new(config.load_client_config()?);
628 let connector = TlsConnector::from(tls_config);
629
630 let stream = TcpStream::connect(addr).await?;
631 let domain = config.server_name()?;
632
633 let tls_stream = connector
634 .connect(domain, stream)
635 .await
636 .map_err(|e| ProtocolError::TlsError(format!("TLS connection failed: {e}")))?;
637
638 let framed = Framed::new(tls_stream, PacketCodec);
639 Ok(framed)
640}