network_protocol/transport/
tls.rs

1//! # TLS Transport Layer
2//!
3//! This file is part of the Network Protocol project.
4//!
5//! It defines the TLS transport layer for secure network communication,
6//! particularly for external untrusted connections.
7//!
8//! The TLS transport layer provides a secure channel for communication
9//! using industry-standard TLS protocol, ensuring confidentiality,
10//! integrity, and authentication of the data transmitted.
11//!
12//! ## Responsibilities
13//! - Establish secure TLS connections
14//! - Handle TLS certificates and verification
15//! - Provide secure framed transport for higher protocol layers
16//! - Compatible with existing packet codec infrastructure
17
18use std::fs::File;
19use std::io::{self, BufReader, Seek, Write};
20use std::net::SocketAddr;
21use std::path::Path;
22use std::sync::Arc;
23
24use rustls::ServerName;
25use rustls::{Certificate, ClientConfig, PrivateKey, RootCertStore, ServerConfig};
26use rustls_pemfile::{certs, pkcs8_private_keys};
27use tokio::net::{TcpListener, TcpStream};
28use tokio_rustls::client::TlsStream as ClientTlsStream;
29use tokio_rustls::server::TlsStream as ServerTlsStream;
30use tokio_rustls::{TlsAcceptor, TlsConnector};
31use tokio_util::codec::Framed;
32use tracing::{debug, error, info, instrument, warn};
33
34use crate::core::codec::PacketCodec;
35use crate::core::packet::Packet;
36use crate::error::{ProtocolError, Result};
37use futures::{SinkExt, StreamExt};
38
39// Custom certificate verifiers
40struct CertificateFingerprint {
41    fingerprint: Vec<u8>,
42}
43
44impl rustls::client::ServerCertVerifier for CertificateFingerprint {
45    fn verify_server_cert(
46        &self,
47        end_entity: &Certificate,
48        _intermediates: &[Certificate],
49        _server_name: &ServerName,
50        _scts: &mut dyn Iterator<Item = &[u8]>,
51        _ocsp_response: &[u8],
52        _now: std::time::SystemTime,
53    ) -> std::result::Result<rustls::client::ServerCertVerified, rustls::Error> {
54        use sha2::{Digest, Sha256};
55
56        let mut hasher = Sha256::new();
57        hasher.update(&end_entity.0);
58        let hash = hasher.finalize();
59
60        if hash.as_slice() == self.fingerprint.as_slice() {
61            Ok(rustls::client::ServerCertVerified::assertion())
62        } else {
63            Err(rustls::Error::General(
64                "Pinned certificate hash mismatch".into(),
65            ))
66        }
67    }
68}
69
70struct AcceptAnyServerCert;
71
72impl rustls::client::ServerCertVerifier for AcceptAnyServerCert {
73    fn verify_server_cert(
74        &self,
75        _end_entity: &Certificate,
76        _intermediates: &[Certificate],
77        _server_name: &ServerName,
78        _scts: &mut dyn Iterator<Item = &[u8]>,
79        _ocsp_response: &[u8],
80        _now: std::time::SystemTime,
81    ) -> std::result::Result<rustls::client::ServerCertVerified, rustls::Error> {
82        Ok(rustls::client::ServerCertVerified::assertion())
83    }
84}
85
86/// TLS protocol version
87pub enum TlsVersion {
88    /// TLS 1.2
89    TLS12,
90    /// TLS 1.3
91    TLS13,
92    /// Both TLS 1.2 and 1.3
93    All,
94}
95
96/// TLS server configuration
97pub struct TlsServerConfig {
98    cert_path: String,
99    key_path: String,
100    /// Optional path to client CA certificates for mTLS
101    client_ca_path: Option<String>,
102    /// Whether to require client certificates (mTLS)
103    require_client_auth: bool,
104    /// Allowed TLS protocol versions (None = use rustls defaults)
105    tls_versions: Option<Vec<TlsVersion>>,
106    /// Allowed cipher suites (None = use rustls defaults)
107    cipher_suites: Option<Vec<rustls::SupportedCipherSuite>>,
108}
109
110impl TlsServerConfig {
111    /// Create a new TLS server configuration
112    pub fn new<P: AsRef<std::path::Path>>(cert_path: P, key_path: P) -> Self {
113        Self {
114            cert_path: cert_path.as_ref().to_string_lossy().to_string(),
115            key_path: key_path.as_ref().to_string_lossy().to_string(),
116            client_ca_path: None,
117            require_client_auth: false,
118            tls_versions: None,
119            cipher_suites: None,
120        }
121    }
122
123    /// Set allowed TLS protocol versions
124    pub fn with_tls_versions(mut self, versions: Vec<TlsVersion>) -> Self {
125        self.tls_versions = Some(versions);
126        self
127    }
128
129    /// Set allowed cipher suites
130    pub fn with_cipher_suites(mut self, cipher_suites: Vec<rustls::SupportedCipherSuite>) -> Self {
131        self.cipher_suites = Some(cipher_suites);
132        self
133    }
134
135    /// Enable mutual TLS authentication by providing a CA certificate path
136    pub fn with_client_auth<S: Into<String>>(mut self, client_ca_path: S) -> Self {
137        self.client_ca_path = Some(client_ca_path.into());
138        self.require_client_auth = true;
139        self
140    }
141
142    /// Set whether client authentication is required (true) or optional (false)
143    pub fn require_client_auth(mut self, required: bool) -> Self {
144        self.require_client_auth = required;
145        self
146    }
147
148    /// Generate a self-signed certificate for development/testing purposes
149    pub fn generate_self_signed<P: AsRef<Path>>(cert_path: P, key_path: P) -> io::Result<Self> {
150        let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()])
151            .map_err(|e| io::Error::other(format!("Certificate generation error: {e}")))?;
152
153        // Write certificate
154        let mut cert_file = File::create(&cert_path)?;
155        let pem = cert.cert.pem();
156        cert_file.write_all(pem.as_bytes())?;
157
158        // Write private key
159        let mut key_file = File::create(&key_path)?;
160        key_file.write_all(cert.signing_key.serialize_pem().as_bytes())?;
161
162        Ok(Self {
163            cert_path: cert_path.as_ref().to_string_lossy().to_string(),
164            key_path: key_path.as_ref().to_string_lossy().to_string(),
165            client_ca_path: None,
166            require_client_auth: false,
167            tls_versions: None,
168            cipher_suites: None,
169        })
170    }
171
172    /// Load the TLS configuration from files
173    pub fn load_server_config(&self) -> Result<ServerConfig> {
174        // Load certificate
175        let cert_file = File::open(&self.cert_path)
176            .map_err(|e| ProtocolError::TlsError(format!("Failed to open cert file: {e}")))?;
177        let mut cert_reader = BufReader::new(cert_file);
178        let cert_chain = certs(&mut cert_reader)
179            .map_err(|_| ProtocolError::TlsError("Failed to parse certificate".into()))?;
180
181        // Convert to rustls Certificate type
182        let cert_chain: Vec<Certificate> = cert_chain.into_iter().map(Certificate).collect();
183
184        // Load private key
185        let key_file = File::open(&self.key_path)
186            .map_err(|e| ProtocolError::TlsError(format!("Failed to open key file: {e}")))?;
187        let mut key_reader = BufReader::new(key_file);
188        let keys = pkcs8_private_keys(&mut key_reader)
189            .map_err(|_| ProtocolError::TlsError("Failed to parse private key".into()))?;
190
191        if keys.is_empty() {
192            return Err(ProtocolError::TlsError("No private keys found".into()));
193        }
194
195        // Convert to rustls PrivateKey
196        let private_key = PrivateKey(keys[0].clone());
197
198        // Validate TLS versions if specified
199        // Note: In rustls 0.21, with_safe_defaults() restricts to TLS 1.2+ (best practice)
200        if let Some(versions) = &self.tls_versions {
201            let mut has_tls13 = false;
202            let mut has_tls12 = false;
203            for v in versions {
204                match v {
205                    TlsVersion::TLS12 => has_tls12 = true,
206                    TlsVersion::TLS13 => has_tls13 = true,
207                    TlsVersion::All => {
208                        has_tls13 = true;
209                        has_tls12 = true;
210                    }
211                }
212            }
213            // Document that with_safe_defaults uses best practices
214            debug!(
215                "TLS versions requested: TLS1.2={}, TLS1.3={}",
216                has_tls12, has_tls13
217            );
218        }
219
220        // Create a server configuration with safe defaults (TLS 1.2+, modern ciphersuites)
221        let config_builder = ServerConfig::builder().with_safe_defaults();
222
223        // Note: rustls 0.21 doesn't expose API to change ciphersuites after builder creation.
224        // with_safe_defaults() already provides excellent defaults:
225        // - TLS 1.3: CHACHA20_POLY1305, AES_128_GCM, AES_256_GCM
226        // - TLS 1.2: same + some legacy suites
227        // Custom cipher suite restriction requires building at compile time.
228
229        let cert_builder = config_builder.with_no_client_auth();
230
231        // Build config with certificates
232        let mut config = cert_builder
233            .with_single_cert(cert_chain.clone(), private_key.clone())
234            .map_err(|e| ProtocolError::TlsError(format!("TLS error: {e}")))?;
235
236        // Configure client authentication if required (mTLS)
237        if let Some(client_ca_path) = &self.client_ca_path {
238            // Load client CA certificates
239            let client_ca_file = File::open(client_ca_path).map_err(|e| {
240                ProtocolError::TlsError(format!("Failed to open client CA file: {e}"))
241            })?;
242            let mut client_ca_reader = BufReader::new(client_ca_file);
243            let client_ca_certs = certs(&mut client_ca_reader).map_err(|_| {
244                ProtocolError::TlsError("Failed to parse client CA certificate".into())
245            })?;
246
247            // Convert to rustls Certificate type
248            let client_ca_certs: Vec<Certificate> =
249                client_ca_certs.into_iter().map(Certificate).collect();
250
251            // Create client cert verifier
252            let mut client_root_store = RootCertStore::empty();
253            for cert in &client_ca_certs {
254                client_root_store.add(cert).map_err(|e| {
255                    ProtocolError::TlsError(format!("Failed to add client CA cert: {e}"))
256                })?;
257            }
258
259            // Create client authentication verifier
260            let client_auth = Arc::new(rustls::server::AllowAnyAuthenticatedClient::new(
261                client_root_store,
262            ));
263
264            // Create new config builder with client auth
265            let new_builder = ServerConfig::builder().with_safe_defaults();
266            let new_cert_builder = new_builder.with_client_cert_verifier(client_auth);
267
268            // Build a new config with certificates and client auth
269            config = new_cert_builder
270                .with_single_cert(cert_chain, private_key)
271                .map_err(|e| ProtocolError::TlsError(format!("TLS error with client auth: {e}")))?;
272
273            debug!("mTLS enabled with client certificate verification required");
274        }
275
276        Ok(config)
277    }
278}
279
280/// TLS Client Configuration
281pub struct TlsClientConfig {
282    server_name: String,
283    insecure: bool,
284    /// Optional certificate hash to pin (SHA-256 fingerprint)
285    pinned_cert_hash: Option<Vec<u8>>,
286    /// Optional client certificate path for mTLS
287    client_cert_path: Option<String>,
288    /// Optional client key path for mTLS
289    client_key_path: Option<String>,
290    /// Allowed TLS protocol versions (None = use rustls defaults)
291    tls_versions: Option<Vec<TlsVersion>>,
292    /// Allowed cipher suites (None = use rustls defaults)
293    cipher_suites: Option<Vec<rustls::SupportedCipherSuite>>,
294}
295
296impl TlsClientConfig {
297    /// Create a new TLS client configuration
298    pub fn new<S: Into<String>>(server_name: S) -> Self {
299        Self {
300            server_name: server_name.into(),
301            insecure: false,
302            pinned_cert_hash: None,
303            client_cert_path: None,
304            client_key_path: None,
305            tls_versions: None,
306            cipher_suites: None,
307        }
308    }
309
310    /// Set allowed TLS protocol versions
311    pub fn with_tls_versions(mut self, versions: Vec<TlsVersion>) -> Self {
312        self.tls_versions = Some(versions);
313        self
314    }
315
316    /// Set allowed cipher suites
317    pub fn with_cipher_suites(mut self, cipher_suites: Vec<rustls::SupportedCipherSuite>) -> Self {
318        self.cipher_suites = Some(cipher_suites);
319        self
320    }
321
322    /// Configure client authentication for mTLS
323    pub fn with_client_certificate<S: Into<String>>(mut self, cert_path: S, key_path: S) -> Self {
324        self.client_cert_path = Some(cert_path.into());
325        self.client_key_path = Some(key_path.into());
326        self
327    }
328
329    /// Allow insecure connections (skip certificate verification)
330    ///
331    /// # WARNING: Security Risk
332    /// This mode disables certificate verification entirely and should ONLY be used for:
333    /// - Development and testing
334    /// - Debugging environments
335    /// - Internal networks with certificate pinning enabled
336    ///
337    /// **NEVER** use this in production without explicit certificate pinning via `with_pinned_cert_hash()`.
338    ///
339    /// For maximum security, only use this with the "dangerous_configuration" feature enabled,
340    /// which is a strong indicator this is for testing/development only.
341    pub fn insecure(mut self) -> Self {
342        warn!("INSECURE MODE ENABLED: Certificate verification is disabled. This should only be used for development/testing.");
343        self.insecure = true;
344        self
345    }
346
347    /// Pin a certificate by its SHA-256 hash/fingerprint
348    ///
349    /// This provides additional security by only accepting connections
350    /// from servers with the exact certificate matching this hash.
351    /// Can be combined with insecure mode for development environments where
352    /// you want to skip standard CA verification but still verify a specific cert.
353    pub fn with_pinned_cert_hash(mut self, hash: Vec<u8>) -> Self {
354        if hash.len() != 32 {
355            warn!(
356                "Certificate hash has unexpected length: {} (expected 32 bytes for SHA-256)",
357                hash.len()
358            );
359        }
360        self.pinned_cert_hash = Some(hash);
361        self
362    }
363
364    /// Calculate SHA-256 hash for a certificate to use with pinning
365    pub fn calculate_cert_hash(cert: &Certificate) -> Vec<u8> {
366        use sha2::{Digest, Sha256};
367        let mut hasher = Sha256::new();
368        hasher.update(&cert.0);
369        hasher.finalize().to_vec()
370    }
371
372    /// Helper method to load a private key from PKCS8 format
373    fn load_private_key(reader: &mut BufReader<File>) -> Result<PrivateKey> {
374        // Try to load PKCS8 keys
375        // Seek to beginning of file first
376        reader
377            .seek(std::io::SeekFrom::Start(0))
378            .map_err(ProtocolError::Io)?;
379
380        // We need to use pkcs8_private_keys on the BufReader directly since it implements BufRead
381        let keys = pkcs8_private_keys(reader)
382            .map_err(|_| ProtocolError::TlsError("Failed to parse PKCS8 private key".into()))?;
383
384        if !keys.is_empty() {
385            return Ok(PrivateKey(keys[0].clone()));
386        }
387
388        // Note: Add support for other key formats like RSA or EC if needed
389
390        Err(ProtocolError::TlsError(
391            "No supported private key format found".into(),
392        ))
393    }
394
395    /// Load the TLS client configuration
396    pub fn load_client_config(&self) -> Result<ClientConfig> {
397        self.log_tls_version_info();
398
399        if self.insecure {
400            self.build_insecure_client_config()
401        } else {
402            self.build_secure_client_config()
403        }
404    }
405
406    /// Log TLS version configuration
407    fn log_tls_version_info(&self) {
408        if let Some(versions) = &self.tls_versions {
409            let mut has_tls13 = false;
410            let mut has_tls12 = false;
411            for v in versions {
412                match v {
413                    TlsVersion::TLS12 => has_tls12 = true,
414                    TlsVersion::TLS13 => has_tls13 = true,
415                    TlsVersion::All => {
416                        has_tls13 = true;
417                        has_tls12 = true;
418                    }
419                }
420            }
421            debug!(
422                "TLS client versions requested: TLS1.2={}, TLS1.3={}",
423                has_tls12, has_tls13
424            );
425        }
426    }
427
428    /// Build secure client config with system root CAs
429    fn build_secure_client_config(&self) -> Result<ClientConfig> {
430        let root_store = self.load_system_root_certificates()?;
431        let builder = ClientConfig::builder()
432            .with_safe_defaults()
433            .with_root_certificates(root_store);
434
435        // Apply client auth directly
436        if let (Some(client_cert_path), Some(client_key_path)) =
437            (&self.client_cert_path, &self.client_key_path)
438        {
439            let (cert_chain, key) =
440                self.load_client_credentials(client_cert_path, client_key_path)?;
441            builder.with_client_auth_cert(cert_chain, key).map_err(|e| {
442                ProtocolError::TlsError(format!("Failed to set client certificate: {e}"))
443            })
444        } else {
445            Ok(builder.with_no_client_auth())
446        }
447    }
448
449    /// Build insecure client config with custom verifier
450    fn build_insecure_client_config(&self) -> Result<ClientConfig> {
451        let builder = ClientConfig::builder().with_safe_defaults();
452        let verifier = self.create_custom_verifier();
453        let custom_builder = builder.with_custom_certificate_verifier(verifier);
454
455        // Apply client auth directly
456        if let (Some(client_cert_path), Some(client_key_path)) =
457            (&self.client_cert_path, &self.client_key_path)
458        {
459            let (cert_chain, key) =
460                self.load_client_credentials(client_cert_path, client_key_path)?;
461            custom_builder
462                .with_client_auth_cert(cert_chain, key)
463                .map_err(|e| {
464                    ProtocolError::TlsError(format!("Failed to set client certificate: {e}"))
465                })
466        } else {
467            Ok(custom_builder.with_no_client_auth())
468        }
469    }
470
471    /// Load system root certificates
472    fn load_system_root_certificates(&self) -> Result<RootCertStore> {
473        let mut root_store = RootCertStore::empty();
474        let native_certs = rustls_native_certs::load_native_certs()
475            .map_err(|e| ProtocolError::TlsError(format!("Failed to load native certs: {e}")))?;
476
477        for cert in native_certs {
478            root_store.add(&Certificate(cert.0)).map_err(|e| {
479                ProtocolError::TlsError(format!("Failed to add cert to root store: {e}"))
480            })?;
481        }
482
483        Ok(root_store)
484    }
485
486    /// Create custom certificate verifier (pinning or accept-any)
487    fn create_custom_verifier(&self) -> Arc<dyn rustls::client::ServerCertVerifier> {
488        if let Some(hash) = &self.pinned_cert_hash {
489            Arc::new(CertificateFingerprint {
490                fingerprint: hash.clone(),
491            })
492        } else {
493            Arc::new(AcceptAnyServerCert)
494        }
495    }
496
497    /// Load client certificate and private key
498    fn load_client_credentials(
499        &self,
500        cert_path: &str,
501        key_path: &str,
502    ) -> Result<(Vec<Certificate>, PrivateKey)> {
503        // Load certificate
504        let cert_file = File::open(cert_path).map_err(ProtocolError::Io)?;
505        let mut cert_reader = BufReader::new(cert_file);
506        let certs = rustls_pemfile::certs(&mut cert_reader)
507            .map_err(|_| ProtocolError::TlsError("Failed to parse client certificate".into()))?;
508
509        if certs.is_empty() {
510            return Err(ProtocolError::TlsError(
511                "No client certificates found".into(),
512            ));
513        }
514
515        // Load private key
516        let key_file = File::open(key_path).map_err(ProtocolError::Io)?;
517        let mut key_reader = BufReader::new(key_file);
518        let key = Self::load_private_key(&mut key_reader)?;
519
520        let cert_chain = certs.into_iter().map(Certificate).collect();
521        Ok((cert_chain, key))
522    }
523
524    /// Get the server name as a rustls::ServerName
525    pub fn server_name(&self) -> Result<ServerName> {
526        ServerName::try_from(self.server_name.as_str())
527            .map_err(|_| ProtocolError::TlsError("Invalid server name".into()))
528    }
529}
530
531/// Start a TLS server on the given address
532#[instrument(skip(config))]
533pub async fn start_server(addr: &str, config: TlsServerConfig) -> Result<()> {
534    let tls_config = config.load_server_config()?;
535    let acceptor = TlsAcceptor::from(Arc::new(tls_config));
536    let listener = TcpListener::bind(addr).await?;
537
538    info!(address=%addr, "TLS server listening");
539
540    loop {
541        let (stream, peer) = listener.accept().await?;
542        let acceptor = acceptor.clone();
543
544        tokio::spawn(async move {
545            match acceptor.accept(stream).await {
546                Ok(tls_stream) => {
547                    if let Err(e) = handle_tls_connection(tls_stream, peer).await {
548                        error!(%peer, error=%e, "Connection error");
549                    }
550                }
551                Err(e) => {
552                    error!(%peer, error=%e, "TLS handshake failed");
553                }
554            }
555        });
556    }
557}
558
559/// Handle a TLS connection
560#[instrument(skip(tls_stream), fields(peer=%peer))]
561async fn handle_tls_connection(
562    tls_stream: ServerTlsStream<TcpStream>,
563    peer: SocketAddr,
564) -> Result<()> {
565    let mut framed = Framed::new(tls_stream, PacketCodec);
566
567    info!("TLS connection established");
568
569    while let Some(packet) = framed.next().await {
570        match packet {
571            Ok(pkt) => {
572                debug!(bytes = pkt.payload.len(), "Received data");
573                on_packet(pkt, &mut framed).await?;
574            }
575            Err(e) => {
576                error!(error=%e, "Protocol error");
577                break;
578            }
579        }
580    }
581
582    info!("TLS connection closed");
583    Ok(())
584}
585
586/// Handle incoming TLS packets
587#[instrument(skip(framed), fields(packet_version=pkt.version, payload_size=pkt.payload.len()))]
588async fn on_packet<T>(pkt: Packet, framed: &mut Framed<T, PacketCodec>) -> Result<()>
589where
590    T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
591{
592    // Echo the packet back (sample implementation)
593    let response = Packet {
594        version: pkt.version,
595        payload: pkt.payload,
596    };
597
598    framed.send(response).await?;
599    Ok(())
600}
601
602/// Connect to a TLS server
603#[instrument(skip(config), fields(address=%addr))]
604pub async fn connect(
605    addr: &str,
606    config: TlsClientConfig,
607) -> Result<Framed<ClientTlsStream<TcpStream>, PacketCodec>> {
608    let tls_config = Arc::new(config.load_client_config()?);
609    let connector = TlsConnector::from(tls_config);
610
611    let stream = TcpStream::connect(addr).await?;
612    let domain = config.server_name()?;
613
614    let tls_stream = connector
615        .connect(domain, stream)
616        .await
617        .map_err(|e| ProtocolError::TlsError(format!("TLS connection failed: {e}")))?;
618
619    let framed = Framed::new(tls_stream, PacketCodec);
620    Ok(framed)
621}