use crate::state::AppState;
use axum::{
body::Body,
extract::{ConnectInfo, State},
http::{HeaderValue, Request, Response, StatusCode},
middleware::Next,
response::IntoResponse,
Json,
};
use parking_lot::RwLock;
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use uuid::Uuid;
#[derive(Debug, Clone)]
struct RateLimitEntry {
tokens: f64,
last_update: Instant,
}
#[derive(Debug, Clone)]
pub struct RateLimiter {
entries: Arc<RwLock<HashMap<String, RateLimitEntry>>>,
max_requests: u32,
window_secs: u64,
}
impl RateLimiter {
pub fn new(requests_per_minute: u32) -> Self {
Self {
entries: Arc::new(RwLock::new(HashMap::new())),
max_requests: requests_per_minute,
window_secs: 60,
}
}
pub fn check(&self, key: &str) -> bool {
let mut entries = self.entries.write();
let now = Instant::now();
let entry = entries
.entry(key.to_string())
.or_insert_with(|| RateLimitEntry {
tokens: self.max_requests as f64,
last_update: now,
});
let elapsed = now.duration_since(entry.last_update);
let refill_rate = self.max_requests as f64 / self.window_secs as f64;
let refill = elapsed.as_secs_f64() * refill_rate;
entry.tokens = (entry.tokens + refill).min(self.max_requests as f64);
entry.last_update = now;
if entry.tokens >= 1.0 {
entry.tokens -= 1.0;
true
} else {
false
}
}
pub fn cleanup(&self) {
let mut entries = self.entries.write();
let now = Instant::now();
let max_age = Duration::from_secs(self.window_secs * 2);
entries.retain(|_, entry| now.duration_since(entry.last_update) < max_age);
}
}
impl Default for RateLimiter {
fn default() -> Self {
Self::new(100) }
}
pub async fn request_id(mut request: Request<Body>, next: Next) -> Response<Body> {
let request_id = Uuid::new_v4().to_string();
request.headers_mut().insert(
"x-request-id",
HeaderValue::from_str(&request_id).unwrap_or_else(|_| HeaderValue::from_static("unknown")),
);
let mut response = next.run(request).await;
response.headers_mut().insert(
"x-request-id",
HeaderValue::from_str(&request_id).unwrap_or_else(|_| HeaderValue::from_static("unknown")),
);
response
}
pub async fn shield_check(
State(state): State<AppState>,
request: Request<Body>,
next: Next,
) -> Result<Response<Body>, impl IntoResponse> {
let source_ip = request
.headers()
.get("x-forwarded-for")
.and_then(|h| h.to_str().ok())
.unwrap_or("127.0.0.1")
.split(',')
.next()
.unwrap_or("127.0.0.1")
.trim()
.to_string();
let ctx = aegis_shield::RequestContext {
source_ip: source_ip.clone(),
path: request.uri().path().to_string(),
method: request.method().to_string(),
user_agent: request
.headers()
.get("user-agent")
.and_then(|h| h.to_str().ok())
.map(|s| s.to_string()),
auth_user: None,
body_size: 0,
headers: std::collections::HashMap::new(),
};
match state.shield.analyze_request(&ctx) {
aegis_shield::ShieldVerdict::Allow => Ok(next.run(request).await),
aegis_shield::ShieldVerdict::Block {
reason,
threat_level,
} => {
tracing::warn!(
ip = %source_ip,
level = ?threat_level,
"Shield blocked request: {}",
reason
);
Err((
StatusCode::FORBIDDEN,
Json(serde_json::json!({
"error": "Request blocked by security shield",
"reason": reason,
})),
))
}
aegis_shield::ShieldVerdict::RateLimit { delay_ms } => {
let mut response = next.run(request).await;
if let Ok(val) = HeaderValue::from_str(&delay_ms.to_string()) {
response.headers_mut().insert("x-ratelimit-delay-ms", val);
}
Ok(response)
}
}
}
pub async fn require_auth(
State(state): State<AppState>,
request: Request<Body>,
next: Next,
) -> Result<Response<Body>, impl IntoResponse> {
if state.auth.list_users().is_empty() {
tracing::warn!(
path = %request.uri().path(),
"SECURITY: No admin user configured — all endpoints are unauthenticated. \
Create an admin user via POST /api/v1/auth/login or set \
AEGIS_ADMIN_USERNAME/AEGIS_ADMIN_PASSWORD to secure the server."
);
return Ok(next.run(request).await);
}
let auth_header = request
.headers()
.get("authorization")
.and_then(|h| h.to_str().ok());
let token = match auth_header {
Some(header) if header.starts_with("Bearer ") => &header[7..],
_ => {
return Err((
StatusCode::UNAUTHORIZED,
Json(serde_json::json!({
"error": "Missing or invalid Authorization header",
"message": "Provide a valid Bearer token in the Authorization header"
})),
));
}
};
match state.auth.validate_session(token) {
Some(_user) => {
Ok(next.run(request).await)
}
None => Err((
StatusCode::UNAUTHORIZED,
Json(serde_json::json!({
"error": "Invalid or expired session token",
"message": "Please log in again to obtain a new token"
})),
)),
}
}
fn get_client_ip(request: &Request<Body>) -> String {
if let Some(forwarded) = request
.headers()
.get("x-forwarded-for")
.and_then(|h| h.to_str().ok())
{
if let Some(first_ip) = forwarded.split(',').next() {
return first_ip.trim().to_string();
}
}
if let Some(real_ip) = request
.headers()
.get("x-real-ip")
.and_then(|h| h.to_str().ok())
{
return real_ip.to_string();
}
if let Some(connect_info) = request.extensions().get::<ConnectInfo<SocketAddr>>() {
return connect_info.0.ip().to_string();
}
"unknown".to_string()
}
pub async fn rate_limit(
State(state): State<AppState>,
request: Request<Body>,
next: Next,
) -> Result<Response<Body>, impl IntoResponse> {
let client_ip = get_client_ip(&request);
let rate_limit = state.config.rate_limit_per_minute;
if rate_limit == 0 {
return Ok(next.run(request).await);
}
if state.rate_limiter.check(&client_ip) {
Ok(next.run(request).await)
} else {
Err((
StatusCode::TOO_MANY_REQUESTS,
Json(serde_json::json!({
"error": "Rate limit exceeded",
"message": format!("Too many requests. Please try again later. Limit: {} requests per minute.", rate_limit),
"retry_after_seconds": 60
})),
))
}
}
pub async fn login_rate_limit(
State(state): State<AppState>,
request: Request<Body>,
next: Next,
) -> Result<Response<Body>, impl IntoResponse> {
let client_ip = get_client_ip(&request);
let rate_limit = state.config.login_rate_limit_per_minute;
if rate_limit == 0 {
return Ok(next.run(request).await);
}
if state
.login_rate_limiter
.check(&format!("login:{}", client_ip))
{
Ok(next.run(request).await)
} else {
Err((
StatusCode::TOO_MANY_REQUESTS,
Json(serde_json::json!({
"error": "Too many login attempts",
"message": format!("Too many login attempts. Please try again later. Limit: {} attempts per minute.", rate_limit),
"retry_after_seconds": 60
})),
))
}
}
pub async fn security_headers(
State(state): State<AppState>,
request: Request<Body>,
next: Next,
) -> Response<Body> {
let mut response = next.run(request).await;
let headers = response.headers_mut();
headers.insert(
"content-security-policy",
HeaderValue::from_static("default-src 'self'"),
);
headers.insert(
"x-content-type-options",
HeaderValue::from_static("nosniff"),
);
headers.insert("x-frame-options", HeaderValue::from_static("DENY"));
headers.insert(
"x-xss-protection",
HeaderValue::from_static("1; mode=block"),
);
headers.insert(
"referrer-policy",
HeaderValue::from_static("strict-origin-when-cross-origin"),
);
if state.config.tls.is_some() {
headers.insert(
"strict-transport-security",
HeaderValue::from_static("max-age=31536000; includeSubDomains"),
);
}
response
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::ServerConfig;
use axum::body::Body;
use axum::http::{Request, StatusCode};
use axum::{routing::get, Router};
use tower::util::ServiceExt;
async fn handler() -> &'static str {
"ok"
}
#[tokio::test]
async fn test_request_id_middleware() {
let app = Router::new()
.route("/", get(handler))
.layer(axum::middleware::from_fn(request_id));
let response = app
.oneshot(
Request::builder()
.uri("/")
.body(Body::empty())
.expect("failed to build request"),
)
.await
.expect("failed to execute request");
assert_eq!(response.status(), StatusCode::OK);
assert!(response.headers().contains_key("x-request-id"));
}
#[tokio::test]
async fn test_auth_middleware_no_token() {
let state = AppState::new(ServerConfig::default());
let _ = state
.auth
.create_user("testuser", "test@test.local", "TestPass123!", "admin");
let app = Router::new()
.route("/", get(handler))
.layer(axum::middleware::from_fn_with_state(
state.clone(),
require_auth,
))
.with_state(state);
let response = app
.oneshot(
Request::builder()
.uri("/")
.body(Body::empty())
.expect("failed to build request"),
)
.await
.expect("failed to execute request");
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn test_auth_middleware_invalid_token() {
let state = AppState::new(ServerConfig::default());
let _ = state
.auth
.create_user("testuser", "test@test.local", "TestPass123!", "admin");
let app = Router::new()
.route("/", get(handler))
.layer(axum::middleware::from_fn_with_state(
state.clone(),
require_auth,
))
.with_state(state);
let response = app
.oneshot(
Request::builder()
.uri("/")
.header("Authorization", "Bearer invalid_token")
.body(Body::empty())
.expect("failed to build request"),
)
.await
.expect("failed to execute request");
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn test_auth_middleware_valid_token() {
let state = AppState::new(ServerConfig::default());
state
.auth
.create_user("authtest", "auth@test.com", "TestPassword123!", "admin")
.expect("failed to create test user");
let login_response = state.auth.login("authtest", "TestPassword123!");
let token = login_response.token.expect("login should return token");
let app = Router::new()
.route("/", get(handler))
.layer(axum::middleware::from_fn_with_state(
state.clone(),
require_auth,
))
.with_state(state);
let response = app
.oneshot(
Request::builder()
.uri("/")
.header("Authorization", format!("Bearer {}", token))
.body(Body::empty())
.expect("failed to build request"),
)
.await
.expect("failed to execute request");
assert_eq!(response.status(), StatusCode::OK);
}
#[test]
fn test_rate_limiter_allows_requests() {
let limiter = RateLimiter::new(10);
for _ in 0..10 {
assert!(limiter.check("test_client"));
}
assert!(!limiter.check("test_client"));
}
#[test]
fn test_rate_limiter_different_clients() {
let limiter = RateLimiter::new(5);
for _ in 0..5 {
assert!(limiter.check("client_a"));
assert!(limiter.check("client_b"));
}
assert!(!limiter.check("client_a"));
assert!(!limiter.check("client_b"));
}
#[test]
fn test_rate_limiter_cleanup() {
let limiter = RateLimiter::new(10);
limiter.check("client_1");
limiter.check("client_2");
limiter.cleanup();
assert!(limiter.check("client_1"));
}
#[tokio::test]
async fn test_security_headers_without_tls() {
let state = AppState::new(ServerConfig::default());
let app = Router::new()
.route("/", get(handler))
.layer(axum::middleware::from_fn_with_state(
state.clone(),
security_headers,
))
.with_state(state);
let response = app
.oneshot(
Request::builder()
.uri("/")
.body(Body::empty())
.expect("failed to build request"),
)
.await
.expect("failed to execute request");
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response
.headers()
.get("content-security-policy")
.map(|v| v.to_str().unwrap()),
Some("default-src 'self'")
);
assert_eq!(
response
.headers()
.get("x-content-type-options")
.map(|v| v.to_str().unwrap()),
Some("nosniff")
);
assert_eq!(
response
.headers()
.get("x-frame-options")
.map(|v| v.to_str().unwrap()),
Some("DENY")
);
assert_eq!(
response
.headers()
.get("x-xss-protection")
.map(|v| v.to_str().unwrap()),
Some("1; mode=block")
);
assert_eq!(
response
.headers()
.get("referrer-policy")
.map(|v| v.to_str().unwrap()),
Some("strict-origin-when-cross-origin")
);
assert!(response
.headers()
.get("strict-transport-security")
.is_none());
}
#[tokio::test]
async fn test_security_headers_with_tls() {
let config = ServerConfig::default().with_tls("/path/to/cert.pem", "/path/to/key.pem");
let state = AppState::new(config);
let app = Router::new()
.route("/", get(handler))
.layer(axum::middleware::from_fn_with_state(
state.clone(),
security_headers,
))
.with_state(state);
let response = app
.oneshot(
Request::builder()
.uri("/")
.body(Body::empty())
.expect("failed to build request"),
)
.await
.expect("failed to execute request");
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response
.headers()
.get("strict-transport-security")
.map(|v| v.to_str().unwrap()),
Some("max-age=31536000; includeSubDomains")
);
}
}