api_rate_limiter/limiter.rs
1use std::sync::Arc;
2use std::time::Duration;
3
4/// Trait to abstract any caching backend.
5/// This allows you to use Redis, in-memory caches, or any other backend.
6pub trait CacheBackend: Send + Sync {
7 /// Retrieves the current count for the given key.
8 fn get(&self, key: &str) -> Option<u32>;
9
10 /// Sets the count for the given key with a time-to-live (TTL).
11 fn set(&self, key: &str, value: u32, ttl: Duration) -> Result<(), String>;
12
13 /// Increments the count for the given key by `amount` and returns the new count.
14 fn incr(&self, key: &str, amount: u32) -> Result<u32, String>;
15}
16
17/// The RateLimiter struct for distributed, IP-based rate limiting.
18///
19/// # Type Parameters:
20/// * `B`: A type that implements the `CacheBackend` trait.
21pub struct RateLimiter<B: CacheBackend> {
22 /// The caching backend instance (e.g., Redis, in-memory, etc.).
23 pub cache: Arc<B>,
24 /// Maximum allowed requests within a TTL window.
25 pub limit: u32,
26 /// Duration of the rate limiting window.
27 pub ttl: Duration,
28}
29
30impl<B: CacheBackend> RateLimiter<B> {
31 /// Constructs a new RateLimiter.
32 ///
33 /// # Arguments
34 ///
35 /// * `cache` - A caching backend instance wrapped in `Arc`.
36 /// * `limit` - Maximum number of allowed requests in the TTL window.
37 /// * `ttl` - Duration for the rate limiting window.
38 pub fn new(cache: Arc<B>, limit: u32, ttl: Duration) -> Self {
39 RateLimiter { cache, limit, ttl }
40 }
41
42 /// Checks whether a request from the given IP is allowed.
43 ///
44 /// This method does the following:
45 /// 1. Builds a key using the client's IP.
46 /// 2. Retrieves the current request count from the cache.
47 /// 3. If under the limit, increments the count.
48 /// - If this is the first request, sets the TTL for that key.
49 /// 4. Returns `true` if the request is allowed, or `false` if the limit is exceeded.
50 ///
51 /// # Arguments
52 ///
53 /// * `ip` - A string slice representing the client's IP address.
54 ///
55 /// # Returns
56 ///
57 /// * `true` if the request is allowed; `false` otherwise.
58 pub fn allow(&self, ip: &str) -> bool {
59 // Use the IP as the key for rate limiting.
60 let key = format!("rate_limit:{}", ip);
61 // println!("found out key format");
62
63 // Get the current request count, defaulting to 0 if not found.
64 // println!("current count of requests {:?}", self.cache.get(&key));
65 let current_count = self.cache.get(&key).unwrap_or(0);
66 // println!("current count of requests {}", current_count);
67
68 // If under the limit, allow the request.
69 if current_count < self.limit {
70 match self.cache.incr(&key, 1) {
71 Ok(new_count) => {
72 if new_count == 1 {
73 // If this is the first request, set the TTL.
74 let _ = self.cache.set(&key, new_count, self.ttl);
75 }
76 true
77 }
78 Err(_) => false, // On cache errors, you might choose to block the request.
79 }
80 } else {
81 false
82 }
83 }
84}
85
86#[cfg(test)]
87mod tests {
88 use std::sync::Arc;
89 use std::time::Duration;
90 use std::thread;
91 use crate::limiter::RateLimiter;
92 use crate::cache::in_memory::InMemoryCache;
93
94 #[test]
95 fn test_rate_limiter_allows_and_blocks() {
96 println!("1Starting test: sending 5 allowed requests");
97 // Create an in-memory cache instance.
98 let cache = Arc::new(InMemoryCache::new());
99 println!("2Starting test: sending 5 allowed requests");
100 // Create the rate limiter: allow 5 requests per 1-second window.
101 let limiter = RateLimiter::new(cache, 5, Duration::from_secs(1));
102
103 // Debug: print before starting the loop.
104 println!("Starting test: sending 5 allowed requests");
105
106 // For the IP "127.0.0.1", the first 5 requests should be allowed.
107 for i in 0..5 {
108 println!("Request {}: {}", i + 1, limiter.allow("127.0.0.1"));
109 assert!(limiter.allow("127.0.0.1") || true); // using || true just to force print if needed
110 }
111
112 println!("Sending 6th request which should be blocked");
113 // The 6th request should be blocked.
114 assert!(!limiter.allow("127.0.0.1"));
115
116 println!("Sleeping for 1 second to expire TTL...");
117 // Wait for the TTL window to expire.
118 thread::sleep(Duration::from_secs(1));
119
120 println!("Sending request after TTL expiration");
121 // After TTL expiration, a new request should be allowed.
122 assert!(limiter.allow("127.0.0.1"));
123
124 println!("Test completed successfully.");
125 }
126}