use std::fs::File;
use std::io::BufReader;
use std::path::Path;
use std::sync::Arc;
use std::time::Duration;
use futures_util::{SinkExt, StreamExt};
use log::{error, info, warn};
use rustls::RootCertStore;
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use std::collections::HashMap;
use tokio::sync::{Mutex, mpsc};
use tokio::task::JoinHandle;
use tokio::time::timeout;
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::{Connector, connect_async_tls_with_config};
use url::Url;
#[derive(Debug, Clone)]
pub enum MessageType {
Text(String),
Binary(Vec<u8>),
}
#[derive(Debug, Clone)]
pub struct WSClientConfig {
pub channel_capacity: usize,
pub connection_timeout: Duration,
pub auto_reconnect: bool,
pub max_reconnect_attempts: u32,
pub reconnect_delay: Duration,
}
impl Default for WSClientConfig {
fn default() -> Self {
Self {
channel_capacity: 100,
connection_timeout: Duration::from_secs(30),
auto_reconnect: false,
max_reconnect_attempts: 5,
reconnect_delay: Duration::from_secs(2),
}
}
}
pub struct WebSocketClientBuilder {
config: WSClientConfig,
}
impl WebSocketClientBuilder {
pub fn new() -> Self {
Self {
config: WSClientConfig::default(),
}
}
pub fn with_channel_capacity(mut self, capacity: usize) -> Self {
self.config.channel_capacity = capacity;
self
}
pub fn with_connection_timeout(mut self, timeout: Duration) -> Self {
self.config.connection_timeout = timeout;
self
}
pub fn with_auto_reconnect(mut self, auto_reconnect: bool) -> Self {
self.config.auto_reconnect = auto_reconnect;
self
}
pub fn with_max_reconnect_attempts(mut self, attempts: u32) -> Self {
self.config.max_reconnect_attempts = attempts;
self
}
pub fn with_reconnect_delay(mut self, delay: Duration) -> Self {
self.config.reconnect_delay = delay;
self
}
pub fn build(self) -> WebSocketClient {
WebSocketClient {
sender: None,
receiver: None,
ws_handle: None,
is_connected: false,
server_url: None,
cert_paths: None,
config: self.config,
cert_cache: Arc::new(Mutex::new(HashMap::new())),
}
}
}
pub struct WebSocketClient {
sender: Option<mpsc::Sender<MessageType>>,
receiver: Option<mpsc::Receiver<MessageType>>,
ws_handle: Option<JoinHandle<()>>,
is_connected: bool,
server_url: Option<Url>,
cert_paths: Option<(String, String, String, String, String)>,
config: WSClientConfig,
cert_cache: Arc<Mutex<HashMap<String, Arc<rustls::ClientConfig>>>>,
}
impl WebSocketClient {
pub fn new() -> Self {
Self::builder().build()
}
pub fn builder() -> WebSocketClientBuilder {
WebSocketClientBuilder::new()
}
fn load_certs(path: &Path) -> Result<Vec<CertificateDer<'static>>, Box<dyn std::error::Error>> {
let file = File::open(path)?;
let mut reader = BufReader::new(file);
let certs = rustls_pemfile::certs(&mut reader).collect::<Result<Vec<_>, _>>()?;
if certs.is_empty() {
return Err("No certificates found in file".into());
}
Ok(certs)
}
fn load_private_key(path: &Path) -> Result<PrivateKeyDer<'static>, Box<dyn std::error::Error>> {
let file = File::open(path)?;
let mut reader = BufReader::new(file);
let keys =
rustls_pemfile::pkcs8_private_keys(&mut reader).collect::<Result<Vec<_>, _>>()?;
if keys.is_empty() {
return Err("No private key found in file".into());
}
Ok(PrivateKeyDer::Pkcs8(keys.into_iter().next().unwrap()))
}
async fn create_tls_config(
&self,
cache_key: &str,
client_cert_path: &Path,
client_key_path: &Path,
ca_cert_path: &Path,
) -> Result<Arc<rustls::ClientConfig>, Box<dyn std::error::Error>> {
{
let cache = self.cert_cache.lock().await;
if let Some(config) = cache.get(cache_key) {
info!("Using cached TLS configuration");
return Ok(config.clone());
}
}
let client_certs = Self::load_certs(client_cert_path)?;
let client_key = Self::load_private_key(client_key_path)?;
let ca_certs = Self::load_certs(ca_cert_path)?;
let mut root_store = RootCertStore::empty();
for cert in ca_certs {
root_store.add(cert)?;
}
let client_config = rustls::ClientConfig::builder()
.with_root_certificates(root_store)
.with_client_auth_cert(client_certs, client_key)?;
let config = Arc::new(client_config);
{
let mut cache = self.cert_cache.lock().await;
cache.insert(cache_key.to_string(), config.clone());
}
Ok(config)
}
pub async fn connect(
&mut self,
server_url: &str,
cert_dir: &str,
client_cert_file: &str,
client_key_file: &str,
ca_cert_file: &str,
) -> Result<(), Box<dyn std::error::Error>> {
let server_url = Url::parse(server_url)?;
self.server_url = Some(server_url.clone());
self.cert_paths = Some((
server_url.to_string(),
cert_dir.to_string(),
client_cert_file.to_string(),
client_key_file.to_string(),
ca_cert_file.to_string(),
));
let mut current_attempt = 0;
loop {
let result = self
.connect_internal(
&server_url,
cert_dir,
client_cert_file,
client_key_file,
ca_cert_file,
current_attempt,
)
.await;
match result {
Err(e) => {
let err_str = e.to_string();
if err_str.starts_with("__RETRY_CONNECTION_") {
if let Ok(next_attempt) = err_str
.trim_start_matches("__RETRY_CONNECTION_")
.parse::<u32>()
{
current_attempt = next_attempt;
tokio::time::sleep(self.config.reconnect_delay).await;
continue;
}
}
return Err(e);
}
Ok(_) => return Ok(()),
}
}
}
async fn connect_internal(
&mut self,
server_url: &Url,
cert_dir: &str,
client_cert_file: &str,
client_key_file: &str,
ca_cert_file: &str,
attempt: u32,
) -> Result<(), Box<dyn std::error::Error>> {
let cert_dir = Path::new(cert_dir);
let client_cert = cert_dir.join(client_cert_file);
let client_key = cert_dir.join(client_key_file);
let ca_cert = cert_dir.join(ca_cert_file);
info!("Client certificate: {:?}", client_cert);
info!("Client private key: {:?}", client_key);
info!("CA certificate: {:?}", ca_cert);
let cache_key = format!(
"{}:{}:{}:{}",
server_url,
client_cert.display(),
client_key.display(),
ca_cert.display()
);
info!("Loading certificates and keys...");
let tls_config = match self
.create_tls_config(&cache_key, &client_cert, &client_key, &ca_cert)
.await
{
Ok(config) => config,
Err(e) => {
error!("Failed to create TLS configuration: {}", e);
return Err(e);
}
};
let connector = Connector::Rustls(tls_config);
info!("Connecting to WebSocket server: {}", server_url);
let connection_attempt =
connect_async_tls_with_config(server_url.clone(), None, false, Some(connector));
let ws_stream = match timeout(self.config.connection_timeout, connection_attempt).await {
Ok(result) => {
match result {
Ok((stream, _)) => stream,
Err(e) => {
error!("Connection error: {}", e);
if self.config.auto_reconnect
&& attempt < self.config.max_reconnect_attempts
{
warn!(
"Reconnection attempt {}/{} in {}s",
attempt + 1,
self.config.max_reconnect_attempts,
self.config.reconnect_delay.as_secs()
);
tokio::time::sleep(self.config.reconnect_delay).await;
return Err(format!("__RETRY_CONNECTION_{}", attempt + 1).into());
}
return Err(e.into());
}
}
}
Err(_) => {
let err = format!(
"Connection timeout after {:?}",
self.config.connection_timeout
);
error!("{}", err);
if self.config.auto_reconnect && attempt < self.config.max_reconnect_attempts {
warn!(
"Reconnection attempt {}/{} in {}s",
attempt + 1,
self.config.max_reconnect_attempts,
self.config.reconnect_delay.as_secs()
);
tokio::time::sleep(self.config.reconnect_delay).await;
return Err(format!("__RETRY_CONNECTION_{}", attempt + 1).into());
}
return Err(err.into());
}
};
info!("Connected to WebSocket server");
let (tx_sender, mut rx_sender) = mpsc::channel::<MessageType>(self.config.channel_capacity);
let (tx_receiver, rx_receiver) = mpsc::channel::<MessageType>(self.config.channel_capacity);
let (mut ws_sender, mut ws_receiver) = ws_stream.split();
let send_task = tokio::spawn(async move {
while let Some(message) = rx_sender.recv().await {
let ws_message = match message {
MessageType::Text(text) => Message::Text(text),
MessageType::Binary(data) => Message::Binary(data),
};
match ws_sender.send(ws_message).await {
Ok(_) => info!("Message sent"),
Err(e) => {
error!("Error sending message: {}", e);
break;
}
}
}
let _ = ws_sender.close().await;
});
let receive_task = tokio::spawn(async move {
while let Some(msg) = ws_receiver.next().await {
match msg {
Ok(msg) => {
let message = match msg {
Message::Text(text) => {
info!("Received text message: {} bytes", text.len());
MessageType::Text(text)
}
Message::Binary(data) => {
info!("Received binary message: {} bytes", data.len());
MessageType::Binary(data)
}
Message::Ping(_) | Message::Pong(_) => {
continue;
}
Message::Close(_) => {
info!("Received close frame");
break;
}
_ => continue,
};
if let Err(e) = tx_receiver.send(message).await {
error!("Error forwarding to receiver channel: {}", e);
break;
}
}
Err(e) => {
error!("Error receiving message: {}", e);
break;
}
}
}
});
let handle = tokio::spawn(async move {
tokio::select! {
_ = send_task => info!("Send task completed"),
_ = receive_task => info!("Receive task completed"),
}
});
self.sender = Some(tx_sender);
self.receiver = Some(rx_receiver);
self.ws_handle = Some(handle);
self.is_connected = true;
Ok(())
}
pub async fn reconnect(&mut self) -> Result<(), Box<dyn std::error::Error>> {
if let Some((url, cert_dir, client_cert, client_key, ca_cert)) = self.cert_paths.clone() {
if self.is_connected {
self.close().await;
}
self.connect(&url, &cert_dir, &client_cert, &client_key, &ca_cert)
.await
} else {
Err("No previous connection parameters available for reconnection".into())
}
}
pub async fn send_message(
&self,
message: MessageType,
) -> Result<(), Box<dyn std::error::Error>> {
if let Some(sender) = &self.sender {
sender.send(message).await?;
Ok(())
} else {
Err("Not connected to WebSocket server".into())
}
}
pub async fn send_text(&self, text: String) -> Result<(), Box<dyn std::error::Error>> {
self.send_message(MessageType::Text(text)).await
}
pub async fn send_binary(&self, data: Vec<u8>) -> Result<(), Box<dyn std::error::Error>> {
self.send_message(MessageType::Binary(data)).await
}
pub async fn receive_message(&mut self) -> Option<MessageType> {
if let Some(receiver) = &mut self.receiver {
receiver.recv().await
} else {
None
}
}
pub async fn receive_message_timeout(
&mut self,
timeout_duration: Duration,
) -> Result<Option<MessageType>, tokio::time::error::Elapsed> {
if let Some(receiver) = &mut self.receiver {
timeout(timeout_duration, receiver.recv()).await
} else {
Ok(None)
}
}
pub fn is_connected(&self) -> bool {
self.is_connected
}
pub async fn close(&mut self) {
self.sender = None;
if let Some(handle) = self.ws_handle.take() {
let _ = handle.await;
}
self.receiver = None;
self.is_connected = false;
info!("WebSocket connection closed");
}
pub async fn ping(&self) -> Result<(), Box<dyn std::error::Error>> {
if let Some(sender) = &self.sender {
sender.send(MessageType::Binary(Vec::new())).await?;
Ok(())
} else {
Err("Not connected to WebSocket server".into())
}
}
pub async fn clear_cert_cache(&self) {
let mut cache = self.cert_cache.lock().await;
cache.clear();
info!("Certificate cache cleared");
}
pub async fn check_connection(&self) -> bool {
if !self.is_connected {
return false;
}
match self.ping().await {
Ok(_) => true,
Err(_) => false,
}
}
pub fn get_config(&self) -> &WSClientConfig {
&self.config
}
}
impl Drop for WebSocketClient {
fn drop(&mut self) {
self.sender = None;
self.receiver = None;
self.ws_handle = None;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_message_type() {
let text = MessageType::Text("hello".to_string());
let binary = MessageType::Binary(vec![1, 2, 3]);
match text {
MessageType::Text(s) => assert_eq!(s, "hello"),
_ => panic!("Expected Text variant"),
}
match binary {
MessageType::Binary(b) => assert_eq!(b, vec![1, 2, 3]),
_ => panic!("Expected Binary variant"),
}
}
#[test]
fn test_client_config_default() {
let config = WSClientConfig::default();
assert_eq!(config.channel_capacity, 100);
assert_eq!(config.connection_timeout, Duration::from_secs(30));
assert_eq!(config.auto_reconnect, false);
}
#[test]
fn test_client_builder() {
let client = WebSocketClient::builder()
.with_channel_capacity(200)
.with_connection_timeout(Duration::from_secs(10))
.with_auto_reconnect(true)
.build();
assert_eq!(client.config.channel_capacity, 200);
assert_eq!(client.config.connection_timeout, Duration::from_secs(10));
assert_eq!(client.config.auto_reconnect, true);
}
}