use std::future::Future;
use std::marker::PhantomData;
use std::time::Duration;
use crate::codec::WsCodec;
use crate::connection::{RecvError, SendError, WsConnection};
#[derive(Debug, Clone)]
pub struct BackoffConfig {
pub initial: Duration,
pub max: Duration,
pub multiplier: f64,
}
impl Default for BackoffConfig {
fn default() -> Self {
BackoffConfig {
initial: Duration::from_secs(1),
max: Duration::from_secs(30),
multiplier: 2.0,
}
}
}
impl BackoffConfig {
fn delay_for_attempt(&self, attempt: u32) -> Duration {
let delay_secs = self.initial.as_secs_f64() * self.multiplier.powi(attempt as i32);
let clamped = delay_secs.min(self.max.as_secs_f64());
Duration::from_secs_f64(clamped)
}
}
pub struct ReconnectingWs<S, R, F, Fut>
where
S: WsCodec + Clone,
R: WsCodec,
F: FnMut() -> Fut,
Fut: Future<Output = Option<WsConnection<S, R>>>,
{
config: BackoffConfig,
connect_fn: F,
conn: Option<WsConnection<S, R>>,
attempt: u32,
_types: PhantomData<(S, R)>,
}
impl<S, R, F, Fut> ReconnectingWs<S, R, F, Fut>
where
S: WsCodec + Clone,
R: WsCodec,
F: FnMut() -> Fut,
Fut: Future<Output = Option<WsConnection<S, R>>>,
{
pub fn new(config: BackoffConfig, connect_fn: F) -> Self {
ReconnectingWs {
config,
connect_fn,
conn: None,
attempt: 0,
_types: PhantomData,
}
}
pub async fn send(&mut self, msg: S) -> Result<(), SendError> {
loop {
if self.conn.is_none() {
self.reconnect().await;
}
if let Some(ref mut conn) = self.conn {
match conn.send(msg.clone()).await {
Ok(()) => {
self.attempt = 0;
return Ok(());
}
Err(_) => {
self.conn = None;
continue;
}
}
}
return Err(SendError::Closed);
}
}
pub async fn recv(&mut self) -> Option<Result<R, RecvError>> {
loop {
if self.conn.is_none() {
self.reconnect().await;
}
if let Some(ref mut conn) = self.conn {
match conn.recv().await {
Some(Ok(msg)) => {
self.attempt = 0;
return Some(Ok(msg));
}
Some(Err(RecvError::Decode(e))) => {
self.attempt = 0;
return Some(Err(RecvError::Decode(e)));
}
Some(Err(RecvError::Closed)) | None => {
self.conn = None;
continue;
}
}
}
}
}
pub async fn reconnect(&mut self) {
self.conn = None;
loop {
if self.attempt > 0 {
let delay = self
.config
.delay_for_attempt(self.attempt.saturating_sub(1));
tokio::time::sleep(delay).await;
}
self.attempt = self.attempt.saturating_add(1);
if let Some(conn) = (self.connect_fn)().await {
self.conn = Some(conn);
self.attempt = 0;
return;
}
}
}
}
#[cfg(feature = "native-client")]
#[allow(clippy::type_complexity)]
pub fn connect_native<E>(
base_url: String,
config: BackoffConfig,
) -> ReconnectingWs<
E::ClientMsg,
E::ServerMsg,
impl FnMut() -> std::pin::Pin<
Box<dyn Future<Output = Option<WsConnection<E::ClientMsg, E::ServerMsg>>> + Send>,
>,
std::pin::Pin<
Box<dyn Future<Output = Option<WsConnection<E::ClientMsg, E::ServerMsg>>> + Send>,
>,
>
where
E: crate::WsEndpoint,
{
ReconnectingWs::new(config, move || {
let url = base_url.clone();
Box::pin(async move { crate::native_client::connect::<E>(&url).await.ok() })
as std::pin::Pin<
Box<dyn Future<Output = Option<WsConnection<E::ClientMsg, E::ServerMsg>>> + Send>,
>
})
}