use crate::rate_limit_per_key::{PerKeyRateLimiter, RateLimitResult};
use argentor_security::RateLimiter;
use axum::{
extract::{Query, Request, State},
http::{HeaderMap, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
};
use std::sync::Arc;
use tracing::warn;
use uuid::Uuid;
#[derive(Clone)]
pub struct AuthConfig {
pub api_keys: Vec<String>,
}
impl AuthConfig {
pub fn new(api_keys: Vec<String>) -> Self {
Self { api_keys }
}
pub fn is_enabled(&self) -> bool {
!self.api_keys.is_empty()
}
}
#[derive(Clone)]
pub struct MiddlewareState {
pub rate_limiter: Arc<RateLimiter>,
pub auth: AuthConfig,
pub per_key_rate_limiter: Option<Arc<PerKeyRateLimiter>>,
}
pub async fn auth_middleware(
State(state): State<Arc<MiddlewareState>>,
headers: HeaderMap,
query: Query<AuthQuery>,
request: Request,
next: Next,
) -> Response {
if !state.auth.is_enabled() {
return next.run(request).await;
}
let key_from_header = headers
.get("authorization")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer "))
.map(std::string::ToString::to_string);
let key = key_from_header.or_else(|| query.api_key.clone());
match key {
Some(k) if state.auth.api_keys.contains(&k) => next.run(request).await,
Some(_) => {
warn!("Rejected request: invalid API key");
(StatusCode::UNAUTHORIZED, "Invalid API key").into_response()
}
None => {
warn!("Rejected request: missing API key");
(StatusCode::UNAUTHORIZED, "API key required").into_response()
}
}
}
#[derive(serde::Deserialize, Default)]
pub struct AuthQuery {
pub api_key: Option<String>,
}
pub async fn rate_limit_middleware(
State(state): State<Arc<MiddlewareState>>,
request: Request,
next: Next,
) -> Response {
let session_id = Uuid::nil();
if !state.rate_limiter.check(session_id).await {
warn!("Rate limited request");
return (StatusCode::TOO_MANY_REQUESTS, "Rate limit exceeded").into_response();
}
next.run(request).await
}
fn extract_api_key(headers: &HeaderMap) -> Option<String> {
if let Some(auth) = headers
.get("authorization")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer "))
{
return Some(auth.to_string());
}
if let Some(key) = headers.get("x-api-key").and_then(|v| v.to_str().ok()) {
return Some(key.to_string());
}
None
}
pub async fn per_key_rate_limit_middleware(
State(state): State<Arc<MiddlewareState>>,
headers: HeaderMap,
request: Request,
next: Next,
) -> Response {
let limiter = match &state.per_key_rate_limiter {
Some(l) => l,
None => return next.run(request).await,
};
let api_key = match extract_api_key(&headers) {
Some(k) => k,
None => return next.run(request).await,
};
#[allow(clippy::unwrap_used)]
match limiter.check(&api_key) {
RateLimitResult::Allow => {
let mut response = next.run(request).await;
if let Some(stats) = limiter.stats(&api_key) {
let headers = response.headers_mut();
headers.insert(
"X-RateLimit-Limit",
stats
.config
.requests_per_minute
.to_string()
.parse()
.unwrap(),
);
let remaining = stats
.config
.requests_per_minute
.saturating_sub(stats.requests_this_minute);
headers.insert(
"X-RateLimit-Remaining",
remaining.to_string().parse().unwrap(),
);
}
response
}
RateLimitResult::Deny {
reason,
retry_after,
} => {
warn!(
api_key = %api_key,
reason = %reason,
retry_after = retry_after,
"Per-key rate limit exceeded"
);
let body = serde_json::json!({
"error": "rate_limit_exceeded",
"message": reason.to_string(),
"retry_after": retry_after,
});
let mut response = (StatusCode::TOO_MANY_REQUESTS, body.to_string()).into_response();
response
.headers_mut()
.insert("Retry-After", retry_after.to_string().parse().unwrap());
response
.headers_mut()
.insert("X-RateLimit-Remaining", "0".parse().unwrap());
response
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[test]
fn test_auth_config_disabled() {
let config = AuthConfig::new(vec![]);
assert!(!config.is_enabled());
}
#[test]
fn test_auth_config_enabled() {
let config = AuthConfig::new(vec!["key123".to_string()]);
assert!(config.is_enabled());
}
}