#[cfg(feature = "portal")]
use axum::{
extract::{Request, State},
http::{HeaderMap, HeaderValue, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
};
#[cfg(feature = "portal")]
use serde::{Deserialize, Serialize};
#[cfg(feature = "portal")]
use std::{
collections::HashMap,
sync::Arc,
time::{SystemTime, UNIX_EPOCH},
};
#[cfg(feature = "portal")]
use tokio::sync::RwLock;
#[cfg(feature = "portal")]
use tracing::{debug, error, warn};
#[cfg(feature = "portal")]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RateLimitConfig {
pub default_limit: u32,
pub window_size_secs: u64,
pub redis_url: Option<String>,
}
#[cfg(feature = "portal")]
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
default_limit: 60,
window_size_secs: 60,
redis_url: None,
}
}
}
#[cfg(feature = "portal")]
#[derive(Debug, Clone, Default)]
struct RateLimitEntry {
timestamps: Vec<u64>,
}
#[cfg(feature = "portal")]
impl RateLimitEntry {
fn prune(&mut self, window_start: u64) {
self.timestamps.retain(|&ts| ts >= window_start);
}
fn count(&self) -> usize {
self.timestamps.len()
}
fn add(&mut self, timestamp: u64) {
self.timestamps.push(timestamp);
}
}
#[cfg(feature = "portal")]
#[derive(Clone)]
enum RateLimitBackend {
Memory(Arc<RwLock<HashMap<String, RateLimitEntry>>>),
}
#[cfg(feature = "portal")]
#[derive(Clone)]
pub struct RateLimiter {
config: RateLimitConfig,
backend: RateLimitBackend,
}
#[cfg(feature = "portal")]
impl RateLimiter {
pub async fn new(config: RateLimitConfig) -> Self {
let backend = if config.redis_url.is_some() {
warn!("Redis URL configured but redis feature not enabled. Using in-memory cache.");
RateLimitBackend::Memory(Arc::new(RwLock::new(HashMap::new())))
} else {
debug!("Using in-memory rate limiting.");
RateLimitBackend::Memory(Arc::new(RwLock::new(HashMap::new())))
};
Self { config, backend }
}
pub async fn check_rate_limit(&self, api_key: &str, limit: u32) -> Result<u32, u64> {
let now = Self::current_timestamp();
let window_start = now.saturating_sub(self.config.window_size_secs);
match &self.backend {
RateLimitBackend::Memory(cache) => {
self.check_memory(cache.clone(), api_key, limit, now, window_start)
.await
}
}
}
async fn check_memory(
&self,
cache: Arc<RwLock<HashMap<String, RateLimitEntry>>>,
api_key: &str,
limit: u32,
now: u64,
window_start: u64,
) -> Result<u32, u64> {
let mut cache = cache.write().await;
let entry = cache.entry(api_key.to_string()).or_default();
entry.prune(window_start);
let count = entry.count();
if count >= limit as usize {
let retry_after = self.config.window_size_secs;
debug!(
"Rate limit exceeded for API key {}. Retry after {} seconds.",
api_key, retry_after
);
return Err(retry_after);
}
entry.add(now);
Ok(limit.saturating_sub(count as u32 + 1))
}
fn current_timestamp() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
pub fn default_limit(&self) -> u32 {
self.config.default_limit
}
pub fn window_size_secs(&self) -> u64 {
self.config.window_size_secs
}
}
#[cfg(feature = "portal")]
#[derive(Debug)]
pub struct RateLimitError {
pub status: StatusCode,
pub headers: HeaderMap,
pub message: String,
}
#[cfg(feature = "portal")]
impl IntoResponse for RateLimitError {
fn into_response(self) -> Response {
let body = serde_json::json!({ "error": self.message });
let mut response = (self.status, axum::Json(body)).into_response();
response.headers_mut().extend(self.headers);
response
}
}
#[cfg(feature = "portal")]
pub async fn require_rate_limit(
State(rate_limiter): State<RateLimiter>,
headers: HeaderMap,
request: Request,
next: Next,
) -> Result<Response, RateLimitError> {
let api_key = headers
.get("X-API-Key")
.and_then(|v| v.to_str().ok())
.ok_or_else(|| RateLimitError {
status: StatusCode::UNAUTHORIZED,
headers: HeaderMap::new(),
message: "Missing or invalid X-API-Key header".to_string(),
})?;
let limit = rate_limiter.default_limit();
match rate_limiter.check_rate_limit(api_key, limit).await {
Ok(remaining) => {
let mut response = next.run(request).await;
let headers = response.headers_mut();
headers.insert(
"X-RateLimit-Limit",
HeaderValue::from_str(&limit.to_string()).unwrap(),
);
headers.insert(
"X-RateLimit-Remaining",
HeaderValue::from_str(&remaining.to_string()).unwrap(),
);
let reset_time = RateLimiter::current_timestamp() + rate_limiter.window_size_secs();
headers.insert(
"X-RateLimit-Reset",
HeaderValue::from_str(&reset_time.to_string()).unwrap(),
);
Ok(response)
}
Err(retry_after) => {
let mut headers = HeaderMap::new();
headers.insert(
"Retry-After",
HeaderValue::from_str(&retry_after.to_string()).unwrap(),
);
headers.insert(
"X-RateLimit-Limit",
HeaderValue::from_str(&limit.to_string()).unwrap(),
);
headers.insert("X-RateLimit-Remaining", HeaderValue::from_str("0").unwrap());
let reset_time = RateLimiter::current_timestamp() + retry_after;
headers.insert(
"X-RateLimit-Reset",
HeaderValue::from_str(&reset_time.to_string()).unwrap(),
);
Err(RateLimitError {
status: StatusCode::TOO_MANY_REQUESTS,
headers,
message: "Rate limit exceeded. Please try again later.".to_string(),
})
}
}
}
#[cfg(all(test, feature = "portal"))]
mod tests {
use super::*;
#[tokio::test]
async fn test_memory_backend_basic() {
let config = RateLimitConfig {
default_limit: 5,
window_size_secs: 60,
redis_url: None,
};
let limiter = RateLimiter::new(config).await;
for i in 0..5 {
let result = limiter.check_rate_limit("test_key", 5).await;
assert!(result.is_ok(), "Request {} should be allowed", i + 1);
}
let result = limiter.check_rate_limit("test_key", 5).await;
assert!(result.is_err(), "6th request should be rate limited");
}
#[tokio::test]
async fn test_different_keys_independent() {
let config = RateLimitConfig {
default_limit: 2,
window_size_secs: 60,
redis_url: None,
};
let limiter = RateLimiter::new(config).await;
assert!(limiter.check_rate_limit("key1", 2).await.is_ok());
assert!(limiter.check_rate_limit("key1", 2).await.is_ok());
assert!(limiter.check_rate_limit("key1", 2).await.is_err());
assert!(limiter.check_rate_limit("key2", 2).await.is_ok());
assert!(limiter.check_rate_limit("key2", 2).await.is_ok());
}
#[tokio::test]
async fn test_window_pruning() {
let config = RateLimitConfig {
default_limit: 2,
window_size_secs: 1, redis_url: None,
};
let limiter = RateLimiter::new(config).await;
assert!(limiter.check_rate_limit("test_key", 2).await.is_ok());
assert!(limiter.check_rate_limit("test_key", 2).await.is_ok());
assert!(limiter.check_rate_limit("test_key", 2).await.is_err());
tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
assert!(limiter.check_rate_limit("test_key", 2).await.is_ok());
}
}