#![allow(unused_imports)]
use crate::connection::WebSocketClient;
use log::{warn, error, info};
use tokio::time::{sleep, Duration};
use tokio_tungstenite::tungstenite::Error;
use std::sync::Arc;
use async_trait::async_trait;
#[async_trait]
pub trait Connectable: Send + Sync {
async fn connect(&self) -> Result<(), Error>;
}
#[async_trait]
impl Connectable for WebSocketClient {
async fn connect(&self) -> Result<(), Error> {
let client = self.clone();
tokio::spawn(async move {
match client.connect().await {
Ok(_) => info!("Successfully connected"),
Err(e) => error!("Failed to connect: {}", e),
}
});
Ok(())
}
}
pub struct ReconnectStrategy {
retries: u32,
base_delay: Duration,
}
impl ReconnectStrategy {
pub fn new(retries: u32, base_delay_secs: u64) -> Self {
ReconnectStrategy {
retries,
base_delay: Duration::from_secs(base_delay_secs),
}
}
pub fn get_retries(&self) -> u32 {
self.retries
}
pub async fn reconnect(&self, client: Arc<dyn Connectable>) -> Option<()> {
for attempt in 1..=self.retries {
warn!("Reconnection attempt {} of {}", attempt, self.retries);
match client.connect().await {
Ok(()) => {
info!("Reconnected successfully on attempt {}", attempt);
return Some(()); }
Err(e) => error!("Reconnection attempt {} failed: {}", attempt, e),
}
let delay = self.base_delay * attempt;
warn!("Waiting for {:?} before next reconnection attempt", delay);
sleep(delay).await;
}
error!("Exceeded maximum reconnection attempts");
None
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio_tungstenite::tungstenite::Error;
use std::sync::Arc;
struct MockWebSocketClient;
#[async_trait]
impl Connectable for MockWebSocketClient {
async fn connect(&self) -> Result<(), Error> {
Err(Error::ConnectionClosed)
}
}
#[tokio::test]
async fn test_reconnect_strategy_creation() {
let reconnect_strategy = ReconnectStrategy::new(3, 2);
assert_eq!(reconnect_strategy.retries, 3);
assert_eq!(reconnect_strategy.base_delay, Duration::from_secs(2));
}
#[tokio::test]
async fn test_reconnect_with_exponential_backoff() {
let reconnect_strategy = ReconnectStrategy::new(3, 1);
let client = Arc::new(MockWebSocketClient);
let reconnection_result = reconnect_strategy.reconnect(client).await;
assert!(reconnection_result.is_none(), "Expected all reconnection attempts to fail");
}
#[tokio::test]
async fn test_reconnect_success() {
struct SuccessClient;
#[async_trait]
impl Connectable for SuccessClient {
async fn connect(&self) -> Result<(), Error> {
Ok(()) }
}
let reconnect_strategy = ReconnectStrategy::new(3, 1);
let client = Arc::new(SuccessClient);
let reconnection_result = reconnect_strategy.reconnect(client).await;
assert!(reconnection_result.is_some(), "Expected successful reconnection");
}
}