use crate::types::{
Event, EventFilter, EventTopic, SubscriptionErrorResponse, SubscriptionResponse,
SubscriptionStatus, UnsubscribeResponse, WebSocketRequest,
};
use crate::{SubscriptionRequest, UnsubscribeRequest};
use futures::future::{BoxFuture, FutureExt};
use futures::{SinkExt, StreamExt};
use serde::Deserialize;
use std::collections::HashMap;
use std::future::Future;
use std::panic::{self, AssertUnwindSafe};
use std::sync::Arc;
use std::time::Duration;
use tokio::net::TcpStream;
use tokio::sync::{mpsc, Mutex, RwLock};
use tokio_tungstenite::tungstenite::Bytes;
use tokio_tungstenite::{connect_async, tungstenite::protocol::Message, WebSocketStream};
use tracing::{debug, error, info, warn};
#[derive(Debug, Clone, thiserror::Error)]
pub enum WebSocketError {
#[error("Failed to connect to the server: {0}")]
ConnectionFailed(String),
#[error("Failed to send a message: {0}")]
SendFailed(String),
#[error("Failed to parse a response: {0}")]
ParseError(String),
#[error("Failed to subscribe: {0}")]
SubscriptionFailed(String),
#[error("Failed to unsubscribe: {0}")]
UnsubscriptionFailed(String),
#[error("Failed to read from the WebSocket: {0}")]
ReadFailed(String),
#[error("Other error: {0}")]
Other(String),
}
pub type EventCallback = Box<dyn Fn(Event) + Send + Sync + 'static>;
pub type ConnectionCallback = Box<dyn Fn(bool) + Send + Sync + 'static>;
pub type AsyncEventCallback = Box<dyn Fn(Event) -> BoxFuture<'static, ()> + Send + Sync + 'static>;
struct SubscriptionHandler {
topic: EventTopic,
filter: EventFilter,
pending: bool,
}
#[derive(Debug, Deserialize)]
#[serde(untagged)]
pub enum WebSocketMessage {
Event(Event),
SubscriptionResponse(SubscriptionResponse),
UnsubscribeResponse(UnsubscribeResponse),
ErrorResponse(SubscriptionErrorResponse),
}
#[derive(Clone, Debug)]
pub enum BackoffStrategy {
Constant(Duration),
Linear { initial: Duration, step: Duration },
Exponential {
initial: Duration,
factor: f32,
max_delay: Duration,
jitter: f32, },
}
impl BackoffStrategy {
pub fn default_exponential() -> Self {
Self::Exponential {
initial: Duration::from_secs(1),
factor: 2.0,
max_delay: Duration::from_secs(5),
jitter: 0.1,
}
}
pub fn next_delay(&self, attempt: usize) -> Duration {
match self {
Self::Constant(duration) => *duration,
Self::Linear { initial, step } => *initial + (*step * attempt as u32),
Self::Exponential {
initial,
factor,
max_delay,
jitter,
} => {
let base_ms = initial.as_millis() as f32 * factor.powi(attempt as i32);
let jitter_factor = 1.0 - jitter + rand::random::<f32>() * jitter * 2.0;
let jittered_ms = base_ms * jitter_factor;
let capped_ms = jittered_ms.min(max_delay.as_millis() as f32);
Duration::from_millis(capped_ms as u64)
}
}
}
}
pub struct WebSocketClient {
sender: mpsc::Sender<Message>,
subscriptions: Arc<RwLock<HashMap<String, SubscriptionHandler>>>,
event_callbacks: Arc<RwLock<HashMap<EventTopic, Vec<EventCallback>>>>,
connection_callbacks: Arc<RwLock<Vec<ConnectionCallback>>>,
connected: Arc<Mutex<bool>>,
server_url: Arc<String>,
auto_reconnect: Arc<Mutex<bool>>,
backoff_strategy: Arc<Mutex<BackoffStrategy>>,
max_reconnect_attempts: Arc<Mutex<usize>>,
running: Arc<Mutex<bool>>,
cancel_tx: Option<mpsc::Sender<()>>,
async_event_callbacks: Arc<RwLock<HashMap<EventTopic, Vec<AsyncEventCallback>>>>,
state_change_lock: Arc<Mutex<()>>,
keep_alive_interval: Arc<Mutex<Option<Duration>>>,
keep_alive_handle: Arc<Mutex<Option<tokio::task::JoinHandle<()>>>>,
}
impl WebSocketClient {
pub fn new(url: &str) -> Self {
let subscriptions = Arc::new(RwLock::new(HashMap::new()));
let event_callbacks = Arc::new(RwLock::new(HashMap::new()));
let connection_callbacks = Arc::new(RwLock::new(Vec::new()));
let connected = Arc::new(Mutex::new(false));
let server_url = Arc::new(url.to_string());
let auto_reconnect = Arc::new(Mutex::new(false));
let backoff_strategy = Arc::new(Mutex::new(BackoffStrategy::default_exponential()));
let max_reconnect_attempts = Arc::new(Mutex::new(0));
let running = Arc::new(Mutex::new(true));
let async_event_callbacks = Arc::new(RwLock::new(HashMap::new()));
let state_change_lock = Arc::new(Mutex::new(()));
let keep_alive_interval = Arc::new(Mutex::new(None));
let keep_alive_handle = Arc::new(Mutex::new(None));
let (sender, _) = mpsc::channel::<Message>(100);
WebSocketClient {
sender,
subscriptions,
event_callbacks,
connection_callbacks,
connected,
server_url,
auto_reconnect,
backoff_strategy,
max_reconnect_attempts,
running,
cancel_tx: None,
async_event_callbacks,
state_change_lock,
keep_alive_interval,
keep_alive_handle,
}
}
pub async fn connect(&mut self) -> Result<(), WebSocketError> {
let _connection_lock = self.state_change_lock.lock().await;
if *self.connected.lock().await {
return Ok(());
}
if let Some(cancel_tx) = self.cancel_tx.take() {
let _ = cancel_tx.send(()).await;
tokio::time::sleep(Duration::from_millis(50)).await;
}
let (cancel_tx, cancel_rx) = mpsc::channel::<()>(1);
self.cancel_tx = Some(cancel_tx);
let connect_result = Self::establish_new_connection(
&self.server_url,
&self.connected,
&self.connection_callbacks,
&self.keep_alive_handle,
&self.keep_alive_interval,
&self.running,
)
.await;
if let Err(e) = &connect_result {
self.cancel_tx = None;
return Err(e.clone());
}
let (read, sender) = connect_result.unwrap();
self.sender = sender.clone();
let _task_handle = tokio::spawn(Self::message_processor(
read,
sender,
self.subscriptions.clone(),
self.event_callbacks.clone(),
self.async_event_callbacks.clone(),
self.connection_callbacks.clone(),
self.connected.clone(),
self.keep_alive_handle.clone(),
self.keep_alive_interval.clone(),
self.running.clone(),
self.server_url.clone(),
self.auto_reconnect.clone(),
self.backoff_strategy.clone(),
self.max_reconnect_attempts.clone(),
cancel_rx,
));
Ok(())
}
pub async fn connect_static(url: &str) -> Result<Self, WebSocketError> {
let mut client = Self::new(url);
client.connect().await?;
Ok(client)
}
async fn establish_new_connection(
server_url: &Arc<String>,
connected: &Arc<Mutex<bool>>,
connection_callbacks: &Arc<RwLock<Vec<ConnectionCallback>>>,
keep_alive_handle: &Arc<Mutex<Option<tokio::task::JoinHandle<()>>>>,
keep_alive_interval: &Arc<Mutex<Option<Duration>>>,
running: &Arc<Mutex<bool>>,
) -> Result<
(
futures::stream::SplitStream<
WebSocketStream<tokio_tungstenite::MaybeTlsStream<TcpStream>>,
>,
mpsc::Sender<Message>,
),
WebSocketError,
> {
let ws_stream = Self::establish_connection(server_url).await?;
let (write, read) = ws_stream.split();
let sender = Self::spawn_writer_task(write);
if let Some(interval) = *keep_alive_interval.lock().await {
Self::restart_keep_alive(keep_alive_handle, &sender, running, connected, interval)
.await?;
}
Self::update_connection_status(connected, true, connection_callbacks).await;
Ok((read, sender))
}
#[allow(clippy::too_many_arguments)]
async fn attempt_reconnection(
server_url: &Arc<String>,
connected: &Arc<Mutex<bool>>,
connection_callbacks: &Arc<RwLock<Vec<ConnectionCallback>>>,
subscriptions: &Arc<RwLock<HashMap<String, SubscriptionHandler>>>,
keep_alive_handle: &Arc<Mutex<Option<tokio::task::JoinHandle<()>>>>,
keep_alive_interval: &Arc<Mutex<Option<Duration>>>,
running: &Arc<Mutex<bool>>,
backoff_strategy: &Arc<Mutex<BackoffStrategy>>,
max_reconnect_attempts: &Arc<Mutex<usize>>,
reconnect_attempts: &mut usize,
cancel_rx: &mut mpsc::Receiver<()>,
) -> Result<
(
futures::stream::SplitStream<
WebSocketStream<tokio_tungstenite::MaybeTlsStream<TcpStream>>,
>,
mpsc::Sender<Message>,
),
WebSocketError,
> {
loop {
let max_attempts = *max_reconnect_attempts.lock().await;
if max_attempts > 0 && *reconnect_attempts >= max_attempts {
error!("Maximum reconnection attempts reached ({})", max_attempts);
return Err(WebSocketError::ConnectionFailed(
"Maximum reconnection attempts reached".to_string(),
));
}
let delay = backoff_strategy
.lock()
.await
.next_delay(*reconnect_attempts);
tokio::time::sleep(delay).await;
*reconnect_attempts += 1;
info!(
"Attempting to reconnect (attempt {}, delay: {:?})...",
reconnect_attempts, delay
);
match Self::establish_new_connection(
server_url,
connected,
connection_callbacks,
keep_alive_handle,
keep_alive_interval,
running,
)
.await
{
Ok((read, sender)) => {
if let Err(e) = Self::resubscribe_all(&sender, subscriptions).await {
error!("Failed to re-subscribe: {}", e);
}
return Ok((read, sender));
}
Err(e) => {
error!("Reconnection failed: {}", e);
}
}
if let Ok(Some(())) =
tokio::time::timeout(tokio::time::Duration::from_millis(10), cancel_rx.recv()).await
{
return Err(WebSocketError::Other(
"Cancelled during reconnection".to_string(),
));
}
}
}
async fn notify_connection_status(
connected: bool,
callbacks: &Arc<RwLock<Vec<ConnectionCallback>>>,
) {
let callbacks_guard = callbacks.read().await;
for callback in callbacks_guard.iter() {
callback(connected);
}
}
async fn handle_subscription_response(
response: SubscriptionResponse,
subscriptions: &Arc<RwLock<HashMap<String, SubscriptionHandler>>>,
) {
info!("Received subscription response: {:?}", response);
let mut subs = subscriptions.write().await;
let Some(lookup_key) = &response.request_id else {
warn!("Received subscription response for unknown subscription");
return;
};
if let Some(handler) = subs.remove(lookup_key) {
if matches!(response.status, SubscriptionStatus::Subscribed) {
subs.insert(
response.subscription_id.clone(),
SubscriptionHandler {
topic: handler.topic,
filter: handler.filter,
pending: false,
},
);
}
} else {
warn!(
"Received subscription response for unknown subscription: {:?}",
response
);
}
}
async fn resubscribe_all(
sender: &mpsc::Sender<Message>,
subscriptions: &Arc<RwLock<HashMap<String, SubscriptionHandler>>>,
) -> Result<(), WebSocketError> {
let resubscribe_list = {
let mut subs = subscriptions.write().await;
let to_resubscribe: Vec<_> = subs
.iter()
.map(|(id, handler)| (id.clone(), handler.topic.clone(), handler.filter.clone()))
.collect();
subs.clear();
to_resubscribe
};
for (_, topic, filter) in resubscribe_list {
let pending_id = format!("pending-{}-{}", topic, uuid::Uuid::new_v4());
{
let mut subs = subscriptions.write().await;
subs.insert(
pending_id.clone(),
SubscriptionHandler {
topic: topic.clone(),
filter: filter.clone(),
pending: true,
},
);
}
let request = WebSocketRequest::Subscribe(SubscriptionRequest {
topic,
filter,
request_id: Some(pending_id),
});
let message = serde_json::to_string(&request).map_err(|e| {
WebSocketError::Other(format!("Failed to serialize request: {}", e))
})?;
sender
.send(Message::Text(message.into()))
.await
.map_err(|e| WebSocketError::SendFailed(e.to_string()))?;
}
Ok(())
}
async fn process_message(
message: Message,
sender: &mpsc::Sender<Message>,
subscriptions: &Arc<RwLock<HashMap<String, SubscriptionHandler>>>,
event_callbacks: &Arc<RwLock<HashMap<EventTopic, Vec<EventCallback>>>>,
async_event_callbacks: &Arc<RwLock<HashMap<EventTopic, Vec<AsyncEventCallback>>>>,
) -> bool {
match message {
Message::Text(text) => {
match serde_json::from_str::<WebSocketMessage>(&text) {
Ok(WebSocketMessage::Event(event)) => {
Self::handle_event(event, event_callbacks, async_event_callbacks).await;
}
Ok(WebSocketMessage::SubscriptionResponse(response)) => {
Self::handle_subscription_response(response, subscriptions).await;
}
Ok(WebSocketMessage::UnsubscribeResponse(response)) => {
Self::handle_unsubscribe_response(response, subscriptions).await;
}
Ok(WebSocketMessage::ErrorResponse(error)) => {
warn!("Subscription error: {}", error.error);
}
Err(e) => {
error!("Failed to parse WebSocket message: {}", e);
debug!("Message content: {}", text);
}
}
false
}
Message::Binary(_) => {
debug!("Received binary message");
false
}
Message::Ping(data) => {
if let Err(e) = sender.send(Message::Pong(data)).await {
warn!("Failed to send pong: {}", e);
}
false
}
Message::Pong(_) => false, Message::Frame(_) => false, Message::Close(_) => true, }
}
async fn handle_event(
event: Event,
event_callbacks: &Arc<RwLock<HashMap<EventTopic, Vec<EventCallback>>>>,
async_event_callbacks: &Arc<RwLock<HashMap<EventTopic, Vec<AsyncEventCallback>>>>,
) {
let topic = event.topic();
{
let callbacks = event_callbacks.read().await;
if let Some(handlers) = callbacks.get(&topic) {
for handler in handlers {
match panic::catch_unwind(AssertUnwindSafe(|| {
handler(event.clone());
})) {
Ok(_) => {}
Err(e) => {
let panic_msg = if let Some(s) = e.downcast_ref::<&str>() {
s
} else if let Some(s) = e.downcast_ref::<String>() {
s.as_str()
} else {
"Unknown panic"
};
error!("Event handler panicked: {}", panic_msg);
}
}
}
}
}
{
let async_callbacks = async_event_callbacks.read().await;
if let Some(handlers) = async_callbacks.get(&topic) {
for handler in handlers {
let event_clone = event.clone();
let future = match panic::catch_unwind(AssertUnwindSafe(|| {
handler(event_clone.clone())
})) {
Ok(future) => future,
Err(e) => {
let panic_msg = if let Some(s) = e.downcast_ref::<&str>() {
s
} else if let Some(s) = e.downcast_ref::<String>() {
s.as_str()
} else {
"Unknown panic"
};
error!("Async event handler panicked during setup: {}", panic_msg);
continue;
}
};
tokio::spawn(async move {
match panic::catch_unwind(AssertUnwindSafe(|| async {
future.await;
})) {
Ok(f) => {
f.await;
}
Err(e) => {
let panic_msg = if let Some(s) = e.downcast_ref::<&str>() {
s
} else if let Some(s) = e.downcast_ref::<String>() {
s.as_str()
} else {
"Unknown panic"
};
error!(
"Async event handler panicked during execution: {}",
panic_msg
);
}
}
});
}
}
}
}
async fn handle_unsubscribe_response(
response: UnsubscribeResponse,
subscriptions: &Arc<RwLock<HashMap<String, SubscriptionHandler>>>,
) {
if matches!(response.status, SubscriptionStatus::Unsubscribed) {
subscriptions
.write()
.await
.remove(&response.subscription_id);
}
}
async fn establish_connection(
server_url: &str,
) -> Result<WebSocketStream<tokio_tungstenite::MaybeTlsStream<TcpStream>>, WebSocketError> {
match connect_async(server_url).await {
Ok((ws_stream, _)) => Ok(ws_stream),
Err(e) => Err(WebSocketError::ConnectionFailed(e.to_string())),
}
}
fn spawn_writer_task(
write: futures::stream::SplitSink<
WebSocketStream<tokio_tungstenite::MaybeTlsStream<TcpStream>>,
Message,
>,
) -> mpsc::Sender<Message> {
let (sender, mut new_receiver) = mpsc::channel::<Message>(1000);
tokio::spawn(async move {
let mut writer = write;
while let Some(msg) = new_receiver.recv().await {
if let Err(e) = writer.send(msg).await {
error!("Failed to send message: {}", e);
break;
}
}
});
sender
}
#[allow(clippy::too_many_arguments)]
async fn message_processor(
initial_stream: futures::stream::SplitStream<
WebSocketStream<tokio_tungstenite::MaybeTlsStream<TcpStream>>,
>,
initial_sender: mpsc::Sender<Message>,
subscriptions: Arc<RwLock<HashMap<String, SubscriptionHandler>>>,
event_callbacks: Arc<RwLock<HashMap<EventTopic, Vec<EventCallback>>>>,
async_event_callbacks: Arc<RwLock<HashMap<EventTopic, Vec<AsyncEventCallback>>>>,
connection_callbacks: Arc<RwLock<Vec<ConnectionCallback>>>,
connected: Arc<Mutex<bool>>,
keep_alive_handle: Arc<Mutex<Option<tokio::task::JoinHandle<()>>>>,
keep_alive_interval: Arc<Mutex<Option<Duration>>>,
running: Arc<Mutex<bool>>,
server_url: Arc<String>,
auto_reconnect: Arc<Mutex<bool>>,
backoff_strategy: Arc<Mutex<BackoffStrategy>>,
max_reconnect_attempts: Arc<Mutex<usize>>,
mut cancel_rx: mpsc::Receiver<()>,
) {
let mut reconnect_attempts: usize = 0;
let mut read = initial_stream;
let mut sender = initial_sender;
while *running.lock().await {
let disconnected = Self::process_messages(
&mut read,
&sender,
&subscriptions,
&event_callbacks,
&async_event_callbacks,
&connection_callbacks,
&connected,
&mut cancel_rx,
)
.await;
if disconnected {
if !*running.lock().await || !*auto_reconnect.lock().await {
return; }
match Self::attempt_reconnection(
&server_url,
&connected,
&connection_callbacks,
&subscriptions,
&keep_alive_handle,
&keep_alive_interval,
&running,
&backoff_strategy,
&max_reconnect_attempts,
&mut reconnect_attempts,
&mut cancel_rx,
)
.await
{
Ok((new_read, new_sender)) => {
read = new_read;
sender = new_sender;
reconnect_attempts = 0; }
Err(_) => return, }
}
}
}
#[allow(clippy::too_many_arguments)]
async fn process_messages(
read: &mut futures::stream::SplitStream<
WebSocketStream<tokio_tungstenite::MaybeTlsStream<TcpStream>>,
>,
sender: &mpsc::Sender<Message>,
subscriptions: &Arc<RwLock<HashMap<String, SubscriptionHandler>>>,
event_callbacks: &Arc<RwLock<HashMap<EventTopic, Vec<EventCallback>>>>,
async_event_callbacks: &Arc<RwLock<HashMap<EventTopic, Vec<AsyncEventCallback>>>>,
connection_callbacks: &Arc<RwLock<Vec<ConnectionCallback>>>,
connected: &Arc<Mutex<bool>>,
cancel_rx: &mut mpsc::Receiver<()>,
) -> bool {
loop {
tokio::select! {
_ = cancel_rx.recv() => {
return false; }
message = read.next() => {
match message {
Some(Ok(msg)) => {
if Self::process_message(msg, sender, subscriptions, event_callbacks, async_event_callbacks).await {
Self::update_connection_status(connected, false, connection_callbacks).await;
return true; }
}
Some(Err(e)) => {
error!("WebSocket read error: {}", e);
Self::update_connection_status(connected, false, connection_callbacks).await;
return true; }
None => {
debug!("WebSocket stream ended");
Self::update_connection_status(connected, false, connection_callbacks).await;
return true; }
}
}
}
}
}
pub async fn subscribe(
&self,
topic: EventTopic,
filter: EventFilter,
) -> Result<(), WebSocketError> {
let pending_id = format!("pending-{}-{}", topic, uuid::Uuid::new_v4());
let subs = self.subscriptions.read().await;
for (_, handler) in subs.iter() {
if !handler.pending && handler.topic == topic && handler.filter == filter {
return Ok(());
}
}
drop(subs);
let mut subs = self.subscriptions.write().await;
subs.insert(
pending_id.clone(), SubscriptionHandler {
topic: topic.clone(),
filter: filter.clone(),
pending: true,
},
);
drop(subs);
let request = WebSocketRequest::Subscribe(SubscriptionRequest {
topic: topic.clone(),
filter: filter.clone(),
request_id: Some(pending_id.clone()),
});
let message = serde_json::to_string(&request)
.map_err(|e| WebSocketError::Other(format!("Failed to serialize request: {}", e)))?;
self.sender
.send(Message::Text(message.into()))
.await
.map_err(|e| WebSocketError::SendFailed(e.to_string()))?;
Ok(())
}
pub async fn on_event<F>(
&self,
topic: EventTopic,
filter: Option<EventFilter>,
callback: F,
) -> Result<(), WebSocketError>
where
F: Fn(Event) + Send + Sync + 'static,
{
let subs = self.subscriptions.read().await;
let has_topic_subscription = subs.values().any(|s| s.topic == topic);
drop(subs);
if !has_topic_subscription {
self.subscribe(topic.clone(), filter.unwrap_or_default())
.await?;
}
let mut callbacks = self.event_callbacks.write().await;
callbacks
.entry(topic)
.or_insert_with(Vec::new)
.push(Box::new(callback));
Ok(())
}
pub async fn on_event_async<F, Fut>(
&self,
topic: EventTopic,
filter: Option<EventFilter>,
callback: F,
) -> Result<(), WebSocketError>
where
F: Fn(Event) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let subs = self.subscriptions.read().await;
let has_topic_subscription = subs.values().any(|s| s.topic == topic);
drop(subs);
if !has_topic_subscription {
self.subscribe(topic.clone(), filter.unwrap_or_default())
.await?;
}
let boxed_callback =
move |event: Event| -> BoxFuture<'static, ()> { callback(event).boxed() };
let mut callbacks = self.async_event_callbacks.write().await;
callbacks
.entry(topic)
.or_insert_with(Vec::new)
.push(Box::new(boxed_callback));
Ok(())
}
pub async fn on_connection_change<F>(&self, callback: F)
where
F: Fn(bool) + Send + Sync + 'static,
{
self.connection_callbacks
.write()
.await
.push(Box::new(callback));
}
pub async fn unsubscribe(&self, subscription_id: &str) -> Result<(), WebSocketError> {
let topic = {
let subs = self.subscriptions.read().await;
match subs.get(subscription_id) {
Some(sub) => sub.topic.clone(),
None => {
return Err(WebSocketError::UnsubscriptionFailed(
"Subscription not found".to_string(),
))
}
}
};
let request = WebSocketRequest::Unsubscribe(UnsubscribeRequest {
topic,
subscription_id: subscription_id.to_string(),
});
let request_json =
serde_json::to_string(&request).map_err(|e| WebSocketError::Other(e.to_string()))?;
self.sender
.send(Message::Text(request_json.into()))
.await
.map_err(|e| WebSocketError::SendFailed(e.to_string()))?;
self.subscriptions.write().await.remove(subscription_id);
Ok(())
}
pub async fn unsubscribe_topic(&self, topic: &EventTopic) -> Result<(), WebSocketError> {
let subscription_ids: Vec<String> = {
let subs = self.subscriptions.read().await;
subs.iter()
.filter(|(_, handler)| handler.topic == *topic)
.map(|(id, _)| id.clone())
.collect()
};
let mut result = Ok(());
for id in subscription_ids {
if let Err(e) = self.unsubscribe(&id).await {
result = Err(e);
}
}
result
}
pub async fn remove_event_listeners(&self, topic: &EventTopic) {
let mut callbacks = self.event_callbacks.write().await;
callbacks.remove(topic);
let mut async_callbacks = self.async_event_callbacks.write().await;
async_callbacks.remove(topic);
}
pub async fn set_auto_reconnect(
&self,
enabled: bool,
interval: std::time::Duration,
max_attempts: usize,
) {
*self.auto_reconnect.lock().await = enabled;
*self.backoff_strategy.lock().await = BackoffStrategy::Constant(interval);
*self.max_reconnect_attempts.lock().await = max_attempts;
}
pub async fn is_connected(&self) -> bool {
*self.connected.lock().await
}
pub async fn close(&self) -> Result<(), WebSocketError> {
*self.auto_reconnect.lock().await = false;
*self.running.lock().await = false;
if let Some(cancel_tx) = &self.cancel_tx {
let _ = cancel_tx.send(()).await;
}
if let Some(handle) = self.keep_alive_handle.lock().await.take() {
handle.abort();
}
let _ = self.sender.send(Message::Close(None)).await;
Self::update_connection_status(&self.connected, false, &self.connection_callbacks).await;
Ok(())
}
pub async fn set_reconnect_options(
&self,
enabled: bool,
strategy: BackoffStrategy,
max_attempts: usize,
) {
*self.auto_reconnect.lock().await = enabled;
*self.backoff_strategy.lock().await = strategy;
*self.max_reconnect_attempts.lock().await = max_attempts;
}
async fn update_connection_status(
connected: &Arc<Mutex<bool>>,
new_state: bool,
connection_callbacks: &Arc<RwLock<Vec<ConnectionCallback>>>,
) -> bool {
let mut connected_guard = connected.lock().await;
let changed = *connected_guard != new_state;
*connected_guard = new_state;
drop(connected_guard);
if changed {
Self::notify_connection_status(new_state, connection_callbacks).await;
}
changed }
pub async fn enable_keep_alive(&self, interval: Duration) -> Result<(), WebSocketError> {
*self.keep_alive_interval.lock().await = Some(interval);
Self::restart_keep_alive(
&self.keep_alive_handle,
&self.sender,
&self.running,
&self.connected,
interval,
)
.await
}
async fn restart_keep_alive(
keep_alive_handle: &Arc<Mutex<Option<tokio::task::JoinHandle<()>>>>,
sender: &mpsc::Sender<Message>,
running: &Arc<Mutex<bool>>,
connected: &Arc<Mutex<bool>>,
interval: Duration,
) -> Result<(), WebSocketError> {
if let Some(handle) = keep_alive_handle.lock().await.take() {
handle.abort();
}
let sender = sender.clone();
let running = running.clone();
let connected = connected.clone();
let handle = tokio::spawn(async move {
let mut interval_timer = tokio::time::interval(interval);
while *running.lock().await {
interval_timer.tick().await;
if *connected.lock().await {
if let Err(e) = sender.send(Message::Ping(Bytes::from_static(&[]))).await {
error!("Failed to send ping: {}", e);
}
}
}
});
*keep_alive_handle.lock().await = Some(handle);
Ok(())
}
}
impl Drop for WebSocketClient {
fn drop(&mut self) {
if let Some(running) = Arc::get_mut(&mut self.running) {
if let Ok(mut guard) = running.try_lock() {
*guard = false;
}
}
if let Some(cancel_tx) = self.cancel_tx.take() {
let _ = cancel_tx.try_send(());
}
if let Some(keep_alive_handle) = Arc::get_mut(&mut self.keep_alive_handle) {
if let Ok(mut guard) = keep_alive_handle.try_lock() {
if let Some(handle) = guard.take() {
handle.abort();
}
}
}
}
}