use futures::{SinkExt, StreamExt};
use rustls::pki_types::ServerName;
use std::sync::Arc;
use tokio::net::TcpStream;
use tokio_rustls::client::TlsStream;
use tokio_util::codec::Framed;
use tracing::{debug, instrument};
use crate::core::codec::PacketCodec;
use crate::core::packet::Packet;
use crate::error::{ProtocolError, Result};
use crate::protocol::message::Message;
use crate::transport::session_cache::SessionCache;
use crate::transport::tls::TlsClientConfig;
pub struct TlsClient {
framed: Framed<TlsStream<TcpStream>, PacketCodec>,
session_cache: Option<Arc<SessionCache>>,
session_id: Option<String>,
}
impl TlsClient {
#[instrument(skip(config))]
pub async fn connect(addr: &str, config: TlsClientConfig) -> Result<Self> {
Self::connect_with_session(addr, config, None).await
}
pub async fn connect_with_session(
addr: &str,
config: TlsClientConfig,
session_cache: Option<Arc<SessionCache>>,
) -> Result<Self> {
let tls_config = config.load_client_config()?;
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config));
let stream = TcpStream::connect(addr).await?;
let server_name_str = config.server_name_string();
let domain_static: &'static str = Box::leak(server_name_str.into_boxed_str());
let domain = ServerName::try_from(domain_static)
.map_err(|_| ProtocolError::TlsError("Invalid server name".into()))?;
let tls_stream = connector.connect(domain, stream).await?;
let framed = Framed::new(tls_stream, PacketCodec);
let session_id = format!(
"{}_{}",
addr,
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
);
if let Some(ref _cache) = session_cache {
debug!("Session resumption enabled");
}
Ok(Self {
framed,
session_cache,
session_id: Some(session_id),
})
}
pub async fn send(&mut self, message: Message) -> Result<()> {
let bytes = bincode::serialize(&message)?;
let packet = Packet {
version: 1,
payload: bytes,
};
self.framed.send(packet).await?;
Ok(())
}
pub async fn receive(&mut self) -> Result<Message> {
let packet = match self.framed.next().await {
Some(Ok(pkt)) => pkt,
Some(Err(e)) => return Err(e),
None => {
return Err(crate::error::ProtocolError::Custom(
"Connection closed".to_string(),
))
}
};
let message = bincode::deserialize(&packet.payload)?;
Ok(message)
}
pub async fn request(&mut self, message: Message) -> Result<Message> {
self.send(message).await?;
self.receive().await
}
pub fn session_cache(&self) -> Option<&SessionCache> {
self.session_cache.as_deref()
}
pub fn session_id(&self) -> Option<&str> {
self.session_id.as_deref()
}
}