use axum::{
extract::{ConnectInfo, Request, State},
http::{header, HeaderValue, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
};
use std::collections::HashSet;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
pub async fn security_headers(request: Request, next: Next) -> Response {
let mut response = next.run(request).await;
let headers = response.headers_mut();
headers.insert(
header::X_CONTENT_TYPE_OPTIONS,
HeaderValue::from_static("nosniff"),
);
headers.insert(header::X_FRAME_OPTIONS, HeaderValue::from_static("DENY"));
headers.insert(
"X-XSS-Protection",
HeaderValue::from_static("1; mode=block"),
);
headers.insert(
header::CONTENT_SECURITY_POLICY,
HeaderValue::from_static("default-src 'none'; frame-ancestors 'none'"),
);
headers.insert(
header::CACHE_CONTROL,
HeaderValue::from_static("no-store, max-age=0"),
);
headers.insert(
header::REFERRER_POLICY,
HeaderValue::from_static("strict-origin-when-cross-origin"),
);
headers.insert(
"Permissions-Policy",
HeaderValue::from_static("geolocation=(), camera=(), microphone=()"),
);
response
}
#[derive(Clone, Debug)]
pub struct AuthConfig {
pub api_keys: HashSet<String>,
pub header_name: String,
pub prefix: String,
pub public_paths: HashSet<String>,
}
impl Default for AuthConfig {
fn default() -> Self {
Self {
api_keys: HashSet::new(),
header_name: "Authorization".to_string(),
prefix: "Bearer ".to_string(),
public_paths: ["/health", "/ready"]
.iter()
.map(|s| s.to_string())
.collect(),
}
}
}
impl AuthConfig {
pub fn with_keys(keys: impl IntoIterator<Item = String>) -> Self {
Self {
api_keys: keys.into_iter().collect(),
..Default::default()
}
}
pub fn add_public_path(mut self, path: impl Into<String>) -> Self {
self.public_paths.insert(path.into());
self
}
pub fn requires_auth(&self, path: &str) -> bool {
!self.public_paths.contains(path)
}
pub fn validate_key(&self, key: &str) -> bool {
self.api_keys.contains(key)
}
pub fn is_enabled(&self) -> bool {
!self.api_keys.is_empty()
}
}
#[derive(Clone)]
pub struct AuthState {
pub config: Arc<AuthConfig>,
}
pub async fn api_key_auth(
State(auth): State<AuthState>,
request: Request,
next: Next,
) -> Result<Response, Response> {
let path = request.uri().path();
if !auth.config.requires_auth(path) {
return Ok(next.run(request).await);
}
if !auth.config.is_enabled() {
return Ok(next.run(request).await);
}
let auth_header = request
.headers()
.get(&auth.config.header_name)
.and_then(|v| v.to_str().ok());
let api_key = match auth_header {
Some(value) if value.starts_with(&auth.config.prefix) => &value[auth.config.prefix.len()..],
Some(_) => {
return Err((
StatusCode::UNAUTHORIZED,
[(header::WWW_AUTHENTICATE, "Bearer")],
"Invalid authorization header format",
)
.into_response());
}
None => {
return Err((
StatusCode::UNAUTHORIZED,
[(header::WWW_AUTHENTICATE, "Bearer")],
"Missing authorization header",
)
.into_response());
}
};
if !auth.config.validate_key(api_key) {
tracing::warn!(
path = %path,
"Invalid API key attempt"
);
return Err((
StatusCode::UNAUTHORIZED,
[(header::WWW_AUTHENTICATE, "Bearer")],
"Invalid API key",
)
.into_response());
}
Ok(next.run(request).await)
}
#[derive(Clone, Debug, Default)]
pub enum RateLimitKey {
#[default]
ByIp,
ByApiKey,
}
#[derive(Clone, Debug)]
pub struct RateLimitConfig {
pub max_requests: u32,
pub window: Duration,
pub key_strategy: RateLimitKey,
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
max_requests: 100,
window: Duration::from_secs(60),
key_strategy: RateLimitKey::ByIp,
}
}
}
#[derive(Clone)]
struct RateLimitEntry {
count: u32,
window_start: Instant,
}
#[derive(Clone)]
pub struct RateLimitState {
pub config: Arc<RateLimitConfig>,
entries: Arc<RwLock<std::collections::HashMap<String, RateLimitEntry>>>,
}
impl RateLimitState {
pub fn new(config: RateLimitConfig) -> Self {
Self {
config: Arc::new(config),
entries: Arc::new(RwLock::new(std::collections::HashMap::new())),
}
}
pub async fn cleanup(&self) {
let now = Instant::now();
let window = self.config.window;
let mut entries = self.entries.write().await;
entries.retain(|_, entry| now.duration_since(entry.window_start) < window);
}
async fn check_and_increment(&self, key: String) -> Result<(u32, u32), (u32, Duration)> {
let now = Instant::now();
let mut entries = self.entries.write().await;
let entry = entries.entry(key).or_insert_with(|| RateLimitEntry {
count: 0,
window_start: now,
});
if now.duration_since(entry.window_start) >= self.config.window {
entry.count = 0;
entry.window_start = now;
}
entry.count += 1;
if entry.count > self.config.max_requests {
let retry_after = self.config.window - now.duration_since(entry.window_start);
Err((entry.count, retry_after))
} else {
Ok((
self.config.max_requests - entry.count,
self.config.max_requests,
))
}
}
}
pub async fn rate_limit(
State(state): State<RateLimitState>,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
request: Request,
next: Next,
) -> Result<Response, Response> {
let key = match state.config.key_strategy {
RateLimitKey::ByIp => addr.ip().to_string(),
RateLimitKey::ByApiKey => {
request
.headers()
.get(header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.map(|s| s.trim_start_matches("Bearer ").to_string())
.unwrap_or_else(|| addr.ip().to_string())
}
};
match state.check_and_increment(key).await {
Ok((remaining, limit)) => {
let mut response = next.run(request).await;
let headers = response.headers_mut();
headers.insert(
"X-RateLimit-Limit",
HeaderValue::from_str(&limit.to_string()).unwrap(),
);
headers.insert(
"X-RateLimit-Remaining",
HeaderValue::from_str(&remaining.to_string()).unwrap(),
);
Ok(response)
}
Err((_, retry_after)) => {
let retry_secs = retry_after.as_secs().max(1);
Err((
StatusCode::TOO_MANY_REQUESTS,
[
("Retry-After", retry_secs.to_string()),
("X-RateLimit-Limit", state.config.max_requests.to_string()),
("X-RateLimit-Remaining", "0".to_string()),
],
"Rate limit exceeded",
)
.into_response())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_auth_config_default() {
let config = AuthConfig::default();
assert!(!config.is_enabled());
assert!(config.public_paths.contains("/health"));
assert!(config.public_paths.contains("/ready"));
}
#[test]
fn test_auth_config_with_keys() {
let config = AuthConfig::with_keys(["key1".to_string(), "key2".to_string()]);
assert!(config.is_enabled());
assert!(config.validate_key("key1"));
assert!(config.validate_key("key2"));
assert!(!config.validate_key("key3"));
}
#[test]
fn test_auth_config_public_paths() {
let config = AuthConfig::default().add_public_path("/metrics");
assert!(!config.requires_auth("/health"));
assert!(!config.requires_auth("/ready"));
assert!(!config.requires_auth("/metrics"));
assert!(config.requires_auth("/v1/state"));
}
#[test]
fn test_rate_limit_config_default() {
let config = RateLimitConfig::default();
assert_eq!(config.max_requests, 100);
assert_eq!(config.window, Duration::from_secs(60));
}
#[tokio::test]
async fn test_rate_limit_state() {
let state = RateLimitState::new(RateLimitConfig {
max_requests: 3,
window: Duration::from_secs(60),
key_strategy: RateLimitKey::ByIp,
});
assert!(state.check_and_increment("test".to_string()).await.is_ok());
assert!(state.check_and_increment("test".to_string()).await.is_ok());
assert!(state.check_and_increment("test".to_string()).await.is_ok());
assert!(state.check_and_increment("test".to_string()).await.is_err());
assert!(state.check_and_increment("other".to_string()).await.is_ok());
}
}