use crate::models::AppState;
use axum::{
extract::Request,
http::{HeaderMap, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
};
#[derive(Debug, Clone)]
pub struct RateLimitInfo {
pub requests: Vec<std::time::Instant>,
pub limit: u32,
pub window_seconds: u64,
}
impl RateLimitInfo {
pub fn new(limit: u32, window_seconds: u64) -> Self {
Self {
requests: Vec::new(),
limit,
window_seconds,
}
}
pub fn is_allowed(&mut self) -> bool {
let now = std::time::Instant::now();
let window_start = now - std::time::Duration::from_secs(self.window_seconds);
self.requests.retain(|&time| time > window_start);
if self.requests.len() < self.limit as usize {
self.requests.push(now);
true
} else {
false
}
}
pub fn remaining(&self) -> u32 {
let now = std::time::Instant::now();
let window_start = now - std::time::Duration::from_secs(self.window_seconds);
let count = self.requests.iter().filter(|&&time| time > window_start).count();
self.limit.saturating_sub(count as u32)
}
pub fn reset_time(&self) -> u64 {
let now = std::time::Instant::now();
let window_start = now - std::time::Duration::from_secs(self.window_seconds);
if let Some(&oldest) = self.requests.iter().filter(|&&time| time > window_start).min() {
let elapsed = now.duration_since(oldest).as_secs();
self.window_seconds.saturating_sub(elapsed)
} else {
self.window_seconds
}
}
}
pub async fn rate_limit_middleware(
axum::extract::State(state): axum::extract::State<AppState>,
request: Request,
next: Next,
) -> Result<Response, StatusCode> {
if request.uri().path() == "/health" {
return Ok(next.run(request).await);
}
let api_key = extract_api_key(request.headers()).map_err(|_| StatusCode::UNAUTHORIZED)?;
let accounts = state.accounts.read().await;
if let Some(account) = accounts.get(&api_key) {
let account_id = account.id.clone();
let rate_limit = account.rate_limit;
drop(accounts);
if rate_limit == 0 {
return Ok(next.run(request).await);
}
let rate_limiter = state.rate_limiter.clone();
let mut limiters = rate_limiter.lock().await;
let limiter = limiters.entry(account_id.clone()).or_insert_with(|| {
RateLimitInfo::new(rate_limit, 60) });
if limiter.limit != rate_limit {
limiter.limit = rate_limit;
}
if limiter.is_allowed() {
let remaining = limiter.remaining();
let reset = limiter.reset_time();
drop(limiters);
let mut response = next.run(request).await;
response.headers_mut().insert(
"X-RateLimit-Limit",
rate_limit.to_string().parse().unwrap()
);
response.headers_mut().insert(
"X-RateLimit-Remaining",
remaining.to_string().parse().unwrap()
);
response.headers_mut().insert(
"X-RateLimit-Reset",
reset.to_string().parse().unwrap()
);
Ok(response)
} else {
let reset = limiter.reset_time();
drop(limiters);
let error_response = serde_json::json!({
"success": false,
"error": {
"code": "RATE_LIMIT_EXCEEDED",
"message": "Rate limit exceeded. Try again later."
}
});
let body = axum::Json(error_response);
Ok((StatusCode::TOO_MANY_REQUESTS,
[
("Content-Type", "application/json"),
("X-RateLimit-Limit", &rate_limit.to_string()),
("X-RateLimit-Remaining", "0"),
("X-RateLimit-Reset", &reset.to_string()),
],
body).into_response())
}
} else {
Ok(next.run(request).await)
}
}
fn extract_api_key(headers: &HeaderMap) -> Result<String, String> {
headers
.get("x-api-key")
.or_else(|| headers.get("authorization"))
.and_then(|h| h.to_str().ok())
.map(|h| {
if h.starts_with("Bearer ") {
h.strip_prefix("Bearer ").unwrap().to_string()
} else {
h.to_string()
}
})
.ok_or_else(|| "Missing API key".to_string())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rate_limiter_basic() {
let mut limiter = RateLimitInfo::new(5, 60);
for i in 1..=5 {
let allowed = limiter.is_allowed();
assert_eq!(allowed, true, "Request {} should be allowed", i);
}
let allowed = limiter.is_allowed();
assert_eq!(allowed, false, "Request 6 should be rate limited");
let remaining = limiter.remaining();
assert_eq!(remaining, 0, "Remaining requests should be 0");
}
}