use axum::{
body::Body,
http::{HeaderValue, Request, StatusCode},
middleware::Next,
response::Response,
};
use dashmap::DashMap;
use hmac::{Hmac, Mac};
use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
use sha2::Sha256;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::time::{Duration, Instant};
use thiserror::Error;
use uuid::Uuid;
use crate::impl_default_new;
#[derive(Debug, Error, Clone)]
pub enum AuthError {
#[error("Missing or invalid authorization header")]
MissingAuth,
#[error("Invalid or expired token")]
InvalidToken,
#[error("Insufficient permissions: {required}")]
InsufficientPermissions {
required: String,
user_permissions: Vec<String>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthContext {
pub(crate) user_id: Option<String>,
pub(crate) permissions: Vec<String>,
pub(crate) metadata: AuthMetadata,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct AuthMetadata {
pub(crate) client_ip: Option<String>,
pub(crate) user_agent: Option<String>,
pub(crate) request_id: String,
pub(crate) timestamp: i64,
}
impl AuthContext {
pub fn user_id(&self) -> Option<&str> {
self.user_id.as_deref()
}
pub fn permissions(&self) -> &[String] {
&self.permissions
}
pub fn metadata(&self) -> &AuthMetadata {
&self.metadata
}
pub fn has_permission(&self, permission: &str) -> bool {
self.permissions.contains(&permission.to_string())
}
pub fn new(user_id: Option<String>, permissions: Vec<String>, metadata: AuthMetadata) -> Self {
Self {
user_id,
permissions,
metadata,
}
}
}
impl AuthMetadata {
pub fn client_ip(&self) -> Option<&str> {
self.client_ip.as_deref()
}
pub fn user_agent(&self) -> Option<&str> {
self.user_agent.as_deref()
}
pub fn request_id(&self) -> &str {
&self.request_id
}
pub fn timestamp(&self) -> i64 {
self.timestamp
}
pub fn new(client_ip: Option<String>, user_agent: Option<String>) -> Self {
Self {
client_ip,
user_agent,
request_id: Uuid::new_v4().to_string(),
timestamp: chrono::Utc::now().timestamp(),
}
}
}
pub type AuthResult<T = AuthContext> = Result<T, AuthError>;
#[derive(Debug)]
pub struct AuthExtractor(pub AuthContext);
#[derive(Clone)]
pub struct ApiKeyAuth {
valid_keys: Arc<DashMap<String, Vec<String>>>,
failed_attempts: Arc<DashMap<String, Vec<Instant>>>,
rate_limit_config: RateLimitConfig,
}
impl ApiKeyAuth {
pub fn new() -> Self {
Self::with_rate_limit(RateLimitConfig {
max_requests: 5,
window: Duration::from_secs(60),
include_headers: false,
})
}
pub fn with_rate_limit(config: RateLimitConfig) -> Self {
Self {
valid_keys: Arc::new(DashMap::new()),
failed_attempts: Arc::new(DashMap::new()),
rate_limit_config: config,
}
}
fn hash_key(key: &str) -> String {
use sha2::Digest;
let mut hasher = sha2::Sha256::new();
hasher.update(key.as_bytes());
format!("{:x}", hasher.finalize())
}
pub fn add_key(&self, key: impl Into<String>, permissions: Vec<String>) {
let key_hash = Self::hash_key(&key.into());
self.valid_keys.insert(key_hash, permissions);
}
pub fn validate_key(&self, key: &str, client_ip: &str) -> Option<Vec<String>> {
let start = Instant::now();
let key_hash = Self::hash_key(key);
let is_valid = self.valid_keys.get(&key_hash).is_some();
if is_valid {
Self::apply_constant_time_delay(start);
return self.valid_keys.get(&key_hash).map(|p| p.clone());
}
let is_limited = self.is_rate_limited(client_ip);
if !is_limited {
self.record_failed_attempt(client_ip);
}
Self::apply_constant_time_delay(start);
None
}
fn apply_constant_time_delay(start: Instant) {
if cfg!(test) {
return;
}
const TARGET_DELAY_US: u64 = 100; let elapsed = start.elapsed();
if elapsed < Duration::from_micros(TARGET_DELAY_US) {
std::thread::sleep(Duration::from_micros(TARGET_DELAY_US) - elapsed);
}
}
fn is_rate_limited(&self, client_ip: &str) -> bool {
let now = Instant::now();
let window_start = now - self.rate_limit_config.window;
let entry = self.failed_attempts.get(client_ip);
if let Some(times) = entry {
let recent_attempts = times.iter().filter(|&&t| t > window_start).count();
recent_attempts >= self.rate_limit_config.max_requests as usize
} else {
false
}
}
fn record_failed_attempt(&self, client_ip: &str) {
let now = Instant::now();
let window_start = now - self.rate_limit_config.window;
let mut entry = self
.failed_attempts
.entry(client_ip.to_string())
.or_default();
let times = entry.value_mut();
times.retain(|&t| t > window_start);
times.push(now);
}
pub fn clear_failed_attempts(&self, client_ip: &str) {
self.failed_attempts.remove(client_ip);
}
}
impl_default_new!(ApiKeyAuth);
#[derive(Debug, Clone)]
pub enum JwtError {
InvalidFormat,
Base64DecodeError,
InvalidSignature,
Expired,
NotYetValid,
InvalidPayload,
ClockSkew,
}
impl std::fmt::Display for JwtError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
JwtError::InvalidFormat => write!(f, "Invalid JWT format"),
JwtError::Base64DecodeError => write!(f, "Failed to decode base64"),
JwtError::InvalidSignature => write!(f, "Invalid JWT signature"),
JwtError::Expired => write!(f, "JWT token expired"),
JwtError::NotYetValid => write!(f, "JWT token not yet valid"),
JwtError::InvalidPayload => write!(f, "Invalid JWT payload"),
JwtError::ClockSkew => write!(f, "Clock skew too large"),
}
}
}
impl std::error::Error for JwtError {}
#[derive(Clone)]
pub struct BearerAuth {
secret: Vec<u8>,
valid_tokens: Arc<DashMap<String, AuthContext>>,
blacklisted_tokens: Arc<DashMap<String, Instant>>,
expected_audience: Option<String>,
expected_issuer: Option<String>,
}
#[derive(Debug, Error)]
pub enum AuthConfigError {
#[error("Invalid secret: {0}")]
InvalidSecret(String),
#[error("Secret too short: {length} chars. Minimum 32 characters required for security.")]
SecretTooShort {
length: usize,
},
#[error("Secret must contain at least one {required_type}")]
MissingCharacterClass {
required_type: &'static str,
},
#[error("Configuration I/O error: {source}")]
IoError {
#[from]
source: std::io::Error,
},
#[error("Configuration parse error: {source}")]
ParseError {
#[from]
source: toml::de::Error,
},
}
impl BearerAuth {
pub fn new(secret: impl Into<String>) -> Self {
Self::try_new(secret).expect("Failed to create BearerAuth: invalid secret")
}
pub fn try_new(secret: impl Into<String>) -> Result<Self, AuthConfigError> {
let secret_str = secret.into();
if secret_str.len() < 32 {
return Err(AuthConfigError::SecretTooShort {
length: secret_str.len(),
});
}
if !secret_str.chars().any(|c| c.is_uppercase()) {
return Err(AuthConfigError::MissingCharacterClass {
required_type: "uppercase letter",
});
}
if !secret_str.chars().any(|c| c.is_lowercase()) {
return Err(AuthConfigError::MissingCharacterClass {
required_type: "lowercase letter",
});
}
if !secret_str.chars().any(|c| c.is_ascii_digit()) {
return Err(AuthConfigError::MissingCharacterClass {
required_type: "digit",
});
}
if !secret_str.chars().any(|c| !c.is_alphanumeric()) {
return Err(AuthConfigError::MissingCharacterClass {
required_type: "special character",
});
}
Ok(Self {
secret: secret_str.into_bytes(),
valid_tokens: Arc::new(DashMap::new()),
blacklisted_tokens: Arc::new(DashMap::new()),
expected_audience: None,
expected_issuer: None,
})
}
pub fn with_audience(secret: impl Into<String>, expected_audience: impl Into<String>) -> Self {
Self {
secret: secret.into().into_bytes(),
valid_tokens: Arc::new(DashMap::new()),
blacklisted_tokens: Arc::new(DashMap::new()),
expected_audience: Some(expected_audience.into()),
expected_issuer: None,
}
}
pub fn with_claims(
secret: impl Into<String>,
expected_audience: impl Into<String>,
expected_issuer: impl Into<String>,
) -> Self {
Self {
secret: secret.into().into_bytes(),
valid_tokens: Arc::new(DashMap::new()),
blacklisted_tokens: Arc::new(DashMap::new()),
expected_audience: Some(expected_audience.into()),
expected_issuer: Some(expected_issuer.into()),
}
}
#[cfg(feature = "security")]
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
use subtle::ConstantTimeEq;
a.ct_eq(b).into()
}
#[cfg(not(feature = "security"))]
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
let mut result = 0u8;
for (byte_a, byte_b) in a.iter().zip(b.iter()) {
result |= byte_a ^ byte_b;
}
result == 0
}
fn base64url_decode(input: &str) -> Option<Vec<u8>> {
let mut table = [0u8; 256];
for (i, b) in b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"
.iter()
.enumerate()
{
table[*b as usize] = i as u8;
}
let mut result = Vec::with_capacity(input.len() * 3 / 4);
let mut buffer = 0u32;
let mut bits = 0i32;
for c in input.bytes() {
if c == b'.' {
continue; }
if c == b' ' || c == b'\n' || c == b'\r' || c == b'\t' {
continue; }
let val = table.get(c as usize)?;
buffer = (buffer << 6) | (*val as u32);
bits += 6;
if bits >= 8 {
bits -= 8;
result.push((buffer >> bits) as u8);
}
}
Some(result)
}
fn verify_jwt(&self, token: &str) -> Option<serde_json::Value> {
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 3 {
return None;
}
let _header = Self::base64url_decode(parts[0])?;
let payload = Self::base64url_decode(parts[1])?;
let payload_str = String::from_utf8_lossy(&payload);
let payload_value: serde_json::Value = serde_json::from_str(&payload_str).ok()?;
let signature_input = format!("{}.{}", parts[0], parts[1]);
let mut mac = Hmac::<Sha256>::new_from_slice(&self.secret).ok()?;
mac.update(signature_input.as_bytes());
let expected_signature = mac.finalize().into_bytes();
let provided_signature = Self::base64url_decode(parts[2])?;
if provided_signature.len() != 32 {
return None;
}
if !Self::constant_time_eq(expected_signature.as_slice(), &provided_signature) {
return None;
}
if let Some(exp) = payload_value.get("exp").and_then(|v| v.as_i64()) {
if chrono::Utc::now().timestamp() > exp {
return None; }
}
if let Some(iat) = payload_value.get("iat").and_then(|v| v.as_i64()) {
let now = chrono::Utc::now().timestamp();
const CLOCK_SKEW_SECONDS: i64 = 60;
if iat > now + CLOCK_SKEW_SECONDS {
return None; }
}
if let Some(nbf) = payload_value.get("nbf").and_then(|v| v.as_i64()) {
if chrono::Utc::now().timestamp() < nbf {
return None; }
}
if let Some(expected_aud) = &self.expected_audience {
let token_aud = payload_value
.get("aud")
.and_then(|v| v.as_str())
.or_else(|| {
payload_value
.get("aud")
.and_then(|v| v.as_array())
.and_then(|arr| arr.first().and_then(|v| v.as_str()))
});
if token_aud != Some(expected_aud.as_str()) {
return None; }
}
if let Some(expected_iss) = &self.expected_issuer {
let token_iss = payload_value.get("iss").and_then(|v| v.as_str());
if token_iss != Some(expected_iss.as_str()) {
return None; }
}
Some(payload_value)
}
pub fn validate_token(&self, token: &str) -> Option<AuthContext> {
if let Some(expiry) = self.blacklisted_tokens.get(token) {
if Instant::now() < *expiry {
return None; }
}
let payload = self.verify_jwt(token)?;
let user_id = payload
.get("sub")
.and_then(|v| v.as_str())
.map(String::from);
let permissions: Vec<String> = payload
.get("permissions")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|p| p.as_str().map(String::from))
.collect()
})
.unwrap_or_default();
Some(AuthContext {
user_id,
permissions,
metadata: AuthMetadata::default(),
})
}
pub fn register_token(&self, token: String, context: AuthContext) {
self.valid_tokens.insert(token, context);
}
pub fn invalidate_token(&self, token: &str) {
self.blacklisted_tokens
.insert(token.to_string(), Instant::now());
}
}
#[derive(Debug, Clone)]
pub struct RateLimitConfig {
pub max_requests: u32,
pub window: Duration,
pub include_headers: bool,
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
max_requests: 100,
window: Duration::from_secs(60),
include_headers: true,
}
}
}
impl TryFrom<crate::config::RateLimitConfigFile> for RateLimitConfig {
type Error = crate::config::ConfigError;
fn try_from(config: crate::config::RateLimitConfigFile) -> Result<Self, Self::Error> {
Ok(Self {
max_requests: config.requests,
window: Duration::from_secs(config.window_seconds),
include_headers: true,
})
}
}
#[derive(Debug, Clone)]
pub struct TrustedProxyConfig {
pub trusted_proxies: Vec<String>,
pub enabled: bool,
}
impl Default for TrustedProxyConfig {
fn default() -> Self {
Self {
trusted_proxies: vec![
"127.0.0.1".to_string(),
"::1".to_string(),
"localhost".to_string(),
],
enabled: true,
}
}
}
pub fn extract_client_ip(
req: &axum::http::Request<axum::body::Body>,
config: &TrustedProxyConfig,
) -> String {
if !config.enabled {
return "unknown".to_string();
}
if let Some(xff) = req
.headers()
.get("x-forwarded-for")
.and_then(|v| v.to_str().ok())
{
if let Some(client_ip) = xff.split(',').next() {
return client_ip.trim().to_string();
}
}
if let Some(xri) = req.headers().get("x-real-ip").and_then(|v| v.to_str().ok()) {
return xri.to_string();
}
"unknown".to_string()
}
#[derive(Clone)]
pub struct RateLimiter {
config: RateLimitConfig,
requests: Arc<DashMap<String, Vec<Instant>>>,
idempotency_cache: Arc<DashMap<String, Instant>>,
semaphore: Arc<tokio::sync::Semaphore>,
}
impl RateLimiter {
pub fn new(config: Option<RateLimitConfig>) -> Self {
Self {
config: config.unwrap_or_default(),
requests: Arc::new(DashMap::new()),
idempotency_cache: Arc::new(DashMap::new()),
semaphore: Arc::new(tokio::sync::Semaphore::new(1000)),
}
}
pub fn check(&self, key: &str) -> Result<u32, RateLimitError> {
let now = Instant::now();
let window_start = now - self.config.window;
let mut entry = self.requests.entry(key.to_string()).or_default();
let times = entry.value_mut();
times.retain(|&t| t > window_start);
if times.len() >= self.config.max_requests as usize {
let retry_after = times
.first()
.map(|t| {
let elapsed = now - *t;
(self.config.window - elapsed).as_secs()
})
.unwrap_or(1);
return Err(RateLimitError {
limit: self.config.max_requests,
remaining: 0,
retry_after,
});
}
times.push(now);
Ok(self.config.max_requests - times.len() as u32)
}
pub fn check_idempotency(&self, idempotency_key: &str) -> bool {
let now = Instant::now();
let window = Duration::from_secs(60);
if let Some(existing) = self.idempotency_cache.get(idempotency_key) {
let existing_time = *existing;
let elapsed = now.saturating_duration_since(existing_time).as_secs();
if elapsed < window.as_secs() {
return true; }
}
self.idempotency_cache
.insert(idempotency_key.to_string(), now);
false }
pub fn remaining(&self, key: &str) -> u32 {
let now = Instant::now();
let window_start = now - self.config.window;
let entry = self.requests.get(key);
if let Some(times) = entry {
let active = times.iter().filter(|&&t| t > window_start).count();
self.config.max_requests - active as u32
} else {
self.config.max_requests
}
}
pub async fn acquire(&self, key: &str) -> Result<Permit, RateLimitError> {
let _remaining = self.check(key)?;
let permit = self
.semaphore
.clone()
.try_acquire_owned()
.map_err(|_| RateLimitError {
limit: self.config.max_requests,
remaining: 0,
retry_after: 1,
})?;
Ok(Permit(permit))
}
}
pub struct Permit(pub tokio::sync::OwnedSemaphorePermit);
impl Drop for Permit {
fn drop(&mut self) {
}
}
#[derive(Debug, Error)]
#[error("Rate limit exceeded. Try again in {retry_after} seconds")]
pub struct RateLimitError {
pub limit: u32,
pub remaining: u32,
pub retry_after: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuditLog {
pub id: String,
pub timestamp: i64,
pub user_id: Option<String>,
pub action: String,
pub resource: String,
pub result: AuditResult,
pub metadata: AuthMetadata,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "status")]
pub enum AuditResult {
#[serde(rename = "success")]
Success,
#[serde(rename = "failure")]
Failure {
message: String,
},
}
#[derive(Clone)]
pub struct AuditLogger {
logs: Arc<DashMap<String, Vec<AuditLog>>>,
max_logs_per_user: usize,
semaphore: Arc<tokio::sync::Semaphore>,
queue_sender: Arc<tokio::sync::mpsc::Sender<AuditLogBatch>>,
fallback_logs: Arc<DashMap<String, Vec<AuditLog>>>,
dropped_log_count: Arc<std::sync::atomic::AtomicU64>,
}
struct AuditLogBatch {
user_id: String,
log: AuditLog,
}
static JWT_PATTERN: Lazy<regex::Regex> = Lazy::new(|| {
regex::Regex::new(r#"eyJ[A-Za-z0-9\-_]+\.eyJ[A-Za-z0-9\-_]+\.[A-Za-z0-9\-_]+"#).unwrap()
});
static SECRET_PATTERN: Lazy<regex::Regex> = Lazy::new(|| {
regex::Regex::new(r#"(?i)(password|secret|token|key|auth|bearer)\s*[:=]\s*[^,\s}\]]{1,100}"#)
.unwrap()
});
static DB_PATTERN: Lazy<regex::Regex> =
Lazy::new(|| regex::Regex::new(r#"postgresql://[^:]+:[^@]+@[^/]+/\w+"#).unwrap());
static PATH_PATTERN: Lazy<regex::Regex> =
Lazy::new(|| regex::Regex::new(r#"/[a-zA-Z0-9/_.-]+\.(pem|key|crt|p12|jks)"#).unwrap());
fn sanitize_error_message(message: &str) -> String {
let mut result = message.to_string();
result = JWT_PATTERN
.replace_all(&result, "[REDACTED_JWT]")
.to_string();
result = SECRET_PATTERN
.replace_all(&result, |caps: ®ex::Captures| {
format!("{}={}", &caps[1], "[REDACTED]")
})
.to_string();
result = regex::Regex::new(r#"(?i)(api[_-]?key|apikey)\s*[:=]\s*['\"]?[A-Za-z0-9]{20,}['\"]?"#)
.unwrap()
.replace_all(&result, "[REDACTED_API_KEY]")
.to_string();
result = regex::Regex::new(r#"\b\d{4}[-\s]?\d{4}[-\s]?\d{4}[-\s]?\d{4}\b"#)
.unwrap()
.replace_all(&result, "[REDACTED_CREDIT_CARD]")
.to_string();
result = regex::Regex::new(r#"\b\d{3}[-\s]?\d{2}[-\s]?\d{4}\b"#)
.unwrap()
.replace_all(&result, "[REDACTED_SSN]")
.to_string();
result = DB_PATTERN
.replace_all(&result, "postgresql://[REDACTED]:[REDACTED]@localhost/db")
.to_string();
result = PATH_PATTERN
.replace_all(&result, "[REDACTED_PATH]")
.to_string();
const MAX_SANITIZED_LENGTH: usize = 500;
if result.len() > MAX_SANITIZED_LENGTH {
result.truncate(MAX_SANITIZED_LENGTH);
result.push_str("...[TRUNCATED]");
}
result
}
impl AuditLogger {
pub fn new() -> Self {
Self::with_limit(1000)
}
pub fn with_limit(max_logs: usize) -> Self {
let (queue_sender, mut queue_receiver) = tokio::sync::mpsc::channel::<AuditLogBatch>(1000);
let logs: Arc<DashMap<String, Vec<AuditLog>>> = Arc::new(DashMap::new());
let fallback_logs: Arc<DashMap<String, Vec<AuditLog>>> = Arc::new(DashMap::new());
let logs_clone = logs.clone();
let fallback_logs_clone = fallback_logs.clone();
let max_logs_clone = max_logs;
tokio::spawn(async move {
while let Some(batch) = queue_receiver.recv().await {
let mut entry = logs_clone.entry(batch.user_id.clone()).or_default();
entry.push(batch.log);
if entry.len() > max_logs_clone {
entry.truncate(max_logs_clone);
}
if let Some(fallback) = fallback_logs_clone.get(&batch.user_id) {
if !fallback.is_empty() {
let mut entry = logs_clone.entry(batch.user_id.clone()).or_default();
entry.extend(fallback.iter().cloned());
if entry.len() > max_logs_clone {
entry.truncate(max_logs_clone);
}
fallback_logs_clone.remove(&batch.user_id);
}
}
}
});
Self {
logs,
max_logs_per_user: max_logs,
semaphore: Arc::new(tokio::sync::Semaphore::new(100)), queue_sender: Arc::new(queue_sender),
fallback_logs,
dropped_log_count: Arc::new(std::sync::atomic::AtomicU64::new(0)),
}
}
pub async fn log(
&self,
context: &AuthContext,
action: impl Into<String>,
resource: impl Into<String>,
success: bool,
message: Option<String>,
) {
let permit = match tokio::time::timeout(
Duration::from_secs(1),
self.semaphore.clone().acquire_owned(),
)
.await
{
Ok(Ok(permit)) => permit,
Ok(Err(_)) | Err(_) => {
return;
}
};
let log = AuditLog {
id: Uuid::new_v4().to_string(),
timestamp: chrono::Utc::now().timestamp(),
user_id: context.user_id.clone(),
action: action.into(),
resource: resource.into(),
result: if success {
AuditResult::Success
} else {
AuditResult::Failure {
message: sanitize_error_message(
&message.unwrap_or_else(|| "Unknown error".to_string()),
),
}
},
metadata: context.metadata.clone(),
};
let user_id = context
.user_id
.clone()
.unwrap_or_else(|| "anonymous".to_string());
let log_for_fallback = log.clone();
let sender = self.queue_sender.clone();
let log_batch = AuditLogBatch {
user_id: user_id.clone(),
log,
};
match sender.try_send(log_batch) {
Ok(()) => {
#[cfg(feature = "logging")]
tracing::debug!(target: "audit", "Audit log queued for user: {}", user_id);
}
Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => {
#[cfg(feature = "logging")]
tracing::warn!(target: "audit",
"Audit log channel full for user: {}, using fallback storage",
user_id
);
self.store_fallback_log(&user_id, &log_for_fallback);
}
Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
#[cfg(feature = "logging")]
tracing::error!(target: "audit",
"Audit log channel closed for user: {}, using synchronous logging",
user_id
);
self.store_fallback_log(&user_id, &log_for_fallback);
}
}
drop(permit);
}
pub fn get_logs(&self, user_id: &str) -> Vec<AuditLog> {
let primary = self
.logs
.get(user_id)
.map(|e| e.clone())
.unwrap_or_default();
let fallback = self
.fallback_logs
.get(user_id)
.map(|e| e.clone())
.unwrap_or_default();
let mut all_logs = primary;
for log in fallback {
if !all_logs.iter().any(|l| l.id == log.id) {
all_logs.push(log);
}
}
all_logs.sort_by(|a, b| b.timestamp.cmp(&a.timestamp));
all_logs
}
pub fn clear_logs(&self, user_id: &str) {
self.logs.remove(user_id);
}
pub fn total_log_count(&self) -> usize {
self.logs.iter().map(|e| e.len()).sum()
}
fn store_fallback_log(&self, user_id: &str, log: &AuditLog) {
let count = self
.dropped_log_count
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
let mut entry = self.fallback_logs.entry(user_id.to_string()).or_default();
entry.push(log.clone());
if entry.len() > self.max_logs_per_user {
entry.truncate(self.max_logs_per_user);
}
if count > 0 && count.is_multiple_of(100) {
#[cfg(feature = "logging")]
tracing::warn!(target: "audit",
"High audit log drop rate: {} logs dropped due to channel congestion",
count + 1
);
}
}
pub fn dropped_log_count(&self) -> u64 {
self.dropped_log_count
.load(std::sync::atomic::Ordering::SeqCst)
}
}
impl_default_new!(AuditLogger);
pub fn auth_middleware<T: Clone + Send + Sync + 'static>(
_auth: Arc<T>,
extract_auth: impl Fn(&Request<Body>) -> AuthResult<AuthContext> + Clone + Send + 'static,
) -> impl Fn(Request<Body>, Next) -> Pin<Box<dyn Future<Output = Response> + Send>> + Clone + Send {
move |mut req: Request<Body>, next: Next| {
let extract_auth = extract_auth.clone();
Box::pin(async move {
match extract_auth(&req) {
Ok(auth_context) => {
req.extensions_mut().insert(auth_context);
next.run(req).await
}
Err(_) => {
let mut response = Response::new(Body::from("Unauthorized"));
*response.status_mut() = StatusCode::UNAUTHORIZED;
response
}
}
})
}
}
pub fn rate_limit_middleware(
limiter: Arc<RateLimiter>,
) -> impl Fn(Request<Body>, Next) -> Pin<Box<dyn Future<Output = Response> + Send>> + Clone + Send {
move |req: Request<Body>, next: Next| {
let limiter = limiter.clone();
Box::pin(async move {
let client_ip = extract_client_ip_simple(&req);
match limiter.check(&client_ip) {
Ok(remaining) => {
let mut response = next.run(req).await;
if limiter.config.include_headers {
response.headers_mut().insert(
"X-RateLimit-Limit",
HeaderValue::from(limiter.config.max_requests),
);
response
.headers_mut()
.insert("X-RateLimit-Remaining", HeaderValue::from(remaining));
}
response
}
Err(e) => {
let mut response = Response::new(Body::from("Rate limit exceeded"));
*response.status_mut() = StatusCode::TOO_MANY_REQUESTS;
response
.headers_mut()
.insert("X-RateLimit-Limit", HeaderValue::from(e.limit));
response
.headers_mut()
.insert("X-RateLimit-Remaining", HeaderValue::from(0));
response
.headers_mut()
.insert("Retry-After", HeaderValue::from(e.retry_after));
response
}
}
})
}
}
fn is_ip_in_range(ip: &str, cidr: &str) -> bool {
let parts: Vec<&str> = cidr.split('/').collect();
if parts.len() != 2 {
return false;
}
let network = parts[0];
let mask_bits: u32 = parts[1].parse().unwrap_or(0);
let ip_bytes: Vec<u8> = ip.split('.').filter_map(|s| s.parse().ok()).collect();
let net_bytes: Vec<u8> = network.split('.').filter_map(|s| s.parse().ok()).collect();
if ip_bytes.len() != 4 || net_bytes.len() != 4 {
return false;
}
let ip_val = (ip_bytes[0] as u32) << 24
| (ip_bytes[1] as u32) << 16
| (ip_bytes[2] as u32) << 8
| ip_bytes[3] as u32;
let net_val = (net_bytes[0] as u32) << 24
| (net_bytes[1] as u32) << 16
| (net_bytes[2] as u32) << 8
| net_bytes[3] as u32;
let mask_val = if mask_bits == 0 {
0
} else {
!0u32 << (32 - mask_bits)
};
(ip_val & mask_val) == (net_val & mask_val)
}
#[inline]
fn extract_client_ip_core(req: &Request<Body>) -> Option<String> {
use axum::extract::connect_info::ConnectInfo;
let trusted_proxies = ["10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "127.0.0.1"];
if let Some(header) = req.headers().get("X-Forwarded-For") {
if let Ok(value) = header.to_str() {
if let Some(ip) = value.split(',').next().map(|s| s.trim()) {
if is_valid_ip(ip)
&& trusted_proxies
.iter()
.any(|range| is_ip_in_range(ip, range))
{
return Some(ip.to_string());
}
}
}
}
if let Some(header) = req.headers().get("X-Real-IP") {
if let Ok(ip) = header.to_str() {
if is_valid_ip(ip) {
return Some(ip.to_string());
}
}
}
if let Some(remote) = req.extensions().get::<ConnectInfo<std::net::SocketAddr>>() {
return Some(remote.0.ip().to_string());
}
None
}
#[cfg(feature = "logging")]
fn extract_client_ip_simple(req: &Request<Body>) -> String {
let proxy_config = TrustedProxyConfig::default();
extract_client_ip_with_config(req, &proxy_config)
}
#[cfg(not(feature = "logging"))]
fn extract_client_ip_simple(req: &Request<Body>) -> String {
let proxy_config = TrustedProxyConfig::default();
extract_client_ip_with_config(req, &proxy_config)
}
fn extract_client_ip_with_config(req: &Request<Body>, proxy_config: &TrustedProxyConfig) -> String {
if !proxy_config.enabled {
if let Some(ip) = extract_client_ip_core(req) {
return ip;
}
return "unknown".to_string();
}
if let Some(header) = req.headers().get("X-Forwarded-For") {
if let Ok(value) = header.to_str() {
if let Some(client_ip) = value.split(',').next().map(|s| s.trim()) {
if is_valid_ip(client_ip) {
return client_ip.to_string();
}
}
}
}
if let Some(header) = req.headers().get("X-Real-IP") {
if let Ok(ip) = header.to_str() {
if is_valid_ip(ip) {
return ip.to_string();
}
}
}
extract_client_ip_core(req).unwrap_or_else(|| "unknown".to_string())
}
fn is_valid_ip(ip: &str) -> bool {
use std::net::IpAddr;
if ip.is_empty() || ip.len() > 45 {
return false;
}
if let Ok(IpAddr::V4(ipv4)) = ip.parse::<IpAddr>() {
let octets = ipv4.octets();
if octets[0] == 10 {
return false;
}
if octets[0] == 172 && octets[1] >= 16 && octets[1] <= 31 {
return false;
}
if octets[0] == 192 && octets[1] == 168 {
return false;
}
if octets[0] == 127 {
return false;
}
if octets[0] == 169 && octets[1] == 254 {
return false;
}
if octets[0] >= 224 && octets[0] <= 239 {
return false;
}
if octets[0] == 0 {
return false;
}
true
} else if let Ok(IpAddr::V6(ipv6)) = ip.parse::<IpAddr>() {
let segments = ipv6.segments();
if segments == [0, 0, 0, 0, 0, 0, 0, 1] {
return false;
}
if segments[0] & 0xffc0 == 0xfe80 {
return false;
}
if segments[0] & 0xfe00 == 0xfc00 {
return false;
}
if segments[0] & 0xff00 == 0xff00 {
return false;
}
if segments == [0, 0, 0, 0, 0, 0, 0, 0] {
return false;
}
true
} else {
false
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_api_key_auth() {
let auth = ApiKeyAuth::new();
auth.add_key("test-key", vec!["read".to_string(), "write".to_string()]);
let permissions = auth.validate_key("test-key", "127.0.0.1");
assert_eq!(
permissions,
Some(vec!["read".to_string(), "write".to_string()])
);
let permissions = auth.validate_key("invalid-key", "127.0.0.1");
assert_eq!(permissions, None);
}
#[tokio::test]
async fn test_api_key_auth_rate_limiting() {
let auth = ApiKeyAuth::with_rate_limit(RateLimitConfig {
max_requests: 3,
window: Duration::from_secs(60),
include_headers: false,
});
auth.add_key("valid-key", vec!["read".to_string()]);
for i in 0..3 {
assert_eq!(
auth.validate_key(&format!("invalid-key-{}", i), "192.168.1.1"),
None
);
}
assert_eq!(auth.validate_key("invalid-key-4", "192.168.1.1"), None);
let permissions = auth.validate_key("valid-key", "192.168.1.1");
assert_eq!(permissions, Some(vec!["read".to_string()]));
assert_eq!(
auth.validate_key("another-invalid-key", "192.168.1.1"),
None
);
}
#[tokio::test]
async fn test_api_key_hashing() {
let auth = ApiKeyAuth::new();
auth.add_key("test-key", vec!["admin".to_string()]);
assert!(auth.validate_key("test-key", "127.0.0.1").is_some());
assert!(auth.validate_key("TEST-KEY", "127.0.0.1").is_none());
}
#[tokio::test]
async fn test_rate_limiter() {
let config = RateLimitConfig {
max_requests: 3,
window: Duration::from_secs(60),
include_headers: true,
};
let limiter = RateLimiter::new(Some(config));
for _ in 0..3 {
assert!(limiter.check("test-ip").is_ok());
}
assert!(limiter.check("test-ip").is_err());
}
#[tokio::test]
async fn test_audit_logger() {
let logger = AuditLogger::new();
let context = AuthContext {
user_id: Some("user-123".to_string()),
permissions: vec![],
metadata: AuthMetadata::default(),
};
logger
.log(&context, "test_action", "test_resource", true, None)
.await;
tokio::task::yield_now().await;
let logs = logger.get_logs("user-123");
assert_eq!(logs.len(), 1);
assert_eq!(logs[0].action, "test_action");
}
#[test]
fn test_ip_range_validation() {
assert!(is_ip_in_range("10.0.0.1", "10.0.0.0/8"));
assert!(is_ip_in_range("192.168.1.100", "192.168.0.0/16"));
assert!(is_ip_in_range("172.16.5.5", "172.16.0.0/12"));
assert!(is_ip_in_range("172.31.255.255", "172.16.0.0/12"));
assert!(!is_ip_in_range("8.8.8.8", "10.0.0.0/8"));
assert!(!is_ip_in_range("172.32.0.1", "172.16.0.0/12"));
assert!(!is_ip_in_range("8.8.8.8", "192.168.0.0/16"));
}
#[test]
fn test_auth_context_creation() {
let metadata = AuthMetadata::new(
Some("192.168.1.1".to_string()),
Some("TestClient/1.0".to_string()),
);
let context = AuthContext::new(
Some("user-123".to_string()),
vec!["read".to_string(), "write".to_string()],
metadata,
);
assert_eq!(context.user_id(), Some("user-123"));
assert_eq!(context.permissions().len(), 2);
assert!(context.has_permission("read"));
assert!(context.has_permission("write"));
assert!(!context.has_permission("delete"));
}
#[test]
fn test_auth_context_without_user() {
let context = AuthContext::new(None, vec![], AuthMetadata::default());
assert_eq!(context.user_id(), None);
assert!(context.permissions().is_empty());
assert!(!context.has_permission("any"));
}
#[test]
fn test_auth_metadata_creation() {
let metadata = AuthMetadata::new(
Some("10.0.0.1".to_string()),
Some("Mozilla/5.0".to_string()),
);
assert_eq!(metadata.client_ip(), Some("10.0.0.1"));
assert_eq!(metadata.user_agent(), Some("Mozilla/5.0"));
assert!(!metadata.request_id().is_empty());
assert!(metadata.timestamp() > 0);
}
#[test]
fn test_auth_error_messages() {
let missing_auth = AuthError::MissingAuth;
assert_eq!(
missing_auth.to_string(),
"Missing or invalid authorization header"
);
let invalid_token = AuthError::InvalidToken;
assert_eq!(invalid_token.to_string(), "Invalid or expired token");
let insufficient = AuthError::InsufficientPermissions {
required: "admin".to_string(),
user_permissions: vec!["read".to_string()],
};
assert!(insufficient
.to_string()
.contains("Insufficient permissions"));
assert!(insufficient.to_string().contains("admin"));
}
#[test]
fn test_bearer_auth_secret_too_short() {
let result = BearerAuth::try_new("Short1!");
assert!(result.is_err());
if let Err(AuthConfigError::SecretTooShort { length }) = result {
assert_eq!(length, 7); } else {
panic!("Expected SecretTooShort error");
}
}
#[test]
fn test_bearer_auth_missing_uppercase() {
let result = BearerAuth::try_new("lowercase123!abcdefghijklmnopqrstuvwxyz");
assert!(result.is_err());
match result {
Err(AuthConfigError::MissingCharacterClass { required_type }) => {
assert_eq!(required_type, "uppercase letter");
}
Err(_) => {
panic!("Expected MissingCharacterClass error");
}
Ok(_) => {
panic!("Expected error but got success");
}
}
}
#[test]
fn test_bearer_auth_missing_lowercase() {
let result = BearerAuth::try_new("UPPERCASE123!ABCDEFGHIJKLMNOPQRSTUVWXYZ");
assert!(result.is_err());
match result {
Err(AuthConfigError::MissingCharacterClass { required_type }) => {
assert_eq!(required_type, "lowercase letter");
}
Err(_) => {
panic!("Expected MissingCharacterClass error");
}
Ok(_) => {
panic!("Expected error but got success");
}
}
}
#[test]
fn test_bearer_auth_missing_digit() {
let result = BearerAuth::try_new("LowercaseUpper!ABCDEFGHIJKLMNOPQRSTUVWXYZ");
assert!(result.is_err());
match result {
Err(AuthConfigError::MissingCharacterClass { required_type }) => {
assert_eq!(required_type, "digit");
}
Err(_) => {
panic!("Expected MissingCharacterClass error");
}
Ok(_) => {
panic!("Expected error but got success");
}
}
}
#[test]
fn test_bearer_auth_missing_special_char() {
let result = BearerAuth::try_new("LowercaseUpper123ABCDEFGHIJKLMNOPQRSTUVWXYZ");
assert!(result.is_err());
match result {
Err(AuthConfigError::MissingCharacterClass { required_type }) => {
assert_eq!(required_type, "special character");
}
Err(_) => {
panic!("Expected MissingCharacterClass error");
}
Ok(_) => {
panic!("Expected error but got success");
}
}
}
#[test]
fn test_bearer_auth_valid_secret() {
let auth = BearerAuth::try_new("ValidSecret123!ABCDEFGHIJKLMNOPQRSTUVWXYZ")
.expect("Valid secret should work");
assert!(auth.validate_token("invalid-token").is_none());
}
#[test]
fn test_bearer_auth_with_audience() {
let auth = BearerAuth::with_audience("ValidSecret123!ABCDEFGHIJKLMNOPQRSTUVWXYZ", "my-api");
assert!(auth.validate_token("any-token").is_none());
}
#[test]
fn test_bearer_auth_with_claims() {
let auth = BearerAuth::with_claims(
"ValidSecret123!ABCDEFGHIJKLMNOPQRSTUVWXYZ",
"my-api",
"issuer",
);
assert!(auth.validate_token("any-token").is_none());
}
#[test]
fn test_rate_limit_config_default() {
let config = RateLimitConfig::default();
assert_eq!(config.max_requests, 100);
assert_eq!(config.window, Duration::from_secs(60));
assert!(config.include_headers);
}
#[test]
fn test_rate_limit_config_custom() {
let config = RateLimitConfig {
max_requests: 50,
window: Duration::from_secs(30),
include_headers: false,
};
assert_eq!(config.max_requests, 50);
assert_eq!(config.window, Duration::from_secs(30));
assert!(!config.include_headers);
}
#[test]
fn test_rate_limiter_remaining() {
let config = RateLimitConfig {
max_requests: 5,
window: Duration::from_secs(60),
include_headers: false,
};
let limiter = RateLimiter::new(Some(config));
assert_eq!(limiter.remaining("test-ip"), 5);
let _ = limiter.check("test-ip");
assert_eq!(limiter.remaining("test-ip"), 4);
let _ = limiter.check("test-ip");
assert_eq!(limiter.remaining("test-ip"), 3);
}
#[test]
fn test_rate_limiter_idempotency() {
let limiter = RateLimiter::new(None);
assert!(!limiter.check_idempotency("req-123"));
assert!(limiter.check_idempotency("req-123"));
assert!(limiter.check_idempotency("req-123"));
assert!(!limiter.check_idempotency("req-456"));
}
#[tokio::test]
async fn test_rate_limiter_acquire() {
let limiter = RateLimiter::new(Some(RateLimitConfig {
max_requests: 2,
window: Duration::from_secs(60),
include_headers: false,
}));
assert!(limiter.acquire("ip-1").await.is_ok());
assert!(limiter.acquire("ip-1").await.is_ok());
assert!(limiter.acquire("ip-1").await.is_err());
}
#[tokio::test]
async fn test_audit_logger_get_logs_empty() {
let logger = AuditLogger::new();
let logs = logger.get_logs("nonexistent-user");
assert!(logs.is_empty());
}
#[tokio::test]
async fn test_audit_logger_clear_logs() {
let logger = AuditLogger::new();
let context = AuthContext::new(
Some("user-to-clear".to_string()),
vec![],
AuthMetadata::default(),
);
logger
.log(&context, "test_action", "test_resource", true, None)
.await;
tokio::task::yield_now().await;
assert_eq!(logger.get_logs("user-to-clear").len(), 1);
logger.clear_logs("user-to-clear");
assert!(logger.get_logs("user-to-clear").is_empty());
}
#[tokio::test]
async fn test_audit_logger_total_log_count() {
let logger = AuditLogger::new();
let context = AuthContext::new(Some("user-1".to_string()), vec![], AuthMetadata::default());
logger
.log(&context, "action1", "resource1", true, None)
.await;
tokio::task::yield_now().await;
let count = logger.total_log_count();
assert!(count >= 1);
}
#[test]
fn test_sanitize_jwt_token() {
let message = "Invalid token: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.dozjgNryP4J3jVmNHl0w5N_XgL0n3I9PlFUP0THsR8U";
let sanitized = sanitize_error_message(message);
assert!(
!sanitized.contains("eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9"),
"JWT header should be redacted"
);
assert!(
!sanitized.contains("eyJzdWIiOiIxMjM0NTY3ODkwIn0"),
"JWT payload should be redacted"
);
assert!(
sanitized.contains("REDACTED") || sanitized.contains("redacted"),
"Should contain redaction marker"
);
}
#[test]
fn test_sanitize_password() {
let message = "Connection failed: password=secret123, user=admin";
let sanitized = sanitize_error_message(message);
assert!(!sanitized.contains("secret123"));
assert!(sanitized.contains("[REDACTED]"));
}
#[test]
fn test_sanitize_database_url() {
let message = "DB error: postgresql://user:pass123@localhost/mydb";
let sanitized = sanitize_error_message(message);
assert!(!sanitized.contains("pass123"));
assert!(sanitized.contains("postgresql://[REDACTED]:[REDACTED]@localhost/db"));
}
#[test]
fn test_sanitize_private_key_path() {
let message = "Invalid file: /etc/ssl/private/server.key";
let sanitized = sanitize_error_message(message);
assert!(!sanitized.contains(".key"));
assert!(sanitized.contains("[REDACTED_PATH]"));
}
#[test]
fn test_sanitize_max_length() {
let long_message = "Error: ".to_string() + &"x".repeat(600);
let sanitized = sanitize_error_message(&long_message);
assert!(
sanitized.len() <= 520,
"Sanitized message should be truncated to ~520 chars max"
);
assert!(
sanitized.contains("...") || sanitized.len() <= 500,
"Should indicate truncation or be under limit"
);
}
#[test]
fn test_is_valid_ip_public() {
assert!(is_valid_ip("8.8.8.8"));
assert!(is_valid_ip("1.1.1.1"));
assert!(is_valid_ip("203.0.113.1"));
}
#[test]
fn test_is_valid_ip_private_ranges() {
assert!(!is_valid_ip("10.0.0.1"));
assert!(!is_valid_ip("172.16.0.1"));
assert!(!is_valid_ip("192.168.1.1"));
}
#[test]
fn test_is_valid_ip_loopback() {
assert!(!is_valid_ip("127.0.0.1"));
assert!(!is_valid_ip("::1"));
}
#[test]
fn test_is_valid_ip_link_local() {
assert!(!is_valid_ip("169.254.0.1"));
}
#[test]
fn test_is_valid_ip_multicast() {
assert!(!is_valid_ip("224.0.0.1"));
assert!(!is_valid_ip("239.255.255.255"));
}
#[test]
fn test_is_valid_ip_empty() {
assert!(!is_valid_ip(""));
}
#[test]
fn test_is_valid_ip_too_long() {
let long_ip = "123.456.789.012.345.678.901.234.567";
assert!(!is_valid_ip(long_ip));
}
#[test]
fn test_is_valid_ip_ipv6_public() {
assert!(is_valid_ip("2001:db8::1"));
}
#[test]
fn test_rate_limit_error_message() {
let error = RateLimitError {
limit: 100,
remaining: 0,
retry_after: 30,
};
let message = error.to_string();
assert!(message.contains("30")); assert!(message.to_lowercase().contains("rate limit")); }
#[test]
fn test_audit_result_serialization() {
use serde_json;
let success = AuditResult::Success;
let json = serde_json::to_string(&success).unwrap();
assert!(json.contains("\"status\":\"success\""));
let failure = AuditResult::Failure {
message: "Test error".to_string(),
};
let json = serde_json::to_string(&failure).unwrap();
assert!(json.contains("\"status\":\"failure\""));
assert!(json.contains("Test error"));
}
#[tokio::test]
async fn test_api_key_clear_failed_attempts() {
let auth = ApiKeyAuth::with_rate_limit(RateLimitConfig {
max_requests: 2,
window: Duration::from_secs(60),
include_headers: false,
});
auth.add_key("valid-key", vec!["read".to_string()]);
let _ = auth.validate_key("invalid-1", "192.168.1.1");
let _ = auth.validate_key("invalid-2", "192.168.1.1");
assert_eq!(auth.validate_key("invalid-3", "192.168.1.1"), None);
auth.clear_failed_attempts("192.168.1.1");
assert_eq!(auth.validate_key("invalid-4", "192.168.1.1"), None);
}
#[tokio::test]
async fn test_api_key_different_ips() {
let auth = ApiKeyAuth::with_rate_limit(RateLimitConfig {
max_requests: 2,
window: Duration::from_secs(60),
include_headers: false,
});
auth.add_key("valid-key", vec!["read".to_string()]);
let _ = auth.validate_key("invalid", "192.168.1.1");
let _ = auth.validate_key("invalid", "192.168.1.1");
assert_eq!(auth.validate_key("invalid", "192.168.1.1"), None);
assert_eq!(auth.validate_key("invalid", "192.168.1.2"), None);
}
#[tokio::test]
async fn test_bearer_auth_blacklist() {
let auth =
BearerAuth::try_new("ValidSecret123!ABCDEFGHIJKLMNOPQRSTUVWXYZ").expect("Valid secret");
auth.invalidate_token("test-token-to-blacklist");
assert!(auth.validate_token("test-token-to-blacklist").is_none());
}
}