Skip to main content

network_protocol/service/
tls_client.rs

1use futures::{SinkExt, StreamExt};
2use std::sync::Arc;
3use tokio::net::TcpStream;
4use tokio_rustls::client::TlsStream;
5use tokio_util::codec::Framed;
6use tracing::{debug, instrument};
7
8use crate::core::codec::PacketCodec;
9use crate::core::packet::Packet;
10use crate::error::Result;
11use crate::protocol::message::Message;
12use crate::transport::session_cache::SessionCache;
13use crate::transport::tls::TlsClientConfig;
14
15/// TLS secure client for connecting to TLS-enabled servers
16///
17/// Supports optional session resumption for improved reconnection performance.
18/// With session resumption, reconnections can skip the full TLS handshake,
19/// reducing latency by ~50-70%.
20pub struct TlsClient {
21    framed: Framed<TlsStream<TcpStream>, PacketCodec>,
22    /// Optional session cache for resumption support
23    session_cache: Option<Arc<SessionCache>>,
24    /// Session identifier (used with session cache)
25    session_id: Option<String>,
26}
27
28impl TlsClient {
29    /// Connect to a TLS server
30    #[instrument(skip(config))]
31    pub async fn connect(addr: &str, config: TlsClientConfig) -> Result<Self> {
32        Self::connect_with_session(addr, config, None).await
33    }
34
35    /// Connect to a TLS server with session resumption support
36    ///
37    /// # Arguments
38    /// * `addr` - Server address to connect to
39    /// * `config` - TLS configuration
40    /// * `session_cache` - Optional session cache for resumption
41    ///
42    /// # Example
43    /// ```ignore
44    /// let cache = SessionCache::new(100, Duration::from_secs(3600));
45    /// let client = TlsClient::connect_with_session(
46    ///     "127.0.0.1:8443",
47    ///     config,
48    ///     Some(Arc::new(cache))
49    /// ).await?;
50    /// ```
51    #[instrument(skip(config, session_cache))]
52    pub async fn connect_with_session(
53        addr: &str,
54        config: TlsClientConfig,
55        session_cache: Option<Arc<SessionCache>>,
56    ) -> Result<Self> {
57        let tls_config = config.load_client_config()?;
58        let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config));
59
60        let stream = TcpStream::connect(addr).await?;
61        let domain = config.server_name()?;
62
63        let tls_stream = connector.connect(domain, stream).await?;
64        let framed = Framed::new(tls_stream, PacketCodec);
65
66        let session_id = format!(
67            "{}_{}",
68            addr,
69            std::time::SystemTime::now()
70                .duration_since(std::time::UNIX_EPOCH)
71                .unwrap_or_default()
72                .as_secs()
73        );
74
75        // Store session after successful connection (for future resumptions)
76        if let Some(ref _cache) = session_cache {
77            debug!("Session resumption enabled");
78        }
79
80        Ok(Self {
81            framed,
82            session_cache,
83            session_id: Some(session_id),
84        })
85    }
86
87    /// Send a message to the TLS server
88    pub async fn send(&mut self, message: Message) -> Result<()> {
89        let bytes = bincode::serialize(&message)?;
90        let packet = Packet {
91            version: 1,
92            payload: bytes,
93        };
94
95        self.framed.send(packet).await?;
96        Ok(())
97    }
98
99    /// Receive a message from the TLS server
100    pub async fn receive(&mut self) -> Result<Message> {
101        let packet = match self.framed.next().await {
102            Some(Ok(pkt)) => pkt,
103            Some(Err(e)) => return Err(e),
104            None => {
105                return Err(crate::error::ProtocolError::Custom(
106                    "Connection closed".to_string(),
107                ))
108            }
109        };
110
111        let message = bincode::deserialize(&packet.payload)?;
112        Ok(message)
113    }
114
115    /// Send a message and wait for a response
116    pub async fn request(&mut self, message: Message) -> Result<Message> {
117        self.send(message).await?;
118        self.receive().await
119    }
120
121    /// Get the session cache if configured
122    pub fn session_cache(&self) -> Option<&SessionCache> {
123        self.session_cache.as_deref()
124    }
125
126    /// Get the session identifier
127    pub fn session_id(&self) -> Option<&str> {
128        self.session_id.as_deref()
129    }
130}