network_protocol/service/
tls_client.rs1use futures::{SinkExt, StreamExt};
2use std::sync::Arc;
3use tokio::net::TcpStream;
4use tokio_rustls::client::TlsStream;
5use tokio_util::codec::Framed;
6use tracing::{debug, instrument};
7
8use crate::core::codec::PacketCodec;
9use crate::core::packet::Packet;
10use crate::error::Result;
11use crate::protocol::message::Message;
12use crate::transport::session_cache::SessionCache;
13use crate::transport::tls::TlsClientConfig;
14
15pub struct TlsClient {
21 framed: Framed<TlsStream<TcpStream>, PacketCodec>,
22 session_cache: Option<Arc<SessionCache>>,
24 session_id: Option<String>,
26}
27
28impl TlsClient {
29 #[instrument(skip(config))]
31 pub async fn connect(addr: &str, config: TlsClientConfig) -> Result<Self> {
32 Self::connect_with_session(addr, config, None).await
33 }
34
35 #[instrument(skip(config, session_cache))]
52 pub async fn connect_with_session(
53 addr: &str,
54 config: TlsClientConfig,
55 session_cache: Option<Arc<SessionCache>>,
56 ) -> Result<Self> {
57 let tls_config = config.load_client_config()?;
58 let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config));
59
60 let stream = TcpStream::connect(addr).await?;
61 let domain = config.server_name()?;
62
63 let tls_stream = connector.connect(domain, stream).await?;
64 let framed = Framed::new(tls_stream, PacketCodec);
65
66 let session_id = format!(
67 "{}_{}",
68 addr,
69 std::time::SystemTime::now()
70 .duration_since(std::time::UNIX_EPOCH)
71 .unwrap_or_default()
72 .as_secs()
73 );
74
75 if let Some(ref _cache) = session_cache {
77 debug!("Session resumption enabled");
78 }
79
80 Ok(Self {
81 framed,
82 session_cache,
83 session_id: Some(session_id),
84 })
85 }
86
87 pub async fn send(&mut self, message: Message) -> Result<()> {
89 let bytes = bincode::serialize(&message)?;
90 let packet = Packet {
91 version: 1,
92 payload: bytes,
93 };
94
95 self.framed.send(packet).await?;
96 Ok(())
97 }
98
99 pub async fn receive(&mut self) -> Result<Message> {
101 let packet = match self.framed.next().await {
102 Some(Ok(pkt)) => pkt,
103 Some(Err(e)) => return Err(e),
104 None => {
105 return Err(crate::error::ProtocolError::Custom(
106 "Connection closed".to_string(),
107 ))
108 }
109 };
110
111 let message = bincode::deserialize(&packet.payload)?;
112 Ok(message)
113 }
114
115 pub async fn request(&mut self, message: Message) -> Result<Message> {
117 self.send(message).await?;
118 self.receive().await
119 }
120
121 pub fn session_cache(&self) -> Option<&SessionCache> {
123 self.session_cache.as_deref()
124 }
125
126 pub fn session_id(&self) -> Option<&str> {
128 self.session_id.as_deref()
129 }
130}