use std::collections::HashMap;
use std::sync::Mutex;
use std::time::{Duration, Instant};
use crate::{Plugin, PluginError, RequestMeta};
use pylon_auth::AuthContext;
pub struct RateLimitPlugin {
max_requests: u32,
window: Duration,
counters: Mutex<HashMap<String, (u32, Instant)>>,
}
impl RateLimitPlugin {
pub fn new(max_requests: u32, window: Duration) -> Self {
Self {
max_requests,
window,
counters: Mutex::new(HashMap::new()),
}
}
fn check(&self, key: &str) -> Result<(), PluginError> {
let mut counters = self.counters.lock().unwrap();
let now = Instant::now();
let entry = counters.entry(key.to_string()).or_insert((0, now));
if now.duration_since(entry.1) > self.window {
*entry = (0, now);
}
entry.0 += 1;
if entry.0 > self.max_requests {
Err(PluginError {
code: "RATE_LIMITED".into(),
message: format!(
"Too many requests. Limit: {} per {:?}",
self.max_requests, self.window
),
status: 429,
})
} else {
Ok(())
}
}
pub fn check_request(&self, user_id: Option<&str>, peer_ip: &str) -> Result<(), PluginError> {
let key = match user_id {
Some(u) if !u.is_empty() => format!("user:{u}"),
_ if !peer_ip.is_empty() => format!("ip:{peer_ip}"),
_ => "__anon__".to_string(),
};
self.check(&key)
}
}
impl Plugin for RateLimitPlugin {
fn name(&self) -> &str {
"rate-limit"
}
fn on_request(
&self,
_method: &str,
_path: &str,
auth: &AuthContext,
) -> Result<(), PluginError> {
let key = auth.user_id.as_deref().unwrap_or("__anon__").to_string();
self.check(&key)
}
fn on_request_with_meta(
&self,
_method: &str,
_path: &str,
auth: &AuthContext,
meta: &RequestMeta<'_>,
) -> Result<(), PluginError> {
self.check_request(auth.user_id.as_deref(), meta.peer_ip)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn allows_under_limit() {
let plugin = RateLimitPlugin::new(3, Duration::from_secs(60));
let auth = AuthContext::anonymous();
assert!(plugin.on_request("GET", "/api/test", &auth).is_ok());
assert!(plugin.on_request("GET", "/api/test", &auth).is_ok());
assert!(plugin.on_request("GET", "/api/test", &auth).is_ok());
}
#[test]
fn different_ips_use_different_buckets() {
let plugin = RateLimitPlugin::new(2, Duration::from_secs(60));
assert!(plugin.check_request(None, "1.1.1.1").is_ok());
assert!(plugin.check_request(None, "1.1.1.1").is_ok());
assert!(plugin.check_request(None, "1.1.1.1").is_err());
assert!(plugin.check_request(None, "2.2.2.2").is_ok());
assert!(plugin.check_request(None, "2.2.2.2").is_ok());
}
#[test]
fn user_id_preferred_over_ip() {
let plugin = RateLimitPlugin::new(2, Duration::from_secs(60));
assert!(plugin.check_request(Some("alice"), "1.1.1.1").is_ok());
assert!(plugin.check_request(Some("alice"), "2.2.2.2").is_ok());
assert!(plugin.check_request(Some("alice"), "3.3.3.3").is_err());
}
#[test]
fn blocks_over_limit() {
let plugin = RateLimitPlugin::new(2, Duration::from_secs(60));
let auth = AuthContext::anonymous();
assert!(plugin.on_request("GET", "/", &auth).is_ok());
assert!(plugin.on_request("GET", "/", &auth).is_ok());
let result = plugin.on_request("GET", "/", &auth);
assert!(result.is_err());
assert_eq!(result.unwrap_err().code, "RATE_LIMITED");
}
#[test]
fn separate_users_separate_limits() {
let plugin = RateLimitPlugin::new(1, Duration::from_secs(60));
let alice = AuthContext::authenticated("alice".into());
let bob = AuthContext::authenticated("bob".into());
assert!(plugin.on_request("GET", "/", &alice).is_ok());
assert!(plugin.on_request("GET", "/", &bob).is_ok());
assert!(plugin.on_request("GET", "/", &alice).is_err());
assert!(plugin.on_request("GET", "/", &bob).is_err());
}
}