braid_core/fs/
rate_limiter.rs1use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9use tokio::sync::Mutex;
10
11#[derive(Debug)]
13pub struct ReconnectRateLimiter {
14 delay_ms: u64,
16 connections: Arc<Mutex<HashMap<String, ConnectionState>>>,
18}
19
20#[derive(Debug, Clone)]
21struct ConnectionState {
22 connected: bool,
24 last_attempt: Instant,
26 failure_count: u32,
28 pending_turns: u32,
30}
31
32impl Default for ConnectionState {
33 fn default() -> Self {
34 Self {
35 connected: false,
36 last_attempt: Instant::now(),
37 failure_count: 0,
38 pending_turns: 0,
39 }
40 }
41}
42
43impl ReconnectRateLimiter {
44 pub fn new(delay_ms: u64) -> Self {
46 Self {
47 delay_ms,
48 connections: Arc::new(Mutex::new(HashMap::new())),
49 }
50 }
51
52 pub async fn get_turn(&self, url: &str) -> Duration {
56 let mut conns = self.connections.lock().await;
57 let state = conns.entry(url.to_string()).or_default();
58
59 state.pending_turns += 1;
60
61 let delay = if state.connected {
63 Duration::ZERO
64 } else {
65 let multiplier = (state.failure_count.min(10) + 1) as u64;
66 Duration::from_millis(self.delay_ms * multiplier)
67 };
68
69 let elapsed = state.last_attempt.elapsed();
71 if elapsed < delay {
72 delay - elapsed
73 } else {
74 Duration::ZERO
75 }
76 }
77
78 pub async fn on_conn(&self, url: &str) {
80 let mut conns = self.connections.lock().await;
81 let state = conns.entry(url.to_string()).or_default();
82
83 state.connected = true;
84 state.failure_count = 0;
85 state.last_attempt = Instant::now();
86
87 tracing::debug!("on_conn: {} - connected", url);
88 }
89
90 pub async fn on_diss(&self, url: &str) {
92 let mut conns = self.connections.lock().await;
93 let state = conns.entry(url.to_string()).or_default();
94
95 state.connected = false;
96 state.failure_count += 1;
97 state.last_attempt = Instant::now();
98
99 tracing::debug!(
100 "on_diss: {} - disconnected (failures: {})",
101 url,
102 state.failure_count
103 );
104 }
105
106 pub async fn is_connected(&self, url: &str) -> bool {
108 let conns = self.connections.lock().await;
109 conns.get(url).map(|s| s.connected).unwrap_or(false)
110 }
111
112 pub async fn failure_count(&self, url: &str) -> u32 {
114 let conns = self.connections.lock().await;
115 conns.get(url).map(|s| s.failure_count).unwrap_or(0)
116 }
117
118 pub async fn reset(&self, url: &str) {
120 let mut conns = self.connections.lock().await;
121 conns.remove(url);
122 }
123}
124
125#[cfg(test)]
126mod tests {
127 use super::*;
128
129 #[tokio::test]
130 async fn test_rate_limiter_basic() {
131 let limiter = ReconnectRateLimiter::new(100);
132
133 let delay = limiter.get_turn("http://example.com").await;
135 assert!(delay <= Duration::from_millis(100));
136 }
137
138 #[tokio::test]
139 async fn test_rate_limiter_on_conn_diss() {
140 let limiter = ReconnectRateLimiter::new(100);
141
142 limiter.on_conn("http://example.com").await;
143 assert!(limiter.is_connected("http://example.com").await);
144
145 limiter.on_diss("http://example.com").await;
146 assert!(!limiter.is_connected("http://example.com").await);
147 assert_eq!(limiter.failure_count("http://example.com").await, 1);
148 }
149
150 #[tokio::test]
151 async fn test_rate_limiter_exponential_backoff() {
152 let limiter = ReconnectRateLimiter::new(100);
153
154 for _ in 0..5 {
156 limiter.on_diss("http://example.com").await;
157 }
158
159 assert_eq!(limiter.failure_count("http://example.com").await, 5);
160 }
161}