use crate::config::RateLimitConfig;
use crate::state::{RateLimitResult, RateLimitStore};
use axum::{
body::Body,
http::{Request, Response, StatusCode},
};
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use tower::{Layer, Service};
#[derive(Clone)]
pub struct RateLimitLayer {
store: RateLimitStore,
}
impl RateLimitLayer {
pub fn new() -> Self {
Self {
store: RateLimitStore::new(),
}
}
pub fn with_store(store: RateLimitStore) -> Self {
Self { store }
}
pub fn cleanup_expired(&self) {
self.store.cleanup_expired();
}
pub fn store(&self) -> &RateLimitStore {
&self.store
}
}
impl Default for RateLimitLayer {
fn default() -> Self {
Self::new()
}
}
impl<S> Layer<S> for RateLimitLayer {
type Service = RateLimitService<S>;
fn layer(&self, inner: S) -> Self::Service {
RateLimitService {
inner,
store: self.store.clone(),
}
}
}
#[derive(Clone)]
pub struct RateLimitService<S> {
inner: S,
store: RateLimitStore,
}
impl<S> RateLimitService<S> {
pub fn new(inner: S, store: RateLimitStore) -> Self {
Self { inner, store }
}
}
impl<S> Service<Request<Body>> for RateLimitService<S>
where
S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
S::Future: Send,
{
type Response = S::Response;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
let tenant_key = extract_rate_limit_key(&req);
let plan = req
.headers()
.get("X-Tenant-Plan")
.and_then(|v| v.to_str().ok())
.unwrap_or("free");
let config = RateLimitConfig::for_plan(plan);
let result = self.store.check(&tenant_key, &config);
match result {
RateLimitResult::Exceeded {
retry_after,
limit_type,
} => {
tracing::warn!(
tenant = %tenant_key,
limit_type = %limit_type,
retry_after = %retry_after,
"Rate limit exceeded"
);
Box::pin(async move {
let response = Response::builder()
.status(StatusCode::TOO_MANY_REQUESTS)
.header("Retry-After", retry_after.to_string())
.header("X-RateLimit-Limit-Type", limit_type)
.header("Content-Type", "application/json")
.body(Body::from(format!(
r#"{{"error":"Rate limit exceeded","error_code":"RATE_LIMIT_EXCEEDED","retry_after":{},"limit_type":"{}"}}"#,
retry_after, limit_type
)))
.expect("Failed to build rate limit response");
Ok(response)
})
}
RateLimitResult::Allowed {
remaining_minute,
remaining_hour,
} => {
let mut inner = self.inner.clone();
Box::pin(async move {
let mut response = inner.call(req).await?;
let headers = response.headers_mut();
if let Ok(v) = remaining_minute.to_string().parse() {
headers.insert("X-RateLimit-Remaining-Minute", v);
}
if let Ok(v) = remaining_hour.to_string().parse() {
headers.insert("X-RateLimit-Remaining-Hour", v);
}
Ok(response)
})
}
}
}
}
fn extract_rate_limit_key(req: &Request<Body>) -> String {
req.headers()
.get("X-Tenant-ID")
.and_then(|v| v.to_str().ok())
.map(String::from)
.unwrap_or_else(|| {
req.headers()
.get("X-Forwarded-For")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.split(',').next())
.map(|s| s.trim())
.and_then(validate_ip_key)
.unwrap_or_else(|| "anonymous".to_string())
})
}
#[inline]
fn validate_ip_key(ip: &str) -> Option<String> {
use std::net::IpAddr;
ip.parse::<IpAddr>().ok().map(|addr| addr.to_string())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_layer_creation() {
let layer = RateLimitLayer::new();
assert!(layer.store().is_empty());
}
#[test]
fn test_layer_with_store() {
let store = RateLimitStore::with_capacity(500);
let layer = RateLimitLayer::with_store(store);
assert!(layer.store().is_empty());
}
}