use std::{
sync::{Arc, atomic::{AtomicBool, Ordering}},
collections::hash_map::{HashMap, Entry},
time::Duration,
mem,
};
use tokio::{
sync::{mpsc as tokio_mpsc, Mutex as AsyncMutex, Notify},
task::JoinHandle,
net::TcpStream,
time::{MissedTickBehavior, timeout},
};
use tokio_tungstenite::{
tungstenite,
MaybeTlsStream,
};
pub use tungstenite::Error as TungsteniteError;
use futures_util::{
sink::SinkExt,
stream::{StreamExt, SplitSink},
};
use parking_lot::Mutex as SyncMutex;
type WebSocketStream = tokio_tungstenite::WebSocketStream<MaybeTlsStream<TcpStream>>;
type WebSocketSplitSink = SplitSink<WebSocketStream, tungstenite::Message>;
#[derive(Debug)]
#[must_use = "dropping WebSocketConnection closes the connection"]
pub struct WebSocketConnection<H: WebSocketHandler> {
task_reconnect: JoinHandle<()>,
sink: Arc<AsyncMutex<WebSocketSplitSink>>,
inner: Arc<ConnectionInner<H>>,
reconnect_state: ReconnectState,
}
#[derive(Debug)]
struct ConnectionInner<H: WebSocketHandler> {
url: String,
handler: Arc<SyncMutex<H>>,
message_tx: tokio_mpsc::UnboundedSender<(bool, FeederMessage)>,
next_connection_id: AtomicBool,
}
enum FeederMessage {
Message(tungstenite::Result<tungstenite::Message>),
ConnectionClosed,
DropConnectionRequest,
}
impl<H: WebSocketHandler> WebSocketConnection<H> {
pub async fn new(url: &str, handler: H) -> Result<Self, TungsteniteError> {
let config = handler.websocket_config();
let handler = Arc::new(SyncMutex::new(handler));
let url = config.url_prefix.clone() + url;
let (message_tx, message_rx) = tokio_mpsc::unbounded_channel();
let reconnect_manager = ReconnectState::new();
let connection = Arc::new(ConnectionInner {
url,
handler: Arc::clone(&handler),
message_tx,
next_connection_id: AtomicBool::new(false),
});
async fn feed_handler(
connection: Arc<ConnectionInner<impl WebSocketHandler>>,
mut message_rx: tokio_mpsc::UnboundedReceiver<(bool, FeederMessage)>,
reconnect_manager: ReconnectState,
config: WebSocketConfig,
sink: Arc<AsyncMutex<WebSocketSplitSink>>,
) {
let mut messages: HashMap<WebSocketMessage, isize> = HashMap::new();
let timeout_duration = if config.message_timeout.is_zero() {
Duration::MAX
} else {
config.message_timeout
};
loop {
match timeout(timeout_duration, message_rx.recv()).await {
Ok(Some((id, FeederMessage::Message(Ok(message))))) => {
if let Some(message) = WebSocketMessage::from_message(message) {
if reconnect_manager.is_reconnecting() {
let id_sign: isize = if id {
1
} else {
-1
};
let entry = messages.entry(message.clone());
match entry {
Entry::Occupied(mut occupied) => {
if config.ignore_duplicate_during_reconnection {
log::debug!("Skipping duplicate message.");
continue;
}
*occupied.get_mut() += id_sign;
if id_sign != occupied.get().signum() {
log::debug!("Skipping duplicate message.");
continue;
}
},
Entry::Vacant(vacant) => {
vacant.insert(id_sign);
}
}
} else {
messages.clear();
}
let messages = connection.handler.lock().handle_message(message);
let mut sink_lock = sink.lock().await;
for message in messages {
if let Err(error) = sink_lock.send(message.into_message()).await {
log::error!("Failed to send message because of an error: {}", error);
};
}
if let Err(error) = sink_lock.flush().await {
log::error!("An error occurred while flushing WebSocket sink: {error:?}");
}
}
},
Ok(Some((_, FeederMessage::Message(Err(error))))) => {
log::error!("Failed to receive message because of an error: {error:?}");
if reconnect_manager.request_reconnect() {
log::info!("Reconnecting WebSocket because there was an error while receiving a message");
}
},
Err(_) => {
log::debug!("WebSocket message timeout");
if reconnect_manager.request_reconnect() {
log::info!("Reconnecting WebSocket because of timeout");
}
},
Ok(Some((id, FeederMessage::ConnectionClosed))) => {
let current_id = !connection.next_connection_id.load(Ordering::SeqCst);
if id != current_id {
continue;
}
log::debug!("WebSocket connection closed by server");
if reconnect_manager.request_reconnect() {
log::info!("Reconnecting WebSocket because it was disconnected by the server");
}
},
Ok(Some((_, FeederMessage::DropConnectionRequest))) => {
if let Err(error) = sink.lock().await.close().await {
log::debug!("Failed to close WebSocket connection: {error:?}");
}
break;
}
Ok(None) => unreachable!("message_rx should never be closed"),
}
}
connection.handler.lock().handle_close(false);
}
async fn reconnect<H: WebSocketHandler>(
interval: Duration,
cooldown: Duration,
connection: Arc<ConnectionInner<H>>,
sink: Arc<AsyncMutex<WebSocketSplitSink>>,
reconnect_manager: ReconnectState,
no_duplicate: bool,
wait: Duration,
) {
let mut cooldown = tokio::time::interval(cooldown);
cooldown.set_missed_tick_behavior(MissedTickBehavior::Delay);
loop {
let timer = if interval.is_zero() {
tokio::time::sleep(Duration::MAX)
} else {
tokio::time::sleep(interval)
};
tokio::select! {
_ = reconnect_manager.inner.reconnect_notify.notified() => {},
_ = timer => {},
}
log::debug!("Reconnection requested");
cooldown.tick().await;
reconnect_manager.inner.reconnecting.store(true, Ordering::SeqCst);
reconnect_manager.inner.reconnect_notify.notify_one();
reconnect_manager.inner.reconnect_notify.notified().await;
log::debug!("Starting reconnection process ...");
if no_duplicate {
tokio::time::sleep(wait).await;
}
match WebSocketConnection::<H>::start_connection(Arc::clone(&connection)).await {
Ok(new_sink) => {
let mut old_sink = mem::replace(&mut *sink.lock().await, new_sink);
log::debug!("New connection established");
if no_duplicate {
tokio::time::sleep(wait).await;
}
if let Err(error) = old_sink.close().await {
log::debug!("An error occurred while closing old connection: {}", error);
}
connection.handler.lock().handle_close(true);
log::debug!("Old connection closed");
},
Err(error) => {
log::error!("Failed to reconnect because of an error: {}, trying again ...", error);
reconnect_manager.inner.reconnect_notify.notify_one();
},
}
if no_duplicate {
tokio::time::sleep(wait).await;
}
reconnect_manager.inner.reconnecting.store(false, Ordering::SeqCst);
log::debug!("Reconnection process complete");
}
}
let sink_inner = Self::start_connection(Arc::clone(&connection)).await?;
let sink = Arc::new(AsyncMutex::new(sink_inner));
tokio::spawn(
feed_handler(
Arc::clone(&connection),
message_rx,
reconnect_manager.clone(),
config.clone(),
Arc::clone(&sink),
)
);
let task_reconnect = tokio::spawn(reconnect(
config.refresh_after,
config.connect_cooldown,
Arc::clone(&connection),
Arc::clone(&sink),
reconnect_manager.clone(),
config.ignore_duplicate_during_reconnection,
config.reconnection_wait,
));
Ok(Self {
task_reconnect,
sink,
inner: connection,
reconnect_state: reconnect_manager,
})
}
async fn start_connection(connection: Arc<ConnectionInner<impl WebSocketHandler>>) -> Result<WebSocketSplitSink, TungsteniteError> {
let (websocket_stream, _) = tokio_tungstenite::connect_async(connection.url.clone()).await?;
let (mut sink, mut stream) = websocket_stream.split();
let messages = connection.handler.lock().handle_start();
for message in messages {
sink.send(message.into_message()).await?;
}
sink.flush().await?;
let id = connection.next_connection_id.fetch_xor(true, Ordering::SeqCst);
tokio::spawn(async move {
while let Some(message) = stream.next().await {
if connection.message_tx.send((id, FeederMessage::Message(message))).is_err() {
log::debug!("WebSocket message receiver is closed; abandon connection");
return;
}
}
drop(connection.message_tx.send((id, FeederMessage::ConnectionClosed))); log::debug!("WebSocket stream closed");
});
Ok(sink)
}
pub async fn send_message(&self, message: WebSocketMessage) -> Result<(), TungsteniteError> {
let mut sink_lock = self.sink.lock().await;
sink_lock.send(message.into_message()).await?;
sink_lock.flush().await
}
pub fn reconnect_state(&self) -> ReconnectState {
self.reconnect_state.clone()
}
}
impl<H: WebSocketHandler> Drop for WebSocketConnection<H> {
fn drop(&mut self) {
self.task_reconnect.abort();
let current_id = !self.inner.next_connection_id.load(Ordering::SeqCst);
self.inner.message_tx.send((current_id, FeederMessage::DropConnectionRequest)).ok();
}
}
#[derive(Debug, Clone)]
pub struct ReconnectState {
inner: Arc<ReconnectMangerInner>,
}
#[derive(Debug)]
struct ReconnectMangerInner {
reconnect_notify: Notify,
reconnecting: AtomicBool,
}
impl ReconnectState {
fn new() -> Self {
Self {
inner: Arc::new(ReconnectMangerInner {
reconnect_notify: Notify::new(),
reconnecting: AtomicBool::new(false),
})
}
}
pub fn is_reconnecting(&self) -> bool {
self.inner.reconnecting.load(Ordering::SeqCst)
}
pub fn request_reconnect(&self) -> bool {
if self.is_reconnecting() {
false
} else {
self.inner.reconnect_notify.notify_one();
true
}
}
}
#[derive(Debug, Eq, PartialEq, Clone, Hash)]
pub enum WebSocketMessage {
Text(String),
Binary(Vec<u8>),
Ping(Vec<u8>),
Pong(Vec<u8>),
}
impl WebSocketMessage {
fn from_message(message: tungstenite::Message) -> Option<Self> {
match message {
tungstenite::Message::Text(text) => Some(Self::Text(text)),
tungstenite::Message::Binary(data) => Some(Self::Binary(data)),
tungstenite::Message::Ping(data) => Some(Self::Ping(data)),
tungstenite::Message::Pong(data) => Some(Self::Pong(data)),
tungstenite::Message::Close(_) | tungstenite::Message::Frame(_) => None,
}
}
fn into_message(self) -> tungstenite::Message {
match self {
WebSocketMessage::Text(text) => tungstenite::Message::Text(text),
WebSocketMessage::Binary(data) => tungstenite::Message::Binary(data),
WebSocketMessage::Ping(data) => tungstenite::Message::Ping(data),
WebSocketMessage::Pong(data) => tungstenite::Message::Pong(data),
}
}
}
pub trait WebSocketHandler: Send + 'static {
fn websocket_config(&self) -> WebSocketConfig {
WebSocketConfig::default()
}
fn handle_start(&mut self) -> Vec<WebSocketMessage> {
log::debug!("WebSocket connection started");
vec![]
}
fn handle_message(&mut self, message: WebSocketMessage) -> Vec<WebSocketMessage>;
#[allow(unused_variables)]
fn handle_close(&mut self, reconnect: bool) {
log::debug!("WebSocket connection closed; reconnect: {}", reconnect);
}
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct WebSocketConfig {
pub connect_cooldown: Duration,
pub refresh_after: Duration,
pub url_prefix: String,
pub ignore_duplicate_during_reconnection: bool,
pub reconnection_wait: Duration,
pub message_timeout: Duration,
}
impl WebSocketConfig {
pub fn new() -> Self {
Self::default()
}
}
impl Default for WebSocketConfig {
fn default() -> Self {
Self {
connect_cooldown: Duration::from_millis(3000),
refresh_after: Duration::ZERO,
url_prefix: String::new(),
ignore_duplicate_during_reconnection: false,
reconnection_wait: Duration::from_millis(300),
message_timeout: Duration::ZERO,
}
}
}