api_rate_limiter/
limiter.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4/// Trait to abstract any caching backend.
5/// This allows you to use Redis, in-memory caches, or any other backend.
6pub trait CacheBackend: Send + Sync {
7    /// Retrieves the current count for the given key.
8    fn get(&self, key: &str) -> Option<u32>;
9
10    /// Sets the count for the given key with a time-to-live (TTL).
11    fn set(&self, key: &str, value: u32, ttl: Duration) -> Result<(), String>;
12
13    /// Increments the count for the given key by `amount` and returns the new count.
14    fn incr(&self, key: &str, amount: u32) -> Result<u32, String>;
15}
16
17/// The RateLimiter struct for distributed, IP-based rate limiting.
18///
19/// # Type Parameters:
20/// * `B`: A type that implements the `CacheBackend` trait.
21pub struct RateLimiter<B: CacheBackend> {
22    /// The caching backend instance (e.g., Redis, in-memory, etc.).
23    pub cache: Arc<B>,
24    /// Maximum allowed requests within a TTL window.
25    pub limit: u32,
26    /// Duration of the rate limiting window.
27    pub ttl: Duration,
28}
29
30impl<B: CacheBackend> RateLimiter<B> {
31    /// Constructs a new RateLimiter.
32    ///
33    /// # Arguments
34    ///
35    /// * `cache` - A caching backend instance wrapped in `Arc`.
36    /// * `limit` - Maximum number of allowed requests in the TTL window.
37    /// * `ttl` - Duration for the rate limiting window.
38    pub fn new(cache: Arc<B>, limit: u32, ttl: Duration) -> Self {
39        RateLimiter { cache, limit, ttl }
40    }
41
42    /// Checks whether a request from the given IP is allowed.
43    ///
44    /// This method does the following:
45    /// 1. Builds a key using the client's IP.
46    /// 2. Retrieves the current request count from the cache.
47    /// 3. If under the limit, increments the count.
48    ///    - If this is the first request, sets the TTL for that key.
49    /// 4. Returns `true` if the request is allowed, or `false` if the limit is exceeded.
50    ///
51    /// # Arguments
52    ///
53    /// * `ip` - A string slice representing the client's IP address.
54    ///
55    /// # Returns
56    ///
57    /// * `true` if the request is allowed; `false` otherwise.
58    pub fn allow(&self, ip: &str) -> bool {
59        // Use the IP as the key for rate limiting.
60        let key = format!("rate_limit:{}", ip);
61        // println!("found out key format");
62        
63        // Get the current request count, defaulting to 0 if not found.
64        // println!("current count of requests {:?}", self.cache.get(&key));
65        let current_count = self.cache.get(&key).unwrap_or(0);
66        // println!("current count of requests {}", current_count);
67
68        // If under the limit, allow the request.
69        if current_count < self.limit {
70            match self.cache.incr(&key, 1) {
71                Ok(new_count) => {
72                    if new_count == 1 {
73                        // If this is the first request, set the TTL.
74                        let _ = self.cache.set(&key, new_count, self.ttl);
75                    }
76                    true
77                }
78                Err(_) => false, // On cache errors, you might choose to block the request.
79            }
80        } else {
81            false
82        }
83    }
84}
85
86#[cfg(test)]
87mod tests {
88    use std::sync::Arc;
89    use std::time::Duration;
90    use std::thread;
91    use crate::limiter::RateLimiter;
92    use crate::cache::in_memory::InMemoryCache;
93
94    #[test]
95    fn test_rate_limiter_allows_and_blocks() {
96        println!("1Starting test: sending 5 allowed requests");
97        // Create an in-memory cache instance.
98        let cache = Arc::new(InMemoryCache::new());
99        println!("2Starting test: sending 5 allowed requests");
100        // Create the rate limiter: allow 5 requests per 1-second window.
101        let limiter = RateLimiter::new(cache, 5, Duration::from_secs(1));
102
103        // Debug: print before starting the loop.
104        println!("Starting test: sending 5 allowed requests");
105
106        // For the IP "127.0.0.1", the first 5 requests should be allowed.
107        for i in 0..5 {
108            println!("Request {}: {}", i + 1, limiter.allow("127.0.0.1"));
109            assert!(limiter.allow("127.0.0.1") || true); // using || true just to force print if needed
110        }
111
112        println!("Sending 6th request which should be blocked");
113        // The 6th request should be blocked.
114        assert!(!limiter.allow("127.0.0.1"));
115
116        println!("Sleeping for 1 second to expire TTL...");
117        // Wait for the TTL window to expire.
118        thread::sleep(Duration::from_secs(1));
119
120        println!("Sending request after TTL expiration");
121        // After TTL expiration, a new request should be allowed.
122        assert!(limiter.allow("127.0.0.1"));
123
124        println!("Test completed successfully.");
125    }
126}