network_protocol/transport/
tls.rs1use std::fs::File;
19use std::io::{self, BufReader, Seek, Write};
20use std::net::SocketAddr;
21use std::path::Path;
22use std::sync::Arc;
23
24use rustls::ServerName;
25use rustls::{Certificate, ClientConfig, PrivateKey, RootCertStore, ServerConfig};
26use rustls_pemfile::{certs, pkcs8_private_keys};
27use tokio::net::{TcpListener, TcpStream};
28use tokio_rustls::client::TlsStream as ClientTlsStream;
29use tokio_rustls::server::TlsStream as ServerTlsStream;
30use tokio_rustls::{TlsAcceptor, TlsConnector};
31use tokio_util::codec::Framed;
32use tracing::{debug, error, info, instrument, warn};
33
34use crate::core::codec::PacketCodec;
35use crate::core::packet::Packet;
36use crate::error::{ProtocolError, Result};
37use futures::{SinkExt, StreamExt};
38
39struct CertificateFingerprint {
41 fingerprint: Vec<u8>,
42}
43
44impl rustls::client::ServerCertVerifier for CertificateFingerprint {
45 fn verify_server_cert(
46 &self,
47 end_entity: &Certificate,
48 _intermediates: &[Certificate],
49 _server_name: &ServerName,
50 _scts: &mut dyn Iterator<Item = &[u8]>,
51 _ocsp_response: &[u8],
52 _now: std::time::SystemTime,
53 ) -> std::result::Result<rustls::client::ServerCertVerified, rustls::Error> {
54 use sha2::{Digest, Sha256};
55
56 let mut hasher = Sha256::new();
57 hasher.update(&end_entity.0);
58 let hash = hasher.finalize();
59
60 if hash.as_slice() == self.fingerprint.as_slice() {
61 Ok(rustls::client::ServerCertVerified::assertion())
62 } else {
63 Err(rustls::Error::General(
64 "Pinned certificate hash mismatch".into(),
65 ))
66 }
67 }
68}
69
70struct AcceptAnyServerCert;
71
72impl rustls::client::ServerCertVerifier for AcceptAnyServerCert {
73 fn verify_server_cert(
74 &self,
75 _end_entity: &Certificate,
76 _intermediates: &[Certificate],
77 _server_name: &ServerName,
78 _scts: &mut dyn Iterator<Item = &[u8]>,
79 _ocsp_response: &[u8],
80 _now: std::time::SystemTime,
81 ) -> std::result::Result<rustls::client::ServerCertVerified, rustls::Error> {
82 Ok(rustls::client::ServerCertVerified::assertion())
83 }
84}
85
86pub enum TlsVersion {
88 TLS12,
90 TLS13,
92 All,
94}
95
96pub struct TlsServerConfig {
98 cert_path: String,
99 key_path: String,
100 client_ca_path: Option<String>,
102 require_client_auth: bool,
104 tls_versions: Option<Vec<TlsVersion>>,
106 cipher_suites: Option<Vec<rustls::SupportedCipherSuite>>,
108}
109
110impl TlsServerConfig {
111 pub fn new<P: AsRef<std::path::Path>>(cert_path: P, key_path: P) -> Self {
113 Self {
114 cert_path: cert_path.as_ref().to_string_lossy().to_string(),
115 key_path: key_path.as_ref().to_string_lossy().to_string(),
116 client_ca_path: None,
117 require_client_auth: false,
118 tls_versions: None,
119 cipher_suites: None,
120 }
121 }
122
123 pub fn with_tls_versions(mut self, versions: Vec<TlsVersion>) -> Self {
125 self.tls_versions = Some(versions);
126 self
127 }
128
129 pub fn with_cipher_suites(mut self, cipher_suites: Vec<rustls::SupportedCipherSuite>) -> Self {
131 self.cipher_suites = Some(cipher_suites);
132 self
133 }
134
135 pub fn with_client_auth<S: Into<String>>(mut self, client_ca_path: S) -> Self {
137 self.client_ca_path = Some(client_ca_path.into());
138 self.require_client_auth = true;
139 self
140 }
141
142 pub fn require_client_auth(mut self, required: bool) -> Self {
144 self.require_client_auth = required;
145 self
146 }
147
148 pub fn generate_self_signed<P: AsRef<Path>>(cert_path: P, key_path: P) -> io::Result<Self> {
150 let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()])
151 .map_err(|e| io::Error::other(format!("Certificate generation error: {e}")))?;
152
153 let mut cert_file = File::create(&cert_path)?;
155 let pem = cert.cert.pem();
156 cert_file.write_all(pem.as_bytes())?;
157
158 let mut key_file = File::create(&key_path)?;
160 key_file.write_all(cert.signing_key.serialize_pem().as_bytes())?;
161
162 Ok(Self {
163 cert_path: cert_path.as_ref().to_string_lossy().to_string(),
164 key_path: key_path.as_ref().to_string_lossy().to_string(),
165 client_ca_path: None,
166 require_client_auth: false,
167 tls_versions: None,
168 cipher_suites: None,
169 })
170 }
171
172 pub fn load_server_config(&self) -> Result<ServerConfig> {
174 let cert_file = File::open(&self.cert_path)
176 .map_err(|e| ProtocolError::TlsError(format!("Failed to open cert file: {e}")))?;
177 let mut cert_reader = BufReader::new(cert_file);
178 let cert_chain = certs(&mut cert_reader)
179 .map_err(|_| ProtocolError::TlsError("Failed to parse certificate".into()))?;
180
181 let cert_chain: Vec<Certificate> = cert_chain.into_iter().map(Certificate).collect();
183
184 let key_file = File::open(&self.key_path)
186 .map_err(|e| ProtocolError::TlsError(format!("Failed to open key file: {e}")))?;
187 let mut key_reader = BufReader::new(key_file);
188 let keys = pkcs8_private_keys(&mut key_reader)
189 .map_err(|_| ProtocolError::TlsError("Failed to parse private key".into()))?;
190
191 if keys.is_empty() {
192 return Err(ProtocolError::TlsError("No private keys found".into()));
193 }
194
195 let private_key = PrivateKey(keys[0].clone());
197
198 if let Some(versions) = &self.tls_versions {
201 let mut has_tls13 = false;
202 let mut has_tls12 = false;
203 for v in versions {
204 match v {
205 TlsVersion::TLS12 => has_tls12 = true,
206 TlsVersion::TLS13 => has_tls13 = true,
207 TlsVersion::All => {
208 has_tls13 = true;
209 has_tls12 = true;
210 }
211 }
212 }
213 debug!(
215 "TLS versions requested: TLS1.2={}, TLS1.3={}",
216 has_tls12, has_tls13
217 );
218 }
219
220 let config_builder = ServerConfig::builder().with_safe_defaults();
222
223 let cert_builder = config_builder.with_no_client_auth();
230
231 let mut config = cert_builder
233 .with_single_cert(cert_chain.clone(), private_key.clone())
234 .map_err(|e| ProtocolError::TlsError(format!("TLS error: {e}")))?;
235
236 if let Some(client_ca_path) = &self.client_ca_path {
238 let client_ca_file = File::open(client_ca_path).map_err(|e| {
240 ProtocolError::TlsError(format!("Failed to open client CA file: {e}"))
241 })?;
242 let mut client_ca_reader = BufReader::new(client_ca_file);
243 let client_ca_certs = certs(&mut client_ca_reader).map_err(|_| {
244 ProtocolError::TlsError("Failed to parse client CA certificate".into())
245 })?;
246
247 let client_ca_certs: Vec<Certificate> =
249 client_ca_certs.into_iter().map(Certificate).collect();
250
251 let mut client_root_store = RootCertStore::empty();
253 for cert in &client_ca_certs {
254 client_root_store.add(cert).map_err(|e| {
255 ProtocolError::TlsError(format!("Failed to add client CA cert: {e}"))
256 })?;
257 }
258
259 let client_auth = Arc::new(rustls::server::AllowAnyAuthenticatedClient::new(
261 client_root_store,
262 ));
263
264 let new_builder = ServerConfig::builder().with_safe_defaults();
266 let new_cert_builder = new_builder.with_client_cert_verifier(client_auth);
267
268 config = new_cert_builder
270 .with_single_cert(cert_chain, private_key)
271 .map_err(|e| ProtocolError::TlsError(format!("TLS error with client auth: {e}")))?;
272
273 debug!("mTLS enabled with client certificate verification required");
274 }
275
276 Ok(config)
277 }
278}
279
280pub struct TlsClientConfig {
282 server_name: String,
283 insecure: bool,
284 pinned_cert_hash: Option<Vec<u8>>,
286 client_cert_path: Option<String>,
288 client_key_path: Option<String>,
290 tls_versions: Option<Vec<TlsVersion>>,
292 cipher_suites: Option<Vec<rustls::SupportedCipherSuite>>,
294}
295
296impl TlsClientConfig {
297 pub fn new<S: Into<String>>(server_name: S) -> Self {
299 Self {
300 server_name: server_name.into(),
301 insecure: false,
302 pinned_cert_hash: None,
303 client_cert_path: None,
304 client_key_path: None,
305 tls_versions: None,
306 cipher_suites: None,
307 }
308 }
309
310 pub fn with_tls_versions(mut self, versions: Vec<TlsVersion>) -> Self {
312 self.tls_versions = Some(versions);
313 self
314 }
315
316 pub fn with_cipher_suites(mut self, cipher_suites: Vec<rustls::SupportedCipherSuite>) -> Self {
318 self.cipher_suites = Some(cipher_suites);
319 self
320 }
321
322 pub fn with_client_certificate<S: Into<String>>(mut self, cert_path: S, key_path: S) -> Self {
324 self.client_cert_path = Some(cert_path.into());
325 self.client_key_path = Some(key_path.into());
326 self
327 }
328
329 pub fn insecure(mut self) -> Self {
342 warn!("INSECURE MODE ENABLED: Certificate verification is disabled. This should only be used for development/testing.");
343 self.insecure = true;
344 self
345 }
346
347 pub fn with_pinned_cert_hash(mut self, hash: Vec<u8>) -> Self {
354 if hash.len() != 32 {
355 warn!(
356 "Certificate hash has unexpected length: {} (expected 32 bytes for SHA-256)",
357 hash.len()
358 );
359 }
360 self.pinned_cert_hash = Some(hash);
361 self
362 }
363
364 pub fn calculate_cert_hash(cert: &Certificate) -> Vec<u8> {
366 use sha2::{Digest, Sha256};
367 let mut hasher = Sha256::new();
368 hasher.update(&cert.0);
369 hasher.finalize().to_vec()
370 }
371
372 fn load_private_key(reader: &mut BufReader<File>) -> Result<PrivateKey> {
374 reader
377 .seek(std::io::SeekFrom::Start(0))
378 .map_err(ProtocolError::Io)?;
379
380 let keys = pkcs8_private_keys(reader)
382 .map_err(|_| ProtocolError::TlsError("Failed to parse PKCS8 private key".into()))?;
383
384 if !keys.is_empty() {
385 return Ok(PrivateKey(keys[0].clone()));
386 }
387
388 Err(ProtocolError::TlsError(
391 "No supported private key format found".into(),
392 ))
393 }
394
395 pub fn load_client_config(&self) -> Result<ClientConfig> {
397 self.log_tls_version_info();
398
399 if self.insecure {
400 self.build_insecure_client_config()
401 } else {
402 self.build_secure_client_config()
403 }
404 }
405
406 fn log_tls_version_info(&self) {
408 if let Some(versions) = &self.tls_versions {
409 let mut has_tls13 = false;
410 let mut has_tls12 = false;
411 for v in versions {
412 match v {
413 TlsVersion::TLS12 => has_tls12 = true,
414 TlsVersion::TLS13 => has_tls13 = true,
415 TlsVersion::All => {
416 has_tls13 = true;
417 has_tls12 = true;
418 }
419 }
420 }
421 debug!(
422 "TLS client versions requested: TLS1.2={}, TLS1.3={}",
423 has_tls12, has_tls13
424 );
425 }
426 }
427
428 fn build_secure_client_config(&self) -> Result<ClientConfig> {
430 let root_store = self.load_system_root_certificates()?;
431 let builder = ClientConfig::builder()
432 .with_safe_defaults()
433 .with_root_certificates(root_store);
434
435 if let (Some(client_cert_path), Some(client_key_path)) =
437 (&self.client_cert_path, &self.client_key_path)
438 {
439 let (cert_chain, key) =
440 self.load_client_credentials(client_cert_path, client_key_path)?;
441 builder.with_client_auth_cert(cert_chain, key).map_err(|e| {
442 ProtocolError::TlsError(format!("Failed to set client certificate: {e}"))
443 })
444 } else {
445 Ok(builder.with_no_client_auth())
446 }
447 }
448
449 fn build_insecure_client_config(&self) -> Result<ClientConfig> {
451 let builder = ClientConfig::builder().with_safe_defaults();
452 let verifier = self.create_custom_verifier();
453 let custom_builder = builder.with_custom_certificate_verifier(verifier);
454
455 if let (Some(client_cert_path), Some(client_key_path)) =
457 (&self.client_cert_path, &self.client_key_path)
458 {
459 let (cert_chain, key) =
460 self.load_client_credentials(client_cert_path, client_key_path)?;
461 custom_builder
462 .with_client_auth_cert(cert_chain, key)
463 .map_err(|e| {
464 ProtocolError::TlsError(format!("Failed to set client certificate: {e}"))
465 })
466 } else {
467 Ok(custom_builder.with_no_client_auth())
468 }
469 }
470
471 fn load_system_root_certificates(&self) -> Result<RootCertStore> {
473 let mut root_store = RootCertStore::empty();
474 let native_certs = rustls_native_certs::load_native_certs()
475 .map_err(|e| ProtocolError::TlsError(format!("Failed to load native certs: {e}")))?;
476
477 for cert in native_certs {
478 root_store.add(&Certificate(cert.0)).map_err(|e| {
479 ProtocolError::TlsError(format!("Failed to add cert to root store: {e}"))
480 })?;
481 }
482
483 Ok(root_store)
484 }
485
486 fn create_custom_verifier(&self) -> Arc<dyn rustls::client::ServerCertVerifier> {
488 if let Some(hash) = &self.pinned_cert_hash {
489 Arc::new(CertificateFingerprint {
490 fingerprint: hash.clone(),
491 })
492 } else {
493 Arc::new(AcceptAnyServerCert)
494 }
495 }
496
497 fn load_client_credentials(
499 &self,
500 cert_path: &str,
501 key_path: &str,
502 ) -> Result<(Vec<Certificate>, PrivateKey)> {
503 let cert_file = File::open(cert_path).map_err(ProtocolError::Io)?;
505 let mut cert_reader = BufReader::new(cert_file);
506 let certs = rustls_pemfile::certs(&mut cert_reader)
507 .map_err(|_| ProtocolError::TlsError("Failed to parse client certificate".into()))?;
508
509 if certs.is_empty() {
510 return Err(ProtocolError::TlsError(
511 "No client certificates found".into(),
512 ));
513 }
514
515 let key_file = File::open(key_path).map_err(ProtocolError::Io)?;
517 let mut key_reader = BufReader::new(key_file);
518 let key = Self::load_private_key(&mut key_reader)?;
519
520 let cert_chain = certs.into_iter().map(Certificate).collect();
521 Ok((cert_chain, key))
522 }
523
524 pub fn server_name(&self) -> Result<ServerName> {
526 ServerName::try_from(self.server_name.as_str())
527 .map_err(|_| ProtocolError::TlsError("Invalid server name".into()))
528 }
529}
530
531#[instrument(skip(config))]
533pub async fn start_server(addr: &str, config: TlsServerConfig) -> Result<()> {
534 let tls_config = config.load_server_config()?;
535 let acceptor = TlsAcceptor::from(Arc::new(tls_config));
536 let listener = TcpListener::bind(addr).await?;
537
538 info!(address=%addr, "TLS server listening");
539
540 loop {
541 let (stream, peer) = listener.accept().await?;
542 let acceptor = acceptor.clone();
543
544 tokio::spawn(async move {
545 match acceptor.accept(stream).await {
546 Ok(tls_stream) => {
547 if let Err(e) = handle_tls_connection(tls_stream, peer).await {
548 error!(%peer, error=%e, "Connection error");
549 }
550 }
551 Err(e) => {
552 error!(%peer, error=%e, "TLS handshake failed");
553 }
554 }
555 });
556 }
557}
558
559#[instrument(skip(tls_stream), fields(peer=%peer))]
561async fn handle_tls_connection(
562 tls_stream: ServerTlsStream<TcpStream>,
563 peer: SocketAddr,
564) -> Result<()> {
565 let mut framed = Framed::new(tls_stream, PacketCodec);
566
567 info!("TLS connection established");
568
569 while let Some(packet) = framed.next().await {
570 match packet {
571 Ok(pkt) => {
572 debug!(bytes = pkt.payload.len(), "Received data");
573 on_packet(pkt, &mut framed).await?;
574 }
575 Err(e) => {
576 error!(error=%e, "Protocol error");
577 break;
578 }
579 }
580 }
581
582 info!("TLS connection closed");
583 Ok(())
584}
585
586#[instrument(skip(framed), fields(packet_version=pkt.version, payload_size=pkt.payload.len()))]
588async fn on_packet<T>(pkt: Packet, framed: &mut Framed<T, PacketCodec>) -> Result<()>
589where
590 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
591{
592 let response = Packet {
594 version: pkt.version,
595 payload: pkt.payload,
596 };
597
598 framed.send(response).await?;
599 Ok(())
600}
601
602#[instrument(skip(config), fields(address=%addr))]
604pub async fn connect(
605 addr: &str,
606 config: TlsClientConfig,
607) -> Result<Framed<ClientTlsStream<TcpStream>, PacketCodec>> {
608 let tls_config = Arc::new(config.load_client_config()?);
609 let connector = TlsConnector::from(tls_config);
610
611 let stream = TcpStream::connect(addr).await?;
612 let domain = config.server_name()?;
613
614 let tls_stream = connector
615 .connect(domain, stream)
616 .await
617 .map_err(|e| ProtocolError::TlsError(format!("TLS connection failed: {e}")))?;
618
619 let framed = Framed::new(tls_stream, PacketCodec);
620 Ok(framed)
621}