network_protocol/service/
tls_client.rs1use futures::{SinkExt, StreamExt};
2use rustls::pki_types::ServerName;
3use std::sync::Arc;
4use tokio::net::TcpStream;
5use tokio_rustls::client::TlsStream;
6use tokio_util::codec::Framed;
7use tracing::{debug, instrument};
8
9use crate::core::codec::PacketCodec;
10use crate::core::packet::Packet;
11use crate::error::{ProtocolError, Result};
12use crate::protocol::message::Message;
13use crate::transport::session_cache::SessionCache;
14use crate::transport::tls::TlsClientConfig;
15
16pub struct TlsClient {
22 framed: Framed<TlsStream<TcpStream>, PacketCodec>,
23 session_cache: Option<Arc<SessionCache>>,
25 session_id: Option<String>,
27}
28
29impl TlsClient {
30 #[instrument(skip(config))]
32 pub async fn connect(addr: &str, config: TlsClientConfig) -> Result<Self> {
33 Self::connect_with_session(addr, config, None).await
34 }
35
36 pub async fn connect_with_session(
53 addr: &str,
54 config: TlsClientConfig,
55 session_cache: Option<Arc<SessionCache>>,
56 ) -> Result<Self> {
57 let tls_config = config.load_client_config()?;
58 let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config));
59
60 let stream = TcpStream::connect(addr).await?;
61
62 let server_name_str = config.server_name_string();
65 let domain_static: &'static str = Box::leak(server_name_str.into_boxed_str());
66 let domain = ServerName::try_from(domain_static)
67 .map_err(|_| ProtocolError::TlsError("Invalid server name".into()))?;
68
69 let tls_stream = connector.connect(domain, stream).await?;
70 let framed = Framed::new(tls_stream, PacketCodec);
71
72 let session_id = format!(
73 "{}_{}",
74 addr,
75 std::time::SystemTime::now()
76 .duration_since(std::time::UNIX_EPOCH)
77 .unwrap_or_default()
78 .as_secs()
79 );
80
81 if let Some(ref _cache) = session_cache {
83 debug!("Session resumption enabled");
84 }
85
86 Ok(Self {
87 framed,
88 session_cache,
89 session_id: Some(session_id),
90 })
91 }
92
93 pub async fn send(&mut self, message: Message) -> Result<()> {
95 let bytes = bincode::serialize(&message)?;
96 let packet = Packet {
97 version: 1,
98 payload: bytes,
99 };
100
101 self.framed.send(packet).await?;
102 Ok(())
103 }
104
105 pub async fn receive(&mut self) -> Result<Message> {
107 let packet = match self.framed.next().await {
108 Some(Ok(pkt)) => pkt,
109 Some(Err(e)) => return Err(e),
110 None => {
111 return Err(crate::error::ProtocolError::Custom(
112 "Connection closed".to_string(),
113 ))
114 }
115 };
116
117 let message = bincode::deserialize(&packet.payload)?;
118 Ok(message)
119 }
120
121 pub async fn request(&mut self, message: Message) -> Result<Message> {
123 self.send(message).await?;
124 self.receive().await
125 }
126
127 pub fn session_cache(&self) -> Option<&SessionCache> {
129 self.session_cache.as_deref()
130 }
131
132 pub fn session_id(&self) -> Option<&str> {
134 self.session_id.as_deref()
135 }
136}