Skip to main content

shared/utils/rate_limiting/
token_bucket.rs

1use std::sync::Arc;
2
3use crate::{config::RateLimitConfig, error::CoreError};
4use dashmap::DashMap;
5
6#[derive(Clone, Default)]
7pub struct RateLimiter {
8    buckets: Arc<DashMap<String, TokenBucket>>,
9}
10impl RateLimiter {
11    pub fn check(&self, key: &str, config: &RateLimitConfig) -> Result<(), CoreError> {
12        let bucket = TokenBucket::default()
13            .with_bucket_size(config.capacity)
14            .with_refill_rate(config.capacity)
15            .with_duration(config.duration_minutes);
16
17        let mut bucket = self.buckets.entry(key.to_string()).or_insert(bucket);
18        bucket.run()
19    }
20}
21
22#[derive(Clone, Debug)]
23pub struct TokenBucket {
24    capacity: u32,
25    bucket_size: u32,
26    refill_rate: u32,
27    duration: chrono::Duration,
28    last_refill: chrono::DateTime<chrono::Utc>,
29}
30
31impl Default for TokenBucket {
32    fn default() -> Self {
33        Self {
34            capacity: 20,
35            bucket_size: 20,
36            refill_rate: 20,
37            duration: chrono::Duration::minutes(15),
38            last_refill: chrono::Utc::now(),
39        }
40    }
41}
42
43impl TokenBucket {
44    pub fn with_bucket_size(mut self, bucket_size: u32) -> Self {
45        self.bucket_size = bucket_size;
46        self.capacity = bucket_size;
47        self
48    }
49    pub fn with_refill_rate(mut self, refill_rate: u32) -> Self {
50        self.refill_rate = refill_rate;
51        self
52    }
53    pub fn with_duration(mut self, minutes: u32) -> Self {
54        self.duration = chrono::Duration::minutes(minutes as i64);
55        self
56    }
57}
58impl TokenBucket {
59    pub fn run(&mut self) -> Result<(), CoreError> {
60        let time_since_refill = chrono::Utc::now() - self.last_refill;
61        if time_since_refill >= self.duration {
62            self.bucket_size = self.refill_rate;
63            self.last_refill = chrono::Utc::now();
64        }
65
66        if self.bucket_size == 0 {
67            return Err(CoreError::RateLimitExceeded {
68                limit: self.capacity,
69                window: self.duration,
70            });
71        }
72
73        self.bucket_size -= 1;
74        Ok(())
75    }
76}