Skip to main content

network_protocol/transport/
tls.rs

1//! # TLS Transport Layer
2//!
3//! This file is part of the Network Protocol project.
4//!
5//! It defines the TLS transport layer for secure network communication,
6//! particularly for external untrusted connections.
7//!
8//! The TLS transport layer provides a secure channel for communication
9//! using industry-standard TLS protocol, ensuring confidentiality,
10//! integrity, and authentication of the data transmitted.
11//!
12//! ## Responsibilities
13//! - Establish secure TLS connections
14//! - Handle TLS certificates and verification
15//! - Provide secure framed transport for higher protocol layers
16//! - Compatible with existing packet codec infrastructure
17
18use std::fs::File;
19use std::io::{self, BufReader, Seek, Write};
20use std::net::SocketAddr;
21use std::path::Path;
22use std::sync::Arc;
23
24use rustls::client::danger::{ServerCertVerified, ServerCertVerifier};
25use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime};
26use rustls::{ClientConfig, DigitallySignedStruct, RootCertStore, ServerConfig};
27use rustls_pemfile::{certs, pkcs8_private_keys};
28use tokio::net::{TcpListener, TcpStream};
29use tokio_rustls::client::TlsStream as ClientTlsStream;
30use tokio_rustls::server::TlsStream as ServerTlsStream;
31use tokio_rustls::{TlsAcceptor, TlsConnector};
32use tokio_util::codec::Framed;
33use tracing::{debug, error, info, instrument, warn};
34
35use crate::core::codec::PacketCodec;
36use crate::core::packet::Packet;
37use crate::error::{ProtocolError, Result};
38use futures::{SinkExt, StreamExt};
39
40// Custom certificate verifiers
41#[derive(Debug)]
42struct CertificateFingerprint {
43    fingerprint: Vec<u8>,
44}
45
46impl ServerCertVerifier for CertificateFingerprint {
47    fn verify_server_cert(
48        &self,
49        end_entity: &CertificateDer<'_>,
50        _intermediates: &[CertificateDer<'_>],
51        _server_name: &ServerName,
52        _ocsp_response: &[u8],
53        _now: UnixTime,
54    ) -> std::result::Result<ServerCertVerified, rustls::Error> {
55        use sha2::{Digest, Sha256};
56
57        let mut hasher = Sha256::new();
58        hasher.update(end_entity);
59        let hash = hasher.finalize();
60
61        if hash.as_slice() == self.fingerprint.as_slice() {
62            Ok(ServerCertVerified::assertion())
63        } else {
64            Err(rustls::Error::General(
65                "Pinned certificate hash mismatch".into(),
66            ))
67        }
68    }
69
70    fn verify_tls12_signature(
71        &self,
72        _message: &[u8],
73        _cert: &CertificateDer<'_>,
74        _dss: &DigitallySignedStruct,
75    ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
76        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
77    }
78
79    fn verify_tls13_signature(
80        &self,
81        _message: &[u8],
82        _cert: &CertificateDer<'_>,
83        _dss: &DigitallySignedStruct,
84    ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
85        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
86    }
87
88    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
89        // Accept common signature schemes
90        vec![
91            rustls::SignatureScheme::RSA_PKCS1_SHA256,
92            rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
93            rustls::SignatureScheme::ED25519,
94        ]
95    }
96}
97
98#[derive(Debug)]
99struct AcceptAnyServerCert;
100
101impl ServerCertVerifier for AcceptAnyServerCert {
102    fn verify_server_cert(
103        &self,
104        _end_entity: &CertificateDer<'_>,
105        _intermediates: &[CertificateDer<'_>],
106        _server_name: &ServerName,
107        _ocsp_response: &[u8],
108        _now: UnixTime,
109    ) -> std::result::Result<ServerCertVerified, rustls::Error> {
110        Ok(ServerCertVerified::assertion())
111    }
112
113    fn verify_tls12_signature(
114        &self,
115        _message: &[u8],
116        _cert: &CertificateDer<'_>,
117        _dss: &DigitallySignedStruct,
118    ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
119        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
120    }
121
122    fn verify_tls13_signature(
123        &self,
124        _message: &[u8],
125        _cert: &CertificateDer<'_>,
126        _dss: &DigitallySignedStruct,
127    ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
128        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
129    }
130
131    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
132        vec![
133            rustls::SignatureScheme::RSA_PKCS1_SHA256,
134            rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
135            rustls::SignatureScheme::ED25519,
136        ]
137    }
138}
139
140/// Helper function to load a private key from PKCS8 format
141fn load_private_key(reader: &mut BufReader<File>) -> Result<PrivateKeyDer<'static>> {
142    // Try to load PKCS8 keys
143    // Seek to beginning of file first
144    reader
145        .seek(std::io::SeekFrom::Start(0))
146        .map_err(ProtocolError::Io)?;
147
148    // pkcs8_private_keys returns an iterator of Results
149    let keys: std::result::Result<Vec<_>, _> = pkcs8_private_keys(reader).collect();
150    let keys =
151        keys.map_err(|_| ProtocolError::TlsError("Failed to parse PKCS8 private key".into()))?;
152
153    if !keys.is_empty() {
154        return Ok(PrivateKeyDer::Pkcs8(keys[0].clone_key()));
155    }
156
157    // Note: Add support for other key formats like RSA or EC if needed
158
159    Err(ProtocolError::TlsError(
160        "No supported private key format found".into(),
161    ))
162}
163
164/// TLS protocol version
165pub enum TlsVersion {
166    /// TLS 1.2
167    TLS12,
168    /// TLS 1.3
169    TLS13,
170    /// Both TLS 1.2 and 1.3
171    All,
172}
173
174/// TLS server configuration
175pub struct TlsServerConfig {
176    cert_path: String,
177    key_path: String,
178    /// Optional path to client CA certificates for mTLS
179    client_ca_path: Option<String>,
180    /// Whether to require client certificates (mTLS)
181    require_client_auth: bool,
182    /// Allowed TLS protocol versions (None = use rustls defaults)
183    tls_versions: Option<Vec<TlsVersion>>,
184    /// Allowed cipher suites (None = use rustls defaults)
185    cipher_suites: Option<Vec<rustls::SupportedCipherSuite>>,
186    /// ALPN protocols to advertise (default: ["h2", "http/1.1"])
187    alpn_protocols: Option<Vec<Vec<u8>>>,
188}
189
190impl TlsServerConfig {
191    /// Create a new TLS server configuration
192    pub fn new<P: AsRef<std::path::Path>>(cert_path: P, key_path: P) -> Self {
193        Self {
194            cert_path: cert_path.as_ref().to_string_lossy().to_string(),
195            key_path: key_path.as_ref().to_string_lossy().to_string(),
196            client_ca_path: None,
197            require_client_auth: false,
198            tls_versions: None,
199            cipher_suites: None,
200            alpn_protocols: Some(vec![b"h2".to_vec(), b"http/1.1".to_vec()]),
201        }
202    }
203
204    /// Set allowed TLS protocol versions
205    pub fn with_tls_versions(mut self, versions: Vec<TlsVersion>) -> Self {
206        self.tls_versions = Some(versions);
207        self
208    }
209
210    /// Set allowed cipher suites
211    pub fn with_cipher_suites(mut self, cipher_suites: Vec<rustls::SupportedCipherSuite>) -> Self {
212        self.cipher_suites = Some(cipher_suites);
213        self
214    }
215
216    /// Enable mutual TLS authentication by providing a CA certificate path
217    pub fn with_client_auth<S: Into<String>>(mut self, client_ca_path: S) -> Self {
218        self.client_ca_path = Some(client_ca_path.into());
219        self.require_client_auth = true;
220        self
221    }
222
223    /// Set whether client authentication is required (true) or optional (false)
224    pub fn require_client_auth(mut self, required: bool) -> Self {
225        self.require_client_auth = required;
226        self
227    }
228
229    /// Set ALPN protocols to advertise during TLS handshake
230    pub fn with_alpn_protocols(mut self, protocols: Vec<Vec<u8>>) -> Self {
231        self.alpn_protocols = Some(protocols);
232        self
233    }
234
235    /// Generate a self-signed certificate for development/testing purposes
236    pub fn generate_self_signed<P: AsRef<Path>>(cert_path: P, key_path: P) -> io::Result<Self> {
237        let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()])
238            .map_err(|e| io::Error::other(format!("Certificate generation error: {e}")))?;
239
240        // Write certificate
241        let mut cert_file = File::create(&cert_path)?;
242        let pem = cert.cert.pem();
243        cert_file.write_all(pem.as_bytes())?;
244
245        // Write private key
246        let mut key_file = File::create(&key_path)?;
247        key_file.write_all(cert.signing_key.serialize_pem().as_bytes())?;
248
249        Ok(Self {
250            cert_path: cert_path.as_ref().to_string_lossy().to_string(),
251            key_path: key_path.as_ref().to_string_lossy().to_string(),
252            client_ca_path: None,
253            require_client_auth: false,
254            tls_versions: None,
255            cipher_suites: None,
256            alpn_protocols: Some(vec![b"h2".to_vec(), b"http/1.1".to_vec()]),
257        })
258    }
259
260    /// Load the TLS configuration from files
261    pub fn load_server_config(&self) -> Result<ServerConfig> {
262        // Load certificate
263        let cert_file = File::open(&self.cert_path)
264            .map_err(|e| ProtocolError::TlsError(format!("Failed to open cert file: {e}")))?;
265        let mut cert_reader = BufReader::new(cert_file);
266        let cert_chain: std::result::Result<Vec<_>, _> = certs(&mut cert_reader).collect();
267        let cert_chain: Vec<CertificateDer<'static>> = cert_chain
268            .map_err(|_| ProtocolError::TlsError("Failed to parse certificate".into()))?;
269
270        if cert_chain.is_empty() {
271            return Err(ProtocolError::TlsError("No certificates found".into()));
272        }
273
274        // Load private key
275        let key_file = File::open(&self.key_path)
276            .map_err(|e| ProtocolError::TlsError(format!("Failed to open key file: {e}")))?;
277        let mut key_reader = BufReader::new(key_file);
278        let private_key = load_private_key(&mut key_reader)?;
279
280        // Validate TLS versions if specified
281        // Note: In rustls 0.22, with_safe_defaults() restricts to TLS 1.2+ (best practice)
282        if let Some(versions) = &self.tls_versions {
283            let mut has_tls13 = false;
284            let mut has_tls12 = false;
285            for v in versions {
286                match v {
287                    TlsVersion::TLS12 => has_tls12 = true,
288                    TlsVersion::TLS13 => has_tls13 = true,
289                    TlsVersion::All => {
290                        has_tls13 = true;
291                        has_tls12 = true;
292                    }
293                }
294            }
295            // Document that with_safe_defaults uses best practices
296            debug!(
297                "TLS versions requested: TLS1.2={}, TLS1.3={}",
298                has_tls12, has_tls13
299            );
300        }
301
302        // Create a server configuration with safe defaults (TLS 1.2+, modern ciphersuites)
303        let config_builder = ServerConfig::builder_with_provider(std::sync::Arc::new(
304            rustls::crypto::ring::default_provider(),
305        ))
306        .with_safe_default_protocol_versions()
307        .map_err(|_| ProtocolError::TlsError("Failed to configure TLS protocol versions".into()))?;
308
309        let cert_builder = config_builder.with_no_client_auth();
310
311        // Build config with certificates
312        let mut config = cert_builder
313            .with_single_cert(cert_chain.clone(), private_key.clone_key())
314            .map_err(|e| ProtocolError::TlsError(format!("TLS error: {e}")))?;
315
316        // Configure client authentication if required (mTLS)
317        if let Some(client_ca_path) = &self.client_ca_path {
318            // Load client CA certificates
319            let client_ca_file = File::open(client_ca_path).map_err(|e| {
320                ProtocolError::TlsError(format!("Failed to open client CA file: {e}"))
321            })?;
322            let mut client_ca_reader = BufReader::new(client_ca_file);
323            let client_ca_certs: std::result::Result<Vec<_>, _> =
324                certs(&mut client_ca_reader).collect();
325            let client_ca_certs: Vec<CertificateDer<'static>> = client_ca_certs.map_err(|_| {
326                ProtocolError::TlsError("Failed to parse client CA certificate".into())
327            })?;
328
329            if client_ca_certs.is_empty() {
330                return Err(ProtocolError::TlsError(
331                    "No client CA certificates found".into(),
332                ));
333            }
334
335            // Create client cert verifier
336            let mut client_root_store = RootCertStore::empty();
337            for cert in client_ca_certs {
338                client_root_store.add(cert).map_err(|e| {
339                    ProtocolError::TlsError(format!("Failed to add client CA cert: {e}"))
340                })?;
341            }
342
343            // Create client authentication verifier using WebPkiClientVerifier
344            let client_auth = rustls::server::WebPkiClientVerifier::builder(std::sync::Arc::new(
345                client_root_store,
346            ))
347            .build()
348            .map_err(|e| {
349                ProtocolError::TlsError(format!("Failed to build client verifier: {e}"))
350            })?;
351
352            // Create new config builder with client auth
353            let new_builder = ServerConfig::builder_with_provider(std::sync::Arc::new(
354                rustls::crypto::ring::default_provider(),
355            ))
356            .with_safe_default_protocol_versions()
357            .map_err(|_| {
358                ProtocolError::TlsError("Failed to configure TLS protocol versions".into())
359            })?;
360            let new_cert_builder = new_builder.with_client_cert_verifier(client_auth);
361
362            // Build a new config with certificates and client auth
363            config = new_cert_builder
364                .with_single_cert(cert_chain, private_key.clone_key())
365                .map_err(|e| ProtocolError::TlsError(format!("TLS error with client auth: {e}")))?;
366
367            debug!("mTLS enabled with client certificate verification required");
368        }
369
370        // Configure ALPN protocols if specified
371        if let Some(protocols) = &self.alpn_protocols {
372            config.alpn_protocols = protocols.clone();
373            debug!(
374                protocol_count = protocols.len(),
375                "ALPN protocols configured"
376            );
377        }
378
379        Ok(config)
380    }
381
382    /// Calculate SHA-256 hash for a certificate to use with pinning
383    pub fn calculate_cert_hash(cert: &CertificateDer<'_>) -> Vec<u8> {
384        use sha2::{Digest, Sha256};
385        let mut hasher = Sha256::new();
386        hasher.update(cert.as_ref());
387        hasher.finalize().to_vec()
388    }
389}
390
391/// TLS Client Configuration
392pub struct TlsClientConfig {
393    server_name: String,
394    insecure: bool,
395    /// Optional certificate hash to pin (SHA-256 fingerprint)
396    pinned_cert_hash: Option<Vec<u8>>,
397    /// Optional client certificate path for mTLS
398    client_cert_path: Option<String>,
399    /// Optional client key path for mTLS
400    client_key_path: Option<String>,
401    /// Allowed TLS protocol versions (None = use rustls defaults)
402    tls_versions: Option<Vec<TlsVersion>>,
403    /// Allowed cipher suites (None = use rustls defaults)
404    cipher_suites: Option<Vec<rustls::SupportedCipherSuite>>,
405}
406
407impl TlsClientConfig {
408    /// Create a new TLS client configuration
409    pub fn new<S: Into<String>>(server_name: S) -> Self {
410        Self {
411            server_name: server_name.into(),
412            insecure: false,
413            pinned_cert_hash: None,
414            client_cert_path: None,
415            client_key_path: None,
416            tls_versions: None,
417            cipher_suites: None,
418        }
419    }
420
421    /// Set allowed TLS protocol versions
422    pub fn with_tls_versions(mut self, versions: Vec<TlsVersion>) -> Self {
423        self.tls_versions = Some(versions);
424        self
425    }
426
427    /// Set allowed cipher suites
428    pub fn with_cipher_suites(mut self, cipher_suites: Vec<rustls::SupportedCipherSuite>) -> Self {
429        self.cipher_suites = Some(cipher_suites);
430        self
431    }
432
433    /// Configure client authentication for mTLS
434    pub fn with_client_certificate<S: Into<String>>(mut self, cert_path: S, key_path: S) -> Self {
435        self.client_cert_path = Some(cert_path.into());
436        self.client_key_path = Some(key_path.into());
437        self
438    }
439
440    /// Allow insecure connections (skip certificate verification)
441    ///
442    /// # WARNING: Security Risk
443    /// This mode disables certificate verification entirely and should ONLY be used for:
444    /// - Development and testing
445    /// - Debugging environments
446    /// - Internal networks with certificate pinning enabled
447    ///
448    /// **NEVER** use this in production without explicit certificate pinning via `with_pinned_cert_hash()`.
449    ///
450    /// For maximum security, only use this with the "dangerous_configuration" feature enabled,
451    /// which is a strong indicator this is for testing/development only.
452    pub fn insecure(mut self) -> Self {
453        warn!("INSECURE MODE ENABLED: Certificate verification is disabled. This should only be used for development/testing.");
454        self.insecure = true;
455        self
456    }
457
458    /// Pin a certificate by its SHA-256 hash/fingerprint
459    ///
460    /// This provides additional security by only accepting connections
461    /// from servers with the exact certificate matching this hash.
462    /// Can be combined with insecure mode for development environments where
463    /// you want to skip standard CA verification but still verify a specific cert.
464    pub fn with_pinned_cert_hash(mut self, hash: Vec<u8>) -> Self {
465        if hash.len() != 32 {
466            warn!(
467                "Certificate hash has unexpected length: {} (expected 32 bytes for SHA-256)",
468                hash.len()
469            );
470        }
471        self.pinned_cert_hash = Some(hash);
472        self
473    }
474
475    /// Load the TLS client configuration
476    pub fn load_client_config(&self) -> Result<ClientConfig> {
477        self.log_tls_version_info();
478
479        if self.insecure {
480            self.build_insecure_client_config()
481        } else {
482            self.build_secure_client_config()
483        }
484    }
485
486    /// Log TLS version configuration
487    fn log_tls_version_info(&self) {
488        if let Some(versions) = &self.tls_versions {
489            let mut has_tls13 = false;
490            let mut has_tls12 = false;
491            for v in versions {
492                match v {
493                    TlsVersion::TLS12 => has_tls12 = true,
494                    TlsVersion::TLS13 => has_tls13 = true,
495                    TlsVersion::All => {
496                        has_tls13 = true;
497                        has_tls12 = true;
498                    }
499                }
500            }
501            debug!(
502                "TLS client versions requested: TLS1.2={}, TLS1.3={}",
503                has_tls12, has_tls13
504            );
505        }
506    }
507
508    /// Build secure client config with system root CAs
509    fn build_secure_client_config(&self) -> Result<ClientConfig> {
510        let root_store = self.load_system_root_certificates()?;
511        let builder = ClientConfig::builder_with_provider(std::sync::Arc::new(
512            rustls::crypto::ring::default_provider(),
513        ))
514        .with_safe_default_protocol_versions()
515        .map_err(|_| ProtocolError::TlsError("Failed to configure TLS protocol versions".into()))?
516        .with_root_certificates(root_store);
517
518        // Apply client auth directly
519        if let (Some(client_cert_path), Some(client_key_path)) =
520            (&self.client_cert_path, &self.client_key_path)
521        {
522            let (cert_chain, key) =
523                self.load_client_credentials(client_cert_path, client_key_path)?;
524            builder.with_client_auth_cert(cert_chain, key).map_err(|e| {
525                ProtocolError::TlsError(format!("Failed to set client certificate: {e}"))
526            })
527        } else {
528            Ok(builder.with_no_client_auth())
529        }
530    }
531
532    /// Build insecure client config with custom verifier
533    fn build_insecure_client_config(&self) -> Result<ClientConfig> {
534        let builder = ClientConfig::builder_with_provider(std::sync::Arc::new(
535            rustls::crypto::ring::default_provider(),
536        ))
537        .with_safe_default_protocol_versions()
538        .map_err(|_| ProtocolError::TlsError("Failed to configure TLS protocol versions".into()))?;
539        let verifier = self.create_custom_verifier();
540        let custom_builder = builder
541            .dangerous()
542            .with_custom_certificate_verifier(verifier);
543
544        // Apply client auth directly
545        if let (Some(client_cert_path), Some(client_key_path)) =
546            (&self.client_cert_path, &self.client_key_path)
547        {
548            let (cert_chain, key) =
549                self.load_client_credentials(client_cert_path, client_key_path)?;
550            custom_builder
551                .with_client_auth_cert(cert_chain, key)
552                .map_err(|e| {
553                    ProtocolError::TlsError(format!("Failed to set client certificate: {e}"))
554                })
555        } else {
556            Ok(custom_builder.with_no_client_auth())
557        }
558    }
559
560    /// Load system root certificates
561    fn load_system_root_certificates(&self) -> Result<RootCertStore> {
562        let mut root_store = RootCertStore::empty();
563        let native_certs = rustls_native_certs::load_native_certs()
564            .map_err(|e| ProtocolError::TlsError(format!("Failed to load native certs: {e}")))?;
565
566        for cert in native_certs {
567            root_store.add(cert).map_err(|e| {
568                ProtocolError::TlsError(format!("Failed to add cert to root store: {e}"))
569            })?;
570        }
571
572        Ok(root_store)
573    }
574
575    /// Create custom certificate verifier (pinning or accept-any)
576    fn create_custom_verifier(&self) -> Arc<dyn ServerCertVerifier> {
577        if let Some(hash) = &self.pinned_cert_hash {
578            Arc::new(CertificateFingerprint {
579                fingerprint: hash.clone(),
580            })
581        } else {
582            Arc::new(AcceptAnyServerCert)
583        }
584    }
585
586    /// Load client certificate and private key
587    fn load_client_credentials(
588        &self,
589        cert_path: &str,
590        key_path: &str,
591    ) -> Result<(Vec<CertificateDer<'static>>, PrivateKeyDer<'static>)> {
592        // Load certificate
593        let cert_file = File::open(cert_path).map_err(ProtocolError::Io)?;
594        let mut cert_reader = BufReader::new(cert_file);
595        let certs_result: std::result::Result<Vec<_>, _> =
596            rustls_pemfile::certs(&mut cert_reader).collect();
597        let certs: Vec<CertificateDer<'static>> = certs_result
598            .map_err(|_| ProtocolError::TlsError("Failed to parse client certificate".into()))?;
599
600        if certs.is_empty() {
601            return Err(ProtocolError::TlsError(
602                "No client certificates found".into(),
603            ));
604        }
605
606        // Load private key
607        let key_file = File::open(key_path).map_err(ProtocolError::Io)?;
608        let mut key_reader = BufReader::new(key_file);
609        let key = load_private_key(&mut key_reader)?;
610
611        Ok((certs, key))
612    }
613
614    /// Get the server name as a rustls::ServerName
615    pub fn server_name(&self) -> Result<ServerName<'_>> {
616        ServerName::try_from(self.server_name.as_str())
617            .map_err(|_| ProtocolError::TlsError("Invalid server name".into()))
618    }
619
620    /// Get the server name as an owned String
621    pub fn server_name_string(&self) -> String {
622        self.server_name.clone()
623    }
624}
625
626/// Start a TLS server on the given address
627#[instrument(skip(config))]
628pub async fn start_server(addr: &str, config: TlsServerConfig) -> Result<()> {
629    let tls_config = config.load_server_config()?;
630    let acceptor = TlsAcceptor::from(Arc::new(tls_config));
631    let listener = TcpListener::bind(addr).await?;
632
633    info!(address=%addr, "TLS server listening");
634
635    loop {
636        let (stream, peer) = listener.accept().await?;
637        let acceptor = acceptor.clone();
638
639        tokio::spawn(async move {
640            match acceptor.accept(stream).await {
641                Ok(tls_stream) => {
642                    if let Err(e) = handle_tls_connection(tls_stream, peer).await {
643                        error!(%peer, error=%e, "Connection error");
644                    }
645                }
646                Err(e) => {
647                    error!(%peer, error=%e, "TLS handshake failed");
648                }
649            }
650        });
651    }
652}
653
654/// Handle a TLS connection
655#[instrument(skip(tls_stream), fields(peer=%peer))]
656async fn handle_tls_connection(
657    tls_stream: ServerTlsStream<TcpStream>,
658    peer: SocketAddr,
659) -> Result<()> {
660    let mut framed = Framed::new(tls_stream, PacketCodec);
661
662    info!("TLS connection established");
663
664    while let Some(packet) = framed.next().await {
665        match packet {
666            Ok(pkt) => {
667                debug!(bytes = pkt.payload.len(), "Received data");
668                on_packet(pkt, &mut framed).await?;
669            }
670            Err(e) => {
671                error!(error=%e, "Protocol error");
672                break;
673            }
674        }
675    }
676
677    info!("TLS connection closed");
678    Ok(())
679}
680
681/// Handle incoming TLS packets
682#[instrument(skip(framed), fields(packet_version=pkt.version, payload_size=pkt.payload.len()))]
683async fn on_packet<T>(pkt: Packet, framed: &mut Framed<T, PacketCodec>) -> Result<()>
684where
685    T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
686{
687    // Echo the packet back (sample implementation)
688    let response = Packet {
689        version: pkt.version,
690        payload: pkt.payload,
691    };
692
693    framed.send(response).await?;
694    Ok(())
695}
696
697/// Connect to a TLS server
698pub async fn connect(
699    addr: &str,
700    config: TlsClientConfig,
701) -> Result<Framed<ClientTlsStream<TcpStream>, PacketCodec>> {
702    let tls_config = Arc::new(config.load_client_config()?);
703    let connector = TlsConnector::from(tls_config);
704
705    let stream = TcpStream::connect(addr).await?;
706
707    // Create ServerName from owned string to ensure 'static lifetime
708    // Note: Box::leak() is used here to satisfy tokio_rustls' 'static requirement
709    let server_name_str = config.server_name_string();
710    let domain_static: &'static str = Box::leak(server_name_str.into_boxed_str());
711    let domain = ServerName::try_from(domain_static)
712        .map_err(|_| ProtocolError::TlsError("Invalid server name".into()))?;
713
714    let tls_stream = connector
715        .connect(domain, stream)
716        .await
717        .map_err(|e| ProtocolError::TlsError(format!("TLS connection failed: {e}")))?;
718
719    let framed = Framed::new(tls_stream, PacketCodec);
720    Ok(framed)
721}