Skip to main content

braid_core/fs/
rate_limiter.rs

1//! Reconnection rate limiter for BraidFS.
2//!
3//! Prevents too-rapid reconnection attempts that could overload servers.
4//! Matches JS `ReconnectRateLimiter` from braidfs/index.js.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9use tokio::sync::Mutex;
10
11/// Rate limiter for reconnection attempts.
12#[derive(Debug)]
13pub struct ReconnectRateLimiter {
14    /// Base delay between reconnection attempts in milliseconds.
15    delay_ms: u64,
16    /// Track connection state per URL.
17    connections: Arc<Mutex<HashMap<String, ConnectionState>>>,
18}
19
20#[derive(Debug, Clone)]
21struct ConnectionState {
22    /// Whether currently connected.
23    connected: bool,
24    /// Last connection attempt time.
25    last_attempt: Instant,
26    /// Number of consecutive failures.
27    failure_count: u32,
28    /// Queue of pending connection requests.
29    pending_turns: u32,
30}
31
32impl Default for ConnectionState {
33    fn default() -> Self {
34        Self {
35            connected: false,
36            last_attempt: Instant::now(),
37            failure_count: 0,
38            pending_turns: 0,
39        }
40    }
41}
42
43impl ReconnectRateLimiter {
44    /// Create a new rate limiter with the given base delay.
45    pub fn new(delay_ms: u64) -> Self {
46        Self {
47            delay_ms,
48            connections: Arc::new(Mutex::new(HashMap::new())),
49        }
50    }
51
52    /// Get a "turn" to attempt a connection.
53    ///
54    /// This may wait if too many rapid attempts have been made.
55    pub async fn get_turn(&self, url: &str) -> Duration {
56        let mut conns = self.connections.lock().await;
57        let state = conns.entry(url.to_string()).or_default();
58
59        state.pending_turns += 1;
60
61        // Calculate delay based on failure count
62        let delay = if state.connected {
63            Duration::ZERO
64        } else {
65            let multiplier = (state.failure_count.min(10) + 1) as u64;
66            Duration::from_millis(self.delay_ms * multiplier)
67        };
68
69        // Check if we need to wait
70        let elapsed = state.last_attempt.elapsed();
71        if elapsed < delay {
72            delay - elapsed
73        } else {
74            Duration::ZERO
75        }
76    }
77
78    /// Called when a connection is established.
79    pub async fn on_conn(&self, url: &str) {
80        let mut conns = self.connections.lock().await;
81        let state = conns.entry(url.to_string()).or_default();
82
83        state.connected = true;
84        state.failure_count = 0;
85        state.last_attempt = Instant::now();
86
87        tracing::debug!("on_conn: {} - connected", url);
88    }
89
90    /// Called when a connection is disconnected.
91    pub async fn on_diss(&self, url: &str) {
92        let mut conns = self.connections.lock().await;
93        let state = conns.entry(url.to_string()).or_default();
94
95        state.connected = false;
96        state.failure_count += 1;
97        state.last_attempt = Instant::now();
98
99        tracing::debug!(
100            "on_diss: {} - disconnected (failures: {})",
101            url,
102            state.failure_count
103        );
104    }
105
106    /// Check if a URL is currently connected.
107    pub async fn is_connected(&self, url: &str) -> bool {
108        let conns = self.connections.lock().await;
109        conns.get(url).map(|s| s.connected).unwrap_or(false)
110    }
111
112    /// Get the current failure count for a URL.
113    pub async fn failure_count(&self, url: &str) -> u32 {
114        let conns = self.connections.lock().await;
115        conns.get(url).map(|s| s.failure_count).unwrap_or(0)
116    }
117
118    /// Reset the state for a URL.
119    pub async fn reset(&self, url: &str) {
120        let mut conns = self.connections.lock().await;
121        conns.remove(url);
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use super::*;
128
129    #[tokio::test]
130    async fn test_rate_limiter_basic() {
131        let limiter = ReconnectRateLimiter::new(100);
132
133        // First connection should have no delay
134        let delay = limiter.get_turn("http://example.com").await;
135        assert!(delay <= Duration::from_millis(100));
136    }
137
138    #[tokio::test]
139    async fn test_rate_limiter_on_conn_diss() {
140        let limiter = ReconnectRateLimiter::new(100);
141
142        limiter.on_conn("http://example.com").await;
143        assert!(limiter.is_connected("http://example.com").await);
144
145        limiter.on_diss("http://example.com").await;
146        assert!(!limiter.is_connected("http://example.com").await);
147        assert_eq!(limiter.failure_count("http://example.com").await, 1);
148    }
149
150    #[tokio::test]
151    async fn test_rate_limiter_exponential_backoff() {
152        let limiter = ReconnectRateLimiter::new(100);
153
154        // Simulate multiple failures
155        for _ in 0..5 {
156            limiter.on_diss("http://example.com").await;
157        }
158
159        assert_eq!(limiter.failure_count("http://example.com").await, 5);
160    }
161}