use crate::core::rate_limiter::get_global_rate_limiter;
use crate::core::types::context::RequestContext;
use crate::server::state::AppState;
use actix_web::dev::{Service, ServiceRequest, ServiceResponse, Transform, forward_ready};
use actix_web::http::StatusCode;
use actix_web::web;
use actix_web::{HttpMessage, HttpResponse, ResponseError};
use dashmap::DashMap;
use futures::future::{Ready, ready};
use std::fmt;
use std::future::Future;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tracing::{debug, info, warn};
struct KeyTracker {
timestamps: Vec<Instant>,
}
impl KeyTracker {
fn new() -> Self {
Self {
timestamps: Vec::new(),
}
}
fn check_and_record(&mut self, limit: u32, window: Duration) -> (bool, u64) {
let now = Instant::now();
self.timestamps
.retain(|&ts| now.duration_since(ts) < window);
let count = self.timestamps.len() as u32;
if count >= limit {
let retry_after = self
.timestamps
.first()
.map(|&ts| {
let age = now.duration_since(ts);
window.saturating_sub(age).as_secs().max(1)
})
.unwrap_or(window.as_secs());
(false, retry_after)
} else {
self.timestamps.push(now);
(true, 0)
}
}
}
#[derive(Debug)]
struct RateLimitError {
retry_after: u64,
limit: u32,
}
impl fmt::Display for RateLimitError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Too Many Requests")
}
}
impl ResponseError for RateLimitError {
fn status_code(&self) -> StatusCode {
StatusCode::TOO_MANY_REQUESTS
}
fn error_response(&self) -> HttpResponse {
HttpResponse::TooManyRequests()
.insert_header(("Retry-After", self.retry_after.to_string()))
.insert_header(("X-RateLimit-Limit", self.limit.to_string()))
.json(serde_json::json!({
"error": {
"message": "Rate limit exceeded. Please retry after the indicated seconds.",
"type": "rate_limit_error",
"code": 429
}
}))
}
}
pub struct RateLimitMiddleware {
requests_per_minute: u32,
}
impl RateLimitMiddleware {
pub fn new(requests_per_minute: u32) -> Self {
Self {
requests_per_minute,
}
}
}
impl Default for RateLimitMiddleware {
fn default() -> Self {
Self::new(60)
}
}
impl<S, B> Transform<S, ServiceRequest> for RateLimitMiddleware
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error>,
S::Future: 'static,
B: 'static,
{
type Response = ServiceResponse<B>;
type Error = actix_web::Error;
type InitError = ();
type Transform = RateLimitMiddlewareService<S>;
type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future {
ready(Ok(RateLimitMiddlewareService {
service,
requests_per_minute: self.requests_per_minute,
fallback_store: Arc::new(DashMap::new()),
}))
}
}
pub struct RateLimitMiddlewareService<S> {
service: S,
requests_per_minute: u32,
fallback_store: Arc<DashMap<String, KeyTracker>>,
}
fn parse_peer_ip(peer: &str) -> String {
peer.parse::<SocketAddr>()
.map(|addr| addr.ip().to_string())
.unwrap_or_else(|_| peer.to_string())
}
fn extract_client_key(req: &ServiceRequest, trusted_proxies: &[String]) -> String {
if let Some(identity) = authenticated_client_key(req) {
return identity;
}
network_client_key(req, trusted_proxies)
}
fn authenticated_client_key(req: &ServiceRequest) -> Option<String> {
let extensions = req.extensions();
let context = extensions.get::<RequestContext>()?;
if let Some(api_key_id) = context.api_key_id() {
return Some(format!("api_key:{}", api_key_id));
}
context
.user_id
.as_deref()
.map(str::trim)
.filter(|user_id| !user_id.is_empty())
.map(|user_id| format!("user:{}", user_id))
}
fn network_client_key(req: &ServiceRequest, trusted_proxies: &[String]) -> String {
let conn = req.connection_info();
let peer = conn.peer_addr().unwrap_or("unknown");
let peer_ip = parse_peer_ip(peer);
if trusted_proxies.iter().any(|p| p == &peer_ip)
&& let Some(forwarded) = req.headers().get("X-Forwarded-For")
&& let Ok(val) = forwarded.to_str()
&& let first = val.split(',').next().unwrap_or("").trim()
&& !first.is_empty()
{
return format!("ip:{}", first);
}
format!("ip:{}", peer_ip)
}
impl<S, B> Service<ServiceRequest> for RateLimitMiddlewareService<S>
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error>,
S::Future: 'static,
B: 'static,
{
type Response = ServiceResponse<B>;
type Error = actix_web::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
forward_ready!(service);
fn call(&self, req: ServiceRequest) -> Self::Future {
let app_state = req.app_data::<web::Data<AppState>>().cloned();
let trusted_proxies: Vec<String> = match app_state.as_ref() {
Some(state) => {
let cfg = state.config.load();
cfg.server().trusted_proxies.clone()
}
None => Vec::new(),
};
let requests_per_minute = self.requests_per_minute;
let start_time = Instant::now();
let path = req.path().to_string();
let method = req.method().to_string();
let client_key = extract_client_key(&req, &trusted_proxies);
if let Some(global_limiter) = get_global_rate_limiter() {
let limit = global_limiter.limit();
let fut = self.service.call(req);
let key = client_key.clone();
return Box::pin(async move {
let result = global_limiter.check_and_record(&key).await;
if !result.allowed {
let retry_after = result.retry_after_secs.unwrap_or(60);
warn!(
client = %key,
path = %path,
"Rate limit exceeded (global limiter): retry after {}s",
retry_after
);
let err = RateLimitError { retry_after, limit };
return Err(actix_web::Error::from(err));
}
debug!(
client = %key,
remaining = result.remaining,
"Rate limit check passed (global limiter)"
);
let res = fut.await?;
let duration = start_time.elapsed();
info!(
"{} {} completed in {:?} with status {}",
method,
path,
duration,
res.status()
);
Ok(res)
});
}
let fallback_store = self.fallback_store.clone();
let fut = self.service.call(req);
let key = client_key.clone();
Box::pin(async move {
let window = Duration::from_secs(60);
let (allowed, retry_after) = {
let mut tracker = fallback_store
.entry(key.clone())
.or_insert_with(KeyTracker::new);
tracker.check_and_record(requests_per_minute, window)
};
if !allowed {
warn!(
client = %key,
path = %path,
"Rate limit exceeded (fallback limiter): retry after {}s",
retry_after
);
let err = RateLimitError {
retry_after,
limit: requests_per_minute,
};
return Err(actix_web::Error::from(err));
}
debug!(
client = %key,
limit = requests_per_minute,
"Rate limit check passed (fallback limiter)"
);
let res = fut.await?;
let duration = start_time.elapsed();
info!(
"{} {} completed in {:?} with status {}",
method,
path,
duration,
res.status()
);
Ok(res)
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use actix_web::test::TestRequest;
use uuid::Uuid;
#[test]
fn test_parse_peer_ip_ipv4_with_port() {
assert_eq!(parse_peer_ip("127.0.0.1:1234"), "127.0.0.1");
}
#[test]
fn test_parse_peer_ip_ipv4_no_port() {
assert_eq!(parse_peer_ip("10.0.0.1"), "10.0.0.1");
}
#[test]
fn test_parse_peer_ip_ipv6_with_port() {
assert_eq!(parse_peer_ip("[::1]:8080"), "::1");
}
#[test]
fn test_parse_peer_ip_unknown_falls_back() {
assert_eq!(parse_peer_ip("unknown"), "unknown");
}
#[test]
fn test_trusted_proxy_match() {
let proxies = ["10.0.0.1".to_string()];
assert!(proxies.iter().any(|p| p == "10.0.0.1"));
}
#[test]
fn test_trusted_proxy_no_match() {
let proxies = ["10.0.0.1".to_string()];
assert!(!proxies.iter().any(|p| p == "10.0.0.2"));
}
#[test]
fn test_trusted_proxy_empty_list() {
let proxies: Vec<String> = vec![];
assert!(!proxies.iter().any(|p| p == "127.0.0.1"));
}
#[test]
fn test_extract_client_key_ignores_rotating_authorization_headers() {
let req_a = TestRequest::default()
.peer_addr("203.0.113.10:1000".parse().unwrap())
.insert_header(("Authorization", "Bearer bogus-a"))
.to_srv_request();
let req_b = TestRequest::default()
.peer_addr("203.0.113.10:1000".parse().unwrap())
.insert_header(("Authorization", "Bearer bogus-b"))
.to_srv_request();
let key_a = extract_client_key(&req_a, &[]);
let key_b = extract_client_key(&req_b, &[]);
assert_eq!(key_a, "ip:203.0.113.10");
assert_eq!(key_a, key_b);
}
#[test]
fn test_extract_client_key_ignores_rotating_api_key_headers() {
let req_a = TestRequest::default()
.peer_addr("203.0.113.20:1000".parse().unwrap())
.insert_header(("x-api-key", "bogus-a"))
.to_srv_request();
let req_b = TestRequest::default()
.peer_addr("203.0.113.20:1000".parse().unwrap())
.insert_header(("x-api-key", "bogus-b"))
.to_srv_request();
let key_a = extract_client_key(&req_a, &[]);
let key_b = extract_client_key(&req_b, &[]);
assert_eq!(key_a, "ip:203.0.113.20");
assert_eq!(key_a, key_b);
}
#[test]
fn test_extract_client_key_uses_trusted_forwarded_ip() {
let req = TestRequest::default()
.peer_addr("10.0.0.1:1000".parse().unwrap())
.insert_header(("X-Forwarded-For", "198.51.100.7, 10.0.0.2"))
.to_srv_request();
let key = extract_client_key(&req, &["10.0.0.1".to_string()]);
assert_eq!(key, "ip:198.51.100.7");
}
#[test]
fn test_extract_client_key_prefers_authenticated_api_key_id() {
let api_key_id = Uuid::new_v4();
let req = TestRequest::default()
.peer_addr("203.0.113.30:1000".parse().unwrap())
.to_srv_request();
req.extensions_mut()
.insert(RequestContext::new().with_api_key(api_key_id));
let key = extract_client_key(&req, &[]);
assert_eq!(key, format!("api_key:{}", api_key_id));
}
#[test]
fn test_extract_client_key_uses_authenticated_user_id_without_api_key() {
let req = TestRequest::default()
.peer_addr("203.0.113.40:1000".parse().unwrap())
.to_srv_request();
req.extensions_mut()
.insert(RequestContext::new().with_user_id("user-123"));
let key = extract_client_key(&req, &[]);
assert_eq!(key, "user:user-123");
}
}