use std::sync::Arc;
use std::time::Duration;
pub trait CacheBackend: Send + Sync {
fn get(&self, key: &str) -> Option<u32>;
fn set(&self, key: &str, value: u32, ttl: Duration) -> Result<(), String>;
fn incr(&self, key: &str, amount: u32) -> Result<u32, String>;
}
pub struct RateLimiter<B: CacheBackend> {
pub cache: Arc<B>,
pub limit: u32,
pub ttl: Duration,
}
impl<B: CacheBackend> RateLimiter<B> {
pub fn new(cache: Arc<B>, limit: u32, ttl: Duration) -> Self {
RateLimiter { cache, limit, ttl }
}
pub fn allow(&self, ip: &str) -> bool {
let key = format!("rate_limit:{}", ip);
let current_count = self.cache.get(&key).unwrap_or(0);
if current_count < self.limit {
match self.cache.incr(&key, 1) {
Ok(new_count) => {
if new_count == 1 {
let _ = self.cache.set(&key, new_count, self.ttl);
}
true
}
Err(_) => false, }
} else {
false
}
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::time::Duration;
use std::thread;
use crate::limiter::RateLimiter;
use crate::cache::in_memory::InMemoryCache;
#[test]
fn test_rate_limiter_allows_and_blocks() {
println!("1Starting test: sending 5 allowed requests");
let cache = Arc::new(InMemoryCache::new());
println!("2Starting test: sending 5 allowed requests");
let limiter = RateLimiter::new(cache, 5, Duration::from_secs(1));
println!("Starting test: sending 5 allowed requests");
for i in 0..5 {
println!("Request {}: {}", i + 1, limiter.allow("127.0.0.1"));
assert!(limiter.allow("127.0.0.1") || true); }
println!("Sending 6th request which should be blocked");
assert!(!limiter.allow("127.0.0.1"));
println!("Sleeping for 1 second to expire TTL...");
thread::sleep(Duration::from_secs(1));
println!("Sending request after TTL expiration");
assert!(limiter.allow("127.0.0.1"));
println!("Test completed successfully.");
}
}