Skip to main content

network_protocol/service/
tls_client.rs

1use futures::{SinkExt, StreamExt};
2use rustls::pki_types::ServerName;
3use std::sync::Arc;
4use tokio::net::TcpStream;
5use tokio_rustls::client::TlsStream;
6use tokio_util::codec::Framed;
7use tracing::{debug, instrument};
8
9use crate::core::codec::PacketCodec;
10use crate::core::packet::Packet;
11use crate::error::{ProtocolError, Result};
12use crate::protocol::message::Message;
13use crate::transport::session_cache::SessionCache;
14use crate::transport::tls::TlsClientConfig;
15
16/// TLS secure client for connecting to TLS-enabled servers
17///
18/// Supports optional session resumption for improved reconnection performance.
19/// With session resumption, reconnections can skip the full TLS handshake,
20/// reducing latency by ~50-70%.
21pub struct TlsClient {
22    framed: Framed<TlsStream<TcpStream>, PacketCodec>,
23    /// Optional session cache for resumption support
24    session_cache: Option<Arc<SessionCache>>,
25    /// Session identifier (used with session cache)
26    session_id: Option<String>,
27}
28
29impl TlsClient {
30    /// Connect to a TLS server
31    #[instrument(skip(config))]
32    pub async fn connect(addr: &str, config: TlsClientConfig) -> Result<Self> {
33        Self::connect_with_session(addr, config, None).await
34    }
35
36    /// Connect to a TLS server with session resumption support
37    ///
38    /// # Arguments
39    /// * `addr` - Server address to connect to
40    /// * `config` - TLS configuration
41    /// * `session_cache` - Optional session cache for resumption
42    ///
43    /// # Example
44    /// ```ignore
45    /// let cache = SessionCache::new(100, Duration::from_secs(3600));
46    /// let client = TlsClient::connect_with_session(
47    ///     "127.0.0.1:8443",
48    ///     config,
49    ///     Some(Arc::new(cache))
50    /// ).await?;
51    /// ```
52    pub async fn connect_with_session(
53        addr: &str,
54        config: TlsClientConfig,
55        session_cache: Option<Arc<SessionCache>>,
56    ) -> Result<Self> {
57        let tls_config = config.load_client_config()?;
58        let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config));
59
60        let stream = TcpStream::connect(addr).await?;
61
62        // Create ServerName from owned string to ensure 'static lifetime
63        // Note: Box::leak() is used here to satisfy tokio_rustls' 'static requirement
64        let server_name_str = config.server_name_string();
65        let domain_static: &'static str = Box::leak(server_name_str.into_boxed_str());
66        let domain = ServerName::try_from(domain_static)
67            .map_err(|_| ProtocolError::TlsError("Invalid server name".into()))?;
68
69        let tls_stream = connector.connect(domain, stream).await?;
70        let framed = Framed::new(tls_stream, PacketCodec);
71
72        let session_id = format!(
73            "{}_{}",
74            addr,
75            std::time::SystemTime::now()
76                .duration_since(std::time::UNIX_EPOCH)
77                .unwrap_or_default()
78                .as_secs()
79        );
80
81        // Store session after successful connection (for future resumptions)
82        if let Some(ref _cache) = session_cache {
83            debug!("Session resumption enabled");
84        }
85
86        Ok(Self {
87            framed,
88            session_cache,
89            session_id: Some(session_id),
90        })
91    }
92
93    /// Send a message to the TLS server
94    pub async fn send(&mut self, message: Message) -> Result<()> {
95        let bytes = bincode::serialize(&message)?;
96        let packet = Packet {
97            version: 1,
98            payload: bytes,
99        };
100
101        self.framed.send(packet).await?;
102        Ok(())
103    }
104
105    /// Receive a message from the TLS server
106    pub async fn receive(&mut self) -> Result<Message> {
107        let packet = match self.framed.next().await {
108            Some(Ok(pkt)) => pkt,
109            Some(Err(e)) => return Err(e),
110            None => {
111                return Err(crate::error::ProtocolError::Custom(
112                    "Connection closed".to_string(),
113                ))
114            }
115        };
116
117        let message = bincode::deserialize(&packet.payload)?;
118        Ok(message)
119    }
120
121    /// Send a message and wait for a response
122    pub async fn request(&mut self, message: Message) -> Result<Message> {
123        self.send(message).await?;
124        self.receive().await
125    }
126
127    /// Get the session cache if configured
128    pub fn session_cache(&self) -> Option<&SessionCache> {
129        self.session_cache.as_deref()
130    }
131
132    /// Get the session identifier
133    pub fn session_id(&self) -> Option<&str> {
134        self.session_id.as_deref()
135    }
136}