tx2_link/
rate_limit.rs

1use crate::error::{LinkError, Result};
2use std::time::{Duration, Instant};
3use std::collections::VecDeque;
4
5#[derive(Debug, Clone)]
6pub struct RateLimitConfig {
7    pub max_messages_per_second: u32,
8    pub max_bytes_per_second: u64,
9    pub burst_size: u32,
10    pub window_duration: Duration,
11}
12
13impl Default for RateLimitConfig {
14    fn default() -> Self {
15        Self {
16            max_messages_per_second: 1000,
17            max_bytes_per_second: 10 * 1024 * 1024,
18            burst_size: 100,
19            window_duration: Duration::from_secs(1),
20        }
21    }
22}
23
24impl RateLimitConfig {
25    pub fn new() -> Self {
26        Self::default()
27    }
28
29    pub fn with_max_messages(mut self, max: u32) -> Self {
30        self.max_messages_per_second = max;
31        self
32    }
33
34    pub fn with_max_bytes(mut self, max: u64) -> Self {
35        self.max_bytes_per_second = max;
36        self
37    }
38
39    pub fn with_burst_size(mut self, size: u32) -> Self {
40        self.burst_size = size;
41        self
42    }
43
44    pub fn with_window_duration(mut self, duration: Duration) -> Self {
45        self.window_duration = duration;
46        self
47    }
48}
49
50struct MessageRecord {
51    timestamp: Instant,
52    size: u64,
53}
54
55pub struct RateLimiter {
56    config: RateLimitConfig,
57    message_history: VecDeque<MessageRecord>,
58    byte_history: VecDeque<MessageRecord>,
59    total_messages: u64,
60    total_bytes: u64,
61    total_rejected: u64,
62}
63
64impl RateLimiter {
65    pub fn new(config: RateLimitConfig) -> Self {
66        Self {
67            config,
68            message_history: VecDeque::new(),
69            byte_history: VecDeque::new(),
70            total_messages: 0,
71            total_bytes: 0,
72            total_rejected: 0,
73        }
74    }
75
76    pub fn check_and_record(&mut self, message_size: u64) -> Result<()> {
77        let now = Instant::now();
78
79        self.cleanup_old_records(now);
80
81        let messages_in_window = self.count_messages_in_window(now);
82        let bytes_in_window = self.count_bytes_in_window(now);
83
84        if messages_in_window >= self.config.max_messages_per_second {
85            self.total_rejected += 1;
86            return Err(LinkError::RateLimitExceeded(
87                format!("Message rate limit exceeded: {} msgs/sec", self.config.max_messages_per_second)
88            ));
89        }
90
91        if bytes_in_window + message_size > self.config.max_bytes_per_second {
92            self.total_rejected += 1;
93            return Err(LinkError::RateLimitExceeded(
94                format!("Byte rate limit exceeded: {} bytes/sec", self.config.max_bytes_per_second)
95            ));
96        }
97
98        let burst_count = self.count_recent_burst(now);
99        if burst_count >= self.config.burst_size {
100            self.total_rejected += 1;
101            return Err(LinkError::RateLimitExceeded(
102                format!("Burst limit exceeded: {} msgs", self.config.burst_size)
103            ));
104        }
105
106        self.record_message(now, message_size);
107
108        Ok(())
109    }
110
111    pub fn check(&mut self, message_size: u64) -> bool {
112        self.check_and_record(message_size).is_ok()
113    }
114
115    fn record_message(&mut self, timestamp: Instant, size: u64) {
116        let record = MessageRecord {
117            timestamp,
118            size,
119        };
120
121        self.message_history.push_back(record.clone());
122        self.byte_history.push_back(record);
123
124        self.total_messages += 1;
125        self.total_bytes += size;
126    }
127
128    fn cleanup_old_records(&mut self, now: Instant) {
129        let cutoff = now - self.config.window_duration;
130
131        while let Some(record) = self.message_history.front() {
132            if record.timestamp < cutoff {
133                self.message_history.pop_front();
134            } else {
135                break;
136            }
137        }
138
139        while let Some(record) = self.byte_history.front() {
140            if record.timestamp < cutoff {
141                self.byte_history.pop_front();
142            } else {
143                break;
144            }
145        }
146    }
147
148    fn count_messages_in_window(&self, now: Instant) -> u32 {
149        let cutoff = now - self.config.window_duration;
150        self.message_history.iter()
151            .filter(|r| r.timestamp >= cutoff)
152            .count() as u32
153    }
154
155    fn count_bytes_in_window(&self, now: Instant) -> u64 {
156        let cutoff = now - self.config.window_duration;
157        self.byte_history.iter()
158            .filter(|r| r.timestamp >= cutoff)
159            .map(|r| r.size)
160            .sum()
161    }
162
163    fn count_recent_burst(&self, now: Instant) -> u32 {
164        let burst_window = Duration::from_millis(100);
165        let cutoff = now - burst_window;
166
167        self.message_history.iter()
168            .filter(|r| r.timestamp >= cutoff)
169            .count() as u32
170    }
171
172    pub fn reset(&mut self) {
173        self.message_history.clear();
174        self.byte_history.clear();
175    }
176
177    pub fn get_stats(&self) -> RateLimitStats {
178        RateLimitStats {
179            total_messages: self.total_messages,
180            total_bytes: self.total_bytes,
181            total_rejected: self.total_rejected,
182            messages_in_window: self.message_history.len() as u32,
183            bytes_in_window: self.byte_history.iter().map(|r| r.size).sum(),
184        }
185    }
186
187    pub fn get_config(&self) -> &RateLimitConfig {
188        &self.config
189    }
190
191    pub fn set_config(&mut self, config: RateLimitConfig) {
192        self.config = config;
193    }
194}
195
196impl Clone for MessageRecord {
197    fn clone(&self) -> Self {
198        Self {
199            timestamp: self.timestamp,
200            size: self.size,
201        }
202    }
203}
204
205#[derive(Debug, Clone)]
206pub struct RateLimitStats {
207    pub total_messages: u64,
208    pub total_bytes: u64,
209    pub total_rejected: u64,
210    pub messages_in_window: u32,
211    pub bytes_in_window: u64,
212}
213
214pub struct TokenBucketRateLimiter {
215    capacity: u32,
216    tokens: u32,
217    refill_rate: u32,
218    last_refill: Instant,
219    total_messages: u64,
220    total_rejected: u64,
221}
222
223impl TokenBucketRateLimiter {
224    pub fn new(capacity: u32, refill_rate: u32) -> Self {
225        Self {
226            capacity,
227            tokens: capacity,
228            refill_rate,
229            last_refill: Instant::now(),
230            total_messages: 0,
231            total_rejected: 0,
232        }
233    }
234
235    pub fn check_and_consume(&mut self) -> Result<()> {
236        self.refill();
237
238        if self.tokens == 0 {
239            self.total_rejected += 1;
240            return Err(LinkError::RateLimitExceeded(
241                format!("Token bucket empty (capacity: {})", self.capacity)
242            ));
243        }
244
245        self.tokens -= 1;
246        self.total_messages += 1;
247
248        Ok(())
249    }
250
251    pub fn check(&mut self) -> bool {
252        self.check_and_consume().is_ok()
253    }
254
255    fn refill(&mut self) {
256        let now = Instant::now();
257        let elapsed = now.duration_since(self.last_refill);
258        let elapsed_secs = elapsed.as_secs_f64();
259
260        let tokens_to_add = (elapsed_secs * self.refill_rate as f64) as u32;
261
262        if tokens_to_add > 0 {
263            self.tokens = (self.tokens + tokens_to_add).min(self.capacity);
264            self.last_refill = now;
265        }
266    }
267
268    pub fn reset(&mut self) {
269        self.tokens = self.capacity;
270        self.last_refill = Instant::now();
271    }
272
273    pub fn get_available_tokens(&self) -> u32 {
274        self.tokens
275    }
276
277    pub fn get_stats(&self) -> (u64, u64) {
278        (self.total_messages, self.total_rejected)
279    }
280}
281
282#[cfg(test)]
283mod tests {
284    use super::*;
285    use std::thread;
286
287    #[test]
288    fn test_rate_limiter_basic() {
289        let config = RateLimitConfig::new()
290            .with_max_messages(10)
291            .with_max_bytes(1000);
292
293        let mut limiter = RateLimiter::new(config);
294
295        for _ in 0..10 {
296            assert!(limiter.check_and_record(50).is_ok());
297        }
298
299        assert!(limiter.check_and_record(50).is_err());
300    }
301
302    #[test]
303    fn test_rate_limiter_byte_limit() {
304        let config = RateLimitConfig::new()
305            .with_max_messages(100)
306            .with_max_bytes(500);
307
308        let mut limiter = RateLimiter::new(config);
309
310        assert!(limiter.check_and_record(300).is_ok());
311        assert!(limiter.check_and_record(300).is_err());
312    }
313
314    #[test]
315    fn test_rate_limiter_burst() {
316        let config = RateLimitConfig::new()
317            .with_max_messages(1000)
318            .with_burst_size(5);
319
320        let mut limiter = RateLimiter::new(config);
321
322        for _ in 0..5 {
323            assert!(limiter.check_and_record(100).is_ok());
324        }
325
326        assert!(limiter.check_and_record(100).is_err());
327    }
328
329    #[test]
330    fn test_rate_limiter_window() {
331        let config = RateLimitConfig::new()
332            .with_max_messages(5)
333            .with_window_duration(Duration::from_millis(100));
334
335        let mut limiter = RateLimiter::new(config);
336
337        for _ in 0..5 {
338            assert!(limiter.check_and_record(100).is_ok());
339        }
340
341        assert!(limiter.check_and_record(100).is_err());
342
343        thread::sleep(Duration::from_millis(150));
344
345        assert!(limiter.check_and_record(100).is_ok());
346    }
347
348    #[test]
349    fn test_token_bucket() {
350        let mut limiter = TokenBucketRateLimiter::new(5, 10);
351
352        for _ in 0..5 {
353            assert!(limiter.check_and_consume().is_ok());
354        }
355
356        assert!(limiter.check_and_consume().is_err());
357
358        thread::sleep(Duration::from_millis(100));
359        limiter.refill();
360
361        assert!(limiter.check_and_consume().is_ok());
362    }
363
364    #[test]
365    fn test_rate_limiter_stats() {
366        let config = RateLimitConfig::new().with_max_messages(5);
367        let mut limiter = RateLimiter::new(config);
368
369        for _ in 0..3 {
370            let _ = limiter.check_and_record(100);
371        }
372
373        for _ in 0..3 {
374            let _ = limiter.check_and_record(100);
375        }
376
377        let stats = limiter.get_stats();
378        assert_eq!(stats.total_messages, 5);
379        assert_eq!(stats.total_rejected, 1);
380    }
381}