use std::time::Duration;
use futures_util::{SinkExt, StreamExt};
use serde_json::Value;
use tokio::net::TcpStream;
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
use crate::error::Error;
const MAX_RECONNECT_ATTEMPTS: u32 = 10;
const BASE_RECONNECT_DELAY: Duration = Duration::from_secs(1);
const MAX_RECONNECT_DELAY: Duration = Duration::from_secs(30);
type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
pub struct WebSocketConnection {
url: String,
stream: Option<WsStream>,
on_reconnect: Option<Box<dyn FnMut(u32) + Send>>,
}
impl std::fmt::Debug for WebSocketConnection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WebSocketConnection")
.field("url", &self.url)
.field("connected", &self.stream.is_some())
.finish()
}
}
impl WebSocketConnection {
pub async fn connect(url: &str) -> Result<Self, Error> {
let stream = open(url).await?;
Ok(Self {
url: url.to_string(),
stream: Some(stream),
on_reconnect: None,
})
}
pub fn on_reconnect<F>(&mut self, callback: F)
where
F: FnMut(u32) + Send + 'static,
{
self.on_reconnect = Some(Box::new(callback));
}
pub async fn next_message(&mut self) -> Result<Option<Value>, Error> {
loop {
let stream = match self.stream.as_mut() {
Some(s) => s,
None => match self.reconnect().await? {
Some(s) => s,
None => return Ok(None),
},
};
match stream.next().await {
Some(Ok(Message::Text(text))) => {
return Ok(Some(serde_json::from_str(&text)?));
}
Some(Ok(Message::Binary(bytes))) => {
return Ok(Some(serde_json::from_slice(&bytes)?));
}
Some(Ok(Message::Ping(payload))) => {
let _ = stream.send(Message::Pong(payload)).await;
continue;
}
Some(Ok(Message::Pong(_))) | Some(Ok(Message::Frame(_))) => continue,
Some(Ok(Message::Close(_))) | None => {
self.stream = None;
continue;
}
Some(Err(err)) => {
self.stream = None;
let _ = err;
continue;
}
}
}
}
pub async fn send(&mut self, value: &Value) -> Result<(), Error> {
let Some(stream) = self.stream.as_mut() else {
return Err(Error::Network {
message: "websocket is not connected".into(),
source: None,
});
};
let text = serde_json::to_string(value)?;
stream
.send(Message::Text(text))
.await
.map_err(|e| Error::Network {
message: e.to_string(),
source: Some(Box::new(e)),
})?;
Ok(())
}
pub async fn close(mut self) -> Result<(), Error> {
if let Some(mut stream) = self.stream.take() {
let _ = stream.close(None).await;
}
Ok(())
}
async fn reconnect(&mut self) -> Result<Option<&mut WsStream>, Error> {
for attempt in 1..=MAX_RECONNECT_ATTEMPTS {
if let Some(cb) = self.on_reconnect.as_mut() {
cb(attempt);
}
let delay = backoff(attempt);
tokio::time::sleep(delay).await;
match open(&self.url).await {
Ok(stream) => {
self.stream = Some(stream);
return Ok(self.stream.as_mut());
}
Err(_) if attempt < MAX_RECONNECT_ATTEMPTS => continue,
Err(_) => return Ok(None),
}
}
Ok(None)
}
}
fn backoff(attempt: u32) -> Duration {
let shift = (attempt.saturating_sub(1)).min(16);
let millis = BASE_RECONNECT_DELAY.as_millis() as u64 * (1u64 << shift);
Duration::from_millis(millis.min(MAX_RECONNECT_DELAY.as_millis() as u64))
}
async fn open(url: &str) -> Result<WsStream, Error> {
let (stream, _response) =
tokio_tungstenite::connect_async(url)
.await
.map_err(|e| Error::Network {
message: format!("websocket connect failed: {e}"),
source: Some(Box::new(e)),
})?;
Ok(stream)
}