Skip to main content

network_protocol/transport/
tls.rs

1//! # TLS Transport Layer
2//!
3//! This file is part of the Network Protocol project.
4//!
5//! It defines the TLS transport layer for secure network communication,
6//! particularly for external untrusted connections.
7//!
8//! The TLS transport layer provides a secure channel for communication
9//! using industry-standard TLS protocol, ensuring confidentiality,
10//! integrity, and authentication of the data transmitted.
11//!
12//! ## Responsibilities
13//! - Establish secure TLS connections
14//! - Handle TLS certificates and verification
15//! - Provide secure framed transport for higher protocol layers
16//! - Compatible with existing packet codec infrastructure
17
18use std::fs::File;
19use std::io::{self, BufReader, Seek, Write};
20use std::net::SocketAddr;
21use std::path::Path;
22use std::sync::Arc;
23
24use rustls::ServerName;
25use rustls::{Certificate, ClientConfig, PrivateKey, RootCertStore, ServerConfig};
26use rustls_pemfile::{certs, pkcs8_private_keys};
27use tokio::net::{TcpListener, TcpStream};
28use tokio_rustls::client::TlsStream as ClientTlsStream;
29use tokio_rustls::server::TlsStream as ServerTlsStream;
30use tokio_rustls::{TlsAcceptor, TlsConnector};
31use tokio_util::codec::Framed;
32use tracing::{debug, error, info, instrument, warn};
33
34use crate::core::codec::PacketCodec;
35use crate::core::packet::Packet;
36use crate::error::{ProtocolError, Result};
37use futures::{SinkExt, StreamExt};
38
39// Custom certificate verifiers
40struct CertificateFingerprint {
41    fingerprint: Vec<u8>,
42}
43
44impl rustls::client::ServerCertVerifier for CertificateFingerprint {
45    fn verify_server_cert(
46        &self,
47        end_entity: &Certificate,
48        _intermediates: &[Certificate],
49        _server_name: &ServerName,
50        _scts: &mut dyn Iterator<Item = &[u8]>,
51        _ocsp_response: &[u8],
52        _now: std::time::SystemTime,
53    ) -> std::result::Result<rustls::client::ServerCertVerified, rustls::Error> {
54        use sha2::{Digest, Sha256};
55
56        let mut hasher = Sha256::new();
57        hasher.update(&end_entity.0);
58        let hash = hasher.finalize();
59
60        if hash.as_slice() == self.fingerprint.as_slice() {
61            Ok(rustls::client::ServerCertVerified::assertion())
62        } else {
63            Err(rustls::Error::General(
64                "Pinned certificate hash mismatch".into(),
65            ))
66        }
67    }
68}
69
70struct AcceptAnyServerCert;
71
72impl rustls::client::ServerCertVerifier for AcceptAnyServerCert {
73    fn verify_server_cert(
74        &self,
75        _end_entity: &Certificate,
76        _intermediates: &[Certificate],
77        _server_name: &ServerName,
78        _scts: &mut dyn Iterator<Item = &[u8]>,
79        _ocsp_response: &[u8],
80        _now: std::time::SystemTime,
81    ) -> std::result::Result<rustls::client::ServerCertVerified, rustls::Error> {
82        Ok(rustls::client::ServerCertVerified::assertion())
83    }
84}
85
86/// TLS protocol version
87pub enum TlsVersion {
88    /// TLS 1.2
89    TLS12,
90    /// TLS 1.3
91    TLS13,
92    /// Both TLS 1.2 and 1.3
93    All,
94}
95
96/// TLS server configuration
97pub struct TlsServerConfig {
98    cert_path: String,
99    key_path: String,
100    /// Optional path to client CA certificates for mTLS
101    client_ca_path: Option<String>,
102    /// Whether to require client certificates (mTLS)
103    require_client_auth: bool,
104    /// Allowed TLS protocol versions (None = use rustls defaults)
105    tls_versions: Option<Vec<TlsVersion>>,
106    /// Allowed cipher suites (None = use rustls defaults)
107    cipher_suites: Option<Vec<rustls::SupportedCipherSuite>>,
108    /// ALPN protocols to advertise (default: ["h2", "http/1.1"])
109    alpn_protocols: Option<Vec<Vec<u8>>>,
110}
111
112impl TlsServerConfig {
113    /// Create a new TLS server configuration
114    pub fn new<P: AsRef<std::path::Path>>(cert_path: P, key_path: P) -> Self {
115        Self {
116            cert_path: cert_path.as_ref().to_string_lossy().to_string(),
117            key_path: key_path.as_ref().to_string_lossy().to_string(),
118            client_ca_path: None,
119            require_client_auth: false,
120            tls_versions: None,
121            cipher_suites: None,
122            alpn_protocols: Some(vec![b"h2".to_vec(), b"http/1.1".to_vec()]),
123        }
124    }
125
126    /// Set allowed TLS protocol versions
127    pub fn with_tls_versions(mut self, versions: Vec<TlsVersion>) -> Self {
128        self.tls_versions = Some(versions);
129        self
130    }
131
132    /// Set allowed cipher suites
133    pub fn with_cipher_suites(mut self, cipher_suites: Vec<rustls::SupportedCipherSuite>) -> Self {
134        self.cipher_suites = Some(cipher_suites);
135        self
136    }
137
138    /// Enable mutual TLS authentication by providing a CA certificate path
139    pub fn with_client_auth<S: Into<String>>(mut self, client_ca_path: S) -> Self {
140        self.client_ca_path = Some(client_ca_path.into());
141        self.require_client_auth = true;
142        self
143    }
144
145    /// Set whether client authentication is required (true) or optional (false)
146    pub fn require_client_auth(mut self, required: bool) -> Self {
147        self.require_client_auth = required;
148        self
149    }
150
151    /// Set ALPN protocols to advertise during TLS handshake
152    pub fn with_alpn_protocols(mut self, protocols: Vec<Vec<u8>>) -> Self {
153        self.alpn_protocols = Some(protocols);
154        self
155    }
156
157    /// Generate a self-signed certificate for development/testing purposes
158    pub fn generate_self_signed<P: AsRef<Path>>(cert_path: P, key_path: P) -> io::Result<Self> {
159        let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()])
160            .map_err(|e| io::Error::other(format!("Certificate generation error: {e}")))?;
161
162        // Write certificate
163        let mut cert_file = File::create(&cert_path)?;
164        let pem = cert.cert.pem();
165        cert_file.write_all(pem.as_bytes())?;
166
167        // Write private key
168        let mut key_file = File::create(&key_path)?;
169        key_file.write_all(cert.signing_key.serialize_pem().as_bytes())?;
170
171        Ok(Self {
172            cert_path: cert_path.as_ref().to_string_lossy().to_string(),
173            key_path: key_path.as_ref().to_string_lossy().to_string(),
174            client_ca_path: None,
175            require_client_auth: false,
176            tls_versions: None,
177            cipher_suites: None,
178            alpn_protocols: Some(vec![b"h2".to_vec(), b"http/1.1".to_vec()]),
179        })
180    }
181
182    /// Load the TLS configuration from files
183    pub fn load_server_config(&self) -> Result<ServerConfig> {
184        // Load certificate
185        let cert_file = File::open(&self.cert_path)
186            .map_err(|e| ProtocolError::TlsError(format!("Failed to open cert file: {e}")))?;
187        let mut cert_reader = BufReader::new(cert_file);
188        let cert_chain = certs(&mut cert_reader)
189            .map_err(|_| ProtocolError::TlsError("Failed to parse certificate".into()))?;
190
191        // Convert to rustls Certificate type
192        let cert_chain: Vec<Certificate> = cert_chain.into_iter().map(Certificate).collect();
193
194        // Load private key
195        let key_file = File::open(&self.key_path)
196            .map_err(|e| ProtocolError::TlsError(format!("Failed to open key file: {e}")))?;
197        let mut key_reader = BufReader::new(key_file);
198        let keys = pkcs8_private_keys(&mut key_reader)
199            .map_err(|_| ProtocolError::TlsError("Failed to parse private key".into()))?;
200
201        if keys.is_empty() {
202            return Err(ProtocolError::TlsError("No private keys found".into()));
203        }
204
205        // Convert to rustls PrivateKey
206        let private_key = PrivateKey(keys[0].clone());
207
208        // Validate TLS versions if specified
209        // Note: In rustls 0.21, with_safe_defaults() restricts to TLS 1.2+ (best practice)
210        if let Some(versions) = &self.tls_versions {
211            let mut has_tls13 = false;
212            let mut has_tls12 = false;
213            for v in versions {
214                match v {
215                    TlsVersion::TLS12 => has_tls12 = true,
216                    TlsVersion::TLS13 => has_tls13 = true,
217                    TlsVersion::All => {
218                        has_tls13 = true;
219                        has_tls12 = true;
220                    }
221                }
222            }
223            // Document that with_safe_defaults uses best practices
224            debug!(
225                "TLS versions requested: TLS1.2={}, TLS1.3={}",
226                has_tls12, has_tls13
227            );
228        }
229
230        // Create a server configuration with safe defaults (TLS 1.2+, modern ciphersuites)
231        let config_builder = ServerConfig::builder().with_safe_defaults();
232
233        // Note: rustls 0.21 doesn't expose API to change ciphersuites after builder creation.
234        // with_safe_defaults() already provides excellent defaults:
235        // - TLS 1.3: CHACHA20_POLY1305, AES_128_GCM, AES_256_GCM
236        // - TLS 1.2: same + some legacy suites
237        // Custom cipher suite restriction requires building at compile time.
238
239        let cert_builder = config_builder.with_no_client_auth();
240
241        // Build config with certificates
242        let mut config = cert_builder
243            .with_single_cert(cert_chain.clone(), private_key.clone())
244            .map_err(|e| ProtocolError::TlsError(format!("TLS error: {e}")))?;
245
246        // Configure client authentication if required (mTLS)
247        if let Some(client_ca_path) = &self.client_ca_path {
248            // Load client CA certificates
249            let client_ca_file = File::open(client_ca_path).map_err(|e| {
250                ProtocolError::TlsError(format!("Failed to open client CA file: {e}"))
251            })?;
252            let mut client_ca_reader = BufReader::new(client_ca_file);
253            let client_ca_certs = certs(&mut client_ca_reader).map_err(|_| {
254                ProtocolError::TlsError("Failed to parse client CA certificate".into())
255            })?;
256
257            // Convert to rustls Certificate type
258            let client_ca_certs: Vec<Certificate> =
259                client_ca_certs.into_iter().map(Certificate).collect();
260
261            // Create client cert verifier
262            let mut client_root_store = RootCertStore::empty();
263            for cert in &client_ca_certs {
264                client_root_store.add(cert).map_err(|e| {
265                    ProtocolError::TlsError(format!("Failed to add client CA cert: {e}"))
266                })?;
267            }
268
269            // Create client authentication verifier
270            let client_auth = Arc::new(rustls::server::AllowAnyAuthenticatedClient::new(
271                client_root_store,
272            ));
273
274            // Create new config builder with client auth
275            let new_builder = ServerConfig::builder().with_safe_defaults();
276            let new_cert_builder = new_builder.with_client_cert_verifier(client_auth);
277
278            // Build a new config with certificates and client auth
279            config = new_cert_builder
280                .with_single_cert(cert_chain, private_key)
281                .map_err(|e| ProtocolError::TlsError(format!("TLS error with client auth: {e}")))?;
282
283            debug!("mTLS enabled with client certificate verification required");
284        }
285
286        // Configure ALPN protocols if specified
287        if let Some(protocols) = &self.alpn_protocols {
288            config.alpn_protocols = protocols.clone();
289            debug!(
290                protocol_count = protocols.len(),
291                "ALPN protocols configured"
292            );
293        }
294
295        Ok(config)
296    }
297}
298
299/// TLS Client Configuration
300pub struct TlsClientConfig {
301    server_name: String,
302    insecure: bool,
303    /// Optional certificate hash to pin (SHA-256 fingerprint)
304    pinned_cert_hash: Option<Vec<u8>>,
305    /// Optional client certificate path for mTLS
306    client_cert_path: Option<String>,
307    /// Optional client key path for mTLS
308    client_key_path: Option<String>,
309    /// Allowed TLS protocol versions (None = use rustls defaults)
310    tls_versions: Option<Vec<TlsVersion>>,
311    /// Allowed cipher suites (None = use rustls defaults)
312    cipher_suites: Option<Vec<rustls::SupportedCipherSuite>>,
313}
314
315impl TlsClientConfig {
316    /// Create a new TLS client configuration
317    pub fn new<S: Into<String>>(server_name: S) -> Self {
318        Self {
319            server_name: server_name.into(),
320            insecure: false,
321            pinned_cert_hash: None,
322            client_cert_path: None,
323            client_key_path: None,
324            tls_versions: None,
325            cipher_suites: None,
326        }
327    }
328
329    /// Set allowed TLS protocol versions
330    pub fn with_tls_versions(mut self, versions: Vec<TlsVersion>) -> Self {
331        self.tls_versions = Some(versions);
332        self
333    }
334
335    /// Set allowed cipher suites
336    pub fn with_cipher_suites(mut self, cipher_suites: Vec<rustls::SupportedCipherSuite>) -> Self {
337        self.cipher_suites = Some(cipher_suites);
338        self
339    }
340
341    /// Configure client authentication for mTLS
342    pub fn with_client_certificate<S: Into<String>>(mut self, cert_path: S, key_path: S) -> Self {
343        self.client_cert_path = Some(cert_path.into());
344        self.client_key_path = Some(key_path.into());
345        self
346    }
347
348    /// Allow insecure connections (skip certificate verification)
349    ///
350    /// # WARNING: Security Risk
351    /// This mode disables certificate verification entirely and should ONLY be used for:
352    /// - Development and testing
353    /// - Debugging environments
354    /// - Internal networks with certificate pinning enabled
355    ///
356    /// **NEVER** use this in production without explicit certificate pinning via `with_pinned_cert_hash()`.
357    ///
358    /// For maximum security, only use this with the "dangerous_configuration" feature enabled,
359    /// which is a strong indicator this is for testing/development only.
360    pub fn insecure(mut self) -> Self {
361        warn!("INSECURE MODE ENABLED: Certificate verification is disabled. This should only be used for development/testing.");
362        self.insecure = true;
363        self
364    }
365
366    /// Pin a certificate by its SHA-256 hash/fingerprint
367    ///
368    /// This provides additional security by only accepting connections
369    /// from servers with the exact certificate matching this hash.
370    /// Can be combined with insecure mode for development environments where
371    /// you want to skip standard CA verification but still verify a specific cert.
372    pub fn with_pinned_cert_hash(mut self, hash: Vec<u8>) -> Self {
373        if hash.len() != 32 {
374            warn!(
375                "Certificate hash has unexpected length: {} (expected 32 bytes for SHA-256)",
376                hash.len()
377            );
378        }
379        self.pinned_cert_hash = Some(hash);
380        self
381    }
382
383    /// Calculate SHA-256 hash for a certificate to use with pinning
384    pub fn calculate_cert_hash(cert: &Certificate) -> Vec<u8> {
385        use sha2::{Digest, Sha256};
386        let mut hasher = Sha256::new();
387        hasher.update(&cert.0);
388        hasher.finalize().to_vec()
389    }
390
391    /// Helper method to load a private key from PKCS8 format
392    fn load_private_key(reader: &mut BufReader<File>) -> Result<PrivateKey> {
393        // Try to load PKCS8 keys
394        // Seek to beginning of file first
395        reader
396            .seek(std::io::SeekFrom::Start(0))
397            .map_err(ProtocolError::Io)?;
398
399        // We need to use pkcs8_private_keys on the BufReader directly since it implements BufRead
400        let keys = pkcs8_private_keys(reader)
401            .map_err(|_| ProtocolError::TlsError("Failed to parse PKCS8 private key".into()))?;
402
403        if !keys.is_empty() {
404            return Ok(PrivateKey(keys[0].clone()));
405        }
406
407        // Note: Add support for other key formats like RSA or EC if needed
408
409        Err(ProtocolError::TlsError(
410            "No supported private key format found".into(),
411        ))
412    }
413
414    /// Load the TLS client configuration
415    pub fn load_client_config(&self) -> Result<ClientConfig> {
416        self.log_tls_version_info();
417
418        if self.insecure {
419            self.build_insecure_client_config()
420        } else {
421            self.build_secure_client_config()
422        }
423    }
424
425    /// Log TLS version configuration
426    fn log_tls_version_info(&self) {
427        if let Some(versions) = &self.tls_versions {
428            let mut has_tls13 = false;
429            let mut has_tls12 = false;
430            for v in versions {
431                match v {
432                    TlsVersion::TLS12 => has_tls12 = true,
433                    TlsVersion::TLS13 => has_tls13 = true,
434                    TlsVersion::All => {
435                        has_tls13 = true;
436                        has_tls12 = true;
437                    }
438                }
439            }
440            debug!(
441                "TLS client versions requested: TLS1.2={}, TLS1.3={}",
442                has_tls12, has_tls13
443            );
444        }
445    }
446
447    /// Build secure client config with system root CAs
448    fn build_secure_client_config(&self) -> Result<ClientConfig> {
449        let root_store = self.load_system_root_certificates()?;
450        let builder = ClientConfig::builder()
451            .with_safe_defaults()
452            .with_root_certificates(root_store);
453
454        // Apply client auth directly
455        if let (Some(client_cert_path), Some(client_key_path)) =
456            (&self.client_cert_path, &self.client_key_path)
457        {
458            let (cert_chain, key) =
459                self.load_client_credentials(client_cert_path, client_key_path)?;
460            builder.with_client_auth_cert(cert_chain, key).map_err(|e| {
461                ProtocolError::TlsError(format!("Failed to set client certificate: {e}"))
462            })
463        } else {
464            Ok(builder.with_no_client_auth())
465        }
466    }
467
468    /// Build insecure client config with custom verifier
469    fn build_insecure_client_config(&self) -> Result<ClientConfig> {
470        let builder = ClientConfig::builder().with_safe_defaults();
471        let verifier = self.create_custom_verifier();
472        let custom_builder = builder.with_custom_certificate_verifier(verifier);
473
474        // Apply client auth directly
475        if let (Some(client_cert_path), Some(client_key_path)) =
476            (&self.client_cert_path, &self.client_key_path)
477        {
478            let (cert_chain, key) =
479                self.load_client_credentials(client_cert_path, client_key_path)?;
480            custom_builder
481                .with_client_auth_cert(cert_chain, key)
482                .map_err(|e| {
483                    ProtocolError::TlsError(format!("Failed to set client certificate: {e}"))
484                })
485        } else {
486            Ok(custom_builder.with_no_client_auth())
487        }
488    }
489
490    /// Load system root certificates
491    fn load_system_root_certificates(&self) -> Result<RootCertStore> {
492        let mut root_store = RootCertStore::empty();
493        let native_certs = rustls_native_certs::load_native_certs()
494            .map_err(|e| ProtocolError::TlsError(format!("Failed to load native certs: {e}")))?;
495
496        for cert in native_certs {
497            root_store.add(&Certificate(cert.0)).map_err(|e| {
498                ProtocolError::TlsError(format!("Failed to add cert to root store: {e}"))
499            })?;
500        }
501
502        Ok(root_store)
503    }
504
505    /// Create custom certificate verifier (pinning or accept-any)
506    fn create_custom_verifier(&self) -> Arc<dyn rustls::client::ServerCertVerifier> {
507        if let Some(hash) = &self.pinned_cert_hash {
508            Arc::new(CertificateFingerprint {
509                fingerprint: hash.clone(),
510            })
511        } else {
512            Arc::new(AcceptAnyServerCert)
513        }
514    }
515
516    /// Load client certificate and private key
517    fn load_client_credentials(
518        &self,
519        cert_path: &str,
520        key_path: &str,
521    ) -> Result<(Vec<Certificate>, PrivateKey)> {
522        // Load certificate
523        let cert_file = File::open(cert_path).map_err(ProtocolError::Io)?;
524        let mut cert_reader = BufReader::new(cert_file);
525        let certs = rustls_pemfile::certs(&mut cert_reader)
526            .map_err(|_| ProtocolError::TlsError("Failed to parse client certificate".into()))?;
527
528        if certs.is_empty() {
529            return Err(ProtocolError::TlsError(
530                "No client certificates found".into(),
531            ));
532        }
533
534        // Load private key
535        let key_file = File::open(key_path).map_err(ProtocolError::Io)?;
536        let mut key_reader = BufReader::new(key_file);
537        let key = Self::load_private_key(&mut key_reader)?;
538
539        let cert_chain = certs.into_iter().map(Certificate).collect();
540        Ok((cert_chain, key))
541    }
542
543    /// Get the server name as a rustls::ServerName
544    pub fn server_name(&self) -> Result<ServerName> {
545        ServerName::try_from(self.server_name.as_str())
546            .map_err(|_| ProtocolError::TlsError("Invalid server name".into()))
547    }
548}
549
550/// Start a TLS server on the given address
551#[instrument(skip(config))]
552pub async fn start_server(addr: &str, config: TlsServerConfig) -> Result<()> {
553    let tls_config = config.load_server_config()?;
554    let acceptor = TlsAcceptor::from(Arc::new(tls_config));
555    let listener = TcpListener::bind(addr).await?;
556
557    info!(address=%addr, "TLS server listening");
558
559    loop {
560        let (stream, peer) = listener.accept().await?;
561        let acceptor = acceptor.clone();
562
563        tokio::spawn(async move {
564            match acceptor.accept(stream).await {
565                Ok(tls_stream) => {
566                    if let Err(e) = handle_tls_connection(tls_stream, peer).await {
567                        error!(%peer, error=%e, "Connection error");
568                    }
569                }
570                Err(e) => {
571                    error!(%peer, error=%e, "TLS handshake failed");
572                }
573            }
574        });
575    }
576}
577
578/// Handle a TLS connection
579#[instrument(skip(tls_stream), fields(peer=%peer))]
580async fn handle_tls_connection(
581    tls_stream: ServerTlsStream<TcpStream>,
582    peer: SocketAddr,
583) -> Result<()> {
584    let mut framed = Framed::new(tls_stream, PacketCodec);
585
586    info!("TLS connection established");
587
588    while let Some(packet) = framed.next().await {
589        match packet {
590            Ok(pkt) => {
591                debug!(bytes = pkt.payload.len(), "Received data");
592                on_packet(pkt, &mut framed).await?;
593            }
594            Err(e) => {
595                error!(error=%e, "Protocol error");
596                break;
597            }
598        }
599    }
600
601    info!("TLS connection closed");
602    Ok(())
603}
604
605/// Handle incoming TLS packets
606#[instrument(skip(framed), fields(packet_version=pkt.version, payload_size=pkt.payload.len()))]
607async fn on_packet<T>(pkt: Packet, framed: &mut Framed<T, PacketCodec>) -> Result<()>
608where
609    T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
610{
611    // Echo the packet back (sample implementation)
612    let response = Packet {
613        version: pkt.version,
614        payload: pkt.payload,
615    };
616
617    framed.send(response).await?;
618    Ok(())
619}
620
621/// Connect to a TLS server
622#[instrument(skip(config), fields(address=%addr))]
623pub async fn connect(
624    addr: &str,
625    config: TlsClientConfig,
626) -> Result<Framed<ClientTlsStream<TcpStream>, PacketCodec>> {
627    let tls_config = Arc::new(config.load_client_config()?);
628    let connector = TlsConnector::from(tls_config);
629
630    let stream = TcpStream::connect(addr).await?;
631    let domain = config.server_name()?;
632
633    let tls_stream = connector
634        .connect(domain, stream)
635        .await
636        .map_err(|e| ProtocolError::TlsError(format!("TLS connection failed: {e}")))?;
637
638    let framed = Framed::new(tls_stream, PacketCodec);
639    Ok(framed)
640}