use std::collections::HashMap;
use std::num::NonZeroU32;
use std::sync::{Arc, RwLock};
use axum::{
body::Body,
http::{Method, StatusCode},
response::Response,
};
#[derive(Clone, Debug)]
pub struct RateLimitConfig {
pub get_experience_rpm: NonZeroU32,
pub post_experience_rpm: NonZeroU32,
pub key_management_rpm: NonZeroU32,
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
get_experience_rpm: NonZeroU32::new(100).unwrap(),
post_experience_rpm: NonZeroU32::new(30).unwrap(),
key_management_rpm: NonZeroU32::new(20).unwrap(),
}
}
}
#[derive(Clone)]
pub struct EndpointLimiter {
requests: Arc<RwLock<HashMap<String, Vec<u64>>>>,
max_requests: NonZeroU32,
window_secs: u64,
}
impl EndpointLimiter {
fn new(max_requests: NonZeroU32, window_secs: u64) -> Self {
Self {
requests: Arc::new(RwLock::new(HashMap::new())),
max_requests,
window_secs,
}
}
fn check(&self, key: &str, now_secs: u64) -> Result<(), u64> {
let window = self.window_secs;
let cutoff = now_secs.saturating_sub(window);
let mut requests = self.requests.write().unwrap();
if let Some(timestamps) = requests.get_mut(key) {
timestamps.retain(|&t| t > cutoff);
}
let count = requests.get(key).map(|v| v.len()).unwrap_or(0);
if count >= self.max_requests.get() as usize {
if let Some(timestamps) = requests.get(key) {
if let Some(&oldest) = timestamps.first() {
let retry_after = oldest + window - now_secs;
return Err(retry_after.max(1));
}
}
return Err(window);
}
requests
.entry(key.to_string())
.or_insert_with(Vec::new)
.push(now_secs);
Ok(())
}
}
#[derive(Clone)]
pub struct RateLimiterRegistry {
get_experience: EndpointLimiter,
post_experience: EndpointLimiter,
key_management: EndpointLimiter,
}
impl RateLimiterRegistry {
pub fn new(config: &RateLimitConfig) -> Self {
Self {
get_experience: EndpointLimiter::new(config.get_experience_rpm, 60),
post_experience: EndpointLimiter::new(config.post_experience_rpm, 60),
key_management: EndpointLimiter::new(config.key_management_rpm, 60),
}
}
pub fn get_limiter(&self, method: &Method, path: &str) -> Option<&EndpointLimiter> {
if path.starts_with("/experience") && method == Method::GET {
Some(&self.get_experience)
} else if path.starts_with("/experience") && method == Method::POST {
Some(&self.post_experience)
} else if path.starts_with("/keys") || path.starts_with("/public-keys") {
Some(&self.key_management)
} else {
None
}
}
pub fn check(
&self,
method: &Method,
path: &str,
client_id: &str,
now_secs: u64,
) -> Result<(), u64> {
match self.get_limiter(method, path) {
Some(limiter) => limiter.check(client_id, now_secs),
None => Ok(()),
}
}
}
pub fn extract_client_id(request: &axum::http::Request<axum::body::Body>) -> String {
if let Some(forwarded) = request.headers().get("x-forwarded-for") {
if let Ok(forwarded_str) = forwarded.to_str() {
if let Some(ip) = forwarded_str.split(',').next() {
return ip.trim().to_string();
}
}
}
if let Some(real_ip) = request.headers().get("x-real-ip") {
if let Ok(ip) = real_ip.to_str() {
return ip.to_string();
}
}
"default".to_string()
}
pub fn rate_limit_response(retry_after_secs: u64) -> Response<Body> {
Response::builder()
.status(StatusCode::TOO_MANY_REQUESTS)
.header("Retry-After", retry_after_secs.to_string())
.header("Content-Type", "application/json")
.body(Body::from(format!(
r#"{{"error":"rate limit exceeded","error_code":"RATE_LIMIT_EXCEEDED","retry_after":{}}}"#,
retry_after_secs
)))
.unwrap()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = RateLimitConfig::default();
assert_eq!(config.get_experience_rpm.get(), 100);
assert_eq!(config.post_experience_rpm.get(), 30);
assert_eq!(config.key_management_rpm.get(), 20);
}
#[test]
fn test_rate_limiter_allows_within_limit() {
let limiter = EndpointLimiter::new(NonZeroU32::new(5).unwrap(), 60);
let now = 1000u64;
for i in 0..5 {
let result = limiter.check(&format!("client-{}", i), now);
assert!(result.is_ok(), "request {} should be allowed", i);
}
}
#[test]
fn test_rate_limiter_blocks_over_limit() {
let limiter = EndpointLimiter::new(NonZeroU32::new(2).unwrap(), 60);
let now = 1000u64;
assert!(limiter.check("client", now).is_ok());
assert!(limiter.check("client", now).is_ok());
assert!(limiter.check("client", now).is_err());
}
}