use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::Mutex;
#[derive(Debug)]
pub struct ReconnectRateLimiter {
delay_ms: u64,
connections: Arc<Mutex<HashMap<String, ConnectionState>>>,
}
#[derive(Debug, Clone)]
struct ConnectionState {
connected: bool,
last_attempt: Instant,
failure_count: u32,
pending_turns: u32,
}
impl Default for ConnectionState {
fn default() -> Self {
Self {
connected: false,
last_attempt: Instant::now(),
failure_count: 0,
pending_turns: 0,
}
}
}
impl ReconnectRateLimiter {
pub fn new(delay_ms: u64) -> Self {
Self {
delay_ms,
connections: Arc::new(Mutex::new(HashMap::new())),
}
}
pub async fn get_turn(&self, url: &str) -> Duration {
let mut conns = self.connections.lock().await;
let state = conns.entry(url.to_string()).or_default();
state.pending_turns += 1;
let delay = if state.connected {
Duration::ZERO
} else {
let multiplier = (state.failure_count.min(10) + 1) as u64;
Duration::from_millis(self.delay_ms * multiplier)
};
let elapsed = state.last_attempt.elapsed();
if elapsed < delay {
delay - elapsed
} else {
Duration::ZERO
}
}
pub async fn on_conn(&self, url: &str) {
let mut conns = self.connections.lock().await;
let state = conns.entry(url.to_string()).or_default();
state.connected = true;
state.failure_count = 0;
state.last_attempt = Instant::now();
tracing::debug!("on_conn: {} - connected", url);
}
pub async fn on_diss(&self, url: &str) {
let mut conns = self.connections.lock().await;
let state = conns.entry(url.to_string()).or_default();
state.connected = false;
state.failure_count += 1;
state.last_attempt = Instant::now();
tracing::debug!(
"on_diss: {} - disconnected (failures: {})",
url,
state.failure_count
);
}
pub async fn is_connected(&self, url: &str) -> bool {
let conns = self.connections.lock().await;
conns.get(url).map(|s| s.connected).unwrap_or(false)
}
pub async fn failure_count(&self, url: &str) -> u32 {
let conns = self.connections.lock().await;
conns.get(url).map(|s| s.failure_count).unwrap_or(0)
}
pub async fn reset(&self, url: &str) {
let mut conns = self.connections.lock().await;
conns.remove(url);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_rate_limiter_basic() {
let limiter = ReconnectRateLimiter::new(100);
let delay = limiter.get_turn("http://example.com").await;
assert!(delay <= Duration::from_millis(100));
}
#[tokio::test]
async fn test_rate_limiter_on_conn_diss() {
let limiter = ReconnectRateLimiter::new(100);
limiter.on_conn("http://example.com").await;
assert!(limiter.is_connected("http://example.com").await);
limiter.on_diss("http://example.com").await;
assert!(!limiter.is_connected("http://example.com").await);
assert_eq!(limiter.failure_count("http://example.com").await, 1);
}
#[tokio::test]
async fn test_rate_limiter_exponential_backoff() {
let limiter = ReconnectRateLimiter::new(100);
for _ in 0..5 {
limiter.on_diss("http://example.com").await;
}
assert_eq!(limiter.failure_count("http://example.com").await, 5);
}
}