use std::sync::Arc;
use std::time::Duration;
use alloy::providers::{Provider, RootProvider};
use alloy::rpc::types::{Filter, Header, Log};
use tokio::sync::mpsc;
use tokio_stream::StreamExt;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ReconnectConfig {
pub initial_backoff: Duration,
pub max_backoff: Duration,
pub backoff_multiplier: u32,
pub max_attempts: u32,
}
impl Default for ReconnectConfig {
fn default() -> Self {
Self {
initial_backoff: Duration::from_millis(500),
max_backoff: Duration::from_secs(30),
backoff_multiplier: 2,
max_attempts: 0, }
}
}
pub struct WsManager {
url: String,
config: ReconnectConfig,
provider: Arc<RootProvider>,
}
impl std::fmt::Debug for WsManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WsManager")
.field("url", &self.url)
.field("config", &self.config)
.finish_non_exhaustive()
}
}
impl WsManager {
pub async fn connect(url: impl Into<String>, config: ReconnectConfig) -> crate::Result<Self> {
let url = url.into();
let ws_connect = alloy::providers::WsConnect::new(&url);
let rpc_client = alloy::rpc::client::ClientBuilder::default()
.ws(ws_connect)
.await?;
let provider = RootProvider::new(rpc_client);
tracing::info!(url = %url, "WebSocket connected");
Ok(Self {
url,
config,
provider: Arc::new(provider),
})
}
pub async fn subscribe_blocks(&self) -> crate::Result<mpsc::Receiver<Header>> {
let (tx, rx) = mpsc::channel(64);
let provider = Arc::clone(&self.provider);
let url = self.url.clone();
tracing::debug!(url = %self.url, "subscribing to blocks");
tokio::spawn(async move {
let sub = match provider.subscribe_blocks().await {
Ok(sub) => sub,
Err(e) => {
tracing::warn!(url = %url, error = %e, "block subscription failed");
return;
}
};
tracing::debug!(url = %url, "block subscription established");
let mut stream = sub.into_stream();
while let Some(block) = stream.next().await {
if tx.send(block).await.is_err() {
break; }
}
tracing::debug!(url = %url, "block subscription ended");
});
Ok(rx)
}
pub async fn subscribe_logs(&self, filter: Filter) -> crate::Result<mpsc::Receiver<Log>> {
let (tx, rx) = mpsc::channel(256);
let provider = Arc::clone(&self.provider);
let url = self.url.clone();
tracing::debug!(url = %self.url, "subscribing to logs");
tokio::spawn(async move {
let sub = match provider.subscribe_logs(&filter).await {
Ok(sub) => sub,
Err(e) => {
tracing::warn!(url = %url, error = %e, "log subscription failed");
return;
}
};
tracing::debug!(url = %url, "log subscription established");
let mut stream = sub.into_stream();
while let Some(log) = stream.next().await {
if tx.send(log).await.is_err() {
break;
}
}
tracing::debug!(url = %url, "log subscription ended");
});
Ok(rx)
}
pub async fn reconnect(&self) -> Option<Self> {
let mut delay = self.config.initial_backoff;
let mut attempts = 0u32;
loop {
attempts += 1;
if self.config.max_attempts > 0 && attempts > self.config.max_attempts {
tracing::warn!(url = %self.url, max_attempts = self.config.max_attempts, "reconnect attempts exhausted");
return None;
}
tracing::info!(url = %self.url, attempt = attempts, delay_ms = delay.as_millis() as u64, "reconnecting");
tokio::time::sleep(delay).await;
match Self::connect(self.url.clone(), self.config).await {
Ok(new_manager) => {
tracing::info!(url = %self.url, attempt = attempts, "reconnected");
return Some(new_manager);
}
Err(e) => {
tracing::warn!(url = %self.url, attempt = attempts, error = %e, "reconnect failed");
delay = (delay * self.config.backoff_multiplier).min(self.config.max_backoff);
}
}
}
}
pub fn url(&self) -> &str {
&self.url
}
pub fn provider(&self) -> &RootProvider {
&self.provider
}
}
pub fn backoff_delay(config: &ReconnectConfig, attempt: u32) -> Duration {
let multiplier = config
.backoff_multiplier
.checked_pow(attempt)
.unwrap_or(u32::MAX);
config
.initial_backoff
.saturating_mul(multiplier)
.min(config.max_backoff)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn backoff_delay_exponential() {
let config = ReconnectConfig {
initial_backoff: Duration::from_millis(100),
max_backoff: Duration::from_secs(10),
backoff_multiplier: 2,
max_attempts: 0,
};
assert_eq!(backoff_delay(&config, 0), Duration::from_millis(100));
assert_eq!(backoff_delay(&config, 1), Duration::from_millis(200));
assert_eq!(backoff_delay(&config, 2), Duration::from_millis(400));
assert_eq!(backoff_delay(&config, 3), Duration::from_millis(800));
}
#[test]
fn backoff_delay_capped_at_max() {
let config = ReconnectConfig {
initial_backoff: Duration::from_millis(100),
max_backoff: Duration::from_millis(500),
backoff_multiplier: 2,
max_attempts: 0,
};
assert_eq!(backoff_delay(&config, 5), Duration::from_millis(500));
assert_eq!(backoff_delay(&config, 10), Duration::from_millis(500));
}
#[test]
fn backoff_delay_handles_overflow() {
let config = ReconnectConfig {
initial_backoff: Duration::from_secs(1),
max_backoff: Duration::from_secs(60),
backoff_multiplier: 10,
max_attempts: 0,
};
assert_eq!(backoff_delay(&config, 30), Duration::from_secs(60));
}
#[test]
fn backoff_delay_multiplier_one_is_constant() {
let config = ReconnectConfig {
initial_backoff: Duration::from_millis(100),
max_backoff: Duration::from_secs(10),
backoff_multiplier: 1,
max_attempts: 0,
};
assert_eq!(backoff_delay(&config, 0), Duration::from_millis(100));
assert_eq!(backoff_delay(&config, 5), Duration::from_millis(100));
assert_eq!(backoff_delay(&config, 100), Duration::from_millis(100));
}
}